| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- #pragma once
- #ifdef USE_C10D_NCCL
- #include <stdio.h>
- #include <stdlib.h>
- #include <memory>
- #include <mutex>
- #include <nccl.h>
- #include <c10/util/Exception.h>
- #include <c10/util/Optional.h>
- namespace {
- // Provides additional detail into NCCL error codes based on when these are
- // thrown in the NCCL codebase.
- const inline char* getNcclErrorDetailStr(ncclResult_t error, c10::optional<std::string> processGroupFailureReason = c10::nullopt) {
- // Prioritize failure reason provided by PG NCCL first, as it can abort
- // communicators when it encounters collective timeouts, etc.
- if (processGroupFailureReason != c10::nullopt) {
- return (*processGroupFailureReason).c_str();
- }
- switch (error) {
- case ncclUnhandledCudaError:
- return "ncclUnhandledCudaError: Call to CUDA function failed.";
- case ncclSystemError:
- return "ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. "
- "It can be also caused by unexpected exit of a remote peer, you can check NCCL warnings for failure reason and see if there is connection closure by a peer.";
- case ncclInternalError:
- return "ncclInternalError: Internal check failed. This is either a bug in NCCL or due to memory corruption";
- case ncclInvalidArgument:
- return "ncclInvalidArgument: Invalid value for an argument (such as invalid pointer, device count, ip:host pair, etc).";
- case ncclInvalidUsage:
- return "ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc).";
- default:
- break;
- }
- return "Unknown NCCL error";
- }
- } // namespace
- // Error checking is enabled only for NCCL versions 2.4+ since ncclCommAbort()
- // and ncclCommGetAsyncError() are not supported in earlier versions.
- #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
- (NCCL_MINOR >= 4)
- #define ENABLE_NCCL_ERROR_CHECKING
- #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
- #define ENABLE_NCCL_ERROR_CHECKING
- #endif
- // P2P is enabled only for NCCL versions 2.7+ since ncclSend()
- // and ncclRecv() are not supported in earlier versions.
- #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
- (NCCL_MINOR >= 7)
- #define ENABLE_NCCL_P2P_SUPPORT
- #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
- #define ENABLE_NCCL_P2P_SUPPORT
- #endif
- // Macro to throw on a non-successful NCCL return value.
- #define C10D_NCCL_CHECK(cmd, failureReason) \
- do { \
- ncclResult_t result = cmd; \
- if (result != ncclSuccess) { \
- std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \
- std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \
- "\n" + getNcclErrorDetailStr(result, failureReason); \
- TORCH_CHECK(false, err); \
- } \
- } while (0)
- // Macro to print and abort on a non-successful NCCL return value.
- #define C10D_NCCL_ASSERT(cmd) \
- do { \
- ncclResult_t result = cmd; \
- if (result != ncclSuccess) { \
- std::string err = ncclGetErrorWithVersion(result); \
- fprintf( \
- stderr, \
- "NCCL error in: %s:%d, %s\n", \
- __FILE__, \
- __LINE__, \
- err.c_str()); \
- abort(); \
- } \
- } while (0)
- namespace c10d {
- std::string getNcclVersion();
- std::string ncclGetErrorWithVersion(ncclResult_t error);
- // RAII wrapper for NCCL communicator
- class NCCLComm {
- public:
- explicit NCCLComm(ncclComm_t ncclComm)
- : ncclComm_(ncclComm),
- aborted_(false),
- ncclAsyncErr_(ncclSuccess),
- commFailureReason_(c10::nullopt) {}
- NCCLComm() : NCCLComm(nullptr) {}
- ~NCCLComm() noexcept {
- // Add lock in this destructor, as aborted_ needs to be read after memory
- // barrier here.
- std::unique_lock<std::mutex> lock(mutex_);
- if (ncclComm_ && !aborted_) {
- #ifdef ENABLE_NCCL_ERROR_CHECKING
- // Use ncclCommAbort instead of ncclCommDestroy here since
- // ncclCommDestroy could block forever waiting for work to complete on
- // the communicator.
- C10D_NCCL_ASSERT(::ncclCommAbort(ncclComm_));
- #else
- C10D_NCCL_ASSERT(::ncclCommDestroy(ncclComm_));
- #endif
- }
- }
- static std::shared_ptr<NCCLComm> create(
- int numRanks,
- int rank,
- ncclUniqueId commId) {
- auto comm = std::make_shared<NCCLComm>();
- C10D_NCCL_CHECK(
- ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt);
- comm->ncclId_ = commId;
- comm->rank_ = rank;
- return comm;
- }
- ncclUniqueId getNcclId() {
- return ncclId_;
- }
- // Must not be copyable
- NCCLComm(const NCCLComm&) = delete;
- NCCLComm& operator=(const NCCLComm&) = delete;
- // Do not support move assignment as there is no valid use case
- NCCLComm& operator=(NCCLComm&& other) = delete;
- // Move constructable
- NCCLComm(NCCLComm&& other) {
- // Using other's lock, as it reads other's states
- // Can not use this.mutex_, as this object is being constructed.
- std::unique_lock<std::mutex> lock(other.mutex_);
- std::swap(ncclComm_, other.ncclComm_);
- std::swap(aborted_, other.aborted_);
- std::swap(ncclAsyncErr_, other.ncclAsyncErr_);
- }
- ncclComm_t getNcclComm();
- c10::optional<std::string> getNcclCommFailureReason() const {
- std::unique_lock<std::mutex> lock(mutex_);
- return commFailureReason_;
- }
- void ncclCommAbort(
- c10::optional<std::string> commFailureReason = c10::nullopt) {
- std::unique_lock<std::mutex> lock(mutex_);
- #ifdef ENABLE_NCCL_ERROR_CHECKING
- if (aborted_) {
- // Should not abort twice.
- return;
- }
- // Set true failure reason if provided by ProcessGroupNCCL (e.g. work
- // timeout)
- commFailureReason_ = commFailureReason;
- C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_);
- aborted_ = true;
- ncclComm_ = nullptr;
- // Set an appropriate error so that we avoid using the communicator.
- if (ncclAsyncErr_ == ncclSuccess) {
- ncclAsyncErr_ = ncclSystemError;
- }
- #else
- // This is a NOOP, if error checks are disabled.
- return;
- #endif
- }
- bool isAborted() const {
- std::unique_lock<std::mutex> lock(mutex_);
- return aborted_;
- }
- ncclResult_t checkForNcclError() {
- std::unique_lock<std::mutex> lock(mutex_);
- #ifdef ENABLE_NCCL_ERROR_CHECKING
- if (ncclAsyncErr_ != ncclSuccess) {
- return ncclAsyncErr_;
- }
- C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_);
- return ncclAsyncErr_;
- #else
- // Always return success, if error checks are disabled.
- return ncclSuccess;
- #endif
- }
- protected:
- ncclComm_t ncclComm_;
- // Unique nccl_id for this communicator.
- ncclUniqueId ncclId_;
- bool aborted_;
- ncclResult_t ncclAsyncErr_;
- mutable std::mutex mutex_;
- // Rank that this communicator corresponds to.
- int rank_;
- // Optional reason for communicator failure, provided by ProcessGroupNCCL for
- // better error messaging.
- c10::optional<std::string> commFailureReason_;
- };
- } // namespace c10d
- #endif // USE_C10D_NCCL
|