Browse Source

Added checkpoints

bkerbl 2 years ago
parent
commit
6b54263364
4 changed files with 77 additions and 28 deletions
  1. 5 2
      README.md
  2. 1 2
      scene/__init__.py
  3. 52 15
      scene/gaussian_model.py
  4. 19 9
      train.py

File diff suppressed because it is too large
+ 5 - 2
README.md


+ 1 - 2
scene/__init__.py

@@ -78,8 +78,7 @@ class Scene:
             self.gaussians.load_ply(os.path.join(self.model_path,
             self.gaussians.load_ply(os.path.join(self.model_path,
                                                            "point_cloud",
                                                            "point_cloud",
                                                            "iteration_" + str(self.loaded_iter),
                                                            "iteration_" + str(self.loaded_iter),
-                                                           "point_cloud.ply"),
-                                              og_number_points=len(scene_info.point_cloud.points))
+                                                           "point_cloud.ply"))
         else:
         else:
             self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
             self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
 
 

+ 52 - 15
scene/gaussian_model.py

@@ -22,17 +22,28 @@ from utils.graphics_utils import BasicPointCloud
 from utils.general_utils import strip_symmetric, build_scaling_rotation
 from utils.general_utils import strip_symmetric, build_scaling_rotation
 
 
 class GaussianModel:
 class GaussianModel:
-    def __init__(self, sh_degree : int):
 
 
+    def setup_functions(self):
         def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
         def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
             L = build_scaling_rotation(scaling_modifier * scaling, rotation)
             L = build_scaling_rotation(scaling_modifier * scaling, rotation)
             actual_covariance = L @ L.transpose(1, 2)
             actual_covariance = L @ L.transpose(1, 2)
             symm = strip_symmetric(actual_covariance)
             symm = strip_symmetric(actual_covariance)
             return symm
             return symm
+        
+        self.scaling_activation = torch.exp
+        self.scaling_inverse_activation = torch.log
 
 
+        self.covariance_activation = build_covariance_from_scaling_rotation
+
+        self.opacity_activation = torch.sigmoid
+        self.inverse_opacity_activation = inverse_sigmoid
+
+        self.rotation_activation = torch.nn.functional.normalize
+
+
+    def __init__(self, sh_degree : int):
         self.active_sh_degree = 0
         self.active_sh_degree = 0
         self.max_sh_degree = sh_degree  
         self.max_sh_degree = sh_degree  
-
         self._xyz = torch.empty(0)
         self._xyz = torch.empty(0)
         self._features_dc = torch.empty(0)
         self._features_dc = torch.empty(0)
         self._features_rest = torch.empty(0)
         self._features_rest = torch.empty(0)
@@ -41,18 +52,45 @@ class GaussianModel:
         self._opacity = torch.empty(0)
         self._opacity = torch.empty(0)
         self.max_radii2D = torch.empty(0)
         self.max_radii2D = torch.empty(0)
         self.xyz_gradient_accum = torch.empty(0)
         self.xyz_gradient_accum = torch.empty(0)
-
+        self.denom = torch.empty(0)
         self.optimizer = None
         self.optimizer = None
-
-        self.scaling_activation = torch.exp
-        self.scaling_inverse_activation = torch.log
-
-        self.covariance_activation = build_covariance_from_scaling_rotation
-
-        self.opacity_activation = torch.sigmoid
-        self.inverse_opacity_activation = inverse_sigmoid
-
-        self.rotation_activation = torch.nn.functional.normalize
+        self.percent_dense = 0
+        self.spatial_lr_scale = 0
+        self.setup_functions()
+
+    def capture(self):
+        return (
+            self.active_sh_degree,
+            self._xyz,
+            self._features_dc,
+            self._features_rest,
+            self._scaling,
+            self._rotation,
+            self._opacity,
+            self.max_radii2D,
+            self.xyz_gradient_accum,
+            self.denom,
+            self.optimizer.state_dict(),
+            self.spatial_lr_scale,
+        )
+    
+    def restore(self, model_args, training_args):
+        (self.active_sh_degree, 
+        self._xyz, 
+        self._features_dc, 
+        self._features_rest,
+        self._scaling, 
+        self._rotation, 
+        self._opacity,
+        self.max_radii2D, 
+        xyz_gradient_accum, 
+        denom,
+        opt_dict, 
+        self.spatial_lr_scale) = model_args
+        self.training_setup(training_args)
+        self.xyz_gradient_accum = xyz_gradient_accum
+        self.denom = denom
+        self.optimizer.load_state_dict(opt_dict)
 
 
     @property
     @property
     def get_scaling(self):
     def get_scaling(self):
@@ -174,8 +212,7 @@ class GaussianModel:
         optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
         optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
         self._opacity = optimizable_tensors["opacity"]
         self._opacity = optimizable_tensors["opacity"]
 
 
