ProcessGroupRoundRobin.hpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #pragma once
  2. #include <vector>
  3. #include <c10d/ProcessGroup.hpp>
  4. namespace c10d {
  5. constexpr const char* ROUND_ROBIN_BACKEND_NAME = "round_robin";
  6. // ProcessGroupRoundRobin implements simple load balancing.
  7. //
  8. // It is constructed with multiple processes groups. Each call is dispatched to
  9. // one of the specified process groups in a round robin fashion. Each process
  10. // group instance must have the same rank and size.
  11. //
  12. // All functions of the class are expected to be called in the same order
  13. // across all processes in the process group. This is the only way that we
  14. // can guarantee to match up the same calls among all processes.
  15. //
  16. class TORCH_API ProcessGroupRoundRobin final : public ProcessGroup {
  17. public:
  18. explicit ProcessGroupRoundRobin(
  19. int rank,
  20. int size,
  21. std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups);
  22. ~ProcessGroupRoundRobin() override;
  23. const std::string getBackendName() const override {
  24. return std::string(ROUND_ROBIN_BACKEND_NAME);
  25. }
  26. c10::intrusive_ptr<ProcessGroup::Work> broadcast(
  27. std::vector<at::Tensor>& tensors,
  28. const BroadcastOptions& opts = BroadcastOptions()) override;
  29. c10::intrusive_ptr<ProcessGroup::Work> allreduce(
  30. std::vector<at::Tensor>& tensors,
  31. const AllreduceOptions& opts = AllreduceOptions()) override;
  32. c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
  33. std::vector<at::Tensor>& tensors,
  34. const AllreduceCoalescedOptions& opts =
  35. AllreduceCoalescedOptions()) override;
  36. c10::intrusive_ptr<ProcessGroup::Work> reduce(
  37. std::vector<at::Tensor>& tensors,
  38. const ReduceOptions& opts = ReduceOptions()) override;
  39. c10::intrusive_ptr<ProcessGroup::Work> allgather(
  40. std::vector<std::vector<at::Tensor>>& outputs,
  41. std::vector<at::Tensor>& inputs,
  42. const AllgatherOptions& opts = AllgatherOptions()) override;
  43. c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
  44. at::Tensor& outputBuffer,
  45. at::Tensor& inputBuffer,
  46. const AllgatherOptions& opts = AllgatherOptions()) override;
  47. c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
  48. std::vector<std::vector<at::Tensor>>& outputTensorLists,
  49. std::vector<at::Tensor>& inputTensors,
  50. const AllgatherOptions& opts = AllgatherOptions()) override;
  51. c10::intrusive_ptr<ProcessGroup::Work> gather(
  52. std::vector<std::vector<at::Tensor>>& outputs,
  53. std::vector<at::Tensor>& inputs,
  54. const GatherOptions& opts = GatherOptions()) override;
  55. c10::intrusive_ptr<ProcessGroup::Work> scatter(
  56. std::vector<at::Tensor>& outputs,
  57. std::vector<std::vector<at::Tensor>>& inputs,
  58. const ScatterOptions& opts = ScatterOptions()) override;
  59. c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
  60. std::vector<at::Tensor>& outputs,
  61. std::vector<std::vector<at::Tensor>>& inputs,
  62. const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
  63. c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
  64. at::Tensor& outputTensor,
  65. at::Tensor& inputTensor,
  66. std::vector<int64_t>& outputSplitSizes,
  67. std::vector<int64_t>& inputSplitSizes,
  68. const AllToAllOptions& opts = AllToAllOptions()) override;
  69. c10::intrusive_ptr<ProcessGroup::Work> send(
  70. std::vector<at::Tensor>& tensors,
  71. int dstRank,
  72. int tag) override;
  73. c10::intrusive_ptr<ProcessGroup::Work> recv(
  74. std::vector<at::Tensor>& tensors,
  75. int srcRank,
  76. int tag) override;
  77. c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
  78. std::vector<at::Tensor>& tensors,
  79. int tag) override;
  80. c10::intrusive_ptr<ProcessGroup::Work> barrier(
  81. const BarrierOptions& opts = BarrierOptions()) override;
  82. private:
  83. std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups_;
  84. std::vector<c10::intrusive_ptr<ProcessGroup>>::const_iterator iterator_;
  85. // Returns the next ProcessGroup to use.
  86. const c10::intrusive_ptr<ProcessGroup>& next();
  87. };
  88. } // namespace c10d