distributed.py 78 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742
  1. import sys
  2. import copy
  3. from dataclasses import dataclass
  4. from typing import Callable, Any, Type
  5. from enum import Enum, auto
  6. import inspect
  7. import itertools
  8. import logging
  9. import os
  10. import warnings
  11. from contextlib import contextmanager
  12. import torch
  13. import torch.distributed as dist
  14. from torch.autograd import Function, Variable
  15. from torch.distributed.algorithms.join import (
  16. Join,
  17. Joinable,
  18. JoinHook,
  19. )
  20. from torch.utils._pytree import tree_flatten, tree_unflatten
  21. RPC_AVAILABLE = False
  22. if dist.is_available():
  23. from torch.distributed.utils import (
  24. _verify_param_shape_across_processes,
  25. _sync_module_states,
  26. _to_kwargs,
  27. )
  28. from torch.distributed.distributed_c10d import ReduceOp, _get_default_group
  29. if torch.distributed.rpc.is_available():
  30. RPC_AVAILABLE = True
  31. from torch.distributed.rpc import RRef
  32. from torch._utils import _get_device_index
  33. from ..modules import Module
  34. from ._replicated_tensor_ddp_utils import _ddp_with_replicated_tensor_enabled
  35. from .scatter_gather import gather, is_namedtuple, scatter_kwargs # noqa: F401
  36. logger = logging.getLogger(__name__)
  37. def _tree_flatten_with_rref(output):
  38. output_is_rref = RPC_AVAILABLE and isinstance(output, RRef)
  39. if output_is_rref:
  40. output_tensor_list, treespec = tree_flatten(output.local_value())
  41. else:
  42. output_tensor_list, treespec = tree_flatten(output)
  43. # Need to return flattened tensors, spec to re-pack them, as well
  44. # as if the return type was actually an RRef to reconstruct.
  45. return output_tensor_list, treespec, output_is_rref
  46. def _tree_unflatten_with_rref(output, treespec, output_is_rref):
  47. output = tree_unflatten(output, treespec)
  48. if output_is_rref:
  49. output = RRef(output)
  50. return output
  51. def _find_tensors(obj):
  52. r"""
  53. Recursively find all tensors contained in the specified object.
  54. """
  55. if RPC_AVAILABLE and isinstance(obj, RRef):
  56. # If the current node is the owner of the RRef, unwrap it and try to
  57. # find Tensors.
  58. # TODO: Expand to remote RRefs.
  59. if obj.is_owner():
  60. return _find_tensors(obj.local_value())
  61. if isinstance(obj, torch.Tensor):
  62. return [obj]
  63. if isinstance(obj, (list, tuple)):
  64. return itertools.chain(*map(_find_tensors, obj))
  65. if isinstance(obj, dict):
  66. return itertools.chain(*map(_find_tensors, obj.values()))
  67. return []
  68. def _dump_DDP_relevant_env_vars():
  69. relevant_env_vars = [
  70. "RANK",
  71. "LOCAL_RANK",
  72. "WORLD_SIZE",
  73. "MASTER_PORT",
  74. "MASTER_ADDR",
  75. "CUDA_VISIBLE_DEVICES",
  76. "GLOO_SOCKET_IFNAME",
  77. "GLOO_DEVICE_TRANSPORT",
  78. "NCCL_SOCKET_IFNAME",
  79. "NCCL_BLOCKING_WAIT",
  80. "NCCL_DEBUG",
  81. "NCCL_DEBUG_SUBSYS",
  82. "NCCL_IB_DISABLE",
  83. # More NCCL env vars:
  84. "NCCL_P2P_DISABLE",
  85. "NCCL_P2P_LEVEL",
  86. "NCCL_SHM_DISABLE",
  87. "NCCL_SOCKET_NTHREADS",
  88. "NCCL_NSOCKS_PERTHREAD",
  89. "NCCL_BUFFSIZE",
  90. "NCCL_NTHREADS",
  91. "NCCL_RINGS",
  92. "NCCL_MAX_NCHANNELS",
  93. "NCCL_MIN_NCHANNELS",
  94. "NCCL_CHECKS_DISABLE",
  95. "NCCL_CHECK_POINTERS",
  96. "NCCL_LAUNCH_MODE",
  97. "NCCL_IB_HCA",
  98. "NCCL_IB_TIMEOUT",
  99. "NCCL_IB_RETRY_CNT",
  100. "NCCL_IB_GID_INDEX",
  101. "NCCL_IB_SL",
  102. "NCCL_IB_TC",
  103. "NCCL_IB_AR_THRESHOLD",
  104. "NCCL_IB_CUDA_SUPPORT",
  105. "NCCL_NET_GDR_LEVEL",
  106. "NCCL_NET_GDR_READ",
  107. "NCCL_SINGLE_RING_THRESHOLD",
  108. "NCCL_LL_THRESHOLD",
  109. "NCCL_TREE_THRESHOLD",
  110. "NCCL_ALGO",
  111. "NCCL_PROTO",
  112. "NCCL_IGNORE_CPU_AFFINITY",
  113. "NCCL_DEBUG_FILE",
  114. "NCCL_COLLNET_ENABLE",
  115. "NCCL_TOPO_FILE",
  116. "NCCL_TOPO_DUMP_FILE",
  117. "NCCL_ASYNC_ERROR_HANDLING",
  118. ]
  119. formatted_output = ""
  120. for var in relevant_env_vars:
  121. value = os.environ[var] if var in os.environ else "N/A"
  122. formatted_output += "env:%s=%s\n" % (var, value)
  123. print(formatted_output)
  124. class _BufferCommHookLocation(Enum):
  125. PRE_FORWARD = auto()
  126. POST_FORWARD = auto()
  127. @dataclass
  128. class _BufferCommHook:
  129. buffer_comm_hook: Callable
  130. buffer_comm_hook_state: Any
  131. buffer_comm_hook_location: _BufferCommHookLocation
  132. # Add a DDPSink to run various functions when backwards starts, such as
  133. # queueing call back of out-most backward/graph task,
  134. # this helps call back is fired after all gradients' calculation
  135. # is completed.
  136. class _DDPSink(Function):
  137. @staticmethod
  138. def forward(ctx, reducer, state_dict, *inputs):
  139. # set_materialize_grads(False) will ensure that None gradients stay as
  140. # None and are not filled with zeros.
  141. ctx.set_materialize_grads(False)
  142. ctx.reducer = reducer
  143. ctx.state_dict = state_dict
  144. ret = tuple(
  145. inp.clone()
  146. if isinstance(inp, torch.Tensor)
  147. else inp
  148. for inp in inputs
  149. )
  150. return ret
  151. @staticmethod
  152. def backward(ctx, *grad_outputs):
  153. state_dict = ctx.state_dict
  154. # Enqueue delay allreduce for static graph training on the first
  155. # iteration.
  156. if ctx.state_dict['static_graph'] and ctx.state_dict['num_iterations'] == 1:
  157. Variable._execution_engine.queue_callback(ctx.reducer._delay_all_reduce)
  158. return (None, None, *grad_outputs)
  159. class _DDPJoinHook(JoinHook):
  160. def __init__(self, ddp, divide_by_initial_world_size):
  161. """
  162. Sets config variables for internal usage.
  163. """
  164. assert isinstance(ddp, DistributedDataParallel), (
  165. "DDP join hook requires passing in a DistributedDataParallel "
  166. "instance as the state"
  167. )
  168. ddp.logger._set_uneven_input_join()
  169. self.ddp = ddp
  170. self.ddp._divide_by_initial_world_size = divide_by_initial_world_size
  171. super().__init__()
  172. def main_hook(self):
  173. """
  174. Shadows the DDP collective communication operations in the forward and
  175. backward passes.
  176. """
  177. ddp = self.ddp
  178. # Buckets are rebuilt only once during a training period
  179. ddp.reducer._rebuild_buckets()
  180. # Schedule a broadcast if we are syncing module buffers in the
  181. # forward pass
  182. # TODO: make DDP uneven inputs context manager support buffer
  183. # comm hook (https://github.com/pytorch/pytorch/issues/65436)
  184. ddp._check_and_sync_module_buffers()
  185. # Check if need to sync in the backward pass
  186. work = ddp._check_global_requires_backward_grad_sync(is_joined_rank=True)
  187. work.wait()
  188. should_sync_backwards = work.result()[0].item() != 0
  189. # Forward parameter sync is disabled in the next iteration if we
  190. # are skipping gradient sync this iteration, so set
  191. # `require_forward_param_sync` accordingly
  192. ddp.require_forward_param_sync = should_sync_backwards
  193. if not should_sync_backwards:
  194. return
  195. # Schedule one allreduce per gradient bucket to match the backward
  196. # pass allreduce
  197. ddp._match_all_reduce_for_bwd_pass()
  198. # Check if we need to allreduce locally unused parameters
  199. if ddp.find_unused_parameters:
  200. ddp._match_unused_params_allreduce()
  201. # Rebuilt parameters are pushed only once during a training period
  202. ddp.reducer._push_all_rebuilt_params()
  203. def post_hook(self, is_last_joiner: bool):
  204. """
  205. Syncs the final model to ensure that the model is the same across all
  206. processes.
  207. """
  208. self.ddp._sync_final_model(is_last_joiner)
  209. class DistributedDataParallel(Module, Joinable):
  210. r"""Implements distributed data parallelism that is based on
  211. ``torch.distributed`` package at the module level.
  212. This container parallelizes the application of the given module by
  213. splitting the input across the specified devices by chunking in the batch
  214. dimension. The module is replicated on each machine and each device, and
  215. each such replica handles a portion of the input. During the backwards
  216. pass, gradients from each node are averaged.
  217. The batch size should be larger than the number of GPUs used locally.
  218. See also: :ref:`distributed-basics` and :ref:`cuda-nn-ddp-instead`.
  219. The same constraints on input as in :class:`torch.nn.DataParallel` apply.
  220. Creation of this class requires that ``torch.distributed`` to be already
  221. initialized, by calling :func:`torch.distributed.init_process_group`.
  222. ``DistributedDataParallel`` is proven to be significantly faster than
  223. :class:`torch.nn.DataParallel` for single-node multi-GPU data
  224. parallel training.
  225. To use ``DistributedDataParallel`` on a host with N GPUs, you should spawn
  226. up ``N`` processes, ensuring that each process exclusively works on a single
  227. GPU from 0 to N-1. This can be done by either setting
  228. ``CUDA_VISIBLE_DEVICES`` for every process or by calling:
  229. >>> torch.cuda.set_device(i)
  230. where i is from 0 to N-1. In each process, you should refer the following
  231. to construct this module:
  232. >>> torch.distributed.init_process_group(
  233. >>> backend='nccl', world_size=N, init_method='...'
  234. >>> )
  235. >>> model = DistributedDataParallel(model, device_ids=[i], output_device=i)
  236. In order to spawn up multiple processes per node, you can use either
  237. ``torch.distributed.launch`` or ``torch.multiprocessing.spawn``.
  238. .. note::
  239. Please refer to `PyTorch Distributed Overview <https://pytorch.org/tutorials/beginner/dist_overview.html>`__
  240. for a brief introduction to all features related to distributed training.
  241. .. note::
  242. ``DistributedDataParallel`` can be used in conjunction with
  243. :class:`torch.distributed.optim.ZeroRedundancyOptimizer` to reduce
  244. per-rank optimizer states memory footprint. Please refer to
  245. `ZeroRedundancyOptimizer recipe <https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html>`__
  246. for more details.
  247. .. note:: ``nccl`` backend is currently the fastest and highly recommended
  248. backend when using GPUs. This applies to both single-node and
  249. multi-node distributed training.
  250. .. note:: This module also supports mixed-precision distributed training.
  251. This means that your model can have different types of parameters such
  252. as mixed types of ``fp16`` and ``fp32``, the gradient reduction on these
  253. mixed types of parameters will just work fine.
  254. .. note:: If you use ``torch.save`` on one process to checkpoint the module,
  255. and ``torch.load`` on some other processes to recover it, make sure that
  256. ``map_location`` is configured properly for every process. Without
  257. ``map_location``, ``torch.load`` would recover the module to devices
  258. where the module was saved from.
  259. .. note:: When a model is trained on ``M`` nodes with ``batch=N``, the
  260. gradient will be ``M`` times smaller when compared to the same model
  261. trained on a single node with ``batch=M*N`` if the loss is summed (NOT
  262. averaged as usual) across instances in a batch (because the gradients
  263. between different nodes are averaged). You should take this into
  264. consideration when you want to obtain a mathematically equivalent
  265. training process compared to the local training counterpart. But in most
  266. cases, you can just treat a DistributedDataParallel wrapped model, a
  267. DataParallel wrapped model and an ordinary model on a single GPU as the
  268. same (E.g. using the same learning rate for equivalent batch size).
  269. .. note::
  270. Parameters are never broadcast between processes. The module performs
  271. an all-reduce step on gradients and assumes that they will be modified
  272. by the optimizer in all processes in the same way. Buffers
  273. (e.g. BatchNorm stats) are broadcast from the module in process of rank
  274. 0, to all other replicas in the system in every iteration.
  275. .. note::
  276. If you are using DistributedDataParallel in conjunction with the
  277. :ref:`distributed-rpc-framework`, you should always use
  278. :meth:`torch.distributed.autograd.backward` to compute gradients and
  279. :class:`torch.distributed.optim.DistributedOptimizer` for optimizing
  280. parameters.
  281. .. note::
  282. DistributedDataParallel currently offers limited support for gradient
  283. checkpointing with :meth:`torch.utils.checkpoint`. DDP will work as
  284. expected when there are no unused parameters in the model and each layer
  285. is checkpointed at most once (make sure you are not passing
  286. `find_unused_parameters=True` to DDP). We currently do not support the
  287. case where a layer is checkpointed multiple times, or when there unused
  288. parameters in the checkpointed model.
  289. Example::
  290. >>> import torch.distributed.autograd as dist_autograd
  291. >>> from torch.nn.parallel import DistributedDataParallel as DDP
  292. >>> import torch
  293. >>> from torch import optim
  294. >>> from torch.distributed.optim import DistributedOptimizer
  295. >>> import torch.distributed.rpc as rpc
  296. >>> from torch.distributed.rpc import RRef
  297. >>>
  298. >>> t1 = torch.rand((3, 3), requires_grad=True)
  299. >>> t2 = torch.rand((3, 3), requires_grad=True)
  300. >>> rref = rpc.remote("worker1", torch.add, args=(t1, t2))
  301. >>> ddp_model = DDP(my_model)
  302. >>>
  303. >>> # Setup optimizer
  304. >>> optimizer_params = [rref]
  305. >>> for param in ddp_model.parameters():
  306. >>> optimizer_params.append(RRef(param))
  307. >>>
  308. >>> dist_optim = DistributedOptimizer(
  309. >>> optim.SGD,
  310. >>> optimizer_params,
  311. >>> lr=0.05,
  312. >>> )
  313. >>>
  314. >>> with dist_autograd.context() as context_id:
  315. >>> pred = ddp_model(rref.to_here())
  316. >>> loss = loss_func(pred, target)
  317. >>> dist_autograd.backward(context_id, [loss])
  318. >>> dist_optim.step(context_id)
  319. .. note::
  320. To let a non-DDP model load a state dict from a DDP model,
  321. :meth:`~torch.nn.modules.utils.consume_prefix_in_state_dict_if_present`
  322. needs to be applied to strip the prefix "module." in the DDP state dict before loading.
  323. .. warning::
  324. Constructor, forward method, and differentiation of the output (or a
  325. function of the output of this module) are distributed synchronization
  326. points. Take that into account in case different processes might be
  327. executing different code.
  328. .. warning::
  329. This module assumes all parameters are registered in the model by the
  330. time it is created. No parameters should be added nor removed later.
  331. Same applies to buffers.
  332. .. warning::
  333. This module assumes all parameters are registered in the model of each
  334. distributed processes are in the same order. The module itself will
  335. conduct gradient ``allreduce`` following the reverse order of the
  336. registered parameters of the model. In other words, it is users'
  337. responsibility to ensure that each distributed process has the exact
  338. same model and thus the exact same parameter registration order.
  339. .. warning::
  340. This module allows parameters with non-rowmajor-contiguous strides.
  341. For example, your model may contain some parameters whose
  342. :class:`torch.memory_format` is ``torch.contiguous_format``
  343. and others whose format is ``torch.channels_last``. However,
  344. corresponding parameters in different processes must have the
  345. same strides.
  346. .. warning::
  347. This module doesn't work with :func:`torch.autograd.grad` (i.e. it will
  348. only work if gradients are to be accumulated in ``.grad`` attributes of
  349. parameters).
  350. .. warning::
  351. If you plan on using this module with a ``nccl`` backend or a ``gloo``
  352. backend (that uses Infiniband), together with a DataLoader that uses
  353. multiple workers, please change the multiprocessing start method to
  354. ``forkserver`` (Python 3 only) or ``spawn``. Unfortunately
  355. Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will
  356. likely experience deadlocks if you don't change this setting.
  357. .. warning::
  358. You should never try to change your model's parameters after wrapping
  359. up your model with ``DistributedDataParallel``. Because, when
  360. wrapping up your model with ``DistributedDataParallel``, the constructor
  361. of ``DistributedDataParallel`` will register the additional gradient
  362. reduction functions on all the parameters of the model itself at the
  363. time of construction. If you change the model's parameters afterwards,
  364. gradient redunction functions no longer match the correct set of
  365. parameters.
  366. .. warning::
  367. Using ``DistributedDataParallel`` in conjunction with the
  368. :ref:`distributed-rpc-framework` is experimental and subject to change.
  369. Args:
  370. module (Module): module to be parallelized
  371. device_ids (list of int or torch.device): CUDA devices.
  372. 1) For single-device modules, ``device_ids`` can
  373. contain exactly one device id, which represents the only
  374. CUDA device where the input module corresponding to this process resides.
  375. Alternatively, ``device_ids`` can also be ``None``.
  376. 2) For multi-device modules and CPU modules,
  377. ``device_ids`` must be ``None``.
  378. When ``device_ids`` is ``None`` for both cases,
  379. both the input data for the forward pass and the actual module
  380. must be placed on the correct device.
  381. (default: ``None``)
  382. output_device (int or torch.device): Device location of output for
  383. single-device CUDA modules. For multi-device modules and
  384. CPU modules, it must be ``None``, and the module itself
  385. dictates the output location. (default: ``device_ids[0]``
  386. for single-device modules)
  387. broadcast_buffers (bool): Flag that enables syncing (broadcasting)
  388. buffers of the module at beginning of the ``forward``
  389. function. (default: ``True``)
  390. process_group: The process group to be used for distributed data
  391. all-reduction. If ``None``, the default process group, which
  392. is created by :func:`torch.distributed.init_process_group`,
  393. will be used. (default: ``None``)
  394. bucket_cap_mb: ``DistributedDataParallel`` will bucket parameters into
  395. multiple buckets so that gradient reduction of each
  396. bucket can potentially overlap with backward computation.
  397. :attr:`bucket_cap_mb` controls the bucket size in
  398. MegaBytes (MB). (default: 25)
  399. find_unused_parameters (bool): Traverse the autograd graph from all
  400. tensors contained in the return value of the
  401. wrapped module's ``forward`` function. Parameters
  402. that don't receive gradients as part of this
  403. graph are preemptively marked as being ready to
  404. be reduced. In addition, parameters that may have
  405. been used in the wrapped module's ``forward``
  406. function but were not part of loss computation and
  407. thus would also not receive gradients are
  408. preemptively marked as ready to be reduced.
  409. (default: ``False``)
  410. check_reduction: This argument is deprecated.
  411. gradient_as_bucket_view (bool): When set to ``True``, gradients will be views
  412. pointing to different offsets of ``allreduce`` communication
  413. buckets. This can reduce peak memory usage, where the
  414. saved memory size will be equal to the total gradients
  415. size. Moreover, it avoids the overhead of copying between
  416. gradients and ``allreduce`` communication buckets. When
  417. gradients are views, ``detach_()`` cannot be called on the
  418. gradients. If hitting such errors, please fix it by
  419. referring to the :meth:`~torch.optim.Optimizer.zero_grad`
  420. function in ``torch/optim/optimizer.py`` as a solution.
  421. Note that gradients will be views after first iteration, so
  422. the peak memory saving should be checked after first iteration.
  423. static_graph (bool): When set to ``True``, DDP knows the trained graph is
  424. static. Static graph means 1) The set of used and unused
  425. parameters will not change during the whole training loop; in
  426. this case, it does not matter whether users set
  427. ``find_unused_parameters = True`` or not. 2) How the graph is trained
  428. will not change during the whole training loop (meaning there is
  429. no control flow depending on iterations).
  430. When static_graph is set to be ``True``, DDP will support cases that
  431. can not be supported in the past:
  432. 1) Reentrant backwards.
  433. 2) Activation checkpointing multiple times.
  434. 3) Activation checkpointing when model has unused parameters.
  435. 4) There are model parameters that are outside of forward function.
  436. 5) Potentially improve performance when there are unused parameters,
  437. as DDP will not search graph in each iteraton to detect unused
  438. parameters when static_graph is set to be ``True``.
  439. To check whether you can set static_graph to be ``True``, one way is to
  440. check ddp logging data at the end of your previous model training,
  441. if ``ddp_logging_data.get("can_set_static_graph") == True``, mostly you
  442. can set ``static_graph = True`` as well.
  443. Example::
  444. >>> model_DDP = torch.nn.parallel.DistributedDataParallel(model)
  445. >>> # Training loop
  446. >>> .....
  447. >>> ddp_logging_data = model_DDP._get_ddp_logging_data()
  448. >>> static_graph = ddp_logging_data.get("can_set_static_graph")
  449. Attributes:
  450. module (Module): the module to be parallelized.
  451. Example::
  452. >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
  453. >>> net = torch.nn.parallel.DistributedDataParallel(model, pg)
  454. """
  455. def __init__(
  456. self,
  457. module,
  458. device_ids=None,
  459. output_device=None,
  460. dim=0,
  461. broadcast_buffers=True,
  462. process_group=None,
  463. bucket_cap_mb=25,
  464. find_unused_parameters=False,
  465. check_reduction=False,
  466. gradient_as_bucket_view=False,
  467. static_graph=False,
  468. ):
  469. super(DistributedDataParallel, self).__init__()
  470. Joinable.__init__(self)
  471. self.logger = None
  472. if not any((p.requires_grad for p in module.parameters())):
  473. self._log_and_throw(
  474. RuntimeError,
  475. "DistributedDataParallel is not needed when a module "
  476. "doesn't have any parameter that requires a gradient.",
  477. )
  478. if device_ids is not None and len(device_ids) > 1:
  479. self._log_and_throw(
  480. ValueError, "device_ids can only be None or contain a single element."
  481. )
  482. self.is_multi_device_module = len({p.device for p in module.parameters()}) > 1
  483. distinct_device_types = {p.device.type for p in module.parameters()}
  484. if len(distinct_device_types) != 1:
  485. self._log_and_throw(
  486. ValueError,
  487. "DistributedDataParallel's input module must be on "
  488. "the same type of devices, but input module parameters locate in {}.".format(
  489. distinct_device_types
  490. ),
  491. )
  492. self.device_type = list(distinct_device_types)[0]
  493. if (
  494. device_ids is None
  495. or len(device_ids) == 0 # For backward compatibility.
  496. or self.device_type == "cpu"
  497. or self.is_multi_device_module
  498. ):
  499. if device_ids or output_device:
  500. self._log_and_throw(
  501. ValueError,
  502. "DistributedDataParallel device_ids and output_device arguments "
  503. "only work with single-device/multiple-device GPU modules or CPU modules, "
  504. "but got device_ids {}, output_device {}, and module parameters {}.".format(
  505. device_ids,
  506. output_device,
  507. {p.device for p in module.parameters()},
  508. ),
  509. )
  510. self.device_ids = None
  511. self.output_device = None
  512. else:
  513. self.device_ids = [_get_device_index(x, True) for x in device_ids]
  514. if output_device is None:
  515. output_device = device_ids[0]
  516. self.output_device = _get_device_index(output_device, True)
  517. if process_group is None:
  518. self.process_group = _get_default_group()
  519. else:
  520. self.process_group = process_group
  521. self.static_graph = False
  522. self.dim = dim
  523. self.module = module
  524. self.device = list(self.module.parameters())[0].device
  525. self.broadcast_buffers = broadcast_buffers
  526. self.find_unused_parameters = find_unused_parameters
  527. self.require_backward_grad_sync = True
  528. self.require_forward_param_sync = True
  529. self.gradient_as_bucket_view = gradient_as_bucket_view
  530. if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
  531. self.parameters_to_ignore = module._ddp_params_and_buffers_to_ignore
  532. else:
  533. self.parameters_to_ignore = []
  534. self._use_replicated_tensor_module = _ddp_with_replicated_tensor_enabled()
  535. self._build_replicated_tensor_module()
  536. if check_reduction:
  537. # This argument is no longer used since the reducer
  538. # will ensure reduction completes even if some parameters
  539. # do not receive gradients.
  540. warnings.warn(
  541. "The `check_reduction` argument in `DistributedDataParallel` "
  542. "module is deprecated. Please avoid using it."
  543. )
  544. # Check that a module does not have Uninitialized parameters
  545. for param in module.parameters():
  546. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  547. self._log_and_throw(
  548. RuntimeError,
  549. "Modules with uninitialized parameters can't be used with `DistributedDataParallel`. "
  550. "Run a dummy forward pass to correctly initialize the modules",
  551. )
  552. # used for intra-node param sync and inter-node sync as well
  553. self.broadcast_bucket_size = int(250 * 1024 * 1024)
  554. # reduction bucket size
  555. self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
  556. # Whether to perform input tensor CPU to GPU copies on a side-stream
  557. self.use_side_stream_for_tensor_copies = (
  558. os.environ.get("PYTORCH_DDP_USE_SIDE_STREAM", "1") == "1"
  559. )
  560. # Build parameters for reducer.
  561. parameters, expect_sparse_gradient = self._build_params_for_reducer()
  562. # Verify model equivalence.
  563. _verify_param_shape_across_processes(self.process_group, parameters)
  564. # Sync params and buffers. Ensures all DDP models start off at the same value.
  565. _sync_module_states(
  566. module=self.module,
  567. process_group=self.process_group,
  568. broadcast_bucket_size=self.broadcast_bucket_size,
  569. src=0,
  570. params_and_buffers_to_ignore=self.parameters_to_ignore,
  571. )
  572. # In debug mode, build a mapping of parameter index -> parameter.
  573. param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
  574. # Builds reducer.
  575. self._ddp_init_helper(
  576. parameters, expect_sparse_gradient, param_to_name_mapping, static_graph
  577. )
  578. self._has_rebuilt_buckets = False
  579. if static_graph:
  580. self._set_static_graph()
  581. def _build_replicated_tensor_module(self):
  582. if self._use_replicated_tensor_module:
  583. # Create a module with ReplicatedTensor without copying tensors. Avoid
  584. # registering '_replicated_tensor_module' as a submodule by directly
  585. # adding to self.__dict__.
  586. from ._replicated_tensor_ddp_interop import _replicate_module
  587. self.__dict__['_replicated_tensor_module'] = _replicate_module(self.module, self.process_group)
  588. def _log_and_throw(self, err_type, err_msg):
  589. if self.logger is not None:
  590. self.logger.set_error_and_log(f"{str(err_type)}: {err_msg}")
  591. raise err_type(err_msg)
  592. def _ddp_init_helper(
  593. self, parameters, expect_sparse_gradient, param_to_name_mapping,
  594. static_graph
  595. ):
  596. """
  597. Initialization helper function that does the following:
  598. (1) bucketing the parameters for reductions
  599. (2) resetting the bucketing states
  600. (3) registering the grad hooks
  601. (4) Logging construction-time DDP logging data
  602. (5) passing a handle of DDP to SyncBatchNorm Layer
  603. """
  604. self.num_iterations = 0
  605. # Notice, the parameters order is not in the order in which they are used,
  606. # especially in models with control flow.
  607. #
  608. # Alongside parameters are not presented in the real execution order,
  609. # if a certain model happens to also
  610. # 1) have other collectives comm ops in its backward graph.
  611. # 2) have unused parameter in subset ranks of the whole world.
  612. # bucketing could insert ALL-REDUCE comm op too early on the rank with unused parameter,
  613. # matching up with other collectives comm ops on other ranks unexpectedly.
  614. #
  615. # In order to handle this corner case, when the parameters are not in the real execution order,
  616. # we don't do bucketing, thus only one ALL-REDUCE is inserted after all the gradients
  617. # of the whole graph are computed.
  618. #
  619. # Notice, here we only disable bucketing for the first iteration.
  620. # After the first iteration, it's OK to rebuild buckets,
  621. # because "bucket rebuild" bucketizes parameters based on its real execution order in backward graph.
  622. # Can remove this branching once #73732 is landed.
  623. if static_graph is True or self.find_unused_parameters is False:
  624. bucket_size_limits = [sys.maxsize]
  625. else:
  626. bucket_size_limits = [dist._DEFAULT_FIRST_BUCKET_BYTES, self.bucket_bytes_cap]
  627. bucket_indices, per_bucket_size_limits = dist._compute_bucket_assignment_by_size(
  628. parameters,
  629. bucket_size_limits,
  630. expect_sparse_gradient,
  631. )
  632. # Note: reverse list of buckets because we want to approximate the
  633. # order in which their gradients are produced, and assume they
  634. # are used in the forward pass in the order they are defined.
  635. self.reducer = dist.Reducer(
  636. parameters,
  637. list(reversed(bucket_indices)),
  638. list(reversed(per_bucket_size_limits)),
  639. self.process_group,
  640. expect_sparse_gradient,
  641. # The bucket size limit is specified in the constructor.
  642. # Additionally, we allow for a single small bucket for parameters
  643. # that are defined first, such that their gradients don't spill into
  644. # a much larger bucket, adding unnecessary latency after gradient
  645. # computation finishes. Experiments showed 1MB is a reasonable value.
  646. self.bucket_bytes_cap,
  647. self.find_unused_parameters,
  648. self.gradient_as_bucket_view,
  649. param_to_name_mapping,
  650. # User can set dist._DEFAULT_FIRST_BUCKET_BYTES to tune DDP first
  651. # bucket.
  652. dist._DEFAULT_FIRST_BUCKET_BYTES
  653. )
  654. self.logger = dist.Logger(self.reducer)
  655. # Set as a weak reference to avoid reference cycle between
  656. # logger and reducer.
  657. self.reducer.set_logger(self.logger)
  658. has_sync_bn = False
  659. for submodule in self.module.modules():
  660. if isinstance(submodule, torch.nn.SyncBatchNorm):
  661. has_sync_bn = True
  662. break
  663. # Set logging data that can be got during construction time.
  664. self.logger.set_construction_data_and_log(
  665. self.module.__class__.__name__,
  666. [] if self.device_ids is None else self.device_ids,
  667. -1 if self.output_device is None else self.output_device,
  668. self.broadcast_buffers,
  669. has_sync_bn,
  670. static_graph,
  671. )
  672. # passing a handle to torch.nn.SyncBatchNorm layer
  673. self._passing_sync_batchnorm_handle(self.module)
  674. def __getstate__(self):
  675. self._check_default_group()
  676. attrs = copy.copy(self.__dict__)
  677. del attrs["process_group"]
  678. del attrs["reducer"]
  679. del attrs["logger"]
  680. if self._use_replicated_tensor_module:
  681. del attrs["_replicated_tensor_module"]
  682. return attrs
  683. def __setstate__(self, state):
  684. # If serializable, then the process group should be the default one
  685. self.process_group = _get_default_group()
  686. super(DistributedDataParallel, self).__setstate__(state)
  687. self._build_replicated_tensor_module()
  688. self.__dict__.setdefault("require_forward_param_sync", True)
  689. self.__dict__.setdefault("require_backward_grad_sync", True)
  690. parameters, expect_sparse_gradient = self._build_params_for_reducer()
  691. # In debug mode, build a mapping of parameter index -> parameter.
  692. param_to_name_mapping = self._build_debug_param_to_name_mapping(parameters)
  693. # Builds reducer.
  694. self._ddp_init_helper(
  695. parameters, expect_sparse_gradient, param_to_name_mapping, self.static_graph
  696. )
  697. if self.static_graph:
  698. self.reducer._set_static_graph()
  699. self.logger._set_static_graph()
  700. def _build_params_for_reducer(self):
  701. # Build tuple of (module, parameter) for all parameters that require grads.
  702. modules_and_parameters = [
  703. (module, parameter)
  704. for module_name, module in self.module.named_modules()
  705. for parameter in [
  706. param
  707. # Note that we access module.named_parameters instead of
  708. # parameters(module). parameters(module) is only needed in the
  709. # single-process multi device case, where it accesses replicated
  710. # parameters through _former_parameters.
  711. for param_name, param in module.named_parameters(recurse=False)
  712. if param.requires_grad
  713. and f"{module_name}.{param_name}" not in self.parameters_to_ignore
  714. ]
  715. ]
  716. # Deduplicate any parameters that might be shared across child modules.
  717. memo = set()
  718. modules_and_parameters = [
  719. # "p not in memo" is the deduplication check.
  720. # "not memo.add(p)" is always True, and it's only there to cause "add(p)" if needed.
  721. (m, p) for m, p in modules_and_parameters
  722. if p not in memo and not memo.add(p)
  723. ]
  724. # Build list of parameters.
  725. parameters = list(parameter for _, parameter in modules_and_parameters)
  726. # Checks if a module will produce a sparse gradient.
  727. def produces_sparse_gradient(module):
  728. if isinstance(module, torch.nn.Embedding) or isinstance(
  729. module, torch.nn.EmbeddingBag
  730. ):
  731. return module.sparse
  732. return False
  733. # Build list of booleans indicating whether or not to expect sparse
  734. # gradients for the corresponding parameters.
  735. expect_sparse_gradient = list(produces_sparse_gradient(module) for module, _ in modules_and_parameters)
  736. self._assign_modules_buffers()
  737. return parameters, expect_sparse_gradient
  738. def _assign_modules_buffers(self):
  739. """
  740. Assigns module buffers to self.modules_buffers which are then used to
  741. broadcast across ranks when broadcast_buffers=True. Note that this
  742. must be called every time buffers need to be synced because buffers can
  743. be reassigned by user module,
  744. see https://github.com/pytorch/pytorch/issues/63916.
  745. """
  746. # Collect buffers for modules, filtering out buffers that should be ignored.
  747. named_module_buffers = [
  748. (buffer, buffer_name)
  749. for buffer_name, buffer in self.module.named_buffers()
  750. if buffer_name not in self.parameters_to_ignore
  751. ]
  752. self.modules_buffers = [
  753. buffer
  754. for (buffer, buffer_name) in named_module_buffers
  755. ]
  756. # Dict[str, tensor] representing module buffers not ignored by DDP.
  757. self.named_module_buffers = {
  758. buffer_name: buffer for (buffer, buffer_name) in named_module_buffers
  759. }
  760. def _build_debug_param_to_name_mapping(self, parameters):
  761. if dist.get_debug_level() == dist.DebugLevel.OFF:
  762. return {}
  763. param_to_param_index = {parameters[i]: i for i in range(len(parameters))}
  764. param_set = set(parameters)
  765. param_index_to_param_fqn = {}
  766. for module_name, module in self.module.named_modules():
  767. for param_name, param in module.named_parameters(recurse=False):
  768. fqn = f"{module_name}.{param_name}"
  769. # Bypass ignored parameters since those are not reduced by DDP
  770. # to begin with.
  771. if fqn not in self.parameters_to_ignore and param.requires_grad:
  772. if param not in param_set:
  773. self._log_and_throw(
  774. ValueError,
  775. f"Param with name {fqn} found in module parameters, but not DDP parameters."
  776. " This indicates a bug in DDP, please report an issue to PyTorch.",
  777. )
  778. param_index = param_to_param_index[param]
  779. param_index_to_param_fqn[param_index] = fqn
  780. # Ensure we covered all parameters
  781. if len(param_set) != len(param_index_to_param_fqn):
  782. self._log_and_throw(
  783. ValueError,
  784. (
  785. "Expected param to name mapping to cover all parameters, but"
  786. f" got conflicting lengths: {len(param_set)} vs "
  787. f"{len(param_index_to_param_fqn)}. This indicates a bug in DDP"
  788. ", please report an issue to PyTorch."
  789. ),
  790. )
  791. return param_index_to_param_fqn
  792. def _get_parameters(self, m, recurse=True):
  793. """
  794. Returns a generator of module parameters
  795. """
  796. def model_parameters(m):
  797. ps = (
  798. m._former_parameters.values()
  799. if hasattr(m, "_former_parameters")
  800. else m.parameters(recurse=False)
  801. )
  802. for p in ps:
  803. yield p
  804. for m in m.modules() if recurse else [m]:
  805. for p in model_parameters(m):
  806. yield p
  807. def _check_default_group(self):
  808. pickle_not_supported = False
  809. try:
  810. if self.process_group != _get_default_group():
  811. pickle_not_supported = True
  812. except RuntimeError:
  813. pickle_not_supported = True
  814. if pickle_not_supported:
  815. self._log_and_throw(
  816. RuntimeError,
  817. "DDP Pickling/Unpickling are only supported "
  818. "when using DDP with the default process "
  819. "group. That is, when you have called "
  820. "init_process_group and have not passed "
  821. "process_group argument to DDP constructor",
  822. )
  823. @contextmanager
  824. def no_sync(self):
  825. r"""
  826. A context manager to disable gradient synchronizations across DDP
  827. processes. Within this context, gradients will be accumulated on module
  828. variables, which will later be synchronized in the first
  829. forward-backward pass exiting the context.
  830. Example::
  831. >>> ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
  832. >>> with ddp.no_sync():
  833. >>> for input in inputs:
  834. >>> ddp(input).backward() # no synchronization, accumulate grads
  835. >>> ddp(another_input).backward() # synchronize grads
  836. """
  837. old_require_backward_grad_sync = self.require_backward_grad_sync
  838. self.require_backward_grad_sync = False
  839. try:
  840. yield
  841. finally:
  842. self.require_backward_grad_sync = old_require_backward_grad_sync
  843. def _run_ddp_forward(self, *inputs, **kwargs):
  844. module_to_run = self._replicated_tensor_module if self._use_replicated_tensor_module else self.module
  845. if self.device_ids:
  846. inputs, kwargs = _to_kwargs(
  847. inputs,
  848. kwargs,
  849. self.device_ids[0],
  850. self.use_side_stream_for_tensor_copies
  851. )
  852. return module_to_run(*inputs[0], **kwargs[0])
  853. else:
  854. return module_to_run(*inputs, **kwargs)
  855. def forward(self, *inputs, **kwargs):
  856. with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
  857. if torch.is_grad_enabled() and self.require_backward_grad_sync:
  858. self.logger.set_runtime_stats_and_log()
  859. self.num_iterations += 1
  860. self.reducer.prepare_for_forward()
  861. # Notify the join context that this process has not joined, if
  862. # needed
  863. work = Join.notify_join_context(self)
  864. if work:
  865. self.reducer._set_forward_pass_work_handle(
  866. work, self._divide_by_initial_world_size
  867. )
  868. # Calling _rebuild_buckets before forward compuation,
  869. # It may allocate new buckets before deallocating old buckets
  870. # inside _rebuild_buckets. To save peak memory usage,
  871. # call _rebuild_buckets before the peak memory usage increases
  872. # during forward computation.
  873. # This should be called only once during whole training period.
  874. if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
  875. logger.info("Reducer buckets have been rebuilt in this iteration.")
  876. self._has_rebuilt_buckets = True
  877. # sync params according to location (before/after forward) user
  878. # specified as part of hook, if hook was specified.
  879. buffer_hook_registered = hasattr(self, 'buffer_hook')
  880. if self._check_sync_bufs_pre_fwd():
  881. self._sync_buffers()
  882. if self._join_config.enable:
  883. # Notify joined ranks whether they should sync in backwards pass or not.
  884. self._check_global_requires_backward_grad_sync(is_joined_rank=False)
  885. output = self._run_ddp_forward(*inputs, **kwargs)
  886. # sync params according to location (before/after forward) user
  887. # specified as part of hook, if hook was specified.
  888. if self._check_sync_bufs_post_fwd():
  889. self._sync_buffers()
  890. if torch.is_grad_enabled() and self.require_backward_grad_sync:
  891. self.require_forward_param_sync = True
  892. # We'll return the output object verbatim since it is a freeform
  893. # object. We need to find any tensors in this object, though,
  894. # because we need to figure out which parameters were used during
  895. # this forward pass, to ensure we short circuit reduction for any
  896. # unused parameters. Only if `find_unused_parameters` is set.
  897. if self.find_unused_parameters and not self.static_graph:
  898. # Do not need to populate this for static graph.
  899. self.reducer.prepare_for_backward(list(_find_tensors(output)))
  900. else:
  901. self.reducer.prepare_for_backward([])
  902. else:
  903. self.require_forward_param_sync = False
  904. # TODO: DDPSink is currently enabled for unused parameter detection and
  905. # static graph training for first iteration.
  906. if (self.find_unused_parameters and not self.static_graph) or (
  907. self.static_graph and self.num_iterations == 1
  908. ):
  909. state_dict = {
  910. 'static_graph': self.static_graph,
  911. 'num_iterations': self.num_iterations,
  912. }
  913. output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(
  914. output
  915. )
  916. output_placeholders = [None for _ in range(len(output_tensor_list))]
  917. # Do not touch tensors that have no grad_fn, which can cause issues
  918. # such as https://github.com/pytorch/pytorch/issues/60733
  919. for i, output in enumerate(output_tensor_list):
  920. if torch.is_tensor(output) and output.grad_fn is None:
  921. output_placeholders[i] = output
  922. # When find_unused_parameters=True, makes tensors which require grad
  923. # run through the DDPSink backward pass. When not all outputs are
  924. # used in loss, this makes those corresponding tensors receive
  925. # undefined gradient which the reducer then handles to ensure
  926. # param.grad field is not touched and we don't error out.
  927. passthrough_tensor_list = _DDPSink.apply(
  928. self.reducer,
  929. state_dict,
  930. *output_tensor_list,
  931. )
  932. for i in range(len(output_placeholders)):
  933. if output_placeholders[i] is None:
  934. output_placeholders[i] = passthrough_tensor_list[i]
  935. # Reconstruct output data structure.
  936. output = _tree_unflatten_with_rref(
  937. output_placeholders, treespec, output_is_rref
  938. )
  939. return output
  940. def scatter(self, inputs, kwargs, device_ids):
  941. return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
  942. def to_kwargs(self, inputs, kwargs, device_id):
  943. # Kept for BC
  944. return _to_kwargs(
  945. inputs, kwargs, device_id, self.use_side_stream_for_tensor_copies
  946. )
  947. def gather(self, outputs, output_device):
  948. return gather(outputs, output_device, dim=self.dim)
  949. def train(self, mode=True):
  950. super(DistributedDataParallel, self).train(mode)
  951. if self._use_replicated_tensor_module:
  952. self._replicated_tensor_module.train(mode)
  953. return self
  954. # When running in join mode, schedules an allreduce to notify joined ranks
  955. # of whether backwards pass synchronization will run this iteraton or not.
  956. def _check_global_requires_backward_grad_sync(self, is_joined_rank):
  957. if not is_joined_rank and self.require_backward_grad_sync:
  958. requires_sync_tensor = torch.ones(1, device=self.device)
  959. else:
  960. requires_sync_tensor = torch.zeros(1, device=self.device)
  961. work = dist.all_reduce(
  962. requires_sync_tensor, group=self.process_group, async_op=True
  963. )
  964. return work
  965. # When running in join mode, checks and performs sync of module buffers if
  966. # the models have buffers that should be synchronized in the forward pass.
  967. def _check_and_sync_module_buffers(self):
  968. if self._check_sync_bufs_pre_fwd():
  969. authoritative_rank = self._find_common_rank(self._distributed_rank, False)
  970. self._sync_module_buffers(authoritative_rank)
  971. # When running in join model, agrees upon a common rank and broadcast model
  972. # parameters to all other ranks.
  973. def _sync_final_model(self, is_last_joiner):
  974. # Agree upon the process that will be the authoritative model copy.
  975. # The current rank is a candidate for being the authoritative copy if
  976. # is_last_joiner=True. We break ties via picking the larger rank.
  977. self._authoritative_rank = self._find_common_rank(
  978. self._distributed_rank, is_last_joiner
  979. )
  980. _sync_module_states(
  981. module=self.module,
  982. process_group=self.process_group,
  983. broadcast_bucket_size=self.broadcast_bucket_size,
  984. src=self._authoritative_rank,
  985. params_and_buffers_to_ignore=self.parameters_to_ignore
  986. )
  987. # Schedule comm ops to match those scheduled in the reducer's backward
  988. # pass.
  989. def _match_all_reduce_for_bwd_pass(self):
  990. comm_work = []
  991. # Schedule comm in the same order as Reducer schedules them, i.e.
  992. # the order of the buckets. Retrieving the bucket order from the reducer
  993. # ensures that we keep the same order in join mode, such as when bucket
  994. # order is rebuilt dynamically.
  995. # Returns grad_buckets in order, but real tensors are substituted with
  996. # zero tensors of the same shape.
  997. grad_buckets = self.reducer._get_zeros_like_grad_buckets()
  998. for grad_bucket in grad_buckets:
  999. # Joined processes contribute zero gradient. In the case that
  1000. # divide_by_initial_world_size=True, we divide grads by the static
  1001. # world size, if not, the dividing factor is reduced by the number
  1002. # of joined processes.
  1003. work = self.reducer._run_comm_hook(grad_bucket)
  1004. comm_work.append(work)
  1005. for work in comm_work:
  1006. work.wait()
  1007. # Allreduces the used parameter mapping across ranks.
  1008. def _match_unused_params_allreduce(self):
  1009. locally_used_param_map = self.reducer._get_local_used_map()
  1010. self.process_group.allreduce(locally_used_param_map)
  1011. def join(
  1012. self,
  1013. divide_by_initial_world_size: bool = True,
  1014. enable: bool = True,
  1015. throw_on_early_termination: bool = False,
  1016. ):
  1017. r"""
  1018. A context manager to be used in conjunction with an instance of
  1019. :class:`torch.nn.parallel.DistributedDataParallel` to be
  1020. able to train with uneven inputs across participating processes.
  1021. This context manager will keep track of already-joined DDP processes,
  1022. and "shadow" the forward and backward passes by inserting collective
  1023. communication operations to match with the ones created by non-joined
  1024. DDP processes. This will ensure each collective call has a corresponding
  1025. call by already-joined DDP processes, preventing hangs or errors that
  1026. would otherwise happen when training with uneven inputs across
  1027. processes. Alternatively, if the flag ``throw_on_early_termination`` is
  1028. specified to be ``True``, all trainers will throw an error once one rank
  1029. runs out of inputs, allowing these errors to be caught and handled
  1030. according to application logic.
  1031. Once all DDP processes have joined, the context manager will broadcast
  1032. the model corresponding to the last joined process to all processes to
  1033. ensure the model is the same across all processes
  1034. (which is guaranteed by DDP).
  1035. To use this to enable training with uneven inputs across processes,
  1036. simply wrap this context manager around your training loop. No further
  1037. modifications to the model or data loading is required.
  1038. .. warning::
  1039. If the model or training loop this context manager is wrapped around
  1040. has additional distributed collective operations, such as
  1041. ``SyncBatchNorm`` in the model's forward pass, then the flag
  1042. ``throw_on_early_termination`` must be enabled. This is because this
  1043. context manager is not aware of non-DDP collective communication.
  1044. This flag will cause all ranks to throw when any one rank
  1045. exhausts inputs, allowing these errors to be caught and recovered
  1046. from across all ranks.
  1047. Args:
  1048. divide_by_initial_world_size (bool): If ``True``, will divide
  1049. gradients by the initial ``world_size`` DDP training was launched
  1050. with. If ``False``, will compute the effective world size
  1051. (number of ranks that have not depleted their inputs yet) and
  1052. divide gradients by that during allreduce. Set
  1053. ``divide_by_initial_world_size=True`` to ensure every input
  1054. sample including the uneven inputs have equal weight in terms of
  1055. how much they contribute to the global gradient. This is
  1056. achieved by always dividing the gradient by the initial
  1057. ``world_size`` even when we encounter uneven inputs. If you set
  1058. this to ``False``, we divide the gradient by the remaining
  1059. number of nodes. This ensures parity with training on a smaller
  1060. ``world_size`` although it also means the uneven inputs would
  1061. contribute more towards the global gradient. Typically, you
  1062. would want to set this to ``True`` for cases where the last few
  1063. inputs of your training job are uneven. In extreme cases, where
  1064. there is a large discrepancy in the number of inputs, setting
  1065. this to ``False`` might provide better results.
  1066. enable (bool): Whether to enable uneven input detection or not. Pass
  1067. in ``enable=False`` to disable in cases where you know that
  1068. inputs are even across participating processes. Default is
  1069. ``True``.
  1070. throw_on_early_termination (bool): Whether to throw an error
  1071. or continue training when at least one rank has exhausted
  1072. inputs. If ``True``, will throw upon the first rank reaching end
  1073. of data. If ``False``, will continue training with a smaller
  1074. effective world size until all ranks are joined. Note that if
  1075. this flag is specified, then the flag
  1076. ``divide_by_initial_world_size`` would be ignored. Default
  1077. is ``False``.
  1078. Example::
  1079. >>> import torch
  1080. >>> import torch.distributed as dist
  1081. >>> import os
  1082. >>> import torch.multiprocessing as mp
  1083. >>> import torch.nn as nn
  1084. >>> # On each spawned worker
  1085. >>> def worker(rank):
  1086. >>> dist.init_process_group("nccl", rank=rank, world_size=2)
  1087. >>> torch.cuda.set_device(rank)
  1088. >>> model = nn.Linear(1, 1, bias=False).to(rank)
  1089. >>> model = torch.nn.parallel.DistributedDataParallel(
  1090. >>> model, device_ids=[rank], output_device=rank
  1091. >>> )
  1092. >>> # Rank 1 gets one more input than rank 0.
  1093. >>> inputs = [torch.tensor([1]).float() for _ in range(10 + rank)]
  1094. >>> with model.join():
  1095. >>> for _ in range(5):
  1096. >>> for inp in inputs:
  1097. >>> loss = model(inp).sum()
  1098. >>> loss.backward()
  1099. >>> # Without the join() API, the below synchronization will hang
  1100. >>> # blocking for rank 1's allreduce to complete.
  1101. >>> torch.cuda.synchronize(device=rank)
  1102. """
  1103. return Join(
  1104. [self],
  1105. enable,
  1106. throw_on_early_termination,
  1107. divide_by_initial_world_size=divide_by_initial_world_size,
  1108. )
  1109. def join_hook(
  1110. self,
  1111. **kwargs,
  1112. ):
  1113. r"""
  1114. Returns the DDP join hook, which enables training on uneven inputs by
  1115. shadowing the collective communications in the forward and backward
  1116. passes.
  1117. Arguments:
  1118. kwargs (dict): a :class:`dict` containing any keyword arguments
  1119. to modify the behavior of the join hook at run time; all
  1120. :class:`Joinable` instances sharing the same join context
  1121. manager are forwarded the same value for ``kwargs``.
  1122. The hook supports the following keyword arguments:
  1123. divide_by_initial_world_size (bool, optional):
  1124. If ``True``, then gradients are divided by the initial world
  1125. size that DDP was launched with.
  1126. If ``False``, then gradients are divided by the effective world
  1127. size (i.e. the number of non-joined processes), meaning that
  1128. the uneven inputs contribute more toward the global gradient.
  1129. Typically, this should be set to ``True`` if the degree of
  1130. unevenness is small but can be set to ``False`` in extreme
  1131. cases for possibly better results.
  1132. Default is ``True``.
  1133. """
  1134. divide_by_initial_world_size = kwargs.get("divide_by_initial_world_size", True)
  1135. return _DDPJoinHook(
  1136. self, divide_by_initial_world_size=divide_by_initial_world_size
  1137. )
  1138. @property
  1139. def join_device(self):
  1140. return self.device
  1141. @property
  1142. def join_process_group(self):
  1143. return self.process_group
  1144. def _register_buffer_comm_hook(
  1145. self,
  1146. state,
  1147. hook: callable,
  1148. comm_hook_location=_BufferCommHookLocation.POST_FORWARD
  1149. ):
  1150. r"""
  1151. Allows custom registration of hooks that define how buffer are
  1152. synchronized across ranks. The hook takes in an optional state
  1153. and is passed in a Dict[str, Tensor] corresponding to buffer names
  1154. and the buffers, and can run arbitrary reductions on buffers as
  1155. opposed to DDP's default broadcast from rank 0. This is useful for
  1156. example if a counter needs to be summed or averaged across ranks
  1157. every iteration.
  1158. Args:
  1159. state (Any): Optional state that is passed to the hook.
  1160. hook (Callable): Callable with the following signature:
  1161. ``hook(state: object, buffers: Dict[str, torch.Tensor])
  1162. -> Optional[List[torch.futures.Future[torch.Tensor]]]``
  1163. comm_hook_location (_BufferCommHookLocation): Enum value indicating
  1164. where to run the hook.
  1165. _BufferCommHookLocation.PRE_FORWARD means that the
  1166. hook will run _before_ the forward pass, and
  1167. _BufferCommHookLocation.POST_FORWARD means that the
  1168. hook will run _after_ the forward pass.
  1169. hook (callable): Callable with the following signature:
  1170. ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``:
  1171. NOTE: To maximize performance, users can return a
  1172. List[torch.futures.Future] from their hook, and DDP will
  1173. install and await these hooks appropriately at the end of
  1174. the backward pass. This will ensure all buffers are
  1175. synchronized by the end of the backward pass. If this
  1176. setting is used, it is recommended to pass
  1177. comm_hook_location=_BufferCommHookLocation.POST_FORWARD,
  1178. which will trigger the hook after the forward pass.
  1179. If _BufferCommHookLocation.PRE_FORWARD is used, users must
  1180. ensure appropriate synchronization when manipulating GPU
  1181. buffers in the forward pass.
  1182. """
  1183. assert callable(hook)
  1184. self.buffer_hook = _BufferCommHook(
  1185. buffer_comm_hook=hook,
  1186. buffer_comm_hook_state=state,
  1187. buffer_comm_hook_location=comm_hook_location
  1188. )
  1189. def register_comm_hook(self, state: object, hook: callable):
  1190. r"""
  1191. Registers a communication hook which is an enhancement that provides a
  1192. flexible hook to users where they can specify how DDP aggregates gradients
  1193. across multiple workers.
  1194. This hook would be very useful for researchers to try out new ideas. For
  1195. example, this hook can be used to implement several algorithms like GossipGrad
  1196. and gradient compression which involve different communication strategies for
  1197. parameter syncs while running Distributed DataParallel training.
  1198. Args:
  1199. state (object): Passed to the hook to maintain any state information during the training process.
  1200. Examples include error feedback in gradient compression,
  1201. peers to communicate with next in GossipGrad, etc.
  1202. It is locally stored by each worker
  1203. and shared by all the gradient tensors on the worker.
  1204. hook (callable): Callable with the following signature:
  1205. ``hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]``:
  1206. This function is called once the bucket is ready. The
  1207. hook can perform whatever processing is needed and return
  1208. a Future indicating completion of any async work (ex: allreduce).
  1209. If the hook doesn't perform any communication, it still
  1210. must return a completed Future. The Future should hold the
  1211. new value of grad bucket's tensors. Once a bucket is ready,
  1212. c10d reducer would call this hook and use the tensors returned
  1213. by the Future and copy grads to individual parameters.
  1214. Note that the future's return type must be a single tensor.
  1215. We also provide an API called ``get_future`` to retrieve a
  1216. Future associated with the completion of ``c10d.ProcessGroup.Work``.
  1217. ``get_future`` is currently supported for NCCL and also supported for most
  1218. operations on GLOO and MPI, except for peer to peer operations (send/recv).
  1219. .. warning ::
  1220. Grad bucket's tensors will not be predivided by world_size. User is responsible
  1221. to divide by the world_size in case of operations like allreduce.
  1222. .. warning ::
  1223. DDP communication hook can only be registered once and should be registered
  1224. before calling backward.
  1225. .. warning ::
  1226. The Future object that hook returns should contain a single tensor
  1227. that has the same shape with the tensors inside grad bucket.
  1228. .. warning ::
  1229. ``get_future`` API supports NCCL, and partially GLOO and MPI backends (no support
  1230. for peer-to-peer operations like send/recv) and will return a ``torch.futures.Future``.
  1231. Example::
  1232. Below is an example of a noop hook that returns the same tensor.
  1233. >>> def noop(state: object, bucket: dist.GradBucket): -> torch.futures.Future[torch.Tensor]
  1234. >>> fut = torch.futures.Future()
  1235. >>> fut.set_result(bucket.buffer())
  1236. >>> return fut
  1237. >>> ddp.register_comm_hook(state=None, hook=noop)
  1238. Example::
  1239. Below is an example of a Parallel SGD algorithm where gradients are encoded before
  1240. allreduce, and then decoded after allreduce.
  1241. >>> def encode_and_decode(state: object, bucket: dist.GradBucket): -> torch.futures.Future[torch.Tensor]
  1242. >>> encoded_tensor = encode(bucket.buffer()) # encode gradients
  1243. >>> fut = torch.distributed.all_reduce(encoded_tensor).get_future()
  1244. >>> # Define the then callback to decode.
  1245. >>> def decode(fut):
  1246. >>> decoded_tensor = decode(fut.value()[0]) # decode gradients
  1247. >>> return decoded_tensor
  1248. >>> return fut.then(decode)
  1249. >>> ddp.register_comm_hook(state=None, hook=encode_and_decode)
  1250. """
  1251. self._check_comm_hook(hook)
  1252. self.logger._set_comm_hook_name(hook.__qualname__)
  1253. dist._register_comm_hook(self.reducer, state, hook)
  1254. def _register_builtin_comm_hook(self, comm_hook_type):
  1255. r"""
  1256. Registers a built-in communication hook that specifies how DDP
  1257. aggregates gradients across multiple workers.
  1258. The built-in hooks aim to provide efficient C++ implementations for certain hooks,
  1259. which might not be as efficient if implemented in Python using a Python communication hook.
  1260. Args:
  1261. comm_hook_type (dist.BuiltinCommHookType): type of communication hook, such as ALLREDUCE, FP16_COMPRESS, etc.
  1262. .. warning ::
  1263. DDP communication hook can only be registered once and should be registered
  1264. before calling backward.
  1265. Example::
  1266. Below is an example of a FP16 compression where gradients are
  1267. compressed into 16-bit floating-point numbers before allreduce, and
  1268. then decompressed after allreduce.
  1269. >>> ddp._register_builtin_comm_hook(dist.BuiltinCommHookType.FP16_COMPRESS)
  1270. """
  1271. self.logger._set_comm_hook_name(str(comm_hook_type))
  1272. dist._register_builtin_comm_hook(self.reducer, comm_hook_type)
  1273. def _register_fused_optim(self, optim: Type, *args, optim_params=None, **kwargs):
  1274. r"""
  1275. Registers an optimizer with DDP such that the optimization for a
  1276. parameter will run immediately when that parameter's gradient is
  1277. finished with reduction, instead of waiting for all parameters'
  1278. gradients to finish reduction. This can result in a training speedup
  1279. depending on your workload since the optimizer can run while gradient
  1280. reduction for other parameters are still ongoing. In addition, this has
  1281. the potential to reduce peak memory consumption during training, as it
  1282. only needs to load the per-parameter optimizer states of a single
  1283. parameter at a time, instead of loading all per-parameter optimizer
  1284. states at once.
  1285. Args:
  1286. optim_cls (Type): a ``torch.optim.Optimizer`` class to be registered
  1287. as a fused optimizer.
  1288. *args (Sequence[Any]): Arguments to forward to `optim_cls`.
  1289. optim_params (Optional[Iterable[torch.Tensor]]): Set of parameters
  1290. to optimize, similar to `params` argument of traditional `torch.optim`
  1291. Optimizers. If this is omitted, all DDP model parameters will be
  1292. optimized.
  1293. **kwargs: (Dict[str, Any]): Keyword arguments to forward to `optim_cls`.
  1294. .. warning ::
  1295. _register_fused_optim should only be called once on a DDP instance,
  1296. and registering multiple fused optimizers for the same DDP model
  1297. is not currently supported. Please ping
  1298. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  1299. for your use case.
  1300. .. warning ::
  1301. _register_fused_optim and register_comm_hook currently do not
  1302. compose together, meaning that custom DDP communication hooks are
  1303. not supported with overlapped optimizers. Please ping
  1304. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  1305. for your use case.
  1306. .. warning ::
  1307. Gradient accumulation and DDP `no_sync` are currently not supported
  1308. with overlapped optimizer. Please ping
  1309. https://github.com/pytorch/pytorch/issues/71595 if this is necessary
  1310. for your use case.
  1311. Example::
  1312. >>> torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')
  1313. >>> net = torch.nn.parallel.DistributedDataParallel(model, pg)
  1314. >>> lr = 1e-2
  1315. >>> betas = (0.9, 0.99)
  1316. >>> eps = 1e-6
  1317. >>> net._register_fused_optim(torch.optim.Adam, lr, betas=betas, eps=eps)
  1318. >>> # Example with subset of parameters
  1319. >>> params_to_opt = [list(net.parameters())[0]]
  1320. >>> net._register_fused_optim(
  1321. torch.optim.Adam, lr, optim_params=params_to_opt, betas=betas, eps=eps
  1322. )
  1323. """
  1324. # Note: importing in function, otherwise this will cause a circular
  1325. # import as optimizer_overlap module needs to import DistributedDataParallel.
  1326. from torch.distributed.algorithms._optimizer_overlap import _as_overlapped_optim
  1327. overlapped_optim = _as_overlapped_optim(optim, optim_params, *args, **kwargs)
  1328. try:
  1329. overlapped_optim.register_ddp(self)
  1330. except NotImplementedError:
  1331. raise RuntimeError(
  1332. f"{optim} does not support overlapped DDP. Please file an issue to PyTorch or the respective owner of {optim}."
  1333. )
  1334. def _distributed_broadcast_coalesced(
  1335. self, tensors, buffer_size, authoritative_rank=0
  1336. ):
  1337. dist._broadcast_coalesced(
  1338. self.process_group, tensors, buffer_size, authoritative_rank
  1339. )
  1340. def _check_sync_bufs_post_fwd(self):
  1341. return (
  1342. self.will_sync_module_buffers() and
  1343. hasattr(self, 'buffer_hook') and
  1344. self.buffer_hook.buffer_comm_hook_location ==
  1345. _BufferCommHookLocation.POST_FORWARD
  1346. )
  1347. def _check_sync_bufs_pre_fwd(self):
  1348. return self.will_sync_module_buffers() and (
  1349. not hasattr(self, 'buffer_hook') or
  1350. self.buffer_hook.buffer_comm_hook_location
  1351. == _BufferCommHookLocation.PRE_FORWARD
  1352. )
  1353. def will_sync_module_buffers(self):
  1354. return (
  1355. self.require_forward_param_sync
  1356. and self.broadcast_buffers
  1357. and len(self.modules_buffers) > 0
  1358. )
  1359. def _find_common_rank(self, input_rank, rank_cond):
  1360. # -1 indicates that this rank is not under consideration to be the
  1361. # common_rank
  1362. rank_to_use = torch.tensor(
  1363. [input_rank if rank_cond else -1],
  1364. device=self.device,
  1365. )
  1366. dist.all_reduce(rank_to_use, op=ReduceOp.MAX, group=self.process_group)
  1367. if rank_to_use.item() == -1:
  1368. self._log_and_throw(
  1369. ValueError,
  1370. "BUG! Expected rank_cond to be true for at least one process."
  1371. " This indicates a bug in PyTorch, please report an issue.",
  1372. )
  1373. return rank_to_use.item()
  1374. def _sync_buffers(self):
  1375. with torch.no_grad():
  1376. # module buffer sync
  1377. # Synchronize buffers across processes.
  1378. # If we are running DDP with the join manager, we have to agree
  1379. # upon a rank to sync module buffers from, since rank 0 may
  1380. # already have been joined and have stale module buffers.
  1381. if self._join_config.enable:
  1382. authoritative_rank = self._find_common_rank(
  1383. self._distributed_rank, True
  1384. )
  1385. else:
  1386. # The process with rank 0 is considered the authoritative copy.
  1387. authoritative_rank = 0
  1388. # Update self.modules_buffers incase any buffers were
  1389. # reassigned.
  1390. self._assign_modules_buffers()
  1391. self._sync_module_buffers(authoritative_rank)
  1392. def _sync_module_buffers(self, authoritative_rank):
  1393. if not hasattr(self, 'buffer_hook'):
  1394. self._default_broadcast_coalesced(authoritative_rank=authoritative_rank)
  1395. else:
  1396. hook = self.buffer_hook.buffer_comm_hook
  1397. state = self.buffer_hook.buffer_comm_hook_state
  1398. futs = hook(state, self.named_module_buffers)
  1399. if futs is not None:
  1400. self.reducer._install_post_backward_futures(futs)
  1401. def _default_broadcast_coalesced(
  1402. self, bufs=None, bucket_size=None, authoritative_rank=0
  1403. ):
  1404. """
  1405. Broadcasts buffers from rank 0 to rest of workers. If bufs, bucket_size
  1406. are None, default values self.modules_buffers and
  1407. self.broadcast_bucket_size are used instead.
  1408. """
  1409. if bufs is None:
  1410. bufs = self.modules_buffers
  1411. if bucket_size is None:
  1412. bucket_size = self.broadcast_bucket_size
  1413. self._distributed_broadcast_coalesced(
  1414. bufs,
  1415. bucket_size,
  1416. authoritative_rank
  1417. )
  1418. def _passing_sync_batchnorm_handle(self, module):
  1419. for layer in module.modules():
  1420. if isinstance(layer, torch.nn.modules.SyncBatchNorm):
  1421. if self.device_type == "cpu":
  1422. self._log_and_throw(
  1423. ValueError, "SyncBatchNorm layers only work with GPU modules"
  1424. )
  1425. def _check_comm_hook(self, hook):
  1426. if not callable(hook):
  1427. self._log_and_throw(TypeError, "Communication hook must be callable.")
  1428. sig = inspect.signature(hook)
  1429. if (
  1430. sig.parameters["bucket"].annotation != inspect._empty
  1431. and sig.parameters["bucket"].annotation != dist.GradBucket
  1432. ):
  1433. self._log_and_throw(
  1434. ValueError,
  1435. "Communication hook: bucket annotation should be dist.GradBucket.",
  1436. )
  1437. if (
  1438. sig.return_annotation != inspect._empty
  1439. and sig.return_annotation != torch.futures.Future[torch.Tensor]
  1440. ):
  1441. self._log_and_throw(
  1442. ValueError,
  1443. "Communication hook: return annotation should be torch.futures.Future[torch.Tensor].",
  1444. )
  1445. if (
  1446. hook.__name__ in ["bf16_compress_hook", "bf16_compress_wrapper_hook"]
  1447. and
  1448. (
  1449. torch.version.cuda is None
  1450. or int(torch.version.cuda.split('.')[0]) < 11
  1451. or not dist.is_available()
  1452. or not dist.is_nccl_available()
  1453. or torch.cuda.nccl.version() < (2, 10)
  1454. )
  1455. ):
  1456. self._log_and_throw(TypeError, "BF16 all reduce communication hook required CUDA 11+ and NCCL 2.10+.")
  1457. @property
  1458. def _distributed_rank(self):
  1459. return dist.get_rank(self.process_group)
  1460. @staticmethod
  1461. def _set_params_and_buffers_to_ignore_for_model(
  1462. module, params_and_buffers_to_ignore
  1463. ):
  1464. """
  1465. Sets parameters and buffers to be ignored by DDP. Expected format for
  1466. parameters is the fully qualified name: {module_name}.{param_name}, and
  1467. similarly, {module_name}.{buffer_name} for buffers. For example:
  1468. params_to_ignore = []
  1469. # NB: model here is vanilla PyTorch module, not yet wrapped with DDP.
  1470. for module_name, module in model.named_modules():
  1471. for param_name, param in module.named_parameters(recurse=False):
  1472. if should_ignore(param):
  1473. # Create expected format
  1474. fqn = f"{module_name}.{param_name}"
  1475. params_to_ignore.append(fqn)
  1476. torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
  1477. model,
  1478. params_to_ignore
  1479. )
  1480. """
  1481. # This is a workaround to set parameters and buffers DDP should ignore
  1482. # during synchronization. It will be removed when the API is finalized
  1483. # as part of addressing https://github.com/pytorch/pytorch/issues/43690.
  1484. module._ddp_params_and_buffers_to_ignore = params_and_buffers_to_ignore
  1485. def _get_ddp_logging_data(self):
  1486. r"""
  1487. This interface can be called after DistributedDataParallel() is
  1488. constructed. It returns a dictionary of logging data. It could help
  1489. for debugging and analysis. The loggind data includes DistributedDataParallel
  1490. constructor input parameters, some internal states of DistributedDataParallel
  1491. and performance metrics. Simply print the dictorinary and see what
  1492. these metrics are.
  1493. This is a prototype interface and subject to change in the future.
  1494. """
  1495. ddp_logging_data = self.logger._get_ddp_logging_data()
  1496. return {**ddp_logging_data.strs_map, **ddp_logging_data.ints_map}
  1497. def _set_ddp_runtime_logging_sample_rate(self, sample_rate):
  1498. r"""
  1499. This interface allows users to set sample_rate of collecting
  1500. runtime stats. The runtime stats will be recorded for the
  1501. first 10 iterations, after 10 iteratons runtime stats will be
  1502. recorded once every "sample_rate" training iterations. In
  1503. default, runtime stats are recorded for the first 10 iterations,
  1504. after 10 iterations runtime stats are recorded once every
  1505. "kDDPRuntimeLoggingSampleRate=100" training iterations.
  1506. This is a prototype interface and subject to change in the future.
  1507. """
  1508. if sample_rate < 1:
  1509. self._log_and_throw(
  1510. ValueError,
  1511. "DDP runtime logging sample rate should be equal or greater than 1",
  1512. )
  1513. self.reducer._set_ddp_runtime_logging_sample_rate(sample_rate)
  1514. def _set_static_graph(self):
  1515. """
  1516. It is recommended to set static graph in the DDP constructor, which will
  1517. call this private API internally.
  1518. """
  1519. # If self.static_graph has been set, no need to set it again
  1520. if self.static_graph:
  1521. warnings.warn(
  1522. "You've set static_graph to be True, no need to set it again."
  1523. )
  1524. return
  1525. self.static_graph = True
  1526. self.reducer._set_static_graph()
  1527. self.logger._set_static_graph()
  1528. if self.find_unused_parameters:
  1529. warnings.warn(
  1530. "You passed find_unused_parameters=true to DistributedDataParallel, "
  1531. "`_set_static_graph` will detect unused parameters automatically, so "
  1532. "you do not need to set find_unused_parameters=true, just be sure these "
  1533. "unused parameters will not change during training loop while calling "
  1534. "`_set_static_graph`."
  1535. )