| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298 |
- import sys
- import torch
- import functools
- import inspect
- from typing import Any, Callable, TypeVar, cast
- __all__ = ['no_grad', 'enable_grad', 'set_grad_enabled',
- 'inference_mode']
- # Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
- # See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
- FuncType = Callable[..., Any]
- F = TypeVar('F', bound=FuncType)
- class _DecoratorContextManager:
- """Allow a context manager to be used as a decorator"""
- def __call__(self, func: F) -> F:
- if inspect.isgeneratorfunction(func):
- return self._wrap_generator(func)
- @functools.wraps(func)
- def decorate_context(*args, **kwargs):
- with self.clone():
- return func(*args, **kwargs)
- return cast(F, decorate_context)
- def _wrap_generator(self, func):
- """Wrap each generator invocation with the context manager"""
- @functools.wraps(func)
- def generator_context(*args, **kwargs):
- gen = func(*args, **kwargs)
- # Generators are suspended and unsuspended at `yield`, hence we
- # make sure the grad mode is properly set every time the execution
- # flow returns into the wrapped generator and restored when it
- # returns through our `yield` to our caller (see PR #49017).
- try:
- # Issuing `None` to a generator fires it up
- with self.clone():
- response = gen.send(None)
- while True:
- try:
- # Forward the response to our caller and get its next request
- request = yield response
- except GeneratorExit:
- # Inform the still active generator about its imminent closure
- with self.clone():
- gen.close()
- raise
- except BaseException:
- # Propagate the exception thrown at us by the caller
- with self.clone():
- response = gen.throw(*sys.exc_info())
- else:
- # Pass the last request to the generator and get its response
- with self.clone():
- response = gen.send(request)
- # We let the exceptions raised above by the generator's `.throw` or
- # `.send` methods bubble up to our caller, except for StopIteration
- except StopIteration as e:
- # The generator informed us that it is done: take whatever its
- # returned value (if any) was and indicate that we're done too
- # by returning it (see docs for python's return-statement).
- return e.value
- return generator_context
- def __enter__(self) -> None:
- raise NotImplementedError
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
- raise NotImplementedError
- def clone(self):
- # override this method if your children class takes __init__ parameters
- return self.__class__()
- class no_grad(_DecoratorContextManager):
- r"""Context-manager that disabled gradient calculation.
- Disabling gradient calculation is useful for inference, when you are sure
- that you will not call :meth:`Tensor.backward()`. It will reduce memory
- consumption for computations that would otherwise have `requires_grad=True`.
- In this mode, the result of every computation will have
- `requires_grad=False`, even when the inputs have `requires_grad=True`.
- This context manager is thread local; it will not affect computation
- in other threads.
- Also functions as a decorator. (Make sure to instantiate with parenthesis.)
- .. note::
- No-grad is one of several mechanisms that can enable or
- disable gradients locally see :ref:`locally-disable-grad-doc` for
- more information on how they compare.
- .. note::
- This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
- If you want to disable forward AD for a computation, you can unpack
- your dual tensors.
- Example::
- >>> x = torch.tensor([1.], requires_grad=True)
- >>> with torch.no_grad():
- ... y = x * 2
- >>> y.requires_grad
- False
- >>> @torch.no_grad()
- ... def doubler(x):
- ... return x * 2
- >>> z = doubler(x)
- >>> z.requires_grad
- False
- """
- def __init__(self) -> None:
- if not torch._jit_internal.is_scripting():
- super().__init__()
- self.prev = False
- def __enter__(self) -> None:
- self.prev = torch.is_grad_enabled()
- torch.set_grad_enabled(False)
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
- torch.set_grad_enabled(self.prev)
- class enable_grad(_DecoratorContextManager):
- r"""Context-manager that enables gradient calculation.
- Enables gradient calculation, if it has been disabled via :class:`~no_grad`
- or :class:`~set_grad_enabled`.
- This context manager is thread local; it will not affect computation
- in other threads.
- Also functions as a decorator. (Make sure to instantiate with parenthesis.)
- .. note::
- enable_grad is one of several mechanisms that can enable or
- disable gradients locally see :ref:`locally-disable-grad-doc` for
- more information on how they compare.
- .. note::
- This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
- Example::
- >>> x = torch.tensor([1.], requires_grad=True)
- >>> with torch.no_grad():
- ... with torch.enable_grad():
- ... y = x * 2
- >>> y.requires_grad
- True
- >>> y.backward()
- >>> x.grad
- >>> @torch.enable_grad()
- ... def doubler(x):
- ... return x * 2
- >>> with torch.no_grad():
- ... z = doubler(x)
- >>> z.requires_grad
- True
- """
- def __enter__(self) -> None:
- self.prev = torch.is_grad_enabled()
- torch._C._set_grad_enabled(True)
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
- torch._C._set_grad_enabled(self.prev)
- class set_grad_enabled(_DecoratorContextManager):
- r"""Context-manager that sets gradient calculation to on or off.
- ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
- It can be used as a context-manager or as a function.
- This context manager is thread local; it will not affect computation
- in other threads.
- Args:
- mode (bool): Flag whether to enable grad (``True``), or disable
- (``False``). This can be used to conditionally enable
- gradients.
- .. note::
- set_grad_enabled is one of several mechanisms that can enable or
- disable gradients locally see :ref:`locally-disable-grad-doc` for
- more information on how they compare.
- .. note::
- This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
- Example::
- >>> x = torch.tensor([1.], requires_grad=True)
- >>> is_train = False
- >>> with torch.set_grad_enabled(is_train):
- ... y = x * 2
- >>> y.requires_grad
- False
- >>> torch.set_grad_enabled(True)
- >>> y = x * 2
- >>> y.requires_grad
- True
- >>> torch.set_grad_enabled(False)
- >>> y = x * 2
- >>> y.requires_grad
- False
- """
- def __init__(self, mode: bool) -> None:
- self.prev = torch.is_grad_enabled()
- torch._C._set_grad_enabled(mode)
- self.mode = mode
- def __enter__(self) -> None:
- pass
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
- torch._C._set_grad_enabled(self.prev)
- def clone(self):
- return self.__class__(self.mode)
- class inference_mode(_DecoratorContextManager):
- r"""Context-manager that enables or disables inference mode
- InferenceMode is a new context manager analogous to :class:`~no_grad`
- to be used when you are certain your operations will have no interactions
- with autograd (e.g., model training). Code run under this mode gets better
- performance by disabling view tracking and version counter bumps. Note that
- unlike some other mechanisms that locally enable or disable grad,
- entering inference_mode also disables to :ref:`forward-mode AD <forward-mode-ad>`.
- This context manager is thread local; it will not affect computation
- in other threads.
- Also functions as a decorator. (Make sure to instantiate with parenthesis.)
- .. note::
- Inference mode is one of several mechanisms that can enable or
- disable gradients locally see :ref:`locally-disable-grad-doc` for
- more information on how they compare.
- Args:
- mode (bool): Flag whether to enable or disable inference mode
- Example::
- >>> import torch
- >>> x = torch.ones(1, 2, 3, requires_grad=True)
- >>> with torch.inference_mode():
- ... y = x * x
- >>> y.requires_grad
- False
- >>> y._version
- Traceback (most recent call last):
- File "<stdin>", line 1, in <module>
- RuntimeError: Inference tensors do not track version counter.
- >>> @torch.inference_mode()
- ... def func(x):
- ... return x * x
- >>> out = func(x)
- >>> out.requires_grad
- False
- """
- def __init__(self, mode=True):
- if not torch._jit_internal.is_scripting():
- super().__init__()
- # Holds a python binding to a RAII guard that can enable or disable
- # inference mode
- self._inference_mode_raii_guard = None
- self.mode = mode
- def __enter__(self):
- self._inference_mode_raii_guard = torch._C._InferenceMode(self.mode)
- def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
- del self._inference_mode_raii_guard
- def clone(self):
- return self.__class__(self.mode)
|