stateless.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import contextlib
  2. from typing import Any, Callable, Dict, Iterator, List, Tuple
  3. import torch
  4. from torch import Tensor
  5. __all__ = ["functional_call"]
  6. # We avoid typing module here because module attributes are declared as Union[Parameter, Tensor] by default
  7. # and using other types causes mypy errors
  8. def _change_class(module, params_and_buffers) -> None:
  9. cls = module.__class__
  10. attr_to_path : Dict[str, str] = module._attr_to_path
  11. def _getattribute(self, name: str) -> Any:
  12. if name in attr_to_path:
  13. return params_and_buffers[attr_to_path[name]]
  14. return cls.__getattribute__(self, name)
  15. def _setattr(self, name: str, value: Any) -> None:
  16. if name in attr_to_path:
  17. params_and_buffers[attr_to_path[name]] = value
  18. else:
  19. return cls.__setattr__(self, name, value)
  20. param_cls = type(
  21. f"StatelessReplacer{cls.__name__}",
  22. (cls,),
  23. {
  24. "__getattribute__": _getattribute,
  25. "__setattr__": _setattr,
  26. },
  27. )
  28. module.__class__ = param_cls
  29. module._orig_class = cls
  30. def _create_swap_params(params_and_buffers):
  31. def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None:
  32. # Changes the module class to get a new __getattr__ dunder method
  33. # that looks for the reparametrized tensor
  34. if hasattr(module, "_attr_to_path"):
  35. module._attr_to_path[tensor_name] = full_path
  36. else:
  37. module._attr_to_path = {}
  38. module._attr_to_path[tensor_name] = full_path
  39. _change_class(module, params_and_buffers)
  40. return _swap_parameters
  41. def _remove_swap(module, name: str, full_path: str) -> None:
  42. if hasattr(module, "_orig_class"):
  43. module.__class__ = module._orig_class
  44. delattr(module, "_orig_class")
  45. delattr(module, "_attr_to_path")
  46. @contextlib.contextmanager
  47. def _reparametrize_module(
  48. module: 'torch.nn.Module',
  49. parameters_and_buffers: Dict[str, Tensor],
  50. ) -> Iterator[None]:
  51. for name, tensor in parameters_and_buffers.items():
  52. _apply_func_submodules(
  53. _create_swap_params(parameters_and_buffers),
  54. module, name.split("."), name, (tensor,))
  55. try:
  56. yield
  57. finally:
  58. for name in parameters_and_buffers:
  59. _apply_func_submodules(
  60. _remove_swap,
  61. module, name.split("."), name, ())
  62. def _apply_func_submodules(
  63. func: Callable[..., None],
  64. module: 'torch.nn.Module',
  65. path: List[str],
  66. full_path: str,
  67. args: Tuple,
  68. ):
  69. if len(path) == 1:
  70. func(module, path[0], full_path, *args)
  71. else:
  72. _apply_func_submodules(func, getattr(module, path[0]), path[1:], full_path, args)
  73. def functional_call(
  74. module: 'torch.nn.Module',
  75. parameters_and_buffers: Dict[str, Tensor],
  76. args: Tuple,
  77. kwargs : Dict[str, Any] = None,
  78. ):
  79. r"""Performs a functional call on the module by replacing the module parameters
  80. and buffers with the provided ones.
  81. .. note:: If the module has active parametrizations, passing a value in the
  82. :attr:`parameters_and_buffers` argument with the name set to the regular parameter
  83. name will completely disable the parametrization.
  84. If you want to apply the parametrization function to the value passed
  85. please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``.
  86. .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected
  87. in the `parameters_and_buffers` input.
  88. Example::
  89. >>> a = {'foo': torch.zeros(())}
  90. >>> mod = Foo() # does self.foo = self.foo + 1
  91. >>> print(mod.foo) # tensor(0.)
  92. >>> functional_call(mod, a, torch.ones(()))
  93. >>> print(mod.foo) # tensor(0.)
  94. >>> print(a['foo']) # tensor(1.)
  95. Args:
  96. module (torch.nn.Module): the module to call
  97. parameters_and_buffers (dict of str and Tensor): the parameters that will be used in
  98. the module call.
  99. args (tuple): arguments to be passed to the module call
  100. kwargs (dict): keyword arguments to be passed to the module call
  101. Returns:
  102. Any: the result of calling ``module``.
  103. """
  104. # TODO allow kwargs such as unsafe and others for parametrization
  105. if (
  106. torch.jit.is_tracing()
  107. or torch.jit.is_scripting()
  108. or isinstance(module, (
  109. torch.jit.RecursiveScriptModule,
  110. torch.jit.ScriptModule,
  111. torch.jit.ScriptFunction)
  112. )
  113. ):
  114. raise RuntimeError("The stateless API can't be used with Jitted modules")
  115. if kwargs is None:
  116. kwargs = {}
  117. with _reparametrize_module(module, parameters_and_buffers):
  118. if isinstance(args, tuple):
  119. out = module(*args, **kwargs)
  120. else:
  121. out = module(args, **kwargs)
  122. return out