comm.hpp 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. #pragma once
  2. #include <ATen/ATen.h>
  3. #include <ATen/core/ivalue.h>
  4. #include <c10d/ProcessGroup.hpp>
  5. #include <torch/csrc/Export.h>
  6. namespace c10d {
  7. // Broadcast many tensors to all processes in the process group.
  8. TORCH_API void broadcast_coalesced(
  9. c10::intrusive_ptr<c10d::ProcessGroup> process_group,
  10. at::TensorList tensors,
  11. size_t buffer_size,
  12. int rank = 0);
  13. // This class passes bucket contents tensor to DDP communication hook.
  14. class TORCH_API GradBucket {
  15. public:
  16. explicit GradBucket(
  17. size_t index,
  18. size_t bucket_count,
  19. const at::Tensor& tensor,
  20. const std::vector<size_t>& offsets,
  21. const std::vector<size_t>& lengths,
  22. const std::vector<c10::IntArrayRef>& sizes_vec,
  23. const std::vector<at::Tensor>& parameters)
  24. : index_(index),
  25. bucket_count_(bucket_count),
  26. buffer_(tensor),
  27. offsets_(offsets),
  28. lengths_(lengths),
  29. sizes_vec_(sizes_vec),
  30. parameters_(parameters) {}
  31. // Returns the index of the bucket, which is unique across all the buckets.
  32. size_t getIndex() const {
  33. return index_;
  34. }
  35. const at::Tensor& getBuffer() const {
  36. return buffer_;
  37. }
  38. // Returns a mutable buffer compared with the above method.
  39. at::Tensor& getBufferRef() {
  40. return buffer_;
  41. }
  42. // Overwrites the buffer at a specific index.
  43. void setBuffer(at::Tensor& buffer) {
  44. buffer_ = buffer;
  45. }
  46. // Each tensor in the list that getGradients corresponds to a
  47. // parameter.
  48. std::vector<at::Tensor> getGradients() const;
  49. // Returns model parameters belonging to this bucket. They are returned in the
  50. // same order as gradient tensors via getGradients(). For example,
  51. // getParameters[i] will have its gradient stored in
  52. // getGradients[i]
  53. const std::vector<at::Tensor> getParameters() const {
  54. return parameters_;
  55. }
  56. // Returns whther this bucket is the last bucket to allreduce in an iteration.
  57. bool isLast() const {
  58. return index_ == bucket_count_ - 1;
  59. }
  60. private:
  61. size_t index_;
  62. size_t bucket_count_;
  63. at::Tensor buffer_;
  64. // Per-variable info in buffer_.
  65. std::vector<size_t> offsets_;
  66. std::vector<size_t> lengths_;
  67. std::vector<c10::IntArrayRef> sizes_vec_;
  68. // Model parameters for this bucket.
  69. const std::vector<at::Tensor> parameters_;
  70. };
  71. // Base class of both `PythonCommHook` and `CppCommHook`.
  72. // Requires implementing 1) `runHook` method that communicates gradients
  73. // asynchronously, and 2) `parseHookResult` method that converts the hook
  74. // result into a tensor.
  75. class TORCH_API CommHookInterface {
  76. public:
  77. virtual ~CommHookInterface() = default;
  78. // Passes the input grad bucket to the registered communication hook.
  79. // Once the tensor in the bucket are ready, kicks off the hook asynchronously
  80. // and returns a future that holds the communication results.
  81. virtual c10::intrusive_ptr<c10::ivalue::Future> runHook(
  82. GradBucket& bucket) = 0;
  83. // Returns the resulting tensor once the communication hook result is
  84. // ready. The resulting tensor will then be copied to the grads of
  85. // individual parameters.
  86. virtual at::Tensor parseHookResult(
  87. const c10::IValue& result) = 0;
  88. };
  89. namespace detail {
  90. // This helper function is called both by CppCommHookInterface below and inside
  91. // reducer.
  92. inline at::Tensor parseCppCommHookResult(
  93. const c10::IValue& result) {
  94. TORCH_INTERNAL_ASSERT(
  95. result.isTensor() || result.isTensorList(),
  96. "expected the hook result is either a Tensor or a TensorList");
  97. if (result.isTensor()) {
  98. return result.toTensor();
  99. }
  100. return result.toTensorVector()[0];
  101. }
  102. } // namespace detail
  103. // This CppCommHook interface only requires implementing runHook method that
  104. // potentially uses a state.
  105. template <typename T>
  106. class CppCommHookInterface : public CommHookInterface {
  107. public:
  108. explicit CppCommHookInterface(T& state) : state_(state) {}
  109. ~CppCommHookInterface() override = default;
  110. at::Tensor parseHookResult(const c10::IValue& result) override {
  111. return detail::parseCppCommHookResult(result);
  112. }
  113. protected:
  114. T state_;
  115. };
  116. } // namespace c10d