| 123456789101112131415161718192021 |
- import torch
- from .modules.lpips import LPIPS
- def lpips(x: torch.Tensor,
- y: torch.Tensor,
- net_type: str = 'alex',
- version: str = '0.1'):
- r"""Function that measures
- Learned Perceptual Image Patch Similarity (LPIPS).
- Arguments:
- x, y (torch.Tensor): the input tensors to compare.
- net_type (str): the network type to compare the features:
- 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
- version (str): the version of LPIPS. Default: 0.1.
- """
- device = x.device
- criterion = LPIPS(net_type, version).to(device)
- return criterion(x, y)
|