ProcessGroupMPI.hpp 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. #pragma once
  2. #ifdef USE_C10D_MPI
  3. #include <condition_variable>
  4. #include <deque>
  5. #include <exception>
  6. #include <memory>
  7. #include <mutex>
  8. #include <thread>
  9. #include <vector>
  10. #include <ATen/core/ivalue.h>
  11. #include <ATen/core/ivalue_inl.h>
  12. #include <c10d/ProcessGroup.hpp>
  13. #include <c10d/Types.hpp>
  14. #include <c10d/Utils.hpp>
  15. #include <mpi.h>
  16. namespace c10d {
  17. constexpr const char* MPI_BACKEND_NAME = "mpi";
  18. // WorkEntry is the state associated with a single MPI run instance.
  19. // It include the source Tensor list and destination Tensor list, as well as
  20. // The actual run function that will operate either on src or dst or both.
  21. struct WorkEntry {
  22. explicit WorkEntry(
  23. std::vector<at::Tensor>* srcPtr,
  24. std::vector<at::Tensor>* dstPtr,
  25. std::function<void(std::unique_ptr<WorkEntry>&)> run)
  26. : dst(dstPtr ? *dstPtr : std::vector<at::Tensor>()),
  27. run(std::move(run)) {
  28. if (srcPtr) {
  29. src = *srcPtr;
  30. }
  31. }
  32. // Not copyable
  33. WorkEntry(const WorkEntry&) = delete;
  34. // Not copy assignable
  35. WorkEntry& operator=(const WorkEntry&) = delete;
  36. // For input and output tensors (in-place), we will always use src
  37. std::vector<at::Tensor> src;
  38. // Copy of user provided outputs.
  39. const std::vector<at::Tensor> dst;
  40. // src rank returned, for recv only
  41. int* srcRank = nullptr;
  42. std::function<void(std::unique_ptr<WorkEntry>&)> run;
  43. };
  44. // ProcessGroupMPI implements MPI bindings for c10d.
  45. //
  46. // All functions on this class are expected to be called in the same
  47. // order across processes in the group. This is the only way that we
  48. // can guarantee to match up the same calls across processes.
  49. //
  50. // All MPI functions provided by this class is asynchronously scheduled on a
  51. // Worker thread. Therefore, ProcessGroupMPI requires the MPI implementation
  52. // that is used to have a minimum thread support value of MPI_THREAD_SERIALIZED.
  53. // That is, The process may be multi-threaded, and multiple threads may make
  54. // MPI calls, but only one at a time: MPI calls are not made concurrently from
  55. // two distinct threads (all MPI calls are serialized). However, with
  56. // MPI_THREAD_SERIALIZED, ProcessGroupMPI will only support a singe process
  57. // group. In other words, no more than 1 process group can be created globally.
  58. //
  59. // If you would like to use multiple ProcessGroupMPI, it requres your MPI
  60. // implemenation to have a thread support value of MPI_THREAD_MULTIPLE, that is,
  61. // multiple threads may call MPI, with no restriction.
  62. //
  63. // Also note that ProcessGroupMPI only supports a single Tensor operation. In
  64. // other words, the size of the input Tensor vector should always be 1.
  65. //
  66. // CUDA tensor can be supported if the MPI used is CUDA-aware MPI, and
  67. // ProcessGroupMPI will automatically detect this support.
  68. class TORCH_API ProcessGroupMPI : public ProcessGroup {
  69. public:
  70. class WorkMPI : public ProcessGroup::Work {
  71. public:
  72. explicit WorkMPI(
  73. std::vector<at::Tensor> outputTensors,
  74. const char* profilingTitle = nullptr,
  75. const c10::optional<std::vector<at::Tensor>>& inputTensors =
  76. c10::nullopt)
  77. : ProcessGroup::Work(-1, OpType::UNKNOWN, profilingTitle, inputTensors),
  78. outputTensors_(std::move(outputTensors)),
  79. future_(c10::make_intrusive<at::ivalue::Future>(
  80. c10::ListType::create(c10::TensorType::get()))) {}
  81. std::vector<at::Tensor> result() override;
  82. c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
  83. protected:
  84. friend class ProcessGroupMPI;
  85. private:
  86. void finishWorkMPI();
  87. void finishWorkMPIError(std::exception_ptr eptr);
  88. std::vector<at::Tensor> outputTensors_;
  89. c10::intrusive_ptr<at::ivalue::Future> future_;
  90. };
  91. class AsyncWork : public ProcessGroup::Work {
  92. public:
  93. AsyncWork(
  94. MPI_Request request,
  95. std::vector<at::Tensor> outputTensors,
  96. const char* profilingTitle = nullptr,
  97. const c10::optional<std::vector<at::Tensor>>& inputTensors =
  98. c10::nullopt);
  99. virtual ~AsyncWork();
  100. bool isCompleted() override;
  101. bool isSuccess() const override;
  102. int sourceRank() const override;
  103. bool wait(std::chrono::milliseconds timeout = kUnsetTimeout) override;
  104. void abort() override;
  105. std::vector<at::Tensor> result() override;
  106. protected:
  107. void populateException();
  108. private:
  109. const std::vector<at::Tensor> outputTensors_;
  110. MPI_Request request_;
  111. MPI_Status status_;
  112. };
  113. // Constructor will spawn up the worker thread loop
  114. explicit ProcessGroupMPI(int rank, int size, MPI_Comm pgComm);
  115. virtual ~ProcessGroupMPI();
  116. // Abort the MPI program, needs to be called when exception is detected
  117. void abort();
  118. const std::string getBackendName() const override {
  119. return std::string(MPI_BACKEND_NAME);
  120. }
  121. c10::intrusive_ptr<ProcessGroup::Work> broadcast(
  122. std::vector<at::Tensor>& data,
  123. const BroadcastOptions& opts = BroadcastOptions()) override;
  124. c10::intrusive_ptr<ProcessGroup::Work> allreduce(
  125. std::vector<at::Tensor>& tensors,
  126. const AllreduceOptions& opts = AllreduceOptions()) override;
  127. c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
  128. std::vector<at::Tensor>& tensors,
  129. const AllreduceCoalescedOptions& opts =
  130. AllreduceCoalescedOptions()) override;
  131. c10::intrusive_ptr<ProcessGroup::Work> reduce(
  132. std::vector<at::Tensor>& tensors,
  133. const ReduceOptions& opts = ReduceOptions()) override;
  134. c10::intrusive_ptr<ProcessGroup::Work> allgather(
  135. std::vector<std::vector<at::Tensor>>& outputTensors,
  136. std::vector<at::Tensor>& inputTensors,
  137. const AllgatherOptions& opts = AllgatherOptions()) override;
  138. c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
  139. at::Tensor& outputbuffer,
  140. at::Tensor& inputbuffer,
  141. const AllgatherOptions& opts = AllgatherOptions()) override;
  142. c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
  143. std::vector<std::vector<at::Tensor>>& outputTensorLists,
  144. std::vector<at::Tensor>& inputTensors,
  145. const AllgatherOptions& opts = AllgatherOptions()) override;
  146. c10::intrusive_ptr<ProcessGroup::Work> gather(
  147. std::vector<std::vector<at::Tensor>>& outputTensors,
  148. std::vector<at::Tensor>& inputTensors,
  149. const GatherOptions& opts = GatherOptions()) override;
  150. c10::intrusive_ptr<ProcessGroup::Work> scatter(
  151. std::vector<at::Tensor>& outputTensors,
  152. std::vector<std::vector<at::Tensor>>& inputTensors,
  153. const ScatterOptions& opts = ScatterOptions()) override;
  154. c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
  155. std::vector<at::Tensor>& outputTensors,
  156. std::vector<std::vector<at::Tensor>>& inputTensors,
  157. const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
  158. c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
  159. at::Tensor& outputTensor,
  160. at::Tensor& inputTensor,
  161. std::vector<int64_t>& outputSplitSizes,
  162. std::vector<int64_t>& inputSplitSizes,
  163. const AllToAllOptions& opts = AllToAllOptions()) override;
  164. c10::intrusive_ptr<ProcessGroup::Work> alltoall(
  165. std::vector<at::Tensor>& outputTensors,
  166. std::vector<at::Tensor>& inputTensors,
  167. const AllToAllOptions& opts = AllToAllOptions()) override;
  168. c10::intrusive_ptr<ProcessGroup::Work> send(
  169. std::vector<at::Tensor>& tensors,
  170. int dstRank,
  171. int tag) override;
  172. c10::intrusive_ptr<ProcessGroup::Work> recv(
  173. std::vector<at::Tensor>& tensors,
  174. int srcRank,
  175. int tag) override;
  176. c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
  177. std::vector<at::Tensor>& tensor,
  178. int tag) override;
  179. c10::intrusive_ptr<ProcessGroup::Work> barrier(
  180. const BarrierOptions& opts = BarrierOptions()) override;
  181. // Creating a new ProcessGroupMPI, will initiialize MPI if not initialized
  182. static c10::intrusive_ptr<ProcessGroupMPI> createProcessGroupMPI(
  183. std::vector<int> ranks = {});
  184. protected:
  185. using WorkType =
  186. std::tuple<std::unique_ptr<WorkEntry>, c10::intrusive_ptr<WorkMPI>>;
  187. // Worker thread loop
  188. void runLoop();
  189. // Helper function that is called by the destructor
  190. void destroy();
  191. c10::intrusive_ptr<ProcessGroup::Work> enqueue(
  192. std::unique_ptr<WorkEntry> entry,
  193. const char* profilingTitle = nullptr,
  194. const c10::optional<std::vector<at::Tensor>>& inputTensors = c10::nullopt);
  195. bool stop_;
  196. std::mutex pgMutex_;
  197. std::thread workerThread_;
  198. std::deque<WorkType> queue_;
  199. std::condition_variable queueProduceCV_;
  200. std::condition_variable queueConsumeCV_;
  201. // Global states
  202. static void initMPIOnce();
  203. static void mpiExit();
  204. static std::once_flag onceFlagInitMPI;
  205. static std::mutex pgGlobalMutex_;
  206. static int mpiThreadSupport_;
  207. MPI_Comm pgComm_;
  208. };
  209. } // namespace c10d
  210. #endif // USE_C10D_MPI