ソースを参照

conflict resolved

bkerbl 2 年 前
コミット
ef918fc9e3
2 ファイル変更8 行追加9 行削除
  1. 2 2
      scene/dataset_readers.py
  2. 6 7
      train.py

+ 2 - 2
scene/dataset_readers.py

@@ -205,8 +205,8 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension=
             image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
 
             fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
-            FovY = fovx 
-            FovX = fovy
+            FovY = fovy 
+            FovX = fovx
 
             cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
                             image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))

+ 6 - 7
train.py

@@ -154,20 +154,19 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
 
         for config in validation_configs:
             if config['cameras'] and len(config['cameras']) > 0:
-                images = torch.tensor([], device="cuda")
-                gts = torch.tensor([], device="cuda")
+                l1_test = 0
+                psnr_test = 0
                 for idx, viewpoint in enumerate(config['cameras']):
                     image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
                     gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
-                    images = torch.cat((images, image.unsqueeze(0)), dim=0)
-                    gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0)
                     if tb_writer and (idx < 5):
                         tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
                         if iteration == testing_iterations[0]:
                             tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
-
-                l1_test = l1_loss(images, gts)
-                psnr_test = psnr(images, gts).mean()            
+                    l1_test += l1_loss(image, gt_image).mean()
+                    psnr_test += psnr(image, gt_image).mean()
+                psnr_test /= len(config['cameras'])
+                l1_test /= len(config['cameras'])          
                 print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
                 if tb_writer:
                     tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)