parametrizations.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. from enum import Enum, auto
  2. import torch
  3. from torch import Tensor
  4. from ..utils import parametrize
  5. from ..modules import Module
  6. from .. import functional as F
  7. from typing import Optional
  8. def _is_orthogonal(Q, eps=None):
  9. n, k = Q.size(-2), Q.size(-1)
  10. Id = torch.eye(k, dtype=Q.dtype, device=Q.device)
  11. # A reasonable eps, but not too large
  12. eps = 10. * n * torch.finfo(Q.dtype).eps
  13. return torch.allclose(Q.mH @ Q, Id, atol=eps)
  14. def _make_orthogonal(A):
  15. """ Assume that A is a tall matrix.
  16. Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative
  17. """
  18. X, tau = torch.geqrf(A)
  19. Q = torch.linalg.householder_product(X, tau)
  20. # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs
  21. Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
  22. return Q
  23. class _OrthMaps(Enum):
  24. matrix_exp = auto()
  25. cayley = auto()
  26. householder = auto()
  27. class _Orthogonal(Module):
  28. base: Tensor
  29. def __init__(self,
  30. weight,
  31. orthogonal_map: _OrthMaps,
  32. *,
  33. use_trivialization=True) -> None:
  34. super().__init__()
  35. # Note [Householder complex]
  36. # For complex tensors, it is not possible to compute the tensor `tau` necessary for
  37. # linalg.householder_product from the reflectors.
  38. # To see this, note that the reflectors have a shape like:
  39. # 0 0 0
  40. # * 0 0
  41. # * * 0
  42. # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters
  43. # to parametrize the unitary matrices. Saving tau on its own does not work either, because
  44. # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise
  45. # them as independent tensors we would not maintain the constraint
  46. # An equivalent reasoning holds for rectangular matrices
  47. if weight.is_complex() and orthogonal_map == _OrthMaps.householder:
  48. raise ValueError("The householder parametrization does not support complex tensors.")
  49. self.shape = weight.shape
  50. self.orthogonal_map = orthogonal_map
  51. if use_trivialization:
  52. self.register_buffer("base", None)
  53. def forward(self, X: torch.Tensor) -> torch.Tensor:
  54. n, k = X.size(-2), X.size(-1)
  55. transposed = n < k
  56. if transposed:
  57. X = X.mT
  58. n, k = k, n
  59. # Here n > k and X is a tall matrix
  60. if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley:
  61. # We just need n x k - k(k-1)/2 parameters
  62. X = X.tril()
  63. if n != k:
  64. # Embed into a square matrix
  65. X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
  66. A = X - X.mH
  67. # A is skew-symmetric (or skew-hermitian)
  68. if self.orthogonal_map == _OrthMaps.matrix_exp:
  69. Q = torch.matrix_exp(A)
  70. elif self.orthogonal_map == _OrthMaps.cayley:
  71. # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1}
  72. Id = torch.eye(n, dtype=A.dtype, device=A.device)
  73. Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5))
  74. # Q is now orthogonal (or unitary) of size (..., n, n)
  75. if n != k:
  76. Q = Q[..., :k]
  77. # Q is now the size of the X (albeit perhaps transposed)
  78. else:
  79. # X is real here, as we do not support householder with complex numbers
  80. A = X.tril(diagonal=-1)
  81. tau = 2. / (1. + (A * A).sum(dim=-2))
  82. Q = torch.linalg.householder_product(A, tau)
  83. # The diagonal of X is 1's and -1's
  84. # We do not want to differentiate through this or update the diagonal of X hence the casting
  85. Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
  86. if hasattr(self, "base"):
  87. Q = self.base @ Q
  88. if transposed:
  89. Q = Q.mT
  90. return Q
  91. @torch.autograd.no_grad()
  92. def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:
  93. if Q.shape != self.shape:
  94. raise ValueError(f"Expected a matrix or batch of matrices of shape {self.shape}. "
  95. f"Got a tensor of shape {Q.shape}.")
  96. Q_init = Q
  97. n, k = Q.size(-2), Q.size(-1)
  98. transpose = n < k
  99. if transpose:
  100. Q = Q.mT
  101. n, k = k, n
  102. # We always make sure to always copy Q in every path
  103. if not hasattr(self, "base"):
  104. # Note [right_inverse expm cayley]
  105. # If we do not have use_trivialization=True, we just implement the inverse of the forward
  106. # map for the Householder. To see why, think that for the Cayley map,
  107. # we would need to find the matrix X \in R^{n x k} such that:
  108. # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
  109. # A = Y - Y.mH
  110. # cayley(A)[:, :k]
  111. # gives the original tensor. It is not clear how to do this.
  112. # Perhaps via some algebraic manipulation involving the QR like that of
  113. # Corollary 2.2 in Edelman, Arias and Smith?
  114. if self.orthogonal_map == _OrthMaps.cayley or self.orthogonal_map == _OrthMaps.matrix_exp:
  115. raise NotImplementedError("It is not possible to assign to the matrix exponential "
  116. "or the Cayley parametrizations when use_trivialization=False.")
  117. # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition.
  118. # Here Q is always real because we do not support householder and complex matrices.
  119. # See note [Householder complex]
  120. A, tau = torch.geqrf(Q)
  121. # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could
  122. # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition
  123. # The diagonal of Q is the diagonal of R from the qr decomposition
  124. A.diagonal(dim1=-2, dim2=-1).sign_()
  125. # Equality with zero is ok because LAPACK returns exactly zero when it does not want
  126. # to use a particular reflection
  127. A.diagonal(dim1=-2, dim2=-1)[tau == 0.] *= -1
  128. return A.mT if transpose else A
  129. else:
  130. if n == k:
  131. # We check whether Q is orthogonal
  132. if not _is_orthogonal(Q):
  133. Q = _make_orthogonal(Q)
  134. else: # Is orthogonal
  135. Q = Q.clone()
  136. else:
  137. # Complete Q into a full n x n orthogonal matrix
  138. N = torch.randn(*(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device)
  139. Q = torch.cat([Q, N], dim=-1)
  140. Q = _make_orthogonal(Q)
  141. self.base = Q
  142. # It is necessary to return the -Id, as we use the diagonal for the
  143. # Householder parametrization. Using -Id makes:
  144. # householder(torch.zeros(m,n)) == torch.eye(m,n)
  145. # Poor man's version of eye_like
  146. neg_Id = torch.zeros_like(Q_init)
  147. neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.)
  148. return neg_Id
  149. def orthogonal(module: Module,
  150. name: str = 'weight',
  151. orthogonal_map: Optional[str] = None,
  152. *,
  153. use_trivialization: bool = True) -> Module:
  154. r"""Applies an orthogonal or unitary parametrization to a matrix or a batch of matrices.
  155. Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized
  156. matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as
  157. .. math::
  158. \begin{align*}
  159. Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\
  160. QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n}
  161. \end{align*}
  162. where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex
  163. and the transpose when :math:`Q` is real-valued, and
  164. :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix.
  165. In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n`
  166. and orthonormal rows otherwise.
  167. If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`.
  168. The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor:
  169. - ``"matrix_exp"``/``"cayley"``:
  170. the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_
  171. :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric
  172. :math:`A` to give an orthogonal matrix.
  173. - ``"householder"``: computes a product of Householder reflectors
  174. (:func:`~torch.linalg.householder_product`).
  175. ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than
  176. ``"householder"``, but they are slower to compute for very thin or very wide matrices.
  177. If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework",
  178. where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under
  179. ``module.parametrizations.weight[0].base``. This helps the
  180. convergence of the parametrized layer at the expense of some extra memory use.
  181. See `Trivializations for Gradient-Based Optimization on Manifolds`_ .
  182. Initial value of :math:`Q`:
  183. If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value
  184. of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case)
  185. and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`).
  186. Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``.
  187. Otherwise, the initial value is the result of the composition of all the registered
  188. parametrizations applied to the original tensor.
  189. .. note::
  190. This function is implemented using the parametrization functionality
  191. in :func:`~torch.nn.utils.parametrize.register_parametrization`.
  192. .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
  193. .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501
  194. Args:
  195. module (nn.Module): module on which to register the parametrization.
  196. name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``.
  197. orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``.
  198. Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise.
  199. use_trivialization (bool, optional): whether to use the dynamic trivialization framework.
  200. Default: ``True``.
  201. Returns:
  202. The original module with an orthogonal parametrization registered to the specified
  203. weight
  204. Example::
  205. >>> orth_linear = orthogonal(nn.Linear(20, 40))
  206. >>> orth_linear
  207. ParametrizedLinear(
  208. in_features=20, out_features=40, bias=True
  209. (parametrizations): ModuleDict(
  210. (weight): ParametrizationList(
  211. (0): _Orthogonal()
  212. )
  213. )
  214. )
  215. >>> Q = orth_linear.weight
  216. >>> torch.dist(Q.T @ Q, torch.eye(20))
  217. tensor(4.9332e-07)
  218. """
  219. weight = getattr(module, name, None)
  220. if not isinstance(weight, Tensor):
  221. raise ValueError(
  222. "Module '{}' has no parameter ot buffer with name '{}'".format(module, name)
  223. )
  224. # We could implement this for 1-dim tensors as the maps on the sphere
  225. # but I believe it'd bite more people than it'd help
  226. if weight.ndim < 2:
  227. raise ValueError("Expected a matrix or batch of matrices. "
  228. f"Got a tensor of {weight.ndim} dimensions.")
  229. if orthogonal_map is None:
  230. orthogonal_map = "matrix_exp" if weight.size(-2) == weight.size(-1) or weight.is_complex() else "householder"
  231. orth_enum = getattr(_OrthMaps, orthogonal_map, None)
  232. if orth_enum is None:
  233. raise ValueError('orthogonal_map has to be one of "matrix_exp", "cayley", "householder". '
  234. f'Got: {orthogonal_map}')
  235. orth = _Orthogonal(weight,
  236. orth_enum,
  237. use_trivialization=use_trivialization)
  238. parametrize.register_parametrization(module, name, orth, unsafe=True)
  239. return module
  240. class _SpectralNorm(Module):
  241. def __init__(
  242. self,
  243. weight: torch.Tensor,
  244. n_power_iterations: int = 1,
  245. dim: int = 0,
  246. eps: float = 1e-12
  247. ) -> None:
  248. super().__init__()
  249. ndim = weight.ndim
  250. if dim >= ndim or dim < -ndim:
  251. raise IndexError("Dimension out of range (expected to be in range of "
  252. f"[-{ndim}, {ndim - 1}] but got {dim})")
  253. if n_power_iterations <= 0:
  254. raise ValueError('Expected n_power_iterations to be positive, but '
  255. 'got n_power_iterations={}'.format(n_power_iterations))
  256. self.dim = dim if dim >= 0 else dim + ndim
  257. self.eps = eps
  258. if ndim > 1:
  259. # For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward)
  260. self.n_power_iterations = n_power_iterations
  261. weight_mat = self._reshape_weight_to_matrix(weight)
  262. h, w = weight_mat.size()
  263. u = weight_mat.new_empty(h).normal_(0, 1)
  264. v = weight_mat.new_empty(w).normal_(0, 1)
  265. self.register_buffer('_u', F.normalize(u, dim=0, eps=self.eps))
  266. self.register_buffer('_v', F.normalize(v, dim=0, eps=self.eps))
  267. # Start with u, v initialized to some reasonable values by performing a number
  268. # of iterations of the power method
  269. self._power_method(weight_mat, 15)
  270. def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
  271. # Precondition
  272. assert weight.ndim > 1
  273. if self.dim != 0:
  274. # permute dim to front
  275. weight = weight.permute(self.dim, *(d for d in range(weight.dim()) if d != self.dim))
  276. return weight.flatten(1)
  277. @torch.autograd.no_grad()
  278. def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None:
  279. # See original note at torch/nn/utils/spectral_norm.py
  280. # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
  281. # updated in power iteration **in-place**. This is very important
  282. # because in `DataParallel` forward, the vectors (being buffers) are
  283. # broadcast from the parallelized module to each module replica,
  284. # which is a new module object created on the fly. And each replica
  285. # runs its own spectral norm power iteration. So simply assigning
  286. # the updated vectors to the module this function runs on will cause
  287. # the update to be lost forever. And the next time the parallelized
  288. # module is replicated, the same randomly initialized vectors are
  289. # broadcast and used!
  290. #
  291. # Therefore, to make the change propagate back, we rely on two
  292. # important behaviors (also enforced via tests):
  293. # 1. `DataParallel` doesn't clone storage if the broadcast tensor
  294. # is already on correct device; and it makes sure that the
  295. # parallelized module is already on `device[0]`.
  296. # 2. If the out tensor in `out=` kwarg has correct shape, it will
  297. # just fill in the values.
  298. # Therefore, since the same power iteration is performed on all
  299. # devices, simply updating the tensors in-place will make sure that
  300. # the module replica on `device[0]` will update the _u vector on the
  301. # parallized module (by shared storage).
  302. #
  303. # However, after we update `u` and `v` in-place, we need to **clone**
  304. # them before using them to normalize the weight. This is to support
  305. # backproping through two forward passes, e.g., the common pattern in
  306. # GAN training: loss = D(real) - D(fake). Otherwise, engine will
  307. # complain that variables needed to do backward for the first forward
  308. # (i.e., the `u` and `v` vectors) are changed in the second forward.
  309. # Precondition
  310. assert weight_mat.ndim > 1
  311. for _ in range(n_power_iterations):
  312. # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
  313. # are the first left and right singular vectors.
  314. # This power iteration produces approximations of `u` and `v`.
  315. self._u = F.normalize(torch.mv(weight_mat, self._v), # type: ignore[has-type]
  316. dim=0, eps=self.eps, out=self._u) # type: ignore[has-type]
  317. self._v = F.normalize(torch.mv(weight_mat.t(), self._u),
  318. dim=0, eps=self.eps, out=self._v) # type: ignore[has-type]
  319. def forward(self, weight: torch.Tensor) -> torch.Tensor:
  320. if weight.ndim == 1:
  321. # Faster and more exact path, no need to approximate anything
  322. return F.normalize(weight, dim=0, eps=self.eps)
  323. else:
  324. weight_mat = self._reshape_weight_to_matrix(weight)
  325. if self.training:
  326. self._power_method(weight_mat, self.n_power_iterations)
  327. # See above on why we need to clone
  328. u = self._u.clone(memory_format=torch.contiguous_format)
  329. v = self._v.clone(memory_format=torch.contiguous_format)
  330. # The proper way of computing this should be through F.bilinear, but
  331. # it seems to have some efficiency issues:
  332. # https://github.com/pytorch/pytorch/issues/58093
  333. sigma = torch.dot(u, torch.mv(weight_mat, v))
  334. return weight / sigma
  335. def right_inverse(self, value: torch.Tensor) -> torch.Tensor:
  336. # we may want to assert here that the passed value already
  337. # satisfies constraints
  338. return value
  339. def spectral_norm(module: Module,
  340. name: str = 'weight',
  341. n_power_iterations: int = 1,
  342. eps: float = 1e-12,
  343. dim: Optional[int] = None) -> Module:
  344. r"""Applies spectral normalization to a parameter in the given module.
  345. .. math::
  346. \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
  347. \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
  348. When applied on a vector, it simplifies to
  349. .. math::
  350. \mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}
  351. Spectral normalization stabilizes the training of discriminators (critics)
  352. in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant
  353. of the model. :math:`\sigma` is approximated performing one iteration of the
  354. `power method`_ every time the weight is accessed. If the dimension of the
  355. weight tensor is greater than 2, it is reshaped to 2D in power iteration
  356. method to get spectral norm.
  357. See `Spectral Normalization for Generative Adversarial Networks`_ .
  358. .. _`power method`: https://en.wikipedia.org/wiki/Power_iteration
  359. .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
  360. .. note::
  361. This function is implemented using the parametrization functionality
  362. in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a
  363. reimplementation of :func:`torch.nn.utils.spectral_norm`.
  364. .. note::
  365. When this constraint is registered, the singular vectors associated to the largest
  366. singular value are estimated rather than sampled at random. These are then updated
  367. performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor
  368. is accessed with the module on `training` mode.
  369. .. note::
  370. If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`,
  371. is in training mode on removal, it will perform another power iteration.
  372. If you'd like to avoid this iteration, set the module to eval mode
  373. before its removal.
  374. Args:
  375. module (nn.Module): containing module
  376. name (str, optional): name of weight parameter. Default: ``"weight"``.
  377. n_power_iterations (int, optional): number of power iterations to
  378. calculate spectral norm. Default: ``1``.
  379. eps (float, optional): epsilon for numerical stability in
  380. calculating norms. Default: ``1e-12``.
  381. dim (int, optional): dimension corresponding to number of outputs.
  382. Default: ``0``, except for modules that are instances of
  383. ConvTranspose{1,2,3}d, when it is ``1``
  384. Returns:
  385. The original module with a new parametrization registered to the specified
  386. weight
  387. Example::
  388. >>> snm = spectral_norm(nn.Linear(20, 40))
  389. >>> snm
  390. ParametrizedLinear(
  391. in_features=20, out_features=40, bias=True
  392. (parametrizations): ModuleDict(
  393. (weight): ParametrizationList(
  394. (0): _SpectralNorm()
  395. )
  396. )
  397. )
  398. >>> torch.linalg.matrix_norm(snm.weight, 2)
  399. tensor(1.0000, grad_fn=<CopyBackwards>)
  400. """
  401. weight = getattr(module, name, None)
  402. if not isinstance(weight, Tensor):
  403. raise ValueError(
  404. "Module '{}' has no parameter or buffer with name '{}'".format(module, name)
  405. )
  406. if dim is None:
  407. if isinstance(module, (torch.nn.ConvTranspose1d,
  408. torch.nn.ConvTranspose2d,
  409. torch.nn.ConvTranspose3d)):
  410. dim = 1
  411. else:
  412. dim = 0
  413. parametrize.register_parametrization(module, name, _SpectralNorm(weight, n_power_iterations, dim, eps))
  414. return module