metrics.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. from pathlib import Path
  2. import os
  3. from PIL import Image
  4. import torch
  5. import torchvision.transforms.functional as tf
  6. from utils.loss_utils import ssim
  7. from lpipsPyTorch import lpips
  8. import json
  9. from tqdm import tqdm
  10. from utils.image_utils import psnr
  11. from argparse import ArgumentParser
  12. def readImages(renders_dir, gt_dir):
  13. renders = []
  14. gts = []
  15. image_names = []
  16. for fname in os.listdir(renders_dir):
  17. render = Image.open(renders_dir / fname)
  18. gt = Image.open(gt_dir / fname)
  19. renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
  20. gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
  21. image_names.append(fname)
  22. return renders, gts, image_names
  23. def evaluate(model_paths):
  24. full_dict = {}
  25. per_view_dict = {}
  26. full_dict_polytopeonly = {}
  27. per_view_dict_polytopeonly = {}
  28. for scene_dir in model_paths:
  29. try:
  30. print("\nScene:", scene_dir)
  31. full_dict[scene_dir] = {}
  32. per_view_dict[scene_dir] = {}
  33. full_dict_polytopeonly[scene_dir] = {}
  34. per_view_dict_polytopeonly[scene_dir] = {}
  35. test_dir = Path(scene_dir) / "test"
  36. for method in os.listdir(test_dir):
  37. print("Method:", method)
  38. full_dict[scene_dir][method] = {}
  39. per_view_dict[scene_dir][method] = {}
  40. full_dict_polytopeonly[scene_dir][method] = {}
  41. per_view_dict_polytopeonly[scene_dir][method] = {}
  42. method_dir = test_dir / method
  43. gt_dir = method_dir/ "gt"
  44. renders_dir = method_dir / "renders"
  45. renders, gts, image_names = readImages(renders_dir, gt_dir)
  46. ssims = []
  47. psnrs = []
  48. lpipss = []
  49. for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
  50. ssims.append(ssim(renders[idx], gts[idx]))
  51. psnrs.append(psnr(renders[idx], gts[idx]))
  52. lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))
  53. print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
  54. print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
  55. print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"), "\n")
  56. full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
  57. "PSNR": torch.tensor(psnrs).mean().item(),
  58. "LPIPS": torch.tensor(lpipss).mean().item()})
  59. per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
  60. "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
  61. "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}})
  62. with open(scene_dir + "/results.json", 'w') as fp:
  63. json.dump(full_dict[scene_dir], fp, indent=True)
  64. with open(scene_dir + "/per_view.json", 'w') as fp:
  65. json.dump(per_view_dict[scene_dir], fp, indent=True)
  66. except:
  67. print("Unable to compute metrics for model", scene_dir)
  68. if __name__ == "__main__":
  69. device = torch.device("cuda:0")
  70. torch.cuda.set_device(device)
  71. # Set up command line argument parser
  72. parser = ArgumentParser(description="Training script parameters")
  73. parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
  74. args = parser.parse_args()
  75. evaluate(args.model_paths)