_lobpcg.py 43 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117
  1. """Locally Optimal Block Preconditioned Conjugate Gradient methods.
  2. """
  3. # Author: Pearu Peterson
  4. # Created: February 2020
  5. from typing import Dict, Tuple, Optional
  6. import torch
  7. from torch import Tensor
  8. from . import _linalg_utils as _utils
  9. from .overrides import has_torch_function, handle_torch_function
  10. __all__ = ['lobpcg']
  11. def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U):
  12. # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0
  13. F = D.unsqueeze(-2) - D.unsqueeze(-1)
  14. F.diagonal(dim1=-2, dim2=-1).fill_(float('inf'))
  15. F.pow_(-1)
  16. # A.grad = U (D.grad + (U^T U.grad * F)) U^T
  17. Ut = U.mT.contiguous()
  18. res = torch.matmul(
  19. U,
  20. torch.matmul(
  21. torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F,
  22. Ut
  23. )
  24. )
  25. return res
  26. def _polynomial_coefficients_given_roots(roots):
  27. """
  28. Given the `roots` of a polynomial, find the polynomial's coefficients.
  29. If roots = (r_1, ..., r_n), then the method returns
  30. coefficients (a_0, a_1, ..., a_n (== 1)) so that
  31. p(x) = (x - r_1) * ... * (x - r_n)
  32. = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0
  33. Note: for better performance requires writing a low-level kernel
  34. """
  35. poly_order = roots.shape[-1]
  36. poly_coeffs_shape = list(roots.shape)
  37. # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0,
  38. # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)},
  39. # but we insert one extra coefficient to enable better vectorization below
  40. poly_coeffs_shape[-1] += 2
  41. poly_coeffs = roots.new_zeros(poly_coeffs_shape)
  42. poly_coeffs[..., 0] = 1
  43. poly_coeffs[..., -1] = 1
  44. # perform the Horner's rule
  45. for i in range(1, poly_order + 1):
  46. # note that it is computationally hard to compute backward for this method,
  47. # because then given the coefficients it would require finding the roots and/or
  48. # calculating the sensitivity based on the Vieta's theorem.
  49. # So the code below tries to circumvent the explicit root finding by series
  50. # of operations on memory copies imitating the Horner's method.
  51. # The memory copies are required to construct nodes in the computational graph
  52. # by exploting the explicit (not in-place, separate node for each step)
  53. # recursion of the Horner's method.
  54. # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity.
  55. poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs
  56. out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1)
  57. out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(-1, poly_order - i + 1, i + 1)
  58. poly_coeffs = poly_coeffs_new
  59. return poly_coeffs.narrow(-1, 1, poly_order + 1)
  60. def _polynomial_value(poly, x, zero_power, transition):
  61. """
  62. A generic method for computing poly(x) using the Horner's rule.
  63. Args:
  64. poly (Tensor): the (possibly batched) 1D Tensor representing
  65. polynomial coefficients such that
  66. poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and
  67. poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n
  68. x (Tensor): the value (possible batched) to evalate the polynomial `poly` at.
  69. zero_power (Tensor): the represenation of `x^0`. It is application-specific.
  70. transition (Callable): the function that accepts some intermediate result `int_val`,
  71. the `x` and a specific polynomial coefficient
  72. `poly[..., k]` for some iteration `k`.
  73. It basically performs one iteration of the Horner's rule
  74. defined as `x * int_val + poly[..., k] * zero_power`.
  75. Note that `zero_power` is not a parameter,
  76. because the step `+ poly[..., k] * zero_power` depends on `x`,
  77. whether it is a vector, a matrix, or something else, so this
  78. functionality is delegated to the user.
  79. """
  80. res = zero_power.clone()
  81. for k in range(poly.size(-1) - 2, -1, -1):
  82. res = transition(res, x, poly[..., k])
  83. return res
  84. def _matrix_polynomial_value(poly, x, zero_power=None):
  85. """
  86. Evaluates `poly(x)` for the (batched) matrix input `x`.
  87. Check out `_polynomial_value` function for more details.
  88. """
  89. # matrix-aware Horner's rule iteration
  90. def transition(curr_poly_val, x, poly_coeff):
  91. res = x.matmul(curr_poly_val)
  92. res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1))
  93. return res
  94. if zero_power is None:
  95. zero_power = torch.eye(x.size(-1), x.size(-1), dtype=x.dtype, device=x.device) \
  96. .view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1))
  97. return _polynomial_value(poly, x, zero_power, transition)
  98. def _vector_polynomial_value(poly, x, zero_power=None):
  99. """
  100. Evaluates `poly(x)` for the (batched) vector input `x`.
  101. Check out `_polynomial_value` function for more details.
  102. """
  103. # vector-aware Horner's rule iteration
  104. def transition(curr_poly_val, x, poly_coeff):
  105. res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val)
  106. return res
  107. if zero_power is None:
  108. zero_power = x.new_ones(1).expand(x.shape)
  109. return _polynomial_value(poly, x, zero_power, transition)
  110. def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest):
  111. # compute a projection operator onto an orthogonal subspace spanned by the
  112. # columns of U defined as (I - UU^T)
  113. Ut = U.mT.contiguous()
  114. proj_U_ortho = -U.matmul(Ut)
  115. proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1)
  116. # compute U_ortho, a basis for the orthogonal complement to the span(U),
  117. # by projecting a random [..., m, m - k] matrix onto the subspace spanned
  118. # by the columns of U.
  119. #
  120. # fix generator for determinism
  121. gen = torch.Generator(A.device)
  122. # orthogonal complement to the span(U)
  123. U_ortho = proj_U_ortho.matmul(
  124. torch.randn(
  125. (*A.shape[:-1], A.size(-1) - D.size(-1)),
  126. dtype=A.dtype,
  127. device=A.device,
  128. generator=gen
  129. )
  130. )
  131. U_ortho_t = U_ortho.mT.contiguous()
  132. # compute the coefficients of the characteristic polynomial of the tensor D.
  133. # Note that D is diagonal, so the diagonal elements are exactly the roots
  134. # of the characteristic polynomial.
  135. chr_poly_D = _polynomial_coefficients_given_roots(D)
  136. # the code belows finds the explicit solution to the Sylvester equation
  137. # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U
  138. # and incorporates it into the whole gradient stored in the `res` variable.
  139. #
  140. # Equivalent to the following naive implementation:
  141. # res = A.new_zeros(A.shape)
  142. # p_res = A.new_zeros(*A.shape[:-1], D.size(-1))
  143. # for k in range(1, chr_poly_D.size(-1)):
  144. # p_res.zero_()
  145. # for i in range(0, k):
  146. # p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2)
  147. # res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t())
  148. #
  149. # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity
  150. # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g,
  151. # and we need to compute g(U_grad, A, U, D)
  152. #
  153. # The naive implementation is based on the paper
  154. # Hu, Qingxi, and Daizhan Cheng.
  155. # "The polynomial solution to the Sylvester matrix equation."
  156. # Applied mathematics letters 19.9 (2006): 859-864.
  157. #
  158. # We can modify the computation of `p_res` from above in a more efficient way
  159. # p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2)
  160. # + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2)
  161. # + ...
  162. # + A.matrix_power(k - 1) U_grad * chr_poly_D[k]
  163. # Note that this saves us from redundant matrix products with A (elimination of matrix_power)
  164. U_grad_projected = U_grad
  165. series_acc = U_grad_projected.new_zeros(U_grad_projected.shape)
  166. for k in range(1, chr_poly_D.size(-1)):
  167. poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D)
  168. series_acc += U_grad_projected * poly_D.unsqueeze(-2)
  169. U_grad_projected = A.matmul(U_grad_projected)
  170. # compute chr_poly_D(A) which essentially is:
  171. #
  172. # chr_poly_D_at_A = A.new_zeros(A.shape)
  173. # for k in range(chr_poly_D.size(-1)):
  174. # chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k)
  175. #
  176. # Note, however, for better performance we use the Horner's rule
  177. chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A)
  178. # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t
  179. chr_poly_D_at_A_to_U_ortho = torch.matmul(
  180. U_ortho_t,
  181. torch.matmul(
  182. chr_poly_D_at_A,
  183. U_ortho
  184. )
  185. )
  186. # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its
  187. # Cholesky decomposition and then use `torch.cholesky_solve` for better stability.
  188. # Cholesky decomposition requires the input to be positive-definite.
  189. # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if
  190. # 1. `largest` == False, or
  191. # 2. `largest` == True and `k` is even
  192. # under the assumption that `A` has distinct eigenvalues.
  193. #
  194. # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite
  195. chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1
  196. chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky(
  197. chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho
  198. )
  199. # compute the gradient part in span(U)
  200. res = _symeig_backward_complete_eigenspace(
  201. D_grad, U_grad, A, D, U
  202. )
  203. # incorporate the Sylvester equation solution into the full gradient
  204. # it resides in span(U_ortho)
  205. res -= U_ortho.matmul(
  206. chr_poly_D_at_A_to_U_ortho_sign * torch.cholesky_solve(
  207. U_ortho_t.matmul(series_acc),
  208. chr_poly_D_at_A_to_U_ortho_L
  209. )
  210. ).matmul(Ut)
  211. return res
  212. def _symeig_backward(D_grad, U_grad, A, D, U, largest):
  213. # if `U` is square, then the columns of `U` is a complete eigenspace
  214. if U.size(-1) == U.size(-2):
  215. return _symeig_backward_complete_eigenspace(
  216. D_grad, U_grad, A, D, U
  217. )
  218. else:
  219. return _symeig_backward_partial_eigenspace(
  220. D_grad, U_grad, A, D, U, largest
  221. )
  222. class LOBPCGAutogradFunction(torch.autograd.Function):
  223. @staticmethod
  224. def forward(ctx, # type: ignore[override]
  225. A: Tensor,
  226. k: Optional[int] = None,
  227. B: Optional[Tensor] = None,
  228. X: Optional[Tensor] = None,
  229. n: Optional[int] = None,
  230. iK: Optional[Tensor] = None,
  231. niter: Optional[int] = None,
  232. tol: Optional[float] = None,
  233. largest: Optional[bool] = None,
  234. method: Optional[str] = None,
  235. tracker: None = None,
  236. ortho_iparams: Optional[Dict[str, int]] = None,
  237. ortho_fparams: Optional[Dict[str, float]] = None,
  238. ortho_bparams: Optional[Dict[str, bool]] = None
  239. ) -> Tuple[Tensor, Tensor]:
  240. # makes sure that input is contiguous for efficiency.
  241. # Note: autograd does not support dense gradients for sparse input yet.
  242. A = A.contiguous() if (not A.is_sparse) else A
  243. if B is not None:
  244. B = B.contiguous() if (not B.is_sparse) else B
  245. D, U = _lobpcg(
  246. A, k, B, X,
  247. n, iK, niter, tol, largest, method, tracker,
  248. ortho_iparams, ortho_fparams, ortho_bparams
  249. )
  250. ctx.save_for_backward(A, B, D, U)
  251. ctx.largest = largest
  252. return D, U
  253. @staticmethod
  254. def backward(ctx, D_grad, U_grad):
  255. A_grad = B_grad = None
  256. grads = [None] * 14
  257. A, B, D, U = ctx.saved_tensors
  258. largest = ctx.largest
  259. # lobpcg.backward has some limitations. Checks for unsupported input
  260. if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]):
  261. raise ValueError(
  262. 'lobpcg.backward does not support sparse input yet.'
  263. 'Note that lobpcg.forward does though.'
  264. )
  265. if A.dtype in (torch.complex64, torch.complex128) or \
  266. B is not None and B.dtype in (torch.complex64, torch.complex128):
  267. raise ValueError(
  268. 'lobpcg.backward does not support complex input yet.'
  269. 'Note that lobpcg.forward does though.'
  270. )
  271. if B is not None:
  272. raise ValueError(
  273. 'lobpcg.backward does not support backward with B != I yet.'
  274. )
  275. if largest is None:
  276. largest = True
  277. # symeig backward
  278. if B is None:
  279. A_grad = _symeig_backward(
  280. D_grad, U_grad, A, D, U, largest
  281. )
  282. # A has index 0
  283. grads[0] = A_grad
  284. # B has index 2
  285. grads[2] = B_grad
  286. return tuple(grads)
  287. def lobpcg(A: Tensor,
  288. k: Optional[int] = None,
  289. B: Optional[Tensor] = None,
  290. X: Optional[Tensor] = None,
  291. n: Optional[int] = None,
  292. iK: Optional[Tensor] = None,
  293. niter: Optional[int] = None,
  294. tol: Optional[float] = None,
  295. largest: Optional[bool] = None,
  296. method: Optional[str] = None,
  297. tracker: None = None,
  298. ortho_iparams: Optional[Dict[str, int]] = None,
  299. ortho_fparams: Optional[Dict[str, float]] = None,
  300. ortho_bparams: Optional[Dict[str, bool]] = None
  301. ) -> Tuple[Tensor, Tensor]:
  302. """Find the k largest (or smallest) eigenvalues and the corresponding
  303. eigenvectors of a symmetric positive definite generalized
  304. eigenvalue problem using matrix-free LOBPCG methods.
  305. This function is a front-end to the following LOBPCG algorithms
  306. selectable via `method` argument:
  307. `method="basic"` - the LOBPCG method introduced by Andrew
  308. Knyazev, see [Knyazev2001]. A less robust method, may fail when
  309. Cholesky is applied to singular input.
  310. `method="ortho"` - the LOBPCG method with orthogonal basis
  311. selection [StathopoulosEtal2002]. A robust method.
  312. Supported inputs are dense, sparse, and batches of dense matrices.
  313. .. note:: In general, the basic method spends least time per
  314. iteration. However, the robust methods converge much faster and
  315. are more stable. So, the usage of the basic method is generally
  316. not recommended but there exist cases where the usage of the
  317. basic method may be preferred.
  318. .. warning:: The backward method does not support sparse and complex inputs.
  319. It works only when `B` is not provided (i.e. `B == None`).
  320. We are actively working on extensions, and the details of
  321. the algorithms are going to be published promptly.
  322. .. warning:: While it is assumed that `A` is symmetric, `A.grad` is not.
  323. To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric
  324. in first-order optimization routines, prior to running `lobpcg`
  325. we do the following symmetrization map: `A -> (A + A.t()) / 2`.
  326. The map is performed only when the `A` requires gradients.
  327. Args:
  328. A (Tensor): the input tensor of size :math:`(*, m, m)`
  329. B (Tensor, optional): the input tensor of size :math:`(*, m,
  330. m)`. When not specified, `B` is interpereted as
  331. identity matrix.
  332. X (tensor, optional): the input tensor of size :math:`(*, m, n)`
  333. where `k <= n <= m`. When specified, it is used as
  334. initial approximation of eigenvectors. X must be a
  335. dense tensor.
  336. iK (tensor, optional): the input tensor of size :math:`(*, m,
  337. m)`. When specified, it will be used as preconditioner.
  338. k (integer, optional): the number of requested
  339. eigenpairs. Default is the number of :math:`X`
  340. columns (when specified) or `1`.
  341. n (integer, optional): if :math:`X` is not specified then `n`
  342. specifies the size of the generated random
  343. approximation of eigenvectors. Default value for `n`
  344. is `k`. If :math:`X` is specified, the value of `n`
  345. (when specified) must be the number of :math:`X`
  346. columns.
  347. tol (float, optional): residual tolerance for stopping
  348. criterion. Default is `feps ** 0.5` where `feps` is
  349. smallest non-zero floating-point number of the given
  350. input tensor `A` data type.
  351. largest (bool, optional): when True, solve the eigenproblem for
  352. the largest eigenvalues. Otherwise, solve the
  353. eigenproblem for smallest eigenvalues. Default is
  354. `True`.
  355. method (str, optional): select LOBPCG method. See the
  356. description of the function above. Default is
  357. "ortho".
  358. niter (int, optional): maximum number of iterations. When
  359. reached, the iteration process is hard-stopped and
  360. the current approximation of eigenpairs is returned.
  361. For infinite iteration but until convergence criteria
  362. is met, use `-1`.
  363. tracker (callable, optional) : a function for tracing the
  364. iteration process. When specified, it is called at
  365. each iteration step with LOBPCG instance as an
  366. argument. The LOBPCG instance holds the full state of
  367. the iteration process in the following attributes:
  368. `iparams`, `fparams`, `bparams` - dictionaries of
  369. integer, float, and boolean valued input
  370. parameters, respectively
  371. `ivars`, `fvars`, `bvars`, `tvars` - dictionaries
  372. of integer, float, boolean, and Tensor valued
  373. iteration variables, respectively.
  374. `A`, `B`, `iK` - input Tensor arguments.
  375. `E`, `X`, `S`, `R` - iteration Tensor variables.
  376. For instance:
  377. `ivars["istep"]` - the current iteration step
  378. `X` - the current approximation of eigenvectors
  379. `E` - the current approximation of eigenvalues
  380. `R` - the current residual
  381. `ivars["converged_count"]` - the current number of converged eigenpairs
  382. `tvars["rerr"]` - the current state of convergence criteria
  383. Note that when `tracker` stores Tensor objects from
  384. the LOBPCG instance, it must make copies of these.
  385. If `tracker` sets `bvars["force_stop"] = True`, the
  386. iteration process will be hard-stopped.
  387. ortho_iparams, ortho_fparams, ortho_bparams (dict, optional):
  388. various parameters to LOBPCG algorithm when using
  389. `method="ortho"`.
  390. Returns:
  391. E (Tensor): tensor of eigenvalues of size :math:`(*, k)`
  392. X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)`
  393. References:
  394. [Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal
  395. Preconditioned Eigensolver: Locally Optimal Block Preconditioned
  396. Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2),
  397. 517-541. (25 pages)
  398. https://epubs.siam.org/doi/abs/10.1137/S1064827500366124
  399. [StathopoulosEtal2002] Andreas Stathopoulos and Kesheng
  400. Wu. (2002) A Block Orthogonalization Procedure with Constant
  401. Synchronization Requirements. SIAM J. Sci. Comput., 23(6),
  402. 2165-2182. (18 pages)
  403. https://epubs.siam.org/doi/10.1137/S1064827500370883
  404. [DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming
  405. Gu. (2018) A Robust and Efficient Implementation of LOBPCG.
  406. SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages)
  407. https://epubs.siam.org/doi/abs/10.1137/17M1129830
  408. """
  409. if not torch.jit.is_scripting():
  410. tensor_ops = (A, B, X, iK)
  411. if (not set(map(type, tensor_ops)).issubset((torch.Tensor, type(None))) and has_torch_function(tensor_ops)):
  412. return handle_torch_function(
  413. lobpcg, tensor_ops, A, k=k,
  414. B=B, X=X, n=n, iK=iK, niter=niter, tol=tol,
  415. largest=largest, method=method, tracker=tracker,
  416. ortho_iparams=ortho_iparams,
  417. ortho_fparams=ortho_fparams,
  418. ortho_bparams=ortho_bparams)
  419. if not torch._jit_internal.is_scripting():
  420. if A.requires_grad or (B is not None and B.requires_grad):
  421. # While it is expected that `A` is symmetric,
  422. # the `A_grad` might be not. Therefore we perform the trick below,
  423. # so that `A_grad` becomes symmetric.
  424. # The symmetrization is important for first-order optimization methods,
  425. # so that (A - alpha * A_grad) is still a symmetric matrix.
  426. # Same holds for `B`.
  427. A_sym = (A + A.mT) / 2
  428. B_sym = (B + B.mT) / 2 if (B is not None) else None
  429. return LOBPCGAutogradFunction.apply(
  430. A_sym, k, B_sym, X, n, iK, niter, tol, largest,
  431. method, tracker, ortho_iparams, ortho_fparams, ortho_bparams
  432. )
  433. else:
  434. if A.requires_grad or (B is not None and B.requires_grad):
  435. raise RuntimeError(
  436. 'Script and require grads is not supported atm.'
  437. 'If you just want to do the forward, use .detach()'
  438. 'on A and B before calling into lobpcg'
  439. )
  440. return _lobpcg(
  441. A, k, B, X,
  442. n, iK, niter, tol, largest, method, tracker,
  443. ortho_iparams, ortho_fparams, ortho_bparams
  444. )
  445. def _lobpcg(A: Tensor,
  446. k: Optional[int] = None,
  447. B: Optional[Tensor] = None,
  448. X: Optional[Tensor] = None,
  449. n: Optional[int] = None,
  450. iK: Optional[Tensor] = None,
  451. niter: Optional[int] = None,
  452. tol: Optional[float] = None,
  453. largest: Optional[bool] = None,
  454. method: Optional[str] = None,
  455. tracker: None = None,
  456. ortho_iparams: Optional[Dict[str, int]] = None,
  457. ortho_fparams: Optional[Dict[str, float]] = None,
  458. ortho_bparams: Optional[Dict[str, bool]] = None
  459. ) -> Tuple[Tensor, Tensor]:
  460. # A must be square:
  461. assert A.shape[-2] == A.shape[-1], A.shape
  462. if B is not None:
  463. # A and B must have the same shapes:
  464. assert A.shape == B.shape, (A.shape, B.shape)
  465. dtype = _utils.get_floating_dtype(A)
  466. device = A.device
  467. if tol is None:
  468. feps = {torch.float32: 1.2e-07,
  469. torch.float64: 2.23e-16}[dtype]
  470. tol = feps ** 0.5
  471. m = A.shape[-1]
  472. k = (1 if X is None else X.shape[-1]) if k is None else k
  473. n = (k if n is None else n) if X is None else X.shape[-1]
  474. if (m < 3 * n):
  475. raise ValueError(
  476. 'LPBPCG algorithm is not applicable when the number of A rows (={})'
  477. ' is smaller than 3 x the number of requested eigenpairs (={})'
  478. .format(m, n))
  479. method = 'ortho' if method is None else method
  480. iparams = {
  481. 'm': m,
  482. 'n': n,
  483. 'k': k,
  484. 'niter': 1000 if niter is None else niter,
  485. }
  486. fparams = {
  487. 'tol': tol,
  488. }
  489. bparams = {
  490. 'largest': True if largest is None else largest
  491. }
  492. if method == 'ortho':
  493. if ortho_iparams is not None:
  494. iparams.update(ortho_iparams)
  495. if ortho_fparams is not None:
  496. fparams.update(ortho_fparams)
  497. if ortho_bparams is not None:
  498. bparams.update(ortho_bparams)
  499. iparams['ortho_i_max'] = iparams.get('ortho_i_max', 3)
  500. iparams['ortho_j_max'] = iparams.get('ortho_j_max', 3)
  501. fparams['ortho_tol'] = fparams.get('ortho_tol', tol)
  502. fparams['ortho_tol_drop'] = fparams.get('ortho_tol_drop', tol)
  503. fparams['ortho_tol_replace'] = fparams.get('ortho_tol_replace', tol)
  504. bparams['ortho_use_drop'] = bparams.get('ortho_use_drop', False)
  505. if not torch.jit.is_scripting():
  506. LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore[assignment]
  507. if len(A.shape) > 2:
  508. N = int(torch.prod(torch.tensor(A.shape[:-2])))
  509. bA = A.reshape((N,) + A.shape[-2:])
  510. bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None
  511. bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None
  512. bE = torch.empty((N, k), dtype=dtype, device=device)
  513. bXret = torch.empty((N, m, k), dtype=dtype, device=device)
  514. for i in range(N):
  515. A_ = bA[i]
  516. B_ = bB[i] if bB is not None else None
  517. X_ = torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i]
  518. assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n))
  519. iparams['batch_index'] = i
  520. worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker)
  521. worker.run()
  522. bE[i] = worker.E[:k]
  523. bXret[i] = worker.X[:, :k]
  524. if not torch.jit.is_scripting():
  525. LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[assignment]
  526. return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k))
  527. X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X
  528. assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n))
  529. worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker)
  530. worker.run()
  531. if not torch.jit.is_scripting():
  532. LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[assignment]
  533. return worker.E[:k], worker.X[:, :k]
  534. class LOBPCG(object):
  535. """Worker class of LOBPCG methods.
  536. """
  537. def __init__(self,
  538. A: Optional[Tensor],
  539. B: Optional[Tensor],
  540. X: Tensor,
  541. iK: Optional[Tensor],
  542. iparams: Dict[str, int],
  543. fparams: Dict[str, float],
  544. bparams: Dict[str, bool],
  545. method: str,
  546. tracker: None
  547. ) -> None:
  548. # constant parameters
  549. self.A = A
  550. self.B = B
  551. self.iK = iK
  552. self.iparams = iparams
  553. self.fparams = fparams
  554. self.bparams = bparams
  555. self.method = method
  556. self.tracker = tracker
  557. m = iparams['m']
  558. n = iparams['n']
  559. # variable parameters
  560. self.X = X
  561. self.E = torch.zeros((n, ), dtype=X.dtype, device=X.device)
  562. self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device)
  563. self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device)
  564. self.tvars: Dict[str, Tensor] = {}
  565. self.ivars: Dict[str, int] = {'istep': 0}
  566. self.fvars: Dict[str, float] = {'_': 0.0}
  567. self.bvars: Dict[str, bool] = {'_': False}
  568. def __str__(self):
  569. lines = ['LOPBCG:']
  570. lines += [' iparams={}'.format(self.iparams)]
  571. lines += [' fparams={}'.format(self.fparams)]
  572. lines += [' bparams={}'.format(self.bparams)]
  573. lines += [' ivars={}'.format(self.ivars)]
  574. lines += [' fvars={}'.format(self.fvars)]
  575. lines += [' bvars={}'.format(self.bvars)]
  576. lines += [' tvars={}'.format(self.tvars)]
  577. lines += [' A={}'.format(self.A)]
  578. lines += [' B={}'.format(self.B)]
  579. lines += [' iK={}'.format(self.iK)]
  580. lines += [' X={}'.format(self.X)]
  581. lines += [' E={}'.format(self.E)]
  582. r = ''
  583. for line in lines:
  584. r += line + '\n'
  585. return r
  586. def update(self):
  587. """Set and update iteration variables.
  588. """
  589. if self.ivars['istep'] == 0:
  590. X_norm = float(torch.norm(self.X))
  591. iX_norm = X_norm ** -1
  592. A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm
  593. B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm
  594. self.fvars['X_norm'] = X_norm
  595. self.fvars['A_norm'] = A_norm
  596. self.fvars['B_norm'] = B_norm
  597. self.ivars['iterations_left'] = self.iparams['niter']
  598. self.ivars['converged_count'] = 0
  599. self.ivars['converged_end'] = 0
  600. if self.method == 'ortho':
  601. self._update_ortho()
  602. else:
  603. self._update_basic()
  604. self.ivars['iterations_left'] = self.ivars['iterations_left'] - 1
  605. self.ivars['istep'] = self.ivars['istep'] + 1
  606. def update_residual(self):
  607. """Update residual R from A, B, X, E.
  608. """
  609. mm = _utils.matmul
  610. self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E
  611. def update_converged_count(self):
  612. """Determine the number of converged eigenpairs using backward stable
  613. convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018].
  614. Users may redefine this method for custom convergence criteria.
  615. """
  616. # (...) -> int
  617. prev_count = self.ivars['converged_count']
  618. tol = self.fparams['tol']
  619. A_norm = self.fvars['A_norm']
  620. B_norm = self.fvars['B_norm']
  621. E, X, R = self.E, self.X, self.R
  622. rerr = torch.norm(R, 2, (0, )) * (torch.norm(X, 2, (0, )) * (A_norm + E[:X.shape[-1]] * B_norm)) ** -1
  623. converged = rerr < tol
  624. count = 0
  625. for b in converged:
  626. if not b:
  627. # ignore convergence of following pairs to ensure
  628. # strict ordering of eigenpairs
  629. break
  630. count += 1
  631. assert count >= prev_count, 'the number of converged eigenpairs ' \
  632. '(was {}, got {}) cannot decrease'.format(prev_count, count)
  633. self.ivars['converged_count'] = count
  634. self.tvars['rerr'] = rerr
  635. return count
  636. def stop_iteration(self):
  637. """Return True to stop iterations.
  638. Note that tracker (if defined) can force-stop iterations by
  639. setting ``worker.bvars['force_stop'] = True``.
  640. """
  641. return (self.bvars.get('force_stop', False)
  642. or self.ivars['iterations_left'] == 0
  643. or self.ivars['converged_count'] >= self.iparams['k'])
  644. def run(self):
  645. """Run LOBPCG iterations.
  646. Use this method as a template for implementing LOBPCG
  647. iteration scheme with custom tracker that is compatible with
  648. TorchScript.
  649. """
  650. self.update()
  651. if not torch.jit.is_scripting() and self.tracker is not None:
  652. self.call_tracker()
  653. while not self.stop_iteration():
  654. self.update()
  655. if not torch.jit.is_scripting() and self.tracker is not None:
  656. self.call_tracker()
  657. @torch.jit.unused
  658. def call_tracker(self):
  659. """Interface for tracking iteration process in Python mode.
  660. Tracking the iteration process is disabled in TorchScript
  661. mode. In fact, one should specify tracker=None when JIT
  662. compiling functions using lobpcg.
  663. """
  664. # do nothing when in TorchScript mode
  665. pass
  666. # Internal methods
  667. def _update_basic(self):
  668. """
  669. Update or initialize iteration variables when `method == "basic"`.
  670. """
  671. mm = torch.matmul
  672. ns = self.ivars['converged_end']
  673. nc = self.ivars['converged_count']
  674. n = self.iparams['n']
  675. largest = self.bparams['largest']
  676. if self.ivars['istep'] == 0:
  677. Ri = self._get_rayleigh_ritz_transform(self.X)
  678. M = _utils.qform(_utils.qform(self.A, self.X), Ri)
  679. E, Z = _utils.symeig(M, largest)
  680. self.X[:] = mm(self.X, mm(Ri, Z))
  681. self.E[:] = E
  682. np = 0
  683. self.update_residual()
  684. nc = self.update_converged_count()
  685. self.S[..., :n] = self.X
  686. W = _utils.matmul(self.iK, self.R)
  687. self.ivars['converged_end'] = ns = n + np + W.shape[-1]
  688. self.S[:, n + np:ns] = W
  689. else:
  690. S_ = self.S[:, nc:ns]
  691. Ri = self._get_rayleigh_ritz_transform(S_)
  692. M = _utils.qform(_utils.qform(self.A, S_), Ri)
  693. E_, Z = _utils.symeig(M, largest)
  694. self.X[:, nc:] = mm(S_, mm(Ri, Z[:, :n - nc]))
  695. self.E[nc:] = E_[:n - nc]
  696. P = mm(S_, mm(Ri, Z[:, n:2 * n - nc]))
  697. np = P.shape[-1]
  698. self.update_residual()
  699. nc = self.update_converged_count()
  700. self.S[..., :n] = self.X
  701. self.S[:, n:n + np] = P
  702. W = _utils.matmul(self.iK, self.R[:, nc:])
  703. self.ivars['converged_end'] = ns = n + np + W.shape[-1]
  704. self.S[:, n + np:ns] = W
  705. def _update_ortho(self):
  706. """
  707. Update or initialize iteration variables when `method == "ortho"`.
  708. """
  709. mm = torch.matmul
  710. ns = self.ivars['converged_end']
  711. nc = self.ivars['converged_count']
  712. n = self.iparams['n']
  713. largest = self.bparams['largest']
  714. if self.ivars['istep'] == 0:
  715. Ri = self._get_rayleigh_ritz_transform(self.X)
  716. M = _utils.qform(_utils.qform(self.A, self.X), Ri)
  717. E, Z = _utils.symeig(M, largest)
  718. self.X = mm(self.X, mm(Ri, Z))
  719. self.update_residual()
  720. np = 0
  721. nc = self.update_converged_count()
  722. self.S[:, :n] = self.X
  723. W = self._get_ortho(self.R, self.X)
  724. ns = self.ivars['converged_end'] = n + np + W.shape[-1]
  725. self.S[:, n + np:ns] = W
  726. else:
  727. S_ = self.S[:, nc:ns]
  728. # Rayleigh-Ritz procedure
  729. E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest)
  730. # Update E, X, P
  731. self.X[:, nc:] = mm(S_, Z[:, :n - nc])
  732. self.E[nc:] = E_[:n - nc]
  733. P = mm(S_, mm(Z[:, n - nc:], _utils.basis(_utils.transpose(Z[:n - nc, n - nc:]))))
  734. np = P.shape[-1]
  735. # check convergence
  736. self.update_residual()
  737. nc = self.update_converged_count()
  738. # update S
  739. self.S[:, :n] = self.X
  740. self.S[:, n:n + np] = P
  741. W = self._get_ortho(self.R[:, nc:], self.S[:, :n + np])
  742. ns = self.ivars['converged_end'] = n + np + W.shape[-1]
  743. self.S[:, n + np:ns] = W
  744. def _get_rayleigh_ritz_transform(self, S):
  745. """Return a transformation matrix that is used in Rayleigh-Ritz
  746. procedure for reducing a general eigenvalue problem :math:`(S^TAS)
  747. C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T
  748. S^TAS Ri) Z = Z E` where `C = Ri Z`.
  749. .. note:: In the original Rayleight-Ritz procedure in
  750. [DuerschEtal2018], the problem is formulated as follows::
  751. SAS = S^T A S
  752. SBS = S^T B S
  753. D = (<diagonal matrix of SBS>) ** -1/2
  754. R^T R = Cholesky(D SBS D)
  755. Ri = D R^-1
  756. solve symeig problem Ri^T SAS Ri Z = Theta Z
  757. C = Ri Z
  758. To reduce the number of matrix products (denoted by empty
  759. space between matrices), here we introduce element-wise
  760. products (denoted by symbol `*`) so that the Rayleight-Ritz
  761. procedure becomes::
  762. SAS = S^T A S
  763. SBS = S^T B S
  764. d = (<diagonal of SBS>) ** -1/2 # this is 1-d column vector
  765. dd = d d^T # this is 2-d matrix
  766. R^T R = Cholesky(dd * SBS)
  767. Ri = R^-1 * d # broadcasting
  768. solve symeig problem Ri^T SAS Ri Z = Theta Z
  769. C = Ri Z
  770. where `dd` is 2-d matrix that replaces matrix products `D M
  771. D` with one element-wise product `M * dd`; and `d` replaces
  772. matrix product `D M` with element-wise product `M *
  773. d`. Also, creating the diagonal matrix `D` is avoided.
  774. Args:
  775. S (Tensor): the matrix basis for the search subspace, size is
  776. :math:`(m, n)`.
  777. Returns:
  778. Ri (tensor): upper-triangular transformation matrix of size
  779. :math:`(n, n)`.
  780. """
  781. B = self.B
  782. mm = torch.matmul
  783. SBS = _utils.qform(B, S)
  784. d_row = SBS.diagonal(0, -2, -1) ** -0.5
  785. d_col = d_row.reshape(d_row.shape[0], 1)
  786. # TODO use torch.linalg.cholesky_solve once it is implemented
  787. R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True)
  788. return torch.linalg.solve_triangular(R, d_row.diag_embed(), upper=True, left=False)
  789. def _get_svqb(self,
  790. U: Tensor, # Tensor
  791. drop: bool, # bool
  792. tau: float # float
  793. ) -> Tensor:
  794. """Return B-orthonormal U.
  795. .. note:: When `drop` is `False` then `svqb` is based on the
  796. Algorithm 4 from [DuerschPhD2015] that is a slight
  797. modification of the corresponding algorithm
  798. introduced in [StathopolousWu2002].
  799. Args:
  800. U (Tensor) : initial approximation, size is (m, n)
  801. drop (bool) : when True, drop columns that
  802. contribution to the `span([U])` is small.
  803. tau (float) : positive tolerance
  804. Returns:
  805. U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size
  806. is (m, n1), where `n1 = n` if `drop` is `False,
  807. otherwise `n1 <= n`.
  808. """
  809. if torch.numel(U) == 0:
  810. return U
  811. UBU = _utils.qform(self.B, U)
  812. d = UBU.diagonal(0, -2, -1)
  813. # Detect and drop exact zero columns from U. While the test
  814. # `abs(d) == 0` is unlikely to be True for random data, it is
  815. # possible to construct input data to lobpcg where it will be
  816. # True leading to a failure (notice the `d ** -0.5` operation
  817. # in the original algorithm). To prevent the failure, we drop
  818. # the exact zero columns here and then continue with the
  819. # original algorithm below.
  820. nz = torch.where(abs(d) != 0.0)
  821. assert len(nz) == 1, nz
  822. if len(nz[0]) < len(d):
  823. U = U[:, nz[0]]
  824. if torch.numel(U) == 0:
  825. return U
  826. UBU = _utils.qform(self.B, U)
  827. d = UBU.diagonal(0, -2, -1)
  828. nz = torch.where(abs(d) != 0.0)
  829. assert len(nz[0]) == len(d)
  830. # The original algorithm 4 from [DuerschPhD2015].
  831. d_col = (d ** -0.5).reshape(d.shape[0], 1)
  832. DUBUD = (UBU * d_col) * _utils.transpose(d_col)
  833. E, Z = _utils.symeig(DUBUD)
  834. t = tau * abs(E).max()
  835. if drop:
  836. keep = torch.where(E > t)
  837. assert len(keep) == 1, keep
  838. E = E[keep[0]]
  839. Z = Z[:, keep[0]]
  840. d_col = d_col[keep[0]]
  841. else:
  842. E[(torch.where(E < t))[0]] = t
  843. return torch.matmul(U * _utils.transpose(d_col), Z * E ** -0.5)
  844. def _get_ortho(self, U, V):
  845. """Return B-orthonormal U with columns are B-orthogonal to V.
  846. .. note:: When `bparams["ortho_use_drop"] == False` then
  847. `_get_ortho` is based on the Algorithm 3 from
  848. [DuerschPhD2015] that is a slight modification of
  849. the corresponding algorithm introduced in
  850. [StathopolousWu2002]. Otherwise, the method
  851. implements Algorithm 6 from [DuerschPhD2015]
  852. .. note:: If all U columns are B-collinear to V then the
  853. returned tensor U will be empty.
  854. Args:
  855. U (Tensor) : initial approximation, size is (m, n)
  856. V (Tensor) : B-orthogonal external basis, size is (m, k)
  857. Returns:
  858. U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`)
  859. such that :math:`V^T B U=0`, size is (m, n1),
  860. where `n1 = n` if `drop` is `False, otherwise
  861. `n1 <= n`.
  862. """
  863. mm = torch.matmul
  864. mm_B = _utils.matmul
  865. m = self.iparams['m']
  866. tau_ortho = self.fparams['ortho_tol']
  867. tau_drop = self.fparams['ortho_tol_drop']
  868. tau_replace = self.fparams['ortho_tol_replace']
  869. i_max = self.iparams['ortho_i_max']
  870. j_max = self.iparams['ortho_j_max']
  871. # when use_drop==True, enable dropping U columns that have
  872. # small contribution to the `span([U, V])`.
  873. use_drop = self.bparams['ortho_use_drop']
  874. # clean up variables from the previous call
  875. for vkey in list(self.fvars.keys()):
  876. if vkey.startswith('ortho_') and vkey.endswith('_rerr'):
  877. self.fvars.pop(vkey)
  878. self.ivars.pop('ortho_i', 0)
  879. self.ivars.pop('ortho_j', 0)
  880. BV_norm = torch.norm(mm_B(self.B, V))
  881. BU = mm_B(self.B, U)
  882. VBU = mm(_utils.transpose(V), BU)
  883. i = j = 0
  884. stats = ''
  885. for i in range(i_max):
  886. U = U - mm(V, VBU)
  887. drop = False
  888. tau_svqb = tau_drop
  889. for j in range(j_max):
  890. if use_drop:
  891. U = self._get_svqb(U, drop, tau_svqb)
  892. drop = True
  893. tau_svqb = tau_replace
  894. else:
  895. U = self._get_svqb(U, False, tau_replace)
  896. if torch.numel(U) == 0:
  897. # all initial U columns are B-collinear to V
  898. self.ivars['ortho_i'] = i
  899. self.ivars['ortho_j'] = j
  900. return U
  901. BU = mm_B(self.B, U)
  902. UBU = mm(_utils.transpose(U), BU)
  903. U_norm = torch.norm(U)
  904. BU_norm = torch.norm(BU)
  905. R = UBU - torch.eye(UBU.shape[-1],
  906. device=UBU.device,
  907. dtype=UBU.dtype)
  908. R_norm = torch.norm(R)
  909. # https://github.com/pytorch/pytorch/issues/33810 workaround:
  910. rerr = float(R_norm) * float(BU_norm * U_norm) ** -1
  911. vkey = 'ortho_UBUmI_rerr[{}, {}]'.format(i, j)
  912. self.fvars[vkey] = rerr
  913. if rerr < tau_ortho:
  914. break
  915. VBU = mm(_utils.transpose(V), BU)
  916. VBU_norm = torch.norm(VBU)
  917. U_norm = torch.norm(U)
  918. rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1
  919. vkey = 'ortho_VBU_rerr[{}]'.format(i)
  920. self.fvars[vkey] = rerr
  921. if rerr < tau_ortho:
  922. break
  923. if m < U.shape[-1] + V.shape[-1]:
  924. # TorchScript needs the class var to be assigned to a local to
  925. # do optional type refinement
  926. B = self.B
  927. assert B is not None
  928. raise ValueError(
  929. 'Overdetermined shape of U:'
  930. ' #B-cols(={}) >= #U-cols(={}) + #V-cols(={}) must hold'
  931. .format(B.shape[-1], U.shape[-1], V.shape[-1]))
  932. self.ivars['ortho_i'] = i
  933. self.ivars['ortho_j'] = j
  934. return U
  935. # Calling tracker is separated from LOBPCG definitions because
  936. # TorchScript does not support user-defined callback arguments:
  937. LOBPCG_call_tracker_orig = LOBPCG.call_tracker
  938. def LOBPCG_call_tracker(self):
  939. self.tracker(self)