瀏覽代碼

Initial commit

bkerbl 2 年之前
當前提交
15d64e6781

+ 8 - 0
.gitignore

@@ -0,0 +1,8 @@
+*.pyc
+.vscode
+output
+build
+diff_rasterization/diff_rast.egg-info
+diff_rasterization/dist
+tensorboard_3d
+screenshots

+ 12 - 0
.gitmodules

@@ -0,0 +1,12 @@
+[submodule "submodules/diff-gaussian-rasterization"]
+	path = submodules/diff-gaussian-rasterization
+	url = https://gitlab.inria.fr/bkerbl/diff-gaussian-rasterization.git
+[submodule "submodules/simple-knn"]
+	path = submodules/simple-knn
+	url = https://gitlab.inria.fr/bkerbl/simple-knn.git
+[submodule "SIBR_viewers_windows"]
+	path = SIBR_viewers_windows
+	url = https://gitlab.inria.fr/sibr/sibr_core.git
+[submodule "SIBR_viewers_linux"]
+	path = SIBR_viewers_linux
+	url = https://gitlab.inria.fr/sibr/sibr_core.git

文件差異過大導致無法顯示
+ 203 - 0
README.md


+ 93 - 0
arguments/__init__.py

@@ -0,0 +1,93 @@
+from argparse import ArgumentParser, Namespace
+import sys
+import os
+
+class GroupParams:
+    pass
+
+class ParamGroup:
+    def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
+        group = parser.add_argument_group(name)
+        for key, value in vars(self).items():
+            shorthand = False
+            if key.startswith("_"):
+                shorthand = True
+                key = key[1:]
+            t = type(value)
+            value = value if not fill_none else None 
+            if shorthand:
+                if t == bool:
+                    group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
+                else:
+                    group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
+            else:
+                if t == bool:
+                    group.add_argument("--" + key, default=value, action="store_true")
+                else:
+                    group.add_argument("--" + key, default=value, type=t)
+
+    def extract(self, args):
+        group = GroupParams()
+        for arg in vars(args).items():
+            if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
+                setattr(group, arg[0], arg[1])
+        return group
+
+class ModelParams(ParamGroup): 
+    def __init__(self, parser, sentinel=False):
+        self.sh_degree = 3
+        self._source_path = ""
+        self._model_path = ""
+        self._images = "images"
+        self._resolution = 1
+        self._white_background = False
+        self.eval = False
+        super().__init__(parser, "Loading Parameters", sentinel)
+
+class PipelineParams(ParamGroup):
+    def __init__(self, parser):
+        self.convert_SHs_python = False
+        self.compute_cov3D_python = False
+        super().__init__(parser, "Pipeline Parameters")
+
+class OptimizationParams(ParamGroup):
+    def __init__(self, parser):
+        self.iterations = 30_000
+        self.position_lr_init = 0.00016
+        self.position_lr_final = 0.0000016
+        self.position_lr_delay_mult = 0.01
+        self.posititon_lr_max_steps = 30_000
+        self.feature_lr = 0.0025
+        self.opacity_lr = 0.05
+        self.scaling_lr = 0.001
+        self.rotation_lr = 0.001
+        self.percent_dense = 0.01
+        self.lambda_dssim = 0.2
+        self.densification_interval = 100
+        self.opacity_reset_interval = 3000
+        self.densify_from_iter = 500
+        self.densify_until_iter = 15_000
+        self.densify_grad_threshold = 0.0002
+        super().__init__(parser, "Optimization Parameters")
+
+def get_combined_args(parser : ArgumentParser):
+    cmdlne_string = sys.argv[1:]
+    cfgfile_string = "Namespace()"
+    args_cmdline = parser.parse_args(cmdlne_string)
+
+    try:
+        cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
+        print("Looking for config file in", cfgfilepath)
+        with open(cfgfilepath) as cfg_file:
+            print("Config file found: {}".format(cfgfilepath))
+            cfgfile_string = cfg_file.read()
+    except TypeError:
+        print("Config file not found at")
+        pass
+    args_cfgfile = eval(cfgfile_string)
+
+    merged_dict = vars(args_cfgfile).copy()
+    for k,v in vars(args_cmdline).items():
+        if v != None:
+            merged_dict[k] = v
+    return Namespace(**merged_dict)

二進制
assets/logo_graphdeco.png


二進制
assets/logo_inria.png


二進制
assets/logo_mpi.png


文件差異過大導致無法顯示
+ 488 - 0
assets/logo_mpi.svg


二進制
assets/logo_uca.png


二進制
assets/teaser.png


+ 85 - 0
convert.py

@@ -0,0 +1,85 @@
+import os
+from argparse import ArgumentParser
+import shutil
+
+# This Python script is based on the shell converter script provided in the MipNerF 360 repository.
+parser = ArgumentParser("Colmap converter")
+parser.add_argument("--no_gpu", action='store_true')
+parser.add_argument("--source_path", "-s", required=True, type=str)
+parser.add_argument("--camera", default="OPENCV", type=str)
+parser.add_argument("--colmap_executable", default="", type=str)
+parser.add_argument("--resize", action="store_true")
+parser.add_argument("--magick_executable", default="", type=str)
+args = parser.parse_args()
+colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap"
+magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick"
+use_gpu = 1 if not args.no_gpu else 0
+
+os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True)
+
+## Feature extraction
+os.system(colmap_command + " feature_extractor "\
+    "--database_path " + args.source_path + "/distorted/database.db \
+    --image_path " + args.source_path + "/input \
+    --ImageReader.single_camera 1 \
+    --ImageReader.camera_model " + args.camera + " \
+    --SiftExtraction.use_gpu " + str(use_gpu))
+
+## Feature matching
+os.system(colmap_command + " exhaustive_matcher \
+    --database_path " + args.source_path + "/distorted/database.db \
+    --SiftMatching.use_gpu " + str(use_gpu))
+
+### Bundle adjustment
+# The default Mapper tolerance is unnecessarily large,
+# decreasing it speeds up bundle adjustment steps.
+os.system(colmap_command + " mapper \
+    --database_path " + args.source_path + "/distorted/database.db \
+    --image_path "  + args.source_path + "/input \
+    --output_path "  + args.source_path + "/distorted/sparse \
+    --Mapper.ba_global_function_tolerance=0.000001")
+
+### Image undistortion
+## We need to undistort our images into ideal pinhole intrinsics.
+os.system(colmap_command + " image_undistorter \
+    --image_path " + args.source_path + "/input \
+    --input_path " + args.source_path + "/distorted/sparse/0 \
+    --output_path " + args.source_path + "\
+    --output_type COLMAP")
+
+files = os.listdir(args.source_path + "/sparse")
+os.makedirs(args.source_path + "/sparse/0", exist_ok=True)
+# Copy each file from the source directory to the destination directory
+for file in files:
+    if file == '0':
+        continue
+    source_file = os.path.join(args.source_path, "sparse", file)
+    destination_file = os.path.join(args.source_path, "sparse", "0", file)
+    shutil.move(source_file, destination_file)
+
+if(args.resize):
+    print("Copying and resizing...")
+
+    # Resize images.
+    os.makedirs(args.source_path + "/images_2", exist_ok=True)
+    os.makedirs(args.source_path + "/images_4", exist_ok=True)
+    os.makedirs(args.source_path + "/images_8", exist_ok=True)
+    # Get the list of files in the source directory
+    files = os.listdir(args.source_path + "/images")
+    # Copy each file from the source directory to the destination directory
+    for file in files:
+        source_file = os.path.join(args.source_path, "images", file)
+
+        destination_file = os.path.join(args.source_path, "images_2", file)
+        shutil.copy2(source_file, destination_file)
+        os.system(magick_command + " mogrify -resize 50% " + destination_file)
+
+        destination_file = os.path.join(args.source_path, "images_4", file)
+        shutil.copy2(source_file, destination_file)
+        os.system(magick_command + " mogrify -resize 25% " + destination_file)
+
+        destination_file = os.path.join(args.source_path, "images_8", file)
+        shutil.copy2(source_file, destination_file)
+        os.system(magick_command + " mogrify -resize 12.5% " + destination_file)
+
+print("Done.")

+ 19 - 0
environment_full.yml

@@ -0,0 +1,19 @@
+name: gaussian_splatting
+channels:
+  - pytorch
+  - conda-forge
+  - defaults
+dependencies:
+  - cudatoolkit=11.6
+  - cudatoolkit-dev=11.6
+  - cxx-compiler=1.3.0
+  - plyfile=0.8.1
+  - python=3.7.13
+  - pip=22.3.1
+  - pytorch=1.12.1
+  - torchaudio=0.12.1
+  - torchvision=0.13.1
+  - tqdm
+  - pip:
+    - submodules/diff-gaussian-rasterization
+    - submodules/simple-knn

+ 17 - 0
environment_light.yml

@@ -0,0 +1,17 @@
+name: gaussian_splatting
+channels:
+  - pytorch
+  - conda-forge
+  - defaults
+dependencies:
+  - cudatoolkit=11.6
+  - plyfile=0.8.1
+  - python=3.7.13
+  - pip=22.3.1
+  - pytorch=1.12.1
+  - torchaudio=0.12.1
+  - torchvision=0.13.1
+  - tqdm
+  - pip:
+    - submodules/diff-gaussian-rasterization
+    - submodules/simple-knn

+ 52 - 0
full_eval.py

