utils.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. from functools import update_wrapper
  2. from numbers import Number
  3. import torch
  4. import torch.nn.functional as F
  5. from typing import Dict, Any
  6. from torch.overrides import is_tensor_like
  7. euler_constant = 0.57721566490153286060 # Euler Mascheroni Constant
  8. def broadcast_all(*values):
  9. r"""
  10. Given a list of values (possibly containing numbers), returns a list where each
  11. value is broadcasted based on the following rules:
  12. - `torch.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`.
  13. - numbers.Number instances (scalars) are upcast to tensors having
  14. the same size and type as the first tensor passed to `values`. If all the
  15. values are scalars, then they are upcasted to scalar Tensors.
  16. Args:
  17. values (list of `numbers.Number`, `torch.*Tensor` or objects implementing __torch_function__)
  18. Raises:
  19. ValueError: if any of the values is not a `numbers.Number` instance,
  20. a `torch.*Tensor` instance, or an instance implementing __torch_function__
  21. """
  22. if not all(is_tensor_like(v) or isinstance(v, Number)
  23. for v in values):
  24. raise ValueError('Input arguments must all be instances of numbers.Number, '
  25. 'torch.Tensor or objects implementing __torch_function__.')
  26. if not all(is_tensor_like(v) for v in values):
  27. options: Dict[str, Any] = dict(dtype=torch.get_default_dtype())
  28. for value in values:
  29. if isinstance(value, torch.Tensor):
  30. options = dict(dtype=value.dtype, device=value.device)
  31. break
  32. new_values = [v if is_tensor_like(v) else torch.tensor(v, **options)
  33. for v in values]
  34. return torch.broadcast_tensors(*new_values)
  35. return torch.broadcast_tensors(*values)
  36. def _standard_normal(shape, dtype, device):
  37. if torch._C._get_tracing_state():
  38. # [JIT WORKAROUND] lack of support for .normal_()
  39. return torch.normal(torch.zeros(shape, dtype=dtype, device=device),
  40. torch.ones(shape, dtype=dtype, device=device))
  41. return torch.empty(shape, dtype=dtype, device=device).normal_()
  42. def _sum_rightmost(value, dim):
  43. r"""
  44. Sum out ``dim`` many rightmost dimensions of a given tensor.
  45. Args:
  46. value (Tensor): A tensor of ``.dim()`` at least ``dim``.
  47. dim (int): The number of rightmost dims to sum out.
  48. """
  49. if dim == 0:
  50. return value
  51. required_shape = value.shape[:-dim] + (-1,)
  52. return value.reshape(required_shape).sum(-1)
  53. def logits_to_probs(logits, is_binary=False):
  54. r"""
  55. Converts a tensor of logits into probabilities. Note that for the
  56. binary case, each value denotes log odds, whereas for the
  57. multi-dimensional case, the values along the last dimension denote
  58. the log probabilities (possibly unnormalized) of the events.
  59. """
  60. if is_binary:
  61. return torch.sigmoid(logits)
  62. return F.softmax(logits, dim=-1)
  63. def clamp_probs(probs):
  64. eps = torch.finfo(probs.dtype).eps
  65. return probs.clamp(min=eps, max=1 - eps)
  66. def probs_to_logits(probs, is_binary=False):
  67. r"""
  68. Converts a tensor of probabilities into logits. For the binary case,
  69. this denotes the probability of occurrence of the event indexed by `1`.
  70. For the multi-dimensional case, the values along the last dimension
  71. denote the probabilities of occurrence of each of the events.
  72. """
  73. ps_clamped = clamp_probs(probs)
  74. if is_binary:
  75. return torch.log(ps_clamped) - torch.log1p(-ps_clamped)
  76. return torch.log(ps_clamped)
  77. class lazy_property:
  78. r"""
  79. Used as a decorator for lazy loading of class attributes. This uses a
  80. non-data descriptor that calls the wrapped method to compute the property on
  81. first call; thereafter replacing the wrapped method into an instance
  82. attribute.
  83. """
  84. def __init__(self, wrapped):
  85. self.wrapped = wrapped
  86. update_wrapper(self, wrapped)
  87. def __get__(self, instance, obj_type=None):
  88. if instance is None:
  89. return _lazy_property_and_property(self.wrapped)
  90. with torch.enable_grad():
  91. value = self.wrapped(instance)
  92. setattr(instance, self.wrapped.__name__, value)
  93. return value
  94. class _lazy_property_and_property(lazy_property, property):
  95. """We want lazy properties to look like multiple things.
  96. * property when Sphinx autodoc looks
  97. * lazy_property when Distribution validate_args looks
  98. """
  99. def __init__(self, wrapped):
  100. return property.__init__(self, wrapped)
  101. def tril_matrix_to_vec(mat, diag=0):
  102. r"""
  103. Convert a `D x D` matrix or a batch of matrices into a (batched) vector
  104. which comprises of lower triangular elements from the matrix in row order.
  105. """
  106. n = mat.shape[-1]
  107. if not torch._C._get_tracing_state() and (diag < -n or diag >= n):
  108. raise ValueError(f'diag ({diag}) provided is outside [{-n}, {n-1}].')
  109. arange = torch.arange(n, device=mat.device)
  110. tril_mask = arange < arange.view(-1, 1) + (diag + 1)
  111. vec = mat[..., tril_mask]
  112. return vec
  113. def vec_to_tril_matrix(vec, diag=0):
  114. r"""
  115. Convert a vector or a batch of vectors into a batched `D x D`
  116. lower triangular matrix containing elements from the vector in row order.
  117. """
  118. # +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0
  119. n = (-(1 + 2 * diag) + ((1 + 2 * diag)**2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1))**0.5) / 2
  120. eps = torch.finfo(vec.dtype).eps
  121. if not torch._C._get_tracing_state() and (round(n) - n > eps):
  122. raise ValueError(f'The size of last dimension is {vec.shape[-1]} which cannot be expressed as ' +
  123. 'the lower triangular part of a square D x D matrix.')
  124. n = torch.round(n).long() if isinstance(n, torch.Tensor) else round(n)
  125. mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n)))
  126. arange = torch.arange(n, device=vec.device)
  127. tril_mask = arange < arange.view(-1, 1) + (diag + 1)
  128. mat[..., tril_mask] = vec
  129. return mat