Browse Source

Provide --data_device option to put data on CPU to save VRAM for training (#14)

* Provide --data_on_cpu option to save VRAM for training

when there are many training images such as in large scene, most of the VRAM are used to store training data, use --data_on_cpu  can help reduce VRAM and make it possible to train on GPU with less VRAM

* Fix data_on_cpu  effect on default mask

* --data_on_cpu to --data_device

* update readme

* format warning infos
Pythonix Huang 2 năm trước cách đây
mục cha
commit
989320fdf2
4 tập tin đã thay đổi với 15 bổ sung5 xóa
  1. 2 0
      README.md
  2. 1 0
      arguments/__init__.py
  3. 11 4
      scene/cameras.py
  4. 1 1
      utils/camera_utils.py

+ 2 - 0
README.md

@@ -165,6 +165,8 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
   Space-separated iterations at which the training script saves the Gaussian model, ```7000 30000 <iterations>``` by default.
   #### --quiet 
   Flag to omit any text written to standard out pipe. 
+  #### --data_device
+  Specify where to put the data on,```cuda``` by default, recommend use ```cpu``` if training on large scale/resolution dataset, will save a lot of VRAM required to train, but slightly slower the training
 
 </details>
 <br>

+ 1 - 0
arguments/__init__.py

@@ -52,6 +52,7 @@ class ModelParams(ParamGroup):
         self._images = "images"
         self._resolution = -1
         self._white_background = False
+        self.data_device = "cuda"
         self.eval = False
         super().__init__(parser, "Loading Parameters", sentinel)
 

+ 11 - 4
scene/cameras.py

@@ -17,7 +17,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
 class Camera(nn.Module):
     def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
                  image_name, uid,
-                 trans=np.array([0.0, 0.0, 0.0]), scale=1.0
+                 trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
                  ):
         super(Camera, self).__init__()
 
@@ -29,14 +29,21 @@ class Camera(nn.Module):
         self.FoVy = FoVy
         self.image_name = image_name
 
-        self.original_image = image.clamp(0.0, 1.0).cuda()
+        try:
+            self.data_device = torch.device(data_device)
+        except Exception as e:
+            print(e)
+            print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
+            self.data_device = torch.device("cuda")
+
+        self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
         self.image_width = self.original_image.shape[2]
         self.image_height = self.original_image.shape[1]
 
         if gt_alpha_mask is not None:
-            self.original_image *= gt_alpha_mask.cuda()
+            self.original_image *= gt_alpha_mask.to(self.data_device)
         else:
-            self.original_image *= torch.ones((1, self.image_height, self.image_width), device="cuda")
+            self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
 
         self.zfar = 100.0
         self.znear = 0.01

+ 1 - 1
utils/camera_utils.py

@@ -49,7 +49,7 @@ def loadCam(args, id, cam_info, resolution_scale):
     return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 
                   FoVx=cam_info.FovX, FoVy=cam_info.FovY, 
                   image=gt_image, gt_alpha_mask=loaded_mask,
-                  image_name=cam_info.image_name, uid=id)
+                  image_name=cam_info.image_name, uid=id, data_device=args.data_device)
 
 def cameraList_from_camInfos(cam_infos, resolution_scale, args):
     camera_list = []