metrics.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from pathlib import Path
  2. import os
  3. from PIL import Image
  4. import torch
  5. import torchvision.transforms.functional as tf
  6. from utils.loss_utils import ssim
  7. from lpipsPyTorch import lpips
  8. import json
  9. from tqdm import tqdm
  10. from utils.image_utils import psnr
  11. from argparse import ArgumentParser
  12. def readImages(renders_dir, gt_dir):
  13. renders = []
  14. gts = []
  15. image_names = []
  16. for fname in os.listdir(renders_dir):
  17. render = Image.open(renders_dir / fname)
  18. gt = Image.open(gt_dir / fname)
  19. renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
  20. gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
  21. image_names.append(fname)
  22. return renders, gts, image_names
  23. def evaluate(model_paths):
  24. full_dict = {}
  25. per_view_dict = {}
  26. full_dict_polytopeonly = {}
  27. per_view_dict_polytopeonly = {}
  28. for scene_dir in model_paths:
  29. print("Scene:", scene_dir)
  30. full_dict[scene_dir] = {}
  31. per_view_dict[scene_dir] = {}
  32. full_dict_polytopeonly[scene_dir] = {}
  33. per_view_dict_polytopeonly[scene_dir] = {}
  34. test_dir = Path(scene_dir) / "test"
  35. for method in os.listdir(test_dir):
  36. print("Method:", method)
  37. full_dict[scene_dir][method] = {}
  38. per_view_dict[scene_dir][method] = {}
  39. full_dict_polytopeonly[scene_dir][method] = {}
  40. per_view_dict_polytopeonly[scene_dir][method] = {}
  41. method_dir = test_dir / method
  42. gt_dir = method_dir/ "gt"
  43. renders_dir = method_dir / "renders"
  44. renders, gts, image_names = readImages(renders_dir, gt_dir)
  45. ssims = []
  46. psnrs = []
  47. lpipss = []
  48. for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
  49. ssims.append(ssim(renders[idx], gts[idx]))
  50. psnrs.append(psnr(renders[idx], gts[idx]))
  51. lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))
  52. print("SSIM: {}".format(torch.tensor(ssims).mean()))
  53. print("PSNR: {}".format(torch.tensor(psnrs).mean()))
  54. print("LPIPS: {}".format(torch.tensor(lpipss).mean()))
  55. full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
  56. "PSNR": torch.tensor(psnrs).mean().item(),
  57. "LPIPS": torch.tensor(lpipss).mean().item()})
  58. per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
  59. "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
  60. "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}})
  61. with open(scene_dir + "/results.json", 'w') as fp:
  62. json.dump(full_dict[scene_dir], fp, indent=True)
  63. with open(scene_dir + "/per_view.json", 'w') as fp:
  64. json.dump(per_view_dict[scene_dir], fp, indent=True)
  65. if __name__ == "__main__":
  66. device = torch.device("cuda:0")
  67. torch.cuda.set_device(device)
  68. # Set up command line argument parser
  69. parser = ArgumentParser(description="Training script parameters")
  70. parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
  71. args = parser.parse_args()
  72. evaluate(args.model_paths)