#pragma once #include #include namespace c10d { constexpr const char* ROUND_ROBIN_BACKEND_NAME = "round_robin"; // ProcessGroupRoundRobin implements simple load balancing. // // It is constructed with multiple processes groups. Each call is dispatched to // one of the specified process groups in a round robin fashion. Each process // group instance must have the same rank and size. // // All functions of the class are expected to be called in the same order // across all processes in the process group. This is the only way that we // can guarantee to match up the same calls among all processes. // class TORCH_API ProcessGroupRoundRobin final : public ProcessGroup { public: explicit ProcessGroupRoundRobin( int rank, int size, std::vector> processGroups); ~ProcessGroupRoundRobin() override; const std::string getBackendName() const override { return std::string(ROUND_ROBIN_BACKEND_NAME); } c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; c10::intrusive_ptr allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr _allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts = GatherOptions()) override; c10::intrusive_ptr scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts = ScatterOptions()) override; c10::intrusive_ptr reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; private: std::vector> processGroups_; std::vector>::const_iterator iterator_; // Returns the next ProcessGroup to use. const c10::intrusive_ptr& next(); }; } // namespace c10d