#pragma once #ifdef USE_C10D_NCCL #include #include #include #include #include #include #include namespace { // Provides additional detail into NCCL error codes based on when these are // thrown in the NCCL codebase. const inline char* getNcclErrorDetailStr(ncclResult_t error, c10::optional processGroupFailureReason = c10::nullopt) { // Prioritize failure reason provided by PG NCCL first, as it can abort // communicators when it encounters collective timeouts, etc. if (processGroupFailureReason != c10::nullopt) { return (*processGroupFailureReason).c_str(); } switch (error) { case ncclUnhandledCudaError: return "ncclUnhandledCudaError: Call to CUDA function failed."; case ncclSystemError: return "ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. " "It can be also caused by unexpected exit of a remote peer, you can check NCCL warnings for failure reason and see if there is connection closure by a peer."; case ncclInternalError: return "ncclInternalError: Internal check failed. This is either a bug in NCCL or due to memory corruption"; case ncclInvalidArgument: return "ncclInvalidArgument: Invalid value for an argument (such as invalid pointer, device count, ip:host pair, etc)."; case ncclInvalidUsage: return "ncclInvalidUsage: This usually reflects invalid usage of NCCL library (such as too many async ops, too many collectives at once, mixing streams in a group, etc)."; default: break; } return "Unknown NCCL error"; } } // namespace // Error checking is enabled only for NCCL versions 2.4+ since ncclCommAbort() // and ncclCommGetAsyncError() are not supported in earlier versions. #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ (NCCL_MINOR >= 4) #define ENABLE_NCCL_ERROR_CHECKING #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) #define ENABLE_NCCL_ERROR_CHECKING #endif // P2P is enabled only for NCCL versions 2.7+ since ncclSend() // and ncclRecv() are not supported in earlier versions. #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ (NCCL_MINOR >= 7) #define ENABLE_NCCL_P2P_SUPPORT #elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3) #define ENABLE_NCCL_P2P_SUPPORT #endif // Macro to throw on a non-successful NCCL return value. #define C10D_NCCL_CHECK(cmd, failureReason) \ do { \ ncclResult_t result = cmd; \ if (result != ncclSuccess) { \ std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \ "\n" + getNcclErrorDetailStr(result, failureReason); \ TORCH_CHECK(false, err); \ } \ } while (0) // Macro to print and abort on a non-successful NCCL return value. #define C10D_NCCL_ASSERT(cmd) \ do { \ ncclResult_t result = cmd; \ if (result != ncclSuccess) { \ std::string err = ncclGetErrorWithVersion(result); \ fprintf( \ stderr, \ "NCCL error in: %s:%d, %s\n", \ __FILE__, \ __LINE__, \ err.c_str()); \ abort(); \ } \ } while (0) namespace c10d { std::string getNcclVersion(); std::string ncclGetErrorWithVersion(ncclResult_t error); // RAII wrapper for NCCL communicator class NCCLComm { public: explicit NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm), aborted_(false), ncclAsyncErr_(ncclSuccess), commFailureReason_(c10::nullopt) {} NCCLComm() : NCCLComm(nullptr) {} ~NCCLComm() noexcept { // Add lock in this destructor, as aborted_ needs to be read after memory // barrier here. std::unique_lock lock(mutex_); if (ncclComm_ && !aborted_) { #ifdef ENABLE_NCCL_ERROR_CHECKING // Use ncclCommAbort instead of ncclCommDestroy here since // ncclCommDestroy could block forever waiting for work to complete on // the communicator. C10D_NCCL_ASSERT(::ncclCommAbort(ncclComm_)); #else C10D_NCCL_ASSERT(::ncclCommDestroy(ncclComm_)); #endif } } static std::shared_ptr create( int numRanks, int rank, ncclUniqueId commId) { auto comm = std::make_shared(); C10D_NCCL_CHECK( ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt); comm->ncclId_ = commId; comm->rank_ = rank; return comm; } ncclUniqueId getNcclId() { return ncclId_; } // Must not be copyable NCCLComm(const NCCLComm&) = delete; NCCLComm& operator=(const NCCLComm&) = delete; // Do not support move assignment as there is no valid use case NCCLComm& operator=(NCCLComm&& other) = delete; // Move constructable NCCLComm(NCCLComm&& other) { // Using other's lock, as it reads other's states // Can not use this.mutex_, as this object is being constructed. std::unique_lock lock(other.mutex_); std::swap(ncclComm_, other.ncclComm_); std::swap(aborted_, other.aborted_); std::swap(ncclAsyncErr_, other.ncclAsyncErr_); } ncclComm_t getNcclComm(); c10::optional getNcclCommFailureReason() const { std::unique_lock lock(mutex_); return commFailureReason_; } void ncclCommAbort( c10::optional commFailureReason = c10::nullopt) { std::unique_lock lock(mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING if (aborted_) { // Should not abort twice. return; } // Set true failure reason if provided by ProcessGroupNCCL (e.g. work // timeout) commFailureReason_ = commFailureReason; C10D_NCCL_CHECK(::ncclCommAbort(ncclComm_), commFailureReason_); aborted_ = true; ncclComm_ = nullptr; // Set an appropriate error so that we avoid using the communicator. if (ncclAsyncErr_ == ncclSuccess) { ncclAsyncErr_ = ncclSystemError; } #else // This is a NOOP, if error checks are disabled. return; #endif } bool isAborted() const { std::unique_lock lock(mutex_); return aborted_; } ncclResult_t checkForNcclError() { std::unique_lock lock(mutex_); #ifdef ENABLE_NCCL_ERROR_CHECKING if (ncclAsyncErr_ != ncclSuccess) { return ncclAsyncErr_; } C10D_NCCL_CHECK(ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_); return ncclAsyncErr_; #else // Always return success, if error checks are disabled. return ncclSuccess; #endif } protected: ncclComm_t ncclComm_; // Unique nccl_id for this communicator. ncclUniqueId ncclId_; bool aborted_; ncclResult_t ncclAsyncErr_; mutable std::mutex mutex_; // Rank that this communicator corresponds to. int rank_; // Optional reason for communicator failure, provided by ProcessGroupNCCL for // better error messaging. c10::optional commFailureReason_; }; } // namespace c10d #endif // USE_C10D_NCCL