__init__.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #
  2. # Copyright (C) 2023, Inria
  3. # GRAPHDECO research group, https://team.inria.fr/graphdeco
  4. # All rights reserved.
  5. #
  6. # This software is free for non-commercial, research and evaluation use
  7. # under the terms of the LICENSE.md file.
  8. #
  9. # For inquiries contact george.drettakis@inria.fr
  10. #
  11. from argparse import ArgumentParser, Namespace
  12. import sys
  13. import os
  14. class GroupParams:
  15. pass
  16. class ParamGroup:
  17. def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
  18. group = parser.add_argument_group(name)
  19. for key, value in vars(self).items():
  20. shorthand = False
  21. if key.startswith("_"):
  22. shorthand = True
  23. key = key[1:]
  24. t = type(value)
  25. value = value if not fill_none else None
  26. if shorthand:
  27. if t == bool:
  28. group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
  29. else:
  30. group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
  31. else:
  32. if t == bool:
  33. group.add_argument("--" + key, default=value, action="store_true")
  34. else:
  35. group.add_argument("--" + key, default=value, type=t)
  36. def extract(self, args):
  37. group = GroupParams()
  38. for arg in vars(args).items():
  39. if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
  40. setattr(group, arg[0], arg[1])
  41. return group
  42. class ModelParams(ParamGroup):
  43. def __init__(self, parser, sentinel=False):
  44. self.sh_degree = 3
  45. self._source_path = ""
  46. self._model_path = ""
  47. self._images = "images"
  48. self._resolution = -1
  49. self._white_background = False
  50. self.data_device = "cuda"
  51. self.eval = False
  52. super().__init__(parser, "Loading Parameters", sentinel)
  53. def extract(self, args):
  54. g = super().extract(args)
  55. g.source_path = os.path.abspath(g.source_path)
  56. return g
  57. class PipelineParams(ParamGroup):
  58. def __init__(self, parser):
  59. self.convert_SHs_python = False
  60. self.compute_cov3D_python = False
  61. super().__init__(parser, "Pipeline Parameters")
  62. class OptimizationParams(ParamGroup):
  63. def __init__(self, parser):
  64. self.iterations = 30_000
  65. self.position_lr_init = 0.00016
  66. self.position_lr_final = 0.0000016
  67. self.position_lr_delay_mult = 0.01
  68. self.position_lr_max_steps = 30_000
  69. self.feature_lr = 0.0025
  70. self.opacity_lr = 0.05
  71. self.scaling_lr = 0.001
  72. self.rotation_lr = 0.001
  73. self.percent_dense = 0.01
  74. self.lambda_dssim = 0.2
  75. self.densification_interval = 100
  76. self.opacity_reset_interval = 3000
  77. self.densify_from_iter = 500
  78. self.densify_until_iter = 15_000
  79. self.densify_grad_threshold = 0.0002
  80. super().__init__(parser, "Optimization Parameters")
  81. def get_combined_args(parser : ArgumentParser):
  82. cmdlne_string = sys.argv[1:]
  83. cfgfile_string = "Namespace()"
  84. args_cmdline = parser.parse_args(cmdlne_string)
  85. try:
  86. cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
  87. print("Looking for config file in", cfgfilepath)
  88. with open(cfgfilepath) as cfg_file:
  89. print("Config file found: {}".format(cfgfilepath))
  90. cfgfile_string = cfg_file.read()
  91. except TypeError:
  92. print("Config file not found at")
  93. pass
  94. args_cfgfile = eval(cfgfile_string)
  95. merged_dict = vars(args_cfgfile).copy()
  96. for k,v in vars(args_cmdline).items():
  97. if v != None:
  98. merged_dict[k] = v
  99. return Namespace(**merged_dict)