poisson.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from numbers import Number
  2. import torch
  3. from torch.distributions import constraints
  4. from torch.distributions.exp_family import ExponentialFamily
  5. from torch.distributions.utils import broadcast_all
  6. class Poisson(ExponentialFamily):
  7. r"""
  8. Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter.
  9. Samples are nonnegative integers, with a pmf given by
  10. .. math::
  11. \mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}
  12. Example::
  13. >>> m = Poisson(torch.tensor([4]))
  14. >>> m.sample()
  15. tensor([ 3.])
  16. Args:
  17. rate (Number, Tensor): the rate parameter
  18. """
  19. arg_constraints = {'rate': constraints.nonnegative}
  20. support = constraints.nonnegative_integer
  21. @property
  22. def mean(self):
  23. return self.rate
  24. @property
  25. def mode(self):
  26. return self.rate.floor()
  27. @property
  28. def variance(self):
  29. return self.rate
  30. def __init__(self, rate, validate_args=None):
  31. self.rate, = broadcast_all(rate)
  32. if isinstance(rate, Number):
  33. batch_shape = torch.Size()
  34. else:
  35. batch_shape = self.rate.size()
  36. super(Poisson, self).__init__(batch_shape, validate_args=validate_args)
  37. def expand(self, batch_shape, _instance=None):
  38. new = self._get_checked_instance(Poisson, _instance)
  39. batch_shape = torch.Size(batch_shape)
  40. new.rate = self.rate.expand(batch_shape)
  41. super(Poisson, new).__init__(batch_shape, validate_args=False)
  42. new._validate_args = self._validate_args
  43. return new
  44. def sample(self, sample_shape=torch.Size()):
  45. shape = self._extended_shape(sample_shape)
  46. with torch.no_grad():
  47. return torch.poisson(self.rate.expand(shape))
  48. def log_prob(self, value):
  49. if self._validate_args:
  50. self._validate_sample(value)
  51. rate, value = broadcast_all(self.rate, value)
  52. return value.xlogy(rate) - rate - (value + 1).lgamma()
  53. @property
  54. def _natural_params(self):
  55. return (torch.log(self.rate), )
  56. def _log_normalizer(self, x):
  57. return torch.exp(x)