uniform.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from numbers import Number
  2. import torch
  3. from torch._six import nan
  4. from torch.distributions import constraints
  5. from torch.distributions.distribution import Distribution
  6. from torch.distributions.utils import broadcast_all
  7. class Uniform(Distribution):
  8. r"""
  9. Generates uniformly distributed random samples from the half-open interval
  10. ``[low, high)``.
  11. Example::
  12. >>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0]))
  13. >>> m.sample() # uniformly distributed in the range [0.0, 5.0)
  14. tensor([ 2.3418])
  15. Args:
  16. low (float or Tensor): lower range (inclusive).
  17. high (float or Tensor): upper range (exclusive).
  18. """
  19. # TODO allow (loc,scale) parameterization to allow independent constraints.
  20. arg_constraints = {'low': constraints.dependent(is_discrete=False, event_dim=0),
  21. 'high': constraints.dependent(is_discrete=False, event_dim=0)}
  22. has_rsample = True
  23. @property
  24. def mean(self):
  25. return (self.high + self.low) / 2
  26. @property
  27. def mode(self):
  28. return nan * self.high
  29. @property
  30. def stddev(self):
  31. return (self.high - self.low) / 12**0.5
  32. @property
  33. def variance(self):
  34. return (self.high - self.low).pow(2) / 12
  35. def __init__(self, low, high, validate_args=None):
  36. self.low, self.high = broadcast_all(low, high)
  37. if isinstance(low, Number) and isinstance(high, Number):
  38. batch_shape = torch.Size()
  39. else:
  40. batch_shape = self.low.size()
  41. super(Uniform, self).__init__(batch_shape, validate_args=validate_args)
  42. if self._validate_args and not torch.lt(self.low, self.high).all():
  43. raise ValueError("Uniform is not defined when low>= high")
  44. def expand(self, batch_shape, _instance=None):
  45. new = self._get_checked_instance(Uniform, _instance)
  46. batch_shape = torch.Size(batch_shape)
  47. new.low = self.low.expand(batch_shape)
  48. new.high = self.high.expand(batch_shape)
  49. super(Uniform, new).__init__(batch_shape, validate_args=False)
  50. new._validate_args = self._validate_args
  51. return new
  52. @constraints.dependent_property(is_discrete=False, event_dim=0)
  53. def support(self):
  54. return constraints.interval(self.low, self.high)
  55. def rsample(self, sample_shape=torch.Size()):
  56. shape = self._extended_shape(sample_shape)
  57. rand = torch.rand(shape, dtype=self.low.dtype, device=self.low.device)
  58. return self.low + rand * (self.high - self.low)
  59. def log_prob(self, value):
  60. if self._validate_args:
  61. self._validate_sample(value)
  62. lb = self.low.le(value).type_as(self.low)
  63. ub = self.high.gt(value).type_as(self.low)
  64. return torch.log(lb.mul(ub)) - torch.log(self.high - self.low)
  65. def cdf(self, value):
  66. if self._validate_args:
  67. self._validate_sample(value)
  68. result = (value - self.low) / (self.high - self.low)
  69. return result.clamp(min=0, max=1)
  70. def icdf(self, value):
  71. result = value * (self.high - self.low) + self.low
  72. return result
  73. def entropy(self):
  74. return torch.log(self.high - self.low)