NCCLUtils.hpp 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. #pragma once
  2. #ifdef USE_C10D_NCCL
  3. #include <stdio.h>
  4. #include <stdlib.h>
  5. #include <memory>
  6. #include <mutex>
  7. #include <nccl.h>
  8. #include <c10/util/Exception.h>
  9. #include <c10/util/Optional.h>
  10. namespace {
  11. // Provides additional detail into NCCL error codes based on when these are
  12. // thrown in the NCCL codebase.
  13. const inline char* getNcclErrorDetailStr(ncclResult_t error, c10::optional<std::string> processGroupFailureReason = c10::nullopt) {
  14. // Prioritize failure reason provided by PG NCCL first, as it can abort
  15. // communicators when it encounters collective timeouts, etc.
  16. if (processGroupFailureReason != c10::nullopt) {
  17. return (*processGroupFailureReason).c_str();
  18. }
  19. switch (error) {
  20. case ncclUnhandledCudaError:
  21. return "ncclUnhandledCudaError: Call to CUDA function failed.";
  22. case ncclSystemError:
  23. return "ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. "
  24. "It can be also caused by unexpected exit of a remote peer, you can check NCCL warnings for failure reason and see if there is connection closure by a peer.";
  25. case ncclInternalError:
  26. return "ncclInternalError: Internal check failed. This is either a bug in NCCL or due to memory corruption";
  27. case ncclInvalidArgument:
  28. return "ncclInvalidArgument: Invalid value for an argument (such as invalid pointer, device count, ip:host pair, etc).";
  29. case ncclInvalidUsage:
  30. return "ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).";
  31. default:
  32. break;
  33. }
  34. return "Unknown NCCL error";
  35. }
  36. } // namespace
  37. // Error checking is enabled only for NCCL versions 2.4+ since ncclCommAbort()
  38. // and ncclCommGetAsyncError() are not supported in earlier versions.
  39. #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
  40. (NCCL_MINOR >= 4)
  41. #define ENABLE_NCCL_ERROR_CHECKING
  42. #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
  43. #define ENABLE_NCCL_ERROR_CHECKING
  44. #endif
  45. // P2P is enabled only for NCCL versions 2.7+ since ncclSend()
  46. // and ncclRecv() are not supported in earlier versions.
  47. #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
  48. (NCCL_MINOR >= 7)
  49. #define ENABLE_NCCL_P2P_SUPPORT
  50. #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
  51. #define ENABLE_NCCL_P2P_SUPPORT
  52. #endif
  53. // Macro to throw on a non-successful NCCL return value.
  54. #define C10D_NCCL_CHECK(cmd, failureReason) \
  55. do { \
  56. ncclResult_t result = cmd; \
  57. if (result != ncclSuccess) { \
  58. std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
  59. std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
  60. "\n" + getNcclErrorDetailStr(result, failureReason); \
  61. TORCH_CHECK(false, err); \
  62. } \
  63. } while (0)
  64. // Macro to print and abort on a non-successful NCCL return value.
  65. #define C10D_NCCL_ASSERT(cmd) \
  66. do { \
  67. ncclResult_t result = cmd; \
  68. if (result != ncclSuccess) { \
  69. std::string err = ncclGetErrorWithVersion(result); \
  70. fprintf( \
  71. stderr, \
  72. "NCCL error in: %s:%d, %s\n", \
  73. __FILE__, \
  74. __LINE__, \
  75. err.c_str()); \
  76. abort(); \
  77. } \
  78. } while (0)
  79. namespace c10d {
  80. std::string getNcclVersion();
  81. std::string ncclGetErrorWithVersion(ncclResult_t error);
  82. // RAII wrapper for NCCL communicator
  83. class NCCLComm {
  84. public:
  85. explicit NCCLComm(ncclComm_t ncclComm)
  86. : ncclComm_(ncclComm),
  87. aborted_(false),
  88. ncclAsyncErr_(ncclSuccess),
  89. commFailureReason_(c10::nullopt) {}
  90. NCCLComm() : NCCLComm(nullptr) {}
  91. ~NCCLComm() noexcept {
  92. // Add lock in this destructor, as aborted_ needs to be read after memory
  93. // barrier here.
  94. std::unique_lock<std::mutex> lock(mutex_);
  95. if (ncclComm_ && !aborted_) {
  96. #ifdef ENABLE_NCCL_ERROR_CHECKING
  97. // Use ncclCommAbort instead of ncclCommDestroy here since
  98. // ncclCommDestroy could block forever waiting for work to complete on
  99. // the communicator.
  100. C10D_NCCL_ASSERT(::ncclCommAbort(ncclComm_));
  101. #else
  102. C10D_NCCL_ASSERT(::ncclCommDestroy(ncclComm_));
  103. #endif
  104. }
  105. }
  106. static std::shared_ptr<NCCLComm> create(
  107. int numRanks,
  108. int rank,
  109. ncclUniqueId commId) {
  110. auto comm = std::make_shared<NCCLComm>();
  111. C10D_NCCL_CHECK(
  112. ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt);
  113. comm->ncclId_ = commId;
  114. comm->rank_ = rank;
  115. return comm;
  116. }
  117. ncclUniqueId getNcclId() {
  118. return ncclId_;
  119. }
  120. // Must not be copyable
  121. NCCLComm(const NCCLComm&) = delete;
  122. NCCLComm& operator=(const NCCLComm&) = delete;
  123. // Do not support move assignment as there is no valid use case
  124. NCCLComm& operator=(NCCLComm&& other) = delete;
  125. // Move constructable
  126. NCCLComm(NCCLComm&& other) {
  127. // Using other's lock, as it reads other's states
  128. // Can not use this.mutex_, as this object is being constructed.
  129. std::unique_lock<std::mutex> lock(other.mutex_);
  130. std::swap(ncclComm_, other.ncclComm_);
  131. std::swap(aborted_, other.aborted_);
  132. std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
  133. }
  134. ncclComm_t getNcclComm();
  135. c10::optional<std::string> getNcclCommFailureReason() const {
  136. std::unique_lock<std::mutex> lock(mutex_);
  137. return commFailureReason_;
  138. }
  139. void ncclCommAbort(
  140. c10::optional<std::string> commFailureReason = c10::nullopt) {
  141. std::unique_lock<std::mutex> lock(mutex_);
  142. #ifdef ENABLE_NCCL_ERROR_CHECKING
  143. if (aborted_) {
  144. // Should not abort twice.
  145. return;
  146. }
  147. // Set true failure reason if provided by ProcessGroupNCCL (e.g. work
  148. // timeout)
  149. commFailureReason_ = commFailureReason;
  150. C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
  151. aborted_ = true;
  152. ncclComm_ = nullptr;
  153. // Set an appropriate error so that we avoid using the communicator.
  154. if (ncclAsyncErr_ == ncclSuccess) {
  155. ncclAsyncErr_ = ncclSystemError;
  156. }
  157. #else
  158. // This is a NOOP, if error checks are disabled.
  159. return;
  160. #endif
  161. }
  162. bool isAborted() const {
  163. std::unique_lock<std::mutex> lock(mutex_);
  164. return aborted_;
  165. }
  166. ncclResult_t checkForNcclError() {
  167. std::unique_lock<std::mutex> lock(mutex_);
  168. #ifdef ENABLE_NCCL_ERROR_CHECKING
  169. if (ncclAsyncErr_ != ncclSuccess) {
  170. return ncclAsyncErr_;
  171. }
  172. C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_);
  173. return ncclAsyncErr_;
  174. #else
  175. // Always return success, if error checks are disabled.
  176. return ncclSuccess;
  177. #endif
  178. }
  179. protected:
  180. ncclComm_t ncclComm_;
  181. // Unique nccl_id for this communicator.
  182. ncclUniqueId ncclId_;
  183. bool aborted_;
  184. ncclResult_t ncclAsyncErr_;
  185. mutable std::mutex mutex_;
  186. // Rank that this communicator corresponds to.
  187. int rank_;
  188. // Optional reason for communicator failure, provided by ProcessGroupNCCL for
  189. // better error messaging.
  190. c10::optional<std::string> commFailureReason_;
  191. };
  192. } // namespace c10d
  193. #endif // USE_C10D_NCCL