ProcessGroupWrapper.hpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. #pragma once
  2. #ifdef USE_C10D_GLOO
  3. #include <c10d/ProcessGroup.hpp>
  4. #include <c10d/ProcessGroupGloo.hpp>
  5. #include <c10d/Types.hpp>
  6. #include <c10d/Utils.hpp>
  7. namespace c10d {
  8. class TORCH_API ProcessGroupWrapper : public ProcessGroup {
  9. public:
  10. explicit ProcessGroupWrapper(
  11. c10::intrusive_ptr<ProcessGroup> pg,
  12. c10::intrusive_ptr<ProcessGroupGloo> glooPg);
  13. const std::string getBackendName() const override;
  14. c10::intrusive_ptr<ProcessGroup::Work> broadcast(
  15. std::vector<at::Tensor>& data,
  16. const BroadcastOptions& opts = BroadcastOptions()) override;
  17. c10::intrusive_ptr<ProcessGroup::Work> allreduce(
  18. std::vector<at::Tensor>& data,
  19. const AllreduceOptions& opts = AllreduceOptions()) override;
  20. c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
  21. std::vector<at::Tensor>& tensors,
  22. const AllreduceCoalescedOptions& opts =
  23. AllreduceCoalescedOptions()) override;
  24. c10::intrusive_ptr<ProcessGroup::Work> reduce(
  25. std::vector<at::Tensor>& tensors,
  26. const ReduceOptions& opts = ReduceOptions()) override;
  27. c10::intrusive_ptr<ProcessGroup::Work> allgather(
  28. std::vector<std::vector<at::Tensor>>& outputTensors,
  29. std::vector<at::Tensor>& inputTensors,
  30. const AllgatherOptions& opts = AllgatherOptions()) override;
  31. c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
  32. at::Tensor& outputBuffer,
  33. at::Tensor& inputBuffer,
  34. const AllgatherOptions& opts = AllgatherOptions()) override;
  35. // This function is deprecated and will be moved out of ProcessGroup to comms:
  36. // * do not add dependencies on this function,
  37. // * do not implement it in your ProcessGroup, implement _allgather_base
  38. // instead.
  39. c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
  40. std::vector<std::vector<at::Tensor>>& outputTensorLists,
  41. std::vector<at::Tensor>& inputTensors,
  42. const AllgatherOptions& opts = AllgatherOptions()) override;
  43. c10::intrusive_ptr<ProcessGroup::Work> gather(
  44. std::vector<std::vector<at::Tensor>>& outputTensors,
  45. std::vector<at::Tensor>& inputTensors,
  46. const GatherOptions& opts = GatherOptions()) override;
  47. c10::intrusive_ptr<ProcessGroup::Work> scatter(
  48. std::vector<at::Tensor>& outputTensors,
  49. std::vector<std::vector<at::Tensor>>& inputTensors,
  50. const ScatterOptions& opts = ScatterOptions()) override;
  51. c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
  52. std::vector<at::Tensor>& outputTensors,
  53. std::vector<std::vector<at::Tensor>>& inputTensors,
  54. const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
  55. c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
  56. at::Tensor& outputTensor,
  57. at::Tensor& inputTensor,
  58. std::vector<int64_t>& outputSplitSizes,
  59. std::vector<int64_t>& inputSplitSizes,
  60. const AllToAllOptions& opts = AllToAllOptions()) override;
  61. c10::intrusive_ptr<ProcessGroup::Work> alltoall(
  62. std::vector<at::Tensor>& outputTensors,
  63. std::vector<at::Tensor>& inputTensors,
  64. const AllToAllOptions& opts = AllToAllOptions()) override;
  65. void monitoredBarrier(const BarrierOptions& opts, bool waitAllRanks = false)
  66. override;
  67. // Agrees on an initial sequence number for the whole group by having rank 0
  68. // create it and broadcast it to other ranks using the store. Only implemented
  69. // for GLOO and NCCL backends currently.
  70. // dont implement this
  71. void setSequenceNumberForGroup() override;
  72. // Retrieves the current sequence number for the whole group, which should be
  73. // in sync. If the returned number is not consistent across the group, it
  74. // may indicate that there is some sort of collective desynchronization.
  75. uint64_t getSequenceNumberForGroup() override; // just call underlying
  76. c10::intrusive_ptr<ProcessGroup::Work> send(
  77. std::vector<at::Tensor>& tensors,
  78. int dstRank,
  79. int tag) override;
  80. c10::intrusive_ptr<ProcessGroup::Work> recv(
  81. std::vector<at::Tensor>& tensors,
  82. int srcRank,
  83. int tag) override;
  84. c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
  85. std::vector<at::Tensor>& tensors,
  86. int tag) override;
  87. c10::intrusive_ptr<ProcessGroup::Work> barrier(
  88. const BarrierOptions& opts = BarrierOptions()) override;
  89. c10::intrusive_ptr<ProcessGroup> getWrappedPg() const;
  90. private:
  91. // Underlying process group that actual application collectives will be
  92. // dispatched to
  93. c10::intrusive_ptr<ProcessGroup> pg_;
  94. // Gloo process group responsible for internal coordination such as monitored
  95. // barrier, sequence number checking, collective fingerprint collecting.
  96. c10::intrusive_ptr<ProcessGroupGloo> glooPg_;
  97. // Conducts several checks to ensure that the underlying collective is well
  98. // formed with the goal of notifying the user about incorrect collective use
  99. // in the application.
  100. void runCollectiveChecks(
  101. OpType op_type,
  102. const std::vector<at::Tensor>& tensors) const;
  103. };
  104. } // namespace c10d
  105. #endif // USE_C10D_GLOO