#pragma once #include #include #include #include #include #include #include #include #include #include #include #include // ************************************************************************* // PROCESS GROUP collective communication API IS BEING CHANGED BETWEEN // versions 1.7 and 1.8. // PLEASE DO NOT ADD ANY DEPENDENCIES. // SEE RFC: https://github.com/pytorch/pytorch/issues/39662 // ************************************************************************* constexpr auto kNoTimeout = std::chrono::milliseconds(0); constexpr auto kProcessGroupDefaultTimeout = std::chrono::milliseconds(30 * 60 * 1000); namespace c10d { constexpr const char* const kSeqNumStoreKey = "SEQ_NUM_STORE_KEY"; enum class OpType : std::uint8_t { BROADCAST = 0, ALLREDUCE = 1, ALLREDUCE_COALESCED = 2, REDUCE = 3, ALLGATHER = 4, _ALLGATHER_BASE = 5, ALLGATHER_COALESCED = 6, GATHER = 7, SCATTER = 8, REDUCE_SCATTER = 9, ALLTOALL_BASE = 10, ALLTOALL = 11, SEND = 12, RECV = 13, RECVANYSOURCE = 14, BARRIER = 15, _REDUCE_SCATTER_BASE = 16, UNKNOWN = 100, }; // Converts OpType to human readable string. TORCH_API std::string opTypeToString(OpType opType); // Whether or not an OP is an p2p op (SEND, RECV, RECVANYSOURCE) TORCH_API bool isP2POp(OpType opType, bool batchP2P = false); // ProcessGroup is a base class that captures collective and point to // point communication in a fixed set of processes. // // The functions specified in the class below describe the API alone; // implementations are provided in subclasses. // // Every function that performs I/O is executed asynchronously by a // thread pool owned by the ProcessGroup (by default). They return an // object that can be used to wait for completion or error. // // The ProcessGroup can instantiate subgroups with fewer or an equal // number of members. Implementations must take care that multiple // process groups can be used in parallel and synchronize accordingly. // // The ProcessGroup assumes a fixed set of processes. If the set // changes, existing instances must be destructed and instantiation // and initialization must start from scratch. For members of the // process group to find each other (referred to as rendezvous from // hereon) // class TORCH_API ProcessGroup : public torch::CustomClassHolder { public: // Please do not use ProcessGroup::Work API, it is going away, to be // replaced by ivalue::Future. // Python binding for this class might change, please do not assume // this will be bound using pybind. class TORCH_API Work : public torch::CustomClassHolder { public: Work( int rank = -1, OpType opType = OpType::UNKNOWN, const char* profilingTitle = nullptr, const c10::optional>& inputTensors = c10::nullopt); virtual ~Work(); // Checks if request has completed. Non-blocking operation. virtual bool isCompleted(); // Returns if the work completed successfully. // If false, the exception function can be called to get details. virtual bool isSuccess() const; // Returns exception if isSuccess() returned false. virtual std::exception_ptr exception() const; // Returns source rank if this objects represents a recv-from-any. virtual int sourceRank() const; // Returns result tensors, if applicable. // If work is not supposed to have result, we return empty list. virtual std::vector result(); // Ensures that operations on the output tensors that are invoked // after this function returns are correctly sequenced after the // asynchronous completion of this work. // // For CUDA tensors, it inserts stream synchronization such that // the streams of the caller wait for completion of the // asynchronous operations on the destination tensors. // // For CPU tensors, it is currently a nop. // // This function should only be used if the caller polls for // completion through the `isCompleted` function, it has returned // true, and the `isSuccess` function also has returned true. // virtual void synchronize(); // Waits until request completes. Blocking operation. // Throws if the work completed with an exception. // Returns false if the work is aborted. // Otherwise, it always returns true, indicating the work is completed. // // Functionally equivalent to: // // while (!isCompleted()) { /* nop */ } // auto success = isSuccess(); // if (!success) { std::rethrow_exception(exception()); } // return success; // virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout); virtual void abort(); // Returns a Future object that will be associated with the completion of // work. Only NCCL backend is currently supported. virtual c10::intrusive_ptr getFuture(); OpType retrieveOpType(); protected: // Completes the work object and optionally sets the exception in a // thread-safe manner. Notifies all waiting condition variables as well. void finish(std::exception_ptr exception = nullptr); // Similar to finish, but throws an exception if one is already set or // provided by the user. void finishAndThrow(std::exception_ptr exception); mutable std::mutex mutex_; std::condition_variable cv_; bool completed_ = false; std::exception_ptr exception_; // Current rank of the node. const int rank_; // Operation type that this work object refers to. OpType opType_; // When profiling, the callback to record end of operation event. This // callback needs to be called when collective operation is complete. std::function recordFunctionEndCallback_; }; // ProcessGroup Options is a base struct that defines the basic options // when constructing a ProcessGroup. Each ProcessGroup subclass should // extend this struct and define its options if it wants to provide more // config options (beyond basic ones defined here) to end user. struct TORCH_API Options : torch::CustomClassHolder { explicit Options( std::string backend, std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout) : timeout(timeout), backend(backend) {} virtual ~Options() = default; std::chrono::milliseconds timeout; // backend name const std::string backend; }; explicit ProcessGroup(int rank, int size); virtual ~ProcessGroup(); int getRank() const { return rank_; } int getSize() const { return size_; } // Subclasses must override this method to return the backend name virtual const std::string getBackendName() const = 0; virtual c10::intrusive_ptr broadcast( std::vector& /* tensors */, const BroadcastOptions& /* opts */ = BroadcastOptions()) { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), "does not support broadcast")); } virtual c10::intrusive_ptr allreduce( std::vector& /* tensors */, const AllreduceOptions& /* opts */ = AllreduceOptions()) { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), "does not support allreduce")); } virtual c10::intrusive_ptr allreduce_coalesced( std::vector& /* tensors */, const AllreduceCoalescedOptions& /* opts */ = AllreduceCoalescedOptions()) { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), "does not support allreduce_coalesced")); } virtual c10::intrusive_ptr reduce( std::vector& /* tensors */, const ReduceOptions& /* opts */ = ReduceOptions()) { TORCH_CHECK( false, c10::str("ProcessGroup ", getBackendName(), "does not support reduce")); } virtual c10::intrusive_ptr allgather( std::vector>& /* outputTensors */, std::vector& /* inputTensors */, const AllgatherOptions& /* opts */ = AllgatherOptions()) { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), "does not support allgather")); } // Gathers a single tensor inputBuffer into a single buffer outputBuffer that // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE. // For implementers of ProcessGroup API and advanced users only. // Note: this function will be deprecated in near future. virtual c10::intrusive_ptr _allgather_base( at::Tensor& /* outputBuffer */, at::Tensor& /* inputBuffer */, const AllgatherOptions& /* opts */ = AllgatherOptions()) { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), "does not support _allgather_base")); } // This function is deprecated and will be moved out of ProcessGroup to comms: // * do not add dependencies on this function, // * do not implement it in your ProcessGroup, implement _allgather_base // instead. virtual c10::intrusive_ptr allgather_coalesced( std::vector>& /* outputTensorLists */, std::vector& /* inputTensors */, const AllgatherOptions& /* opts */ = AllgatherOptions()) { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), "does not support allgather_coalesced")); } virtual c10::intrusive_ptr gather( std::vector>& /* outputTensors */, std::vector& /* inputTensors */, const GatherOptions& /* opts */ = GatherOptions()) { TORCH_CHECK( false, c10::str("ProcessGroup ", getBackendName(), "does not support gather")); } virtual c10::intrusive_ptr scatter( std::vector& /* outputTensors */, std::vector>& /* inputTensors */, const ScatterOptions& /* opts */ = ScatterOptions()) { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), "does not support scatter")); } virtual c10::intrusive_ptr reduce_scatter( std::vector& /* outputTensors */, std::vector>& /* inputTensors */, const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), "does not support reduce_scatter")); } virtual c10::intrusive_ptr _reduce_scatter_base( at::Tensor& /* outputBuffer */, at::Tensor& /* inputBuffer */, const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), "does not support _reduce_scatter_base")); } virtual c10::intrusive_ptr alltoall_base( at::Tensor& /* outputBuffer */, at::Tensor& /* inputBuffer */, std::vector& /* outputSplitSizes */, std::vector& /* inputSplitSizes */, const AllToAllOptions& /* opts */ = AllToAllOptions()) { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), "does not support alltoall_base")); } virtual c10::intrusive_ptr alltoall( std::vector& /* outputTensors */, std::vector& /* inputTensors */, const AllToAllOptions& opts = AllToAllOptions()) { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), "does not support alltoall")); } virtual void monitoredBarrier( const BarrierOptions& /* unused */, bool /* unused */ = false) { auto backendName = getBackendName(); TORCH_CHECK( false, c10::str( "ProcessGroup ", backendName, " does not support monitoredBarrier, only GLOO supports monitored barrier.")); } // Agrees on an initial sequence number for the whole group by having rank 0 // create it and broadcast it to other ranks using the store. Only implemented // for GLOO and NCCL backends currently. virtual void setSequenceNumberForGroup() { auto backendName = getBackendName(); TORCH_CHECK( false, c10::str( "ProcessGroup ", backendName, " does not yet support sequence numbers.")); } // Retrieves the current sequence number for the whole group, which should be // in sync. If the returned number is not consistent across the group, it // may indicate that there is some sort of collective desynchronization. virtual uint64_t getSequenceNumberForGroup() { auto backendName = getBackendName(); TORCH_CHECK( false, c10::str( "ProcessGroup ", backendName, " does not yet support sequence numbers.")); } virtual c10::intrusive_ptr send( std::vector& /* tensors */, int /* dstRank */, int /* tag */) { TORCH_CHECK( false, c10::str("ProcessGroup ", getBackendName(), "does not support send")); } virtual c10::intrusive_ptr recv( std::vector& /* tensors */, int /* srcRank */, int /* tag */) { TORCH_CHECK( false, c10::str("ProcessGroup ", getBackendName(), "does not support recv")); } virtual c10::intrusive_ptr recvAnysource( std::vector& /* tensors */, int /* tag */) { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), "does not support recvAnysource")); } virtual c10::intrusive_ptr barrier( const BarrierOptions& /* opts */ = BarrierOptions()) { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), "does not support barrier")); } protected: // Implementations of this interface need to call this to setup // appropriate logging etc. void init(); const int rank_; const int size_; // Optional sequence number structure for matching collectives. c10::optional sequenceNum_ = c10::nullopt; // Debug level setting. It is parsed once when ProcessGroup is constructed and // remains the same across use of this process group. DebugLevel dist_debug_level_; }; } // namespace c10d