gumbel.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from numbers import Number
  2. import math
  3. import torch
  4. from torch.distributions import constraints
  5. from torch.distributions.uniform import Uniform
  6. from torch.distributions.transformed_distribution import TransformedDistribution
  7. from torch.distributions.transforms import AffineTransform, ExpTransform
  8. from torch.distributions.utils import broadcast_all, euler_constant
  9. class Gumbel(TransformedDistribution):
  10. r"""
  11. Samples from a Gumbel Distribution.
  12. Examples::
  13. >>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0]))
  14. >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2
  15. tensor([ 1.0124])
  16. Args:
  17. loc (float or Tensor): Location parameter of the distribution
  18. scale (float or Tensor): Scale parameter of the distribution
  19. """
  20. arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
  21. support = constraints.real
  22. def __init__(self, loc, scale, validate_args=None):
  23. self.loc, self.scale = broadcast_all(loc, scale)
  24. finfo = torch.finfo(self.loc.dtype)
  25. if isinstance(loc, Number) and isinstance(scale, Number):
  26. base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
  27. else:
  28. base_dist = Uniform(torch.full_like(self.loc, finfo.tiny),
  29. torch.full_like(self.loc, 1 - finfo.eps))
  30. transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
  31. ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)]
  32. super(Gumbel, self).__init__(base_dist, transforms, validate_args=validate_args)
  33. def expand(self, batch_shape, _instance=None):
  34. new = self._get_checked_instance(Gumbel, _instance)
  35. new.loc = self.loc.expand(batch_shape)
  36. new.scale = self.scale.expand(batch_shape)
  37. return super(Gumbel, self).expand(batch_shape, _instance=new)
  38. # Explicitly defining the log probability function for Gumbel due to precision issues
  39. def log_prob(self, value):
  40. if self._validate_args:
  41. self._validate_sample(value)
  42. y = (self.loc - value) / self.scale
  43. return (y - y.exp()) - self.scale.log()
  44. @property
  45. def mean(self):
  46. return self.loc + self.scale * euler_constant
  47. @property
  48. def mode(self):
  49. return self.loc
  50. @property
  51. def stddev(self):
  52. return (math.pi / math.sqrt(6)) * self.scale
  53. @property
  54. def variance(self):
  55. return self.stddev.pow(2)
  56. def entropy(self):
  57. return self.scale.log() + (1 + euler_constant)