ProcessGroup.hpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. #pragma once
  2. #include <condition_variable>
  3. #include <memory>
  4. #include <mutex>
  5. #include <stdexcept>
  6. #include <unordered_map>
  7. #include <vector>
  8. #include <ATen/ATen.h>
  9. #include <c10/macros/Macros.h>
  10. #include <c10d/Types.hpp>
  11. #include <c10d/Utils.hpp>
  12. #include <c10d/debug.h>
  13. #include <c10d/sequence_num.hpp>
  14. // *************************************************************************
  15. // PROCESS GROUP collective communication API IS BEING CHANGED BETWEEN
  16. // versions 1.7 and 1.8.
  17. // PLEASE DO NOT ADD ANY DEPENDENCIES.
  18. // SEE RFC: https://github.com/pytorch/pytorch/issues/39662
  19. // *************************************************************************
  20. constexpr auto kNoTimeout = std::chrono::milliseconds(0);
  21. constexpr auto kProcessGroupDefaultTimeout =
  22. std::chrono::milliseconds(30 * 60 * 1000);
  23. namespace c10d {
  24. constexpr const char* const kSeqNumStoreKey = "SEQ_NUM_STORE_KEY";
  25. enum class OpType : std::uint8_t {
  26. BROADCAST = 0,
  27. ALLREDUCE = 1,
  28. ALLREDUCE_COALESCED = 2,
  29. REDUCE = 3,
  30. ALLGATHER = 4,
  31. _ALLGATHER_BASE = 5,
  32. ALLGATHER_COALESCED = 6,
  33. GATHER = 7,
  34. SCATTER = 8,
  35. REDUCE_SCATTER = 9,
  36. ALLTOALL_BASE = 10,
  37. ALLTOALL = 11,
  38. SEND = 12,
  39. RECV = 13,
  40. RECVANYSOURCE = 14,
  41. BARRIER = 15,
  42. _REDUCE_SCATTER_BASE = 16,
  43. UNKNOWN = 100,
  44. };
  45. // Converts OpType to human readable string.
  46. TORCH_API std::string opTypeToString(OpType opType);
  47. // Whether or not an OP is an p2p op (SEND, RECV, RECVANYSOURCE)
  48. TORCH_API bool isP2POp(OpType opType, bool batchP2P = false);
  49. // ProcessGroup is a base class that captures collective and point to
  50. // point communication in a fixed set of processes.
  51. //
  52. // The functions specified in the class below describe the API alone;
  53. // implementations are provided in subclasses.
  54. //
  55. // Every function that performs I/O is executed asynchronously by a
  56. // thread pool owned by the ProcessGroup (by default). They return an
  57. // object that can be used to wait for completion or error.
  58. //
  59. // The ProcessGroup can instantiate subgroups with fewer or an equal
  60. // number of members. Implementations must take care that multiple
  61. // process groups can be used in parallel and synchronize accordingly.
  62. //
  63. // The ProcessGroup assumes a fixed set of processes. If the set
  64. // changes, existing instances must be destructed and instantiation
  65. // and initialization must start from scratch. For members of the
  66. // process group to find each other (referred to as rendezvous from
  67. // hereon)
  68. //
  69. class TORCH_API ProcessGroup : public torch::CustomClassHolder {
  70. public:
  71. // Please do not use ProcessGroup::Work API, it is going away, to be
  72. // replaced by ivalue::Future.
  73. // Python binding for this class might change, please do not assume
  74. // this will be bound using pybind.
  75. class TORCH_API Work : public torch::CustomClassHolder {
  76. public:
  77. Work(
  78. int rank = -1,
  79. OpType opType = OpType::UNKNOWN,
  80. const char* profilingTitle = nullptr,
  81. const c10::optional<std::vector<at::Tensor>>& inputTensors =
  82. c10::nullopt);
  83. virtual ~Work();
  84. // Checks if request has completed. Non-blocking operation.
  85. virtual bool isCompleted();
  86. // Returns if the work completed successfully.
  87. // If false, the exception function can be called to get details.
  88. virtual bool isSuccess() const;
  89. // Returns exception if isSuccess() returned false.
  90. virtual std::exception_ptr exception() const;
  91. // Returns source rank if this objects represents a recv-from-any.
  92. virtual int sourceRank() const;
  93. // Returns result tensors, if applicable.
  94. // If work is not supposed to have result, we return empty list.
  95. virtual std::vector<at::Tensor> result();
  96. // Ensures that operations on the output tensors that are invoked
  97. // after this function returns are correctly sequenced after the
  98. // asynchronous completion of this work.
  99. //
  100. // For CUDA tensors, it inserts stream synchronization such that
  101. // the streams of the caller wait for completion of the
  102. // asynchronous operations on the destination tensors.
  103. //
  104. // For CPU tensors, it is currently a nop.
  105. //
  106. // This function should only be used if the caller polls for
  107. // completion through the `isCompleted` function, it has returned
  108. // true, and the `isSuccess` function also has returned true.
  109. //
  110. virtual void synchronize();
  111. // Waits until request completes. Blocking operation.
  112. // Throws if the work completed with an exception.
  113. // Returns false if the work is aborted.
  114. // Otherwise, it always returns true, indicating the work is completed.
  115. //
  116. // Functionally equivalent to:
  117. //
  118. // while (!isCompleted()) { /* nop */ }
  119. // auto success = isSuccess();
  120. // if (!success) { std::rethrow_exception(exception()); }
  121. // return success;
  122. //
  123. virtual bool wait(std::chrono::milliseconds timeout = kNoTimeout);
  124. virtual void abort();
  125. // Returns a Future object that will be associated with the completion of
  126. // work. Only NCCL backend is currently supported.
  127. virtual c10::intrusive_ptr<c10::ivalue::Future> getFuture();
  128. OpType retrieveOpType();
  129. protected:
  130. // Completes the work object and optionally sets the exception in a
  131. // thread-safe manner. Notifies all waiting condition variables as well.
  132. void finish(std::exception_ptr exception = nullptr);
  133. // Similar to finish, but throws an exception if one is already set or
  134. // provided by the user.
  135. void finishAndThrow(std::exception_ptr exception);
  136. mutable std::mutex mutex_;
  137. std::condition_variable cv_;
  138. bool completed_ = false;
  139. std::exception_ptr exception_;
  140. // Current rank of the node.
  141. const int rank_;
  142. // Operation type that this work object refers to.
  143. OpType opType_;
  144. // When profiling, the callback to record end of operation event. This
  145. // callback needs to be called when collective operation is complete.
  146. std::function<void()> recordFunctionEndCallback_;
  147. };
  148. // ProcessGroup Options is a base struct that defines the basic options
  149. // when constructing a ProcessGroup. Each ProcessGroup subclass should
  150. // extend this struct and define its options if it wants to provide more
  151. // config options (beyond basic ones defined here) to end user.
  152. struct TORCH_API Options : torch::CustomClassHolder {
  153. explicit Options(
  154. std::string backend,
  155. std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout)
  156. : timeout(timeout), backend(backend) {}
  157. virtual ~Options() = default;
  158. std::chrono::milliseconds timeout;
  159. // backend name
  160. const std::string backend;
  161. };
  162. explicit ProcessGroup(int rank, int size);
  163. virtual ~ProcessGroup();
  164. int getRank() const {
  165. return rank_;
  166. }
  167. int getSize() const {
  168. return size_;
  169. }
  170. // Subclasses must override this method to return the backend name
  171. virtual const std::string getBackendName() const = 0;
  172. virtual c10::intrusive_ptr<ProcessGroup::Work> broadcast(
  173. std::vector<at::Tensor>& /* tensors */,
  174. const BroadcastOptions& /* opts */ = BroadcastOptions()) {
  175. TORCH_CHECK(
  176. false,
  177. c10::str(
  178. "ProcessGroup ", getBackendName(), "does not support broadcast"));
  179. }
  180. virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce(
  181. std::vector<at::Tensor>& /* tensors */,
  182. const AllreduceOptions& /* opts */ = AllreduceOptions()) {
  183. TORCH_CHECK(
  184. false,
  185. c10::str(
  186. "ProcessGroup ", getBackendName(), "does not support allreduce"));
  187. }
  188. virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
  189. std::vector<at::Tensor>& /* tensors */,
  190. const AllreduceCoalescedOptions& /* opts */ = AllreduceCoalescedOptions()) {
  191. TORCH_CHECK(
  192. false,
  193. c10::str(
  194. "ProcessGroup ",
  195. getBackendName(),
  196. "does not support allreduce_coalesced"));
  197. }
  198. virtual c10::intrusive_ptr<ProcessGroup::Work> reduce(
  199. std::vector<at::Tensor>& /* tensors */,
  200. const ReduceOptions& /* opts */ = ReduceOptions()) {
  201. TORCH_CHECK(
  202. false,
  203. c10::str("ProcessGroup ", getBackendName(), "does not support reduce"));
  204. }
  205. virtual c10::intrusive_ptr<ProcessGroup::Work> allgather(
  206. std::vector<std::vector<at::Tensor>>& /* outputTensors */,
  207. std::vector<at::Tensor>& /* inputTensors */,
  208. const AllgatherOptions& /* opts */ = AllgatherOptions()) {
  209. TORCH_CHECK(
  210. false,
  211. c10::str(
  212. "ProcessGroup ", getBackendName(), "does not support allgather"));
  213. }
  214. // Gathers a single tensor inputBuffer into a single buffer outputBuffer that
  215. // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE.
  216. // For implementers of ProcessGroup API and advanced users only.
  217. // Note: this function will be deprecated in near future.
  218. virtual c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
  219. at::Tensor& /* outputBuffer */,
  220. at::Tensor& /* inputBuffer */,
  221. const AllgatherOptions& /* opts */ = AllgatherOptions()) {
  222. TORCH_CHECK(
  223. false,
  224. c10::str(
  225. "ProcessGroup ",
  226. getBackendName(),
  227. "does not support _allgather_base"));
  228. }
  229. // This function is deprecated and will be moved out of ProcessGroup to comms:
  230. // * do not add dependencies on this function,
  231. // * do not implement it in your ProcessGroup, implement _allgather_base
  232. // instead.
  233. virtual c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
  234. std::vector<std::vector<at::Tensor>>& /* outputTensorLists */,
  235. std::vector<at::Tensor>& /* inputTensors */,
  236. const AllgatherOptions& /* opts */ = AllgatherOptions()) {
  237. TORCH_CHECK(
  238. false,
  239. c10::str(
  240. "ProcessGroup ",
  241. getBackendName(),
  242. "does not support allgather_coalesced"));
  243. }
  244. virtual c10::intrusive_ptr<ProcessGroup::Work> gather(
  245. std::vector<std::vector<at::Tensor>>& /* outputTensors */,
  246. std::vector<at::Tensor>& /* inputTensors */,
  247. const GatherOptions& /* opts */ = GatherOptions()) {
  248. TORCH_CHECK(
  249. false,
  250. c10::str("ProcessGroup ", getBackendName(), "does not support gather"));
  251. }
  252. virtual c10::intrusive_ptr<ProcessGroup::Work> scatter(
  253. std::vector<at::Tensor>& /* outputTensors */,
  254. std::vector<std::vector<at::Tensor>>& /* inputTensors */,
  255. const ScatterOptions& /* opts */ = ScatterOptions()) {
  256. TORCH_CHECK(
  257. false,
  258. c10::str(
  259. "ProcessGroup ", getBackendName(), "does not support scatter"));
  260. }
  261. virtual c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
  262. std::vector<at::Tensor>& /* outputTensors */,
  263. std::vector<std::vector<at::Tensor>>& /* inputTensors */,
  264. const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
  265. TORCH_CHECK(
  266. false,
  267. c10::str(
  268. "ProcessGroup ",
  269. getBackendName(),
  270. "does not support reduce_scatter"));
  271. }
  272. virtual c10::intrusive_ptr<ProcessGroup::Work> _reduce_scatter_base(
  273. at::Tensor& /* outputBuffer */,
  274. at::Tensor& /* inputBuffer */,
  275. const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
  276. TORCH_CHECK(
  277. false,
  278. c10::str(
  279. "ProcessGroup ",
  280. getBackendName(),
  281. "does not support _reduce_scatter_base"));
  282. }
  283. virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
  284. at::Tensor& /* outputBuffer */,
  285. at::Tensor& /* inputBuffer */,
  286. std::vector<int64_t>& /* outputSplitSizes */,
  287. std::vector<int64_t>& /* inputSplitSizes */,
  288. const AllToAllOptions& /* opts */ = AllToAllOptions()) {
  289. TORCH_CHECK(
  290. false,
  291. c10::str(
  292. "ProcessGroup ",
  293. getBackendName(),
  294. "does not support alltoall_base"));
  295. }
  296. virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall(
  297. std::vector<at::Tensor>& /* outputTensors */,
  298. std::vector<at::Tensor>& /* inputTensors */,
  299. const AllToAllOptions& opts = AllToAllOptions()) {
  300. TORCH_CHECK(
  301. false,
  302. c10::str(
  303. "ProcessGroup ", getBackendName(), "does not support alltoall"));
  304. }
  305. virtual void monitoredBarrier(
  306. const BarrierOptions& /* unused */,
  307. bool /* unused */ = false) {
  308. auto backendName = getBackendName();
  309. TORCH_CHECK(
  310. false,
  311. c10::str(
  312. "ProcessGroup ",
  313. backendName,
  314. " does not support monitoredBarrier, only GLOO supports monitored barrier."));
  315. }
  316. // Agrees on an initial sequence number for the whole group by having rank 0
  317. // create it and broadcast it to other ranks using the store. Only implemented
  318. // for GLOO and NCCL backends currently.
  319. virtual void setSequenceNumberForGroup() {
  320. auto backendName = getBackendName();
  321. TORCH_CHECK(
  322. false,
  323. c10::str(
  324. "ProcessGroup ",
  325. backendName,
  326. " does not yet support sequence numbers."));
  327. }
  328. // Retrieves the current sequence number for the whole group, which should be
  329. // in sync. If the returned number is not consistent across the group, it
  330. // may indicate that there is some sort of collective desynchronization.
  331. virtual uint64_t getSequenceNumberForGroup() {
  332. auto backendName = getBackendName();
  333. TORCH_CHECK(
  334. false,
  335. c10::str(
  336. "ProcessGroup ",
  337. backendName,
  338. " does not yet support sequence numbers."));
  339. }
  340. virtual c10::intrusive_ptr<ProcessGroup::Work> send(
  341. std::vector<at::Tensor>& /* tensors */,
  342. int /* dstRank */,
  343. int /* tag */) {
  344. TORCH_CHECK(
  345. false,
  346. c10::str("ProcessGroup ", getBackendName(), "does not support send"));
  347. }
  348. virtual c10::intrusive_ptr<ProcessGroup::Work> recv(
  349. std::vector<at::Tensor>& /* tensors */,
  350. int /* srcRank */,
  351. int /* tag */) {
  352. TORCH_CHECK(
  353. false,
  354. c10::str("ProcessGroup ", getBackendName(), "does not support recv"));
  355. }
  356. virtual c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
  357. std::vector<at::Tensor>& /* tensors */,
  358. int /* tag */) {
  359. TORCH_CHECK(
  360. false,
  361. c10::str(
  362. "ProcessGroup ",
  363. getBackendName(),
  364. "does not support recvAnysource"));
  365. }
  366. virtual c10::intrusive_ptr<ProcessGroup::Work> barrier(
  367. const BarrierOptions& /* opts */ = BarrierOptions()) {
  368. TORCH_CHECK(
  369. false,
  370. c10::str(
  371. "ProcessGroup ", getBackendName(), "does not support barrier"));
  372. }
  373. protected:
  374. // Implementations of this interface need to call this to setup
  375. // appropriate logging etc.
  376. void init();
  377. const int rank_;
  378. const int size_;
  379. // Optional sequence number structure for matching collectives.
  380. c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt;
  381. // Debug level setting. It is parsed once when ProcessGroup is constructed and
  382. // remains the same across use of this process group.
  383. DebugLevel dist_debug_level_;
  384. };
  385. } // namespace c10d