#pragma once #include #include namespace c10d { // PyProcessGroup is a pybind11 trampoline class to allow a Python // class to inherit from torch.distributed.ProcessGroup class PyProcessGroup : public ProcessGroup { public: // PyWork is a pybind11 trampoline class to allow a Python // class to inherit from torch.distributed.Work class PyWork : public ProcessGroup::Work { public: PyWork() = default; bool wait(std::chrono::milliseconds timeout = kNoTimeout) override { PYBIND11_OVERRIDE( bool, /* Return type */ ProcessGroup::Work, /* Parent class */ wait, /* Name of function in C++ */ timeout); } }; using ProcessGroup::ProcessGroup; const std::string getBackendName() const override { PYBIND11_OVERRIDE_PURE( std::string, /* Return type */ ProcessGroup, /* Parent class */ getBackendName, /* Name of function in C++ */ ); } c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override { PYBIND11_OVERRIDE( c10::intrusive_ptr, /* Return type */ ProcessGroup, /* Parent class */ allgather, /* Name of function in C++ */ outputTensors, inputTensors, opts); } c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override { PYBIND11_OVERRIDE( c10::intrusive_ptr, /* Return type */ ProcessGroup, /* Parent class */ allreduce, /* Name of function in C++ */ tensors, opts); } c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) { PYBIND11_OVERRIDE( c10::intrusive_ptr, /* Return type */ ProcessGroup, /* Parent class */ barrier, /* Name of function in C++ */ opts); } c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override { PYBIND11_OVERRIDE( c10::intrusive_ptr, /* Return type */ ProcessGroup, /* Parent class */ broadcast, /* Name of function in C++ */ tensors, opts); } c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override { PYBIND11_OVERRIDE( c10::intrusive_ptr, /* Return type */ ProcessGroup, /* Parent class */ reduce_scatter, /* Name of function in C++ */ outputTensors, inputTensors, opts); } c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override { PYBIND11_OVERRIDE( c10::intrusive_ptr, /* Return type */ ProcessGroup, /* Parent class */ send, /* Name of function in C++ */ tensors, dstRank, tag); } c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override { PYBIND11_OVERRIDE( c10::intrusive_ptr, /* Return type */ ProcessGroup, /* Parent class */ recv, /* Name of function in C++ */ tensors, srcRank, tag); } }; } // namespace c10d