__init__.py 4.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #
  2. # Copyright (C) 2023, Inria
  3. # GRAPHDECO research group, https://team.inria.fr/graphdeco
  4. # All rights reserved.
  5. #
  6. # This software is free for non-commercial, research and evaluation use
  7. # under the terms of the LICENSE.md file.
  8. #
  9. # For inquiries contact george.drettakis@inria.fr
  10. #
  11. import os
  12. import random
  13. import json
  14. from utils.system_utils import searchForMaxIteration
  15. from scene.dataset_readers import sceneLoadTypeCallbacks
  16. from scene.gaussian_model import GaussianModel
  17. from arguments import ModelParams
  18. from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
  19. class Scene:
  20. gaussians : GaussianModel
  21. def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
  22. """b
  23. :param path: Path to colmap scene main folder.
  24. """
  25. self.model_path = args.model_path
  26. self.loaded_iter = None
  27. self.gaussians = gaussians
  28. if load_iteration:
  29. if load_iteration == -1:
  30. self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
  31. else:
  32. self.loaded_iter = load_iteration
  33. print("Loading trained model at iteration {}".format(self.loaded_iter))
  34. self.train_cameras = {}
  35. self.test_cameras = {}
  36. if os.path.exists(os.path.join(args.source_path, "sparse")):
  37. scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
  38. elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
  39. print("Found transforms_train.json file, assuming Blender data set!")
  40. scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
  41. else:
  42. assert False, "Could not recognize scene type!"
  43. if not self.loaded_iter:
  44. with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
  45. dest_file.write(src_file.read())
  46. json_cams = []
  47. camlist = []
  48. if scene_info.test_cameras:
  49. camlist.extend(scene_info.test_cameras)
  50. if scene_info.train_cameras:
  51. camlist.extend(scene_info.train_cameras)
  52. for id, cam in enumerate(camlist):
  53. json_cams.append(camera_to_JSON(id, cam))
  54. with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
  55. json.dump(json_cams, file)
  56. if shuffle:
  57. random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
  58. random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
  59. self.cameras_extent = scene_info.nerf_normalization["radius"]
  60. for resolution_scale in resolution_scales:
  61. print("Loading Training Cameras")
  62. self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
  63. print("Loading Test Cameras")
  64. self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
  65. if self.loaded_iter:
  66. self.gaussians.load_ply(os.path.join(self.model_path,
  67. "point_cloud",
  68. "iteration_" + str(self.loaded_iter),
  69. "point_cloud.ply"),
  70. og_number_points=len(scene_info.point_cloud.points))
  71. else:
  72. self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
  73. def save(self, iteration):
  74. point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
  75. self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
  76. def getTrainCameras(self, scale=1.0):
  77. return self.train_cameras[scale]
  78. def getTestCameras(self, scale=1.0):
  79. return self.test_cameras[scale]