@@ -0,0 +1,52 @@
+import os
+from argparse import ArgumentParser
+
+mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"]
+mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"]
+tanks_and_temples_scenes = ["truck", "train"]
+deep_blending_scenes = ["drjohnson", "playroom"]
+
+parser = ArgumentParser(description="Full evaluation script parameters")
+parser.add_argument("--skip_training", action="store_true")
+parser.add_argument("--skip_rendering", action="store_true")
+parser.add_argument("--skip_metrics", action="store_true")
+args, _ = parser.parse_known_args()
+
+if not args.skip_training:
+    parser.add_argument('--mipnerf360', "-m360", required=True, type=str)
+    parser.add_argument("--tanksandtemples", "-tat", required=True, type=str)
+    parser.add_argument("--deepblending", "-db", required=True, type=str)
+    args = parser.parse_args()
+
+    common_args = " --quiet --eval --test_iterations -1"
+    for scene in tanks_and_temples_scenes:
+        source = args.tanksandtemples + "/" + scene
+        os.system("python train.py -s " + source + " -m ./eval/" + scene + common_args)
+    for scene in deep_blending_scenes:
+        source = args.deepblending + "/" + scene
+        os.system("python train.py -s " + source + " -m ./eval/" + scene + common_args)
+    for scene in mipnerf360_outdoor_scenes:
+        source = args.mipnerf360 + "/" + scene
+        os.system("python train.py -s " + source + " -i images_4 -m ./eval/" + scene + common_args)
+    for scene in mipnerf360_indoor_scenes:
+        source = args.mipnerf360 + "/" + scene
+        os.system("python train.py -s " + source + " -i images_2 -m ./eval/" + scene + common_args)
+
+all_scenes = []
+all_scenes.extend(mipnerf360_outdoor_scenes)
+all_scenes.extend(mipnerf360_indoor_scenes)
+all_scenes.extend(tanks_and_temples_scenes)
+all_scenes.extend(deep_blending_scenes)
+
+if not args.skip_rendering:
+    for scene in all_scenes:
+        os.system("python render.py --quiet --skip_train --eval --iteration 7000 -m ./eval/" + scene)
+    for scene in all_scenes:
+        os.system("python render.py --quiet --skip_train --eval --iteration 30000 -m ./eval/" + scene)
+
+if not args.skip_metrics:
+    scenes_string = ""
+    for scene in all_scenes:
+        scenes_string += "\"" + "./eval/" + scene + "\" "
+
+    os.system("python metrics.py -m " + scenes_string)

+ 88 - 0
gaussian_renderer/__init__.py

@@ -0,0 +1,88 @@
+import torch
+import math
+from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
+from scene.gaussian_model import GaussianModel
+from utils.sh_utils import eval_sh
+
+def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):
+    """
+    Render the scene. 
+    
+    Background tensor (bg_color) must be on GPU!
+    """
+ 
+    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
+    screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
+    try:
+        screenspace_points.retain_grad()
+    except:
+        pass
+
+    # Set up rasterization configuration
+    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
+    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
+
+    raster_settings = GaussianRasterizationSettings(
+        image_height=int(viewpoint_camera.image_height),
+        image_width=int(viewpoint_camera.image_width),
+        tanfovx=tanfovx,
+        tanfovy=tanfovy,
+        bg=bg_color,
+        scale_modifier=scaling_modifier,
+        viewmatrix=viewpoint_camera.world_view_transform,
+        projmatrix=viewpoint_camera.full_proj_transform,
+        sh_degree=pc.active_sh_degree,
+        campos=viewpoint_camera.camera_center,
+        prefiltered=False
+    )
+
+    rasterizer = GaussianRasterizer(raster_settings=raster_settings)
+
+    means3D = pc.get_xyz
+    means2D = screenspace_points
+    opacity = pc.get_opacity
+
+    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
+    # scaling / rotation by the rasterizer.
+    scales = None
+    rotations = None
+    cov3D_precomp = None
+    if pipe.compute_cov3D_python:
+        cov3D_precomp = pc.get_covariance(scaling_modifier)
+    else:
+        scales = pc.get_scaling
+        rotations = pc.get_rotation
+
+    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
+    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
+    shs = None
+    colors_precomp = None
+    if colors_precomp is None:
+        if pipe.convert_SHs_python:
+            shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
+            dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
+            dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
+            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
+            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
+        else:
+            shs = pc.get_features
+    else:
+        colors_precomp = override_color
+
+    # Rasterize visible Gaussians to image, obtain their radii (on screen). 
+    rendered_image, radii = rasterizer(
+        means3D = means3D,
+        means2D = means2D,
+        shs = shs,
+        colors_precomp = colors_precomp,
+        opacities = opacity,
+        scales = scales,
+        rotations = rotations,
+        cov3D_precomp = cov3D_precomp)
+
+    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
+    # They will be excluded from value updates used in the splitting criteria.
+    return {"render": rendered_image,
+            "viewspace_points": screenspace_points,
+            "visibility_filter" : radii > 0,
+            "radii": radii}

+ 75 - 0
gaussian_renderer/network_gui.py

@@ -0,0 +1,75 @@
+import torch
+import traceback
+import socket
+import json
+from scene.cameras import MiniCam
+
+host = "127.0.0.1"
+port = 6009
+
+conn = None
+addr = None
+
+listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+def init(wish_host, wish_port):
+    global host, port, listener
+    host = wish_host
+    port = wish_port
+    listener.bind((host, port))
+    listener.listen()
+    listener.settimeout(0)
+
+def try_connect():
+    global conn, addr, listener
+    try:
+        conn, addr = listener.accept()
+        print(f"\nConnected by {addr}")
+        conn.settimeout(None)
+    except Exception as inst:
+        pass
+            
+def read():
+    global conn
+    messageLength = conn.recv(4)
+    messageLength = int.from_bytes(messageLength, 'little')
+    message = conn.recv(messageLength)
+    return json.loads(message.decode("utf-8"))
+
+def send(message_bytes, verify):
+    global conn
+    if message_bytes != None:
+        conn.sendall(message_bytes)
+    conn.sendall(len(verify).to_bytes(4, 'little'))
+    conn.sendall(bytes(verify, 'ascii'))
+
+def receive():
+    message = read()
+
+    width = message["resolution_x"]
+    height = message["resolution_y"]
+
+    if width != 0 and height != 0:
+        try:
+            do_training = bool(message["train"])
+            fovy = message["fov_y"]
+            fovx = message["fov_x"]
+            znear = message["z_near"]
+            zfar = message["z_far"]
+            do_shs_python = bool(message["shs_python"])
+            do_rot_scale_python = bool(message["rot_scale_python"])
+            keep_alive = bool(message["keep_alive"])
+            scaling_modifier = message["scaling_modifier"]
+            world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
+            world_view_transform[:,1] = -world_view_transform[:,1]
+            world_view_transform[:,2] = -world_view_transform[:,2]
+            full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
+            full_proj_transform[:,1] = -full_proj_transform[:,1]
+            custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
+        except Exception as e:
+            print("")
+            traceback.print_exc()
+            raise e
+        return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
+    else:
+        return None, None, None, None, None, None

+ 21 - 0
lpipsPyTorch/__init__.py

@@ -0,0 +1,21 @@
+import torch
+
+from .modules.lpips import LPIPS
+
+
+def lpips(x: torch.Tensor,
+          y: torch.Tensor,
+          net_type: str = 'alex',
+          version: str = '0.1'):
+    r"""Function that measures
+    Learned Perceptual Image Patch Similarity (LPIPS).
+
+    Arguments:
+        x, y (torch.Tensor): the input tensors to compare.
+        net_type (str): the network type to compare the features: 
+                        'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
+        version (str): the version of LPIPS. Default: 0.1.
+    """
+    device = x.device
+    criterion = LPIPS(net_type, version).to(device)
+    return criterion(x, y)

+ 36 - 0
lpipsPyTorch/modules/lpips.py

@@ -0,0 +1,36 @@
+import torch
+import torch.nn as nn
+
+from .networks import get_network, LinLayers
+from .utils import get_state_dict
+
+
+class LPIPS(nn.Module):
+    r"""Creates a criterion that measures
+    Learned Perceptual Image Patch Similarity (LPIPS).
+
+    Arguments:
+        net_type (str): the network type to compare the features: 
+                        'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
+        version (str): the version of LPIPS. Default: 0.1.
+    """
+    def __init__(self, net_type: str = 'alex', version: str = '0.1'):
+
+        assert version in ['0.1'], 'v0.1 is only supported now'
+
+        super(LPIPS, self).__init__()
+
+        # pretrained network
+        self.net = get_network(net_type)
+
+        # linear layers
+        self.lin = LinLayers(self.net.n_channels_list)
+        self.lin.load_state_dict(get_state_dict(net_type, version))
+
+    def forward(self, x: torch.Tensor, y: torch.Tensor):
+        feat_x, feat_y = self.net(x), self.net(y)
+
+        diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
+        res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
+
+        return torch.sum(torch.cat(res, 0), 0, True)

+ 96 - 0
lpipsPyTorch/modules/networks.py

