ufunc.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. from torchgen.model import (
  2. Argument,
  3. BaseTy,
  4. BaseType,
  5. FunctionSchema,
  6. NativeFunctionsGroup,
  7. Type,
  8. DispatchKey,
  9. )
  10. import torchgen.api.types as api_types
  11. from torchgen.api.types import (
  12. ArgName,
  13. BaseCType,
  14. Binding,
  15. ConstRefCType,
  16. NamedCType,
  17. scalarT,
  18. CType,
  19. BaseCppType,
  20. )
  21. from torchgen.api import cpp, structured
  22. from dataclasses import dataclass
  23. from typing import List, Optional
  24. def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
  25. assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas"
  26. return f"ufunc_{func.name.name}_{dispatch_key}"
  27. def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
  28. return schema_kernel_name(g.out.func, dispatch_key)
  29. # Tensors are omitted (as they are stored in TensorIterator), everything else is
  30. # passed along (technically, we can pass tensors along too, it just wastes
  31. # argument registers)
  32. #
  33. # NB: used for CPU only
  34. def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]:
  35. r = cpp.valuetype_type(t, binds=binds)
  36. if r is not None:
  37. return r
  38. if t == BaseType(BaseTy.Scalar):
  39. return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
  40. elif t == BaseType(BaseTy.Tensor):
  41. return None
  42. else:
  43. raise AssertionError(f"unrecognized type {repr(t)}")
  44. def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
  45. if scalar_t == api_types.scalar_t:
  46. return api_types.opmath_t
  47. raise NotImplementedError
  48. # NB: Tensors in constructor are stored in opmath_t, not scalar_t
  49. # because Tensor in constructor = its a scalar tensor partially applied =
  50. # it can be higher precision and we want to compute in that higher precision
  51. #
  52. # NB: CUDA only
  53. def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
  54. r = cpp.valuetype_type(t, binds=binds)
  55. if r is not None:
  56. return r
  57. if t == BaseType(BaseTy.Scalar):
  58. return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
  59. elif t == BaseType(BaseTy.Tensor):
  60. return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
  61. else:
  62. raise AssertionError(f"unrecognized type {repr(t)}")
  63. # Only Tensors ever get passed directly to operator()
  64. #
  65. # NB: CUDA only
  66. # (Actually, this works for CPU too)
  67. def ufunctor_apply_type(
  68. t: Type, *, binds: ArgName, scalar_t: BaseCppType
  69. ) -> NamedCType:
  70. if t == BaseType(BaseTy.Tensor):
  71. return NamedCType(binds, BaseCType(scalar_t))
  72. else:
  73. raise AssertionError(f"unrecognized type {repr(t)}")
  74. # The actual ufunc template function the user writes. Everything here
  75. # is done in the computation type. compute_t is opmath_t in CUDA and scalar_t
  76. # in CPU
  77. def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
  78. r = cpp.valuetype_type(t, binds=binds)
  79. if r is not None:
  80. return r
  81. if t == BaseType(BaseTy.Scalar):
  82. return NamedCType(binds, compute_t)
  83. elif t == BaseType(BaseTy.Tensor):
  84. return NamedCType(binds, compute_t)
  85. else:
  86. raise AssertionError(f"unrecognized type {repr(t)}")
  87. def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
  88. return Binding(
  89. nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t),
  90. name=a.name,
  91. default=None,
  92. argument=a,
  93. )
  94. def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
  95. return Binding(
  96. nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t),
  97. name=a.name,
  98. default=None,
  99. argument=a,
  100. )
  101. def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
  102. return Binding(
  103. nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t),
  104. name=a.name,
  105. default=None,
  106. argument=a,
  107. )
  108. @dataclass(frozen=True)
  109. class UfunctorBindings:
  110. ctor: List[Binding]
  111. apply: List[Binding]
  112. # ufunctors are a CUDA-only concept representing functors that take some of
  113. # their arguments on a host-side constructor, and the rest in the device-side
  114. # apply. E.g.,
  115. #
  116. # template <typename scalar_t>
  117. # struct CUDAFunctorOnSelf_add {
  118. # using opmath_t = at::opmath_type<scalar_t>;
  119. # opmath_t other_;
  120. # opmath_t alpha_;
  121. # CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {}
  122. # __device__ scalar_t operator()(scalar_t self) {
  123. # return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
  124. # }
  125. # };
  126. #
  127. # The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
  128. # to the operator() definition
  129. def ufunctor_arguments(
  130. g: NativeFunctionsGroup, *, scalar_tensor_idx: Optional[int], scalar_t: BaseCppType
  131. ) -> UfunctorBindings:
  132. ctor = []
  133. apply = []
  134. for a in g.functional.func.arguments.flat_non_out:
  135. if a.type.is_tensor_like():
  136. if scalar_tensor_idx == 0:
  137. # put it in the ctor anyway
  138. ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
  139. scalar_tensor_idx = None
  140. else:
  141. if scalar_tensor_idx is not None:
  142. scalar_tensor_idx -= 1
  143. apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t))
  144. else:
  145. ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
  146. assert scalar_tensor_idx is None
  147. return UfunctorBindings(ctor=ctor, apply=apply)
  148. # ufuncs are the inner loop template functions that you wrote in ufunc/add.h
  149. # which do the actual computation in question. E.g.,
  150. #
  151. # template <typename T>
  152. # C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ {
  153. # return self + alpha * other;
  154. # }
  155. #
  156. # In this file, we refer to T as compute_t which is bound by caller
  157. def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> List[Binding]:
  158. return [
  159. ufunc_argument(a, compute_t=compute_t)
  160. for a in g.functional.func.arguments.flat_non_out
  161. ]
  162. # Stubs are the DispatchStub trampolines that CPU kernels use to get to their
  163. # vectorized versions. E.g.,
  164. #
  165. # using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
  166. # DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
  167. def stub_arguments(g: NativeFunctionsGroup) -> List[Binding]:
  168. # stubs drop all tensor arguments (they are implicit in the TensorIterator
  169. # argument and keep everything else)
  170. return [
  171. r
  172. for a in g.out.func.arguments.flat_non_out
  173. if not a.type.is_tensor_like()
  174. for r in structured.argument(a)
  175. ]