lpips.py 1.1 KB

123456789101112131415161718192021222324252627282930313233343536
  1. import torch
  2. import torch.nn as nn
  3. from .networks import get_network, LinLayers
  4. from .utils import get_state_dict
  5. class LPIPS(nn.Module):
  6. r"""Creates a criterion that measures
  7. Learned Perceptual Image Patch Similarity (LPIPS).
  8. Arguments:
  9. net_type (str): the network type to compare the features:
  10. 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
  11. version (str): the version of LPIPS. Default: 0.1.
  12. """
  13. def __init__(self, net_type: str = 'alex', version: str = '0.1'):
  14. assert version in ['0.1'], 'v0.1 is only supported now'
  15. super(LPIPS, self).__init__()
  16. # pretrained network
  17. self.net = get_network(net_type)
  18. # linear layers
  19. self.lin = LinLayers(self.net.n_channels_list)
  20. self.lin.load_state_dict(get_state_dict(net_type, version))
  21. def forward(self, x: torch.Tensor, y: torch.Tensor):
  22. feat_x, feat_y = self.net(x), self.net(y)
  23. diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
  24. res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
  25. return torch.sum(torch.cat(res, 0), 0, True)