@@ -0,0 +1,96 @@
+from typing import Sequence
+
+from itertools import chain
+
+import torch
+import torch.nn as nn
+from torchvision import models
+
+from .utils import normalize_activation
+
+
+def get_network(net_type: str):
+    if net_type == 'alex':
+        return AlexNet()
+    elif net_type == 'squeeze':
+        return SqueezeNet()
+    elif net_type == 'vgg':
+        return VGG16()
+    else:
+        raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
+
+
+class LinLayers(nn.ModuleList):
+    def __init__(self, n_channels_list: Sequence[int]):
+        super(LinLayers, self).__init__([
+            nn.Sequential(
+                nn.Identity(),
+                nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
+            ) for nc in n_channels_list
+        ])
+
+        for param in self.parameters():
+            param.requires_grad = False
+
+
+class BaseNet(nn.Module):
+    def __init__(self):
+        super(BaseNet, self).__init__()
+
+        # register buffer
+        self.register_buffer(
+            'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
+        self.register_buffer(
+            'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
+
+    def set_requires_grad(self, state: bool):
+        for param in chain(self.parameters(), self.buffers()):
+            param.requires_grad = state
+
+    def z_score(self, x: torch.Tensor):
+        return (x - self.mean) / self.std
+
+    def forward(self, x: torch.Tensor):
+        x = self.z_score(x)
+
+        output = []
+        for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
+            x = layer(x)
+            if i in self.target_layers:
+                output.append(normalize_activation(x))
+            if len(output) == len(self.target_layers):
+                break
+        return output
+
+
+class SqueezeNet(BaseNet):
+    def __init__(self):
+        super(SqueezeNet, self).__init__()
+
+        self.layers = models.squeezenet1_1(True).features
+        self.target_layers = [2, 5, 8, 10, 11, 12, 13]
+        self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
+
+        self.set_requires_grad(False)
+
+
+class AlexNet(BaseNet):
+    def __init__(self):
+        super(AlexNet, self).__init__()
+
+        self.layers = models.alexnet(True).features
+        self.target_layers = [2, 5, 8, 10, 12]
+        self.n_channels_list = [64, 192, 384, 256, 256]
+
+        self.set_requires_grad(False)
+
+
+class VGG16(BaseNet):
+    def __init__(self):
+        super(VGG16, self).__init__()
+
+        self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
+        self.target_layers = [4, 9, 16, 23, 30]
+        self.n_channels_list = [64, 128, 256, 512, 512]
+
+        self.set_requires_grad(False)

+ 30 - 0
lpipsPyTorch/modules/utils.py

@@ -0,0 +1,30 @@
+from collections import OrderedDict
+
+import torch
+
+
+def normalize_activation(x, eps=1e-10):
+    norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
+    return x / (norm_factor + eps)
+
+
+def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
+    # build url
+    url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
+        + f'master/lpips/weights/v{version}/{net_type}.pth'
+
+    # download
+    old_state_dict = torch.hub.load_state_dict_from_url(
+        url, progress=True,
+        map_location=None if torch.cuda.is_available() else torch.device('cpu')
+    )
+
+    # rename keys
+    new_state_dict = OrderedDict()
+    for key, val in old_state_dict.items():
+        new_key = key
+        new_key = new_key.replace('lin', '')
+        new_key = new_key.replace('model.', '')
+        new_state_dict[new_key] = val
+
+    return new_state_dict

+ 87 - 0
metrics.py

@@ -0,0 +1,87 @@
+from pathlib import Path
+import os
+from PIL import Image
+import torch
+import torchvision.transforms.functional as tf
+from utils.loss_utils import ssim
+from lpipsPyTorch import lpips
+import json
+from tqdm import tqdm
+from utils.image_utils import psnr
+from argparse import ArgumentParser
+
+def readImages(renders_dir, gt_dir):
+    renders = []
+    gts = []
+    image_names = []
+    for fname in os.listdir(renders_dir):
+        render = Image.open(renders_dir / fname)
+        gt = Image.open(gt_dir / fname)
+        renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
+        gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
+        image_names.append(fname)
+    return renders, gts, image_names
+
+def evaluate(model_paths):
+
+    full_dict = {}
+    per_view_dict = {}
+    full_dict_polytopeonly = {}
+    per_view_dict_polytopeonly = {}
+
+    for scene_dir in model_paths:
+        print("Scene:", scene_dir)
+        full_dict[scene_dir] = {}
+        per_view_dict[scene_dir] = {}
+        full_dict_polytopeonly[scene_dir] = {}
+        per_view_dict_polytopeonly[scene_dir] = {}
+
+        test_dir = Path(scene_dir) / "test"
+
+        for method in os.listdir(test_dir):
+            print("Method:", method)
+
+            full_dict[scene_dir][method] = {}
+            per_view_dict[scene_dir][method] = {}
+            full_dict_polytopeonly[scene_dir][method] = {}
+            per_view_dict_polytopeonly[scene_dir][method] = {}
+
+            method_dir = test_dir / method
+            gt_dir = method_dir/ "gt"
+            renders_dir = method_dir / "renders"
+            renders, gts, image_names = readImages(renders_dir, gt_dir)
+
+            ssims = []
+            psnrs = []
+            lpipss = []
+
+            for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
+                ssims.append(ssim(renders[idx], gts[idx]))
+                psnrs.append(psnr(renders[idx], gts[idx]))
+                lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))
+
+            print("SSIM: {}".format(torch.tensor(ssims).mean()))
+            print("PSNR: {}".format(torch.tensor(psnrs).mean()))
+            print("LPIPS: {}".format(torch.tensor(lpipss).mean()))
+
+            full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
+                                                    "PSNR": torch.tensor(psnrs).mean().item(),
+                                                    "LPIPS": torch.tensor(lpipss).mean().item()})
+            per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
+                                                        "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
+                                                        "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}})
+
+        with open(scene_dir + "/results.json", 'w') as fp:
+            json.dump(full_dict[scene_dir], fp, indent=True)
+        with open(scene_dir + "/per_view.json", 'w') as fp:
+            json.dump(per_view_dict[scene_dir], fp, indent=True)
+
+if __name__ == "__main__":
+    device = torch.device("cuda:0")
+    torch.cuda.set_device(device)
+
+    # Set up command line argument parser
+    parser = ArgumentParser(description="Training script parameters")
+    parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
+    args = parser.parse_args()
+    evaluate(args.model_paths)

+ 55 - 0
render.py

@@ -0,0 +1,55 @@
+import torch
+from scene import Scene
+import os
+from tqdm import tqdm
+from os import makedirs
+from gaussian_renderer import render
+import torchvision
+from utils.general_utils import safe_state
+from argparse import ArgumentParser
+from arguments import ModelParams, PipelineParams, get_combined_args
+from gaussian_renderer import GaussianModel
+
+def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
+    render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
+    gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
+
+    makedirs(render_path, exist_ok=True)
+    makedirs(gts_path, exist_ok=True)
+
+    for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
+        rendering = render(view, gaussians, pipeline, background)["render"]
+        gt = view.original_image[0:3, :, :]
+        torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
+        torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
+
+def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
+    with torch.no_grad():
+        gaussians = GaussianModel(dataset.sh_degree)
+        scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
+
+        bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
+        background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
+
+        if not skip_train:
+             render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
+
+        if not skip_test:
+             render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)
+
+if __name__ == "__main__":
+    # Set up command line argument parser
+    parser = ArgumentParser(description="Testing script parameters")
+    model = ModelParams(parser, sentinel=True)
+    pipeline = PipelineParams(parser)
+    parser.add_argument("--iteration", default=-1, type=int)
+    parser.add_argument("--skip_train", action="store_true")
+    parser.add_argument("--skip_test", action="store_true")
+    parser.add_argument("--quiet", action="store_true")
+    args = get_combined_args(parser)
+    print("Rendering " + args.model_path)
+
+    # Initialize system state (RNG)
+    safe_state(args.quiet)
+
+    render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)

+ 83 - 0
scene/__init__.py

@@ -0,0 +1,83 @@
+import os
+import random
+import json
+from utils.system_utils import searchForMaxIteration
+from scene.dataset_readers import sceneLoadTypeCallbacks
+from scene.gaussian_model import GaussianModel
+from arguments import ModelParams
+from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
+
+class Scene:
+
+    gaussians : GaussianModel
+
+    def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
+        """b
+        :param path: Path to colmap scene main folder.
+        """
+        self.model_path = args.model_path
+        self.loaded_iter = None
+        self.gaussians = gaussians
+
+        if load_iteration:
+            if load_iteration == -1:
+                self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
+            else:
+                self.loaded_iter = load_iteration
+            print("Loading trained model at iteration {}".format(self.loaded_iter))
+
+        self.train_cameras = {}
+        self.test_cameras = {}
+
+        if os.path.exists(os.path.join(args.source_path, "sparse")):
+            scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
+        elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
+            print("Found transforms_train.json file, assuming Blender data set!")
+            scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
+        else:
+            assert False, "Could not recognize scene type!"
+
+        if not self.loaded_iter:
+            with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
+                dest_file.write(src_file.read())
+            json_cams = []
+            camlist = []
+            if scene_info.test_cameras:
+                camlist.extend(scene_info.test_cameras)
+            if scene_info.train_cameras:
+                camlist.extend(scene_info.train_cameras)
+            for id, cam in enumerate(camlist):
+                json_cams.append(camera_to_JSON(id, cam))
+            with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
+                json.dump(json_cams, file)
+
+        if shuffle:
+            random.shuffle(scene_info.train_cameras)  # Multi-res consistent random shuffling
+            random.shuffle(scene_info.test_cameras)  # Multi-res consistent random shuffling
+
+        self.cameras_extent = scene_info.nerf_normalization["radius"]
+
+        for resolution_scale in resolution_scales:
+            print("Loading Training Cameras")
+            self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
+            print("Loading Test Cameras")
+            self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
+
+        if self.loaded_iter:
+            self.gaussians.load_ply(os.path.join(self.model_path,
+                                                           "point_cloud",
+                                                           "iteration_" + str(self.loaded_iter),
+                                                           "point_cloud.ply"),
+                                              og_number_points=len(scene_info.point_cloud.points))
+        else:
+            self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
+
+    def save(self, iteration):
+        point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
+        self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
+
+    def getTrainCameras(self, scale=1.0):
+        return self.train_cameras[scale]
+
+    def getTestCameras(self, scale=1.0):
+        return self.test_cameras[scale]

+ 53 - 0
scene/cameras.py

