| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- #pragma once
- #include <vector>
- #include <c10d/ProcessGroup.hpp>
- namespace c10d {
- constexpr const char* ROUND_ROBIN_BACKEND_NAME = "round_robin";
- // ProcessGroupRoundRobin implements simple load balancing.
- //
- // It is constructed with multiple processes groups. Each call is dispatched to
- // one of the specified process groups in a round robin fashion. Each process
- // group instance must have the same rank and size.
- //
- // All functions of the class are expected to be called in the same order
- // across all processes in the process group. This is the only way that we
- // can guarantee to match up the same calls among all processes.
- //
- class TORCH_API ProcessGroupRoundRobin final : public ProcessGroup {
- public:
- explicit ProcessGroupRoundRobin(
- int rank,
- int size,
- std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups);
- ~ProcessGroupRoundRobin() override;
- const std::string getBackendName() const override {
- return std::string(ROUND_ROBIN_BACKEND_NAME);
- }
- c10::intrusive_ptr<ProcessGroup::Work> broadcast(
- std::vector<at::Tensor>& tensors,
- const BroadcastOptions& opts = BroadcastOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> allreduce(
- std::vector<at::Tensor>& tensors,
- const AllreduceOptions& opts = AllreduceOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
- std::vector<at::Tensor>& tensors,
- const AllreduceCoalescedOptions& opts =
- AllreduceCoalescedOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> reduce(
- std::vector<at::Tensor>& tensors,
- const ReduceOptions& opts = ReduceOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> allgather(
- std::vector<std::vector<at::Tensor>>& outputs,
- std::vector<at::Tensor>& inputs,
- const AllgatherOptions& opts = AllgatherOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
- at::Tensor& outputBuffer,
- at::Tensor& inputBuffer,
- const AllgatherOptions& opts = AllgatherOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
- std::vector<std::vector<at::Tensor>>& outputTensorLists,
- std::vector<at::Tensor>& inputTensors,
- const AllgatherOptions& opts = AllgatherOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> gather(
- std::vector<std::vector<at::Tensor>>& outputs,
- std::vector<at::Tensor>& inputs,
- const GatherOptions& opts = GatherOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> scatter(
- std::vector<at::Tensor>& outputs,
- std::vector<std::vector<at::Tensor>>& inputs,
- const ScatterOptions& opts = ScatterOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
- std::vector<at::Tensor>& outputs,
- std::vector<std::vector<at::Tensor>>& inputs,
- const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
- at::Tensor& outputTensor,
- at::Tensor& inputTensor,
- std::vector<int64_t>& outputSplitSizes,
- std::vector<int64_t>& inputSplitSizes,
- const AllToAllOptions& opts = AllToAllOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> send(
- std::vector<at::Tensor>& tensors,
- int dstRank,
- int tag) override;
- c10::intrusive_ptr<ProcessGroup::Work> recv(
- std::vector<at::Tensor>& tensors,
- int srcRank,
- int tag) override;
- c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
- std::vector<at::Tensor>& tensors,
- int tag) override;
- c10::intrusive_ptr<ProcessGroup::Work> barrier(
- const BarrierOptions& opts = BarrierOptions()) override;
- private:
- std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups_;
- std::vector<c10::intrusive_ptr<ProcessGroup>>::const_iterator iterator_;
- // Returns the next ProcessGroup to use.
- const c10::intrusive_ptr<ProcessGroup>& next();
- };
- } // namespace c10d
|