| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- #pragma once
- #ifdef USE_C10D_GLOO
- #include <c10d/ProcessGroup.hpp>
- #include <c10d/ProcessGroupGloo.hpp>
- #include <c10d/Types.hpp>
- #include <c10d/Utils.hpp>
- namespace c10d {
- class TORCH_API ProcessGroupWrapper : public ProcessGroup {
- public:
- explicit ProcessGroupWrapper(
- c10::intrusive_ptr<ProcessGroup> pg,
- c10::intrusive_ptr<ProcessGroupGloo> glooPg);
- const std::string getBackendName() const override;
- c10::intrusive_ptr<ProcessGroup::Work> broadcast(
- std::vector<at::Tensor>& data,
- const BroadcastOptions& opts = BroadcastOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> allreduce(
- std::vector<at::Tensor>& data,
- const AllreduceOptions& opts = AllreduceOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
- std::vector<at::Tensor>& tensors,
- const AllreduceCoalescedOptions& opts =
- AllreduceCoalescedOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> reduce(
- std::vector<at::Tensor>& tensors,
- const ReduceOptions& opts = ReduceOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> allgather(
- std::vector<std::vector<at::Tensor>>& outputTensors,
- std::vector<at::Tensor>& inputTensors,
- const AllgatherOptions& opts = AllgatherOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
- at::Tensor& outputBuffer,
- at::Tensor& inputBuffer,
- const AllgatherOptions& opts = AllgatherOptions()) override;
- // This function is deprecated and will be moved out of ProcessGroup to comms:
- // * do not add dependencies on this function,
- // * do not implement it in your ProcessGroup, implement _allgather_base
- // instead.
- c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
- std::vector<std::vector<at::Tensor>>& outputTensorLists,
- std::vector<at::Tensor>& inputTensors,
- const AllgatherOptions& opts = AllgatherOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> gather(
- std::vector<std::vector<at::Tensor>>& outputTensors,
- std::vector<at::Tensor>& inputTensors,
- const GatherOptions& opts = GatherOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> scatter(
- std::vector<at::Tensor>& outputTensors,
- std::vector<std::vector<at::Tensor>>& inputTensors,
- const ScatterOptions& opts = ScatterOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
- std::vector<at::Tensor>& outputTensors,
- std::vector<std::vector<at::Tensor>>& inputTensors,
- const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
- at::Tensor& outputTensor,
- at::Tensor& inputTensor,
- std::vector<int64_t>& outputSplitSizes,
- std::vector<int64_t>& inputSplitSizes,
- const AllToAllOptions& opts = AllToAllOptions()) override;
- c10::intrusive_ptr<ProcessGroup::Work> alltoall(
- std::vector<at::Tensor>& outputTensors,
- std::vector<at::Tensor>& inputTensors,
- const AllToAllOptions& opts = AllToAllOptions()) override;
- void monitoredBarrier(const BarrierOptions& opts, bool waitAllRanks = false)
- override;
- // Agrees on an initial sequence number for the whole group by having rank 0
- // create it and broadcast it to other ranks using the store. Only implemented
- // for GLOO and NCCL backends currently.
- // dont implement this
- void setSequenceNumberForGroup() override;
- // Retrieves the current sequence number for the whole group, which should be
- // in sync. If the returned number is not consistent across the group, it
- // may indicate that there is some sort of collective desynchronization.
- uint64_t getSequenceNumberForGroup() override; // just call underlying
- c10::intrusive_ptr<ProcessGroup::Work> send(
- std::vector<at::Tensor>& tensors,
- int dstRank,
- int tag) override;
- c10::intrusive_ptr<ProcessGroup::Work> recv(
- std::vector<at::Tensor>& tensors,
- int srcRank,
- int tag) override;
- c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
- std::vector<at::Tensor>& tensors,
- int tag) override;
- c10::intrusive_ptr<ProcessGroup::Work> barrier(
- const BarrierOptions& opts = BarrierOptions()) override;
- c10::intrusive_ptr<ProcessGroup> getWrappedPg() const;
- private:
- // Underlying process group that actual application collectives will be
- // dispatched to
- c10::intrusive_ptr<ProcessGroup> pg_;
- // Gloo process group responsible for internal coordination such as monitored
- // barrier, sequence number checking, collective fingerprint collecting.
- c10::intrusive_ptr<ProcessGroupGloo> glooPg_;
- // Conducts several checks to ensure that the underlying collective is well
- // formed with the goal of notifying the user about incorrect collective use
- // in the application.
- void runCollectiveChecks(
- OpType op_type,
- const std::vector<at::Tensor>& tensors) const;
- };
- } // namespace c10d
- #endif // USE_C10D_GLOO
|