@@ -0,0 +1,53 @@
+import torch
+from torch import nn
+import numpy as np
+from utils.graphics_utils import getWorld2View2, getProjectionMatrix
+
+class Camera(nn.Module):
+    def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
+                 image_name, uid,
+                 trans=np.array([0.0, 0.0, 0.0]), scale=1.0
+                 ):
+        super(Camera, self).__init__()
+
+        self.uid = uid
+        self.colmap_id = colmap_id
+        self.R = R
+        self.T = T
+        self.FoVx = FoVx
+        self.FoVy = FoVy
+        self.image_name = image_name
+
+        self.original_image = image.clamp(0.0, 1.0).cuda()
+        self.image_width = self.original_image.shape[2]
+        self.image_height = self.original_image.shape[1]
+
+        if gt_alpha_mask is not None:
+            self.original_image *= gt_alpha_mask.cuda()
+        else:
+            self.original_image *= torch.ones((1, self.image_height, self.image_width), device="cuda")
+
+        self.zfar = 100.0
+        self.znear = 0.01
+
+        self.trans = trans
+        self.scale = scale
+
+        self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
+        self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
+        self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
+        self.camera_center = self.world_view_transform.inverse()[3, :3]
+
+class MiniCam:
+    def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
+        self.image_width = width
+        self.image_height = height    
+        self.FoVy = fovy
+        self.FoVx = fovx
+        self.znear = znear
+        self.zfar = zfar
+        self.world_view_transform = world_view_transform
+        self.full_proj_transform = full_proj_transform
+        view_inv = torch.inverse(self.world_view_transform)
+        self.camera_center = view_inv[3][:3]
+

+ 271 - 0
scene/colmap_loader.py

