_vmap_internals.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. import torch
  2. import functools
  3. from torch import Tensor
  4. from typing import Any, Callable, Optional, Tuple, Union, List
  5. from torch.utils._pytree import tree_flatten, tree_unflatten, _broadcast_to_and_flatten
  6. import warnings
  7. in_dims_t = Union[int, Tuple]
  8. out_dims_t = Union[int, Tuple[int, ...]]
  9. # Checks that all args-to-be-batched have the same batch dim size
  10. def _validate_and_get_batch_size(
  11. flat_in_dims: List[Optional[int]],
  12. flat_args: List) -> int:
  13. batch_sizes = [arg.size(in_dim) for in_dim, arg in zip(flat_in_dims, flat_args)
  14. if in_dim is not None]
  15. if batch_sizes and any([size != batch_sizes[0] for size in batch_sizes]):
  16. raise ValueError(
  17. f'vmap: Expected all tensors to have the same size in the mapped '
  18. f'dimension, got sizes {batch_sizes} for the mapped dimension')
  19. return batch_sizes[0]
  20. def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
  21. if isinstance(batched_outputs, tuple):
  22. return len(batched_outputs)
  23. return 1
  24. # If value is a tuple, check it has length `num_elements`.
  25. # If value is not a tuple, make a tuple with `value` repeated `num_elements` times
  26. def _as_tuple(value: Any, num_elements: int, error_message_lambda: Callable[[], str]) -> Tuple:
  27. if not isinstance(value, tuple):
  28. return (value,) * num_elements
  29. if len(value) != num_elements:
  30. raise ValueError(error_message_lambda())
  31. return value
  32. # Creates BatchedTensors for every Tensor in arg that should be batched.
  33. # Returns the (potentially) batched arguments and the batch_size.
  34. def _create_batched_inputs(
  35. in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable) -> Tuple[Tuple, int]:
  36. if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
  37. raise ValueError(
  38. f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
  39. f'expected `in_dims` to be int or a (potentially nested) tuple '
  40. f'matching the structure of inputs, got: {type(in_dims)}.')
  41. if len(args) == 0:
  42. raise ValueError(
  43. f'vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add '
  44. f'inputs, or you are trying to vmap over a function with no inputs. '
  45. f'The latter is unsupported.')
  46. flat_args, args_spec = tree_flatten(args)
  47. flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
  48. if flat_in_dims is None:
  49. raise ValueError(
  50. f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
  51. f'in_dims is not compatible with the structure of `inputs`. '
  52. f'in_dims has structure {tree_flatten(in_dims)[1]} but inputs '
  53. f'has structure {args_spec}.')
  54. for arg, in_dim in zip(flat_args, flat_in_dims):
  55. if not isinstance(in_dim, int) and in_dim is not None:
  56. raise ValueError(
  57. f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
  58. f'Got in_dim={in_dim} for an input but in_dim must be either '
  59. f'an integer dimension or None.')
  60. if isinstance(in_dim, int) and not isinstance(arg, Tensor):
  61. raise ValueError(
  62. f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
  63. f'Got in_dim={in_dim} for an input but the input is of type '
  64. f'{type(arg)}. We cannot vmap over non-Tensor arguments, '
  65. f'please use None as the respective in_dim')
  66. if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()):
  67. raise ValueError(
  68. f'vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): '
  69. f'Got in_dim={in_dim} for some input, but that input is a Tensor '
  70. f'of dimensionality {arg.dim()} so expected in_dim to satisfy '
  71. f'0 <= in_dim < {arg.dim()}.')
  72. batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
  73. # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
  74. batched_inputs = [arg if in_dim is None else
  75. torch._add_batch_dim(arg, in_dim, vmap_level)
  76. for in_dim, arg in zip(flat_in_dims, flat_args)]
  77. return tree_unflatten(batched_inputs, args_spec), batch_size
  78. # Undos the batching (and any batch dimensions) associated with the `vmap_level`.
  79. def _unwrap_batched(
  80. batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
  81. out_dims: out_dims_t, vmap_level: int, batch_size: int, func: Callable,
  82. allow_none_pass_through: bool = False) -> Tuple:
  83. num_outputs = _num_outputs(batched_outputs)
  84. out_dims_as_tuple = _as_tuple(
  85. out_dims, num_outputs,
  86. lambda: f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must '
  87. f'have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.')
  88. # NOTE [Ignored _remove_batch_dim, _add_batch_dim]
  89. # There is something wrong with our type bindings for functions that begin
  90. # with '_', see #40397.
  91. if isinstance(batched_outputs, Tensor):
  92. out_dim = out_dims_as_tuple[0]
  93. return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore[return-value]
  94. if allow_none_pass_through:
  95. return tuple((torch._remove_batch_dim(out, vmap_level, batch_size, out_dim) if out is not None else None)
  96. for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
  97. else:
  98. return tuple(torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
  99. for out, out_dim in zip(batched_outputs, out_dims_as_tuple))
  100. # Checks that `fn` returned one or more Tensors and nothing else.
  101. # NB: A python function that return multiple arguments returns a single tuple,
  102. # so we are effectively checking that `outputs` is a single Tensor or a tuple of
  103. # Tensors.
  104. def _validate_outputs(outputs: Any, func: Callable) -> None:
  105. if isinstance(outputs, Tensor):
  106. return
  107. if not isinstance(outputs, tuple):
  108. raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return '
  109. f'Tensors, got type {type(outputs)} as the return.')
  110. for idx, output in enumerate(outputs):
  111. if isinstance(output, Tensor):
  112. continue
  113. raise ValueError(f'vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return '
  114. f'Tensors, got type {type(output)} for return {idx}.')
  115. def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None:
  116. if isinstance(out_dims, int):
  117. return
  118. if not isinstance(out_dims, tuple) or \
  119. not all([isinstance(out_dim, int) for out_dim in out_dims]):
  120. raise ValueError(
  121. f'vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be '
  122. f'an int or a tuple of int representing where in the outputs the '
  123. f'vmapped dimension should appear.')
  124. def _get_name(func: Callable):
  125. if hasattr(func, '__name__'):
  126. return func.__name__
  127. # Not all callables have __name__, in fact, only static functions/methods do.
  128. # A callable created via functools.partial or an nn.Module, to name some
  129. # examples, don't have a __name__.
  130. return repr(func)
  131. # vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
  132. # sends those into func, and then unwraps the output BatchedTensors. Operations
  133. # on BatchedTensors perform the batched operations that the user is asking for.
  134. def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
  135. """
  136. vmap is the vectorizing map. Returns a new function that maps `func` over some
  137. dimension of the inputs. Semantically, vmap pushes the map into PyTorch
  138. operations called by `func`, effectively vectorizing those operations.
  139. vmap is useful for handling batch dimensions: one can write a function `func`
  140. that runs on examples and then lift it to a function that can take batches of
  141. examples with `vmap(func)`. vmap can also be used to compute batched
  142. gradients when composed with autograd.
  143. .. note::
  144. We have moved development of vmap to
  145. `functorch. <https://github.com/pytorch/functorch>`_ functorch's
  146. vmap is able to arbitrarily compose with gradient computation
  147. and contains significant performance improvements.
  148. Please give that a try if that is what you're looking for.
  149. Furthermore, if you're interested in using vmap for your use case,
  150. please `contact us! <https://github.com/pytorch/pytorch/issues/42368>`_
  151. We're interested in gathering feedback from early adopters to inform
  152. the design.
  153. .. warning::
  154. torch.vmap is an experimental prototype that is subject to
  155. change and/or deletion. Please use at your own risk.
  156. Args:
  157. func (function): A Python function that takes one or more arguments.
  158. Must return one or more Tensors.
  159. in_dims (int or nested structure): Specifies which dimension of the
  160. inputs should be mapped over. `in_dims` should have a structure
  161. like the inputs. If the `in_dim` for a particular input is None,
  162. then that indicates there is no map dimension. Default: 0.
  163. out_dims (int or Tuple[int]): Specifies where the mapped dimension
  164. should appear in the outputs. If `out_dims` is a Tuple, then it should
  165. have one element per output. Default: 0.
  166. Returns:
  167. Returns a new "batched" function. It takes the same inputs as `func`,
  168. except each input has an extra dimension at the index specified by `in_dims`.
  169. It takes returns the same outputs as `func`, except each output has
  170. an extra dimension at the index specified by `out_dims`.
  171. .. warning:
  172. vmap works best with functional-style code. Please do not perform any
  173. side-effects in `func`, with the exception of in-place PyTorch operations.
  174. Examples of side-effects include mutating Python data structures and
  175. assigning values to variables not captured in `func`.
  176. One example of using `vmap` is to compute batched dot products. PyTorch
  177. doesn't provide a batched `torch.dot` API; instead of unsuccessfully
  178. rummaging through docs, use `vmap` to construct a new function.
  179. >>> torch.dot # [D], [D] -> []
  180. >>> batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
  181. >>> x, y = torch.randn(2, 5), torch.randn(2, 5)
  182. >>> batched_dot(x, y)
  183. `vmap` can be helpful in hiding batch dimensions, leading to a simpler
  184. model authoring experience.
  185. >>> batch_size, feature_size = 3, 5
  186. >>> weights = torch.randn(feature_size, requires_grad=True)
  187. >>>
  188. >>> def model(feature_vec):
  189. >>> # Very simple linear model with activation
  190. >>> return feature_vec.dot(weights).relu()
  191. >>>
  192. >>> examples = torch.randn(batch_size, feature_size)
  193. >>> result = torch.vmap(model)(examples)
  194. `vmap` can also help vectorize computations that were previously difficult
  195. or impossible to batch. One example is higher-order gradient computation.
  196. The PyTorch autograd engine computes vjps (vector-Jacobian products).
  197. Computing a full Jacobian matrix for some function f: R^N -> R^N usually
  198. requires N calls to `autograd.grad`, one per Jacobian row. Using `vmap`,
  199. we can vectorize the whole computation, computing the Jacobian in a single
  200. call to `autograd.grad`.
  201. >>> # Setup
  202. >>> N = 5
  203. >>> f = lambda x: x ** 2
  204. >>> x = torch.randn(N, requires_grad=True)
  205. >>> y = f(x)
  206. >>> I_N = torch.eye(N)
  207. >>>
  208. >>> # Sequential approach
  209. >>> jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
  210. >>> for v in I_N.unbind()]
  211. >>> jacobian = torch.stack(jacobian_rows)
  212. >>>
  213. >>> # vectorized gradient computation
  214. >>> def get_vjp(v):
  215. >>> return torch.autograd.grad(y, x, v)
  216. >>> jacobian = torch.vmap(get_vjp)(I_N)
  217. .. note::
  218. vmap does not provide general autobatching or handle variable-length
  219. sequences out of the box.
  220. """
  221. warnings.warn(
  222. 'Please use functorch.vmap instead of torch.vmap '
  223. '(https://github.com/pytorch/functorch). '
  224. 'We\'ve moved development on torch.vmap over to functorch; '
  225. 'functorch\'s vmap has a multitude of significant performance and '
  226. 'functionality improvements.',
  227. stacklevel=2)
  228. return _vmap(func, in_dims, out_dims)
  229. # A version of vmap but without the initial "experimental prototype" warning
  230. def _vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0, allow_none_pass_through: bool = False) -> Callable:
  231. # The `allow_none_pass_through` argument is a temporary workaround may be removed.
  232. # Currently it enables us to wrap the call in `autograd.grad` to the autograd engine,
  233. # which may return None if any of the inputs are unused. See the issue discussing this:
  234. # https://github.com/facebookresearch/functorch/issues/159.
  235. @functools.wraps(func)
  236. def wrapped(*args):
  237. _check_out_dims_is_int_or_int_tuple(out_dims, func)
  238. vmap_level = torch._C._vmapmode_increment_nesting()
  239. try:
  240. batched_inputs, batch_size = _create_batched_inputs(in_dims, args, vmap_level, func)
  241. batched_outputs = func(*batched_inputs)
  242. if not allow_none_pass_through:
  243. _validate_outputs(batched_outputs, func)
  244. return _unwrap_batched(batched_outputs, out_dims, vmap_level, batch_size, func,
  245. allow_none_pass_through=allow_none_pass_through)
  246. finally:
  247. torch._C._vmapmode_decrement_nesting()
  248. return wrapped