_meta_registrations.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import torch
  2. from torch._prims import utils
  3. meta_lib = torch.library.Library("aten", "IMPL", "Meta")
  4. def check(b, s):
  5. if not b:
  6. raise RuntimeError(s)
  7. def toRealValueType(dtype):
  8. from_complex = {
  9. torch.complex32: torch.half,
  10. torch.cfloat: torch.float,
  11. torch.cdouble: torch.double
  12. }
  13. return from_complex.get(dtype, dtype)
  14. # Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
  15. @torch.library.impl(meta_lib, "index_select")
  16. def meta_index_select(self, dim, index):
  17. result_size = list(self.size())
  18. if self.dim() > 0:
  19. result_size[dim] = index.numel()
  20. return self.new_empty(result_size)
  21. @torch.library.impl(meta_lib, "index_select.out")
  22. def meta_index_select_out(self, dim, index, out):
  23. torch._resize_output_(out, self.size(), self.device)
  24. return out.copy_(torch.index_select(self, dim, index))
  25. @torch.library.impl(meta_lib, "abs")
  26. def meta_abs(self):
  27. if self.is_complex():
  28. float_type = toRealValueType(self.dtype)
  29. return self.new_empty(self.size(), dtype=float_type)
  30. else:
  31. return self.new_empty(self.size())
  32. @torch.library.impl(meta_lib, "abs.out")
  33. def meta_abs_out(self, out):
  34. torch._resize_output_(out, self.size(), self.device)
  35. return out.copy_(torch.abs(self))
  36. @torch.library.impl(meta_lib, "max")
  37. def meta_max(self):
  38. return self.new_empty(())
  39. @torch.library.impl(meta_lib, "min")
  40. def meta_min(self):
  41. return self.new_empty(())
  42. def squareCheckInputs(self, f_name):
  43. assert self.dim() >= 2, f"{f_name}: The input tensor must have at least 2 dimensions."
  44. # TODO: I think the error message has the -2 and -1 swapped. If you fix
  45. # it fix the C++ squareCheckInputs too
  46. assert self.size(-1) == self.size(-2), \
  47. f"{f_name}: A must be batches of square matrices, but they are {self.size(-1)} by {self.size(-2)} matrices"
  48. def checkUplo(uplo: str):
  49. uplo_uppercase = uplo.upper()
  50. assert len(uplo) == 1 and uplo_uppercase == 'U' or uplo_uppercase == 'L', \
  51. f"Expected UPLO argument to be 'L' or 'U', but got {uplo}"
  52. @torch.library.impl(meta_lib, "linalg_eigh")
  53. def meta_linalg_eigh(self, uplo="L"):
  54. squareCheckInputs(self, "linalg_eigh")
  55. checkUplo(uplo)
  56. real_dtype = toRealValueType(self.dtype)
  57. assert self.dim() >= 2
  58. values = self.new_empty(self.shape, dtype=real_dtype)
  59. values.transpose_(-2, -1)
  60. vectors = self.new_empty(self.shape[:-1])
  61. return (values, vectors)
  62. @torch.library.impl(meta_lib, "reflection_pad2d")
  63. def meta_pad2d(self, padding):
  64. valid_dims = self.size(1) != 0 and self.size(2) != 0
  65. check(
  66. (self.ndim == 3 and valid_dims)
  67. or (self.ndim == 4 and valid_dims and self.size(3) != 0),
  68. f"3D or 4D (batch mode) tensor expected for input, but got: {self}"
  69. )
  70. if self.ndim == 4:
  71. nbatch, nplane, input_h, input_w = self.shape
  72. else:
  73. nbatch = 1
  74. nplane, input_h, input_w = self.shape
  75. pad_l, pad_r, pad_t, pad_b = padding
  76. output_h = input_h + pad_t + pad_b
  77. output_w = input_w + pad_l + pad_r
  78. if self.ndim == 3:
  79. return self.new_empty((nplane, output_h, output_w))
  80. else:
  81. return self.new_empty((nbatch, nplane, output_h, output_w))
  82. @torch.library.impl(meta_lib, "dot")
  83. def meta_dot(self, tensor):
  84. check(
  85. self.dim() == 1 and tensor.dim() == 1,
  86. f"1D tensors expected, but got {self.dim()}D and {tensor.dim()}D tensors"
  87. )
  88. return self.new_empty(())
  89. @torch.library.impl(meta_lib, "var_mean.correction")
  90. def meta_var_mean_correction(self, dim, *, correction, keepdim=False):
  91. dim = utils.reduction_dims(self.shape, dim)
  92. if keepdim:
  93. output_shape = tuple(self.shape[i] if i not in dim else 1 for i in range(self.ndim))
  94. else:
  95. output_shape = utils.compute_reduction_output_shape(self.shape, dim)
  96. result1 = self.new_empty(output_shape, dtype=toRealValueType(self.dtype))
  97. result2 = self.new_empty(output_shape)
  98. return result1, result2
  99. @torch.library.impl(meta_lib, "inverse")
  100. def meta_inverse(self):
  101. # Bug: https://github.com/pytorch/pytorch/issues/77498
  102. if self.numel() == 0:
  103. return torch.empty_like(self)
  104. r = self.new_empty(self.shape)
  105. r.transpose_(-2, -1)
  106. return r
  107. @torch.library.impl(meta_lib, "bernoulli.out")
  108. def meta_bernoulli(self, *, generator=None, out):
  109. torch._resize_output_(out, self.size(), self.device)
  110. return out
  111. @torch.library.impl(meta_lib, "_adaptive_avg_pool2d")
  112. def meta_adaptive_avg_pool2d(self, output_size):
  113. check(self.ndim == 3 or self.ndim == 4, f"Expected 3D or 4D tensor, but got {self.shape}")
  114. return self.new_empty(self.shape[:-2] + tuple(output_size))
  115. @torch.library.impl(meta_lib, "_adaptive_avg_pool3d")
  116. def meta_adaptive_avg_pool3d(self, output_size):
  117. check(self.ndim == 4 or self.ndim == 5, f"Expected 4D or 5D tensor, but got {self.shape}")
  118. return self.new_empty(self.shape[:-3] + tuple(output_size))