| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- import contextlib
- from typing import Any, Callable, Dict, Iterator, List, Tuple
- import torch
- from torch import Tensor
- __all__ = ["functional_call"]
- # We avoid typing module here because module attributes are declared as Union[Parameter, Tensor] by default
- # and using other types causes mypy errors
- def _change_class(module, params_and_buffers) -> None:
- cls = module.__class__
- attr_to_path : Dict[str, str] = module._attr_to_path
- def _getattribute(self, name: str) -> Any:
- if name in attr_to_path:
- return params_and_buffers[attr_to_path[name]]
- return cls.__getattribute__(self, name)
- def _setattr(self, name: str, value: Any) -> None:
- if name in attr_to_path:
- params_and_buffers[attr_to_path[name]] = value
- else:
- return cls.__setattr__(self, name, value)
- param_cls = type(
- f"StatelessReplacer{cls.__name__}",
- (cls,),
- {
- "__getattribute__": _getattribute,
- "__setattr__": _setattr,
- },
- )
- module.__class__ = param_cls
- module._orig_class = cls
- def _create_swap_params(params_and_buffers):
- def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None:
- # Changes the module class to get a new __getattr__ dunder method
- # that looks for the reparametrized tensor
- if hasattr(module, "_attr_to_path"):
- module._attr_to_path[tensor_name] = full_path
- else:
- module._attr_to_path = {}
- module._attr_to_path[tensor_name] = full_path
- _change_class(module, params_and_buffers)
- return _swap_parameters
- def _remove_swap(module, name: str, full_path: str) -> None:
- if hasattr(module, "_orig_class"):
- module.__class__ = module._orig_class
- delattr(module, "_orig_class")
- delattr(module, "_attr_to_path")
- @contextlib.contextmanager
- def _reparametrize_module(
- module: 'torch.nn.Module',
- parameters_and_buffers: Dict[str, Tensor],
- ) -> Iterator[None]:
- for name, tensor in parameters_and_buffers.items():
- _apply_func_submodules(
- _create_swap_params(parameters_and_buffers),
- module, name.split("."), name, (tensor,))
- try:
- yield
- finally:
- for name in parameters_and_buffers:
- _apply_func_submodules(
- _remove_swap,
- module, name.split("."), name, ())
- def _apply_func_submodules(
- func: Callable[..., None],
- module: 'torch.nn.Module',
- path: List[str],
- full_path: str,
- args: Tuple,
- ):
- if len(path) == 1:
- func(module, path[0], full_path, *args)
- else:
- _apply_func_submodules(func, getattr(module, path[0]), path[1:], full_path, args)
- def functional_call(
- module: 'torch.nn.Module',
- parameters_and_buffers: Dict[str, Tensor],
- args: Tuple,
- kwargs : Dict[str, Any] = None,
- ):
- r"""Performs a functional call on the module by replacing the module parameters
- and buffers with the provided ones.
- .. note:: If the module has active parametrizations, passing a value in the
- :attr:`parameters_and_buffers` argument with the name set to the regular parameter
- name will completely disable the parametrization.
- If you want to apply the parametrization function to the value passed
- please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``.
- .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected
- in the `parameters_and_buffers` input.
- Example::
- >>> a = {'foo': torch.zeros(())}
- >>> mod = Foo() # does self.foo = self.foo + 1
- >>> print(mod.foo) # tensor(0.)
- >>> functional_call(mod, a, torch.ones(()))
- >>> print(mod.foo) # tensor(0.)
- >>> print(a['foo']) # tensor(1.)
- Args:
- module (torch.nn.Module): the module to call
- parameters_and_buffers (dict of str and Tensor): the parameters that will be used in
- the module call.
- args (tuple): arguments to be passed to the module call
- kwargs (dict): keyword arguments to be passed to the module call
- Returns:
- Any: the result of calling ``module``.
- """
- # TODO allow kwargs such as unsafe and others for parametrization
- if (
- torch.jit.is_tracing()
- or torch.jit.is_scripting()
- or isinstance(module, (
- torch.jit.RecursiveScriptModule,
- torch.jit.ScriptModule,
- torch.jit.ScriptFunction)
- )
- ):
- raise RuntimeError("The stateless API can't be used with Jitted modules")
- if kwargs is None:
- kwargs = {}
- with _reparametrize_module(module, parameters_and_buffers):
- if isinstance(args, tuple):
- out = module(*args, **kwargs)
- else:
- out = module(args, **kwargs)
- return out
|