Explorar o código

Merge branch 'release' into develop

bkerbl %!s(int64=2) %!d(string=hai) anos
pai
achega
7d8035ad10
Modificáronse 4 ficheiros con 15 adicións e 5 borrados
  1. 2 0
      README.md
  2. 1 0
      arguments/__init__.py
  3. 11 4
      scene/cameras.py
  4. 1 1
      utils/camera_utils.py

+ 2 - 0
README.md

@@ -115,6 +115,8 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
   Add this flag to use a MipNeRF360-style training/test split for evaluation.
   Add this flag to use a MipNeRF360-style training/test split for evaluation.
   #### --resolution / -r
   #### --resolution / -r
   Specifies resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. **If not set and input image width exceeds 1.6K pixels, inputs are automatically rescaled to this target.**
   Specifies resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. **If not set and input image width exceeds 1.6K pixels, inputs are automatically rescaled to this target.**
+  #### --data_device
+  Specifies where to put the source image data, ```cuda``` by default, recommended to use ```cpu``` if training on large/high-resolution dataset, will reduce VRAM consumption, but slightly slow down training.
   #### --white_background / -w
   #### --white_background / -w
   Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset.
   Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset.
   #### --sh_degree
   #### --sh_degree

+ 1 - 0
arguments/__init__.py

@@ -52,6 +52,7 @@ class ModelParams(ParamGroup):
         self._images = "images"
         self._images = "images"
         self._resolution = -1
         self._resolution = -1
         self._white_background = False
         self._white_background = False
+        self.data_device = "cuda"
         self.eval = False
         self.eval = False
         super().__init__(parser, "Loading Parameters", sentinel)
         super().__init__(parser, "Loading Parameters", sentinel)
 
 

+ 11 - 4
scene/cameras.py

@@ -17,7 +17,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
 class Camera(nn.Module):
 class Camera(nn.Module):
     def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
     def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
                  image_name, uid,
                  image_name, uid,
-                 trans=np.array([0.0, 0.0, 0.0]), scale=1.0
+                 trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
                  ):
                  ):
         super(Camera, self).__init__()
         super(Camera, self).__init__()
 
 
@@ -29,14 +29,21 @@ class Camera(nn.Module):
         self.FoVy = FoVy
         self.FoVy = FoVy
         self.image_name = image_name
         self.image_name = image_name
 
 
-        self.original_image = image.clamp(0.0, 1.0).cuda()
+        try:
+            self.data_device = torch.device(data_device)
+        except Exception as e:
+            print(e)
+            print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
+            self.data_device = torch.device("cuda")
+
+        self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
         self.image_width = self.original_image.shape[2]
         self.image_width = self.original_image.shape[2]
         self.image_height = self.original_image.shape[1]
         self.image_height = self.original_image.shape[1]
 
 
         if gt_alpha_mask is not None:
         if gt_alpha_mask is not None:
-            self.original_image *= gt_alpha_mask.cuda()
+            self.original_image *= gt_alpha_mask.to(self.data_device)
         else:
         else:
-            self.original_image *= torch.ones((1, self.image_height, self.image_width), device="cuda")
+            self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
 
 
         self.zfar = 100.0
         self.zfar = 100.0
         self.znear = 0.01
         self.znear = 0.01

+ 1 - 1
utils/camera_utils.py

@@ -49,7 +49,7 @@ def loadCam(args, id, cam_info, resolution_scale):
     return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 
     return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 
                   FoVx=cam_info.FovX, FoVy=cam_info.FovY, 
                   FoVx=cam_info.FovX, FoVy=cam_info.FovY, 
                   image=gt_image, gt_alpha_mask=loaded_mask,
                   image=gt_image, gt_alpha_mask=loaded_mask,
-                  image_name=cam_info.image_name, uid=id)
+                  image_name=cam_info.image_name, uid=id, data_device=args.data_device)
 
 
 def cameraList_from_camInfos(cam_infos, resolution_scale, args):
 def cameraList_from_camInfos(cam_infos, resolution_scale, args):
     camera_list = []
     camera_list = []