__init__.py 635 B

123456789101112131415161718192021
  1. import torch
  2. from .modules.lpips import LPIPS
  3. def lpips(x: torch.Tensor,
  4. y: torch.Tensor,
  5. net_type: str = 'alex',
  6. version: str = '0.1'):
  7. r"""Function that measures
  8. Learned Perceptual Image Patch Similarity (LPIPS).
  9. Arguments:
  10. x, y (torch.Tensor): the input tensors to compare.
  11. net_type (str): the network type to compare the features:
  12. 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
  13. version (str): the version of LPIPS. Default: 0.1.
  14. """
  15. device = x.device
  16. criterion = LPIPS(net_type, version).to(device)
  17. return criterion(x, y)