distance.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from .module import Module
  2. from .. import functional as F
  3. from torch import Tensor
  4. class PairwiseDistance(Module):
  5. r"""
  6. Computes the pairwise distance between vectors :math:`v_1`, :math:`v_2` using the p-norm:
  7. .. math ::
  8. \Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}.
  9. Args:
  10. p (real): the norm degree. Default: 2
  11. eps (float, optional): Small value to avoid division by zero.
  12. Default: 1e-6
  13. keepdim (bool, optional): Determines whether or not to keep the vector dimension.
  14. Default: False
  15. Shape:
  16. - Input1: :math:`(N, D)` or :math:`(D)` where `N = batch dimension` and `D = vector dimension`
  17. - Input2: :math:`(N, D)` or :math:`(D)`, same shape as the Input1
  18. - Output: :math:`(N)` or :math:`()` based on input dimension.
  19. If :attr:`keepdim` is ``True``, then :math:`(N, 1)` or :math:`(1)` based on input dimension.
  20. Examples::
  21. >>> pdist = nn.PairwiseDistance(p=2)
  22. >>> input1 = torch.randn(100, 128)
  23. >>> input2 = torch.randn(100, 128)
  24. >>> output = pdist(input1, input2)
  25. """
  26. __constants__ = ['norm', 'eps', 'keepdim']
  27. norm: float
  28. eps: float
  29. keepdim: bool
  30. def __init__(self, p: float = 2., eps: float = 1e-6, keepdim: bool = False) -> None:
  31. super(PairwiseDistance, self).__init__()
  32. self.norm = p
  33. self.eps = eps
  34. self.keepdim = keepdim
  35. def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
  36. return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim)
  37. class CosineSimilarity(Module):
  38. r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along `dim`.
  39. .. math ::
  40. \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}.
  41. Args:
  42. dim (int, optional): Dimension where cosine similarity is computed. Default: 1
  43. eps (float, optional): Small value to avoid division by zero.
  44. Default: 1e-8
  45. Shape:
  46. - Input1: :math:`(\ast_1, D, \ast_2)` where D is at position `dim`
  47. - Input2: :math:`(\ast_1, D, \ast_2)`, same number of dimensions as x1, matching x1 size at dimension `dim`,
  48. and broadcastable with x1 at other dimensions.
  49. - Output: :math:`(\ast_1, \ast_2)`
  50. Examples::
  51. >>> input1 = torch.randn(100, 128)
  52. >>> input2 = torch.randn(100, 128)
  53. >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6)
  54. >>> output = cos(input1, input2)
  55. """
  56. __constants__ = ['dim', 'eps']
  57. dim: int
  58. eps: float
  59. def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
  60. super(CosineSimilarity, self).__init__()
  61. self.dim = dim
  62. self.eps = eps
  63. def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
  64. return F.cosine_similarity(x1, x2, self.dim, self.eps)