optimizer.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from typing import List, Optional
  2. import logging
  3. import torch
  4. import torch.distributed.rpc as rpc
  5. import torch.jit as jit
  6. import torch.nn as nn
  7. from torch import Tensor
  8. from torch.distributed.rpc import RRef
  9. from .utils import functional_optim_map
  10. import torch.distributed.autograd as dist_autograd
  11. from collections import defaultdict
  12. from threading import Lock
  13. logger = logging.getLogger(__name__)
  14. # XXX: we define a _ScriptModuleOptimizer here to explicitly
  15. # compile the FunctionalOptimizer class into TorchScript
  16. # This is because ScriptClass instance still lives in
  17. # python unless you explicitly compile it as an attribute
  18. # in ScriptModule or pass it to a ScriptFunction
  19. # _ScriptLocalOptimizerInterface serves as a common
  20. # interface type for Optimizer ScriptModules.
  21. #
  22. # TODO (wanchaol): remove this once we added TorchScript
  23. # class reference semantics
  24. @jit.interface
  25. class _ScriptLocalOptimizerInterface(object):
  26. def step(self, autograd_ctx_id: int) -> None:
  27. pass
  28. class _ScriptLocalOptimizer(nn.Module):
  29. # TorchScript does not support multithread concurrent compiling.
  30. # request_callback might invoke concurrent compiling, so we
  31. # serialize the compiling with a lock
  32. compile_lock = Lock()
  33. def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
  34. super().__init__()
  35. self._local_params = [rref.local_value() for rref in local_params_rref]
  36. self.optim = optim_cls(
  37. self._local_params,
  38. *args,
  39. **kwargs)
  40. @jit.export
  41. def step(self, autograd_ctx_id: int):
  42. all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
  43. # apply functional optimizer step with a list of gradients
  44. grads: List[Optional[Tensor]] = [
  45. all_local_grads[p] if p in all_local_grads else None
  46. for p in self._local_params
  47. ]
  48. self.optim.step(grads)
  49. # TODO (wanchaol): remove/merge this with ScriptLocalOptimizer once
  50. # we have converted all to functional optimizer in distributed.optim
  51. class _LocalOptimizer(object):
  52. # Ideally we would only need to share a lock for instances of
  53. # _LocalOptimizer that deal with the same parameters. We are
  54. # making a simplifying assumption here that if there is more
  55. # than one instance of _LocalOptimizer per worker, they will
  56. # be optimizing the same parameters (e.g. each data parallel
  57. # trainer will create its own instance of _LocalOptimizer but
  58. # they will all optimize the same parameters on each worker)
  59. global_lock = Lock()
  60. def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
  61. self._local_params = [rref.local_value() for rref in local_params_rref]
  62. self.optim = optim_cls(
  63. self._local_params,
  64. *args,
  65. **kwargs)
  66. def step(self, autograd_ctx_id):
  67. all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
  68. with _LocalOptimizer.global_lock:
  69. for param, grad in all_local_grads.items():
  70. param.grad = grad
  71. self.optim.step()
  72. def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
  73. return rpc.RRef(
  74. _LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs))
  75. def _local_optimizer_step(local_optim_rref, autograd_ctx_id):
  76. local_optim = local_optim_rref.local_value()
  77. local_optim.step(autograd_ctx_id)
  78. # new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer
  79. def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
  80. optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs)
  81. with _ScriptLocalOptimizer.compile_lock:
  82. script_optim = jit.script(optim)
  83. return rpc.RRef(
  84. script_optim, _ScriptLocalOptimizerInterface)
  85. @jit.script
  86. def _script_local_optimizer_step(
  87. local_optim_rref: RRef[_ScriptLocalOptimizerInterface],
  88. autograd_ctx_id: int
  89. ) -> None:
  90. local_optim = local_optim_rref.local_value()
  91. local_optim.step(autograd_ctx_id)
  92. def _wait_for_all(rpc_futs):
  93. # TODO: improve error propagation
  94. exception = None
  95. results = []
  96. for fut in rpc_futs:
  97. try:
  98. results.append(fut.wait())
  99. except Exception as e:
  100. results.append(e)
  101. exception = e
  102. if exception is not None:
  103. raise exception
  104. return results
  105. class DistributedOptimizer:
  106. """
  107. DistributedOptimizer takes remote references to parameters scattered
  108. across workers and applies the given optimizer locally for each parameter.
  109. This class uses :meth:`~torch.distributed.autograd.get_gradients` in order
  110. to retrieve the gradients for specific parameters.
  111. Concurrent calls to
  112. :meth:`~torch.distributed.optim.DistributedOptimizer.step`,
  113. either from the same or different clients, will
  114. be serialized on each worker -- as each worker's optimizer can only work
  115. on one set of gradients at a time. However, there is no guarantee that
  116. the full forward-backward-optimizer sequence will execute for one client
  117. at a time. This means that the gradients being applied may not correspond
  118. to the latest forward pass executed on a given worker. Also, there is no
  119. guaranteed ordering across workers.
  120. `DistributedOptimizer` creates the local optimizer with TorchScript enabled
  121. by default, so that optimizer updates are not blocked by the Python Global
  122. Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed
  123. Model Parallel). This feature is currently enabled for most optimizers. You
  124. can also follow `the recipe`__ in PyTorch tutorials to enable TorchScript support
  125. for your own custom optimizers.
  126. Args:
  127. optimizer_class (optim.Optimizer): the class of optimizer to
  128. instantiate on each worker.
  129. params_rref (list[RRef]): list of RRefs to local or remote parameters
  130. to optimize.
  131. args: arguments to pass to the optimizer constructor on each worker.
  132. kwargs: arguments to pass to the optimizer constructor on each worker.
  133. Example::
  134. >>> import torch.distributed.autograd as dist_autograd
  135. >>> import torch.distributed.rpc as rpc
  136. >>> from torch import optim
  137. >>> from torch.distributed.optim import DistributedOptimizer
  138. >>>
  139. >>> with dist_autograd.context() as context_id:
  140. >>> # Forward pass.
  141. >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
  142. >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
  143. >>> loss = rref1.to_here() + rref2.to_here()
  144. >>>
  145. >>> # Backward pass.
  146. >>> dist_autograd.backward(context_id, [loss.sum()])
  147. >>>
  148. >>> # Optimizer.
  149. >>> dist_optim = DistributedOptimizer(
  150. >>> optim.SGD,
  151. >>> [rref1, rref2],
  152. >>> lr=0.05,
  153. >>> )
  154. >>> dist_optim.step(context_id)
  155. __ https://github.com/pytorch/tutorials/pull/1465
  156. """
  157. def __init__(self, optimizer_class, params_rref, *args, **kwargs):
  158. torch._C._log_api_usage_once("torch.distributed.optim.DistributedOptimizer")
  159. per_worker_params_rref = defaultdict(list)
  160. for param in params_rref:
  161. per_worker_params_rref[param.owner()].append(param)
  162. if optimizer_class in functional_optim_map and jit._state._enabled:
  163. optim_ctor = functional_optim_map.get(optimizer_class)
  164. else:
  165. optim_ctor = optimizer_class
  166. self.is_functional_optim = (optim_ctor != optimizer_class)
  167. if self.is_functional_optim:
  168. optimizer_new_func = _new_script_local_optimizer
  169. else:
  170. logger.warn(
  171. f"Creating the optimizer {optimizer_class} without TorchScript support, "
  172. "this might result in slow computation time in multithreading environment"
  173. "(i.e. Distributed Model Parallel training on CPU) due to the Python's "
  174. "Global Interpreter Lock (GIL). Please file an issue if you need this "
  175. "optimizer in TorchScript. "
  176. )
  177. optimizer_new_func = _new_local_optimizer
  178. remote_optim_futs = []
  179. for worker, param_rrefs in per_worker_params_rref.items():
  180. remote_optim_rref_fut = rpc.rpc_async(
  181. worker,
  182. optimizer_new_func,
  183. args=(optim_ctor, param_rrefs) + args,
  184. kwargs=kwargs,
  185. )
  186. remote_optim_futs.append(remote_optim_rref_fut)
  187. self.remote_optimizers = _wait_for_all(remote_optim_futs)
  188. def step(self, context_id):
  189. """
  190. Performs a single optimization step.
  191. This will call :meth:`torch.optim.Optimizer.step` on each worker
  192. containing parameters to be optimized, and will block until all workers
  193. return. The provided ``context_id`` will be used to retrieve the
  194. corresponding :class:`~torch.distributed.autograd.context` that
  195. contains the gradients that should be applied to the parameters.
  196. Args:
  197. context_id: the autograd context id for which we should run the
  198. optimizer step.
  199. """
  200. dist_autograd._is_valid_context(context_id)
  201. if self.is_functional_optim:
  202. optimizer_step_func = _script_local_optimizer_step
  203. else:
  204. optimizer_step_func = _local_optimizer_step
  205. rpc_futs = []
  206. for optimizer in self.remote_optimizers:
  207. rpc_futs.append(rpc.rpc_async(
  208. optimizer.owner(),
  209. optimizer_step_func,
  210. args=(optimizer, context_id),
  211. ))
  212. _wait_for_all(rpc_futs)