PyProcessGroup.hpp 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. #pragma once
  2. #include <c10d/ProcessGroup.hpp>
  3. #include <torch/csrc/utils/pybind.h>
  4. namespace c10d {
  5. // PyProcessGroup is a pybind11 trampoline class to allow a Python
  6. // class to inherit from torch.distributed.ProcessGroup
  7. class PyProcessGroup : public ProcessGroup {
  8. public:
  9. // PyWork is a pybind11 trampoline class to allow a Python
  10. // class to inherit from torch.distributed.Work
  11. class PyWork : public ProcessGroup::Work {
  12. public:
  13. PyWork() = default;
  14. bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
  15. PYBIND11_OVERRIDE(
  16. bool, /* Return type */
  17. ProcessGroup::Work, /* Parent class */
  18. wait, /* Name of function in C++ */
  19. timeout);
  20. }
  21. };
  22. using ProcessGroup::ProcessGroup;
  23. const std::string getBackendName() const override {
  24. PYBIND11_OVERRIDE_PURE(
  25. std::string, /* Return type */
  26. ProcessGroup, /* Parent class */
  27. getBackendName, /* Name of function in C++ */
  28. );
  29. }
  30. c10::intrusive_ptr<ProcessGroup::Work> allgather(
  31. std::vector<std::vector<at::Tensor>>& outputTensors,
  32. std::vector<at::Tensor>& inputTensors,
  33. const AllgatherOptions& opts = AllgatherOptions()) override {
  34. PYBIND11_OVERRIDE(
  35. c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
  36. ProcessGroup, /* Parent class */
  37. allgather, /* Name of function in C++ */
  38. outputTensors,
  39. inputTensors,
  40. opts);
  41. }
  42. c10::intrusive_ptr<ProcessGroup::Work> allreduce(
  43. std::vector<at::Tensor>& tensors,
  44. const AllreduceOptions& opts = AllreduceOptions()) override {
  45. PYBIND11_OVERRIDE(
  46. c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
  47. ProcessGroup, /* Parent class */
  48. allreduce, /* Name of function in C++ */
  49. tensors,
  50. opts);
  51. }
  52. c10::intrusive_ptr<ProcessGroup::Work> barrier(
  53. const BarrierOptions& opts = BarrierOptions()) {
  54. PYBIND11_OVERRIDE(
  55. c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
  56. ProcessGroup, /* Parent class */
  57. barrier, /* Name of function in C++ */
  58. opts);
  59. }
  60. c10::intrusive_ptr<ProcessGroup::Work> broadcast(
  61. std::vector<at::Tensor>& tensors,
  62. const BroadcastOptions& opts = BroadcastOptions()) override {
  63. PYBIND11_OVERRIDE(
  64. c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
  65. ProcessGroup, /* Parent class */
  66. broadcast, /* Name of function in C++ */
  67. tensors,
  68. opts);
  69. }
  70. c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
  71. std::vector<at::Tensor>& outputTensors,
  72. std::vector<std::vector<at::Tensor>>& inputTensors,
  73. const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
  74. PYBIND11_OVERRIDE(
  75. c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
  76. ProcessGroup, /* Parent class */
  77. reduce_scatter, /* Name of function in C++ */
  78. outputTensors,
  79. inputTensors,
  80. opts);
  81. }
  82. c10::intrusive_ptr<ProcessGroup::Work> send(
  83. std::vector<at::Tensor>& tensors,
  84. int dstRank,
  85. int tag) override {
  86. PYBIND11_OVERRIDE(
  87. c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
  88. ProcessGroup, /* Parent class */
  89. send, /* Name of function in C++ */
  90. tensors,
  91. dstRank,
  92. tag);
  93. }
  94. c10::intrusive_ptr<ProcessGroup::Work> recv(
  95. std::vector<at::Tensor>& tensors,
  96. int srcRank,
  97. int tag) override {
  98. PYBIND11_OVERRIDE(
  99. c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
  100. ProcessGroup, /* Parent class */
  101. recv, /* Name of function in C++ */
  102. tensors,
  103. srcRank,
  104. tag);
  105. }
  106. };
  107. } // namespace c10d