| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- #pragma once
- #include <ATen/ATen.h>
- #include <ATen/core/ivalue.h>
- #include <c10d/ProcessGroup.hpp>
- #include <torch/csrc/Export.h>
- namespace c10d {
- // Broadcast many tensors to all processes in the process group.
- TORCH_API void broadcast_coalesced(
- c10::intrusive_ptr<c10d::ProcessGroup> process_group,
- at::TensorList tensors,
- size_t buffer_size,
- int rank = 0);
- // This class passes bucket contents tensor to DDP communication hook.
- class TORCH_API GradBucket {
- public:
- explicit GradBucket(
- size_t index,
- size_t bucket_count,
- const at::Tensor& tensor,
- const std::vector<size_t>& offsets,
- const std::vector<size_t>& lengths,
- const std::vector<c10::IntArrayRef>& sizes_vec,
- const std::vector<at::Tensor>& parameters)
- : index_(index),
- bucket_count_(bucket_count),
- buffer_(tensor),
- offsets_(offsets),
- lengths_(lengths),
- sizes_vec_(sizes_vec),
- parameters_(parameters) {}
- // Returns the index of the bucket, which is unique across all the buckets.
- size_t getIndex() const {
- return index_;
- }
- const at::Tensor& getBuffer() const {
- return buffer_;
- }
- // Returns a mutable buffer compared with the above method.
- at::Tensor& getBufferRef() {
- return buffer_;
- }
- // Overwrites the buffer at a specific index.
- void setBuffer(at::Tensor& buffer) {
- buffer_ = buffer;
- }
- // Each tensor in the list that getGradients corresponds to a
- // parameter.
- std::vector<at::Tensor> getGradients() const;
- // Returns model parameters belonging to this bucket. They are returned in the
- // same order as gradient tensors via getGradients(). For example,
- // getParameters[i] will have its gradient stored in
- // getGradients[i]
- const std::vector<at::Tensor> getParameters() const {
- return parameters_;
- }
- // Returns whther this bucket is the last bucket to allreduce in an iteration.
- bool isLast() const {
- return index_ == bucket_count_ - 1;
- }
- private:
- size_t index_;
- size_t bucket_count_;
- at::Tensor buffer_;
- // Per-variable info in buffer_.
- std::vector<size_t> offsets_;
- std::vector<size_t> lengths_;
- std::vector<c10::IntArrayRef> sizes_vec_;
- // Model parameters for this bucket.
- const std::vector<at::Tensor> parameters_;
- };
- // Base class of both `PythonCommHook` and `CppCommHook`.
- // Requires implementing 1) `runHook` method that communicates gradients
- // asynchronously, and 2) `parseHookResult` method that converts the hook
- // result into a tensor.
- class TORCH_API CommHookInterface {
- public:
- virtual ~CommHookInterface() = default;
- // Passes the input grad bucket to the registered communication hook.
- // Once the tensor in the bucket are ready, kicks off the hook asynchronously
- // and returns a future that holds the communication results.
- virtual c10::intrusive_ptr<c10::ivalue::Future> runHook(
- GradBucket& bucket) = 0;
- // Returns the resulting tensor once the communication hook result is
- // ready. The resulting tensor will then be copied to the grads of
- // individual parameters.
- virtual at::Tensor parseHookResult(
- const c10::IValue& result) = 0;
- };
- namespace detail {
- // This helper function is called both by CppCommHookInterface below and inside
- // reducer.
- inline at::Tensor parseCppCommHookResult(
- const c10::IValue& result) {
- TORCH_INTERNAL_ASSERT(
- result.isTensor() || result.isTensorList(),
- "expected the hook result is either a Tensor or a TensorList");
- if (result.isTensor()) {
- return result.toTensor();
- }
- return result.toTensorVector()[0];
- }
- } // namespace detail
- // This CppCommHook interface only requires implementing runHook method that
- // potentially uses a state.
- template <typename T>
- class CppCommHookInterface : public CommHookInterface {
- public:
- explicit CppCommHookInterface(T& state) : state_(state) {}
- ~CppCommHookInterface() override = default;
- at::Tensor parseHookResult(const c10::IValue& result) override {
- return detail::parseCppCommHookResult(result);
- }
- protected:
- T state_;
- };
- } // namespace c10d
|