|
@@ -17,7 +17,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix
|
|
|
class Camera(nn.Module):
|
|
class Camera(nn.Module):
|
|
|
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
|
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
|
|
|
image_name, uid,
|
|
image_name, uid,
|
|
|
- trans=np.array([0.0, 0.0, 0.0]), scale=1.0
|
|
|
|
|
|
|
+ trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
|
|
|
):
|
|
):
|
|
|
super(Camera, self).__init__()
|
|
super(Camera, self).__init__()
|
|
|
|
|
|
|
@@ -29,14 +29,21 @@ class Camera(nn.Module):
|
|
|
self.FoVy = FoVy
|
|
self.FoVy = FoVy
|
|
|
self.image_name = image_name
|
|
self.image_name = image_name
|
|
|
|
|
|
|
|
- self.original_image = image.clamp(0.0, 1.0).cuda()
|
|
|
|
|
|
|
+ try:
|
|
|
|
|
+ self.data_device = torch.device(data_device)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ print(e)
|
|
|
|
|
+ print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
|
|
|
|
|
+ self.data_device = torch.device("cuda")
|
|
|
|
|
+
|
|
|
|
|
+ self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
|
|
|
self.image_width = self.original_image.shape[2]
|
|
self.image_width = self.original_image.shape[2]
|
|
|
self.image_height = self.original_image.shape[1]
|
|
self.image_height = self.original_image.shape[1]
|
|
|
|
|
|
|
|
if gt_alpha_mask is not None:
|
|
if gt_alpha_mask is not None:
|
|
|
- self.original_image *= gt_alpha_mask.cuda()
|
|
|
|
|
|
|
+ self.original_image *= gt_alpha_mask.to(self.data_device)
|
|
|
else:
|
|
else:
|
|
|
- self.original_image *= torch.ones((1, self.image_height, self.image_width), device="cuda")
|
|
|
|
|
|
|
+ self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
|
|
|
|
|
|
|
|
self.zfar = 100.0
|
|
self.zfar = 100.0
|
|
|
self.znear = 0.01
|
|
self.znear = 0.01
|