general_utils.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import torch
  2. import sys
  3. from datetime import datetime
  4. import numpy as np
  5. import random
  6. def inverse_sigmoid(x):
  7. return torch.log(x/(1-x))
  8. def PILtoTorch(pil_image, resolution):
  9. resized_image_PIL = pil_image.resize(resolution)
  10. resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
  11. if len(resized_image.shape) == 3:
  12. return resized_image.permute(2, 0, 1)
  13. else:
  14. return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
  15. def get_expon_lr_func(
  16. lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
  17. ):
  18. """
  19. Copied from Plenoxels
  20. Continuous learning rate decay function. Adapted from JaxNeRF
  21. The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
  22. is log-linearly interpolated elsewhere (equivalent to exponential decay).
  23. If lr_delay_steps>0 then the learning rate will be scaled by some smooth
  24. function of lr_delay_mult, such that the initial learning rate is
  25. lr_init*lr_delay_mult at the beginning of optimization but will be eased back
  26. to the normal learning rate when steps>lr_delay_steps.
  27. :param conf: config subtree 'lr' or similar
  28. :param max_steps: int, the number of steps during optimization.
  29. :return HoF which takes step as input
  30. """
  31. def helper(step):
  32. if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
  33. # Disable this parameter
  34. return 0.0
  35. if lr_delay_steps > 0:
  36. # A kind of reverse cosine decay.
  37. delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
  38. 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
  39. )
  40. else:
  41. delay_rate = 1.0
  42. t = np.clip(step / max_steps, 0, 1)
  43. log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
  44. return delay_rate * log_lerp
  45. return helper
  46. def strip_lowerdiag(L):
  47. uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
  48. uncertainty[:, 0] = L[:, 0, 0]
  49. uncertainty[:, 1] = L[:, 0, 1]
  50. uncertainty[:, 2] = L[:, 0, 2]
  51. uncertainty[:, 3] = L[:, 1, 1]
  52. uncertainty[:, 4] = L[:, 1, 2]
  53. uncertainty[:, 5] = L[:, 2, 2]
  54. return uncertainty
  55. def strip_symmetric(sym):
  56. return strip_lowerdiag(sym)
  57. def build_rotation(r):
  58. norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
  59. q = r / norm[:, None]
  60. R = torch.zeros((q.size(0), 3, 3), device='cuda')
  61. r = q[:, 0]
  62. x = q[:, 1]
  63. y = q[:, 2]
  64. z = q[:, 3]
  65. R[:, 0, 0] = 1 - 2 * (y*y + z*z)
  66. R[:, 0, 1] = 2 * (x*y - r*z)
  67. R[:, 0, 2] = 2 * (x*z + r*y)
  68. R[:, 1, 0] = 2 * (x*y + r*z)
  69. R[:, 1, 1] = 1 - 2 * (x*x + z*z)
  70. R[:, 1, 2] = 2 * (y*z - r*x)
  71. R[:, 2, 0] = 2 * (x*z - r*y)
  72. R[:, 2, 1] = 2 * (y*z + r*x)
  73. R[:, 2, 2] = 1 - 2 * (x*x + y*y)
  74. return R
  75. def build_scaling_rotation(s, r):
  76. L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
  77. R = build_rotation(r)
  78. L[:,0,0] = s[:,0]
  79. L[:,1,1] = s[:,1]
  80. L[:,2,2] = s[:,2]
  81. L = R @ L
  82. return L
  83. def safe_state(silent):
  84. old_f = sys.stdout
  85. class F:
  86. def __init__(self, silent):
  87. self.silent = silent
  88. def write(self, x):
  89. if not self.silent:
  90. if x.endswith("\n"):
  91. old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
  92. else:
  93. old_f.write(x)
  94. def flush(self):
  95. old_f.flush()
  96. sys.stdout = F(silent)
  97. random.seed(0)
  98. np.random.seed(0)
  99. torch.manual_seed(0)
  100. torch.cuda.set_device(torch.device("cuda:0"))