lkj_cholesky.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. """
  2. This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro).
  3. Original copyright notice:
  4. # Copyright: Contributors to the Pyro project.
  5. # SPDX-License-Identifier: Apache-2.0
  6. """
  7. import math
  8. import torch
  9. from torch.distributions import constraints, Beta
  10. from torch.distributions.distribution import Distribution
  11. from torch.distributions.utils import broadcast_all
  12. class LKJCholesky(Distribution):
  13. r"""
  14. LKJ distribution for lower Cholesky factor of correlation matrices.
  15. The distribution is controlled by ``concentration`` parameter :math:`\eta`
  16. to make the probability of the correlation matrix :math:`M` generated from
  17. a Cholesky factor propotional to :math:`\det(M)^{\eta - 1}`. Because of that,
  18. when ``concentration == 1``, we have a uniform distribution over Cholesky
  19. factors of correlation matrices. Note that this distribution samples the
  20. Cholesky factor of correlation matrices and not the correlation matrices
  21. themselves and thereby differs slightly from the derivations in [1] for
  22. the `LKJCorr` distribution. For sampling, this uses the Onion method from
  23. [1] Section 3.
  24. L ~ LKJCholesky(dim, concentration)
  25. X = L @ L' ~ LKJCorr(dim, concentration)
  26. Example::
  27. >>> l = LKJCholesky(3, 0.5)
  28. >>> l.sample() # l @ l.T is a sample of a correlation 3x3 matrix
  29. tensor([[ 1.0000, 0.0000, 0.0000],
  30. [ 0.3516, 0.9361, 0.0000],
  31. [-0.1899, 0.4748, 0.8593]])
  32. Args:
  33. dimension (dim): dimension of the matrices
  34. concentration (float or Tensor): concentration/shape parameter of the
  35. distribution (often referred to as eta)
  36. **References**
  37. [1] `Generating random correlation matrices based on vines and extended onion method`,
  38. Daniel Lewandowski, Dorota Kurowicka, Harry Joe.
  39. """
  40. arg_constraints = {'concentration': constraints.positive}
  41. support = constraints.corr_cholesky
  42. def __init__(self, dim, concentration=1., validate_args=None):
  43. if dim < 2:
  44. raise ValueError(f'Expected dim to be an integer greater than or equal to 2. Found dim={dim}.')
  45. self.dim = dim
  46. self.concentration, = broadcast_all(concentration)
  47. batch_shape = self.concentration.size()
  48. event_shape = torch.Size((dim, dim))
  49. # This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1].
  50. marginal_conc = self.concentration + 0.5 * (self.dim - 2)
  51. offset = torch.arange(self.dim - 1, dtype=self.concentration.dtype, device=self.concentration.device)
  52. offset = torch.cat([offset.new_zeros((1,)), offset])
  53. beta_conc1 = offset + 0.5
  54. beta_conc0 = marginal_conc.unsqueeze(-1) - 0.5 * offset
  55. self._beta = Beta(beta_conc1, beta_conc0)
  56. super(LKJCholesky, self).__init__(batch_shape, event_shape, validate_args)
  57. def expand(self, batch_shape, _instance=None):
  58. new = self._get_checked_instance(LKJCholesky, _instance)
  59. batch_shape = torch.Size(batch_shape)
  60. new.dim = self.dim
  61. new.concentration = self.concentration.expand(batch_shape)
  62. new._beta = self._beta.expand(batch_shape + (self.dim,))
  63. super(LKJCholesky, new).__init__(batch_shape, self.event_shape, validate_args=False)
  64. new._validate_args = self._validate_args
  65. return new
  66. def sample(self, sample_shape=torch.Size()):
  67. # This uses the Onion method, but there are a few differences from [1] Sec. 3.2:
  68. # - This vectorizes the for loop and also works for heterogeneous eta.
  69. # - Same algorithm generalizes to n=1.
  70. # - The procedure is simplified since we are sampling the cholesky factor of
  71. # the correlation matrix instead of the correlation matrix itself. As such,
  72. # we only need to generate `w`.
  73. y = self._beta.sample(sample_shape).unsqueeze(-1)
  74. u_normal = torch.randn(self._extended_shape(sample_shape),
  75. dtype=y.dtype,
  76. device=y.device).tril(-1)
  77. u_hypersphere = u_normal / u_normal.norm(dim=-1, keepdim=True)
  78. # Replace NaNs in first row
  79. u_hypersphere[..., 0, :].fill_(0.)
  80. w = torch.sqrt(y) * u_hypersphere
  81. # Fill diagonal elements; clamp for numerical stability
  82. eps = torch.finfo(w.dtype).tiny
  83. diag_elems = torch.clamp(1 - torch.sum(w**2, dim=-1), min=eps).sqrt()
  84. w += torch.diag_embed(diag_elems)
  85. return w
  86. def log_prob(self, value):
  87. # See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html
  88. # The probability of a correlation matrix is proportional to
  89. # determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1))
  90. # Additionally, the Jacobian of the transformation from Cholesky factor to
  91. # correlation matrix is:
  92. # prod(L_ii ^ (D - i))
  93. # So the probability of a Cholesky factor is propotional to
  94. # prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i)
  95. # with order_i = 2 * concentration - 2 + D - i
  96. if self._validate_args:
  97. self._validate_sample(value)
  98. diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:]
  99. order = torch.arange(2, self.dim + 1, device=self.concentration.device)
  100. order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order
  101. unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1)
  102. # Compute normalization constant (page 1999 of [1])
  103. dm1 = self.dim - 1
  104. alpha = self.concentration + 0.5 * dm1
  105. denominator = torch.lgamma(alpha) * dm1
  106. numerator = torch.mvlgamma(alpha - 0.5, dm1)
  107. # pi_constant in [1] is D * (D - 1) / 4 * log(pi)
  108. # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi)
  109. # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2
  110. pi_constant = 0.5 * dm1 * math.log(math.pi)
  111. normalize_term = pi_constant + numerator - denominator
  112. return unnormalized_log_pdf - normalize_term