|
|
@@ -28,12 +28,15 @@ try:
|
|
|
except ImportError:
|
|
|
TENSORBOARD_FOUND = False
|
|
|
|
|
|
-def training(dataset, opt, pipe, testing_iterations, saving_iterations):
|
|
|
+def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint):
|
|
|
+ first_iter = 0
|
|
|
tb_writer = prepare_output_and_logger(dataset)
|
|
|
gaussians = GaussianModel(dataset.sh_degree)
|
|
|
-
|
|
|
scene = Scene(dataset, gaussians)
|
|
|
gaussians.training_setup(opt)
|
|
|
+ if checkpoint:
|
|
|
+ (model_params, first_iter) = torch.load(checkpoint)
|
|
|
+ gaussians.restore(model_params, opt)
|
|
|
|
|
|
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
|
|
|
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
|
|
|
@@ -43,8 +46,9 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
|
|
|
|
|
|
viewpoint_stack = None
|
|
|
ema_loss_for_log = 0.0
|
|
|
- progress_bar = tqdm(range(opt.iterations), desc="Training progress")
|
|
|
- for iteration in range(1, opt.iterations + 1):
|
|
|
+ progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
|
|
|
+ first_iter += 1
|
|
|
+ for iteration in range(first_iter, opt.iterations + 1):
|
|
|
if network_gui.conn == None:
|
|
|
network_gui.try_connect()
|
|
|
while network_gui.conn != None:
|
|
|
@@ -62,6 +66,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
|
|
|
|
|
|
iter_start.record()
|
|
|
|
|
|
+ gaussians.update_learning_rate(iteration)
|
|
|
+
|
|
|
# Every 1000 its we increase the levels of SH up to a maximum degree
|
|
|
if iteration % 1000 == 0:
|
|
|
gaussians.oneupSHdegree()
|
|
|
@@ -92,9 +98,6 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
|
|
|
if iteration == opt.iterations:
|
|
|
progress_bar.close()
|
|
|
|
|
|
- # Keep track of max radii in image-space for pruning
|
|
|
- gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
|
|
|
-
|
|
|
# Log and save
|
|
|
training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
|
|
|
if (iteration in saving_iterations):
|
|
|
@@ -103,6 +106,8 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
|
|
|
|
|
|
# Densification
|
|
|
if iteration < opt.densify_until_iter:
|
|
|
+ # Keep track of max radii in image-space for pruning
|
|
|
+ gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
|
|
|
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
|
|
|
|
|
|
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
|
|
|
@@ -116,7 +121,10 @@ def training(dataset, opt, pipe, testing_iterations, saving_iterations):
|
|
|
if iteration < opt.iterations:
|
|
|
gaussians.optimizer.step()
|
|
|
gaussians.optimizer.zero_grad(set_to_none = True)
|
|
|
- gaussians.update_learning_rate(iteration)
|
|
|
+
|
|
|
+ if (iteration in checkpoint_iterations):
|
|
|
+ print("\n[ITER {}] Saving Checkpoint".format(iteration))
|
|
|
+ torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
|
|
|
|
|
|
def prepare_output_and_logger(args):
|
|
|
if not args.model_path:
|
|
|
@@ -189,6 +197,8 @@ if __name__ == "__main__":
|
|
|
parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
|
|
parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
|
|
|
parser.add_argument("--quiet", action="store_true")
|
|
|
+ parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
|
|
|
+ parser.add_argument("--start_checkpoint", type=str, default = None)
|
|
|
args = parser.parse_args(sys.argv[1:])
|
|
|
args.save_iterations.append(args.iterations)
|
|
|
|
|
|
@@ -200,7 +210,7 @@ if __name__ == "__main__":
|
|
|
# Start GUI server, configure and run training
|
|
|
network_gui.init(args.ip, args.port)
|
|
|
torch.autograd.set_detect_anomaly(args.detect_anomaly)
|
|
|
- training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations)
|
|
|
+ training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint)
|
|
|
|
|
|
# All done
|
|
|
print("\nTraining complete.")
|