-    def load_ply(self, path, og_number_points=-1):
-        self.og_number_points = og_number_points
+    def load_ply(self, path):
         plydata = PlyData.read(path)
         plydata = PlyData.read(path)
 
 
         xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
         xyz = np.stack((np.asarray(plydata.elements[0]["x"]),

+ 19 - 9
train.py

@@ -28,12 +28,15 @@ try:
 except ImportError:
 except ImportError:
     TENSORBOARD_FOUND = False
     TENSORBOARD_FOUND = False
 
 
-def training(dataset, opt, pipe, testing_iterations, saving_iterations):
+def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint):
+    first_iter = 0
     tb_writer = prepare_output_and_logger(dataset)
     tb_writer = prepare_output_and_logger(dataset)
     gaussians = GaussianModel(dataset.sh_degree)
     gaussians = GaussianModel(dataset.sh_degree)
-
     scene = Scene(dataset, gaussians)
     scene = Scene(dataset, gaussians)
     gaussians.training_setup(opt)
     gaussians.training_setup(opt)
+    if checkpoint:
+        (model_params, first_iter) = torch.load(checkpoint)
+        gaussians.restore(model_params, opt)
 
 
     bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
     bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
     background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
     background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
@@ -43,8 +46,9 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
 
 
     viewpoint_stack = None
     viewpoint_stack = None
     ema_loss_for_log = 0.0
     ema_loss_for_log = 0.0
-    progress_bar = tqdm(range(opt.iterations), desc="Training progress")
-    for iteration in range(1, opt.iterations + 1):        
+    progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
+    first_iter += 1
+    for iteration in range(first_iter, opt.iterations + 1):        
         if network_gui.conn == None:
         if network_gui.conn == None:
             network_gui.try_connect()
             network_gui.try_connect()
         while network_gui.conn != None:
         while network_gui.conn != None:
@@ -62,6 +66,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
 
 
         iter_start.record()
         iter_start.record()
 
 
+        gaussians.update_learning_rate(iteration)
+
         # Every 1000 its we increase the levels of SH up to a maximum degree
         # Every 1000 its we increase the levels of SH up to a maximum degree
         if iteration % 1000 == 0:
         if iteration % 1000 == 0:
             gaussians.oneupSHdegree()
             gaussians.oneupSHdegree()
@@ -92,9 +98,6 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
             if iteration == opt.iterations:
             if iteration == opt.iterations:
                 progress_bar.close()
                 progress_bar.close()
 
 
-            # Keep track of max radii in image-space for pruning
-            gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
-
             # Log and save
             # Log and save
             training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
             training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
             if (iteration in saving_iterations):
             if (iteration in saving_iterations):
@@ -103,6 +106,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
 
 
             # Densification
             # Densification
             if iteration < opt.densify_until_iter:
             if iteration < opt.densify_until_iter:
+                # Keep track of max radii in image-space for pruning
+                gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
                 gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
                 gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
 
 
                 if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
                 if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
@@ -116,7 +121,10 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
             if iteration < opt.iterations:
             if iteration < opt.iterations:
                 gaussians.optimizer.step()
                 gaussians.optimizer.step()
                 gaussians.optimizer.zero_grad(set_to_none = True)
                 gaussians.optimizer.zero_grad(set_to_none = True)
-                gaussians.update_learning_rate(iteration)
+
+            if (iteration in checkpoint_iterations):
+                print("\n[ITER {}] Saving Checkpoint".format(iteration))
+                torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
 
 
 def prepare_output_and_logger(args):    
 def prepare_output_and_logger(args):    
     if not args.model_path:
     if not args.model_path:
@@ -189,6 +197,8 @@ if __name__ == "__main__":
     parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
     parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
     parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
     parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
     parser.add_argument("--quiet", action="store_true")
     parser.add_argument("--quiet", action="store_true")
+    parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
+    parser.add_argument("--start_checkpoint", type=str, default = None)
     args = parser.parse_args(sys.argv[1:])
     args = parser.parse_args(sys.argv[1:])
     args.save_iterations.append(args.iterations)
     args.save_iterations.append(args.iterations)
     
     
@@ -200,7 +210,7 @@ if __name__ == "__main__":
     # Start GUI server, configure and run training
     # Start GUI server, configure and run training
     network_gui.init(args.ip, args.port)
     network_gui.init(args.ip, args.port)
     torch.autograd.set_detect_anomaly(args.detect_anomaly)
     torch.autograd.set_detect_anomaly(args.detect_anomaly)
-    training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations)
+    training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint)
 
 
     # All done
     # All done
     print("\nTraining complete.")
     print("\nTraining complete.")