@@ -0,0 +1,271 @@
+import numpy as np
+import collections
+import struct
+
+CameraModel = collections.namedtuple(
+    "CameraModel", ["model_id", "model_name", "num_params"])
+Camera = collections.namedtuple(
+    "Camera", ["id", "model", "width", "height", "params"])
+BaseImage = collections.namedtuple(
+    "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
+Point3D = collections.namedtuple(
+    "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
+CAMERA_MODELS = {
+    CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
+    CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
+    CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
+    CameraModel(model_id=3, model_name="RADIAL", num_params=5),
+    CameraModel(model_id=4, model_name="OPENCV", num_params=8),
+    CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
+    CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
+    CameraModel(model_id=7, model_name="FOV", num_params=5),
+    CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
+    CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
+    CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
+}
+CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
+                         for camera_model in CAMERA_MODELS])
+CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
+                           for camera_model in CAMERA_MODELS])
+
+
+def qvec2rotmat(qvec):
+    return np.array([
+        [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
+         2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
+         2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
+        [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
+         1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
+         2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
+        [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
+         2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
+         1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
+
+def rotmat2qvec(R):
+    Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
+    K = np.array([
+        [Rxx - Ryy - Rzz, 0, 0, 0],
+        [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
+        [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
+        [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
+    eigvals, eigvecs = np.linalg.eigh(K)
+    qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
+    if qvec[0] < 0:
+        qvec *= -1
+    return qvec
+
+class Image(BaseImage):
+    def qvec2rotmat(self):
+        return qvec2rotmat(self.qvec)
+
+def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
+    """Read and unpack the next bytes from a binary file.
+    :param fid:
+    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
+    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+    :param endian_character: Any of {@, =, <, >, !}
+    :return: Tuple of read and unpacked values.
+    """
+    data = fid.read(num_bytes)
+    return struct.unpack(endian_character + format_char_sequence, data)
+
+def read_points3D_text(path):
+    """
+    see: src/base/reconstruction.cc
+        void Reconstruction::ReadPoints3DText(const std::string& path)
+        void Reconstruction::WritePoints3DText(const std::string& path)
+    """
+    xyzs = None
+    rgbs = None
+    errors = None
+    with open(path, "r") as fid:
+        while True:
+            line = fid.readline()
+            if not line:
+                break
+            line = line.strip()
+            if len(line) > 0 and line[0] != "#":
+                elems = line.split()
+                xyz = np.array(tuple(map(float, elems[1:4])))
+                rgb = np.array(tuple(map(int, elems[4:7])))
+                error = np.array(float(elems[7]))
+                if xyzs is None:
+                    xyzs = xyz[None, ...]
+                    rgbs = rgb[None, ...]
+                    errors = error[None, ...]
+                else:
+                    xyzs = np.append(xyzs, xyz[None, ...], axis=0)
+                    rgbs = np.append(rgbs, rgb[None, ...], axis=0)
+                    errors = np.append(errors, error[None, ...], axis=0)
+    return xyzs, rgbs, errors
+
+def read_points3D_binary(path_to_model_file):
+    """
+    see: src/base/reconstruction.cc
+        void Reconstruction::ReadPoints3DBinary(const std::string& path)
+        void Reconstruction::WritePoints3DBinary(const std::string& path)
+    """
+
+
+    with open(path_to_model_file, "rb") as fid:
+        num_points = read_next_bytes(fid, 8, "Q")[0]
+
+        xyzs = np.empty((num_points, 3))
+        rgbs = np.empty((num_points, 3))
+        errors = np.empty((num_points, 1))
+
+        for p_id in range(num_points):
+            binary_point_line_properties = read_next_bytes(
+                fid, num_bytes=43, format_char_sequence="QdddBBBd")
+            xyz = np.array(binary_point_line_properties[1:4])
+            rgb = np.array(binary_point_line_properties[4:7])
+            error = np.array(binary_point_line_properties[7])
+            track_length = read_next_bytes(
+                fid, num_bytes=8, format_char_sequence="Q")[0]
+            track_elems = read_next_bytes(
+                fid, num_bytes=8*track_length,
+                format_char_sequence="ii"*track_length)
+            xyzs[p_id] = xyz
+            rgbs[p_id] = rgb
+            errors[p_id] = error
+    return xyzs, rgbs, errors
+
+def read_intrinsics_text(path):
+    """
+    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
+    """
+    cameras = {}
+    with open(path, "r") as fid:
+        while True:
+            line = fid.readline()
+            if not line:
+                break
+            line = line.strip()
+            if len(line) > 0 and line[0] != "#":
+                elems = line.split()
+                camera_id = int(elems[0])
+                model = elems[1]
+                assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
+                width = int(elems[2])
+                height = int(elems[3])
+                params = np.array(tuple(map(float, elems[4:])))
+                cameras[camera_id] = Camera(id=camera_id, model=model,
+                                            width=width, height=height,
+                                            params=params)
+    return cameras
+
+def read_extrinsics_binary(path_to_model_file):
+    """
+    see: src/base/reconstruction.cc
+        void Reconstruction::ReadImagesBinary(const std::string& path)
+        void Reconstruction::WriteImagesBinary(const std::string& path)
+    """
+    images = {}
+    with open(path_to_model_file, "rb") as fid:
+        num_reg_images = read_next_bytes(fid, 8, "Q")[0]
+        for _ in range(num_reg_images):
+            binary_image_properties = read_next_bytes(
+                fid, num_bytes=64, format_char_sequence="idddddddi")
+            image_id = binary_image_properties[0]
+            qvec = np.array(binary_image_properties[1:5])
+            tvec = np.array(binary_image_properties[5:8])
+            camera_id = binary_image_properties[8]
+            image_name = ""
+            current_char = read_next_bytes(fid, 1, "c")[0]
+            while current_char != b"\x00":   # look for the ASCII 0 entry
+                image_name += current_char.decode("utf-8")
+                current_char = read_next_bytes(fid, 1, "c")[0]
+            num_points2D = read_next_bytes(fid, num_bytes=8,
+                                           format_char_sequence="Q")[0]
+            x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
+                                       format_char_sequence="ddq"*num_points2D)
+            xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
+                                   tuple(map(float, x_y_id_s[1::3]))])
+            point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
+            images[image_id] = Image(
+                id=image_id, qvec=qvec, tvec=tvec,
+                camera_id=camera_id, name=image_name,
+                xys=xys, point3D_ids=point3D_ids)
+    return images
+
+
+def read_intrinsics_binary(path_to_model_file):
+    """
+    see: src/base/reconstruction.cc
+        void Reconstruction::WriteCamerasBinary(const std::string& path)
+        void Reconstruction::ReadCamerasBinary(const std::string& path)
+    """
+    cameras = {}
+    with open(path_to_model_file, "rb") as fid:
+        num_cameras = read_next_bytes(fid, 8, "Q")[0]
+        for _ in range(num_cameras):
+            camera_properties = read_next_bytes(
+                fid, num_bytes=24, format_char_sequence="iiQQ")
+            camera_id = camera_properties[0]
+            model_id = camera_properties[1]
+            model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
+            width = camera_properties[2]
+            height = camera_properties[3]
+            num_params = CAMERA_MODEL_IDS[model_id].num_params
+            params = read_next_bytes(fid, num_bytes=8*num_params,
+                                     format_char_sequence="d"*num_params)
+            cameras[camera_id] = Camera(id=camera_id,
+                                        model=model_name,
+                                        width=width,
+                                        height=height,
+                                        params=np.array(params))
+        assert len(cameras) == num_cameras
+    return cameras
+
+
+def read_extrinsics_text(path):
+    """
+    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
+    """
+    images = {}
+    with open(path, "r") as fid:
+        while True:
+            line = fid.readline()
+            if not line:
+                break
+            line = line.strip()
+            if len(line) > 0 and line[0] != "#":
+                elems = line.split()
+                image_id = int(elems[0])
+                qvec = np.array(tuple(map(float, elems[1:5])))
+                tvec = np.array(tuple(map(float, elems[5:8])))
+                camera_id = int(elems[8])
+                image_name = elems[9]
+                elems = fid.readline().split()
+                xys = np.column_stack([tuple(map(float, elems[0::3])),
+                                       tuple(map(float, elems[1::3]))])
+                point3D_ids = np.array(tuple(map(int, elems[2::3])))
+                images[image_id] = Image(
+                    id=image_id, qvec=qvec, tvec=tvec,
+                    camera_id=camera_id, name=image_name,
+                    xys=xys, point3D_ids=point3D_ids)
+    return images
+
+
+def read_colmap_bin_array(path):
+    """
+    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
+
+    :param path: path to the colmap binary file.
+    :return: nd array with the floating point values in the value
+    """
+    with open(path, "rb") as fid:
+        width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
+                                                usecols=(0, 1, 2), dtype=int)
+        fid.seek(0)
+        num_delimiter = 0
+        byte = fid.read(1)
+        while True:
+            if byte == b"&":
+                num_delimiter += 1
+                if num_delimiter >= 3:
+                    break
+            byte = fid.read(1)
+        array = np.fromfile(fid, np.float32)
+    array = array.reshape((width, height, channels), order="F")
+    return np.transpose(array, (1, 0, 2)).squeeze()

+ 244 - 0
scene/dataset_readers.py

@@ -0,0 +1,244 @@
+import os
+import sys
+from PIL import Image
+from typing import NamedTuple
+from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
+    read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
+from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
+import numpy as np
+import json
+from pathlib import Path
+from plyfile import PlyData, PlyElement
+from utils.sh_utils import SH2RGB
+from scene.gaussian_model import BasicPointCloud
+
+class CameraInfo(NamedTuple):
+    uid: int
+    R: np.array
+    T: np.array
+    FovY: np.array
+    FovX: np.array
+    image: np.array
+    image_path: str
+    image_name: str
+    width: int
+    height: int
+
+class SceneInfo(NamedTuple):
+    point_cloud: BasicPointCloud
+    train_cameras: list
+    test_cameras: list
+    nerf_normalization: dict
+    ply_path: str
+
+def getNerfppNorm(cam_info):
+    def get_center_and_diag(cam_centers):
+        cam_centers = np.hstack(cam_centers)
+        avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
+        center = avg_cam_center
+        dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
+        diagonal = np.max(dist)
+        return center.flatten(), diagonal
+
+    cam_centers = []
+
+    for cam in cam_info:
+        W2C = getWorld2View2(cam.R, cam.T)
+        C2W = np.linalg.inv(W2C)
+        cam_centers.append(C2W[:3, 3:4])
+
+    center, diagonal = get_center_and_diag(cam_centers)
+    radius = diagonal * 1.1
+
+    translate = -center
+
+    return {"translate": translate, "radius": radius}
+
+def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
+    cam_infos = []
+    for idx, key in enumerate(cam_extrinsics):
+        sys.stdout.write('\r')
+        # the exact output you're looking for:
+        sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
+        sys.stdout.flush()
+
+        extr = cam_extrinsics[key]
+        intr = cam_intrinsics[extr.camera_id]
+        height = intr.height
+        width = intr.width
+
+        uid = intr.id
+        R = np.transpose(qvec2rotmat(extr.qvec))
+        T = np.array(extr.tvec)
+
+        if intr.model=="SIMPLE_PINHOLE":
+            focal_length_x = intr.params[0]
+            FovY = focal2fov(focal_length_x, height)
+            FovX = focal2fov(focal_length_x, width)
+        elif intr.model=="PINHOLE":
+            focal_length_x = intr.params[0]
+            focal_length_y = intr.params[1]
+            FovY = focal2fov(focal_length_y, height)
+            FovX = focal2fov(focal_length_x, width)
+        else:
+            assert False, "Colmap camera model not handled!"
+
+        image_path = os.path.join(images_folder, os.path.basename(extr.name))
+        image_name = os.path.basename(image_path).split(".")[0]
+        image = Image.open(image_path)
+
+        cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
+                              image_path=image_path, image_name=image_name, width=width, height=height)
+        cam_infos.append(cam_info)
+    sys.stdout.write('\n')
+    return cam_infos
+
+def fetchPly(path):
+    plydata = PlyData.read(path)
+    vertices = plydata['vertex']
+    positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
+    colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
+    normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
+    return BasicPointCloud(points=positions, colors=colors, normals=normals)
+
+def storePly(path, xyz, rgb):
+    # Define the dtype for the structured array
+    dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
+            ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
+            ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
+    
+    normals = np.zeros_like(xyz)
+
+    elements = np.empty(xyz.shape[0], dtype=dtype)
+    attributes = np.concatenate((xyz, normals, rgb), axis=1)
+    elements[:] = list(map(tuple, attributes))
+
+    # Create the PlyData object and write to file
+    vertex_element = PlyElement.describe(elements, 'vertex')
+    ply_data = PlyData([vertex_element])
+    ply_data.write(path)
+
+def readColmapSceneInfo(path, images, eval, llffhold=8):
+    try:
+        cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
+        cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
+        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
+        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
+    except:
+        cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
+        cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
+        cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
+        cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
+
+    reading_dir = "images" if images == None else images
+    cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))
+    cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
+
+    if eval:
+        train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
+        test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
+    else:
+        train_cam_infos = cam_infos
+        test_cam_infos = []
+
+    nerf_normalization = getNerfppNorm(train_cam_infos)
+
+    ply_path = os.path.join(path, "sparse/0/points3d.ply")
+    bin_path = os.path.join(path, "sparse/0/points3d.bin")
+    txt_path = os.path.join(path, "sparse/0/points3d.txt")
+    if not os.path.exists(ply_path):
+        print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
+        try:
+            xyz, rgb, _ = read_points3D_binary(bin_path)
+        except:
+            xyz, rgb, _ = read_points3D_text(txt_path)
+        storePly(ply_path, xyz, rgb)
+    try:
+        pcd = fetchPly(ply_path)
+    except:
+        pcd = None
+
+    scene_info = SceneInfo(point_cloud=pcd,
+                           train_cameras=train_cam_infos,
+                           test_cameras=test_cam_infos,
+                           nerf_normalization=nerf_normalization,
+                           ply_path=ply_path)
+    return scene_info
+
+def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
+    cam_infos = []
+
+    with open(os.path.join(path, transformsfile)) as json_file:
+        contents = json.load(json_file)
+        fovx = contents["camera_angle_x"]
+
+        frames = contents["frames"]
+        for idx, frame in enumerate(frames):
+            cam_name = os.path.join(path, frame["file_path"] + extension)
+
+            matrix = np.linalg.inv(np.array(frame["transform_matrix"]))
+            R = -np.transpose(matrix[:3,:3])
+            R[:,0] = -R[:,0]
+            T = -matrix[:3, 3]
+
+            image_path = os.path.join(path, cam_name)
+            image_name = Path(cam_name).stem
+            image = Image.open(image_path)
+
+            im_data = np.array(image.convert("RGBA"))
+
+            bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
+
+            norm_data = im_data / 255.0
+            arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
+            image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
+
+            fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
+            FovY = fovx 
+            FovX = fovy
+
+            cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,
+                            image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
+            
+    return cam_infos
+
+def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
+    print("Reading Training Transforms")
+    train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
+    print("Reading Test Transforms")
+    test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
+    
+    if not eval:
+        train_cam_infos.extend(test_cam_infos)
+        test_cam_infos = []
+
+    nerf_normalization = getNerfppNorm(train_cam_infos)
+
+    ply_path = os.path.join(path, "points3d.ply")
+    if not os.path.exists(ply_path):
+        # Since this data set has no colmap data, we start with random points
+        num_pts = 100_000
+        print(f"Generating random point cloud ({num_pts})...")
+        
+        # We create random points inside the bounds of the synthetic Blender scenes
+        xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
+        shs = np.random.random((num_pts, 3)) / 255.0
+        pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
+
+        storePly(ply_path, xyz, SH2RGB(shs) * 255)
+    try:
+        pcd = fetchPly(ply_path)
+    except:
+        pcd = None
+
+    scene_info = SceneInfo(point_cloud=pcd,
+                           train_cameras=train_cam_infos,
+                           test_cameras=test_cam_infos,
+                           nerf_normalization=nerf_normalization,
+                           ply_path=ply_path)
+    return scene_info
+
+sceneLoadTypeCallbacks = {
+    "Colmap": readColmapSceneInfo,
+    "Blender" : readNerfSyntheticInfo
+}

+ 356 - 0
scene/gaussian_model.py

@@ -0,0 +1,356 @@
+import torch
+import numpy as np
+from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
+from torch import nn
+import os
+from utils.system_utils import mkdir_p
+from plyfile import PlyData, PlyElement
+from utils.sh_utils import RGB2SH
+from simple_knn._C import distCUDA2
+from utils.graphics_utils import BasicPointCloud
+from utils.general_utils import strip_symmetric, build_scaling_rotation
+
+class GaussianModel:
+    def __init__(self, sh_degree : int):
+
+        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
+            L = build_scaling_rotation(scaling_modifier * scaling, rotation)
+            actual_covariance = L @ L.transpose(1, 2)
+            symm = strip_symmetric(actual_covariance)
+            return symm
+
+        self.active_sh_degree = 0
+        self.max_sh_degree = sh_degree  
+
+        self._xyz = torch.empty(0)
+        self._features_dc = torch.empty(0)
+        self._features_rest = torch.empty(0)
+        self._scaling = torch.empty(0)
+        self._rotation = torch.empty(0)
+        self._opacity = torch.empty(0)
+        self.max_radii2D = torch.empty(0)
+        self.xyz_gradient_accum = torch.empty(0)
+
+        self.optimizer = None
+
+        self.scaling_activation = torch.exp
+        self.scaling_inverse_activation = torch.log
+
+        self.covariance_activation = build_covariance_from_scaling_rotation
+
+        self.opacity_activation = torch.sigmoid
+        self.inverse_opacity_activation = inverse_sigmoid
+
+        self.rotation_activation = torch.nn.functional.normalize
+
+    @property
+    def get_scaling(self):
+        return self.scaling_activation(self._scaling)
+    
+    @property
+    def get_rotation(self):
+        return self.rotation_activation(self._rotation)
+    
+    @property
+    def get_xyz(self):
+        return self._xyz
+    
+    @property
+    def get_features(self):
+        features_dc = self._features_dc
+        features_rest = self._features_rest
+        return torch.cat((features_dc, features_rest), dim=1)
+    
+    @property
+    def get_opacity(self):
+        return self.opacity_activation(self._opacity)
+    
+    def get_covariance(self, scaling_modifier = 1):
+        return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
+
+    def oneupSHdegree(self):
+        if self.active_sh_degree < self.max_sh_degree:
+            self.active_sh_degree += 1
+
+    def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
+        self.spatial_lr_scale = spatial_lr_scale
+        fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
+        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
+        features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
+        features[:, :3, 0 ] = fused_color
+        features[:, 3:, 1:] = 0.0
+
+        print("Number of points at initialisation : ", fused_point_cloud.shape[0])
+
+        dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
+        scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
+        rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
+        rots[:, 0] = 1
+
+        opacities = inverse_sigmoid(0.5 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
+
+        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
+        self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
+        self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
+        self._scaling = nn.Parameter(scales.requires_grad_(True))
+        self._rotation = nn.Parameter(rots.requires_grad_(True))
+        self._opacity = nn.Parameter(opacities.requires_grad_(True))
+        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
+
+    def training_setup(self, training_args):
+        self.percent_dense = training_args.percent_dense
+        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
+        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
+
+        l = [
+            {'params': [self._xyz], 'lr': training_args.position_lr_init*self.spatial_lr_scale, "name": "xyz"},
+            {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
+            {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
+            {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
+            {'params': [self._scaling], 'lr': training_args.scaling_lr*self.spatial_lr_scale, "name": "scaling"},
+            {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}
+        ]
+
+        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
+        self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
+                                                    lr_final=training_args.position_lr_final*self.spatial_lr_scale,
+                                                    lr_delay_mult=training_args.position_lr_delay_mult,
+                                                    max_steps=training_args.posititon_lr_max_steps)
+
+    def update_learning_rate(self, iteration):
+        ''' Learning rate scheduling per step '''
+        for param_group in self.optimizer.param_groups:
+            if param_group["name"] == "xyz":
+                lr = self.xyz_scheduler_args(iteration)
+                param_group['lr'] = lr
+                return lr
+
+    def construct_list_of_attributes(self):
+        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
+        # All channels except the 3 DC
+        for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
+            l.append('f_dc_{}'.format(i))
+        for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
+            l.append('f_rest_{}'.format(i))
+        l.append('opacity')
+        for i in range(self._scaling.shape[1]):
+            l.append('scale_{}'.format(i))
+        for i in range(self._rotation.shape[1]):
+            l.append('rot_{}'.format(i))
+        return l
+
+    def save_ply(self, path):
+        mkdir_p(os.path.dirname(path))
+
+        xyz = self._xyz.detach().cpu().numpy()
+        normals = np.zeros_like(xyz)
+        f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
+        f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
+        opacities = self._opacity.detach().cpu().numpy()
+        scale = self._scaling.detach().cpu().numpy()
+        rotation = self._rotation.detach().cpu().numpy()
+
+        dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
+
+        elements = np.empty(xyz.shape[0], dtype=dtype_full)
+        attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
+        elements[:] = list(map(tuple, attributes))
+        el = PlyElement.describe(elements, 'vertex')
+        PlyData([el]).write(path)
+
+    def reset_opacity(self):
+        opacities_new = inverse_sigmoid(torch.ones_like(self.get_opacity)*0.01)
+        optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
+        self._opacity = optimizable_tensors["opacity"]
+
+    def load_ply(self, path, og_number_points=-1):
+        self.og_number_points = og_number_points
+        plydata = PlyData.read(path)
+
+        xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
+                        np.asarray(plydata.elements[0]["y"]),
+                        np.asarray(plydata.elements[0]["z"])),  axis=1)
+        opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
+
+        features_dc = np.zeros((xyz.shape[0], 3, 1))
+        features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
+        features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
+        features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
+
+        extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
+        assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
+        features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
+        for idx, attr_name in enumerate(extra_f_names):
+            features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
+        # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
+        features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
+
+        scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
+        scales = np.zeros((xyz.shape[0], len(scale_names)))
+        for idx, attr_name in enumerate(scale_names):
+            scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+        rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
+        rots = np.zeros((xyz.shape[0], len(rot_names)))
+        for idx, attr_name in enumerate(rot_names):
+            rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+        self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
+        self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
+        self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
+        self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
+        self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
+        self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
+
+        self.active_sh_degree = self.max_sh_degree
+
+    def replace_tensor_to_optimizer(self, tensor, name):
+        optimizable_tensors = {}
+        for group in self.optimizer.param_groups:
+            if group["name"] == name:
+                stored_state = self.optimizer.state.get(group['params'][0], None)
+                stored_state["exp_avg"] = torch.zeros_like(tensor)
+                stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
+
+                del self.optimizer.state[group['params'][0]]
+                group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
+                self.optimizer.state[group['params'][0]] = stored_state
+
+                optimizable_tensors[group["name"]] = group["params"][0]
+        return optimizable_tensors
+
+    def _prune_optimizer(self, mask):
+        optimizable_tensors = {}
+        for group in self.optimizer.param_groups:
+            stored_state = self.optimizer.state.get(group['params'][0], None)
+            if stored_state is not None:
+                stored_state["exp_avg"] = stored_state["exp_avg"][mask]
+                stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
+
+                del self.optimizer.state[group['params'][0]]
+                group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
+                self.optimizer.state[group['params'][0]] = stored_state
+
+                optimizable_tensors[group["name"]] = group["params"][0]
+            else:
+                group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
+                optimizable_tensors[group["name"]] = group["params"][0]
+        return optimizable_tensors
+
+    def prune_points(self, mask):
+        valid_points_mask = ~mask
+        optimizable_tensors = self._prune_optimizer(valid_points_mask)
+
+        self._xyz = optimizable_tensors["xyz"]
+        self._features_dc = optimizable_tensors["f_dc"]
+        self._features_rest = optimizable_tensors["f_rest"]
+        self._opacity = optimizable_tensors["opacity"]
+        self._scaling = optimizable_tensors["scaling"]
+        self._rotation = optimizable_tensors["rotation"]
+
+        self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
+
+        self.denom = self.denom[valid_points_mask]
+        self.max_radii2D = self.max_radii2D[valid_points_mask]
+
+    def cat_tensors_to_optimizer(self, tensors_dict):
+        optimizable_tensors = {}
+        for group in self.optimizer.param_groups:
+            assert len(group["params"]) == 1
+            extension_tensor = tensors_dict[group["name"]]
+            stored_state = self.optimizer.state.get(group['params'][0], None)
+            if stored_state is not None:
+
+                stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
+                stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
+
+                del self.optimizer.state[group['params'][0]]
+                group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
+                self.optimizer.state[group['params'][0]] = stored_state
+
+                optimizable_tensors[group["name"]] = group["params"][0]
+            else:
+                group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
+                optimizable_tensors[group["name"]] = group["params"][0]
+
+        return optimizable_tensors
+
+    def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
+        d = {"xyz": new_xyz,
+        "f_dc": new_features_dc,
+        "f_rest": new_features_rest,
+        "opacity": new_opacities,
+        "scaling" : new_scaling,
+        "rotation" : new_rotation}
+
+        optimizable_tensors = self.cat_tensors_to_optimizer(d)
+        self._xyz = optimizable_tensors["xyz"]
+        self._features_dc = optimizable_tensors["f_dc"]
+        self._features_rest = optimizable_tensors["f_rest"]
+        self._opacity = optimizable_tensors["opacity"]
+        self._scaling = optimizable_tensors["scaling"]
+        self._rotation = optimizable_tensors["rotation"]
+
+        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
+        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
+        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
+
+    def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
+        n_init_points = self.get_xyz.shape[0]
+        # Extract points that satisfy the gradient condition
+        padded_grad = torch.zeros((n_init_points), device="cuda")
+        padded_grad[:grads.shape[0]] = grads.squeeze()
+        selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
+        selected_pts_mask = torch.logical_and(selected_pts_mask,
+                                              torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
+
+        stds = self.get_scaling[selected_pts_mask].repeat(N,1)
+        means =torch.zeros((stds.size(0), 3),device="cuda")
+        samples = torch.normal(mean=means, std=stds)
+        rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
+        new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
+        new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
+        new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
+        new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
+        new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
+        new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
+
+        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
+
+        prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
+        self.prune_points(prune_filter)
+
+    def densify_and_clone(self, grads, grad_threshold, scene_extent):
+        # Extract points that satisfy the gradient condition
+        selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
+        selected_pts_mask = torch.logical_and(selected_pts_mask,
+                                              torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
+        
+        new_xyz = self._xyz[selected_pts_mask]
+        new_features_dc = self._features_dc[selected_pts_mask]
+        new_features_rest = self._features_rest[selected_pts_mask]
+        new_opacities = self._opacity[selected_pts_mask]
+        new_scaling = self._scaling[selected_pts_mask]
+        new_rotation = self._rotation[selected_pts_mask]
+
+        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
+
+    def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
+        grads = self.xyz_gradient_accum / self.denom
+        grads[grads.isnan()] = 0.0
+
+        self.densify_and_clone(grads, max_grad, extent)
+        self.densify_and_split(grads, max_grad, extent)
+
+        prune_mask = (self.get_opacity < min_opacity).squeeze()
+        if max_screen_size:
+            big_points_vs = self.max_radii2D > max_screen_size
+            big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
+            prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
+        self.prune_points(prune_mask)
+
+        torch.cuda.empty_cache()
+
+    def add_densification_stats(self, viewspace_point_tensor, update_filter):
+        self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
+        self.denom[update_filter] += 1

+ 194 - 0
train.py

@@ -0,0 +1,194 @@
+import os
+import torch
+from random import randint
+from utils.loss_utils import l1_loss, ssim
+from gaussian_renderer import render, network_gui
+import sys
+from scene import Scene, GaussianModel
+from utils.general_utils import safe_state
+import uuid
+from tqdm import tqdm
+from utils.image_utils import psnr
+from argparse import ArgumentParser, Namespace
+from arguments import ModelParams, PipelineParams, OptimizationParams
+try:
+    from torch.utils.tensorboard import SummaryWriter
+    TENSORBOARD_FOUND = True
+except ImportError:
+    TENSORBOARD_FOUND = False
+
+def training(dataset, opt, pipe, testing_iterations, saving_iterations):
+    tb_writer = prepare_output_and_logger(dataset)
+    gaussians = GaussianModel(dataset.sh_degree)
+
+    scene = Scene(dataset, gaussians)
+    gaussians.training_setup(opt)
+
+    bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
+    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
+
+    iter_start = torch.cuda.Event(enable_timing = True)
+    iter_end = torch.cuda.Event(enable_timing = True)
+
+    viewpoint_stack = None
+    ema_loss_for_log = 0.0
+    progress_bar = tqdm(range(opt.iterations), desc="Training progress")
+    for iteration in range(1, opt.iterations + 1):        
+        if network_gui.conn == None:
+            network_gui.try_connect()
+        while network_gui.conn != None:
+            try:
+                net_image_bytes = None
+                custom_cam, do_training, pipe.do_shs_python, pipe.do_cov_python, keep_alive, scaling_modifer = network_gui.receive()
+                if custom_cam != None:
+                    net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
+                    net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
+                network_gui.send(net_image_bytes, dataset.source_path)
+                if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
+                    break
+            except Exception as e:
+                network_gui.conn = None
+
+        iter_start.record()
+
+        # Every 1000 its we increase the levels of SH up to a maximum degree
+        if iteration % 1000 == 0:
+            gaussians.oneupSHdegree()
+
+        # Pick a random Camera
+        if not viewpoint_stack:
+            viewpoint_stack = scene.getTrainCameras().copy()
+        viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
+
+        # Render
+        render_pkg = render(viewpoint_cam, gaussians, pipe, background)
+        image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
+
+        # Loss
+        gt_image = viewpoint_cam.original_image.cuda()
+        Ll1 = l1_loss(image, gt_image)
+        loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
+        loss.backward()
+
+        iter_end.record()
+
+        with torch.no_grad():
+            # Progress bar
+            ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
+            if iteration % 10 == 0:
+                progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
+                progress_bar.update(10)
+            if iteration == opt.iterations:
+                progress_bar.close()
+
+            # Keep track of max radii in image-space for pruning
+            gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
+
+            # Log and save
+            training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
+            if (iteration in saving_iterations):
+                print("\n[ITER {}] Saving Gaussians".format(iteration))
+                scene.save(iteration)
+
+            # Densification
+            if iteration < opt.densify_until_iter:
+                gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
+
+                if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
+                    size_threshold = 20 if iteration > opt.opacity_reset_interval else None
+                    gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
+                
+                if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
+                    gaussians.reset_opacity()
+
+            # Optimizer step
+            if iteration < opt.iterations:
+                gaussians.optimizer.step()
+                gaussians.optimizer.zero_grad(set_to_none = True)
+                gaussians.update_learning_rate(iteration)
+
+def prepare_output_and_logger(args):    
+    if not args.model_path:
+        if os.getenv('OAR_JOB_ID'):
+            unique_str=os.getenv('OAR_JOB_ID')
+        else:
+            unique_str = str(uuid.uuid4())
+        args.model_path = os.path.join("./output/", unique_str[0:10])
+        
+    # Set up output folder
+    print("Output folder: {}".format(args.model_path))
+    os.makedirs(args.model_path, exist_ok = True)
+    with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
+        cfg_log_f.write(str(Namespace(**vars(args))))
+
+    # Create Tensorboard writer
+    tb_writer = None
+    if TENSORBOARD_FOUND:
+        tb_writer = SummaryWriter(args.model_path)
+    else:
+        print("Tensorboard not available: not logging progress")
+    return tb_writer
+
+def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs):
+    if tb_writer:
+        tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
+        tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
+        tb_writer.add_scalar('iter_time', elapsed, iteration)
+
+    # Report test and samples of training set
+    if iteration in testing_iterations:
+        torch.cuda.empty_cache()
+        validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, 
+                              {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
+
+        for config in validation_configs:
+            if config['cameras'] and len(config['cameras']) > 0:
+                images = torch.tensor([], device="cuda")
+                gts = torch.tensor([], device="cuda")
+                for idx, viewpoint in enumerate(config['cameras']):
+                    image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
+                    gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
+                    images = torch.cat((images, image.unsqueeze(0)), dim=0)
+                    gts = torch.cat((gts, gt_image.unsqueeze(0)), dim=0)
+                    if tb_writer and (idx < 5):
+                        tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image, global_step=iteration)
+                        if iteration == testing_iterations[0]:
+                            tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image, global_step=iteration)
+
+                l1_test = l1_loss(images, gts)
+                psnr_test = psnr(images, gts).mean()            
+                print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
+                if tb_writer:
+                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
+                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
+
+        if tb_writer:
+            tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
+            tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
+        torch.cuda.empty_cache()
+
+if __name__ == "__main__":
+    # Set up command line argument parser
+    parser = ArgumentParser(description="Training script parameters")
+    lp = ModelParams(parser)
+    op = OptimizationParams(parser)
+    pp = PipelineParams(parser)
+    parser.add_argument('--ip', type=str, default="127.0.0.1")
+    parser.add_argument('--port', type=int, default=6009)
+    parser.add_argument('--detect_anomaly', action='store_true', default=False)
+    parser.add_argument("--test_iterations", nargs="+", type=int, default=[7_000, 30_000])
+    parser.add_argument("--save_iterations", nargs="+", type=int, default=[7_000, 30_000])
+    parser.add_argument("--quiet", action="store_true")
+    args = parser.parse_args(sys.argv[1:])
+    print("Optimizing " + args.model_path)
+
+    # Initialize system state (RNG)
+    safe_state(args.quiet)
+
+    # Start GUI server, configure and run training
+    network_gui.init(args.ip, args.port)
+    torch.autograd.set_detect_anomaly(args.detect_anomaly)
+    training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations)
+
+    # All done
+    print("\nTraining complete.")

+ 57 - 0
utils/camera_utils.py

@@ -0,0 +1,57 @@
+from scene.cameras import Camera
+import numpy as np
+from utils.general_utils import PILtoTorch
+from utils.graphics_utils import fov2focal
+
+def loadCam(args, id, cam_info, resolution_scale):
+    orig_w, orig_h = cam_info.image.size
+
+    if args.resolution in [1, 2, 4, 8]:
+        resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
+    else:  # should be a type that converts to float
+        global_down = orig_w/args.resolution
+        scale = float(global_down) * float(resolution_scale)
+        resolution = (int(orig_w / scale), int(orig_h / scale))
+
+    resized_image_rgb = PILtoTorch(cam_info.image, resolution)
+
+    gt_image = resized_image_rgb[:3, ...]
+    loaded_mask = None
+
+    if resized_image_rgb.shape[1] == 4:
+        loaded_mask = resized_image_rgb[3:4, ...]
+
+    return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, 
+                  FoVx=cam_info.FovX, FoVy=cam_info.FovY, 
+                  image=gt_image, gt_alpha_mask=loaded_mask,
+                  image_name=cam_info.image_name, uid=id)
+
+def cameraList_from_camInfos(cam_infos, resolution_scale, args):
+    camera_list = []
+
+    for id, c in enumerate(cam_infos):
+        camera_list.append(loadCam(args, id, c, resolution_scale))
+
+    return camera_list
+
+def camera_to_JSON(id, camera : Camera):
+    Rt = np.zeros((4, 4))
+    Rt[:3, :3] = camera.R.transpose()
+    Rt[:3, 3] = camera.T
+    Rt[3, 3] = 1.0
+
+    W2C = np.linalg.inv(Rt)
+    pos = W2C[:3, 3]
+    rot = W2C[:3, :3]
+    serializable_array_2d = [x.tolist() for x in rot]
+    camera_entry = {
+        'id' : id,
+        'img_name' : camera.image_name,
+        'width' : camera.width,
+        'height' : camera.height,
+        'position': pos.tolist(),
+        'rotation': serializable_array_2d,
+        'fy' : fov2focal(camera.FovY, camera.height),
+        'fx' : fov2focal(camera.FovX, camera.width)
+    }
+    return camera_entry

+ 122 - 0
utils/general_utils.py

@@ -0,0 +1,122 @@
+import torch
+import sys
+from datetime import datetime
+import numpy as np
+import random
+
+def inverse_sigmoid(x):
+    return torch.log(x/(1-x))
+
+def PILtoTorch(pil_image, resolution):
+    resized_image_PIL = pil_image.resize(resolution)
+    resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
+    if len(resized_image.shape) == 3:
+        return resized_image.permute(2, 0, 1)
+    else:
+        return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
+
+def get_expon_lr_func(
+    lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
+):
+    """
+    Copied from Plenoxels
+
+    Continuous learning rate decay function. Adapted from JaxNeRF
+    The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
+    is log-linearly interpolated elsewhere (equivalent to exponential decay).
+    If lr_delay_steps>0 then the learning rate will be scaled by some smooth
+    function of lr_delay_mult, such that the initial learning rate is
+    lr_init*lr_delay_mult at the beginning of optimization but will be eased back
+    to the normal learning rate when steps>lr_delay_steps.
+    :param conf: config subtree 'lr' or similar
+    :param max_steps: int, the number of steps during optimization.
+    :return HoF which takes step as input
+    """
+
+    def helper(step):
+        if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
+            # Disable this parameter
+            return 0.0
+        if lr_delay_steps > 0:
+            # A kind of reverse cosine decay.
+            delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
+                0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
+            )
+        else:
+            delay_rate = 1.0
+        t = np.clip(step / max_steps, 0, 1)
+        log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
+        return delay_rate * log_lerp
+
+    return helper
+
+def strip_lowerdiag(L):
+    uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
+
+    uncertainty[:, 0] = L[:, 0, 0]
+    uncertainty[:, 1] = L[:, 0, 1]
+    uncertainty[:, 2] = L[:, 0, 2]
+    uncertainty[:, 3] = L[:, 1, 1]
+    uncertainty[:, 4] = L[:, 1, 2]
+    uncertainty[:, 5] = L[:, 2, 2]
+    return uncertainty
+
+def strip_symmetric(sym):
+    return strip_lowerdiag(sym)
+
+def build_rotation(r):
+    norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
+
+    q = r / norm[:, None]
+
+    R = torch.zeros((q.size(0), 3, 3), device='cuda')
+
+    r = q[:, 0]
+    x = q[:, 1]
+    y = q[:, 2]
+    z = q[:, 3]
+
+    R[:, 0, 0] = 1 - 2 * (y*y + z*z)
+    R[:, 0, 1] = 2 * (x*y - r*z)
+    R[:, 0, 2] = 2 * (x*z + r*y)
+    R[:, 1, 0] = 2 * (x*y + r*z)
+    R[:, 1, 1] = 1 - 2 * (x*x + z*z)
+    R[:, 1, 2] = 2 * (y*z - r*x)
+    R[:, 2, 0] = 2 * (x*z - r*y)
+    R[:, 2, 1] = 2 * (y*z + r*x)
+    R[:, 2, 2] = 1 - 2 * (x*x + y*y)
+    return R
+
+def build_scaling_rotation(s, r):
+    L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
+    R = build_rotation(r)
+
+    L[:,0,0] = s[:,0]
+    L[:,1,1] = s[:,1]
+    L[:,2,2] = s[:,2]
+
+    L = R @ L
+    return L
+
+def safe_state(silent):
+    old_f = sys.stdout
+    class F:
+        def __init__(self, silent):
+            self.silent = silent
+
+        def write(self, x):
+            if not self.silent:
+                if x.endswith("\n"):
+                    old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
+                else:
+                    old_f.write(x)
+
+        def flush(self):
+            old_f.flush()
+
+    sys.stdout = F(silent)
+
+    random.seed(0)
+    np.random.seed(0)
+    torch.manual_seed(0)
+    torch.cuda.set_device(torch.device("cuda:0"))

+ 66 - 0
utils/graphics_utils.py

@@ -0,0 +1,66 @@
+import torch
+import math
+import numpy as np
+from typing import NamedTuple
+
+class BasicPointCloud(NamedTuple):
+    points : np.array
+    colors : np.array
+    normals : np.array
+
+def geom_transform_points(points, transf_matrix):
+    P, _ = points.shape
+    ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
+    points_hom = torch.cat([points, ones], dim=1)
+    points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
+
+    denom = points_out[..., 3:] + 0.0000001
+    return (points_out[..., :3] / denom).squeeze(dim=0)
+
+def getWorld2View(R, t):
+    Rt = np.zeros((4, 4))
+    Rt[:3, :3] = R.transpose()
+    Rt[:3, 3] = t
+    Rt[3, 3] = 1.0
+    return np.float32(Rt)
+
+def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
+    Rt = np.zeros((4, 4))
+    Rt[:3, :3] = R.transpose()
+    Rt[:3, 3] = t
+    Rt[3, 3] = 1.0
+
+    C2W = np.linalg.inv(Rt)
+    cam_center = C2W[:3, 3]
+    cam_center = (cam_center + translate) * scale
+    C2W[:3, 3] = cam_center
+    Rt = np.linalg.inv(C2W)
+    return np.float32(Rt)
+
+def getProjectionMatrix(znear, zfar, fovX, fovY):
+    tanHalfFovY = math.tan((fovY / 2))
+    tanHalfFovX = math.tan((fovX / 2))
+
+    top = tanHalfFovY * znear
+    bottom = -top
+    right = tanHalfFovX * znear
+    left = -right
+
+    P = torch.zeros(4, 4)
+
+    z_sign = 1.0
+
+    P[0, 0] = 2.0 * znear / (right - left)
+    P[1, 1] = 2.0 * znear / (top - bottom)
+    P[0, 2] = (right + left) / (right - left)
+    P[1, 2] = (top + bottom) / (top - bottom)
+    P[3, 2] = z_sign
+    P[2, 2] = z_sign * zfar / (zfar - znear)
+    P[2, 3] = -(zfar * znear) / (zfar - znear)
+    return P
+
+def fov2focal(fov, pixels):
+    return pixels / (2 * math.tan(fov / 2))
+
+def focal2fov(focal, pixels):
+    return 2*math.atan(pixels/(2*focal))

+ 8 - 0
utils/image_utils.py

@@ -0,0 +1,8 @@
+import torch
+
+def mse(img1, img2):
+    return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
+
+def psnr(img1, img2):
+    mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
+    return 20 * torch.log10(1.0 / torch.sqrt(mse))

+ 53 - 0
utils/loss_utils.py

@@ -0,0 +1,53 @@
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+from math import exp
+
+def l1_loss(network_output, gt):
+    return torch.abs((network_output - gt)).mean()
+
+def l2_loss(network_output, gt):
+    return ((network_output - gt) ** 2).mean()
+
+def gaussian(window_size, sigma):
+    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
+    return gauss / gauss.sum()
+
+def create_window(window_size, channel):
+    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
+    return window
+
+def ssim(img1, img2, window_size=11, size_average=True):
+    channel = img1.size(-3)
+    window = create_window(window_size, channel)
+
+    if img1.is_cuda:
+        window = window.cuda(img1.get_device())
+    window = window.type_as(img1)
+
+    return _ssim(img1, img2, window, window_size, channel, size_average)
+
+def _ssim(img1, img2, window, window_size, channel, size_average=True):
+    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
+    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
+
+    mu1_sq = mu1.pow(2)
+    mu2_sq = mu2.pow(2)
+    mu1_mu2 = mu1 * mu2
+
+    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
+    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
+    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
+
+    C1 = 0.01 ** 2
+    C2 = 0.03 ** 2
+
+    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+
+    if size_average:
+        return ssim_map.mean()
+    else:
+        return ssim_map.mean(1).mean(1).mean(1)
+

+ 118 - 0
utils/sh_utils.py

@@ -0,0 +1,118 @@
+#  Copyright 2021 The PlenOctree Authors.
+#  Redistribution and use in source and binary forms, with or without
+#  modification, are permitted provided that the following conditions are met:
+#
+#  1. Redistributions of source code must retain the above copyright notice,
+#  this list of conditions and the following disclaimer.
+#
+#  2. Redistributions in binary form must reproduce the above copyright notice,
+#  this list of conditions and the following disclaimer in the documentation
+#  and/or other materials provided with the distribution.
+#
+#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+#  POSSIBILITY OF SUCH DAMAGE.
+
+import torch
+
+C0 = 0.28209479177387814
+C1 = 0.4886025119029199
+C2 = [
+    1.0925484305920792,
+    -1.0925484305920792,
+    0.31539156525252005,
+    -1.0925484305920792,
+    0.5462742152960396
+]
+C3 = [
+    -0.5900435899266435,
+    2.890611442640554,
+    -0.4570457994644658,
+    0.3731763325901154,
+    -0.4570457994644658,
+    1.445305721320277,
+    -0.5900435899266435
+]
+C4 = [
+    2.5033429417967046,
+    -1.7701307697799304,
+    0.9461746957575601,
+    -0.6690465435572892,
+    0.10578554691520431,
+    -0.6690465435572892,
+    0.47308734787878004,
+    -1.7701307697799304,
+    0.6258357354491761,
+]   
+
+
+def eval_sh(deg, sh, dirs):
+    """
+    Evaluate spherical harmonics at unit directions
+    using hardcoded SH polynomials.
+    Works with torch/np/jnp.
+    ... Can be 0 or more batch dimensions.
+    Args:
+        deg: int SH deg. Currently, 0-3 supported
+        sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
+        dirs: jnp.ndarray unit directions [..., 3]
+    Returns:
+        [..., C]
+    """
+    assert deg <= 4 and deg >= 0
+    coeff = (deg + 1) ** 2
+    assert sh.shape[-1] >= coeff
+
+    result = C0 * sh[..., 0]
+    if deg > 0:
+        x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
+        result = (result -
+                C1 * y * sh[..., 1] +
+                C1 * z * sh[..., 2] -
+                C1 * x * sh[..., 3])
+
+        if deg > 1:
+            xx, yy, zz = x * x, y * y, z * z
+            xy, yz, xz = x * y, y * z, x * z
+            result = (result +
+                    C2[0] * xy * sh[..., 4] +
+                    C2[1] * yz * sh[..., 5] +
+                    C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
+                    C2[3] * xz * sh[..., 7] +
+                    C2[4] * (xx - yy) * sh[..., 8])
+
+            if deg > 2:
+                result = (result +
+                C3[0] * y * (3 * xx - yy) * sh[..., 9] +
+                C3[1] * xy * z * sh[..., 10] +
+                C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
+                C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
+                C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
+                C3[5] * z * (xx - yy) * sh[..., 14] +
+                C3[6] * x * (xx - 3 * yy) * sh[..., 15])
+
+                if deg > 3:
+                    result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
+                            C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
+                            C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
+                            C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
+                            C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
+                            C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
+                            C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
+                            C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
+                            C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
+    return result
+
+def RGB2SH(rgb):
+    return (rgb - 0.5) / C0
+
+def SH2RGB(sh):
+    return sh * C0 + 0.5

+ 17 - 0
utils/system_utils.py

@@ -0,0 +1,17 @@
+from errno import EEXIST
+from os import makedirs, path
+import os
+
+def mkdir_p(folder_path):
+    # Creates a directory. equivalent to using mkdir -p on the command line
+    try:
+        makedirs(folder_path)
+    except OSError as exc: # Python >2.5
+        if exc.errno == EEXIST and path.isdir(folder_path):
+            pass
+        else:
+            raise
+
+def searchForMaxIteration(folder):
+    saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
+    return max(saved_iters)