utils.py 885 B

123456789101112131415161718192021222324252627282930
  1. from collections import OrderedDict
  2. import torch
  3. def normalize_activation(x, eps=1e-10):
  4. norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
  5. return x / (norm_factor + eps)
  6. def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
  7. # build url
  8. url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
  9. + f'master/lpips/weights/v{version}/{net_type}.pth'
  10. # download
  11. old_state_dict = torch.hub.load_state_dict_from_url(
  12. url, progress=True,
  13. map_location=None if torch.cuda.is_available() else torch.device('cpu')
  14. )
  15. # rename keys
  16. new_state_dict = OrderedDict()
  17. for key, val in old_state_dict.items():
  18. new_key = key
  19. new_key = new_key.replace('lin', '')
  20. new_key = new_key.replace('model.', '')
  21. new_state_dict[new_key] = val
  22. return new_state_dict