فهرست منبع

Merge branch 'release' into develop

bkerbl 2 سال پیش
والد
کامیت
7d8035ad10
4فایلهای تغییر یافته به همراه15 افزوده شده و 5 حذف شده
  1. 2 0
      README.md
  2. 1 0
      arguments/__init__.py
  3. 11 4
      scene/cameras.py
  4. 1 1
      utils/camera_utils.py

+ 2 - 0
README.md

@@ -115,6 +115,8 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
   Add this flag to use a MipNeRF360-style training/test split for evaluation.
   #### --resolution / -r
   Specifies resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. **If not set and input image width exceeds 1.6K pixels, inputs are automatically rescaled to this target.**
+  #### --data_device
+  Specifies where to put the source image data, ```cuda``` by default, recommended to use ```cpu``` if training on large/high-resolution dataset, will reduce VRAM consumption, but slightly slow down training.
   #### --white_background / -w
   Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset.
   #### --sh_degree

+ 1 - 0
arguments/__init__.py

@@ -52,6 +52,7 @@ class ModelParams(ParamGroup):
         self._images = "images"
         self._resolution = -1
         self._white_background = False
+        self.data_device = "cuda"
         self.eval = False
         super().__init__(parser, "Loading Parameters", sentinel)
 

+ 11 - 4
scene/cameras.py

@@ -17,7 +17,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
 class Camera(nn.Module):
     def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
                  image_name, uid,
-                 trans=np.array([0.0, 0.0, 0.0]), scale=1.0
+                 trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
                  ):
         super(Camera, self).__init__()
 
@@ -29,14 +29,21 @@ class Camera(nn.Module):
         self.FoVy = FoVy
         self.image_name = image_name
 
-        self.original_image = image.clamp(0.0, 1.0).cuda()
+        try:
+            self.data_device = torch.device(data_device)
+        except Exception as e:
+            print(e)
+            print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
+            self.data_device = torch.device("cuda")
+
+        self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
         self.image_width = self.original_image.shape[2]
         self.image_height = self.original_image.shape[1]
 
         if gt_alpha_mask is not None:
-            self.original_image *= gt_alpha_mask.cuda()
+            self.original_image *= gt_alpha_mask.to(self.data_device)
         else:
-            self.original_image *= torch.ones((1, self.image_height, self.image_width), device="cuda")
+            self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
 
         self.zfar = 100.0
         self.znear = 0.01

+ 1 - 1
utils/camera_utils.py

@@ -49,7 +49,7 @@ def loadCam(args, id, cam_info, resolution_scale):
     return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 
                   FoVx=cam_info.FovX, FoVy=cam_info.FovY, 
                   image=gt_image, gt_alpha_mask=loaded_mask,
-                  image_name=cam_info.image_name, uid=id)
+                  image_name=cam_info.image_name, uid=id, data_device=args.data_device)
 
 def cameraList_from_camInfos(cam_infos, resolution_scale, args):
     camera_list = []