train.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. #
  2. # Copyright (C) 2023, Inria
  3. # GRAPHDECO research group, https://team.inria.fr/graphdeco
  4. # All rights reserved.
  5. #
  6. # This software is free for non-commercial, research and evaluation use
  7. # under the terms of the LICENSE.md file.
  8. #
  9. # For inquiries contact george.drettakis@inria.fr
  10. #
  11. import os
  12. import torch
  13. from random import randint
  14. from utils.loss_utils import l1_loss, ssim
  15. from gaussian_renderer import render, network_gui
  16. import sys
  17. from scene import Scene, GaussianModel
  18. from utils.general_utils import safe_state
  19. import uuid
  20. from tqdm import tqdm
  21. from utils.image_utils import psnr
  22. from argparse import ArgumentParser, Namespace
  23. from arguments import ModelParams, PipelineParams, OptimizationParams
  24. try:
  25. from torch.utils.tensorboard import SummaryWriter
  26. TENSORBOARD_FOUND = True
  27. except ImportError:
  28. TENSORBOARD_FOUND = False
  29. def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from):
  30. first_iter = 0
  31. tb_writer = prepare_output_and_logger(dataset)
  32. gaussians = GaussianModel(dataset.sh_degree)
  33. scene = Scene(dataset, gaussians)
  34. gaussians.training_setup(opt)
  35. if checkpoint:
  36. (model_params, first_iter) = torch.load(checkpoint)
  37. gaussians.restore(model_params, opt)
  38. bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
  39. background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
  40. iter_start = torch.cuda.Event(enable_timing = True)
  41. iter_end = torch.cuda.Event(enable_timing = True)
  42. viewpoint_stack = None
  43. ema_loss_for_log = 0.0
  44. progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
  45. first_iter += 1
  46. for iteration in range(first_iter, opt.iterations + 1):
  47. if network_gui.conn == None:
  48. network_gui.try_connect()
  49. while network_gui.conn != None:
  50. try:
  51. net_image_bytes = None
  52. custom_cam, do_training, pipe.do_shs_python, pipe.do_cov_python, keep_alive, scaling_modifer = network_gui.receive()
  53. if custom_cam != None:
  54. net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
  55. net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
  56. network_gui.send(net_image_bytes, dataset.source_path)
  57. if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
  58. break
  59. except Exception as e:
  60. network_gui.conn = None
  61. iter_start.record()
  62. gaussians.update_learning_rate(iteration)
  63. # Every 1000 its we increase the levels of SH up to a maximum degree
  64. if iteration % 1000 == 0:
  65. gaussians.oneupSHdegree()
  66. # Pick a random Camera
  67. if not viewpoint_stack:
  68. viewpoint_stack = scene.getTrainCameras().copy()
  69. viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
  70. # Render
  71. if (iteration - 1) == debug_from:
  72. pipe.debug = True
  73. render_pkg = render(viewpoint_cam, gaussians, pipe, background)
  74. image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
  75. # Loss
  76. gt_image = viewpoint_cam.original_image.cuda()
  77. Ll1 = l1_loss(image, gt_image)
  78. loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
  79. loss.backward()
  80. iter_end.record()
  81. with torch.no_grad():
  82. # Progress bar
  83. ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
  84. if iteration % 10 == 0:
  85. progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
  86. progress_bar.update(10)
  87. if iteration == opt.iterations:
  88. progress_bar.close()
  89. # Log and save
  90. training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
  91. if (iteration in saving_iterations):
  92. print("\n[ITER {}] Saving Gaussians".format(iteration))
  93. scene.save(iteration)
  94. # Densification
  95. if iteration < opt.densify_until_iter:
  96. # Keep track of max radii in image-space for pruning
  97. gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
  98. gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
  99. if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
  100. size_threshold = 20 if iteration > opt.opacity_reset_interval else None
  101. gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
  102. if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
  103. gaussians.reset_opacity()
  104. # Optimizer step
  105. if iteration < opt.iterations:
  106. gaussians.optimizer.step()
  107. gaussians.optimizer.zero_grad(set_to_none = True)
  108. if (iteration in checkpoint_iterations):
  109. print("\n[ITER {}] Saving Checkpoint".format(iteration))
  110. torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
  111. def prepare_output_and_logger(args):
  112. if not args.model_path:
  113. if os.getenv('OAR_JOB_ID'):
  114. unique_str=os.getenv('OAR_JOB_ID')
  115. else:
  116. unique_str = str(uuid.uuid4())
  117. args.model_path = os.path.join("./output/", unique_str[0:10])
  118. # Set up output folder
  119. print("Output folder: {}".format(args.model_path))
  120. os.makedirs(args.model_path, exist_ok = True)
  121. with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
  122. cfg_log_f.write(str(Namespace(**vars(args))))
  123. # Create Tensorboard writer
  124. tb_writer = None
  125. if TENSORBOARD_FOUND:
  126. tb_writer = SummaryWriter(args.model_path)
  127. else:
  128. print("Tensorboard not available: not logging progress")
  129. return tb_writer
  130. def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs):
  131. if tb_writer:
  132. tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
  133. tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
  134. tb_writer.add_scalar('iter_time', elapsed, iteration)
  135. # Report test and samples of training set
  136. if iteration in testing_iterations:
  137. torch.cuda.empty_cache()
  138. validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
  139. {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
  140. for config in validation_configs:
  141. if config['cameras'] and len(config['cameras']) > 0:
  142. l1_test = 0.0
  143. psnr_test = 0.0
  144. for idx, viewpoint in enumerate(config['cameras']):
  145. image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
  146. gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
  147. if tb_writer and (idx < 5):
  148. tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
  149. if iteration == testing_iterations[0]:
  150. tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
  151. l1_test += l1_loss(image, gt_image).mean().double()
  152. psnr_test += psnr(image, gt_image).mean().double()
  153. psnr_test /= len(config['cameras'])
  154. l1_test /= len(config['cameras'])
  155. print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
  156. if tb_writer:
  157. tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
  158. tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
  159. if tb_writer:
  160. tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
  161. tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
  162. torch.cuda.empty_cache()
  163. if __name__ == "__main__":
  164. # Set up command line argument parser
  165. parser = ArgumentParser(description="Training script parameters")
  166. lp = ModelParams(parser)
  167. op = OptimizationParams(parser)
  168. pp = PipelineParams(parser)
  169. parser.add_argument('--ip', type=str, default="127.0.0.1")
  170. parser.add_argument('--port', type=int, default=6009)
  171. parser.add_argument('--debug_from', type=int, default=-1)
  172. parser.add_argument('--detect_anomaly', action='store_true', default=False)
  173. parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
  174. parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
  175. parser.add_argument("--quiet", action="store_true")
  176. parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
  177. parser.add_argument("--start_checkpoint", type=str, default = None)
  178. args = parser.parse_args(sys.argv[1:])
  179. args.save_iterations.append(args.iterations)
  180. print("Optimizing " + args.model_path)
  181. # Initialize system state (RNG)
  182. safe_state(args.quiet)
  183. # Start GUI server, configure and run training
  184. network_gui.init(args.ip, args.port)
  185. torch.autograd.set_detect_anomaly(args.detect_anomaly)
  186. training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
  187. # All done
  188. print("\nTraining complete.")