general_utils.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. #
  2. # Copyright (C) 2023, Inria
  3. # GRAPHDECO research group, https://team.inria.fr/graphdeco
  4. # All rights reserved.
  5. #
  6. # This software is free for non-commercial, research and evaluation use
  7. # under the terms of the LICENSE.md file.
  8. #
  9. # For inquiries contact george.drettakis@inria.fr
  10. #
  11. import torch
  12. import sys
  13. from datetime import datetime
  14. import numpy as np
  15. import random
  16. def inverse_sigmoid(x):
  17. return torch.log(x/(1-x))
  18. def PILtoTorch(pil_image, resolution):
  19. resized_image_PIL = pil_image.resize(resolution)
  20. resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
  21. if len(resized_image.shape) == 3:
  22. return resized_image.permute(2, 0, 1)
  23. else:
  24. return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
  25. def get_expon_lr_func(
  26. lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
  27. ):
  28. """
  29. Copied from Plenoxels
  30. Continuous learning rate decay function. Adapted from JaxNeRF
  31. The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
  32. is log-linearly interpolated elsewhere (equivalent to exponential decay).
  33. If lr_delay_steps>0 then the learning rate will be scaled by some smooth
  34. function of lr_delay_mult, such that the initial learning rate is
  35. lr_init*lr_delay_mult at the beginning of optimization but will be eased back
  36. to the normal learning rate when steps>lr_delay_steps.
  37. :param conf: config subtree 'lr' or similar
  38. :param max_steps: int, the number of steps during optimization.
  39. :return HoF which takes step as input
  40. """
  41. def helper(step):
  42. if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
  43. # Disable this parameter
  44. return 0.0
  45. if lr_delay_steps > 0:
  46. # A kind of reverse cosine decay.
  47. delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
  48. 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
  49. )
  50. else:
  51. delay_rate = 1.0
  52. t = np.clip(step / max_steps, 0, 1)
  53. log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
  54. return delay_rate * log_lerp
  55. return helper
  56. def strip_lowerdiag(L):
  57. uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
  58. uncertainty[:, 0] = L[:, 0, 0]
  59. uncertainty[:, 1] = L[:, 0, 1]
  60. uncertainty[:, 2] = L[:, 0, 2]
  61. uncertainty[:, 3] = L[:, 1, 1]
  62. uncertainty[:, 4] = L[:, 1, 2]
  63. uncertainty[:, 5] = L[:, 2, 2]
  64. return uncertainty
  65. def strip_symmetric(sym):
  66. return strip_lowerdiag(sym)
  67. def build_rotation(r):
  68. norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
  69. q = r / norm[:, None]
  70. R = torch.zeros((q.size(0), 3, 3), device='cuda')
  71. r = q[:, 0]
  72. x = q[:, 1]
  73. y = q[:, 2]
  74. z = q[:, 3]
  75. R[:, 0, 0] = 1 - 2 * (y*y + z*z)
  76. R[:, 0, 1] = 2 * (x*y - r*z)
  77. R[:, 0, 2] = 2 * (x*z + r*y)
  78. R[:, 1, 0] = 2 * (x*y + r*z)
  79. R[:, 1, 1] = 1 - 2 * (x*x + z*z)
  80. R[:, 1, 2] = 2 * (y*z - r*x)
  81. R[:, 2, 0] = 2 * (x*z - r*y)
  82. R[:, 2, 1] = 2 * (y*z + r*x)
  83. R[:, 2, 2] = 1 - 2 * (x*x + y*y)
  84. return R
  85. def build_scaling_rotation(s, r):
  86. L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
  87. R = build_rotation(r)
  88. L[:,0,0] = s[:,0]
  89. L[:,1,1] = s[:,1]
  90. L[:,2,2] = s[:,2]
  91. L = R @ L
  92. return L
  93. def safe_state(silent):
  94. old_f = sys.stdout
  95. class F:
  96. def __init__(self, silent):
  97. self.silent = silent
  98. def write(self, x):
  99. if not self.silent:
  100. if x.endswith("\n"):
  101. old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
  102. else:
  103. old_f.write(x)
  104. def flush(self):
  105. old_f.flush()
  106. sys.stdout = F(silent)
  107. random.seed(0)
  108. np.random.seed(0)
  109. torch.manual_seed(0)
  110. torch.cuda.set_device(torch.device("cuda:0"))