image_utils.py 268 B

12345678
  1. import torch
  2. def mse(img1, img2):
  3. return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
  4. def psnr(img1, img2):
  5. mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
  6. return 20 * torch.log10(1.0 / torch.sqrt(mse))