ProcessGroupGloo.hpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. #pragma once
  2. #ifdef USE_C10D_GLOO
  3. #include <condition_variable>
  4. #include <deque>
  5. #include <mutex>
  6. #include <thread>
  7. #include <unordered_map>
  8. #include <vector>
  9. #include <gloo/rendezvous/store.h>
  10. #include <gloo/algorithm.h>
  11. #include <gloo/common/error.h>
  12. #include <gloo/context.h>
  13. #include <gloo/rendezvous/store.h>
  14. #include <gloo/transport/device.h>
  15. #include <c10/util/hash.h>
  16. #include <c10d/ProcessGroup.hpp>
  17. #include <c10d/Store.hpp>
  18. #include <c10d/Types.hpp>
  19. #include <c10d/Utils.hpp>
  20. namespace c10d {
  21. constexpr const char* GLOO_BACKEND_NAME = "gloo";
  22. // ProcessGroupGloo implements Gloo bindings for c10d.
  23. //
  24. // All functions on this class are expected to be called in the same
  25. // order across processes in the group. This is the only way that we
  26. // can guarantee to match up the same calls across processes. For
  27. // multi-threaded usage of process groups, you can use consider using
  28. // multiple process group instances.
  29. //
  30. // The Gloo algorithms that this class calls into are cached by their
  31. // signature (see description of AlgorithmKey above). This cache works
  32. // as follows: every function call instantiates an AlgorithmKey and
  33. // looks in the cache for existing entries. If there is one, it is
  34. // removed from the cache and returned to the caller. If there are
  35. // none, a new entry is created and returned. If an entry was created
  36. // before, but is still in use, the call will block and wait until the
  37. // entry is returned to the cache.
  38. //
  39. // In the future, we hope to extend this to allow multiple entries per
  40. // key, to enable parallelism for a single key. The number of entries
  41. // per key must always be identical for all processes. This maximum
  42. // number can be automatically tuned, but only if we let a single
  43. // process take charge, and have it broadcast the limits.
  44. //
  45. class TORCH_API ProcessGroupGloo : public ProcessGroup {
  46. public:
  47. // AsyncWork is the Gloo specific superclass for asynchronous work items.
  48. // We can split asynchronous work into 3 phases:
  49. // 1) Sanity checks and prepare input (e.g. memcpy)
  50. // 2) Run operation on background thread
  51. // 3) Synchronize with completion on foreground thread
  52. //
  53. // There is state to be shared between these 3 phases and all of this state
  54. // is captured in the AsyncWork class and its derivatives.
  55. //
  56. // Note: while we are porting operations to use new style collectives, there
  57. // is a split between operations using the existing caching approach and
  58. // operations using the new AsyncWork base class. Over time we will port
  59. // all operations and perform needed cleanup.
  60. //
  61. // FIXME: This probably should be called WorkGloo since the work is executed in sync mode
  62. // by a background thread.
  63. class TORCH_API AsyncWork : public ProcessGroup::Work {
  64. public:
  65. explicit AsyncWork(
  66. std::vector<std::vector<at::Tensor>> outputTensors,
  67. const char* profilingTitle = nullptr,
  68. const c10::optional<std::vector<at::Tensor>>& inputTensors = c10::nullopt);
  69. ~AsyncWork() override = default;
  70. static void execute(c10::intrusive_ptr<AsyncWork> work);
  71. virtual void run() = 0;
  72. std::vector<at::Tensor> result() override;
  73. c10::intrusive_ptr<c10::ivalue::Future> getFuture() override;
  74. protected:
  75. friend class ProcessGroupGloo;
  76. private:
  77. void finishWorkGloo();
  78. void finishWorkGlooError(std::exception_ptr eptr);
  79. inline void recordAsyncWorkProfilingInfo(
  80. const char* profilingTitle,
  81. const c10::optional<std::vector<at::Tensor>>& inputTensors);
  82. const std::vector<std::vector<at::Tensor>> outputTensors_;
  83. c10::intrusive_ptr<at::ivalue::Future> future_;
  84. std::function<void()> recordFunctionBeforeCallback_;
  85. };
  86. // Wrap c10d store as Gloo store
  87. class TORCH_API GlooStore : public ::gloo::rendezvous::Store {
  88. public:
  89. GlooStore(const c10::intrusive_ptr<::c10d::Store>& store) : store_(store) {}
  90. void setUint(const std::string& key, const std::vector<uint8_t>& value) {
  91. store_->set(key, value);
  92. }
  93. void set(const std::string& key, const std::vector<char>& value) override {
  94. std::vector<uint8_t> tmp(value.begin(), value.end());
  95. store_->set(key, tmp);
  96. }
  97. std::vector<uint8_t> getUint(const std::string& key) {
  98. auto value = store_->get(key);
  99. return value;
  100. }
  101. std::vector<char> get(const std::string& key) override {
  102. auto value = store_->get(key);
  103. return std::vector<char>(value.begin(), value.end());
  104. }
  105. void wait(const std::vector<std::string>& keys) override {
  106. store_->wait(keys, Store::kDefaultTimeout);
  107. }
  108. void wait(
  109. const std::vector<std::string>& keys,
  110. const std::chrono::milliseconds& timeout) override {
  111. store_->wait(keys, timeout);
  112. }
  113. protected:
  114. c10::intrusive_ptr<::c10d::Store> store_;
  115. };
  116. // For send and recv operations there is no need to pass them to the
  117. // thread pool as they are entirely completed by the device thread.
  118. // This work object is used to synchronize completion of the send or
  119. // recv operation. It keeps a reference to the tensor it is
  120. // operating on to prevent it from being deallocated while the
  121. // operation is still in flight.
  122. class TORCH_API SendWork : public ProcessGroup::Work {
  123. public:
  124. explicit SendWork(
  125. at::Tensor& tensor,
  126. std::unique_ptr<::gloo::transport::UnboundBuffer> buffer);
  127. bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
  128. void abort() override;
  129. protected:
  130. at::Tensor tensor_;
  131. std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_;
  132. };
  133. class TORCH_API RecvWork : public ProcessGroup::Work {
  134. public:
  135. explicit RecvWork(
  136. at::Tensor& tensor,
  137. std::unique_ptr<::gloo::transport::UnboundBuffer> buffer,
  138. const char* profilingTitle = nullptr);
  139. int sourceRank() const override;
  140. bool wait(std::chrono::milliseconds timeout = kNoTimeout) override;
  141. void abort() override;
  142. protected:
  143. at::Tensor tensor_;
  144. std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_;
  145. int srcRank_;
  146. };
  147. struct TORCH_API Options : public ProcessGroup::Options {
  148. explicit Options(
  149. std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout);
  150. // return intrusive_ptr of the object
  151. static c10::intrusive_ptr<Options> create(
  152. std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout) {
  153. return c10::make_intrusive<Options>(timeout);
  154. }
  155. std::vector<std::shared_ptr<::gloo::transport::Device>> devices;
  156. int threads;
  157. };
  158. const std::string getBackendName() const override {
  159. return std::string(GLOO_BACKEND_NAME);
  160. }
  161. // Helper functions to create a new device object.
  162. // They are static functions on this class to keep them logically
  163. // separate from the rest of the code base (e.g. torch/csrc/distributed).
  164. // Create new device instance for specific interface.
  165. static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface(
  166. const std::string& interface);
  167. // Create new device instance for specific hostname or address.
  168. static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname(
  169. const std::string& hostname);
  170. // Create new device instance.
  171. // It tries to resolve this machine's hostname and bind to that address.
  172. // If that fails (i.e. the hostname doesn't resolve to an address), it
  173. // falls back to binding to the loopback address.
  174. static std::shared_ptr<::gloo::transport::Device> createDefaultDevice();
  175. explicit ProcessGroupGloo(
  176. const c10::intrusive_ptr<Store>& store,
  177. int rank,
  178. int size,
  179. c10::intrusive_ptr<Options> options = Options::create());
  180. virtual ~ProcessGroupGloo();
  181. c10::intrusive_ptr<Options> getOptions() {
  182. return options_;
  183. }
  184. c10::intrusive_ptr<ProcessGroup::Work> broadcast(
  185. std::vector<at::Tensor>& tensors,
  186. const BroadcastOptions& opts = BroadcastOptions()) override;
  187. c10::intrusive_ptr<ProcessGroup::Work> allreduce(
  188. std::vector<at::Tensor>& tensors,
  189. const AllreduceOptions& opts = AllreduceOptions()) override;
  190. c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
  191. std::vector<at::Tensor>& tensors,
  192. const AllreduceCoalescedOptions& opts =
  193. AllreduceCoalescedOptions()) override;
  194. c10::intrusive_ptr<ProcessGroup::Work> reduce(
  195. std::vector<at::Tensor>& tensors,
  196. const ReduceOptions& opts = ReduceOptions()) override;
  197. c10::intrusive_ptr<ProcessGroup::Work> allgather(
  198. std::vector<std::vector<at::Tensor>>& outputs,
  199. std::vector<at::Tensor>& inputs,
  200. const AllgatherOptions& opts = AllgatherOptions()) override;
  201. c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
  202. at::Tensor& outputBuffer,
  203. at::Tensor& inputBuffer,
  204. const AllgatherOptions& opts = AllgatherOptions()) override;
  205. c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
  206. std::vector<std::vector<at::Tensor>>& output_lists,
  207. std::vector<at::Tensor>& input_list,
  208. const AllgatherOptions& opts = AllgatherOptions()) override;
  209. c10::intrusive_ptr<ProcessGroup::Work> gather(
  210. std::vector<std::vector<at::Tensor>>& outputs,
  211. std::vector<at::Tensor>& inputs,
  212. const GatherOptions& opts = GatherOptions()) override;
  213. c10::intrusive_ptr<ProcessGroup::Work> scatter(
  214. std::vector<at::Tensor>& outputs,
  215. std::vector<std::vector<at::Tensor>>& inputs,
  216. const ScatterOptions& opts = ScatterOptions()) override;
  217. c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
  218. std::vector<at::Tensor>& outputs,
  219. std::vector<std::vector<at::Tensor>>& inputs,
  220. const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
  221. c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
  222. at::Tensor& outputTensor,
  223. at::Tensor& inputTensor,
  224. std::vector<int64_t>& outputCounts,
  225. std::vector<int64_t>& inputCounts,
  226. const AllToAllOptions& opts = AllToAllOptions()) override;
  227. c10::intrusive_ptr<ProcessGroup::Work> send(
  228. std::vector<at::Tensor>& tensors,
  229. int dstRank,
  230. int tag) override;
  231. c10::intrusive_ptr<ProcessGroup::Work> recv(
  232. std::vector<at::Tensor>& tensors,
  233. int srcRank,
  234. int tag) override;
  235. c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
  236. std::vector<at::Tensor>& tensors,
  237. int tag) override;
  238. c10::intrusive_ptr<ProcessGroup::Work> barrier(
  239. const BarrierOptions& opts = BarrierOptions()) override;
  240. const std::unique_ptr<::gloo::rendezvous::Store>& _getStore() const {
  241. return store_;
  242. }
  243. // Similar to barrier(), but blocks rank 0 until all other ranks have
  244. // acknowledged that they are alive (through send/recv from rank 0). Rank 0
  245. // is able to report all failed ranks if waitAllRanks = true, otherwise
  246. // reports the first rank it detected as failed.
  247. void monitoredBarrier(
  248. const BarrierOptions& opts = BarrierOptions(),
  249. bool waitAllRanks = false) override;
  250. // Agrees on an initial sequence number for the whole group by having rank 0
  251. // create it and broadcast it to other ranks using the store.
  252. void setSequenceNumberForGroup() override;
  253. // Retrieves the current sequence number for the whole group, which should be
  254. // in sync. If the returned number is not consistent across the group, it
  255. // may indicate that there is some sort of collective desynchronization.
  256. uint64_t getSequenceNumberForGroup() override;
  257. int getNumThreads() {
  258. return options_->threads;
  259. }
  260. protected:
  261. std::unique_ptr<::gloo::rendezvous::Store> store_;
  262. const c10::intrusive_ptr<Options> options_;
  263. // Every Gloo context represents a set of connections to its peers.
  264. // In order to use more than one device (or allow for parallelism on
  265. // a single device), you need multiple contexts.
  266. std::vector<std::shared_ptr<::gloo::Context>> contexts_;
  267. std::vector<std::thread> threads_;
  268. bool stop_;
  269. // Incremented for every collective we kick off.
  270. // The value is used as tag for collective operations. Collectives are kicked
  271. // off in identical order across processes. Therefore the tag can be used
  272. // to match up operations during concurrent execution.
  273. uint32_t collectiveCounter_;
  274. // Returns next collective tag to use (uses collectiveCounter_).
  275. uint32_t nextTag();
  276. // Returns the context to use for the specified tag.
  277. // With `nextTag` returning an increasing number, this should lead
  278. // to contexts being used in a round-robin fashion.
  279. std::shared_ptr<::gloo::Context> getContext(uint32_t tag);
  280. // Entrypoint for worker threads.
  281. void runLoop(int workerIndex);
  282. // Queue work to run on worker thread.
  283. void enqueue(c10::intrusive_ptr<AsyncWork> work);
  284. // Keep both a queue of pending work, and a vector with in progress work.
  285. // Both of these can only be mutated when holding the queue lock.
  286. // We keep both around instead of just the queue, so we can grab a weak_ptr
  287. // to all in progress and pending work when executing a barrier.
  288. // When executing a barrier, we need to ensure that all prior work
  289. // has completed before completing itself.
  290. std::deque<c10::intrusive_ptr<AsyncWork>> workQueue_;
  291. std::vector<c10::intrusive_ptr<AsyncWork>> workInProgress_;
  292. std::mutex workMutex_;
  293. std::condition_variable workProduceCV_;
  294. std::condition_variable workConsumeCV_;
  295. };
  296. } // namespace c10d
  297. #endif // USE_C10D_GLOO