independent.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import torch
  2. from torch.distributions import constraints
  3. from torch.distributions.distribution import Distribution
  4. from torch.distributions.utils import _sum_rightmost
  5. from typing import Dict
  6. class Independent(Distribution):
  7. r"""
  8. Reinterprets some of the batch dims of a distribution as event dims.
  9. This is mainly useful for changing the shape of the result of
  10. :meth:`log_prob`. For example to create a diagonal Normal distribution with
  11. the same shape as a Multivariate Normal distribution (so they are
  12. interchangeable), you can::
  13. >>> loc = torch.zeros(3)
  14. >>> scale = torch.ones(3)
  15. >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
  16. >>> [mvn.batch_shape, mvn.event_shape]
  17. [torch.Size(()), torch.Size((3,))]
  18. >>> normal = Normal(loc, scale)
  19. >>> [normal.batch_shape, normal.event_shape]
  20. [torch.Size((3,)), torch.Size(())]
  21. >>> diagn = Independent(normal, 1)
  22. >>> [diagn.batch_shape, diagn.event_shape]
  23. [torch.Size(()), torch.Size((3,))]
  24. Args:
  25. base_distribution (torch.distributions.distribution.Distribution): a
  26. base distribution
  27. reinterpreted_batch_ndims (int): the number of batch dims to
  28. reinterpret as event dims
  29. """
  30. arg_constraints: Dict[str, constraints.Constraint] = {}
  31. def __init__(self, base_distribution, reinterpreted_batch_ndims, validate_args=None):
  32. if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
  33. raise ValueError("Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
  34. "actual {} vs {}".format(reinterpreted_batch_ndims,
  35. len(base_distribution.batch_shape)))
  36. shape = base_distribution.batch_shape + base_distribution.event_shape
  37. event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape)
  38. batch_shape = shape[:len(shape) - event_dim]
  39. event_shape = shape[len(shape) - event_dim:]
  40. self.base_dist = base_distribution
  41. self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
  42. super(Independent, self).__init__(batch_shape, event_shape, validate_args=validate_args)
  43. def expand(self, batch_shape, _instance=None):
  44. new = self._get_checked_instance(Independent, _instance)
  45. batch_shape = torch.Size(batch_shape)
  46. new.base_dist = self.base_dist.expand(batch_shape +
  47. self.event_shape[:self.reinterpreted_batch_ndims])
  48. new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
  49. super(Independent, new).__init__(batch_shape, self.event_shape, validate_args=False)
  50. new._validate_args = self._validate_args
  51. return new
  52. @property
  53. def has_rsample(self):
  54. return self.base_dist.has_rsample
  55. @property
  56. def has_enumerate_support(self):
  57. if self.reinterpreted_batch_ndims > 0:
  58. return False
  59. return self.base_dist.has_enumerate_support
  60. @constraints.dependent_property
  61. def support(self):
  62. result = self.base_dist.support
  63. if self.reinterpreted_batch_ndims:
  64. result = constraints.independent(result, self.reinterpreted_batch_ndims)
  65. return result
  66. @property
  67. def mean(self):
  68. return self.base_dist.mean
  69. @property
  70. def mode(self):
  71. return self.base_dist.mode
  72. @property
  73. def variance(self):
  74. return self.base_dist.variance
  75. def sample(self, sample_shape=torch.Size()):
  76. return self.base_dist.sample(sample_shape)
  77. def rsample(self, sample_shape=torch.Size()):
  78. return self.base_dist.rsample(sample_shape)
  79. def log_prob(self, value):
  80. log_prob = self.base_dist.log_prob(value)
  81. return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
  82. def entropy(self):
  83. entropy = self.base_dist.entropy()
  84. return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
  85. def enumerate_support(self, expand=True):
  86. if self.reinterpreted_batch_ndims > 0:
  87. raise NotImplementedError("Enumeration over cartesian product is not implemented")
  88. return self.base_dist.enumerate_support(expand=expand)
  89. def __repr__(self):
  90. return self.__class__.__name__ + '({}, {})'.format(self.base_dist, self.reinterpreted_batch_ndims)