ProcessGroupNCCL.hpp 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. #pragma once
  2. #ifdef USE_C10D_NCCL
  3. #include <chrono>
  4. #include <iostream>
  5. #include <list>
  6. #include <mutex>
  7. #include <thread>
  8. #include <unordered_map>
  9. #include <c10d/NCCLUtils.hpp>
  10. #include <c10d/ProcessGroup.hpp>
  11. #include <c10d/Store.hpp>
  12. #include <c10d/UCCForNCCL.hpp>
  13. #include <ATen/DynamicLibrary.h>
  14. #include <ATen/cuda/CUDAContext.h>
  15. #include <ATen/cuda/CUDAEvent.h>
  16. #include <c10/core/Stream.h>
  17. #include <c10/core/StreamGuard.h>
  18. #include <c10/cuda/CUDACachingAllocator.h>
  19. #include <c10/cuda/CUDAGuard.h>
  20. #include <c10/cuda/CUDAStream.h>
  21. #include <torch/custom_class.h>
  22. namespace c10d {
  23. // Environment variable which controls whether we perform a NCCL healt check
  24. // which ensures communicators are healthy at the beginning of init.
  25. constexpr const char* ENABLE_NCCL_HEALTH_CHECK = "ENABLE_NCCL_HEALTH_CHECK";
  26. // Environment variable which controls whether or not wait() is blocking or
  27. // non-blocking.
  28. constexpr const char* NCCL_BLOCKING_WAIT = "NCCL_BLOCKING_WAIT";
  29. // Environment variable which controls whether or not we perform Async Error
  30. // Handling with NCCL.
  31. constexpr const char* NCCL_ASYNC_ERROR_HANDLING = "NCCL_ASYNC_ERROR_HANDLING";
  32. // Environment Variable to control whether Desync Debug is enabled.
  33. // This variable must be set together with NCCL_ASYNC_ERROR_HANDLING.
  34. constexpr const char* NCCL_DESYNC_DEBUG = "NCCL_DESYNC_DEBUG";
  35. constexpr const char* NCCL_BACKEND_NAME = "nccl";
  36. // ProcessGroupNCCL implements NCCL bindings for c10d.
  37. //
  38. // All functions of the class are expected to be called in the same order
  39. // across all processes in the process group. This is the only way that we
  40. // can guarantee to match up the same calls among all processes.
  41. //
  42. // All NCCL functions provided by this class are asynchronous functions. More
  43. // specifically, each NCCL call is scheduled on a separate CUDA stream that is
  44. // different from the current CUDA stream. This is for the purpose of
  45. // achieving potentially concurrency and better performance. As a result,
  46. // it is the callers' responsibility to make sure that the CUDA stream their
  47. // code works on needs to wait for the NCCL operation from
  48. // this class.
  49. //
  50. // This can be done by calling:
  51. //
  52. // either WorkNCCL::wait() or WorkNCCL::synchronize(), both achieves the same
  53. // functionality and are synonyms.
  54. //
  55. // Also note that WorkNCCL::finishedGPUExecution() is a helper function only
  56. // provided by ProcessGroupNCCL to check if the NCCL operation of WorkNCCL has
  57. // finished execution on the GPU (not just scheduled).
  58. //
  59. // Example on using the NCCL process group
  60. //
  61. // ProcessGroupNCCL pg(store, rank, size);
  62. // std::shared_ptr<WorkNCCL> work = pg.allreduce(tensors);
  63. //
  64. // // At this point, NCCL kernel has already by queued successfully
  65. // // Now, let current stream wait for the NCCL to finish, this function is
  66. // // async operation as well
  67. //
  68. // work->wait()
  69. //
  70. // // Now continue on other work in the current stream.
  71. class TORCH_API ProcessGroupNCCL : public ProcessGroup {
  72. public:
  73. class WorkNCCL : public ProcessGroup::Work,
  74. public std::enable_shared_from_this<WorkNCCL> {
  75. public:
  76. // Constructor takes a list of CUDA devices
  77. WorkNCCL(
  78. const std::vector<at::Device>& devices,
  79. int rank,
  80. OpType opType,
  81. uint64_t seq,
  82. const char* profilingTitle = nullptr,
  83. const c10::optional<std::vector<at::Tensor>>& inputs = c10::nullopt,
  84. bool desyncDebug = false);
  85. // Copy constructor doing partial copy without outputs_. Cleanup thread
  86. // monitors and removes finished works. However it will deadlock when
  87. // destructs outputs_ tensors who are view tensors in autograd graph.
  88. WorkNCCL(const WorkNCCL& w);
  89. virtual ~WorkNCCL();
  90. // Checks if the NCCL kernel has started to execute.
  91. bool isStarted();
  92. // Checks if request has completed. In this specific case of NCCL, it checks
  93. // if the NCCL operation has completed on the GPU in its own NCCL stream.
  94. // Non-blocking operation.
  95. bool isCompleted() override;
  96. bool isSuccess() const override;
  97. // Same as calling synchronize() for NCCL work.
  98. bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
  99. void abort() override;
  100. // Let current stream wait on the completing of the NCCL work
  101. // Throws on exceptions. Blocking operation, which will wait for work
  102. // completion.
  103. void synchronize() override;
  104. // Synchronize streams by blocking each on the NCCL stream
  105. void synchronizeStreams();
  106. // Helper function used in CUDA Stream callbacks to complete WorkNCCL
  107. // objects and throw exceptions when neeeded.
  108. void handleNCCLGuard();
  109. // Helper function that checks if the NCCL kernels have finished
  110. // execution on the GPUs
  111. bool finishedGPUExecution();
  112. // Get a Future object that will be marked as completed internally.
  113. c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
  114. // Helper function that sets an exception_ptr on the WorkNCCL object.
  115. void setException(std::exception_ptr exception_ptr);
  116. // Helper function that returns True if the WorkNCCL object has timed out
  117. // and False otherwise.
  118. bool timedOut();
  119. std::vector<at::Tensor> result() override;
  120. protected:
  121. // The cached list of CUDA devices to operate on
  122. std::vector<at::Device> devices_;
  123. // The start CUDA events of NCCL operator tracking this work item on
  124. // multiple CUDA devices. These start CUDA events are needed by desync
  125. // debugging if enabled.
  126. std::shared_ptr<std::vector<at::cuda::CUDAEvent>> ncclStartEvents_;
  127. // The end CUDA events of NCCL operator tracking this work item on
  128. // multiple CUDA devices.
  129. std::shared_ptr<std::vector<at::cuda::CUDAEvent>> ncclEndEvents_;
  130. // The NCCL communicators used for this work item.
  131. std::vector<std::shared_ptr<NCCLComm>> ncclComms_;
  132. // Tensors used for barrier op
  133. std::vector<at::Tensor> barrierTensors_;
  134. // Clone of blockingWait_ from ProcessGroupNCCL.
  135. bool blockingWait_ = false;
  136. // Clone of opTimeout_ from ProcessGroupNCCL.
  137. std::chrono::milliseconds opTimeout_;
  138. // Time point representing when the work started.
  139. std::chrono::time_point<std::chrono::steady_clock> workStartTime_;
  140. // Record the collective sequential number.
  141. uint64_t seq_;
  142. // Indicates if the nccl start event has been updated to the store trace.
  143. // This will be used by desync debug.
  144. bool startTraceUpdated_{false};
  145. // Wrapper method for the static checkForNCCLErrors which can be overridden
  146. // for tests.
  147. virtual std::exception_ptr checkForNCCLErrors(
  148. const std::vector<std::shared_ptr<NCCLComm>>& ncclComms) const;
  149. friend std::ostream& operator<<(
  150. std::ostream& output,
  151. const WorkNCCL& workNCCL);
  152. private:
  153. // Helper function for synchronize
  154. void synchronizeInternal(std::chrono::milliseconds timeout);
  155. // Checks for NCCL errors and sets an appropriate exception_ptr.
  156. void checkAndSetException();
  157. // Checks for NCCL errors and throws an appropriate exception.
  158. void checkAndThrowException();
  159. // Just checks whether GPU execution has started, without modifying
  160. // exception_ptr.
  161. bool startedGPUExecutionInternal() const;
  162. // Just checks whether GPU execution has completed, without modifying
  163. // exception_ptr.
  164. bool finishedGPUExecutionInternal() const;
  165. // Reference to the store so that we can write aborted communicators
  166. // to the store.
  167. c10::intrusive_ptr<Store> store_;
  168. // Store a reference to NCCL collective's outputs, used by result and to
  169. // give a more descriptive message when representing the Work as a string.
  170. std::shared_ptr<std::vector<at::Tensor>> outputs_;
  171. // The future returned by getFuture.
  172. c10::intrusive_ptr<at::ivalue::Future> future_;
  173. friend class ProcessGroupNCCL;
  174. };
  175. struct Options : ProcessGroup::Options {
  176. // NOTE: timeout in ProcessGroupNCCL::Options denote the timeout for
  177. // operations. This is only used when blockingWait_ is enabled.
  178. explicit Options(
  179. bool is_high_priority_stream = false);
  180. // return intrusive_ptr of the object
  181. static c10::intrusive_ptr<Options> create(
  182. bool is_high_priority_stream = false) {
  183. return c10::make_intrusive<Options>(is_high_priority_stream);
  184. }
  185. // Schedule NCCL operations on high priority CUDA streams
  186. bool is_high_priority_stream;
  187. };
  188. // If you wish to create multiple process groups, each with a potentially
  189. // different rank and size, you can do so by passing a new store instance
  190. // to each one. If you have only a single store object, you can
  191. // use the `c10d::PrefixStore` to derive scoped instances.
  192. // This is also what the Python API in torch.distributed does.
  193. //
  194. // The process group instance keeps a reference to the store because
  195. // it may be used long after the constructor runs. In fact, the constructor
  196. // doesn't create any NCCL communicators. A single NCCL communicator can
  197. // only be used on a specific set of devices, and are therefore created
  198. // on-demand when a collective runs. If another collective is executed later,
  199. // against a different set of devices, the process group creates another NCCL
  200. // communicator. These NCCL communicators are cached and reused if possible.
  201. //
  202. ProcessGroupNCCL(
  203. const c10::intrusive_ptr<Store>& store,
  204. int rank,
  205. int size,
  206. c10::intrusive_ptr<Options> options = Options::create());
  207. // This constructor includes the deprecated `groupName` argument.
  208. // If you have existing code that uses the `groupName`, you can replace
  209. // it by specifying a `c10d::PrefixStore(groupName, store)` for store.
  210. C10_DEPRECATED ProcessGroupNCCL(
  211. const c10::intrusive_ptr<Store>& store,
  212. int rank,
  213. int size,
  214. const std::string& groupName,
  215. c10::intrusive_ptr<Options> options = Options::create())
  216. : ProcessGroupNCCL(store, rank, size, options) {}
  217. virtual ~ProcessGroupNCCL();
  218. c10::intrusive_ptr<Options> getOptions() {
  219. return options_;
  220. }
  221. const std::string getBackendName() const override {
  222. return std::string(NCCL_BACKEND_NAME);
  223. }
  224. c10::intrusive_ptr<ProcessGroup::Work> broadcast(
  225. std::vector<at::Tensor>& tensors,
  226. const BroadcastOptions& opts = BroadcastOptions()) override;
  227. c10::intrusive_ptr<ProcessGroup::Work> allreduce(
  228. std::vector<at::Tensor>& tensors,
  229. const AllreduceOptions& opts = AllreduceOptions()) override;
  230. c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
  231. std::vector<at::Tensor>& tensors,
  232. const AllreduceCoalescedOptions& opts =
  233. AllreduceCoalescedOptions()) override;
  234. c10::intrusive_ptr<ProcessGroup::Work> reduce(
  235. std::vector<at::Tensor>& tensors,
  236. const ReduceOptions& opts = ReduceOptions()) override;
  237. c10::intrusive_ptr<ProcessGroup::Work> allgather(
  238. std::vector<std::vector<at::Tensor>>& outputTensors,
  239. std::vector<at::Tensor>& inputTensors,
  240. const AllgatherOptions& opts = AllgatherOptions()) override;
  241. c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
  242. at::Tensor& outputbuffer,
  243. at::Tensor& inputbuffer,
  244. const AllgatherOptions& opts = AllgatherOptions()) override;
  245. c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
  246. std::vector<std::vector<at::Tensor>>& outputTensorLists,
  247. std::vector<at::Tensor>& inputTensors,
  248. const AllgatherOptions& opts = AllgatherOptions()) override;
  249. c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
  250. std::vector<at::Tensor>& outputTensors,
  251. std::vector<std::vector<at::Tensor>>& inputTensors,
  252. const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
  253. c10::intrusive_ptr<ProcessGroup::Work> _reduce_scatter_base(
  254. at::Tensor& outputTensor,
  255. at::Tensor& inputTensor,
  256. const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
  257. c10::intrusive_ptr<ProcessGroup::Work> barrier(
  258. const BarrierOptions& opts = BarrierOptions()) override;
  259. c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
  260. at::Tensor& outputTensor,
  261. at::Tensor& inputTensor,
  262. std::vector<int64_t>& outputSplitSizes,
  263. std::vector<int64_t>& inputSplitSizes,
  264. const AllToAllOptions& opts = AllToAllOptions()) override;
  265. c10::intrusive_ptr<ProcessGroup::Work> alltoall(
  266. std::vector<at::Tensor>& outputTensors,
  267. std::vector<at::Tensor>& inputTensors,
  268. const AllToAllOptions& opts = AllToAllOptions()) override;
  269. c10::intrusive_ptr<ProcessGroup::Work> send(
  270. std::vector<at::Tensor>& tensors,
  271. int dstRank,
  272. int tag) override;
  273. c10::intrusive_ptr<ProcessGroup::Work> recv(
  274. std::vector<at::Tensor>& tensors,
  275. int srcRank,
  276. int tag) override;
  277. static void groupStart();
  278. static void groupEnd();
  279. // Unsupported Ops
  280. c10::intrusive_ptr<ProcessGroup::Work> gather(
  281. std::vector<std::vector<at::Tensor>>& outputTensors,
  282. std::vector<at::Tensor>& inputTensors,
  283. const GatherOptions& opts = GatherOptions()) override;
  284. c10::intrusive_ptr<ProcessGroup::Work> scatter(
  285. std::vector<at::Tensor>& outputTensors,
  286. std::vector<std::vector<at::Tensor>>& inputTensors,
  287. const ScatterOptions& opts = ScatterOptions()) override;
  288. c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
  289. std::vector<at::Tensor>& tensors,
  290. int tag) override;
  291. // Agrees on an initial sequence number for the whole group by having rank 0
  292. // create it and broadcast it to other ranks using the store.
  293. void setSequenceNumberForGroup() override;
  294. // Retrieves the current sequence number for the whole group, which should be
  295. // in sync. If the returned number is not consistent across the group, it
  296. // may indicate that there is some sort of collective desynchronization.
  297. uint64_t getSequenceNumberForGroup() override;
  298. // Tests if the UCC fallback path is available
  299. bool isUCCAvailable() const;
  300. protected:
  301. // Helper that broadcasts nccl unique ID to all ranks through the store
  302. void broadcastUniqueNCCLID(
  303. ncclUniqueId* ncclID,
  304. bool isSingleP2POp,
  305. const std::string& devicesKey,
  306. int p2pRank);
  307. // Helper that either looks up the cached NCCL communicators or creates
  308. // a new set of NCCL communicators as a cache entry
  309. std::vector<std::shared_ptr<NCCLComm>>& getNCCLComm(
  310. const std::string& devicesKey,
  311. const std::vector<at::Device>& devices,
  312. OpType opType,
  313. int p2pRank = 0,
  314. bool isSendRecvSelf = false);
  315. // Wrapper method which can be overridden for tests.
  316. virtual std::exception_ptr checkForNCCLErrors(
  317. const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
  318. virtual c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
  319. std::vector<at::Device> devices,
  320. int rank,
  321. OpType opType,
  322. const char* profilingTitle=nullptr,
  323. const c10::optional<std::vector<at::Tensor>>& inputs = c10::nullopt);
  324. private:
  325. // Helper that encapsulates work shared across all collective communication
  326. // primitives. The callbacks have the following signatures:
  327. //
  328. // ncclResult_t fn(at::Tensor& input, at::Tensor& output,
  329. // ncclComm_t, at::cuda::CUDAStream&);
  330. // void {pre,post}(std::vector<at::cuda::CUDAStream&>);
  331. template <typename Fn>
  332. c10::intrusive_ptr<ProcessGroup::Work> collective(
  333. std::vector<at::Tensor>& input,
  334. std::vector<at::Tensor>& output,
  335. Fn fn,
  336. OpType opType,
  337. const char* profilingTitle = nullptr);
  338. template <typename Fn, typename PreProcess, typename PostProcess>
  339. c10::intrusive_ptr<ProcessGroup::Work> collective(
  340. std::vector<at::Tensor>& input,
  341. std::vector<at::Tensor>& output,
  342. Fn fn,
  343. PreProcess pre,
  344. PostProcess post,
  345. OpType opType,
  346. const char* profilingTitle = nullptr);
  347. // Helper that encapsulates work shared across point-to-point communication
  348. // primitives. It is the same structure as the helper used for collective
  349. // communicaiton primitives.
  350. template <typename Fn>
  351. c10::intrusive_ptr<ProcessGroup::Work> pointToPoint(
  352. std::vector<at::Tensor>& tensor,
  353. Fn fn,
  354. int peer,
  355. OpType opType,
  356. const char* profilingTitle = nullptr);
  357. template <typename Fn, typename PreProcess, typename PostProcess>
  358. c10::intrusive_ptr<ProcessGroup::Work> pointToPoint(
  359. std::vector<at::Tensor>& tensor,
  360. Fn fn,
  361. int peer,
  362. OpType opType,
  363. PreProcess pre,
  364. PostProcess post,
  365. const char* profilingTitle);
  366. c10::intrusive_ptr<ProcessGroup::Work> allreduce_impl(
  367. std::vector<at::Tensor>& tensors,
  368. const AllreduceOptions& opts = AllreduceOptions());
  369. // Checks for NCCL errors on each of the communicators and returns an
  370. // appropriate exception_ptr (nullptr if no errors).
  371. static std::exception_ptr checkForNCCLErrorsInternal(
  372. const std::vector<std::shared_ptr<NCCLComm>>& ncclComms);
  373. // Function that runs as part of a separate thread and checks for errors on
  374. // NCCL communicators. We need a separate thread to check for NCCL errors
  375. // since we can't rely on the user calling certain methods like wait(),
  376. // isCompleted() etc. to detect and remediate errors. In addition to this, we
  377. // need a mechanism to safely abort and remove NCCL communicators from our
  378. // cache. This can be done cleanly by having a thread for the ProcessGroupNCCL
  379. // class. Attempting to modify the communicator cache from the WorkNCCL class
  380. // might run into issues with object lifetime since the ProcessGroupNCCL
  381. // object might get destroyed before the WorkNCCL object.
  382. void ncclCommWatchdog();
  383. void ncclCommWatchdogInternal();
  384. // This function iterates through the list of WorkNCCL objects in the
  385. // workList_ corresponding to incomplete collectives and then aborts NCCL
  386. // communicators associated with timed out collectives.
  387. void abortTimedOutCollectives(
  388. std::unordered_set<std::string>& abortedCommIds);
  389. // Performs a health check by initializing dummy NCCL communicators and then
  390. // destroying them. This will help indicate and signal any NCCL-related issues
  391. // prior to the first collective. The actual initialization and subsequent
  392. // destruction is ran on a separate thread and the main thread is signalled
  393. // about timeouts/errors to report to the application.
  394. void runHealthCheck();
  395. // Destroys initialized NCCL communicators in devNCCLComMap_ given by input
  396. // key. Throws if there are no communicators to destroy. Also removes
  397. // communicators from the cache and clears used device indices.
  398. void destroyNCCLComms(const std::string& devNCCLCommMapKey);
  399. void workCleanupLoop();
  400. protected:
  401. static const int64_t kWatchdogThreadSleepMillis;
  402. static const int64_t kWorkCleanupThreadSleepMillis;
  403. // The store is used to broadcast the NCCL unique ID of rank 0.
  404. c10::intrusive_ptr<Store> store_;
  405. bool storeError_{false};
  406. const c10::intrusive_ptr<Options> options_;
  407. // The number of NCCL communicators that have been created during
  408. // the lifetime of this process group. This sequence number is
  409. // used to scope keys used in the store.
  410. uint64_t ncclCommCounter_{0};
  411. // The store keys to trace the last NCCL collective kernel CUDA events - start
  412. // event and end event respectively. These are used to do desync root cause
  413. // analysis.
  414. const std::string traceKeyStart_;
  415. const std::string traceKeyEnd_;
  416. // The NCCL communicator that the process group has cached.
  417. //
  418. // For collective operations:
  419. // The key is a list of GPU devices that an operation is operating on
  420. // The GPU devices are stored in a device sequence and the cache NCCL
  421. // communicator is associated with this GPU device sequence
  422. //
  423. // e.g. If the process group op only uses device 0, then the value of
  424. // the used device string stored (value of the hashmap) would be "0".
  425. //
  426. // If the process group op uses device 0 - 7 and the each tensor of the
  427. // input tensor list is on device, 0, 1, 2, 3, 4, 5, 6, 7 separately,
  428. // then the value of the used device string (key) stored would be
  429. // "0,1,2,3,4,5,6,7"
  430. //
  431. // If the process group op uses device 0 - 7 and the each tensor of the
  432. // input tensor list is on device, 0, 4, 5, 6, 7, 1, 2, 3 separately,
  433. // then the value of the used device string stored would be
  434. // "0,4,5,6,7,1,2,3"
  435. //
  436. // Note that the order of the device for the tensor list matters.
  437. //
  438. // For point-to-point operations:
  439. // The key is a string of my current rank and the peer process rank.
  440. // e.g. If process 1 and process 2 are involved in a point-to-point
  441. // communication, the key will be "1:2" on both processes. Note: this is for
  442. // the scenario where there is only 1 GPU per process. When it comes to
  443. // multiple GPUs per process, this part may need to redesigned.
  444. std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
  445. devNCCLCommMap_;
  446. // Map from ncclUniqueId to appropriate communicator.
  447. std::unordered_map<std::string, std::vector<std::shared_ptr<NCCLComm>>>
  448. ncclIdToCommMap_;
  449. // Mutex to guard maps like devNCCLCommMap_ and ncclIdToCommMap_.
  450. std::mutex mutex_;
  451. // Watchdog thread which looks for errors on the cached NCCL communicators.
  452. std::thread ncclCommWatchdogThread_;
  453. // Whether or not we should terminate the watchdog and workCleanup threads.
  454. std::atomic<bool> terminateProcessGroup_;
  455. // Condition variable to control how long the watchdog thread waits.
  456. std::condition_variable watchdogCV_;
  457. // Mutex for watchdog.
  458. std::mutex watchdogCVMutex_;
  459. // Thread that removes NCCL Work upon timeout
  460. std::thread workCleanupThread_;
  461. // Mutex to Guard workMetaList_
  462. std::mutex workMetaListMutex_;
  463. // Condition Variable for timeout thread sleep
  464. std::condition_variable workMetaListCV_;
  465. // Vector to Store WorkNCCL pointers
  466. std::list<ProcessGroupNCCL::WorkNCCL> workMetaList_;
  467. // Add Work Pointer to workVector
  468. void workEnqueue(c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>);
  469. // The CUDA steams used by NCCL kernels
  470. std::unordered_map<std::string, std::vector<at::cuda::CUDAStream>>
  471. ncclStreams_;
  472. // The CUDA events used to sync NCCL streams
  473. std::unordered_map<std::string, std::vector<at::cuda::CUDAEvent>> ncclEvents_;
  474. // Device Indexes used for all collectives in this group
  475. std::set<int> usedDeviceIdxs_;
  476. // map from the key: "group name + pg counter (ID)" to the
  477. // unique NCCL ID count. This needs to be group and pg specific
  478. //
  479. // For each process group, we need a uniform unique NCCL ID counter to ensure
  480. // that NCCL operation in this process group can be completed successfully.
  481. // Since each process group ID belongs to a group name, the key to this map
  482. // is a combination of group name and ProcessGroupNCCL ID.
  483. static std::unordered_map<std::string, ssize_t> pgUniqueNCCLIDCnt_;
  484. // map from group name to the pg counter (ID) within that group
  485. //
  486. // For each group with the "group name" (which is the key), we need to
  487. // keep track of a unique process group ID when creating a new
  488. // ProcessGroupNCCL for this "group name". Therefore, the value of this
  489. // map keeps the unique ProcessGroupNCCL's ID for a specific group with
  490. // the "group name". The reason we need a per-group process group ID counter
  491. // is that different group can have different ranks and we need ensure that
  492. // each group has its own uniform process group ID for all its ranks.
  493. static std::unordered_map<std::string, ssize_t> processGroupCounterMap_;
  494. // Whether or not wait() and synchronize() are blocking operations that wait
  495. // for the operation to complete.
  496. bool blockingWait_ = false;
  497. // Whether or not the workCleanupThread is used to perform async error
  498. // handling.
  499. bool asyncErrorHandling_ = false;
  500. // Whether or not to enable timeout root cause analysis.
  501. bool desyncDebug_;
  502. // Set of communicators that this process group has aborted and their
  503. // ncclUniqueId has been written to the store. We don't need a lock
  504. // for this map since only the watchdog thread accesses this set. The
  505. // set contains the string representation of ncclUniqueId.
  506. std::unordered_set<std::string> abortedComms_;
  507. // The number of active ncclGroupStart() calls. This counter will be increased
  508. // by 1 when ncclGroupStart() is called and decreased by 1 when ncclGroupEnd()
  509. // is called.
  510. static thread_local uint64_t ncclActiveGroupCounter_;
  511. // Counting for the sequential number of NCCL collective call.
  512. uint64_t seq_{0};
  513. #ifdef USE_NCCL_WITH_UCC
  514. // ProcessGroupUCC shared library handle and ProcessGroup pointer
  515. static std::shared_ptr<at::DynamicLibrary> uccLib_;
  516. c10::intrusive_ptr<ProcessGroup> uccPG_;
  517. #endif
  518. };
  519. } // namespace c10d
  520. #endif // USE_C10D_NCCL