wrap.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. #
  3. # This source code is licensed under the BSD license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import contextlib
  6. from typing import (
  7. Any,
  8. Callable,
  9. Dict,
  10. Generator,
  11. Optional,
  12. Set,
  13. Tuple,
  14. Type,
  15. cast,
  16. )
  17. import torch.nn as nn
  18. from torch.nn.modules.batchnorm import _BatchNorm
  19. def always_wrap_policy(*args, **kwargs) -> bool:
  20. """
  21. A simple wrapper policy that always returns ``True``,
  22. i.e. when passed as the `auto_wrap_policy` into FSDP,
  23. this will result in all submodules being wrapped as
  24. distinct FSDP instances.
  25. """
  26. return True
  27. def transformer_auto_wrap_policy(
  28. module: nn.Module,
  29. recurse: bool,
  30. unwrapped_params: int,
  31. transformer_layer_cls: Set[Type[nn.Module]],
  32. ) -> bool:
  33. """
  34. A convenient auto wrap policy for transformer models. If the submodule
  35. is an instance of transformer_layer_cls, the submodule will be wrapped
  36. as a FSDP unit. Otherwise, all the other remainder submodules are wrapped
  37. by the outermost FSDP unit. Right now, FSDP requires submodules that share
  38. weights to be wrapped in the same FSDP unit, this auto wrap policy can
  39. conviniently wrap the shared embeddings into the same FSDP unit for transformer
  40. models. In the near future, FSDP will support submodules that share weights
  41. to be wrapped in the separated FSDP units.
  42. Return if a module should be wrapped during FSDP auto wrapping.
  43. The first three parameters are required by :func:`_recursive_wrap`.
  44. Args:
  45. module (nn.Module):
  46. The module to be considered in this decision.
  47. recurse (bool):
  48. Indicate if this is called to make a decision on whether we
  49. should recurse down a subgraph of the module structure.
  50. If False, it means this function is called to make a decision
  51. on whether we should wrap the said module.
  52. unwrapped_params (int):
  53. The number of parameters yet to be wrapped in this module.
  54. transformer_layer_cls (int):
  55. Submodules with one of the `transformer_layer_cls` names
  56. will be wrapped as seperated FSDP units
  57. """
  58. if recurse:
  59. # always recurse
  60. return True
  61. else:
  62. # if not recursing, decide whether we should wrap for the leaf node or reminder
  63. return isinstance(module, tuple(transformer_layer_cls))
  64. def _wrap_batchnorm_individually(
  65. module: nn.Module,
  66. recurse: bool,
  67. *args,
  68. **kwargs,
  69. ) -> bool:
  70. """
  71. A policy that wraps ``BatchNorm`` instances in their own FSDP unit.
  72. """
  73. if recurse:
  74. # always recurse
  75. return True
  76. else:
  77. # if not recursing, decide whether we should wrap based on whether it is a
  78. # BN layer or not.
  79. return isinstance(module, _BatchNorm)
  80. def _or_policy(
  81. module: nn.Module,
  82. recurse: bool,
  83. unwrapped_params: int,
  84. policies,
  85. ) -> bool:
  86. """
  87. A policy that wraps ``module`` if any policy in the passed in iterable of
  88. ``policies`` returns ``True``.
  89. """
  90. return any(
  91. policy(module, recurse, unwrapped_params) for policy in policies
  92. )
  93. def size_based_auto_wrap_policy(
  94. module: nn.Module,
  95. recurse: bool,
  96. unwrapped_params: int,
  97. # These are customizable for this policy function.
  98. min_num_params: int = int(1e8),
  99. force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
  100. exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
  101. ) -> bool:
  102. """A size based auto_wrap_policy function for FSDP API.
  103. Return if a module should be wrapped during FSDP auto wrapping.
  104. The first three parameters are used by :func:`_recursive_wrap`. If
  105. you write a custom version of this policy function, your version
  106. needs to at least accept the first three parameters and free
  107. to do whatever you want in the function.
  108. Args:
  109. module (nn.Module):
  110. The module to be considered in this decision.
  111. recurse (bool):
  112. Indicate if this is called to make a decision on whether we
  113. should recurse down a subgraph of the module structure.
  114. If False, it means this function is called to make a decision
  115. on whether we should wrap the said module.
  116. unwrapped_params (int):
  117. The number of parameters yet to be wrapped in this module.
  118. min_num_params (int):
  119. Customizable policy input. It controls the size threshold
  120. on how big should a module be to be considered wrapped.
  121. force_leaf_modules (Set[Type[nn.Module]]): set of module types to
  122. keep as leaves, i.e., their children will never be wrapped.
  123. exclude_wrap_modules (Set[Type[nn.Module]]):
  124. Customizable set of module types to be excluded in wrapping.
  125. """
  126. force_leaf_modules = (
  127. size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined]
  128. if force_leaf_modules is None
  129. else force_leaf_modules
  130. )
  131. exclude_wrap_modules = (
  132. size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined]
  133. if exclude_wrap_modules is None
  134. else exclude_wrap_modules
  135. )
  136. is_large = unwrapped_params >= min_num_params
  137. if recurse:
  138. # We should recurse if the module is big enough but not in force_leaf_modules list.
  139. return is_large and not isinstance(module, tuple(force_leaf_modules))
  140. else:
  141. # If we are not recursing, determine if we should wrap.
  142. return is_large and not isinstance(module, tuple(exclude_wrap_modules))
  143. # Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported.
  144. size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined]
  145. size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined]
  146. @contextlib.contextmanager
  147. def enable_wrap(
  148. *, wrapper_cls: Any, **wrapper_kwargs: Any
  149. ) -> Generator[None, None, None]:
  150. """
  151. Context manager to wrap modules using a wrapper.
  152. Useful for when you'd like to apply the same configuration arguments to all
  153. child modules that you wrap. A particularly important use case is wrapping
  154. large layers so that they get sharded (in-place) during initialization, to
  155. avoid running out of system memory. Large layers can indicate that they
  156. should be sharded via the ``wrap`` annotation and this context manager can
  157. provide the exact configuration for these nested instances.
  158. Usage::
  159. with enable_wrap(wrapper_cls, **params):
  160. # Wraps layer in FSDP by default if within context
  161. self.l1 = wrap(torch.nn.Linear(5, 5))
  162. Args:
  163. wrapper_cls:
  164. Class that `wrap` annotation will `wrap` modules with, such as
  165. `FullyShardedDataParallel`.
  166. **wrapper_kwargs:
  167. Configuration settings that will be passed to all ``wrap``
  168. instances inside the context
  169. """
  170. kwargs = {
  171. **{"wrapper_cls": wrapper_cls},
  172. **wrapper_kwargs,
  173. }
  174. with _ConfigAutoWrap(**kwargs):
  175. yield
  176. def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
  177. """
  178. Annotate that a module should be wrapped. Annotated modules will only be
  179. wrapped if inside of an :func:`enable_wrap` context manager. This allows
  180. a module to be initialized both with and without a wrapper without code
  181. change.
  182. The class that this function wraps the passed in ``nn.Module`` with is the
  183. passed in ``wrapper_cls`` argument into ``enable_wrap``. Both
  184. ``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct
  185. the ``wrapper_cls`` instance. In the case of duplicate kwargs in
  186. ``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be
  187. respected.
  188. Usage::
  189. with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
  190. # Wraps layer in FSDP by default if within context
  191. self.l1 = wrap(torch.nn.Linear(5, 5))
  192. Args:
  193. module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
  194. **wrap_overrides: configuration overrides that will take priority over
  195. the values provided by the :func:`enable_wrap` context
  196. """
  197. if _ConfigAutoWrap.in_autowrap_context:
  198. assert _ConfigAutoWrap.wrapper_cls is not None
  199. wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides}
  200. return _wrap(
  201. module,
  202. _ConfigAutoWrap.wrapper_cls,
  203. **wrap_overrides,
  204. )
  205. return module
  206. def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
  207. assert wrapper_cls is not None
  208. if hasattr(module, '_wrap_overrides'):
  209. # If module has a _wrap_overrides attribute, we force overriding the
  210. # FSDP config with these attributes for this module. Currently this
  211. # is only used to disable mixed precision for BatchNorm when
  212. # auto_wrapping.
  213. overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type]
  214. return wrapper_cls(module, **overrides)
  215. return wrapper_cls(module, **kwargs)
  216. def _recursive_wrap(
  217. module: nn.Module,
  218. auto_wrap_policy: Callable,
  219. wrapper_cls: Callable,
  220. ignored_modules: Set[nn.Module],
  221. ignored_params: Set[nn.Parameter],
  222. only_wrap_children: bool = False,
  223. **kwargs: Any
  224. ) -> Tuple[nn.Module, int]:
  225. """
  226. Automatically wrap child modules of *module* that meet the given
  227. criteria with :func:`auto_wrap`. Does not rely on _ConfigAutoWrap.
  228. Args:
  229. module (nn.Module):
  230. module to recursively wrap
  231. auto_wrap_policy (Callable):
  232. A callable specifying a policy to recursively wrap layers with FSDP.
  233. ignored_modules (Set[torch.nn.Module]): Modules to ignore when
  234. wrapping.
  235. ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when
  236. wrapping; these should be the parameters contained in the modules
  237. in ``ignored_modules``.
  238. Returns:
  239. (nn.Module, int):
  240. Wrapped module and the number parameters wrapped recursively.
  241. """
  242. assert auto_wrap_policy is not None, "Must specify auto_wrap_policy."
  243. assert wrapper_cls is not None, "Must specify wrapper_cls"
  244. # Make sure no child is already wrapped.
  245. for _, child in module.named_modules():
  246. if child in ignored_modules:
  247. continue
  248. assert not isinstance(child, cast(type, wrapper_cls))
  249. # We count all params, assuming none of them are already wrapped.
  250. num_params = sum(
  251. p.numel() for p in module.parameters() if p not in ignored_params
  252. )
  253. assert auto_wrap_policy is not None
  254. if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params):
  255. total_wrapped_params = 0
  256. # Iterate through the children, recursively wrap if necessary
  257. for name, child in module.named_children():
  258. if child in ignored_modules:
  259. continue
  260. wrapped_child, num_wrapped_params = _recursive_wrap(
  261. module=child,
  262. auto_wrap_policy=auto_wrap_policy,
  263. wrapper_cls=wrapper_cls,
  264. ignored_modules=ignored_modules,
  265. ignored_params=ignored_params,
  266. **kwargs,
  267. )
  268. setattr(module, name, wrapped_child)
  269. # Keep track of how many parameters have been wrapped
  270. total_wrapped_params += num_wrapped_params
  271. # decide if we need to wrap the current module,
  272. # since the left over parameters exceed the number of params to wrap
  273. remainder = num_params - total_wrapped_params
  274. if not only_wrap_children and auto_wrap_policy(
  275. module=module, recurse=False, unwrapped_params=remainder
  276. ):
  277. # Leaf node or final wrapping of the remainder both happen here.
  278. return _wrap(module, wrapper_cls, **kwargs), num_params
  279. else:
  280. return module, total_wrapped_params
  281. return module, 0
  282. class _ConfigAutoWrap:
  283. """
  284. Helper class to wrap modules based on default config args via a context manager.
  285. See :func:`enable_wrap` for more information.
  286. """
  287. in_autowrap_context: bool = False # Context flag
  288. wrapper_cls: Optional[Callable] = None # The wrapper class
  289. kwargs: Dict[str, Any] = {} # Wrapper's args
  290. def __init__(self, **kwargs: Dict[str, Any]):
  291. self.kwargs = kwargs
  292. @staticmethod
  293. def enable_autowrap_context(kwargs: Any) -> None:
  294. if _ConfigAutoWrap.in_autowrap_context:
  295. raise NotImplementedError(
  296. "You are already within an autowrap context and we currently do not supported nested autowrap."
  297. )
  298. _ConfigAutoWrap.in_autowrap_context = True
  299. # Get and save the wrapper cls for the context.
  300. assert (
  301. "wrapper_cls" in kwargs.keys()
  302. ), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
  303. _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
  304. del kwargs["wrapper_cls"]
  305. # Save the rest.
  306. _ConfigAutoWrap.kwargs = kwargs
  307. @staticmethod
  308. def disable_autowrap_context() -> None:
  309. _ConfigAutoWrap.in_autowrap_context = False
  310. _ConfigAutoWrap.wrapper_cls = None
  311. _ConfigAutoWrap.kwargs = {}
  312. def __enter__(self) -> None:
  313. self.enable_autowrap_context(self.kwargs)
  314. def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
  315. self.disable_autowrap_context()