parametrize.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728
  1. import torch
  2. from torch.nn.modules.container import ModuleList, ModuleDict, Module
  3. from torch.nn.parameter import Parameter
  4. from torch import Tensor
  5. import collections
  6. from contextlib import contextmanager
  7. from typing import Union, Optional, Dict, Tuple, Sequence
  8. _cache_enabled = 0
  9. _cache: Dict[Tuple[int, str], Optional[Tensor]] = {}
  10. @contextmanager
  11. def cached():
  12. r"""Context manager that enables the caching system within parametrizations
  13. registered with :func:`register_parametrization`.
  14. The value of the parametrized objects is computed and cached the first time
  15. they are required when this context manager is active. The cached values are
  16. discarded when leaving the context manager.
  17. This is useful when using a parametrized parameter more than once in the forward pass.
  18. An example of this is when parametrizing the recurrent kernel of an RNN or when
  19. sharing weights.
  20. The simplest way to activate the cache is by wrapping the forward pass of the neural network
  21. .. code-block:: python
  22. import torch.nn.utils.parametrize as P
  23. ...
  24. with P.cached():
  25. output = model(inputs)
  26. in training and evaluation. One may also wrap the parts of the modules that use
  27. several times the parametrized tensors. For example, the loop of an RNN with a
  28. parametrized recurrent kernel:
  29. .. code-block:: python
  30. with P.cached():
  31. for x in xs:
  32. out_rnn = self.rnn_cell(x, out_rnn)
  33. """
  34. global _cache
  35. global _cache_enabled
  36. _cache_enabled += 1
  37. try:
  38. yield
  39. finally:
  40. _cache_enabled -= 1
  41. if not _cache_enabled:
  42. _cache = {}
  43. def _register_parameter_or_buffer(module, name, X):
  44. if isinstance(X, Parameter):
  45. module.register_parameter(name, X)
  46. else:
  47. module.register_buffer(name, X)
  48. class ParametrizationList(ModuleList):
  49. r"""A sequential container that holds and manages the ``original`` or ``original0``, ``original1``, ...
  50. parameters or buffers of a parametrized :class:`torch.nn.Module`.
  51. It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]``
  52. has been parametrized with :func:`register_parametrization`.
  53. If the first registered parmetrization has a ``right_inverse`` that returns one tensor or
  54. does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity),
  55. it will hold the tensor under the name ``original``.
  56. If it has a ``right_inverse`` that returns more than one tensor, these will be registered as
  57. ``original0``, ``original1``, ...
  58. .. warning::
  59. This class is used internally by :func:`register_parametrization`. It is documented
  60. here for completeness. It shall not be instantiated by the user.
  61. Args:
  62. modules (sequence): sequence of modules representing the parametrizations
  63. original (Parameter or Tensor): parameter or buffer that is parametrized
  64. unsafe (bool): a boolean flag that denotes whether the parametrization
  65. may change the dtype and shape of the tensor. Default: `False`
  66. Warning: the parametrization is not checked for consistency upon registration.
  67. Enable this flag at your own risk.
  68. """
  69. original: Tensor
  70. unsafe: bool
  71. def __init__(
  72. self, modules: Sequence[Module], original: Union[Tensor, Parameter], unsafe: bool = False
  73. ) -> None:
  74. # We require this because we need to treat differently the first parametrization
  75. # This should never throw, unless this class is used from the outside
  76. if len(modules) == 0:
  77. raise ValueError("ParametrizationList requires one or more modules.")
  78. super().__init__(modules)
  79. self.unsafe = unsafe
  80. # In plain words:
  81. # module.weight must keep its dtype and shape.
  82. # Furthermore, if there is no right_inverse or the right_inverse returns a tensor,
  83. # this should be of the same dtype as the original tensor
  84. #
  85. # We check that the following invariants hold:
  86. # X = module.weight
  87. # Y = param.right_inverse(X)
  88. # assert isinstance(Y, Tensor) or
  89. # (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y))
  90. # Z = param(Y) if isisntance(Y, Tensor) else param(*Y)
  91. # # Consistency checks
  92. # assert X.dtype == Z.dtype and X.shape == Z.shape
  93. # # If it has one input, this allows to be able to use set_ to be able to
  94. # # move data to/from the original tensor without changing its id (which is what the
  95. # # optimiser uses to track parameters)
  96. # if isinstance(Y, Tensor)
  97. # assert X.dtype == Y.dtype
  98. # Below we use original = X, new = Y
  99. original_shape = original.shape
  100. original_dtype = original.dtype
  101. # Compute new
  102. with torch.no_grad():
  103. new = original
  104. for module in reversed(self): # type: ignore[call-overload]
  105. if hasattr(module, "right_inverse"):
  106. try:
  107. new = module.right_inverse(new)
  108. except NotImplementedError:
  109. pass
  110. # else, or if it throws, we assume that right_inverse is the identity
  111. if not isinstance(new, Tensor) and not isinstance(new, collections.abc.Sequence):
  112. raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). "
  113. f"Got {type(new).__name__}")
  114. # Set the number of original tensors
  115. self.is_tensor = isinstance(new, Tensor)
  116. self.ntensors = 1 if self.is_tensor else len(new)
  117. # Register the tensor(s)
  118. if self.is_tensor:
  119. if original.dtype != new.dtype:
  120. raise ValueError(
  121. "When `right_inverse` outputs one tensor, it may not change the dtype.\n"
  122. f"original.dtype: {original.dtype}\n"
  123. f"right_inverse(original).dtype: {new.dtype}"
  124. )
  125. # Set the original to original so that the user does not need to re-register the parameter
  126. # manually in the optimiser
  127. with torch.no_grad():
  128. original.set_(new) # type: ignore[call-overload]
  129. _register_parameter_or_buffer(self, "original", original)
  130. else:
  131. for i, originali in enumerate(new):
  132. if not isinstance(originali, Tensor):
  133. raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors "
  134. "(list, tuple...). "
  135. f"Got element {i} of the sequence with type {type(originali).__name__}.")
  136. # If the original tensor was a Parameter that required grad, we expect the user to
  137. # add the new parameters to the optimizer after registering the parametrization
  138. # (this is documented)
  139. if isinstance(original, Parameter):
  140. originali = Parameter(originali)
  141. originali.requires_grad_(original.requires_grad)
  142. _register_parameter_or_buffer(self, f"original{i}", originali)
  143. if not self.unsafe:
  144. # Consistency checks:
  145. # Since f : A -> B, right_inverse : B -> A, Z and original should live in B
  146. # Z = forward(right_inverse(original))
  147. Z = self()
  148. if not isinstance(Z, Tensor):
  149. raise ValueError(
  150. f"A parametrization must return a tensor. Got {type(Z).__name__}."
  151. )
  152. if Z.dtype != original_dtype:
  153. raise ValueError(
  154. "Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n"
  155. f"unparametrized dtype: {original_dtype}\n"
  156. f"parametrized dtype: {Z.dtype}"
  157. )
  158. if Z.shape != original_shape:
  159. raise ValueError(
  160. "Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n"
  161. f"unparametrized shape: {original_shape}\n"
  162. f"parametrized shape: {Z.shape}"
  163. )
  164. def right_inverse(self, value: Tensor) -> None:
  165. r"""Calls the methods ``right_inverse`` (see :func:`register_parametrization`)
  166. of the parametrizations in the inverse order they were registered in.
  167. Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor
  168. or in ``self.original0``, ``self.original1``, ... if it outputs several.
  169. Args:
  170. value (Tensor): Value to which initialize the module
  171. """
  172. # All the exceptions in this function should almost never throw.
  173. # They could throw if, for example, right_inverse function returns a different
  174. # dtype when given a different input, which should most likely be caused by a
  175. # bug in the user's code
  176. with torch.no_grad():
  177. # See https://github.com/pytorch/pytorch/issues/53103
  178. for module in reversed(self): # type: ignore[call-overload]
  179. if hasattr(module, "right_inverse"):
  180. value = module.right_inverse(value)
  181. else:
  182. raise RuntimeError(f"parametrization {type(module).__name__} does not implement "
  183. "right_inverse.")
  184. if self.is_tensor:
  185. # These exceptions should only throw when a right_inverse function does not
  186. # return the same dtype for every input, which should most likely be caused by a bug
  187. if not isinstance(value, Tensor):
  188. raise ValueError(
  189. f"`right_inverse` should return a tensor. Got {type(value).__name__}"
  190. )
  191. if value.dtype != self.original.dtype:
  192. raise ValueError(
  193. f"The tensor returned by `right_inverse` has dtype {value.dtype} "
  194. f"while `original` has dtype {self.original.dtype}"
  195. )
  196. # We know that the result is going to have the same dtype
  197. self.original.set_(value) # type: ignore[call-overload]
  198. else:
  199. if not isinstance(value, collections.abc.Sequence):
  200. raise ValueError(
  201. "'right_inverse' must return a sequence of tensors. "
  202. f"Got {type(value).__name__}."
  203. )
  204. if len(value) != self.ntensors:
  205. raise ValueError(
  206. "'right_inverse' must return a sequence of tensors of length "
  207. f"{self.ntensors}. Got a sequence of lenght {len(value)}."
  208. )
  209. for i, tensor in enumerate(value):
  210. original_i = getattr(self, f"original{i}")
  211. if not isinstance(tensor, Tensor):
  212. raise ValueError(
  213. f"`right_inverse` must return a sequence of tensors. "
  214. f"Got element {i} of type {type(tensor).__name__}"
  215. )
  216. if original_i.dtype != tensor.dtype:
  217. raise ValueError(
  218. f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} "
  219. f"while `original{i}` has dtype {original_i.dtype}"
  220. )
  221. original_i.set_(tensor)
  222. def forward(self) -> Tensor:
  223. # Unpack the originals for the first parametrization
  224. if self.is_tensor:
  225. x = self[0](self.original)
  226. else:
  227. originals = (getattr(self, f"original{i}") for i in range(self.ntensors))
  228. x = self[0](*originals)
  229. # It's not possible to call self[1:] here, so we have to be a bit more cryptic
  230. # Also we want to skip all non-integer keys
  231. curr_idx = 1
  232. while hasattr(self, str(curr_idx)):
  233. x = self[curr_idx](x)
  234. curr_idx += 1
  235. return x
  236. def _inject_new_class(module: Module) -> None:
  237. r"""Sets up a module to be parametrized.
  238. This works by substituting the class of the module by a class
  239. that extends it to be able to inject a property
  240. Args:
  241. module (nn.Module): module into which to inject the property
  242. """
  243. cls = module.__class__
  244. def getstate(self):
  245. raise RuntimeError(
  246. "Serialization of parametrized modules is only "
  247. "supported through state_dict(). See:\n"
  248. "https://pytorch.org/tutorials/beginner/saving_loading_models.html"
  249. "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
  250. )
  251. param_cls = type(
  252. f"Parametrized{cls.__name__}",
  253. (cls,),
  254. {
  255. "__getstate__": getstate,
  256. },
  257. )
  258. module.__class__ = param_cls
  259. def _inject_property(module: Module, tensor_name: str) -> None:
  260. r"""Injects a property into module[tensor_name].
  261. It assumes that the class in the module has already been modified from its
  262. original one using _inject_new_class and that the tensor under :attr:`tensor_name`
  263. has already been moved out
  264. Args:
  265. module (nn.Module): module into which to inject the property
  266. tensor_name (str): name of the name of the property to create
  267. """
  268. # We check the precondition.
  269. # This should never fire if register_parametrization is correctly implemented
  270. assert not hasattr(module, tensor_name)
  271. @torch.jit.unused
  272. def get_cached_parametrization(parametrization) -> Tensor:
  273. global _cache
  274. key = (id(module), tensor_name)
  275. tensor = _cache.get(key)
  276. if tensor is None:
  277. tensor = parametrization()
  278. _cache[key] = tensor
  279. return tensor
  280. def get_parametrized(self) -> Tensor:
  281. parametrization = self.parametrizations[tensor_name]
  282. if _cache_enabled:
  283. if torch.jit.is_scripting():
  284. # Scripting
  285. raise RuntimeError('Caching is not implemented for scripting. '
  286. 'Either disable caching or avoid scripting.')
  287. elif torch._C._get_tracing_state() is not None:
  288. # Tracing
  289. raise RuntimeError('Cannot trace a model while caching parametrizations.')
  290. else:
  291. return get_cached_parametrization(parametrization)
  292. else:
  293. # If caching is not active, this function just evaluates the parametrization
  294. return parametrization()
  295. def set_original(self, value: Tensor) -> None:
  296. self.parametrizations[tensor_name].right_inverse(value)
  297. setattr(module.__class__, tensor_name, property(get_parametrized, set_original))
  298. def register_parametrization(
  299. module: Module, tensor_name: str, parametrization: Module, *, unsafe: bool = False,
  300. ) -> Module:
  301. r"""Adds a parametrization to a tensor in a module.
  302. Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``,
  303. the module will return the parametrized version ``parametrization(module.weight)``.
  304. If the original tensor requires a gradient, the backward pass will differentiate
  305. through :attr:`parametrization`, and the optimizer will update the tensor accordingly.
  306. The first time that a module registers a parametrization, this function will add an attribute
  307. ``parametrizations`` to the module of type :class:`~ParametrizationList`.
  308. The list of parametrizations on the tensor ``weight`` will be accessible under
  309. ``module.parametrizations.weight``.
  310. The original tensor will be accessible under
  311. ``module.parametrizations.weight.original``.
  312. Parametrizations may be concatenated by registering several parametrizations
  313. on the same attribute.
  314. The training mode of a registered parametrization is updated on registration
  315. to match the training mode of the host module
  316. Parametrized parameters and buffers have an inbuilt caching system that can be activated
  317. using the context manager :func:`cached`.
  318. A :attr:`parametrization` may optionally implement a method with signature
  319. .. code-block:: python
  320. def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]
  321. This method is called on the unparametrized tensor when the first parametrization
  322. is registered to compute the initial value of the original tensor.
  323. If this method is not implemented, the original tensor will be just the unparametrized tensor.
  324. If all the parametrizations registered on a tensor implement `right_inverse` it is possible
  325. to initialize a parametrized tensor by assigning to it, as shown in the example below.
  326. It is possible for the first parametrization to depend on several inputs.
  327. This may be implemented returning a tuple of tensors from ``right_inverse``
  328. (see the example implementation of a ``RankOne`` parametrization below).
  329. In this case, the unconstrained tensors are also located under ``module.parametrizations.weight``
  330. with names ``original0``, ``original1``,...
  331. .. note::
  332. If unsafe=False (default) both the forward and right_inverse methods will be called
  333. once to perform a number of consistency checks.
  334. If unsafe=True, then right_inverse will be called if the tensor is not parametrized,
  335. and nothing will be called otherwise.
  336. .. note::
  337. In most situations, ``right_inverse`` will be a function such that
  338. ``forward(right_inverse(X)) == X`` (see
  339. `right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_).
  340. Sometimes, when the parametrization is not surjective, it may be reasonable
  341. to relax this.
  342. .. warning::
  343. If a parametrization depends on several inputs, :func:`~register_parametrization`
  344. will register a number of new parameters. If such parametrization is registered
  345. after the optimizer is created, these new parameters will need to be added manually
  346. to the optimizer. See :meth:`torch.Optimizer.add_param_group`.
  347. Args:
  348. module (nn.Module): module on which to register the parametrization
  349. tensor_name (str): name of the parameter or buffer on which to register
  350. the parametrization
  351. parametrization (nn.Module): the parametrization to register
  352. Keyword args:
  353. unsafe (bool): a boolean flag that denotes whether the parametrization
  354. may change the dtype and shape of the tensor. Default: `False`
  355. Warning: the parametrization is not checked for consistency upon registration.
  356. Enable this flag at your own risk.
  357. Raises:
  358. ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name`
  359. Examples:
  360. >>> import torch
  361. >>> import torch.nn as nn
  362. >>> import torch.nn.utils.parametrize as P
  363. >>>
  364. >>> class Symmetric(nn.Module):
  365. >>> def forward(self, X):
  366. >>> return X.triu() + X.triu(1).T # Return a symmetric matrix
  367. >>>
  368. >>> def right_inverse(self, A):
  369. >>> return A.triu()
  370. >>>
  371. >>> m = nn.Linear(5, 5)
  372. >>> P.register_parametrization(m, "weight", Symmetric())
  373. >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric
  374. True
  375. >>> A = torch.rand(5, 5)
  376. >>> A = A + A.T # A is now symmetric
  377. >>> m.weight = A # Initialize the weight to be the symmetric matrix A
  378. >>> print(torch.allclose(m.weight, A))
  379. True
  380. >>> class RankOne(nn.Module):
  381. >>> def forward(self, x, y):
  382. >>> # Form a rank 1 matrix multiplying two vectors
  383. >>> return x.unsqueeze(-1) @ y.unsqueeze(-2)
  384. >>>
  385. >>> def right_inverse(self, Z):
  386. >>> # Project Z onto the rank 1 matrices
  387. >>> U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
  388. >>> # Return rescaled singular vectors
  389. >>> s0_sqrt = S[0].sqrt().unsqueeze(-1)
  390. >>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
  391. >>>
  392. >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne())
  393. >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item())
  394. 1
  395. """
  396. parametrization.train(module.training)
  397. if is_parametrized(module, tensor_name):
  398. # Correctness checks.
  399. # If A is the space of tensors with shape and dtype equal to module.weight
  400. # we check that parametrization.forward and parametrization.right_inverse are
  401. # functions from A to A
  402. if not unsafe:
  403. Y = getattr(module, tensor_name)
  404. X = parametrization(Y)
  405. if not isinstance(X, Tensor):
  406. raise ValueError(
  407. f"A parametrization must return a tensor. Got {type(X).__name__}."
  408. )
  409. if X.dtype != Y.dtype:
  410. raise ValueError(
  411. "Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n"
  412. f"module.{tensor_name}.dtype: {Y.dtype}\n"
  413. f"parametrization(module.{tensor_name}).dtype: {X.dtype}"
  414. )
  415. if X.shape != Y.shape:
  416. raise ValueError(
  417. "Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n"
  418. f"module.{tensor_name}.shape: {Y.shape}\n"
  419. f"parametrization(module.{tensor_name}).shape: {X.shape}"
  420. )
  421. if hasattr(parametrization, "right_inverse"):
  422. try:
  423. Z = parametrization.right_inverse(X) # type: ignore[operator]
  424. except NotImplementedError:
  425. pass
  426. else:
  427. if not isinstance(Z, Tensor):
  428. raise ValueError(
  429. f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}"
  430. )
  431. if Z.dtype != Y.dtype:
  432. raise ValueError(
  433. "The tensor returned by parametrization.right_inverse must have the same dtype "
  434. f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
  435. f"module.{tensor_name}.dtype: {Y.dtype}\n"
  436. f"returned dtype: {Z.dtype}"
  437. )
  438. if Z.shape != Y.shape:
  439. raise ValueError(
  440. "The tensor returned by parametrization.right_inverse must have the same shape "
  441. f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n"
  442. f"module.{tensor_name}.shape: {Y.shape}\n"
  443. f"returned shape: {Z.shape}"
  444. )
  445. # else right_inverse is assumed to be the identity
  446. # add the new parametrization to the parametrization list
  447. assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
  448. module.parametrizations[tensor_name].append(parametrization)
  449. # If unsafe was True in previous parametrization, keep it enabled
  450. module.parametrizations[tensor_name].unsafe |= unsafe # type: ignore[index, union-attr]
  451. elif tensor_name in module._buffers or tensor_name in module._parameters:
  452. # Set the parametrization mechanism
  453. # Fetch the original buffer or parameter
  454. original = getattr(module, tensor_name)
  455. # We create this early to check for possible errors
  456. parametrizations = ParametrizationList([parametrization], original, unsafe=unsafe)
  457. # Delete the previous parameter or buffer
  458. delattr(module, tensor_name)
  459. # If this is the first parametrization registered on the module,
  460. # we prepare the module to inject the property
  461. if not is_parametrized(module):
  462. # Change the class
  463. _inject_new_class(module)
  464. # Inject a ``ModuleDict`` into the instance under module.parametrizations
  465. module.parametrizations = ModuleDict()
  466. # Add a property into the class
  467. _inject_property(module, tensor_name)
  468. # Add a ParametrizationList
  469. assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
  470. module.parametrizations[tensor_name] = parametrizations
  471. else:
  472. raise ValueError(
  473. f"Module '{module}' does not have a parameter, a buffer, or a "
  474. f"parametrized element with name '{tensor_name}'"
  475. )
  476. return module
  477. def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool:
  478. r"""Returns ``True`` if module has an active parametrization.
  479. If the argument :attr:`tensor_name` is specified, returns ``True`` if
  480. ``module[tensor_name]`` is parametrized.
  481. Args:
  482. module (nn.Module): module to query
  483. name (str, optional): attribute in the module to query
  484. Default: ``None``
  485. """
  486. parametrizations = getattr(module, "parametrizations", None)
  487. if parametrizations is None or not isinstance(parametrizations, ModuleDict):
  488. return False
  489. if tensor_name is None:
  490. # Check that there is at least one parametrized buffer or Parameter
  491. return len(parametrizations) > 0
  492. else:
  493. return tensor_name in parametrizations
  494. def remove_parametrizations(
  495. module: Module, tensor_name: str, leave_parametrized: bool = True
  496. ) -> Module:
  497. r"""Removes the parametrizations on a tensor in a module.
  498. - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to
  499. its current output. In this case, the parametrization shall not change the ``dtype``
  500. of the tensor.
  501. - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to
  502. the unparametrised tensor in ``module.parametrizations[tensor_name].original``.
  503. This is only possible when the parametrization depends on just one tensor.
  504. Args:
  505. module (nn.Module): module from which remove the parametrization
  506. tensor_name (str): name of the parametrization to be removed
  507. leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized.
  508. Default: ``True``
  509. Returns:
  510. Module: module
  511. Raises:
  512. ValueError: if ``module[tensor_name]`` is not parametrized
  513. ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors
  514. """
  515. if not is_parametrized(module, tensor_name):
  516. raise ValueError(f"Module {module} does not have a parametrization on {tensor_name}")
  517. # Fetch the original tensor
  518. assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
  519. parametrizations = module.parametrizations[tensor_name]
  520. if parametrizations.is_tensor:
  521. original = parametrizations.original
  522. if leave_parametrized:
  523. with torch.no_grad():
  524. t = getattr(module, tensor_name)
  525. # We know they have the same dtype because we have checked this when registering the
  526. # parametrizations. As such, we can use set_
  527. # We do this so that the parameter does not to change the id()
  528. # This way the user does not need to update the optimizer
  529. with torch.no_grad():
  530. if type(original) is torch.Tensor:
  531. original.set_(t)
  532. else:
  533. try:
  534. original.set_(t)
  535. except RuntimeError as e:
  536. # TODO: Fix this for tensor subclasses that are parameters:
  537. # RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach().
  538. raise RuntimeError("Calling remove_parametrizations() with leave_parametrized=True "
  539. "for a parameter that is an instance of a tensor subclass requires "
  540. "set_() to be implemented correctly for the tensor subclass. Either "
  541. "set leave_parametrized=False or provide a working implementation for "
  542. "set_() in the tensor subclass.")
  543. else:
  544. if leave_parametrized:
  545. # We cannot use no_grad because we need to know whether one or more
  546. # original tensors required grad
  547. t = getattr(module, tensor_name)
  548. # We'll have to trust the user to add it to the optimizer
  549. original = Parameter(t) if t.requires_grad else t
  550. else:
  551. raise ValueError("Cannot leave unparametrized (`leave_parametrized=False`) a tensor "
  552. "that is parametrized in terms of a sequence of tensors.")
  553. # Delete the property that manages the parametrization
  554. delattr(module.__class__, tensor_name)
  555. # Delete the ParametrizationList
  556. del module.parametrizations[tensor_name]
  557. # Restore the parameter / buffer into the main class
  558. _register_parameter_or_buffer(module, tensor_name, original)
  559. # Roll back the parametrized class if no other buffer or parameter
  560. # is currently parametrized in this class
  561. if not is_parametrized(module):
  562. delattr(module, "parametrizations")
  563. # Restore class
  564. orig_cls = module.__class__.__bases__[0]
  565. module.__class__ = orig_cls
  566. return module
  567. def type_before_parametrizations(module: Module) -> type:
  568. r"""Returns the module type before parametrizations were applied and if not,
  569. then it returns the module type.
  570. Args:
  571. module (nn.Module): module to get type of
  572. """
  573. if is_parametrized(module):
  574. return module.__class__.__bases__[0]
  575. else:
  576. return type(module)
  577. def transfer_parametrizations_and_params(
  578. from_module: Module, to_module: Module, tensor_name: Optional[str] = None
  579. ) -> Module:
  580. r"""Transfers parametrizations and the parameters they parametrize from from_module
  581. to to_module. If tensor_name is specified, only transfers the specified parameter, otherwise
  582. transfers all parametrized parameters. If those parameters do not exist in to_module, it will create them.
  583. Does nothing if from_module is not parametrized.
  584. Args:
  585. from_module (nn.Module): module to transfer from
  586. to_module (nn.Module): module to transfer to
  587. tensor_name (str, optional): parameter to transfer
  588. Returns:
  589. Module: to_module
  590. """
  591. if is_parametrized(from_module):
  592. assert isinstance(from_module.parametrizations, ModuleDict) # for mypy
  593. # get list of all params or the single param to transfer
  594. parameters_to_transfer: Union[list, ModuleDict] = (
  595. from_module.parametrizations if tensor_name is None else [tensor_name]
  596. )
  597. assert hasattr(parameters_to_transfer, "__iter__") # for mypy
  598. for parameter_name in parameters_to_transfer:
  599. # initialize the to-be-transfered param in to_module if it doesn't exist already
  600. if not hasattr(to_module, parameter_name):
  601. setattr(
  602. to_module,
  603. parameter_name,
  604. Parameter(getattr(from_module, parameter_name)),
  605. )
  606. # apply the params's parametrizations to to_module
  607. for param_func in from_module.parametrizations[parameter_name]:
  608. register_parametrization(to_module, parameter_name, param_func)
  609. assert isinstance(to_module.parametrizations, ModuleDict) # for mypy
  610. # make values match, original values can be stored in either original or
  611. # original0, original1..., need to check both cases
  612. if hasattr(from_module.parametrizations[parameter_name], "original"):
  613. to_module.parametrizations[parameter_name].original = \
  614. from_module.parametrizations[parameter_name].original
  615. else:
  616. num = 0
  617. orig_num = "original" + str(num)
  618. # loop through each original# until all values have been set
  619. while hasattr(from_module.parametrizations[parameter_name], orig_num):
  620. setattr(
  621. to_module.parametrizations[parameter_name],
  622. orig_num,
  623. getattr(from_module.parametrizations[parameter_name], orig_num),
  624. )
  625. num = num + 1
  626. orig_num = "original" + str(num)
  627. return to_module