exp_family.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import torch
  2. from torch.distributions.distribution import Distribution
  3. class ExponentialFamily(Distribution):
  4. r"""
  5. ExponentialFamily is the abstract base class for probability distributions belonging to an
  6. exponential family, whose probability mass/density function has the form is defined below
  7. .. math::
  8. p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x))
  9. where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic,
  10. :math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier
  11. measure.
  12. Note:
  13. This class is an intermediary between the `Distribution` class and distributions which belong
  14. to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL
  15. divergence methods. We use this class to compute the entropy and KL divergence using the AD
  16. framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and
  17. Cross-entropies of Exponential Families).
  18. """
  19. @property
  20. def _natural_params(self):
  21. """
  22. Abstract method for natural parameters. Returns a tuple of Tensors based
  23. on the distribution
  24. """
  25. raise NotImplementedError
  26. def _log_normalizer(self, *natural_params):
  27. """
  28. Abstract method for log normalizer function. Returns a log normalizer based on
  29. the distribution and input
  30. """
  31. raise NotImplementedError
  32. @property
  33. def _mean_carrier_measure(self):
  34. """
  35. Abstract method for expected carrier measure, which is required for computing
  36. entropy.
  37. """
  38. raise NotImplementedError
  39. def entropy(self):
  40. """
  41. Method to compute the entropy using Bregman divergence of the log normalizer.
  42. """
  43. result = -self._mean_carrier_measure
  44. nparams = [p.detach().requires_grad_() for p in self._natural_params]
  45. lg_normal = self._log_normalizer(*nparams)
  46. gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True)
  47. result += lg_normal
  48. for np, g in zip(nparams, gradients):
  49. result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1)
  50. return result