ufunc.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. from dataclasses import dataclass
  2. from typing import Union, Optional, List, Tuple, Dict, Sequence
  3. from torchgen.api.translate import translate
  4. from torchgen.model import (
  5. NativeFunctionsGroup,
  6. ScalarType,
  7. UfuncKey,
  8. DispatchKey,
  9. BaseType,
  10. BaseTy,
  11. Argument,
  12. )
  13. import torchgen.api.ufunc as ufunc
  14. from torchgen.api.ufunc import UfunctorBindings
  15. from torchgen.api.types import (
  16. StructuredImplSignature,
  17. scalar_t,
  18. opmath_t,
  19. Binding,
  20. CType,
  21. BaseCType,
  22. Expr,
  23. NamedCType,
  24. ScalarTypeToCppMapping,
  25. VectorizedCType,
  26. )
  27. from torchgen.context import with_native_function
  28. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  29. #
  30. # CUDA STUFF
  31. #
  32. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  33. # NB: not bothering to generate dispatch stub forward declaration in header,
  34. # we can just paste it whereever necessary
  35. # TODO: use BackendIndex
  36. # dispatch_key: DispatchKey # only CPU/CUDA right now
  37. # Represents functors for implementing CUDA ufuncs.
  38. # Functors are templated by scalar_t because when USERS instantiate functors
  39. # they are templated. A functor looks something like this:
  40. #
  41. # template <typename scalar_t>
  42. # struct CUDAFunctorOnSelf_add {
  43. # using opmath_t = at::opmath_type<scalar_t>;
  44. # opmath_t other_;
  45. # opmath_t alpha_;
  46. # CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
  47. # : other_(other), alpha_(alpha) {}
  48. # __device__ scalar_t operator()(scalar_t self) {
  49. # return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
  50. # }
  51. # };
  52. #
  53. @dataclass(frozen=True)
  54. class UfunctorSignature:
  55. g: NativeFunctionsGroup
  56. scalar_tensor_idx: Optional[int]
  57. name: str
  58. def arguments(self) -> UfunctorBindings:
  59. return ufunc.ufunctor_arguments(
  60. self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
  61. )
  62. def fields(self) -> List[Binding]:
  63. # fields are renamed to have a trailing underscore, as is conventional
  64. return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
  65. def returns_type(self) -> CType:
  66. # TODO: don't hardcode; return type will be inferred based on tags on
  67. # the native function
  68. return BaseCType(scalar_t)
  69. def decl_fields(self) -> str:
  70. return "\n".join(f"{f.type} {f.name};" for f in self.fields())
  71. def inline_defn_ctor(self) -> str:
  72. args_str = ", ".join(a.decl() for a in self.arguments().ctor)
  73. # NB: hypothetically could do this with translate but the
  74. # transition here is very regular
  75. init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
  76. return f"{self.name}({args_str}) : {init_str} {{}}"
  77. def decl_apply(self) -> str:
  78. args_str = ", ".join(a.decl() for a in self.arguments().apply)
  79. return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
  80. @dataclass(frozen=True)
  81. class UfuncSignature:
  82. g: NativeFunctionsGroup
  83. name: str
  84. compute_t: CType
  85. def arguments(self) -> List[Binding]:
  86. return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
  87. def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str:
  88. return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
  89. # steps:
  90. # 1. take the functional signature
  91. # 2. use api.ufunc to convert it to template signature. this establishes
  92. # the type of the template function
  93. # 3. use api.ufunc (II) to generate a split struct / operator() signature.
  94. # this establish context in which we call the template signature
  95. #
  96. # StructuredImplSignature context
  97. # ~> functor constructor sig
  98. #
  99. # Functor constructor context
  100. # ~> functor fields sig
  101. #
  102. # Functor apply context (functor fields + functor apply sig)
  103. # ~> template sig
  104. #
  105. def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
  106. num_tensors = sum(
  107. 1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
  108. )
  109. return num_tensors == 2
  110. def compute_ufunc_cuda_functors(
  111. g: NativeFunctionsGroup,
  112. ) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]:
  113. # First, build the functors.
  114. ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {}
  115. ufunctors: List[str] = []
  116. loops = g.out.ufunc_inner_loop
  117. scalar_tensor_idx_lookup = {
  118. UfuncKey.CUDAFunctorOnSelf: 1,
  119. UfuncKey.CUDAFunctorOnOther: 0,
  120. UfuncKey.CUDAFunctor: None,
  121. }
  122. if eligible_for_binary_scalar_specialization(g):
  123. keys = [
  124. UfuncKey.CUDAFunctorOnSelf,
  125. UfuncKey.CUDAFunctorOnOther,
  126. UfuncKey.CUDAFunctor,
  127. ]
  128. else:
  129. keys = [UfuncKey.CUDAFunctor]
  130. for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
  131. assert k not in loops, f"cannot use {k} on non-binary function"
  132. for k in keys:
  133. # If the key was directly defined, skip functor codegen; we assume the
  134. # user already done it for us
  135. if k in loops:
  136. ufunctor_sig = UfunctorSignature(
  137. g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
  138. )
  139. for dtype in loops[k].supported_dtypes:
  140. ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
  141. continue
  142. # Note [ScalarOnly and Generic must match names for CUDA]
  143. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  144. # Otherwise, look in ANY of the generic entries. For simplicity of
  145. # codegen, both ScalarOnly and Generic are defined, the ufunc name
  146. # must match (if they didn't match, we'd have to generate distinct
  147. # functors per dtype, which is awful, so we're not going to do it unless
  148. # someone really forces us to)
  149. ufunc_name = None
  150. supported_dtypes = set()
  151. for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
  152. if lk not in loops:
  153. continue
  154. if ufunc_name is None:
  155. ufunc_name = loops[lk].name
  156. else:
  157. # See Note [ScalarOnly and Generic must match names for CUDA]
  158. assert (
  159. ufunc_name == loops[lk].name
  160. ), "ScalarOnly and Generic must have same ufunc name"
  161. supported_dtypes |= loops[lk].supported_dtypes
  162. assert ufunc_name is not None
  163. name = f"{k}_{ufunc_name}"
  164. ufunctor_sig = UfunctorSignature(
  165. g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
  166. )
  167. for dtype in supported_dtypes:
  168. ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
  169. ufunc_sig = UfuncSignature(
  170. g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
  171. )
  172. apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
  173. ufunctors.append(
  174. f"""
  175. template <typename scalar_t>
  176. struct {ufunctor_sig.name} {{
  177. using opmath_t = at::opmath_type<scalar_t>;
  178. {ufunctor_sig.decl_fields()}
  179. {ufunctor_sig.inline_defn_ctor()}
  180. __device__ {ufunctor_sig.decl_apply()} {{
  181. return {ufunc_sig.call(apply_ctx)};
  182. }}
  183. }};
  184. """
  185. )
  186. return ufunctor_sigs, "\n".join(ufunctors)
  187. @dataclass(frozen=True)
  188. class BinaryScalarSpecializationConfig:
  189. scalar_idx: int
  190. ctor_tensor: str
  191. ufunc_key: UfuncKey
  192. BinaryScalarSpecializationConfigs = [
  193. BinaryScalarSpecializationConfig(
  194. scalar_idx=0,
  195. ctor_tensor="self",
  196. ufunc_key=UfuncKey.CUDAFunctorOnOther,
  197. ),
  198. BinaryScalarSpecializationConfig(
  199. scalar_idx=1,
  200. ctor_tensor="other",
  201. ufunc_key=UfuncKey.CUDAFunctorOnSelf,
  202. ),
  203. ]
  204. def compute_ufunc_cuda_dtype_body(
  205. g: NativeFunctionsGroup,
  206. dtype: ScalarType,
  207. inner_loops: Dict[UfuncKey, UfunctorSignature],
  208. parent_ctx: Sequence[Binding],
  209. ) -> str:
  210. body = "using opmath_t = at::opmath_type<scalar_t>;"
  211. body += "if (false) {}\n" # for ease of codegen
  212. for config in BinaryScalarSpecializationConfigs:
  213. if config.ufunc_key not in inner_loops:
  214. continue
  215. ufunctor_sig = inner_loops[config.ufunc_key]
  216. scalar_idx = config.scalar_idx + 1
  217. # Make a copy and at the same time widen the type (not permissible
  218. # without copy; we don't want to mutate the input argument anyway)
  219. ctx: List[Union[Expr, Binding]] = list(parent_ctx)
  220. ctx.append(
  221. Expr(
  222. expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
  223. type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
  224. )
  225. )
  226. ufunctor_ctor_exprs_str = ", ".join(
  227. a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
  228. )
  229. # NB: ufunctor must be allocated before iter.remove_operand is called,
  230. # as it relies on iter
  231. body += f"""\
  232. else if (iter.is_cpu_scalar({scalar_idx})) {{
  233. {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
  234. iter.remove_operand({scalar_idx});
  235. gpu_kernel(iter, ufunctor);
  236. }}"""
  237. ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
  238. ufunctor_ctor_exprs_str = ", ".join(
  239. a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
  240. )
  241. body += f"""
  242. else {{
  243. gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
  244. }}
  245. """
  246. return body
  247. @with_native_function
  248. def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
  249. # First, build the functors, indexing them by dtype
  250. ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
  251. # Next, build the conditionals
  252. sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
  253. dtype_cases = []
  254. for dtype, inner_ufunctor_sigs in ufunctor_sigs.items():
  255. dtype_cases.append(
  256. f"""
  257. AT_PRIVATE_CASE_TYPE("{sig.name}", at::ScalarType::{dtype}, {ScalarTypeToCppMapping[dtype]},
  258. [&]() {{
  259. {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunctor_sigs, sig.arguments())}
  260. }}
  261. )
  262. """
  263. )
  264. dtype_cases_str = "\n".join(dtype_cases)
  265. stub_sig = StubSignature(g)
  266. return f"""
  267. {ufunctors}
  268. {stub_sig.type_defn()};
  269. {stub_sig.dispatch_decl()};
  270. {stub_sig.kernel_defn()} {{
  271. at::ScalarType st = iter.common_dtype();
  272. RECORD_KERNEL_FUNCTION_DTYPE("{sig.name}", st);
  273. switch (st) {{
  274. {dtype_cases_str}
  275. default:
  276. TORCH_CHECK(false, "{sig.name}", " not implemented for '", toString(st), "'");
  277. }}
  278. }}
  279. REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
  280. {sig.defn()} {{
  281. {stub_sig.direct_call(sig.arguments())};
  282. }}
  283. """
  284. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  285. #
  286. # CPU STUFF
  287. #
  288. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  289. @dataclass(frozen=True)
  290. class StubSignature:
  291. g: NativeFunctionsGroup
  292. @property
  293. def name(self) -> str:
  294. return f"{str(self.g.functional.func.name.name)}_stub"
  295. @property
  296. def kernel_name(self) -> str:
  297. return f"{str(self.g.functional.func.name.name)}_kernel"
  298. @property
  299. def type_name(self) -> str:
  300. return f"{str(self.g.functional.func.name.name)}_fn"
  301. def arguments(self) -> List[Binding]:
  302. return ufunc.stub_arguments(self.g)
  303. def type(self) -> str:
  304. cpp_args = self.arguments()
  305. return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
  306. def dispatch_decl(self) -> str:
  307. return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
  308. def dispatch_defn(self) -> str:
  309. return f"DEFINE_DISPATCH({self.name})"
  310. def kernel_defn(self) -> str:
  311. return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
  312. def type_defn(self) -> str:
  313. return f"using {self.type_name} = {self.type()}"
  314. # must be called from context where this is TensorIteratorBase*
  315. def call(self, ctx: Sequence[Binding]) -> str:
  316. return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
  317. # used in CUDA to skip the unnecessary dynamic dispatch
  318. def direct_call(self, ctx: Sequence[Binding]) -> str:
  319. return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
  320. @with_native_function
  321. def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
  322. stub_sig = StubSignature(g)
  323. sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
  324. return f"""
  325. {stub_sig.type_defn()};
  326. {stub_sig.dispatch_decl()};
  327. {stub_sig.dispatch_defn()};
  328. {sig.defn()} {{
  329. {stub_sig.call(sig.arguments())};
  330. }}
  331. """
  332. def compute_ufunc_cpu_dtype_body(
  333. g: NativeFunctionsGroup,
  334. dtype: ScalarType,
  335. inner_loops: Dict[UfuncKey, UfuncSignature],
  336. parent_ctx: Sequence[Binding],
  337. ) -> str:
  338. assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
  339. assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
  340. scalar_loop = inner_loops[UfuncKey.CPUScalar]
  341. vec_loop = None
  342. if UfuncKey.CPUVector in inner_loops:
  343. vec_loop = inner_loops[UfuncKey.CPUVector]
  344. # NB: We DON'T use translate here, because translate is
  345. # incapable of CSE'ing the scalar accesses in case it is also
  346. # used by Vectorized; also, the unpacking here is very simple
  347. # and only affects Scalar; everything else is implicitly captured
  348. # by the lambda
  349. # Setup scalar in scope
  350. body = []
  351. ctx = []
  352. for b in parent_ctx:
  353. if isinstance(b.argument, Argument) and b.argument.type != BaseType(
  354. BaseTy.Scalar
  355. ):
  356. continue
  357. body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
  358. ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
  359. if vec_loop is not None:
  360. for b in parent_ctx:
  361. if isinstance(b.argument, Argument) and b.argument.type != BaseType(
  362. BaseTy.Scalar
  363. ):
  364. continue
  365. body.append(
  366. f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
  367. )
  368. ctx.append(
  369. Expr(
  370. f"_v_{b.name}",
  371. NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
  372. )
  373. )
  374. # Setup lambda signature
  375. # NB: simplified version of ufunctor_arguments
  376. scalar_bindings = []
  377. vec_bindings = []
  378. for a in g.functional.func.arguments.flat_non_out:
  379. if not a.type.is_tensor_like():
  380. continue
  381. assert a.type == BaseType(BaseTy.Tensor)
  382. scalar_bindings.append(
  383. Binding(
  384. name=a.name,
  385. nctype=NamedCType(a.name, BaseCType(scalar_t)),
  386. argument=a,
  387. )
  388. )
  389. if vec_loop is not None:
  390. vec_bindings.append(
  391. Binding(
  392. name=a.name,
  393. nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
  394. argument=a,
  395. )
  396. )
  397. def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]:
  398. r: List[Union[Expr, Binding]] = []
  399. r.extend(ctx)
  400. r.extend(b)
  401. return r
  402. body_str = "\n".join(body)
  403. if vec_loop is not None:
  404. return f"""
  405. {body_str}
  406. cpu_kernel_vec(iter,
  407. [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
  408. [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
  409. );
  410. """
  411. else:
  412. return f"""
  413. {body_str}
  414. cpu_kernel(iter,
  415. [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
  416. );
  417. """
  418. @with_native_function
  419. def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
  420. stub_sig = StubSignature(g)
  421. # Reindex the ufunc by dtypes; processing generic/scalaronly as well
  422. loops = g.out.ufunc_inner_loop
  423. ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {}
  424. for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
  425. lks = []
  426. # ORDER MATTERS: this specifies overriding precedence
  427. if k in loops: # should happen rarely
  428. lks.append(k)
  429. if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
  430. lks.append(UfuncKey.ScalarOnly)
  431. if UfuncKey.Generic in loops:
  432. lks.append(UfuncKey.Generic)
  433. # TODO: don't hardcode ufunc:: namespace here, should be centralized smh
  434. for lk in lks:
  435. for dtype in loops[lk].supported_dtypes:
  436. compute_t: CType
  437. if k is UfuncKey.CPUScalar:
  438. compute_t = BaseCType(scalar_t)
  439. elif k is UfuncKey.CPUVector:
  440. compute_t = VectorizedCType(BaseCType(scalar_t))
  441. else:
  442. raise AssertionError()
  443. inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
  444. if k not in inner_ufunc_sigs:
  445. inner_ufunc_sigs[k] = UfuncSignature(
  446. g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
  447. )
  448. # Build the conditionals
  449. dtype_cases = []
  450. for dtype, inner_ufunc_sigs in ufunc_sigs.items():
  451. dtype_cases.append(
  452. f"""
  453. AT_PRIVATE_CASE_TYPE("{stub_sig.name}", at::ScalarType::{dtype}, {ScalarTypeToCppMapping[dtype]},
  454. [&]() {{
  455. {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
  456. }}
  457. )
  458. """
  459. )
  460. dtype_cases_str = "\n".join(dtype_cases)
  461. return f"""
  462. namespace {{
  463. {stub_sig.kernel_defn()} {{
  464. at::ScalarType st = iter.common_dtype();
  465. RECORD_KERNEL_FUNCTION_DTYPE("{stub_sig.name}", st);
  466. switch (st) {{
  467. {dtype_cases_str}
  468. default:
  469. TORCH_CHECK(false, "{stub_sig.name}", " not implemented for '", toString(st), "'");
  470. }}
  471. }}
  472. }} // anonymous namespace
  473. {stub_sig.type_defn()};
  474. {stub_sig.dispatch_decl()};
  475. REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
  476. """