gamma.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. from numbers import Number
  2. import torch
  3. from torch.distributions import constraints
  4. from torch.distributions.exp_family import ExponentialFamily
  5. from torch.distributions.utils import broadcast_all
  6. def _standard_gamma(concentration):
  7. return torch._standard_gamma(concentration)
  8. class Gamma(ExponentialFamily):
  9. r"""
  10. Creates a Gamma distribution parameterized by shape :attr:`concentration` and :attr:`rate`.
  11. Example::
  12. >>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0]))
  13. >>> m.sample() # Gamma distributed with concentration=1 and rate=1
  14. tensor([ 0.1046])
  15. Args:
  16. concentration (float or Tensor): shape parameter of the distribution
  17. (often referred to as alpha)
  18. rate (float or Tensor): rate = 1 / scale of the distribution
  19. (often referred to as beta)
  20. """
  21. arg_constraints = {'concentration': constraints.positive, 'rate': constraints.positive}
  22. support = constraints.nonnegative
  23. has_rsample = True
  24. _mean_carrier_measure = 0
  25. @property
  26. def mean(self):
  27. return self.concentration / self.rate
  28. @property
  29. def mode(self):
  30. return ((self.concentration - 1) / self.rate).clamp(min=0)
  31. @property
  32. def variance(self):
  33. return self.concentration / self.rate.pow(2)
  34. def __init__(self, concentration, rate, validate_args=None):
  35. self.concentration, self.rate = broadcast_all(concentration, rate)
  36. if isinstance(concentration, Number) and isinstance(rate, Number):
  37. batch_shape = torch.Size()
  38. else:
  39. batch_shape = self.concentration.size()
  40. super(Gamma, self).__init__(batch_shape, validate_args=validate_args)
  41. def expand(self, batch_shape, _instance=None):
  42. new = self._get_checked_instance(Gamma, _instance)
  43. batch_shape = torch.Size(batch_shape)
  44. new.concentration = self.concentration.expand(batch_shape)
  45. new.rate = self.rate.expand(batch_shape)
  46. super(Gamma, new).__init__(batch_shape, validate_args=False)
  47. new._validate_args = self._validate_args
  48. return new
  49. def rsample(self, sample_shape=torch.Size()):
  50. shape = self._extended_shape(sample_shape)
  51. value = _standard_gamma(self.concentration.expand(shape)) / self.rate.expand(shape)
  52. value.detach().clamp_(min=torch.finfo(value.dtype).tiny) # do not record in autograd graph
  53. return value
  54. def log_prob(self, value):
  55. value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device)
  56. if self._validate_args:
  57. self._validate_sample(value)
  58. return (torch.xlogy(self.concentration, self.rate) +
  59. torch.xlogy(self.concentration - 1, value) -
  60. self.rate * value - torch.lgamma(self.concentration))
  61. def entropy(self):
  62. return (self.concentration - torch.log(self.rate) + torch.lgamma(self.concentration) +
  63. (1.0 - self.concentration) * torch.digamma(self.concentration))
  64. @property
  65. def _natural_params(self):
  66. return (self.concentration - 1, -self.rate)
  67. def _log_normalizer(self, x, y):
  68. return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())