reducer.hpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. #pragma once
  2. #include <atomic>
  3. #include <memory>
  4. #include <mutex>
  5. #include <tuple>
  6. #include <unordered_map>
  7. #include <vector>
  8. #include <ATen/core/ivalue_inl.h>
  9. #include <c10/macros/Macros.h>
  10. #include <c10/util/intrusive_ptr.h>
  11. #include <c10d/ProcessGroup.hpp>
  12. #include <c10d/Utils.hpp>
  13. #include <c10d/comm.hpp>
  14. #include <c10d/debug.h>
  15. #include <c10d/reducer_timer.hpp>
  16. #include <c10d/default_comm_hooks.hpp>
  17. #include <torch/csrc/autograd/function.h>
  18. #include <torch/csrc/autograd/profiler.h>
  19. #include <torch/csrc/autograd/variable.h>
  20. #ifndef _WIN32
  21. #include <torch/csrc/distributed/autograd/context/context.h>
  22. #endif
  23. namespace c10d {
  24. constexpr int kDefaultFirstBucketBytes = int(1024 * 1024);
  25. constexpr int kDefaultBucketBytesCap = int(25 * 1024 * 1024);
  26. // Collect runtime stats once for every kDDPRuntimeLoggingSampleRate iterations.
  27. constexpr int kDDPRuntimeLoggingSampleRate = 100;
  28. // Forward declaration
  29. class Logger;
  30. // Local accumulator type for a single bucket.
  31. struct BucketAccumulator {
  32. std::vector<size_t> indices;
  33. size_t size = 0;
  34. size_t size_limit = 0;
  35. };
  36. class TORCH_API Reducer {
  37. public:
  38. // The constructor takes a list of variables (i.e. parameters) for this
  39. // process's single model replica (as DDP assumes single-process
  40. // single-device). The bucket assignment for this reducer, `bucket_indices`,
  41. // is specified as a list of buckets, each of which is specified as a list of
  42. // indices into the bucket's `variables` list.
  43. explicit Reducer(
  44. std::vector<at::Tensor> params,
  45. std::vector<std::vector<size_t>> bucket_indices,
  46. std::vector<size_t> per_bucket_size_limits,
  47. c10::intrusive_ptr<c10d::ProcessGroup> process_group,
  48. std::vector<bool> expect_sparse_gradients,
  49. int64_t bucket_bytes_cap,
  50. bool find_unused_parameters,
  51. bool gradient_as_bucket_view,
  52. std::unordered_map<size_t, std::string> param_names,
  53. int64_t first_bucket_bytes_cap);
  54. ~Reducer() noexcept(false);
  55. // To (re-)initialize bucket assignment, pass a list of buckets, each of
  56. // which is specified by a list of indices in the bucket's `variables` list.
  57. // This function performs validation that the variables within a bucket
  58. // all live on the same device and have the same dimensionality.
  59. void initialize_buckets(std::vector<std::vector<size_t>> bucket_indices);
  60. // This function is called when the forward function has produced an output,
  61. // and the user wishes to reduce gradients in the backwards pass.
  62. // If they don't, and wish to accumulate gradients before reducing them,
  63. // a call to this function can simply be omitted.
  64. void prepare_for_backward(const std::vector<at::Tensor>& outputs);
  65. // Called at the begginning of forward() inside DistributedDataParallel,
  66. // right now it caputures the starting time of forward in each iteration.
  67. void prepare_for_forward();
  68. // Returns the relative time in nanoseconds when gradients were ready,
  69. // with respect to the time `prepare_for_backward` was called. The
  70. // vector is for parameters for a single model replica.
  71. std::vector<int64_t> get_backward_stats() const {
  72. return backward_stats_;
  73. }
  74. // Registers a hook to the reducer. The hook is `CommHookInterface`
  75. // type to allow both Python and CPP hooks. This function can only
  76. // be called once before calling backward.
  77. // Cannot combine with the call of `register_builtin_comm_hook`.
  78. void register_comm_hook(std::unique_ptr<CommHookInterface> iface);
  79. // Registers a built-in C++ comm hook to the reducer. This function can only
  80. // be called once before calling backward.
  81. // Cannot combine with the call of `register_comm_hook`.
  82. void register_builtin_comm_hook(c10d::BuiltinCommHookType comm_hook_type);
  83. // Runs allreduce or installed communication hook given GradBucket instance.
  84. c10::intrusive_ptr<c10::ivalue::Future> run_comm_hook(
  85. GradBucket& grad_bucket);
  86. // Returns gradient buckets in sequential order of buckets_. This is the order
  87. // in which buckets are reduced across processes. If return_zero_tensors=true,
  88. // will return zero tensors of the same shape instead of the true tensors.
  89. std::vector<c10d::GradBucket> get_grad_buckets(
  90. bool return_zero_tensors = true) const;
  91. // Rebuild buckets based on rebuilt_params_ and rebuilt_param_indices_
  92. // according to when tensors received grads in the backward pass.
  93. // TODO this function makes broadcast communication call and
  94. // could be overlapped with next forward() call, thus
  95. // it could be async. Will make it async when rebuilding buckets for
  96. // find_unused_parameters = true case, as we could rebuild buckets more than
  97. // once for find_unused_parameters = true case, where subgraphs are trained
  98. // and parameter indices order may change more frequently.
  99. // For find_unused_parameters = false case, buckets are only rebuilt once,
  100. // the performance cost is negligible. Returns true if the buckets were
  101. // rebuilt.
  102. bool rebuild_buckets();
  103. // Install futures that should be awaited at end of backwards. Currently these
  104. // are only used by user-defined custom buffer reduction hooks, but can be generalized
  105. // to any user-originating futures that need to be awaited.
  106. void install_futures(c10::List<c10::intrusive_ptr<c10::ivalue::Future>> futs);
  107. // Returns true if we should rebuild buckets, else false. We only rebuild
  108. // buckets once after the first iteration and never rebuild them if
  109. // find_unused_parameters_.
  110. inline bool should_rebuild_buckets() const {
  111. return (static_graph_ || !find_unused_parameters_) && !has_rebuilt_bucket_;
  112. }
  113. // Pushes all parameters to be rebuilt.
  114. void push_rebuilt_params_for_all_indices();
  115. // Creates and sets ForwardPassWorkHandle given a ProcessGroup::Work and the
  116. // corresponding tensor being reduced.
  117. void set_forward_pass_work_handle(
  118. c10::intrusive_ptr<c10d::ProcessGroup::Work> forwardPassWorkHandle,
  119. bool useStaticWorldSize);
  120. // Retrieve on-device tensors used to track locally unused parameters. It is
  121. // a tensor where index i = 1 if the Variable with that index has been used.
  122. at::Tensor get_local_used_map_on_device() const;
  123. // An function for users to set sample_rate of collecting
  124. // runtime stats. The time stats will be recorded for the
  125. // first 10 iterations, after 10 iteratons time stats will be
  126. // recorded once every "sample_rate" training iterations.
  127. void set_ddp_runtime_logging_sample_rate(int sample_rate);
  128. // Specify the training graph is static.
  129. void set_static_graph();
  130. // Delay all reduce to be after all gradients' calculation is complete.
  131. void delay_all_reduce();
  132. // Weak reference to associated DDP logger. The reference is weak to avoid
  133. // refcycle between reducer and logger.
  134. void set_logger(std::weak_ptr<c10d::Logger> logger);
  135. // When graph is not explicitly set by user as static and has unused
  136. // parameters, this will return whether the graph has been static until the
  137. // current iteration, which means unused params set has not changed.
  138. bool ddp_graph_static();
  139. protected:
  140. // Forward declaration.
  141. struct Bucket;
  142. void push_rebuilt_params(const size_t& index);
  143. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  144. mutable std::mutex mutex_;
  145. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  146. const std::vector<at::Tensor> params_;
  147. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  148. const c10::intrusive_ptr<::c10d::ProcessGroup> process_group_;
  149. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  150. std::vector<bool> expect_sparse_gradients_;
  151. std::vector<std::shared_ptr<torch::autograd::Node>>
  152. grad_accumulators_; // NOLINT(cppcoreguidelines-non-private-member-variables-in-classes)
  153. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  154. std::unordered_map<torch::autograd::Node*, size_t> gradAccToVariableMap_;
  155. std::vector<std::pair<uintptr_t, std::shared_ptr<torch::autograd::Node>>>
  156. hooks_; // NOLINT(cppcoreguidelines-non-private-member-variables-in-classes)
  157. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  158. bool expect_autograd_hooks_;
  159. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  160. bool require_finalize_;
  161. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  162. size_t next_bucket_;
  163. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  164. bool has_marked_unused_parameters_;
  165. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  166. const bool find_unused_parameters_;
  167. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  168. const bool gradient_as_bucket_view_;
  169. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  170. std::vector<size_t> unused_parameters_;
  171. // Previous iteration's unused params, used for checking if unused parameters
  172. // change between iterations. Only filled during the first backwards call.
  173. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  174. std::vector<size_t> prev_iteration_unused_parameters_;
  175. // Whether graph is static or not. When user does not explicitly set static
  176. // graph, the only possible dynamism is set of unused parameters changing
  177. // between iterations which is tracked by this flag.
  178. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  179. bool ddp_graph_static_{true};
  180. // Locally used parameter maps indicating if parameters are used locally
  181. // during the current iteration or no_sync session if no_sync is on.
  182. // Each map is a one-dim int32 tensor of number of parameters. These tensors
  183. // are marked in autograd_hook to indicate the corresponding param has been
  184. // used, and get allreduced in the end of backward step of current iteration
  185. // or no_sync session for figuring out the globally unused parameters.
  186. //
  187. // local_used_map_: CPU tensor for bookkeeping locally used params
  188. // local_used_map_dev_: dev tensor for reducing globally unused params
  189. at::Tensor local_used_map_;
  190. at::Tensor local_used_map_dev_;
  191. // Indicate that reduction is done and D2H copy is done as well.
  192. bool local_used_map_reduced_;
  193. // Weak pointer to associated DDP logger.
  194. std::weak_ptr<c10d::Logger> logger_;
  195. // List of futures installed by Reducer::install_futures that should be awaited
  196. // at the end of backwards pass.
  197. c10::optional<c10::List<c10::intrusive_ptr<c10::ivalue::Future>>> installed_futures_{c10::nullopt};
  198. // Work handle for allreduce on local_used_map_
  199. c10::intrusive_ptr<c10d::ProcessGroup::Work> local_used_work_;
  200. void mark_variable_ready_dense(size_t variable_index);
  201. void mark_variable_ready_sparse(size_t variable_index);
  202. void mark_variable_ready(size_t variable_index);
  203. void autograd_hook(size_t index);
  204. void mark_bucket_ready(size_t bucket_index);
  205. void finalize_bucket_dense(Bucket& bucket);
  206. void finalize_backward();
  207. // Returns list of model parameters corresponding to the given bucket.
  208. // bucket_index is a key to cache after buckets are rebuilt, after which this
  209. // mapping never changes.
  210. std::vector<at::Tensor> get_variables_for_bucket(
  211. size_t bucket_index, const Bucket& bucket) const;
  212. // Asserts that the reduction for the previous iteration has finished before
  213. // rebuilding buckets or kicking off the next one.
  214. void ensure_prior_reduction_finished();
  215. // Broadcast rebuilt buckets from rank 0 to other ranks before initializing
  216. // the buckets
  217. void sync_bucket_indices(std::vector<std::vector<size_t>>& bucket_indices);
  218. // We'd like to use DistAutogradContext::GradCallback here but dist autograd
  219. // doesn't exist under Windows. So we just directly use the concrete type but
  220. // to preserve and enforce our original intent we do a static assert when dist
  221. // autograd is available.
  222. using GradCallback = std::function<bool(at::Tensor&)>;
  223. #ifndef _WIN32
  224. static_assert(
  225. std::is_same<
  226. GradCallback,
  227. torch::distributed::autograd::DistAutogradContext::GradCallback>::
  228. value,
  229. "");
  230. #endif
  231. void runGradCallbackForVariable(at::Tensor& variable, GradCallback&& cb);
  232. // This function is called inside `initialize_buckets()`. It initializes both
  233. // `bucket_views_in` and `bucket_views_out` with views for each variable's
  234. // gradient into the bucket's flattened `gradients` tensor. Views serve as
  235. // entry points to `copy_()` each grad's data in/out of the flattened
  236. // `gradients` tensor.
  237. void initialize_bucket_views(Bucket& bucket);
  238. // This function is called inside `finalize_backward`, it happens only if
  239. // DDP communication hook was registered to recreate just bucket_views_out
  240. // with the result of `future_work`.
  241. void populate_bucket_views_out(Bucket& bucket, at::Tensor& tensor);
  242. // If gradient_as_bucket_view_ is false, after allreduce buckets,
  243. // copy bucket results back to grads.
  244. void copy_bucket_to_grad(
  245. at::Tensor& variable,
  246. Reducer::Bucket& bucket,
  247. size_t intra_bucket_index,
  248. bool global_unused);
  249. // Check layout of grad and bucket_view before copying the grad to bucket.
  250. void check_grad_layout(const at::Tensor& grad, const at::Tensor& bucket_view);
  251. // A bucket contains [1..N] gradients to be reduced, where the gradients
  252. // have the same dtype and device.
  253. // Coalescing gradients together before reducing can result in lower overhead
  254. // and/or faster time to completion. Coalescing requires the constituent
  255. // gradients to have the same dtype and device, and the resulting flattened
  256. // tensor uses that common dtype and device. The flattened tensor is filled
  257. // as the corresponding gradients are computed (triggered by autograd hooks),
  258. // and the buckets are reduced in a predetermined order consistent across
  259. // processes.
  260. struct Bucket {
  261. // Gradients of the bucket flattened into a 1-dimensional tensor
  262. at::Tensor gradients;
  263. // Views into the `gradients` tensor for each individual gradient
  264. // Each view is created with layout (size and stride) matching the
  265. // gradient's expected layout (see the "Gradient Layout Contract" in
  266. // torch/csrc/autograd/functions/accumulate_grad.h).
  267. // `bucket_views_in[i].copy_(grad)` and `grad.copy_(bucket_views_out[i])`
  268. // provide convenient ways to copy gradient data in/out of `gradients`,
  269. // respectively.
  270. // We keep both `bucket_views_in` and `bucket_views_out` because
  271. // registering a DDP communication hook may re-initialize
  272. // `bucket_views_out` with the value of the hook's `future_work` but we
  273. // still need separate views into the bucket's original flattened gradient
  274. // to copy in gradient data.
  275. std::vector<at::Tensor> bucket_views_in;
  276. std::vector<at::Tensor> bucket_views_out;
  277. // Variables whose gradients are held in this bucket
  278. // We use refcounted tensors here so that we can easily unflatten the
  279. // bucket's flattened `gradients` tensor into the participating variables
  280. // after reduction has completed.
  281. std::vector<at::Tensor> variables;
  282. // Per-variable offset/length into the flattened `gradients` tensor and
  283. // the corresponding `GradBucket` instance for communication hooks
  284. std::vector<size_t> offsets;
  285. std::vector<size_t> lengths;
  286. // Per-variable sizes slicing into the bucket's `gradients` tensor
  287. std::vector<c10::IntArrayRef> sizes_vec;
  288. // Number of gradients left to be computed before the bucket is ready to
  289. // be reduced
  290. size_t pending;
  291. // Global indices of participating variables in the bucket
  292. std::vector<size_t> variable_indices;
  293. // Future work handle for DDP communication hook
  294. // If no hook is registered, a temporary vanilla allreduce hook is used.
  295. c10::intrusive_ptr<at::ivalue::Future> future_work;
  296. // If this bucket should expect a single sparse gradient
  297. // If `true`, then this implies that `bucket.variables.size() == 1`.
  298. bool expect_sparse_gradient = false;
  299. // TODO(@pietern)
  300. // Memory copies from gradient tensors into the bucket are potentially
  301. // done on different CUDA streams. We record an event for every copy
  302. // so that we can synchronize with them prior to kicking off the reduction.
  303. // std::vector<at::cuda::CUDAEvent> events;
  304. };
  305. std::vector<Bucket> buckets_;
  306. // A variable locator locates a particular variable in the reducer's buckets
  307. struct VariableLocator {
  308. // Index of the bucket containing the variable in the `buckets_` vector
  309. size_t bucket_index;
  310. // Index of the variable in the bucket, which may be used consistently
  311. // across `bucket_views_in`, `bucket_views_out`, `variables`, `offsets`,
  312. // `lengths`, `sizes_vec`, and `variable_indices` in `Bucket`
  313. size_t intra_bucket_index;
  314. VariableLocator() = default;
  315. VariableLocator(size_t bucket_index_, size_t intra_bucket_index_) {
  316. bucket_index = bucket_index_;
  317. intra_bucket_index = intra_bucket_index_;
  318. }
  319. };
  320. // Map the index of a variable to its location in the bucket structure.
  321. std::vector<VariableLocator> variable_locators_;
  322. // track the number of iterations to synchronize grads in training so far.
  323. long num_iterations_;
  324. // track the number of buckets that have been ready for
  325. // communication calls like allReduce or communication hooks.
  326. int num_buckets_ready_;
  327. // Timing information.
  328. int64_t backward_compute_start_time_ = -1;
  329. std::unique_ptr<Timer> timer_;
  330. // We collect the relative timestamp of every gradient being ready
  331. // when executing autograd. This can be used to derive a timeline of
  332. // the point in time buckets were ready, or ideal bucket assignment/ordering.
  333. std::vector<int64_t> backward_stats_;
  334. bool should_collect_runtime_stats();
  335. void record_forward_compute_start_time();
  336. void record_backward_compute_start_time();
  337. void record_backward_compute_end_time();
  338. void record_backward_comm_start_time();
  339. void record_backward_comm_end_time();
  340. int get_ddp_runtime_logging_sample_rate();
  341. int ddp_runtime_logging_sample_rate_ = kDDPRuntimeLoggingSampleRate;
  342. bool is_multi_device_module_ = false;
  343. // Following variables are to help build dynamic bucket order
  344. bool has_rebuilt_bucket_;
  345. std::vector<at::Tensor> rebuilt_params_;
  346. std::vector<int64_t> rebuilt_param_indices_;
  347. const int64_t bucket_bytes_cap_;
  348. #ifndef _WIN32
  349. struct RpcContext {
  350. using ContextPtr = torch::distributed::autograd::ContextPtr;
  351. // The shared_ptr is to hold the context instance.
  352. ContextPtr context_ptr_holder;
  353. std::atomic<ContextPtr::element_type*> context_ptr{nullptr};
  354. void set(ContextPtr&& new_context_ptr);
  355. };
  356. RpcContext rpc_context_;
  357. #endif
  358. // A struct containing work handle and tensor for allreduce scheduled in
  359. // forward pass, if applicable.
  360. struct ForwardPassAllreduceWork {
  361. c10::intrusive_ptr<c10d::ProcessGroup::Work> workHandle;
  362. at::Tensor resultTensor;
  363. // whether we should divide by the initial world_size or the no. of
  364. // remaining DDP ranks.
  365. bool useStaticWorldSize;
  366. };
  367. // Handle for the currently scheduled allreduce in the forward pass, if
  368. // applicable.
  369. ForwardPassAllreduceWork forwardPassWorkHandle_;
  370. // Division factor for reduction of gradients.
  371. // Equal to the process group size, with an exception of handling uneven
  372. // input.
  373. int div_factor_;
  374. bool static_graph_;
  375. // Key: size_t (index), Value: the number of times that a variable's
  376. // autograd_hook() should be triggered before marking this variable's grad as
  377. // ready for communication. Map will not change after 1st iteration.
  378. std::unordered_map<size_t, int> numGradHooksTriggeredMap_;
  379. // Key: size_t (index), Value: the number of times that a variable's
  380. // autograd_hook() are left to be triggered before marking this variable's
  381. // grad as ready for communication. Map will change after 1st iteration to
  382. // track a grad is ready for communication or not.
  383. std::unordered_map<size_t, int> numGradHooksTriggeredMapPerIteration_;
  384. private:
  385. // reset counting for buckets before backward starts
  386. void reset_bucket_counting();
  387. // search unused parameters beore backward starts
  388. void search_unused_parameters(
  389. const std::vector<torch::autograd::Variable>& outputs);
  390. void set_divide_factor();
  391. // kick off all reduce for the ready bucket
  392. void all_reduce_bucket(Bucket& bucket);
  393. // kick off all reduce to local used map, it can help find global unused
  394. // parameters
  395. void all_reduce_local_used_map();
  396. // initialize locally used parameter maps
  397. void initialize_local_used_map();
  398. // get current cuda stream
  399. const c10::Stream get_current_stream();
  400. bool dynamic_graph_find_unused();
  401. bool static_graph_first_iteration();
  402. bool static_graph_after_first_iteration();
  403. // comm_hook_ is used to access the DDP communication hook if registered.
  404. std::unique_ptr<CommHookInterface> comm_hook_;
  405. // Debug level setting. It is parsed once when Reducer is constructed, and
  406. // remains the same across a single invocation of DDP training.
  407. DebugLevel ddp_debug_level_;
  408. // Mapping of variable index to fully qualified name of model to notify users
  409. // about errors when certain parameters do not get gradient.
  410. std::unordered_map<size_t, std::string> param_names_;
  411. // Variable indices stored sequentially in order of when the gradient is ready
  412. // for the current backwards pass.
  413. std::vector<int> grad_ready_order_indices_;
  414. // Bytes capacity of first bucket, can be configured by user
  415. int64_t first_bucket_bytes_cap_;
  416. // Per iteration set of parameter indices that have been marked ready.
  417. std::unordered_set<size_t> perIterationReadyParams_;
  418. // Retrieves parameter names that have not been marked as ready as part of
  419. // previous iteration.
  420. std::vector<std::string> getUnmarkedParamsForIteration();
  421. // Retrives parameter indices that have not been marked as ready as part of
  422. // previous iteration.
  423. std::vector<size_t> getUnmarkedParamIndicesForIteration();
  424. // Raises appropriate error if mark_variable_ready is called on the same
  425. // variable twice, which is unexpected.
  426. void checkAndRaiseMarkedTwiceError(size_t curVariableIndex);
  427. // Retrieves parameter corresponding to the given VariableIndex.
  428. at::Tensor& get_param_from_index(size_t index);
  429. // Cached bucket index to model parameter mapping. Populated after buckets
  430. // are rebuilt after which this mapping is static.
  431. mutable std::unordered_map<size_t, std::vector<at::Tensor>> cached_variables_for_bucket_;
  432. friend class Logger;
  433. };
  434. // This is equivalent to take_tensors but returns indices into the
  435. // tensor list argument for bucket assignment. Also, it is aware
  436. // of device placement and will not allow buckets to span devices.
  437. // The index of tensors[i] assigned to bucket is tensor_indices[i],
  438. // when tensor_indices is empty, the index of tensors[i] assigned to
  439. // bucket is i.
  440. TORCH_API std::tuple<std::vector<std::vector<size_t>>, std::vector<size_t>>
  441. compute_bucket_assignment_by_size(
  442. const std::vector<at::Tensor>& tensors,
  443. const std::vector<size_t>& bucket_size,
  444. const std::vector<bool>& expect_sparse_gradient = {},
  445. const std::vector<int64_t>& tensor_indices = {},
  446. const c10::optional<std::weak_ptr<c10d::Logger>>& logger = {});
  447. // Verify models across all processes are the same as model on rank 0 with
  448. // respect to no. of params and matching dtype/size/layout.
  449. TORCH_API void verify_params_across_processes(
  450. const c10::intrusive_ptr<c10d::ProcessGroup>& process_group,
  451. const std::vector<at::Tensor>& params,
  452. const c10::optional<std::weak_ptr<c10d::Logger>>& logger);
  453. } // namespace c10d