grad_mode.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import sys
  2. import torch
  3. import functools
  4. import inspect
  5. from typing import Any, Callable, TypeVar, cast
  6. __all__ = ['no_grad', 'enable_grad', 'set_grad_enabled',
  7. 'inference_mode']
  8. # Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
  9. # See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
  10. FuncType = Callable[..., Any]
  11. F = TypeVar('F', bound=FuncType)
  12. class _DecoratorContextManager:
  13. """Allow a context manager to be used as a decorator"""
  14. def __call__(self, func: F) -> F:
  15. if inspect.isgeneratorfunction(func):
  16. return self._wrap_generator(func)
  17. @functools.wraps(func)
  18. def decorate_context(*args, **kwargs):
  19. with self.clone():
  20. return func(*args, **kwargs)
  21. return cast(F, decorate_context)
  22. def _wrap_generator(self, func):
  23. """Wrap each generator invocation with the context manager"""
  24. @functools.wraps(func)
  25. def generator_context(*args, **kwargs):
  26. gen = func(*args, **kwargs)
  27. # Generators are suspended and unsuspended at `yield`, hence we
  28. # make sure the grad mode is properly set every time the execution
  29. # flow returns into the wrapped generator and restored when it
  30. # returns through our `yield` to our caller (see PR #49017).
  31. try:
  32. # Issuing `None` to a generator fires it up
  33. with self.clone():
  34. response = gen.send(None)
  35. while True:
  36. try:
  37. # Forward the response to our caller and get its next request
  38. request = yield response
  39. except GeneratorExit:
  40. # Inform the still active generator about its imminent closure
  41. with self.clone():
  42. gen.close()
  43. raise
  44. except BaseException:
  45. # Propagate the exception thrown at us by the caller
  46. with self.clone():
  47. response = gen.throw(*sys.exc_info())
  48. else:
  49. # Pass the last request to the generator and get its response
  50. with self.clone():
  51. response = gen.send(request)
  52. # We let the exceptions raised above by the generator's `.throw` or
  53. # `.send` methods bubble up to our caller, except for StopIteration
  54. except StopIteration as e:
  55. # The generator informed us that it is done: take whatever its
  56. # returned value (if any) was and indicate that we're done too
  57. # by returning it (see docs for python's return-statement).
  58. return e.value
  59. return generator_context
  60. def __enter__(self) -> None:
  61. raise NotImplementedError
  62. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  63. raise NotImplementedError
  64. def clone(self):
  65. # override this method if your children class takes __init__ parameters
  66. return self.__class__()
  67. class no_grad(_DecoratorContextManager):
  68. r"""Context-manager that disabled gradient calculation.
  69. Disabling gradient calculation is useful for inference, when you are sure
  70. that you will not call :meth:`Tensor.backward()`. It will reduce memory
  71. consumption for computations that would otherwise have `requires_grad=True`.
  72. In this mode, the result of every computation will have
  73. `requires_grad=False`, even when the inputs have `requires_grad=True`.
  74. This context manager is thread local; it will not affect computation
  75. in other threads.
  76. Also functions as a decorator. (Make sure to instantiate with parenthesis.)
  77. .. note::
  78. No-grad is one of several mechanisms that can enable or
  79. disable gradients locally see :ref:`locally-disable-grad-doc` for
  80. more information on how they compare.
  81. .. note::
  82. This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
  83. If you want to disable forward AD for a computation, you can unpack
  84. your dual tensors.
  85. Example::
  86. >>> x = torch.tensor([1.], requires_grad=True)
  87. >>> with torch.no_grad():
  88. ... y = x * 2
  89. >>> y.requires_grad
  90. False
  91. >>> @torch.no_grad()
  92. ... def doubler(x):
  93. ... return x * 2
  94. >>> z = doubler(x)
  95. >>> z.requires_grad
  96. False
  97. """
  98. def __init__(self) -> None:
  99. if not torch._jit_internal.is_scripting():
  100. super().__init__()
  101. self.prev = False
  102. def __enter__(self) -> None:
  103. self.prev = torch.is_grad_enabled()
  104. torch.set_grad_enabled(False)
  105. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  106. torch.set_grad_enabled(self.prev)
  107. class enable_grad(_DecoratorContextManager):
  108. r"""Context-manager that enables gradient calculation.
  109. Enables gradient calculation, if it has been disabled via :class:`~no_grad`
  110. or :class:`~set_grad_enabled`.
  111. This context manager is thread local; it will not affect computation
  112. in other threads.
  113. Also functions as a decorator. (Make sure to instantiate with parenthesis.)
  114. .. note::
  115. enable_grad is one of several mechanisms that can enable or
  116. disable gradients locally see :ref:`locally-disable-grad-doc` for
  117. more information on how they compare.
  118. .. note::
  119. This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
  120. Example::
  121. >>> x = torch.tensor([1.], requires_grad=True)
  122. >>> with torch.no_grad():
  123. ... with torch.enable_grad():
  124. ... y = x * 2
  125. >>> y.requires_grad
  126. True
  127. >>> y.backward()
  128. >>> x.grad
  129. >>> @torch.enable_grad()
  130. ... def doubler(x):
  131. ... return x * 2
  132. >>> with torch.no_grad():
  133. ... z = doubler(x)
  134. >>> z.requires_grad
  135. True
  136. """
  137. def __enter__(self) -> None:
  138. self.prev = torch.is_grad_enabled()
  139. torch._C._set_grad_enabled(True)
  140. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  141. torch._C._set_grad_enabled(self.prev)
  142. class set_grad_enabled(_DecoratorContextManager):
  143. r"""Context-manager that sets gradient calculation to on or off.
  144. ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
  145. It can be used as a context-manager or as a function.
  146. This context manager is thread local; it will not affect computation
  147. in other threads.
  148. Args:
  149. mode (bool): Flag whether to enable grad (``True``), or disable
  150. (``False``). This can be used to conditionally enable
  151. gradients.
  152. .. note::
  153. set_grad_enabled is one of several mechanisms that can enable or
  154. disable gradients locally see :ref:`locally-disable-grad-doc` for
  155. more information on how they compare.
  156. .. note::
  157. This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
  158. Example::
  159. >>> x = torch.tensor([1.], requires_grad=True)
  160. >>> is_train = False
  161. >>> with torch.set_grad_enabled(is_train):
  162. ... y = x * 2
  163. >>> y.requires_grad
  164. False
  165. >>> torch.set_grad_enabled(True)
  166. >>> y = x * 2
  167. >>> y.requires_grad
  168. True
  169. >>> torch.set_grad_enabled(False)
  170. >>> y = x * 2
  171. >>> y.requires_grad
  172. False
  173. """
  174. def __init__(self, mode: bool) -> None:
  175. self.prev = torch.is_grad_enabled()
  176. torch._C._set_grad_enabled(mode)
  177. self.mode = mode
  178. def __enter__(self) -> None:
  179. pass
  180. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  181. torch._C._set_grad_enabled(self.prev)
  182. def clone(self):
  183. return self.__class__(self.mode)
  184. class inference_mode(_DecoratorContextManager):
  185. r"""Context-manager that enables or disables inference mode
  186. InferenceMode is a new context manager analogous to :class:`~no_grad`
  187. to be used when you are certain your operations will have no interactions
  188. with autograd (e.g., model training). Code run under this mode gets better
  189. performance by disabling view tracking and version counter bumps. Note that
  190. unlike some other mechanisms that locally enable or disable grad,
  191. entering inference_mode also disables to :ref:`forward-mode AD <forward-mode-ad>`.
  192. This context manager is thread local; it will not affect computation
  193. in other threads.
  194. Also functions as a decorator. (Make sure to instantiate with parenthesis.)
  195. .. note::
  196. Inference mode is one of several mechanisms that can enable or
  197. disable gradients locally see :ref:`locally-disable-grad-doc` for
  198. more information on how they compare.
  199. Args:
  200. mode (bool): Flag whether to enable or disable inference mode
  201. Example::
  202. >>> import torch
  203. >>> x = torch.ones(1, 2, 3, requires_grad=True)
  204. >>> with torch.inference_mode():
  205. ... y = x * x
  206. >>> y.requires_grad
  207. False
  208. >>> y._version
  209. Traceback (most recent call last):
  210. File "<stdin>", line 1, in <module>
  211. RuntimeError: Inference tensors do not track version counter.
  212. >>> @torch.inference_mode()
  213. ... def func(x):
  214. ... return x * x
  215. >>> out = func(x)
  216. >>> out.requires_grad
  217. False
  218. """
  219. def __init__(self, mode=True):
  220. if not torch._jit_internal.is_scripting():
  221. super().__init__()
  222. # Holds a python binding to a RAII guard that can enable or disable
  223. # inference mode
  224. self._inference_mode_raii_guard = None
  225. self.mode = mode
  226. def __enter__(self):
  227. self._inference_mode_raii_guard = torch._C._InferenceMode(self.mode)
  228. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  229. del self._inference_mode_raii_guard
  230. def clone(self):
  231. return self.__class__(self.mode)