__init__.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os
  2. import random
  3. import json
  4. from utils.system_utils import searchForMaxIteration
  5. from scene.dataset_readers import sceneLoadTypeCallbacks
  6. from scene.gaussian_model import GaussianModel
  7. from arguments import ModelParams
  8. from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
  9. class Scene:
  10. gaussians : GaussianModel
  11. def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
  12. """b
  13. :param path: Path to colmap scene main folder.
  14. """
  15. self.model_path = args.model_path
  16. self.loaded_iter = None
  17. self.gaussians = gaussians
  18. if load_iteration:
  19. if load_iteration == -1:
  20. self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
  21. else:
  22. self.loaded_iter = load_iteration
  23. print("Loading trained model at iteration {}".format(self.loaded_iter))
  24. self.train_cameras = {}
  25. self.test_cameras = {}
  26. if os.path.exists(os.path.join(args.source_path, "sparse")):
  27. scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
  28. elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
  29. print("Found transforms_train.json file, assuming Blender data set!")
  30. scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
  31. else:
  32. assert False, "Could not recognize scene type!"
  33. if not self.loaded_iter:
  34. with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
  35. dest_file.write(src_file.read())
  36. json_cams = []
  37. camlist = []
  38. if scene_info.test_cameras:
  39. camlist.extend(scene_info.test_cameras)
  40. if scene_info.train_cameras:
  41. camlist.extend(scene_info.train_cameras)
  42. for id, cam in enumerate(camlist):
  43. json_cams.append(camera_to_JSON(id, cam))
  44. with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
  45. json.dump(json_cams, file)
  46. if shuffle:
  47. random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
  48. random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
  49. self.cameras_extent = scene_info.nerf_normalization["radius"]
  50. for resolution_scale in resolution_scales:
  51. print("Loading Training Cameras")
  52. self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
  53. print("Loading Test Cameras")
  54. self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
  55. if self.loaded_iter:
  56. self.gaussians.load_ply(os.path.join(self.model_path,
  57. "point_cloud",
  58. "iteration_" + str(self.loaded_iter),
  59. "point_cloud.ply"),
  60. og_number_points=len(scene_info.point_cloud.points))
  61. else:
  62. self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
  63. def save(self, iteration):
  64. point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
  65. self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
  66. def getTrainCameras(self, scale=1.0):
  67. return self.train_cameras[scale]
  68. def getTestCameras(self, scale=1.0):
  69. return self.test_cameras[scale]