Bladeren bron

Add sort to ensure right order of properties read from ply file

Pythonix Huang 2 jaren geleden
bovenliggende
commit
8c75a90f56
1 gewijzigde bestanden met toevoegingen van 3 en 0 verwijderingen
  1. 3 0
      scene/gaussian_model.py

+ 3 - 0
scene/gaussian_model.py

@@ -226,6 +226,7 @@ class GaussianModel:
         features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
         features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
 
 
         extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
         extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
+        extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
         assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
         assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
         features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
         features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
         for idx, attr_name in enumerate(extra_f_names):
         for idx, attr_name in enumerate(extra_f_names):
@@ -234,11 +235,13 @@ class GaussianModel:
         features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
         features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
 
 
         scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
         scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
+        scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
         scales = np.zeros((xyz.shape[0], len(scale_names)))
         scales = np.zeros((xyz.shape[0], len(scale_names)))
         for idx, attr_name in enumerate(scale_names):
         for idx, attr_name in enumerate(scale_names):
             scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
             scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
 
 
         rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
         rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
+        rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
         rots = np.zeros((xyz.shape[0], len(rot_names)))
         rots = np.zeros((xyz.shape[0], len(rot_names)))
         for idx, attr_name in enumerate(rot_names):
         for idx, attr_name in enumerate(rot_names):
             rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
             rots[:, idx] = np.asarray(plydata.elements[0][attr_name])