| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- #pragma once
- #include <c10d/ProcessGroup.hpp>
- #include <torch/csrc/utils/pybind.h>
- namespace c10d {
- // PyProcessGroup is a pybind11 trampoline class to allow a Python
- // class to inherit from torch.distributed.ProcessGroup
- class PyProcessGroup : public ProcessGroup {
- public:
- // PyWork is a pybind11 trampoline class to allow a Python
- // class to inherit from torch.distributed.Work
- class PyWork : public ProcessGroup::Work {
- public:
- PyWork() = default;
- bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
- PYBIND11_OVERRIDE(
- bool, /* Return type */
- ProcessGroup::Work, /* Parent class */
- wait, /* Name of function in C++ */
- timeout);
- }
- };
- using ProcessGroup::ProcessGroup;
- const std::string getBackendName() const override {
- PYBIND11_OVERRIDE_PURE(
- std::string, /* Return type */
- ProcessGroup, /* Parent class */
- getBackendName, /* Name of function in C++ */
- );
- }
- c10::intrusive_ptr<ProcessGroup::Work> allgather(
- std::vector<std::vector<at::Tensor>>& outputTensors,
- std::vector<at::Tensor>& inputTensors,
- const AllgatherOptions& opts = AllgatherOptions()) override {
- PYBIND11_OVERRIDE(
- c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
- ProcessGroup, /* Parent class */
- allgather, /* Name of function in C++ */
- outputTensors,
- inputTensors,
- opts);
- }
- c10::intrusive_ptr<ProcessGroup::Work> allreduce(
- std::vector<at::Tensor>& tensors,
- const AllreduceOptions& opts = AllreduceOptions()) override {
- PYBIND11_OVERRIDE(
- c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
- ProcessGroup, /* Parent class */
- allreduce, /* Name of function in C++ */
- tensors,
- opts);
- }
- c10::intrusive_ptr<ProcessGroup::Work> barrier(
- const BarrierOptions& opts = BarrierOptions()) {
- PYBIND11_OVERRIDE(
- c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
- ProcessGroup, /* Parent class */
- barrier, /* Name of function in C++ */
- opts);
- }
- c10::intrusive_ptr<ProcessGroup::Work> broadcast(
- std::vector<at::Tensor>& tensors,
- const BroadcastOptions& opts = BroadcastOptions()) override {
- PYBIND11_OVERRIDE(
- c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
- ProcessGroup, /* Parent class */
- broadcast, /* Name of function in C++ */
- tensors,
- opts);
- }
- c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
- std::vector<at::Tensor>& outputTensors,
- std::vector<std::vector<at::Tensor>>& inputTensors,
- const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
- PYBIND11_OVERRIDE(
- c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
- ProcessGroup, /* Parent class */
- reduce_scatter, /* Name of function in C++ */
- outputTensors,
- inputTensors,
- opts);
- }
- c10::intrusive_ptr<ProcessGroup::Work> send(
- std::vector<at::Tensor>& tensors,
- int dstRank,
- int tag) override {
- PYBIND11_OVERRIDE(
- c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
- ProcessGroup, /* Parent class */
- send, /* Name of function in C++ */
- tensors,
- dstRank,
- tag);
- }
- c10::intrusive_ptr<ProcessGroup::Work> recv(
- std::vector<at::Tensor>& tensors,
- int srcRank,
- int tag) override {
- PYBIND11_OVERRIDE(
- c10::intrusive_ptr<ProcessGroup::Work>, /* Return type */
- ProcessGroup, /* Parent class */
- recv, /* Name of function in C++ */
- tensors,
- srcRank,
- tag);
- }
- };
- } // namespace c10d
|