| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443 |
- #pragma once
- #include <condition_variable>
- #include <memory>
- #include <mutex>
- #include <stdexcept>
- #include <unordered_map>
- #include <vector>
- #include <ATen/ATen.h>
- #include <c10/macros/Macros.h>
- #include <c10d/Types.hpp>
- #include <c10d/Utils.hpp>
- #include <c10d/debug.h>
- #include <c10d/sequence_num.hpp>
- // *************************************************************************
- // PROCESS GROUP collective communication API IS BEING CHANGED BETWEEN
- // versions 1.7 and 1.8.
- // PLEASE DO NOT ADD ANY DEPENDENCIES.
- // SEE RFC: https://github.com/pytorch/pytorch/issues/39662
- // *************************************************************************
- constexpr auto kNoTimeout = std::chrono::milliseconds(0);
- constexpr auto kProcessGroupDefaultTimeout =
- std::chrono::milliseconds(30 * 60 * 1000);
- namespace c10d {
- constexpr const char* const kSeqNumStoreKey = "SEQ_NUM_STORE_KEY";
- enum class OpType : std::uint8_t {
- BROADCAST = 0,
- ALLREDUCE = 1,
- ALLREDUCE_COALESCED = 2,
- REDUCE = 3,
- ALLGATHER = 4,
- _ALLGATHER_BASE = 5,
- ALLGATHER_COALESCED = 6,
- GATHER = 7,
- SCATTER = 8,
- REDUCE_SCATTER = 9,
- ALLTOALL_BASE = 10,
- ALLTOALL = 11,
- SEND = 12,
- RECV = 13,
- RECVANYSOURCE = 14,
- BARRIER = 15,
- _REDUCE_SCATTER_BASE = 16,
- UNKNOWN = 100,
- };
- // Converts OpType to human readable string.
- TORCH_API std::string opTypeToString(OpType opType);
- // Whether or not an OP is an p2p op (SEND, RECV, RECVANYSOURCE)
- TORCH_API bool isP2POp(OpType opType, bool batchP2P = false);
- // ProcessGroup is a base class that captures collective and point to
- // point communication in a fixed set of processes.
- //
- // The functions specified in the class below describe the API alone;
- // implementations are provided in subclasses.
- //
- // Every function that performs I/O is executed asynchronously by a
- // thread pool owned by the ProcessGroup (by default). They return an
- // object that can be used to wait for completion or error.
- //
- // The ProcessGroup can instantiate subgroups with fewer or an equal
- // number of members. Implementations must take care that multiple
- // process groups can be used in parallel and synchronize accordingly.
- //
- // The ProcessGroup assumes a fixed set of processes. If the set
- // changes, existing instances must be destructed and instantiation
- // and initialization must start from scratch. For members of the
- // process group to find each other (referred to as rendezvous from
- // hereon)
- //
- class TORCH_API ProcessGroup : public torch::CustomClassHolder {
- public:
- // Please do not use ProcessGroup::Work API, it is going away, to be
- // replaced by ivalue::Future.
- // Python binding for this class might change, please do not assume
- // this will be bound using pybind.
- class TORCH_API Work : public torch::CustomClassHolder {
- public:
- Work(
- int rank = -1,
- OpType opType = OpType::UNKNOWN,
- const char* profilingTitle = nullptr,
- const c10::optional<std::vector<at::Tensor>>& inputTensors =
- c10::nullopt);
- virtual ~Work();
- // Checks if request has completed. Non-blocking operation.
- virtual bool isCompleted();
- // Returns if the work completed successfully.
- // If false, the exception function can be called to get details.
- virtual bool isSuccess() const;
- // Returns exception if isSuccess() returned false.
- virtual std::exception_ptr exception() const;
- // Returns source rank if this objects represents a recv-from-any.
- virtual int sourceRank() const;
- // Returns result tensors, if applicable.
- // If work is not supposed to have result, we return empty list.
- virtual std::vector<at::Tensor> result();
- // Ensures that operations on the output tensors that are invoked
- // after this function returns are correctly sequenced after the
- // asynchronous completion of this work.
- //
- // For CUDA tensors, it inserts stream synchronization such that
- // the streams of the caller wait for completion of the
- // asynchronous operations on the destination tensors.
- //
- // For CPU tensors, it is currently a nop.
- //
- // This function should only be used if the caller polls for
- // completion through the `isCompleted` function, it has returned
- // true, and the `isSuccess` function also has returned true.
- //
- virtual void synchronize();
- // Waits until request completes. Blocking operation.
- // Throws if the work completed with an exception.
- // Returns false if the work is aborted.
- // Otherwise, it always returns true, indicating the work is completed.
- //
- // Functionally equivalent to:
- //
- // while (!isCompleted()) { /* nop */ }
- // auto success = isSuccess();
- // if (!success) { std::rethrow_exception(exception()); }
- // return success;
- //
- virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout);
- virtual void abort();
- // Returns a Future object that will be associated with the completion of
- // work. Only NCCL backend is currently supported.
- virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture();
- OpType retrieveOpType();
- protected:
- // Completes the work object and optionally sets the exception in a
- // thread-safe manner. Notifies all waiting condition variables as well.
- void finish(std::exception_ptr exception = nullptr);
- // Similar to finish, but throws an exception if one is already set or
- // provided by the user.
- void finishAndThrow(std::exception_ptr exception);
- mutable std::mutex mutex_;
- std::condition_variable cv_;
- bool completed_ = false;
- std::exception_ptr exception_;
- // Current rank of the node.
- const int rank_;
- // Operation type that this work object refers to.
- OpType opType_;
- // When profiling, the callback to record end of operation event. This
- // callback needs to be called when collective operation is complete.
- std::function<void()> recordFunctionEndCallback_;
- };
- // ProcessGroup Options is a base struct that defines the basic options
- // when constructing a ProcessGroup. Each ProcessGroup subclass should
- // extend this struct and define its options if it wants to provide more
- // config options (beyond basic ones defined here) to end user.
- struct TORCH_API Options : torch::CustomClassHolder {
- explicit Options(
- std::string backend,
- std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout)
- : timeout(timeout), backend(backend) {}
- virtual ~Options() = default;
- std::chrono::milliseconds timeout;
- // backend name
- const std::string backend;
- };
- explicit ProcessGroup(int rank, int size);
- virtual ~ProcessGroup();
- int getRank() const {
- return rank_;
- }
- int getSize() const {
- return size_;
- }
- // Subclasses must override this method to return the backend name
- virtual const std::string getBackendName() const = 0;
- virtual c10::intrusive_ptr<ProcessGroup::Work> broadcast(
- std::vector<at::Tensor>& /* tensors */,
- const BroadcastOptions& /* opts */ = BroadcastOptions()) {
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ", getBackendName(), "does not support broadcast"));
- }
- virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce(
- std::vector<at::Tensor>& /* tensors */,
- const AllreduceOptions& /* opts */ = AllreduceOptions()) {
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ", getBackendName(), "does not support allreduce"));
- }
- virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
- std::vector<at::Tensor>& /* tensors */,
- const AllreduceCoalescedOptions& /* opts */ = AllreduceCoalescedOptions()) {
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ",
- getBackendName(),
- "does not support allreduce_coalesced"));
- }
- virtual c10::intrusive_ptr<ProcessGroup::Work> reduce(
- std::vector<at::Tensor>& /* tensors */,
- const ReduceOptions& /* opts */ = ReduceOptions()) {
- TORCH_CHECK(
- false,
- c10::str("ProcessGroup ", getBackendName(), "does not support reduce"));
- }
- virtual c10::intrusive_ptr<ProcessGroup::Work> allgather(
- std::vector<std::vector<at::Tensor>>& /* outputTensors */,
- std::vector<at::Tensor>& /* inputTensors */,
- const AllgatherOptions& /* opts */ = AllgatherOptions()) {
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ", getBackendName(), "does not support allgather"));
- }
- // Gathers a single tensor inputBuffer into a single buffer outputBuffer that
- // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE.
- // For implementers of ProcessGroup API and advanced users only.
- // Note: this function will be deprecated in near future.
- virtual c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
- at::Tensor& /* outputBuffer */,
- at::Tensor& /* inputBuffer */,
- const AllgatherOptions& /* opts */ = AllgatherOptions()) {
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ",
- getBackendName(),
- "does not support _allgather_base"));
- }
- // This function is deprecated and will be moved out of ProcessGroup to comms:
- // * do not add dependencies on this function,
- // * do not implement it in your ProcessGroup, implement _allgather_base
- // instead.
- virtual c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
- std::vector<std::vector<at::Tensor>>& /* outputTensorLists */,
- std::vector<at::Tensor>& /* inputTensors */,
- const AllgatherOptions& /* opts */ = AllgatherOptions()) {
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ",
- getBackendName(),
- "does not support allgather_coalesced"));
- }
- virtual c10::intrusive_ptr<ProcessGroup::Work> gather(
- std::vector<std::vector<at::Tensor>>& /* outputTensors */,
- std::vector<at::Tensor>& /* inputTensors */,
- const GatherOptions& /* opts */ = GatherOptions()) {
- TORCH_CHECK(
- false,
- c10::str("ProcessGroup ", getBackendName(), "does not support gather"));
- }
- virtual c10::intrusive_ptr<ProcessGroup::Work> scatter(
- std::vector<at::Tensor>& /* outputTensors */,
- std::vector<std::vector<at::Tensor>>& /* inputTensors */,
- const ScatterOptions& /* opts */ = ScatterOptions()) {
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ", getBackendName(), "does not support scatter"));
- }
- virtual c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
- std::vector<at::Tensor>& /* outputTensors */,
- std::vector<std::vector<at::Tensor>>& /* inputTensors */,
- const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ",
- getBackendName(),
- "does not support reduce_scatter"));
- }
- virtual c10::intrusive_ptr<ProcessGroup::Work> _reduce_scatter_base(
- at::Tensor& /* outputBuffer */,
- at::Tensor& /* inputBuffer */,
- const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ",
- getBackendName(),
- "does not support _reduce_scatter_base"));
- }
- virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
- at::Tensor& /* outputBuffer */,
- at::Tensor& /* inputBuffer */,
- std::vector<int64_t>& /* outputSplitSizes */,
- std::vector<int64_t>& /* inputSplitSizes */,
- const AllToAllOptions& /* opts */ = AllToAllOptions()) {
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ",
- getBackendName(),
- "does not support alltoall_base"));
- }
- virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall(
- std::vector<at::Tensor>& /* outputTensors */,
- std::vector<at::Tensor>& /* inputTensors */,
- const AllToAllOptions& opts = AllToAllOptions()) {
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ", getBackendName(), "does not support alltoall"));
- }
- virtual void monitoredBarrier(
- const BarrierOptions& /* unused */,
- bool /* unused */ = false) {
- auto backendName = getBackendName();
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ",
- backendName,
- " does not support monitoredBarrier, only GLOO supports monitored barrier."));
- }
- // Agrees on an initial sequence number for the whole group by having rank 0
- // create it and broadcast it to other ranks using the store. Only implemented
- // for GLOO and NCCL backends currently.
- virtual void setSequenceNumberForGroup() {
- auto backendName = getBackendName();
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ",
- backendName,
- " does not yet support sequence numbers."));
- }
- // Retrieves the current sequence number for the whole group, which should be
- // in sync. If the returned number is not consistent across the group, it
- // may indicate that there is some sort of collective desynchronization.
- virtual uint64_t getSequenceNumberForGroup() {
- auto backendName = getBackendName();
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ",
- backendName,
- " does not yet support sequence numbers."));
- }
- virtual c10::intrusive_ptr<ProcessGroup::Work> send(
- std::vector<at::Tensor>& /* tensors */,
- int /* dstRank */,
- int /* tag */) {
- TORCH_CHECK(
- false,
- c10::str("ProcessGroup ", getBackendName(), "does not support send"));
- }
- virtual c10::intrusive_ptr<ProcessGroup::Work> recv(
- std::vector<at::Tensor>& /* tensors */,
- int /* srcRank */,
- int /* tag */) {
- TORCH_CHECK(
- false,
- c10::str("ProcessGroup ", getBackendName(), "does not support recv"));
- }
- virtual c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
- std::vector<at::Tensor>& /* tensors */,
- int /* tag */) {
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ",
- getBackendName(),
- "does not support recvAnysource"));
- }
- virtual c10::intrusive_ptr<ProcessGroup::Work> barrier(
- const BarrierOptions& /* opts */ = BarrierOptions()) {
- TORCH_CHECK(
- false,
- c10::str(
- "ProcessGroup ", getBackendName(), "does not support barrier"));
- }
- protected:
- // Implementations of this interface need to call this to setup
- // appropriate logging etc.
- void init();
- const int rank_;
- const int size_;
- // Optional sequence number structure for matching collectives.
- c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt;
- // Debug level setting. It is parsed once when ProcessGroup is constructed and
- // remains the same across use of this process group.
- DebugLevel dist_debug_level_;
- };
- } // namespace c10d
|