wishart.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. import math
  2. import warnings
  3. from numbers import Number
  4. from typing import Union
  5. import torch
  6. from torch._six import nan
  7. from torch.distributions import constraints
  8. from torch.distributions.exp_family import ExponentialFamily
  9. from torch.distributions.utils import lazy_property
  10. from torch.distributions.multivariate_normal import _precision_to_scale_tril
  11. _log_2 = math.log(2)
  12. def _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor:
  13. assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function."
  14. return torch.digamma(
  15. x.unsqueeze(-1)
  16. - torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,))
  17. ).sum(-1)
  18. def _clamp_above_eps(x: torch.Tensor) -> torch.Tensor:
  19. # We assume positive input for this function
  20. return x.clamp(min=torch.finfo(x.dtype).eps)
  21. class Wishart(ExponentialFamily):
  22. r"""
  23. Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`,
  24. or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`
  25. Example:
  26. >>> m = Wishart(torch.eye(2), torch.Tensor([2]))
  27. >>> m.sample() # Wishart distributed with mean=`df * I` and
  28. # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
  29. Args:
  30. covariance_matrix (Tensor): positive-definite covariance matrix
  31. precision_matrix (Tensor): positive-definite precision matrix
  32. scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
  33. df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1
  34. Note:
  35. Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
  36. :attr:`scale_tril` can be specified.
  37. Using :attr:`scale_tril` will be more efficient: all computations internally
  38. are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
  39. :attr:`precision_matrix` is passed instead, it is only used to compute
  40. the corresponding lower triangular matrices using a Cholesky decomposition.
  41. 'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1]
  42. **References**
  43. [1] `On equivalence of the LKJ distribution and the restricted Wishart distribution`,
  44. Zhenxun Wang, Yunan Wu, Haitao Chu.
  45. """
  46. arg_constraints = {
  47. 'covariance_matrix': constraints.positive_definite,
  48. 'precision_matrix': constraints.positive_definite,
  49. 'scale_tril': constraints.lower_cholesky,
  50. 'df': constraints.greater_than(0),
  51. }
  52. support = constraints.positive_definite
  53. has_rsample = True
  54. _mean_carrier_measure = 0
  55. def __init__(self,
  56. df: Union[torch.Tensor, Number],
  57. covariance_matrix: torch.Tensor = None,
  58. precision_matrix: torch.Tensor = None,
  59. scale_tril: torch.Tensor = None,
  60. validate_args=None):
  61. assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \
  62. "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
  63. param = next(p for p in (covariance_matrix, precision_matrix, scale_tril) if p is not None)
  64. if param.dim() < 2:
  65. raise ValueError("scale_tril must be at least two-dimensional, with optional leading batch dimensions")
  66. if isinstance(df, Number):
  67. batch_shape = torch.Size(param.shape[:-2])
  68. self.df = torch.tensor(df, dtype=param.dtype, device=param.device)
  69. else:
  70. batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape)
  71. self.df = df.expand(batch_shape)
  72. event_shape = param.shape[-2:]
  73. if self.df.le(event_shape[-1] - 1).any():
  74. raise ValueError(f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}.")
  75. if scale_tril is not None:
  76. self.scale_tril = param.expand(batch_shape + (-1, -1))
  77. elif covariance_matrix is not None:
  78. self.covariance_matrix = param.expand(batch_shape + (-1, -1))
  79. elif precision_matrix is not None:
  80. self.precision_matrix = param.expand(batch_shape + (-1, -1))
  81. self.arg_constraints['df'] = constraints.greater_than(event_shape[-1] - 1)
  82. if self.df.lt(event_shape[-1]).any():
  83. warnings.warn("Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim.")
  84. super(Wishart, self).__init__(batch_shape, event_shape, validate_args=validate_args)
  85. self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]
  86. if scale_tril is not None:
  87. self._unbroadcasted_scale_tril = scale_tril
  88. elif covariance_matrix is not None:
  89. self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
  90. else: # precision_matrix is not None
  91. self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
  92. # Chi2 distribution is needed for Bartlett decomposition sampling
  93. self._dist_chi2 = torch.distributions.chi2.Chi2(
  94. df=(
  95. self.df.unsqueeze(-1)
  96. - torch.arange(
  97. self._event_shape[-1],
  98. dtype=self._unbroadcasted_scale_tril.dtype,
  99. device=self._unbroadcasted_scale_tril.device,
  100. ).expand(batch_shape + (-1,))
  101. )
  102. )
  103. def expand(self, batch_shape, _instance=None):
  104. new = self._get_checked_instance(Wishart, _instance)
  105. batch_shape = torch.Size(batch_shape)
  106. cov_shape = batch_shape + self.event_shape
  107. new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape)
  108. new.df = self.df.expand(batch_shape)
  109. new._batch_dims = [-(x + 1) for x in range(len(batch_shape))]
  110. if 'covariance_matrix' in self.__dict__:
  111. new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
  112. if 'scale_tril' in self.__dict__:
  113. new.scale_tril = self.scale_tril.expand(cov_shape)
  114. if 'precision_matrix' in self.__dict__:
  115. new.precision_matrix = self.precision_matrix.expand(cov_shape)
  116. # Chi2 distribution is needed for Bartlett decomposition sampling
  117. new._dist_chi2 = torch.distributions.chi2.Chi2(
  118. df=(
  119. new.df.unsqueeze(-1)
  120. - torch.arange(
  121. self.event_shape[-1],
  122. dtype=new._unbroadcasted_scale_tril.dtype,
  123. device=new._unbroadcasted_scale_tril.device,
  124. ).expand(batch_shape + (-1,))
  125. )
  126. )
  127. super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False)
  128. new._validate_args = self._validate_args
  129. return new
  130. @lazy_property
  131. def scale_tril(self):
  132. return self._unbroadcasted_scale_tril.expand(
  133. self._batch_shape + self._event_shape)
  134. @lazy_property
  135. def covariance_matrix(self):
  136. return (
  137. self._unbroadcasted_scale_tril @ self._unbroadcasted_scale_tril.transpose(-2, -1)
  138. ).expand(self._batch_shape + self._event_shape)
  139. @lazy_property
  140. def precision_matrix(self):
  141. identity = torch.eye(
  142. self._event_shape[-1],
  143. device=self._unbroadcasted_scale_tril.device,
  144. dtype=self._unbroadcasted_scale_tril.dtype,
  145. )
  146. return torch.cholesky_solve(
  147. identity, self._unbroadcasted_scale_tril
  148. ).expand(self._batch_shape + self._event_shape)
  149. @property
  150. def mean(self):
  151. return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix
  152. @property
  153. def mode(self):
  154. factor = self.df - self.covariance_matrix.shape[-1] - 1
  155. factor[factor <= 0] = nan
  156. return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix
  157. @property
  158. def variance(self):
  159. V = self.covariance_matrix # has shape (batch_shape x event_shape)
  160. diag_V = V.diagonal(dim1=-2, dim2=-1)
  161. return self.df.view(self._batch_shape + (1, 1)) * (V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V))
  162. def _bartlett_sampling(self, sample_shape=torch.Size()):
  163. p = self._event_shape[-1] # has singleton shape
  164. # Implemented Sampling using Bartlett decomposition
  165. noise = _clamp_above_eps(
  166. self._dist_chi2.rsample(sample_shape).sqrt()
  167. ).diag_embed(dim1=-2, dim2=-1)
  168. i, j = torch.tril_indices(p, p, offset=-1)
  169. noise[..., i, j] = torch.randn(
  170. torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),),
  171. dtype=noise.dtype,
  172. device=noise.device,
  173. )
  174. chol = self._unbroadcasted_scale_tril @ noise
  175. return chol @ chol.transpose(-2, -1)
  176. def rsample(self, sample_shape=torch.Size(), max_try_correction=None):
  177. r"""
  178. .. warning::
  179. In some cases, sampling algorithn based on Bartlett decomposition may return singular matrix samples.
  180. Several tries to correct singular samples are performed by default, but it may end up returning
  181. singular matrix samples. Sigular samples may return `-inf` values in `.log_prob()`.
  182. In those cases, the user should validate the samples and either fix the value of `df`
  183. or adjust `max_try_correction` value for argument in `.rsample` accordingly.
  184. """
  185. if max_try_correction is None:
  186. max_try_correction = 3 if torch._C._get_tracing_state() else 10
  187. sample_shape = torch.Size(sample_shape)
  188. sample = self._bartlett_sampling(sample_shape)
  189. # Below part is to improve numerical stability temporally and should be removed in the future
  190. is_singular = self.support.check(sample)
  191. if self._batch_shape:
  192. is_singular = is_singular.amax(self._batch_dims)
  193. if torch._C._get_tracing_state():
  194. # Less optimized version for JIT
  195. for _ in range(max_try_correction):
  196. sample_new = self._bartlett_sampling(sample_shape)
  197. sample = torch.where(is_singular, sample_new, sample)
  198. is_singular = ~self.support.check(sample)
  199. if self._batch_shape:
  200. is_singular = is_singular.amax(self._batch_dims)
  201. else:
  202. # More optimized version with data-dependent control flow.
  203. if is_singular.any():
  204. warnings.warn("Singular sample detected.")
  205. for _ in range(max_try_correction):
  206. sample_new = self._bartlett_sampling(is_singular[is_singular].shape)
  207. sample[is_singular] = sample_new
  208. is_singular_new = ~self.support.check(sample_new)
  209. if self._batch_shape:
  210. is_singular_new = is_singular_new.amax(self._batch_dims)
  211. is_singular[is_singular.clone()] = is_singular_new
  212. if not is_singular.any():
  213. break
  214. return sample
  215. def log_prob(self, value):
  216. if self._validate_args:
  217. self._validate_sample(value)
  218. nu = self.df # has shape (batch_shape)
  219. p = self._event_shape[-1] # has singleton shape
  220. return (
  221. - nu * (p * _log_2 / 2 + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
  222. - torch.mvlgamma(nu / 2, p=p)
  223. + (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet
  224. - torch.cholesky_solve(value, self._unbroadcasted_scale_tril).diagonal(dim1=-2, dim2=-1).sum(dim=-1) / 2
  225. )
  226. def entropy(self):
  227. nu = self.df # has shape (batch_shape)
  228. p = self._event_shape[-1] # has singleton shape
  229. V = self.covariance_matrix # has shape (batch_shape x event_shape)
  230. return (
  231. (p + 1) * (p * _log_2 / 2 + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
  232. + torch.mvlgamma(nu / 2, p=p)
  233. - (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p)
  234. + nu * p / 2
  235. )
  236. @property
  237. def _natural_params(self):
  238. nu = self.df # has shape (batch_shape)
  239. p = self._event_shape[-1] # has singleton shape
  240. return - self.precision_matrix / 2, (nu - p - 1) / 2
  241. def _log_normalizer(self, x, y):
  242. p = self._event_shape[-1]
  243. return (
  244. (y + (p + 1) / 2) * (- torch.linalg.slogdet(- 2 * x).logabsdet + _log_2 * p)
  245. + torch.mvlgamma(y + (p + 1) / 2, p=p)
  246. )