native_function_generation.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. from torchgen.model import (
  2. Argument,
  3. DispatchKey,
  4. FunctionSchema,
  5. BaseType,
  6. BaseTy,
  7. Return,
  8. Annotation,
  9. NativeFunction,
  10. OperatorName,
  11. BackendIndex,
  12. BackendMetadata,
  13. DeviceCheckType,
  14. SchemaKind,
  15. Variant,
  16. )
  17. from torchgen.utils import (
  18. concatMap,
  19. )
  20. from typing import List, Tuple, Sequence, Dict
  21. from collections import defaultdict
  22. # See Note: [Out ops with functional variants that don't get grouped properly]
  23. OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
  24. # This has a functional variant, but it's currently marked private.
  25. # This function should be marked private as well (*_backward ops aren't exposed to python anyway).
  26. "adaptive_avg_pool3d_backward.grad_input",
  27. # There's a functional variant, _slow_conv2d_backward.output_mask, that isn't grouped properly.
  28. # Maybe we can kill this operator in favor of convolution_backward?
  29. "_slow_conv2d_backward.grad_input",
  30. ]
  31. # See Note: [Mutable ops that cannot get an out variant]
  32. MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
  33. # should be out=?
  34. "_cummax_helper",
  35. # should be out=?
  36. "_cummin_helper",
  37. ]
  38. INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
  39. # polygamma and polygamma.out both exist, but have a
  40. # pre-self arg (while polygamma_ does not)
  41. # We should either fix this schema so it can be grouped properly,
  42. # or allow the codegen to generate new functional/out= NativeFunctions for this op
  43. # (which would require changing its overload name to prevent overload ambiguity).
  44. "polygamma_"
  45. ]
  46. # Groups "similar" NativeFunctions together
  47. # example add.Tensor, add_.Tensor, add.out
  48. # "similar" NativeFunctions are all expected to have an identical `signature()`,
  49. # But have differing SchemaKinds.
  50. def pre_group_native_functions(
  51. native_functions: Sequence[NativeFunction],
  52. ) -> Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]:
  53. pre_grouped_native_functions: Dict[
  54. FunctionSchema, Dict[SchemaKind, NativeFunction]
  55. ] = defaultdict(dict)
  56. for f in native_functions:
  57. d = pre_grouped_native_functions[f.func.signature()]
  58. assert f.func.kind() not in d
  59. d[f.func.kind()] = f
  60. return pre_grouped_native_functions
  61. # Helper function: given an inplace FunctionSchema, generate its corresponding out= variant
  62. # Example before:
  63. # _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
  64. # Example after:
  65. # _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out)
  66. def self_to_out_signature(func: FunctionSchema) -> FunctionSchema:
  67. # Generating an out= schema from an inplace schema.
  68. assert func.kind() == SchemaKind.inplace
  69. assert func.arguments.self_arg is not None
  70. # The new out= schema has:
  71. # - a new out argument with the same type as "func" (but with a mutable annotation)
  72. # - The returns (if any) now alias the out= argument instead of "func"
  73. # - an "out" overload name
  74. return FunctionSchema(
  75. name=func.name.remove_inplace().with_overload(
  76. "out" if not func.name.overload_name else f"{func.name.overload_name}_out"
  77. ),
  78. arguments=func.arguments.remove_self_annotation().with_out_args(
  79. [
  80. Argument(
  81. name="out",
  82. type=func.arguments.self_arg.argument.type,
  83. default=None,
  84. annotation=func.arguments.self_arg.argument.annotation,
  85. )
  86. ]
  87. ),
  88. returns=func.returns,
  89. )
  90. # Helper function: given a mutable FunctionSchema, generate its corresponding out= variant
  91. # Example before:
  92. # _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950
  93. # Example after:
  94. # _fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) # noqa: B950
  95. def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
  96. # Generating an out= schema from a mutable schema.
  97. assert func.kind() == SchemaKind.mutable
  98. # The new out= schema has:
  99. # - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments
  100. # (if the argument is a tensor then we also return it for method chaining,
  101. # otherwise we return nothing)
  102. # - an "out" overload name
  103. #
  104. # Note that:
  105. # (1) This also means that we can *only* generate an out= variant from a mutable schema
  106. # if the mutable schema has at least one tensor-like non-aliasing return.
  107. # (2) The generated out= variant still has mutable positional arguments,
  108. # but if necessary we could probably add another out= variant that also
  109. # functionalizes the mutable arguments (a functional_out variant)
  110. # More of a sanity check - our existing restrictions on schemas should enforce that
  111. # mutable schema kinds never return their mutable arguments.
  112. assert not any(
  113. r.annotation is not None and r.annotation.is_write for r in func.returns
  114. )
  115. tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()]
  116. assert len(tensorlike_rets) > 0
  117. used_annotations = concatMap(
  118. lambda a: [] if a.annotation is None else a.annotation.alias_set,
  119. func.arguments.flat_all,
  120. )
  121. valid_annotations = [
  122. x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations
  123. ]
  124. all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)
  125. new_out_args: List[Argument] = []
  126. # The end result of new_returns is that:
  127. # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
  128. # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
  129. new_returns: List[Return] = []
  130. for (i, r) in enumerate(func.returns):
  131. if r.type.is_tensor_like():
  132. new_out = Argument(
  133. name=f"out{i}",
  134. type=r.type,
  135. default=None,
  136. annotation=Annotation.parse(f"{valid_annotations[i]}!"),
  137. )
  138. new_out_args.append(new_out)
  139. if all_rets_are_tensors:
  140. # The convention for out= schemas is that they only return their out arguments
  141. # if the return is a plain Tensor (or if it's a tuple of plain Tensors)
  142. new_ret = Return(
  143. name=None, type=new_out.type, annotation=new_out.annotation
  144. )
  145. new_returns.append(new_ret)
  146. else:
  147. new_returns.append(r)
  148. return FunctionSchema(
  149. name=func.name.remove_inplace().with_overload(
  150. "out" if not func.name.overload_name else f"{func.name.overload_name}_out"
  151. ),
  152. arguments=func.arguments.with_out_args(new_out_args),
  153. returns=tuple(new_returns),
  154. )
  155. # This function, given function of one SchemaKind, as well as a target SchemaKind,
  156. # generates a new NativeFunction with the same properties, but using the target SchemaKind.
  157. # We only actually generate functions for either functional or out= SchemaKinds.
  158. # This function returns a tuple, with:
  159. # - The generated NativeFunction
  160. # - a dictionary of `BackendIndex` objects, describing which dispatch keys
  161. # we will generate kernels for, for the new NativeFunction.
  162. # Details are in the function, but we only generate composite kernels (in some cases) today.
  163. def generate_function(
  164. f: NativeFunction, k: SchemaKind
  165. ) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]]:
  166. from torchgen.api import cpp
  167. if k == SchemaKind.functional:
  168. assert f.func.kind() != SchemaKind.functional
  169. gets_composite_kernel = True
  170. # The new "functional" NativeFunction has:
  171. # - any mutable arguments have been converted into (immutable) returns.
  172. # (if a mutable argument was not also a return, it gets converted to one)
  173. # - a "functional" overload name.
  174. # The default grouping logic in signature() actually already does this,
  175. # so we can piggy-back off it (but we still want return names)
  176. func = f.func.signature(keep_return_names=True).with_name(
  177. f.func.name.remove_inplace().with_overload(
  178. "functional"
  179. if not f.func.name.overload_name
  180. else f"{f.func.name.overload_name}_functional"
  181. )
  182. )
  183. elif k == SchemaKind.out:
  184. # We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily,
  185. # but at least today, there is no good reason to actually use them.
  186. # we'll generate a dispatcher entry for them, but won't actually register any kernels for them.
  187. gets_composite_kernel = False
  188. if f.func.kind() == SchemaKind.inplace:
  189. func = self_to_out_signature(f.func)
  190. elif f.func.kind() == SchemaKind.mutable:
  191. func = mutable_to_out_signature(f.func)
  192. else:
  193. raise AssertionError(
  194. "We only bother generating out= functions from either inplace or mutable variants"
  195. )
  196. else:
  197. raise AssertionError(
  198. "We currently only generate either functional or out= NativeFunctions"
  199. )
  200. if gets_composite_kernel:
  201. backend_metadata = {
  202. DispatchKey.CompositeExplicitAutograd: {
  203. func.name: BackendMetadata(cpp.name(func), structured=False)
  204. }
  205. }
  206. else:
  207. backend_metadata = {}
  208. return (
  209. NativeFunction(
  210. func=func,
  211. use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
  212. # These generated fn's aren't meant to be user friendly- don't generate methods.
  213. variants=set([Variant.function]),
  214. structured=False,
  215. structured_delegate=None,
  216. structured_inherits=None,
  217. precomputed=None,
  218. autogen=[],
  219. ufunc_inner_loop={},
  220. manual_kernel_registration=False,
  221. manual_cpp_binding=False,
  222. python_module=None,
  223. category_override=None,
  224. device_guard=False,
  225. device_check=DeviceCheckType.NoCheck,
  226. loc=f.loc,
  227. cpp_no_default_args=set(),
  228. is_abstract=f.is_abstract,
  229. has_composite_implicit_autograd_kernel=False,
  230. has_composite_explicit_autograd_kernel=gets_composite_kernel,
  231. # Every generated NativeFunction gets a "generated" tag, so it's easy to tell
  232. # which NativeFunction objects did not come directly from native_functions.yaml.
  233. tags=set(["generated"]),
  234. ),
  235. backend_metadata,
  236. )
  237. # This function is responsible for adding generated NativeFunctions which don't appear
  238. # explicitly in the codegen.
  239. # You can inspect the full list of NativeFunctions yourself with the torchgen package, by running
  240. # torchgen.parse_native_yaml("aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml")
  241. # (Maybe we should make a friendly API for this)
  242. #
  243. # Note: this function *mutates* its two inputs,
  244. # adding the new NativeFunctions / BackendMetadata to them
  245. def add_generated_native_functions(
  246. rs: List[NativeFunction],
  247. indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]],
  248. ) -> None:
  249. # The main code for gnerating new NativeFunctions
  250. # First we group of NaitveFunctions by schema kind,
  251. # then we detect which ones are missing and generate them.
  252. pre_grouped_native_functions = pre_group_native_functions(rs)
  253. for k, d in pre_grouped_native_functions.items():
  254. has_functional = SchemaKind.functional in d
  255. has_inplace = SchemaKind.inplace in d
  256. has_mutable = SchemaKind.mutable in d
  257. has_out = SchemaKind.out in d
  258. # We automatically generate a few native functions that don't exist in the yaml, for a few reasons:
  259. # (1) If an operator has an inplace/out= variant but no functional variant, we can generate
  260. # a simple functional variant that the functionalization pass can consume.
  261. # (2) If an operator has an inplace and functional but no out= variant, we generate an out=
  262. # variant, mostly so we can easily pair up functions into NativeFunctionsGroup,
  263. # while maintaining the constraint that the out= variant is "required".
  264. #
  265. # For now, we don't bother generated NativeFunctions for existing operators
  266. # that only have a functional variant.
  267. if has_mutable or has_inplace or has_out:
  268. # Don't bother generating functions trio's for native functions that bypass the dispatcher.
  269. are_manual = all(f.manual_cpp_binding for f in d.values())
  270. # Don't bother generating functional + out= variants for view operators
  271. has_view_ops = (
  272. has_inplace and "inplace_view" in d[SchemaKind.inplace].tags
  273. ) or any(f.is_view_op for f in d.values())
  274. # Don't generate the other variants for CompositeImplicitAutograd operators.
  275. # We could probably do this, but the main benefit of generating the function triplets
  276. # is for transforms that need them, and transforms don't need to act directly
  277. # on CompositeImplicitAutograd operators (since we let them decompose).
  278. are_composite_implicit = all(
  279. f.has_composite_implicit_autograd_kernel for f in d.values()
  280. )
  281. if are_manual or has_view_ops or are_composite_implicit:
  282. continue
  283. if has_out and len(d.values()) == 1:
  284. # Note: [Out ops with functional variants that don't get grouped properly]
  285. # In theory we could validly have an out= operator in native_functions.yaml
  286. # that has no other variants.
  287. # But today, all of the operators where that's the case actually do have
  288. # functional variants, that we are just unable to pair up properly.
  289. # I think banning this all together is probably safer
  290. # (you can always add a functional variant yourself if you want to add a new out= operator).
  291. #
  292. # We should probably fix the existing cases; this check is to prevent us from adding more over time.
  293. if (
  294. str(d[SchemaKind.out].func.name)
  295. not in OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
  296. ):
  297. raise AssertionError(
  298. f"Found an out= operator that we could not find any other variants of: {str(d[SchemaKind.out].func)}"
  299. )
  300. continue
  301. # Some inplace ops that have problematic schemas (that we should fix), which prevent us
  302. # from generating out= and functional variants
  303. if (
  304. has_inplace
  305. and str(d[SchemaKind.inplace].func.name)
  306. in INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
  307. ):
  308. continue
  309. base_fn = (
  310. d[SchemaKind.inplace]
  311. if has_inplace
  312. else d[SchemaKind.mutable]
  313. if has_mutable
  314. else d[SchemaKind.out]
  315. )
  316. # Note: [Mutable ops that cannot get an out variant]
  317. # We can only generate an out= variant if either:
  318. # - the original function has tensor-like returns (since we can convert them to out kwargs)
  319. # - or it's inplace (since we can convert `self` to an out kwarg)
  320. # There are only two functions that don't fit this criteria today though,
  321. # and they both look like they should be fixed to be out= variants,
  322. # so if feels safer to ban this schema all-together
  323. gets_out_variant = not has_out and (
  324. base_fn.func.kind() == SchemaKind.inplace
  325. or any(r.type.is_tensor_like() for r in base_fn.func.returns)
  326. )
  327. if not has_out and not gets_out_variant:
  328. if (
  329. str(base_fn.func.name)
  330. not in MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
  331. ):
  332. raise AssertionError(
  333. f"""Found a mutable operator that we could not generate an out= variant for: {str(base_fn.func)}.
  334. These operators are problematic, because we can't easily auto-generate functionalization code for them. If you really need
  335. the operator have the schema mentioned, that add the name of the operator to the allow-list. Otherwise if possible,
  336. please convert it to an inplace operator"""
  337. )
  338. # Generate an out= variant
  339. if gets_out_variant:
  340. fn, metadata = generate_function(base_fn, SchemaKind.out)
  341. d[SchemaKind.out] = fn
  342. BackendIndex.grow_index(indices, metadata)
  343. rs.append(fn)
  344. # Generate a functional variant, but only do it if the operator got an out= variant
  345. # (Functional variants are only useful if we can group up the variants,
  346. # which we can only do if they have an out= variant)
  347. if not has_functional and (has_out or gets_out_variant):
  348. fn, metadata = generate_function(base_fn, SchemaKind.functional)
  349. d[SchemaKind.functional] = fn
  350. BackendIndex.grow_index(indices, metadata)
  351. rs.append(fn)