multinomial.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import torch
  2. from torch._six import inf
  3. from torch.distributions.binomial import Binomial
  4. from torch.distributions.distribution import Distribution
  5. from torch.distributions import Categorical
  6. from torch.distributions import constraints
  7. from torch.distributions.utils import broadcast_all
  8. class Multinomial(Distribution):
  9. r"""
  10. Creates a Multinomial distribution parameterized by :attr:`total_count` and
  11. either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of
  12. :attr:`probs` indexes over categories. All other dimensions index over batches.
  13. Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
  14. called (see example below)
  15. .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
  16. and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
  17. will return this normalized value.
  18. The `logits` argument will be interpreted as unnormalized log probabilities
  19. and can therefore be any real number. It will likewise be normalized so that
  20. the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
  21. will return this normalized value.
  22. - :meth:`sample` requires a single shared `total_count` for all
  23. parameters and samples.
  24. - :meth:`log_prob` allows different `total_count` for each parameter and
  25. sample.
  26. Example::
  27. >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
  28. >>> x = m.sample() # equal probability of 0, 1, 2, 3
  29. tensor([ 21., 24., 30., 25.])
  30. >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
  31. tensor([-4.1338])
  32. Args:
  33. total_count (int): number of trials
  34. probs (Tensor): event probabilities
  35. logits (Tensor): event log probabilities (unnormalized)
  36. """
  37. arg_constraints = {'probs': constraints.simplex,
  38. 'logits': constraints.real_vector}
  39. total_count: int
  40. @property
  41. def mean(self):
  42. return self.probs * self.total_count
  43. @property
  44. def variance(self):
  45. return self.total_count * self.probs * (1 - self.probs)
  46. def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
  47. if not isinstance(total_count, int):
  48. raise NotImplementedError('inhomogeneous total_count is not supported')
  49. self.total_count = total_count
  50. self._categorical = Categorical(probs=probs, logits=logits)
  51. self._binomial = Binomial(total_count=total_count, probs=self.probs)
  52. batch_shape = self._categorical.batch_shape
  53. event_shape = self._categorical.param_shape[-1:]
  54. super(Multinomial, self).__init__(batch_shape, event_shape, validate_args=validate_args)
  55. def expand(self, batch_shape, _instance=None):
  56. new = self._get_checked_instance(Multinomial, _instance)
  57. batch_shape = torch.Size(batch_shape)
  58. new.total_count = self.total_count
  59. new._categorical = self._categorical.expand(batch_shape)
  60. super(Multinomial, new).__init__(batch_shape, self.event_shape, validate_args=False)
  61. new._validate_args = self._validate_args
  62. return new
  63. def _new(self, *args, **kwargs):
  64. return self._categorical._new(*args, **kwargs)
  65. @constraints.dependent_property(is_discrete=True, event_dim=1)
  66. def support(self):
  67. return constraints.multinomial(self.total_count)
  68. @property
  69. def logits(self):
  70. return self._categorical.logits
  71. @property
  72. def probs(self):
  73. return self._categorical.probs
  74. @property
  75. def param_shape(self):
  76. return self._categorical.param_shape
  77. def sample(self, sample_shape=torch.Size()):
  78. sample_shape = torch.Size(sample_shape)
  79. samples = self._categorical.sample(torch.Size((self.total_count,)) + sample_shape)
  80. # samples.shape is (total_count, sample_shape, batch_shape), need to change it to
  81. # (sample_shape, batch_shape, total_count)
  82. shifted_idx = list(range(samples.dim()))
  83. shifted_idx.append(shifted_idx.pop(0))
  84. samples = samples.permute(*shifted_idx)
  85. counts = samples.new(self._extended_shape(sample_shape)).zero_()
  86. counts.scatter_add_(-1, samples, torch.ones_like(samples))
  87. return counts.type_as(self.probs)
  88. def entropy(self):
  89. n = torch.tensor(self.total_count)
  90. cat_entropy = self._categorical.entropy()
  91. term1 = n * cat_entropy - torch.lgamma(n + 1)
  92. support = self._binomial.enumerate_support(expand=False)[1:]
  93. binomial_probs = torch.exp(self._binomial.log_prob(support))
  94. weights = torch.lgamma(support + 1)
  95. term2 = (binomial_probs * weights).sum([0, -1])
  96. return term1 + term2
  97. def log_prob(self, value):
  98. if self._validate_args:
  99. self._validate_sample(value)
  100. logits, value = broadcast_all(self.logits, value)
  101. logits = logits.clone(memory_format=torch.contiguous_format)
  102. log_factorial_n = torch.lgamma(value.sum(-1) + 1)
  103. log_factorial_xs = torch.lgamma(value + 1).sum(-1)
  104. logits[(value == 0) & (logits == -inf)] = 0
  105. log_powers = (logits * value).sum(-1)
  106. return log_factorial_n - log_factorial_xs + log_powers