__init__.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from argparse import ArgumentParser, Namespace
  2. import sys
  3. import os
  4. class GroupParams:
  5. pass
  6. class ParamGroup:
  7. def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
  8. group = parser.add_argument_group(name)
  9. for key, value in vars(self).items():
  10. shorthand = False
  11. if key.startswith("_"):
  12. shorthand = True
  13. key = key[1:]
  14. t = type(value)
  15. value = value if not fill_none else None
  16. if shorthand:
  17. if t == bool:
  18. group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
  19. else:
  20. group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
  21. else:
  22. if t == bool:
  23. group.add_argument("--" + key, default=value, action="store_true")
  24. else:
  25. group.add_argument("--" + key, default=value, type=t)
  26. def extract(self, args):
  27. group = GroupParams()
  28. for arg in vars(args).items():
  29. if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
  30. setattr(group, arg[0], arg[1])
  31. return group
  32. class ModelParams(ParamGroup):
  33. def __init__(self, parser, sentinel=False):
  34. self.sh_degree = 3
  35. self._source_path = ""
  36. self._model_path = ""
  37. self._images = "images"
  38. self._resolution = 1
  39. self._white_background = False
  40. self.eval = False
  41. super().__init__(parser, "Loading Parameters", sentinel)
  42. class PipelineParams(ParamGroup):
  43. def __init__(self, parser):
  44. self.convert_SHs_python = False
  45. self.compute_cov3D_python = False
  46. super().__init__(parser, "Pipeline Parameters")
  47. class OptimizationParams(ParamGroup):
  48. def __init__(self, parser):
  49. self.iterations = 30_000
  50. self.position_lr_init = 0.00016
  51. self.position_lr_final = 0.0000016
  52. self.position_lr_delay_mult = 0.01
  53. self.posititon_lr_max_steps = 30_000
  54. self.feature_lr = 0.0025
  55. self.opacity_lr = 0.05
  56. self.scaling_lr = 0.001
  57. self.rotation_lr = 0.001
  58. self.percent_dense = 0.01
  59. self.lambda_dssim = 0.2
  60. self.densification_interval = 100
  61. self.opacity_reset_interval = 3000
  62. self.densify_from_iter = 500
  63. self.densify_until_iter = 15_000
  64. self.densify_grad_threshold = 0.0002
  65. super().__init__(parser, "Optimization Parameters")
  66. def get_combined_args(parser : ArgumentParser):
  67. cmdlne_string = sys.argv[1:]
  68. cfgfile_string = "Namespace()"
  69. args_cmdline = parser.parse_args(cmdlne_string)
  70. try:
  71. cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
  72. print("Looking for config file in", cfgfilepath)
  73. with open(cfgfilepath) as cfg_file:
  74. print("Config file found: {}".format(cfgfilepath))
  75. cfgfile_string = cfg_file.read()
  76. except TypeError:
  77. print("Config file not found at")
  78. pass
  79. args_cfgfile = eval(cfgfile_string)
  80. merged_dict = vars(args_cfgfile).copy()
  81. for k,v in vars(args_cmdline).items():
  82. if v != None:
  83. merged_dict[k] = v
  84. return Namespace(**merged_dict)