continuous_bernoulli.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. from numbers import Number
  2. import math
  3. import torch
  4. from torch.distributions import constraints
  5. from torch.distributions.exp_family import ExponentialFamily
  6. from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property, clamp_probs
  7. from torch.nn.functional import binary_cross_entropy_with_logits
  8. class ContinuousBernoulli(ExponentialFamily):
  9. r"""
  10. Creates a continuous Bernoulli distribution parameterized by :attr:`probs`
  11. or :attr:`logits` (but not both).
  12. The distribution is supported in [0, 1] and parameterized by 'probs' (in
  13. (0,1)) or 'logits' (real-valued). Note that, unlike the Bernoulli, 'probs'
  14. does not correspond to a probability and 'logits' does not correspond to
  15. log-odds, but the same names are used due to the similarity with the
  16. Bernoulli. See [1] for more details.
  17. Example::
  18. >>> m = ContinuousBernoulli(torch.tensor([0.3]))
  19. >>> m.sample()
  20. tensor([ 0.2538])
  21. Args:
  22. probs (Number, Tensor): (0,1) valued parameters
  23. logits (Number, Tensor): real valued parameters whose sigmoid matches 'probs'
  24. [1] The continuous Bernoulli: fixing a pervasive error in variational
  25. autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019.
  26. https://arxiv.org/abs/1907.06845
  27. """
  28. arg_constraints = {'probs': constraints.unit_interval,
  29. 'logits': constraints.real}
  30. support = constraints.unit_interval
  31. _mean_carrier_measure = 0
  32. has_rsample = True
  33. def __init__(self, probs=None, logits=None, lims=(0.499, 0.501), validate_args=None):
  34. if (probs is None) == (logits is None):
  35. raise ValueError("Either `probs` or `logits` must be specified, but not both.")
  36. if probs is not None:
  37. is_scalar = isinstance(probs, Number)
  38. self.probs, = broadcast_all(probs)
  39. # validate 'probs' here if necessary as it is later clamped for numerical stability
  40. # close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
  41. if validate_args is not None:
  42. if not self.arg_constraints['probs'].check(getattr(self, 'probs')).all():
  43. raise ValueError("The parameter {} has invalid values".format('probs'))
  44. self.probs = clamp_probs(self.probs)
  45. else:
  46. is_scalar = isinstance(logits, Number)
  47. self.logits, = broadcast_all(logits)
  48. self._param = self.probs if probs is not None else self.logits
  49. if is_scalar:
  50. batch_shape = torch.Size()
  51. else:
  52. batch_shape = self._param.size()
  53. self._lims = lims
  54. super(ContinuousBernoulli, self).__init__(batch_shape, validate_args=validate_args)
  55. def expand(self, batch_shape, _instance=None):
  56. new = self._get_checked_instance(ContinuousBernoulli, _instance)
  57. new._lims = self._lims
  58. batch_shape = torch.Size(batch_shape)
  59. if 'probs' in self.__dict__:
  60. new.probs = self.probs.expand(batch_shape)
  61. new._param = new.probs
  62. if 'logits' in self.__dict__:
  63. new.logits = self.logits.expand(batch_shape)
  64. new._param = new.logits
  65. super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False)
  66. new._validate_args = self._validate_args
  67. return new
  68. def _new(self, *args, **kwargs):
  69. return self._param.new(*args, **kwargs)
  70. def _outside_unstable_region(self):
  71. return torch.max(torch.le(self.probs, self._lims[0]),
  72. torch.gt(self.probs, self._lims[1]))
  73. def _cut_probs(self):
  74. return torch.where(self._outside_unstable_region(),
  75. self.probs,
  76. self._lims[0] * torch.ones_like(self.probs))
  77. def _cont_bern_log_norm(self):
  78. '''computes the log normalizing constant as a function of the 'probs' parameter'''
  79. cut_probs = self._cut_probs()
  80. cut_probs_below_half = torch.where(torch.le(cut_probs, 0.5),
  81. cut_probs,
  82. torch.zeros_like(cut_probs))
  83. cut_probs_above_half = torch.where(torch.ge(cut_probs, 0.5),
  84. cut_probs,
  85. torch.ones_like(cut_probs))
  86. log_norm = torch.log(torch.abs(torch.log1p(-cut_probs) - torch.log(cut_probs))) - torch.where(
  87. torch.le(cut_probs, 0.5),
  88. torch.log1p(-2.0 * cut_probs_below_half),
  89. torch.log(2.0 * cut_probs_above_half - 1.0))
  90. x = torch.pow(self.probs - 0.5, 2)
  91. taylor = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x
  92. return torch.where(self._outside_unstable_region(), log_norm, taylor)
  93. @property
  94. def mean(self):
  95. cut_probs = self._cut_probs()
  96. mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / (torch.log1p(-cut_probs) - torch.log(cut_probs))
  97. x = self.probs - 0.5
  98. taylor = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * torch.pow(x, 2)) * x
  99. return torch.where(self._outside_unstable_region(), mus, taylor)
  100. @property
  101. def stddev(self):
  102. return torch.sqrt(self.variance)
  103. @property
  104. def variance(self):
  105. cut_probs = self._cut_probs()
  106. vars = cut_probs * (cut_probs - 1.0) / torch.pow(1.0 - 2.0 * cut_probs, 2) + 1.0 / torch.pow(
  107. torch.log1p(-cut_probs) - torch.log(cut_probs), 2)
  108. x = torch.pow(self.probs - 0.5, 2)
  109. taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128. / 945.0 * x) * x
  110. return torch.where(self._outside_unstable_region(), vars, taylor)
  111. @lazy_property
  112. def logits(self):
  113. return probs_to_logits(self.probs, is_binary=True)
  114. @lazy_property
  115. def probs(self):
  116. return clamp_probs(logits_to_probs(self.logits, is_binary=True))
  117. @property
  118. def param_shape(self):
  119. return self._param.size()
  120. def sample(self, sample_shape=torch.Size()):
  121. shape = self._extended_shape(sample_shape)
  122. u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
  123. with torch.no_grad():
  124. return self.icdf(u)
  125. def rsample(self, sample_shape=torch.Size()):
  126. shape = self._extended_shape(sample_shape)
  127. u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device)
  128. return self.icdf(u)
  129. def log_prob(self, value):
  130. if self._validate_args:
  131. self._validate_sample(value)
  132. logits, value = broadcast_all(self.logits, value)
  133. return -binary_cross_entropy_with_logits(logits, value, reduction='none') + self._cont_bern_log_norm()
  134. def cdf(self, value):
  135. if self._validate_args:
  136. self._validate_sample(value)
  137. cut_probs = self._cut_probs()
  138. cdfs = (torch.pow(cut_probs, value) * torch.pow(1.0 - cut_probs, 1.0 - value)
  139. + cut_probs - 1.0) / (2.0 * cut_probs - 1.0)
  140. unbounded_cdfs = torch.where(self._outside_unstable_region(), cdfs, value)
  141. return torch.where(
  142. torch.le(value, 0.0),
  143. torch.zeros_like(value),
  144. torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs))
  145. def icdf(self, value):
  146. cut_probs = self._cut_probs()
  147. return torch.where(
  148. self._outside_unstable_region(),
  149. (torch.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0))
  150. - torch.log1p(-cut_probs)) / (torch.log(cut_probs) - torch.log1p(-cut_probs)),
  151. value)
  152. def entropy(self):
  153. log_probs0 = torch.log1p(-self.probs)
  154. log_probs1 = torch.log(self.probs)
  155. return self.mean * (log_probs0 - log_probs1) - self._cont_bern_log_norm() - log_probs0
  156. @property
  157. def _natural_params(self):
  158. return (self.logits, )
  159. def _log_normalizer(self, x):
  160. """computes the log normalizing constant as a function of the natural parameter"""
  161. out_unst_reg = torch.max(torch.le(x, self._lims[0] - 0.5),
  162. torch.gt(x, self._lims[1] - 0.5))
  163. cut_nat_params = torch.where(out_unst_reg,
  164. x,
  165. (self._lims[0] - 0.5) * torch.ones_like(x))
  166. log_norm = torch.log(torch.abs(torch.exp(cut_nat_params) - 1.0)) - torch.log(torch.abs(cut_nat_params))
  167. taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0
  168. return torch.where(out_unst_reg, log_norm, taylor)