Browse Source

fixed tensorboard update images by adding batch dimension (#9)

JonathonLuiten 2 năm trước cách đây
mục cha
commit
9490ef7612
1 tập tin đã thay đổi với 3 bổ sung3 xóa
  1. 3 3
      train.py

+ 3 - 3
train.py

@@ -162,9 +162,9 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
                     images = torch.cat((images, image.unsqueeze(0)), dim=0)
                     gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0)
                     if tb_writer and (idx < 5):
-                        tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image, global_step=iteration)
+                        tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
                         if iteration == testing_iterations[0]:
-                            tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image, global_step=iteration)
+                            tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
 
                 l1_test = l1_loss(images, gts)
                 psnr_test = psnr(images, gts).mean()            
@@ -204,4 +204,4 @@ if __name__ == "__main__":
     training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations)
 
     # All done
-    print("\nTraining complete.")
+    print("\nTraining complete.")