Parcourir la source

Debugging behavior

bkerbl il y a 2 ans
Parent
commit
91f16deb45
2 fichiers modifiés avec 6 ajouts et 3 suppressions
  1. 1 1
      submodules/diff-gaussian-rasterization
  2. 5 2
      train.py

+ 1 - 1
submodules/diff-gaussian-rasterization

@@ -1 +1 @@
-Subproject commit 73917be7cfd1694e1e61adfa803981925494be5e
+Subproject commit 8064f52ca233942bdec2d1a1451c026deedd320b

+ 5 - 2
train.py

@@ -28,7 +28,7 @@ try:
 except ImportError:
     TENSORBOARD_FOUND = False
 
-def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint):
+def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
     first_iter = 0
     tb_writer = prepare_output_and_logger(dataset)
     gaussians = GaussianModel(dataset.sh_degree)
@@ -78,6 +78,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoi
         viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
 
         # Render
+        if (iteration - 1) == debug_from:
+            pipe.debug = True
         render_pkg = render(viewpoint_cam, gaussians, pipe, background)
         image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
 
@@ -193,6 +195,7 @@ if __name__ == "__main__":
     pp = PipelineParams(parser)
     parser.add_argument('--ip', type=str, default="127.0.0.1")
     parser.add_argument('--port', type=int, default=6009)
+    parser.add_argument('--debug_from', type=int, default=-1)
     parser.add_argument('--detect_anomaly', action='store_true', default=False)
     parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
     parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
@@ -210,7 +213,7 @@ if __name__ == "__main__":
     # Start GUI server, configure and run training
     network_gui.init(args.ip, args.port)
     torch.autograd.set_detect_anomaly(args.detect_anomaly)
-    training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint)
+    training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
 
     # All done
     print("\nTraining complete.")