gen.py 94 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538
  1. import os
  2. from typing import List, Dict, Optional, Tuple, Set, Any, Union, Sequence, TypeVar
  3. from typing_extensions import Literal
  4. import yaml
  5. from collections import OrderedDict, defaultdict, namedtuple
  6. import argparse
  7. import pathlib
  8. import json
  9. from dataclasses import dataclass
  10. import functools
  11. from torchgen.model import (
  12. STRUCTURED_DISPATCH_KEYS,
  13. Argument,
  14. DispatchKey,
  15. FunctionSchema,
  16. Location,
  17. NativeFunction,
  18. NativeFunctionsGroup,
  19. OperatorName,
  20. BackendIndex,
  21. BackendMetadata,
  22. OptionalType,
  23. SchemaKind,
  24. SelfArgument,
  25. TensorOptionsArguments,
  26. Type,
  27. Variant,
  28. is_cuda_dispatch_key,
  29. is_generic_dispatch_key,
  30. is_ufunc_dispatch_key,
  31. NativeFunctionsViewGroup,
  32. ViewSchemaKind,
  33. BaseOperatorName,
  34. )
  35. from torchgen.native_function_generation import (
  36. pre_group_native_functions,
  37. add_generated_native_functions,
  38. )
  39. from torchgen.api.types import (
  40. Binding,
  41. CppSignatureGroup,
  42. DispatcherSignature,
  43. NamedCType,
  44. NativeSignature,
  45. SpecialArgName,
  46. )
  47. from torchgen.api import cpp
  48. import torchgen.api.dispatcher as dispatcher
  49. import torchgen.api.native as native
  50. import torchgen.api.meta as meta
  51. import torchgen.api.structured as structured
  52. from torchgen.api.translate import translate
  53. from torchgen.code_template import CodeTemplate
  54. from torchgen.selective_build.selector import SelectiveBuilder
  55. from torchgen.utils import (
  56. Target,
  57. concatMap,
  58. context,
  59. mapMaybe,
  60. YamlDumper,
  61. YamlLoader,
  62. FileManager,
  63. assert_never,
  64. make_file_manager,
  65. )
  66. from torchgen.context import (
  67. method_with_native_function,
  68. native_function_manager,
  69. with_native_function_and_indices,
  70. with_native_function,
  71. )
  72. import torchgen.dest as dest
  73. from torchgen.gen_functionalization_type import (
  74. gen_functionalization_definition,
  75. gen_functionalization_registration,
  76. gen_functionalization_view_inverse_declaration,
  77. gen_composite_view_copy_kernel,
  78. gen_composite_functional_kernel,
  79. )
  80. T = TypeVar("T")
  81. # Welcome to the ATen code generator v2! The ATen code generator is
  82. # responsible for parsing native_functions.yaml and then generating
  83. # various generated files (e.g., TypeDefault.cpp) based on the operators
  84. # defined in this file. This means that the code generator knows how to
  85. # parse function schema, and then translate this into various C++ types
  86. # and boilerplate code.
  87. #
  88. # Some things to know about this file when you modify it:
  89. #
  90. # - This file has STRICT mypy typechecking. Typecheck it with
  91. # `mypy --config mypy-strict.ini` in the root source directory
  92. #
  93. # - Most of the heavy lifting lives in external modules:
  94. # - 'model' has the data model for native_functions.yaml. The classes
  95. # in those file represent what you see when you look at
  96. # a native_functions.yaml
  97. # - 'api' has conversions for how to translate JIT schema into
  98. # the various C++ APIs that the codegen interacts with. There
  99. # are in fact THREE different C++ APIs: the public C++ API,
  100. # the dispatcher API, and the legacy disaptcher API. See each
  101. # of these respective files for more information
  102. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  103. #
  104. # HELPER FUNCTIONS
  105. #
  106. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  107. class NamespaceHelper:
  108. """A helper for constructing the namespace open and close strings for a nested set of namespaces.
  109. e.g. for namespace_str torch::lazy,
  110. prologue:
  111. namespace torch {
  112. namespace lazy {
  113. epilogue:
  114. } // namespace lazy
  115. } // namespace torch
  116. """
  117. def __init__(self, namespace_str: str):
  118. # cpp_namespace can be a colon joined string such as torch::lazy
  119. cpp_namespaces = namespace_str.split("::")
  120. self.prologue_ = "\n".join([f"namespace {n} {{" for n in cpp_namespaces])
  121. self.epilogue_ = "\n".join(
  122. [f"}} // namespace {n}" for n in reversed(cpp_namespaces)]
  123. )
  124. @property
  125. def prologue(self) -> str:
  126. return self.prologue_
  127. @property
  128. def epilogue(self) -> str:
  129. return self.epilogue_
  130. # A custom loader for YAML to let us also keep track of line numbers
  131. # of each entry in the YAML file
  132. class LineLoader(YamlLoader):
  133. def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
  134. mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
  135. # Add 1 so line numbering starts at 1
  136. mapping["__line__"] = node.start_mark.line + 1
  137. return mapping
  138. _GLOBAL_PARSE_NATIVE_YAML_CACHE = {}
  139. # Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
  140. ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
  141. def parse_native_yaml_struct(
  142. es: object,
  143. valid_tags: Set[str],
  144. ignore_keys: Optional[Set[DispatchKey]] = None,
  145. path: str = "<stdin>",
  146. ) -> ParsedYaml:
  147. assert isinstance(es, list)
  148. rs: List[NativeFunction] = []
  149. bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict)
  150. for e in es:
  151. assert isinstance(e.get("__line__"), int), e
  152. loc = Location(path, e["__line__"])
  153. funcs = e.get("func")
  154. with context(lambda: f"in {loc}:\n {funcs}"):
  155. func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
  156. rs.append(func)
  157. BackendIndex.grow_index(bs, m)
  158. error_check_native_functions(rs)
  159. # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
  160. indices: Dict[DispatchKey, BackendIndex] = defaultdict(
  161. lambda: BackendIndex(
  162. dispatch_key=DispatchKey.Undefined,
  163. use_out_as_primary=True,
  164. external=False,
  165. device_guard=False,
  166. index={},
  167. )
  168. )
  169. add_generated_native_functions(rs, bs)
  170. for k, v in bs.items():
  171. # All structured in-tree operators are implemented in terms of their out operator.
  172. indices[k] = BackendIndex(
  173. dispatch_key=k,
  174. use_out_as_primary=True,
  175. external=False,
  176. # Only cuda-like devices in tree require device guards
  177. device_guard=is_cuda_dispatch_key(k),
  178. index=v,
  179. )
  180. return ParsedYaml(rs, indices)
  181. def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]:
  182. assert isinstance(es, list)
  183. rs: Set[str] = set()
  184. for e in es:
  185. assert isinstance(e.get("__line__"), int), e
  186. loc = Location(path, e["__line__"])
  187. tags = e.get("tag")
  188. with context(lambda: f"in {loc}:\n {tags}"):
  189. e_i = e.copy()
  190. name = e_i.pop("tag")
  191. desc = e_i.pop("desc", "")
  192. # ensure that each tag has a non-empty description
  193. assert desc != ""
  194. rs.add(name)
  195. return rs
  196. @functools.lru_cache(maxsize=None)
  197. def parse_tags_yaml(path: str) -> Set[str]:
  198. # TODO: parse tags.yaml and create a tags database (a dict of tag name mapping to a Tag object)
  199. with open(path, "r") as f:
  200. es = yaml.load(f, Loader=LineLoader)
  201. valid_tags = parse_tags_yaml_struct(es, path=path)
  202. return valid_tags
  203. def parse_native_yaml(
  204. path: str, tags_yaml_path: str, ignore_keys: Optional[Set[DispatchKey]] = None
  205. ) -> ParsedYaml:
  206. # TODO: parse tags.yaml and create a tags database (a dict of tag name mapping to a Tag object)
  207. global _GLOBAL_PARSE_NATIVE_YAML_CACHE
  208. if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
  209. valid_tags = parse_tags_yaml(tags_yaml_path)
  210. with open(path, "r") as f:
  211. es = yaml.load(f, Loader=LineLoader)
  212. _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(
  213. es, valid_tags, ignore_keys, path=path
  214. )
  215. return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
  216. # Some assertions are already performed during parsing, but those are only within a single NativeFunction.
  217. # Assertions here are meant to be performed across NativeFunctions.
  218. def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
  219. func_map: Dict[OperatorName, NativeFunction] = {}
  220. base_func_map: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list)
  221. for f in funcs:
  222. func_map[f.func.name] = f
  223. base_func_map[f.func.name.name].append(f)
  224. for f in funcs:
  225. if f.structured_delegate is not None:
  226. delegate_func = func_map[f.structured_delegate]
  227. assert delegate_func.structured, (
  228. f"{f.func.name} is marked as a structured_delegate pointing to "
  229. f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
  230. f"Consider adding 'structured=True' to the delegated operator"
  231. )
  232. if "inplace_view" in f.tags:
  233. base_name = f.func.name.name
  234. overload_name = f.func.name.overload_name
  235. assert base_name.inplace, (
  236. f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming "
  237. "convention for inplace ops - the codegen expects the base name to have a trailing underscore. "
  238. )
  239. out_of_place_base_name = BaseOperatorName(
  240. base_name.base, False, base_name.dunder_method
  241. )
  242. assert len(base_func_map[out_of_place_base_name]) > 0, (
  243. f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding "
  244. f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one. "
  245. )
  246. def cpp_string(s: str) -> str:
  247. """Convert a python string into a c++ string literal"""
  248. s = s.replace("\\", "\\\\")
  249. s = s.replace('"', '\\"')
  250. s = s.replace("\a", "\\a")
  251. s = s.replace("\b", "\\b")
  252. s = s.replace("\f", "\\f")
  253. s = s.replace("\n", "\\n")
  254. s = s.replace("\v", "\\v")
  255. s = s.replace("\t", "\\t")
  256. return f'"{s}"'
  257. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  258. #
  259. # C++ CODE GENERATION
  260. #
  261. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  262. # Most functions in this section are curried: they consist of a function
  263. # that takes some parameters (e.g., what is to be generated) which itself
  264. # returns a function that actually maps NativeFunction to the code
  265. # to be generated. This pattern makes it convenient to use map, concatMap
  266. # and similar functional combinators.
  267. def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]:
  268. if len(backends) == 0:
  269. return []
  270. else:
  271. return [backend.dispatch_key for backend in backends] + [
  272. DispatchKey.CompositeImplicitAutograd,
  273. DispatchKey.CompositeExplicitAutograd,
  274. ]
  275. def get_static_dispatch_backend(
  276. f: NativeFunction, backend_index: BackendIndex
  277. ) -> Optional[DispatchKey]:
  278. if f.structured_delegate is not None or backend_index.has_kernel(f):
  279. # TODO: for ops with structured_delegate it should check the dispatch table of
  280. # the out variant instead. For now, these structured ops all have CPU/CUDA kernels
  281. # so we always dispatch to the `backend`, but this could be wrong when we
  282. # migrate math/default_backend ops to use structured delegate.
  283. return backend_index.dispatch_key
  284. elif f.has_composite_explicit_autograd_kernel:
  285. return DispatchKey.CompositeExplicitAutograd
  286. elif f.has_composite_implicit_autograd_kernel:
  287. return DispatchKey.CompositeImplicitAutograd
  288. return None
  289. def static_dispatch_ops_header(
  290. f: NativeFunction, backend_index: List[BackendIndex]
  291. ) -> Optional[str]:
  292. if backend_index is None or f.manual_kernel_registration:
  293. return None
  294. output = []
  295. for index in backend_index:
  296. dispatch_key = get_static_dispatch_backend(f, index)
  297. if dispatch_key is not None:
  298. output.append(
  299. f"#include <ATen/ops/{f.root_name}_{dispatch_key.lower()}_dispatch.h>"
  300. )
  301. return "\n".join(output)
  302. def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
  303. return [
  304. f"#include <ATen/{dispatch_key}Functions.h>"
  305. for dispatch_key in static_dispatch_keys(backends)
  306. ]
  307. # Translates arguments of a native function from DispatcherSignature form to CppSignature form with support for
  308. # supporting usecases even when there is a memory_format argument along with tensor_option arguments.
  309. # This usecase is not covered by tools.codegen.api.translate() yet as its application is limited to static dispatch
  310. def translate_args_dispatcher_to_cpp(
  311. f: NativeFunction,
  312. ) -> str:
  313. # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
  314. def add_spl_memory_format_binding(input_bindings: List[Binding]) -> List[Binding]:
  315. output_bindings: List[Binding] = []
  316. for binding in input_bindings:
  317. if binding.name == "memory_format":
  318. spl_mem_format_binding = Binding(
  319. nctype=NamedCType(
  320. SpecialArgName.possibly_redundant_memory_format,
  321. binding.nctype.type,
  322. ),
  323. name=binding.name,
  324. default=binding.default,
  325. argument=binding.argument,
  326. )
  327. output_bindings.append(spl_mem_format_binding)
  328. else:
  329. output_bindings.append(binding)
  330. return output_bindings
  331. disp_sig = DispatcherSignature.from_schema(f.func)
  332. cpp_sig = CppSignatureGroup.from_native_function(
  333. f, method=False, fallback_binding=False
  334. ).signature
  335. disp_bindings = disp_sig.arguments()
  336. # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
  337. # get memory_format bindings of dispatcher signature to have the same NCType as well
  338. for arg in cpp_sig.arguments():
  339. if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
  340. disp_bindings = add_spl_memory_format_binding(disp_sig.arguments())
  341. break
  342. exprs = translate(disp_bindings, cpp_sig.arguments())
  343. return ", ".join(a.expr for a in exprs)
  344. def generate_static_dispatch_backend_call(
  345. f: NativeFunction,
  346. backend_index: BackendIndex,
  347. ) -> str:
  348. name = DispatcherSignature.from_schema(f.func).name()
  349. exprs = translate_args_dispatcher_to_cpp(f)
  350. return f"return at::{backend_index.dispatch_key.lower()}::{name}({exprs});"
  351. def generate_static_dispatch_fallback_call(
  352. f: NativeFunction,
  353. backend_indices: List[BackendIndex],
  354. ) -> str:
  355. name = DispatcherSignature.from_schema(f.func).name()
  356. exprs = translate_args_dispatcher_to_cpp(f)
  357. if f.has_composite_explicit_autograd_kernel:
  358. return f"return at::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
  359. elif f.has_composite_implicit_autograd_kernel:
  360. return f"return at::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
  361. else:
  362. return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
  363. {', '.join([str(index.dispatch_key)for index in backend_indices])} ");"""
  364. def static_dispatch(
  365. f: NativeFunction,
  366. backend_indices: List[BackendIndex],
  367. ) -> str:
  368. if len(backend_indices) == 0 or f.manual_kernel_registration:
  369. return ""
  370. keys = [
  371. b
  372. for b in backend_indices
  373. if b.has_kernel(f)
  374. or (
  375. f.structured_delegate is not None
  376. and b.dispatch_key in STRUCTURED_DISPATCH_KEYS
  377. )
  378. ]
  379. if len(keys) == 1:
  380. return generate_static_dispatch_backend_call(f, keys[0])
  381. elif len(keys) == 0:
  382. return generate_static_dispatch_fallback_call(f, backend_indices)
  383. sig = DispatcherSignature.from_schema(f.func)
  384. native_tensor_args = [
  385. a.name
  386. for a in sig.arguments()
  387. if isinstance(a.argument, SelfArgument)
  388. or isinstance(a.argument, Argument)
  389. and a.argument.type.is_tensor_like()
  390. ]
  391. tensor_args = ", ".join(native_tensor_args)
  392. tensor_opts = f.func.arguments.tensor_options
  393. stmts = []
  394. subexprs: List[str] = []
  395. if tensor_opts is not None:
  396. subexprs.append(
  397. "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
  398. )
  399. if tensor_args != "":
  400. subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
  401. stmts.append(f"""DispatchKeySet _dk_set = {' | '.join(subexprs)};""")
  402. stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
  403. dispatch_code = []
  404. for index in keys:
  405. dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
  406. dispatch_code.append(
  407. f"""\t{generate_static_dispatch_backend_call(f, index)};"""
  408. )
  409. fallback = generate_static_dispatch_fallback_call(f, backend_indices)
  410. connector = "\n\t\t"
  411. return f"""
  412. {connector.join(stmts)}
  413. switch (_dk) {{
  414. {connector.join(dispatch_code)}
  415. default:
  416. {fallback}
  417. }}
  418. """
  419. # Generates RegisterSchema.cpp. Depending on the selector, either
  420. # all schemas are registered, or only some are (in the case of
  421. # selective build)
  422. @dataclass(frozen=True)
  423. class RegisterSchema:
  424. selector: SelectiveBuilder
  425. @method_with_native_function
  426. def __call__(self, f: NativeFunction) -> Optional[str]:
  427. if not self.selector.is_native_function_selected(f):
  428. return None
  429. return f"m.def({cpp_string(str(f.func))});\n"
  430. # Generates Operators.h and Operators.cpp.
  431. # These provide macros that, given an operator and overload name, allow users
  432. # to access an "un-overloaded" function version of the operator. This
  433. # is useful for extension writers who want to (1) want to decltype the operator
  434. # and (2) don't want to worry about method-only operators.
  435. @dataclass(frozen=True)
  436. class ComputeOperators:
  437. target: Union[Literal[Target.DECLARATION], Literal[Target.DEFINITION]]
  438. static_dispatch_backend_indices: List[BackendIndex]
  439. @method_with_native_function
  440. def __call__(self, f: NativeFunction) -> str:
  441. sig = DispatcherSignature.from_schema(f.func)
  442. name = f.func.name.unambiguous_name()
  443. call_method_name = "call"
  444. redispatch_method_name = "redispatch"
  445. if self.target is Target.DECLARATION:
  446. # Note [The ATen Operators API]
  447. # The ATen Operators API lives in the at::_ops namespace, and contains compile-time
  448. # metadata about each operator + entry points into the Dispatcher.
  449. # The C++ function, method, and redispatch API's are all implemented as wrappers
  450. # into various bits of the structs defined here.
  451. #
  452. # Important characteristics about the Operators API:
  453. # (1) It follows the Dispatcher API.
  454. # This is kind of necessary to avoid overhead.
  455. # For example: if it followed the C++ API, then all of the faithful C++ factory functions
  456. # would need to wrap their arguments into TensorOptions only to unwrap them again.
  457. # (2) Overload names are disambiguated.
  458. # This is helpful for pytorch extenders who would like to decltype() an aten operator,
  459. # that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
  460. # (3) No argument defaulting is allowed.
  461. # This is more of an implementation detail to avoid #include cycles,
  462. # since TensorBody.h (which defines the Tensor class) needs to include this file.
  463. # (4) manual_cpp_bindings and faithful names are not included in the API.
  464. # This applies to stuff like __dispatch__is_complex(), and add_outf().
  465. # These aren't "real aten ops", they're just additional functions provided by the C++ API.
  466. # They're implemented as wrappers in Functions.h that call into the actual operators
  467. # defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
  468. # This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
  469. return f"""
  470. struct TORCH_API {name} {{
  471. using schema = {sig.type()};
  472. using ptr_schema = schema*;
  473. // See Note [static constexpr char* members for windows NVCC]
  474. STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}")
  475. STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}")
  476. STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))})
  477. static {sig.defn(name=call_method_name, is_redispatching_fn=False)};
  478. static {sig.defn(name=redispatch_method_name, is_redispatching_fn=True)};
  479. }};"""
  480. elif self.target is Target.DEFINITION:
  481. defns = f"""
  482. STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}")
  483. STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}")
  484. STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))})
  485. // aten::{f.func}
  486. static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
  487. return c10::Dispatcher::singleton()
  488. .findSchemaOrThrow({name}::name, {name}::overload_name)
  489. .typed<{name}::schema>();
  490. }}
  491. """
  492. for is_redispatching_fn in [False, True]:
  493. if is_redispatching_fn:
  494. dispatcher_exprs_str = ", ".join(
  495. ["dispatchKeySet"] + [a.name for a in sig.arguments()]
  496. )
  497. dispatcher_call = "redispatch"
  498. method_name = f"{name}::{redispatch_method_name}"
  499. else:
  500. method_name = f"{name}::{call_method_name}"
  501. dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
  502. dispatcher_call = "call"
  503. fn_body = f"""
  504. static auto op = create_{name}_typed_handle();
  505. return op.{dispatcher_call}({dispatcher_exprs_str});"""
  506. if (
  507. not is_redispatching_fn
  508. and len(self.static_dispatch_backend_indices) > 0
  509. ):
  510. # call() should go through static dispatch
  511. fn_body = static_dispatch(
  512. f, backend_indices=self.static_dispatch_backend_indices
  513. )
  514. defns += f"""
  515. // aten::{f.func}
  516. {sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
  517. {fn_body}
  518. }}
  519. """
  520. return defns
  521. else:
  522. assert_never(self.target)
  523. # Generates Functions.h, which provides the functional public C++ API,
  524. # and the scaffolding to call into the dispatcher from these functions.
  525. @dataclass(frozen=True)
  526. class ComputeFunction:
  527. @method_with_native_function
  528. def __call__(self, f: NativeFunction) -> Optional[str]:
  529. if Variant.function not in f.variants:
  530. return None
  531. sig_group = CppSignatureGroup.from_native_function(
  532. f, method=False, fallback_binding=f.manual_cpp_binding
  533. )
  534. def generate_defn(faithful: bool) -> str:
  535. if faithful:
  536. sig = sig_group.faithful_signature
  537. assert sig is not None
  538. else:
  539. sig = sig_group.signature
  540. # See Note [The ATen Operators API]
  541. target_sig = DispatcherSignature.from_schema(f.func)
  542. exprs = translate(sig.arguments(), target_sig.arguments())
  543. exprs_str = ", ".join([e.expr for e in exprs])
  544. return f"""
  545. // aten::{f.func}
  546. TORCH_API inline {sig.decl()} {{
  547. return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
  548. }}
  549. """
  550. result = generate_defn(False)
  551. if sig_group.faithful_signature is not None:
  552. result += generate_defn(True)
  553. return result
  554. # Generates TensorBody.h. This file provides the object-oriented (method-based)
  555. # public C++ API, and the scaffolding to call into the dispatcher from these functions.
  556. @dataclass(frozen=True)
  557. class ComputeTensorMethod:
  558. target: Union[Literal[Target.DECLARATION], Literal[Target.DEFINITION]]
  559. static_dispatch_backend_indices: List[BackendIndex]
  560. @method_with_native_function
  561. def __call__(self, f: NativeFunction) -> Optional[str]:
  562. if Variant.method not in f.variants:
  563. return None
  564. assert not f.func.is_out_fn()
  565. assert f.func.arguments.self_arg is not None
  566. sig_group = CppSignatureGroup.from_native_function(
  567. f, method=True, fallback_binding=f.manual_cpp_binding
  568. )
  569. if self.target is Target.DECLARATION:
  570. result = f"{sig_group.signature.decl()} const;\n"
  571. if sig_group.faithful_signature is not None:
  572. result += f"{sig_group.faithful_signature.decl()} const;\n"
  573. return result
  574. if self.target is not Target.DEFINITION:
  575. assert_never(self.target)
  576. def generate_defn(faithful: bool) -> str:
  577. if faithful:
  578. sig = sig_group.faithful_signature
  579. assert sig is not None
  580. else:
  581. sig = sig_group.signature
  582. target_sig = DispatcherSignature.from_schema(f.func)
  583. exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
  584. exprs_str = ", ".join([e.expr for e in exprs])
  585. return f"""
  586. // aten::{f.func}
  587. inline {sig.defn(prefix="Tensor::")} const {{
  588. return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
  589. }}
  590. """
  591. result = generate_defn(faithful=False)
  592. if sig_group.faithful_signature is not None:
  593. result += generate_defn(faithful=True)
  594. return result
  595. # Generates RedispatchFunctions.h.
  596. # This is similar to the C++ API defined in Functions.h, but provides access
  597. # to the dispatcher's redispatch API.
  598. @dataclass(frozen=True)
  599. class ComputeRedispatchFunction:
  600. @method_with_native_function
  601. def __call__(self, f: NativeFunction) -> Optional[str]:
  602. # We unconditionally generate function variants of the redispatch API.
  603. # This is mainly because we can namespace functions separately, but not methods,
  604. sig_group = CppSignatureGroup.from_native_function(
  605. f, method=False, fallback_binding=f.manual_cpp_binding
  606. )
  607. def generate_defn(faithful: bool) -> str:
  608. if faithful:
  609. sig = sig_group.faithful_signature
  610. assert sig is not None
  611. else:
  612. sig = sig_group.signature
  613. target_sig = DispatcherSignature.from_schema(f.func)
  614. exprs = translate(sig.arguments(), target_sig.arguments())
  615. exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
  616. return f"""
  617. // aten::{f.func}
  618. TORCH_API inline {sig.decl(is_redispatching_fn=True)} {{
  619. return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
  620. }}
  621. """
  622. result = generate_defn(False)
  623. if sig_group.faithful_signature is not None:
  624. result += generate_defn(True)
  625. return result
  626. # Generates ATenOpList.cpp, a runtime accessible list of all aten
  627. # operators.
  628. # TODO: This was historically used to help some JIT interop code
  629. # figure out whether or not to treat aten namespace'd operators
  630. # one way or another, we should reevaluate if this is actually needed.
  631. @with_native_function
  632. def compute_aten_op(f: NativeFunction) -> str:
  633. return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
  634. # Generates MetaFunctions.h
  635. def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]:
  636. if not g.structured:
  637. return None
  638. with native_function_manager(g.out):
  639. name = meta.name(g)
  640. args = structured.meta_arguments(g)
  641. args_str = ", ".join(a.decl() for a in args)
  642. parent_class = g.out.structured_inherits
  643. if parent_class is None:
  644. parent_class = "at::impl::MetaBase"
  645. meta_return = "void"
  646. precomputed = g.out.precomputed if g.structured else None
  647. if precomputed:
  648. # Generate the template declaration with one bool parameter for each
  649. # precomputed element. Each parameter is true if the corresponding (in
  650. # terms of position) precomputed element has been set.
  651. precomputed_values = [*precomputed.replace.values(), precomputed.add]
  652. precomputed_elements = [
  653. elem for replace_list in precomputed_values for elem in replace_list
  654. ]
  655. precomputed_template_parameters = [
  656. elem.name.upper() for elem in precomputed_elements
  657. ]
  658. precomputed_template_params_str = ", ".join(
  659. f"bool {param} = false" for param in precomputed_template_parameters
  660. )
  661. precompute_template_decl = f"template <{precomputed_template_params_str}>"
  662. # Generate a string containing declarations of all precomputed elements.
  663. precomputed_elements_with_cpp_types = [
  664. structured.argument_type(elem, binds=elem.name)
  665. for elem in precomputed_elements
  666. ]
  667. precomputed_elements_decl = ";\n".join(
  668. f"{elem.cpp_type(strip_ref=True)} {elem.name}"
  669. for elem in precomputed_elements_with_cpp_types
  670. )
  671. # Generate "setter" methods for each precomputed element. Each method will return
  672. # a new instance of precompute_out with the template parameter that corresponds to
  673. # the member set by the method to true (to indicate that it has been set).
  674. setter_methods = []
  675. for i, elem in enumerate(precomputed_elements):
  676. # Generate the signature. The return type will be the same
  677. # as the type of `this` but with the template parameter
  678. # corresponding to the element set by this method set to true.
  679. # The assert generated below will ensure that this template
  680. # parameter is false on the type of `this`.
  681. return_ty_templates = ", ".join(
  682. precomputed_template_parameters[:i]
  683. + ["true"]
  684. + precomputed_template_parameters[i + 1 :]
  685. )
  686. return_ty = f"precompute_out<{return_ty_templates}>"
  687. elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(
  688. strip_ref=True
  689. )
  690. signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
  691. # Generate an assert which checks that the
  692. # template parameter corresponding to the precomputed
  693. # element that is set by this method is false on the
  694. # class corresponding to the object that `this` points to.
  695. # This ensures that each element can be set only once.
  696. assert_msg = f'"{precomputed_elements[i].name} already set"'
  697. assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
  698. # Generate the new object construction block. All state
  699. # except the element that this method sets is copied from the
  700. # object that `this` points to. The value for the element that
  701. # the method sets is taken from a method parameter.
  702. construction_stmts = []
  703. construction_stmts.append(f"{return_ty} ret;")
  704. for j, elem in enumerate(precomputed_elements):
  705. if i == j:
  706. construction_stmts.append(f"ret.{elem.name} = value;")
  707. else:
  708. construction_stmts.append(
  709. f"ret.{elem.name} = this->{elem.name};"
  710. )
  711. construction_stmts.append("return ret;")
  712. construction_block = "\n".join(construction_stmts)
  713. setter_methods.append(
  714. f"""
  715. {signature} {{
  716. {assert_stmt}
  717. {construction_block}
  718. }}
  719. """
  720. )
  721. setter_methods_decl = "\n".join(setter_methods)
  722. # Meta should return an instance of the struct containing the precomputed elements.
  723. meta_return_template_params = ", ".join(
  724. ["true"] * len(precomputed_template_parameters)
  725. )
  726. # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
  727. # type (which has a variable number of template parameters).
  728. meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
  729. meta_return = "meta_return_ty"
  730. precomputed_decl = f"""
  731. {precompute_template_decl}
  732. struct TORCH_API precompute_out {{
  733. {setter_methods_decl}
  734. {precomputed_elements_decl};
  735. }};"""
  736. else:
  737. meta_return_typedef = ""
  738. precomputed_decl = ""
  739. return f"""\
  740. struct TORCH_API structured_{name} : public {parent_class} {{
  741. {precomputed_decl}
  742. {meta_return_typedef}
  743. {meta_return} meta({args_str});
  744. }};
  745. """
  746. def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool:
  747. name = str(f.func.name.name)
  748. if name.endswith("_like") or name.startswith("new_"):
  749. return False
  750. if f.func.arguments.tensor_options is None:
  751. return False
  752. return selector.is_native_function_selected(f)
  753. # Generates RegisterBackendSelect.cpp, a series of kernels which provide
  754. # specialized computation of dispatch key for operator signatures which cannot
  755. # be easily done automatically using templating.
  756. @dataclass(frozen=True)
  757. class ComputeBackendSelect:
  758. target: Union[Literal[Target.DEFINITION], Literal[Target.REGISTRATION]]
  759. # Selector object to determine which operators to generate
  760. # registration code for.
  761. selector: SelectiveBuilder
  762. @method_with_native_function
  763. def __call__(self, f: NativeFunction) -> Optional[str]:
  764. if not needs_backend_select(f, self.selector):
  765. return None
  766. name = native.name(f.func)
  767. native_sig = NativeSignature(f.func)
  768. native_tensor_args = [
  769. a
  770. for a in native_sig.arguments()
  771. if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
  772. ]
  773. dispatcher_sig = DispatcherSignature.from_schema(f.func)
  774. sig: Union[NativeSignature, DispatcherSignature]
  775. sig = dispatcher_sig
  776. dispatcher_exprs = dispatcher_sig.exprs()
  777. dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
  778. if self.target is Target.DEFINITION:
  779. # I don't think there's actually a good reason to generate
  780. # these two cases differently
  781. # The first case could probably be improved though- it calls computeDispatchKeySet(),
  782. # which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
  783. if native_tensor_args:
  784. tensor_args = ", ".join(a.name for a in native_tensor_args)
  785. compute_dk = f"""\
  786. DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
  787. DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
  788. DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
  789. else:
  790. compute_dk = (
  791. f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
  792. )
  793. return f"""\
  794. // aten::{f.func}
  795. C10_ALWAYS_INLINE
  796. {sig.defn(name)} {{
  797. {compute_dk}
  798. return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
  799. _dk, {', '.join(a.expr for a in dispatcher_exprs)});
  800. }}
  801. """
  802. elif self.target is Target.REGISTRATION:
  803. return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
  804. else:
  805. assert_never(self.target)
  806. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  807. #
  808. # YAML CODE GENERATION
  809. #
  810. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  811. def format_yaml(data: object) -> str:
  812. # Ignore alias in Dumper
  813. YamlDumper.ignore_aliases = lambda self, data: True # type: ignore[assignment]
  814. # Support serializing OrderedDict
  815. def dict_representer(dumper: Any, data: Any) -> Any:
  816. return dumper.represent_dict(data.items())
  817. YamlDumper.add_representer(OrderedDict, dict_representer) # type: ignore[no-untyped-call]
  818. # Some yaml parsers (e.g. Haskell's) don't understand line breaks.
  819. # width=1e9 turns off optional line breaks and improves
  820. # the portability of the outputted yaml.
  821. return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9) # type: ignore[no-any-return, call-overload]
  822. # For some reason, some defaults we write to YAML are written as native
  823. # YAML objects, rather than doing them uniformly as strings. This
  824. # function detects those cases and converts them into native Python
  825. # objects.
  826. def pythonify_default(s: str) -> object:
  827. if s == "true":
  828. return True
  829. elif s == "false":
  830. return False
  831. try:
  832. return int(s)
  833. except ValueError:
  834. try:
  835. return float(s)
  836. except ValueError:
  837. return s
  838. # What is a dynamic type? Over time, the semantic meaning of
  839. # dynamic type has degraded to meaninglessness (in the old days,
  840. # it captured dtype-ness of types, but that has gone away with
  841. # the removal of TH). These days, it's mostly the same thing as
  842. # the C++ API argument type, except that Tensor and Tensor?
  843. # arguments simply present as Tensor.
  844. #
  845. # TODO: Get rid of dynamic_type, after getting tools/autograd
  846. # to use the new codegen framework
  847. def dynamic_type(t: Type) -> str:
  848. if isinstance(t, OptionalType):
  849. return dynamic_type(t.elem)
  850. # Note we don't use t.is_tensor_like() here because it would
  851. # also include Tensor[]
  852. if str(t) == "Tensor":
  853. return "at::Tensor"
  854. return cpp.argumenttype_type(t, mutable=False, binds="__placeholder__").cpp_type()
  855. def compute_method_of_yaml(variants: Set[Variant]) -> List[str]:
  856. # This is written out explicitly to ensure that Tensor and
  857. # namespace are put into the list in the right order
  858. method_of = ["Type"]
  859. if Variant.method in variants:
  860. method_of.append("Tensor")
  861. if Variant.function in variants:
  862. method_of.append("namespace")
  863. return method_of
  864. def compute_returns_yaml(
  865. f: NativeFunction,
  866. ) -> Tuple[List[Dict[str, str]], Dict[str, str]]:
  867. # Note [name and field_name]
  868. # ~~~~~~~~~~~~~~~~~~~~~~~~~~
  869. # To understand name_to_field_name, we must first talk about this
  870. # schema:
  871. #
  872. # lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
  873. #
  874. # There is something very odd about this schema: it is an out
  875. # variant of the function (that is to say, it will convert into
  876. # at::lstsq_out() in the C++ API), but the names of the output
  877. # return arguments don't match the keyword argument names of
  878. # the inputs. It TURNS OUT that in this situation, the historical
  879. # Declarations.yaml we want to output is this (abbreviated to
  880. # only show relevant fields):
  881. #
  882. # arguments:
  883. # ...
  884. # - field_name: solution
  885. # name: X
  886. # - field_name: QR
  887. # name: qr
  888. # ...
  889. #
  890. # returns:
  891. # - field_name: solution
  892. # name: X
  893. # - field_name: QR
  894. # name: qr
  895. #
  896. # The name of the return fields is stored in 'field_name', and the
  897. # name of the arguments is stored in 'name'. So when we process
  898. # arguments, we need a way to get at the corresponding return. At
  899. # the moment, this is most conveniently done by constructing a
  900. # mapping from name (the argument concept) to field_name (the
  901. # return concept) while processing return arguments, since we don't
  902. # directly maintain this correspondence in the modeling of function
  903. # schema itself.
  904. #
  905. # See also https://github.com/pytorch/pytorch/issues/43114
  906. name_to_field_name: Dict[str, str] = {}
  907. # Compute the returns field of the YAML entry
  908. names = cpp.return_names(f)
  909. returns = []
  910. for i, (r, name) in enumerate(zip(f.func.returns, names)):
  911. ret = {
  912. "dynamic_type": dynamic_type(r.type),
  913. "name": name,
  914. "type": cpp.return_type(r).cpp_type(),
  915. }
  916. if r.name:
  917. # See Note [name and field_name]
  918. ret["field_name"] = r.name
  919. if f.func.is_out_fn():
  920. name_to_field_name[f.func.arguments.out[i].name] = r.name
  921. returns.append(ret)
  922. return returns, name_to_field_name
  923. # arguments in yaml roughly corresponds to the public C++ API
  924. def compute_cpp_argument_yaml(
  925. cpp_a: Binding,
  926. *,
  927. schema_order: bool,
  928. kwarg_only_set: Set[str],
  929. out_arg_set: Set[str],
  930. name_to_field_name: Dict[str, str],
  931. ) -> object:
  932. if isinstance(cpp_a.argument, TensorOptionsArguments):
  933. arg: Dict[str, object] = {
  934. "annotation": None,
  935. "dynamic_type": "at::TensorOptions",
  936. "is_nullable": False,
  937. "name": cpp_a.name,
  938. "type": cpp_a.type,
  939. "kwarg_only": True,
  940. }
  941. if cpp_a.default is not None:
  942. arg["default"] = cpp_a.default
  943. return arg
  944. elif isinstance(cpp_a.argument, SelfArgument):
  945. raise AssertionError()
  946. elif isinstance(cpp_a.argument, Argument):
  947. return compute_argument_yaml(
  948. cpp_a.argument,
  949. schema_order=schema_order,
  950. kwarg_only_set=kwarg_only_set,
  951. out_arg_set=out_arg_set,
  952. name_to_field_name=name_to_field_name,
  953. )
  954. def compute_argument_yaml(
  955. a: Argument,
  956. *,
  957. schema_order: bool,
  958. kwarg_only_set: Set[str],
  959. out_arg_set: Set[str],
  960. name_to_field_name: Dict[str, str],
  961. ) -> object:
  962. arg: Dict[str, object] = {
  963. "annotation": str(a.annotation) if a.annotation else None,
  964. "dynamic_type": dynamic_type(a.type),
  965. "is_nullable": a.type.is_nullable(),
  966. "name": a.name,
  967. "type": cpp.argument_type(a, binds="__placeholder__").cpp_type(),
  968. }
  969. if a.default is not None:
  970. arg["default"] = pythonify_default(cpp.default_expr(a.default, a.type))
  971. if a.name in kwarg_only_set:
  972. arg["kwarg_only"] = True
  973. if a.name in out_arg_set:
  974. arg["output"] = True
  975. arg["allocate"] = True
  976. # See Note [name and field_name]
  977. if a.name in name_to_field_name:
  978. arg["field_name"] = name_to_field_name[a.name]
  979. # Historically, booleans don't get their size recorded, because it
  980. # is already built into the cpp type (e.g., std::array<bool, 4>)
  981. l = a.type.is_list_like()
  982. if l is not None and l.size is not None and str(l.elem) != "bool":
  983. arg["size"] = l.size
  984. return arg
  985. @with_native_function
  986. def compute_declaration_yaml(f: NativeFunction) -> object:
  987. returns, name_to_field_name = compute_returns_yaml(f)
  988. # These sets are used to conveniently test if an argument is a
  989. # kwarg-only or out argument
  990. kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only)
  991. out_arg_set = set(a.name for a in f.func.arguments.out)
  992. sig_group = CppSignatureGroup.from_native_function(
  993. f, method=False, fallback_binding=False
  994. )
  995. cpp_args = sig_group.signature.arguments()
  996. arguments = [
  997. compute_cpp_argument_yaml(
  998. cpp_a,
  999. schema_order=False,
  1000. kwarg_only_set=kwarg_only_set,
  1001. out_arg_set=out_arg_set,
  1002. name_to_field_name=name_to_field_name,
  1003. )
  1004. for cpp_a in cpp_args
  1005. ]
  1006. schema_order_jit_arguments = list(f.func.schema_order_arguments())
  1007. schema_order_arguments = [
  1008. compute_argument_yaml(
  1009. a,
  1010. schema_order=True,
  1011. kwarg_only_set=kwarg_only_set,
  1012. out_arg_set=out_arg_set,
  1013. name_to_field_name=name_to_field_name,
  1014. )
  1015. for a in schema_order_jit_arguments
  1016. ]
  1017. cpp_schema_order_types = [
  1018. # NB: method here doesn't matter
  1019. r.type
  1020. for a in schema_order_jit_arguments
  1021. for r in cpp.argument(
  1022. a,
  1023. method=False,
  1024. cpp_no_default_args=set(),
  1025. faithful=False,
  1026. has_tensor_options=False,
  1027. )
  1028. ]
  1029. cpp_returns = cpp.returns_type(f.func.returns).cpp_type()
  1030. schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
  1031. is_factory_method = (
  1032. any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args)
  1033. and Variant.method not in f.variants
  1034. )
  1035. return OrderedDict(
  1036. [
  1037. ("name", cpp.name(f.func)),
  1038. ("operator_name", str(f.func.name.name)),
  1039. ("overload_name", str(f.func.name.overload_name)),
  1040. ("manual_kernel_registration", f.manual_kernel_registration),
  1041. (
  1042. "category_override",
  1043. f.category_override if f.category_override is not None else "",
  1044. ),
  1045. ("schema_string", f"aten::{f.func}"),
  1046. ("arguments", arguments),
  1047. ("schema_order_cpp_signature", schema_order_cpp_signature),
  1048. ("schema_order_arguments", schema_order_arguments),
  1049. ("method_of", compute_method_of_yaml(f.variants)),
  1050. ("mode", "native"),
  1051. ("python_module", "" if f.python_module is None else f.python_module),
  1052. ("returns", returns),
  1053. ("inplace", f.func.name.name.inplace),
  1054. ("is_factory_method", is_factory_method),
  1055. ("abstract", f.is_abstract),
  1056. ("device_guard", f.device_guard),
  1057. ("with_gil", False),
  1058. ("deprecated", False),
  1059. ("has_math_kernel", f.has_composite_implicit_autograd_kernel),
  1060. ]
  1061. )
  1062. # See Note [Auto generated composite kernels]
  1063. def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
  1064. return (f.structured or f.structured_delegate is not None) and (
  1065. f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace
  1066. )
  1067. @with_native_function_and_indices
  1068. def compute_registration_declarations(
  1069. f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex]
  1070. ) -> str:
  1071. name = dispatcher.name(f.func)
  1072. returns_type = dispatcher.returns_type(
  1073. f.func.returns
  1074. ).cpp_type_registration_declarations()
  1075. args = dispatcher.arguments(f.func)
  1076. args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args)
  1077. comment_data: Dict[str, str] = {
  1078. "schema": f"aten::{f.func}",
  1079. # TODO: What exactly is the semantics of the 'dispatch' field?
  1080. "dispatch": str(
  1081. {k for k, v in backend_indices.items() if v.has_kernel(f)}
  1082. != {DispatchKey.CompositeImplicitAutograd}
  1083. ),
  1084. "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
  1085. }
  1086. return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
  1087. """
  1088. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  1089. #
  1090. # RUN IT ALL
  1091. #
  1092. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  1093. def get_custom_build_selector(
  1094. provided_op_registration_allowlist: Optional[List[str]],
  1095. op_selection_yaml_path: Optional[str],
  1096. ) -> SelectiveBuilder:
  1097. assert not (
  1098. provided_op_registration_allowlist is not None
  1099. and op_selection_yaml_path is not None
  1100. ), (
  1101. "Both provided_op_registration_allowlist and "
  1102. + "op_selection_yaml_path can NOT be provided at the "
  1103. + "same time."
  1104. )
  1105. op_registration_allowlist: Optional[Set[str]] = None
  1106. if provided_op_registration_allowlist is not None:
  1107. op_registration_allowlist = set(provided_op_registration_allowlist)
  1108. if op_registration_allowlist is not None:
  1109. selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
  1110. op_registration_allowlist,
  1111. True,
  1112. False,
  1113. )
  1114. elif op_selection_yaml_path is not None:
  1115. selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
  1116. else:
  1117. selector = SelectiveBuilder.get_nop_selector()
  1118. return selector
  1119. def get_grouped_by_view_native_functions(
  1120. native_functions: Sequence[NativeFunction],
  1121. ) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]:
  1122. def maybe_create_view_group(
  1123. d: Dict[Union[ViewSchemaKind, SchemaKind], NativeFunction]
  1124. ) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]:
  1125. funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = []
  1126. if ViewSchemaKind.aliasing in d:
  1127. view = d.pop(ViewSchemaKind.aliasing)
  1128. view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
  1129. view_copy = d.pop(SchemaKind.functional, None)
  1130. funcs.append(
  1131. NativeFunctionsViewGroup(
  1132. view=view,
  1133. view_copy=view_copy,
  1134. view_inplace=view_inplace,
  1135. )
  1136. )
  1137. # Take the remaining functions that weren't part of the view group
  1138. # and emit them separately
  1139. for func in d.values():
  1140. funcs.append(func)
  1141. return funcs
  1142. grouped_by_views: Dict[
  1143. FunctionSchema, Dict[Union[SchemaKind, ViewSchemaKind], NativeFunction]
  1144. ] = defaultdict(dict)
  1145. for f in native_functions:
  1146. schema = f.func.view_signature()
  1147. view_kind: ViewSchemaKind = f.view_schema_kind
  1148. # We need to group up ops relevant to the same "view", consisting of:
  1149. # view op (ViewSchemaKind.aliasing)
  1150. # view_inplace op (ViewSchemaKind.aliasing_inplace)
  1151. # view_copy op (SchemaKind.functional)
  1152. if view_kind == ViewSchemaKind.non_aliasing:
  1153. kind = f.func.kind()
  1154. assert kind not in grouped_by_views[schema]
  1155. grouped_by_views[schema][kind] = f
  1156. else:
  1157. assert view_kind not in grouped_by_views[schema]
  1158. grouped_by_views[schema][view_kind] = f
  1159. return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
  1160. def get_grouped_native_functions(
  1161. native_functions: Sequence[NativeFunction],
  1162. ) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
  1163. def flatten_pre_group(
  1164. d: Dict[SchemaKind, NativeFunction]
  1165. ) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
  1166. r = NativeFunctionsGroup.from_dict(d)
  1167. if r is None:
  1168. # Invariant: any NativeFunctions that are code-generated
  1169. # should have been grouped into NativeFunctionsGroup objects
  1170. assert not any("generated" in f.tags for f in d.values())
  1171. return list(d.values())
  1172. else:
  1173. return [r]
  1174. # TODO: how come ValuesView isn't a Sequence lol
  1175. pre_grouped_native_functions = pre_group_native_functions(native_functions)
  1176. return list(
  1177. concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))
  1178. )
  1179. def gen_aggregated_headers(
  1180. *,
  1181. native_functions: Sequence[NativeFunction],
  1182. grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
  1183. structured_native_functions: Sequence[NativeFunctionsGroup],
  1184. static_dispatch_idx: List[BackendIndex],
  1185. selector: SelectiveBuilder,
  1186. backend_indices: Dict[DispatchKey, BackendIndex],
  1187. cpu_fm: FileManager,
  1188. cuda_fm: FileManager,
  1189. functions_keys: Set[DispatchKey],
  1190. dispatch_keys: Sequence[DispatchKey],
  1191. rocm: bool,
  1192. ) -> None:
  1193. # Buck doesn't support dynamic output files, so we aggregate all operator
  1194. # headers into a single file
  1195. cpu_fm.write(
  1196. "NativeMetaFunctions.h",
  1197. lambda: {
  1198. "NativeMetaFunctions_includes": [],
  1199. "NativeMetaFunctions_declarations": list(
  1200. mapMaybe(compute_meta_function_declaration, structured_native_functions)
  1201. ),
  1202. },
  1203. )
  1204. method_native_functions = [
  1205. fn for fn in native_functions if Variant.method in fn.variants
  1206. ]
  1207. non_method_native_functions = [
  1208. fn for fn in native_functions if fn not in method_native_functions
  1209. ]
  1210. cpu_fm.write(
  1211. "MethodOperators.h",
  1212. lambda: {
  1213. "MethodOperators_includes": [],
  1214. "MethodOperators_declarations": list(
  1215. mapMaybe(
  1216. ComputeOperators(
  1217. Target.DECLARATION,
  1218. static_dispatch_backend_indices=static_dispatch_idx,
  1219. ),
  1220. method_native_functions,
  1221. )
  1222. ),
  1223. },
  1224. )
  1225. cpu_fm.write(
  1226. "Operators.h",
  1227. lambda: {
  1228. "Operators_includes": ["#include <ATen/MethodOperators.h>"],
  1229. "Operators_declarations": list(
  1230. mapMaybe(
  1231. ComputeOperators(
  1232. Target.DECLARATION,
  1233. static_dispatch_backend_indices=static_dispatch_idx,
  1234. ),
  1235. non_method_native_functions,
  1236. )
  1237. ),
  1238. },
  1239. )
  1240. cpu_fm.write(
  1241. "Functions.h",
  1242. lambda: {
  1243. "static_dispatch_extra_headers": static_dispatch_extra_headers(
  1244. static_dispatch_idx
  1245. ),
  1246. "Functions_includes": ["#include <ATen/Operators.h>"],
  1247. "Functions_declarations": list(
  1248. mapMaybe(
  1249. ComputeFunction(),
  1250. native_functions,
  1251. )
  1252. ),
  1253. },
  1254. )
  1255. cpu_fm.write(
  1256. "NativeFunctions.h",
  1257. lambda: {
  1258. "NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"],
  1259. "NativeFunctions_declarations": list(
  1260. concatMap(
  1261. # Convert to a set first to remove duplicate kernel names.
  1262. # Backends are allowed to repeat kernel names; only generate the declaration once!
  1263. lambda f: list(
  1264. OrderedDict.fromkeys(
  1265. concatMap(
  1266. lambda backend_idx: dest.compute_native_function_declaration(
  1267. f, backend_idx
  1268. ),
  1269. backend_indices.values(),
  1270. )
  1271. )
  1272. ),
  1273. grouped_native_functions,
  1274. )
  1275. ),
  1276. },
  1277. )
  1278. for dispatch_key in dispatch_keys:
  1279. fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
  1280. if dispatch_key in functions_keys:
  1281. inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
  1282. fm.write_with_template(
  1283. f"{dispatch_key}Functions.h",
  1284. "DispatchKeyFunctions.h",
  1285. lambda: {
  1286. "dispatch_key": str(dispatch_key),
  1287. "inline_headers": inl_headers,
  1288. },
  1289. )
  1290. fm.write_with_template(
  1291. f"{dispatch_key}Functions_inl.h",
  1292. "DispatchKeyFunctions_inl.h",
  1293. lambda: {
  1294. "DispatchKeyFunctions_inl_includes": [],
  1295. "dispatch_namespace": dispatch_key.lower(),
  1296. "dispatch_namespaced_declarations": list(
  1297. concatMap(
  1298. dest.RegisterDispatchKey(
  1299. backend_indices[dispatch_key],
  1300. Target.NAMESPACED_DECLARATION,
  1301. selector,
  1302. rocm=rocm,
  1303. cpp_namespace="at::native",
  1304. class_method_name=None,
  1305. skip_dispatcher_op_registration=False,
  1306. ),
  1307. grouped_native_functions,
  1308. )
  1309. ),
  1310. },
  1311. )
  1312. del fm
  1313. def gen_per_operator_headers(
  1314. *,
  1315. native_functions: Sequence[NativeFunction],
  1316. grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
  1317. static_dispatch_idx: List[BackendIndex],
  1318. selector: SelectiveBuilder,
  1319. backend_indices: Dict[DispatchKey, BackendIndex],
  1320. cpu_fm: FileManager,
  1321. cuda_fm: FileManager,
  1322. ops_fm: FileManager,
  1323. functions_keys: Set[DispatchKey],
  1324. dispatch_keys: Sequence[DispatchKey],
  1325. rocm: bool,
  1326. ) -> None:
  1327. # For CMake builds, split operator declarations into separate headers in
  1328. # the ATen/ops folder to split up header dependencies
  1329. functions_by_root_name: Dict[str, List[NativeFunction]] = defaultdict(lambda: [])
  1330. for fn in native_functions:
  1331. functions_by_root_name[fn.root_name].append(fn)
  1332. grouped_functions_by_root_name: Dict[
  1333. str, List[Union[NativeFunction, NativeFunctionsGroup]]
  1334. ] = defaultdict(lambda: [])
  1335. for group in grouped_native_functions:
  1336. name = group.root_name
  1337. grouped_functions_by_root_name[name].append(group)
  1338. for name, functions in functions_by_root_name.items():
  1339. ops_fm.write_with_template(
  1340. f"{name}_ops.h",
  1341. "Operator.h",
  1342. lambda: {
  1343. "declarations": list(
  1344. mapMaybe(
  1345. ComputeOperators(
  1346. Target.DECLARATION,
  1347. static_dispatch_backend_indices=static_dispatch_idx,
  1348. ),
  1349. functions,
  1350. )
  1351. ),
  1352. },
  1353. )
  1354. ops_fm.write_with_template(
  1355. f"{name}.h",
  1356. "Function.h",
  1357. lambda: {
  1358. "static_dispatch_ops_headers": list(
  1359. mapMaybe(
  1360. lambda fn: static_dispatch_ops_header(
  1361. fn, backend_index=static_dispatch_idx
  1362. ),
  1363. functions,
  1364. )
  1365. ),
  1366. "operator_includes": f"#include <ATen/ops/{name}_ops.h>",
  1367. "function_definitions": list(
  1368. mapMaybe(
  1369. ComputeFunction(),
  1370. functions,
  1371. )
  1372. ),
  1373. },
  1374. )
  1375. grouped_functions = grouped_functions_by_root_name.get(name, [])
  1376. structured_functions = [
  1377. fn
  1378. for fn in grouped_functions
  1379. if isinstance(fn, NativeFunctionsGroup) and fn.structured
  1380. ]
  1381. is_structured = len(structured_functions) > 0
  1382. if is_structured:
  1383. ops_fm.write_with_template(
  1384. f"{name}_meta.h",
  1385. "NativeMetaFunction.h",
  1386. lambda: {
  1387. "meta_function_declarations": list(
  1388. mapMaybe(
  1389. compute_meta_function_declaration, structured_functions
  1390. )
  1391. ),
  1392. },
  1393. )
  1394. ops_fm.write_with_template(
  1395. f"{name}_native.h",
  1396. "NativeFunction.h",
  1397. lambda: {
  1398. "extra_includes": (
  1399. f"#include <ATen/ops/{name}_meta.h>" if is_structured else []
  1400. ),
  1401. "native_function_declarations": list(
  1402. concatMap(
  1403. # Convert to a set first to remove duplicate kernel names.
  1404. # Backends are allowed to repeat kernel names; only generate the declaration once!
  1405. lambda f: list(
  1406. OrderedDict.fromkeys(
  1407. concatMap(
  1408. lambda backend_idx: dest.compute_native_function_declaration(
  1409. f, backend_idx
  1410. ),
  1411. backend_indices.values(),
  1412. )
  1413. )
  1414. ),
  1415. grouped_functions,
  1416. )
  1417. ),
  1418. },
  1419. )
  1420. for category, suffix in [
  1421. ("Functions", ""),
  1422. ("Operators", "_ops"),
  1423. ("NativeMetaFunctions", "_meta"),
  1424. ("NativeFunctions", "_native"),
  1425. ]:
  1426. cpu_fm.write(
  1427. f"{category}.h",
  1428. lambda: {
  1429. f"{category}_includes": [
  1430. f"#include <ATen/ops/{name}{suffix}.h>"
  1431. for name in sorted(functions_by_root_name.keys())
  1432. ],
  1433. f"{category}_declarations": [],
  1434. },
  1435. )
  1436. for dispatch_key in dispatch_keys:
  1437. if dispatch_key not in functions_keys:
  1438. continue
  1439. dispatch_namespace = dispatch_key.lower()
  1440. dispatch_names = []
  1441. for name, functions in functions_by_root_name.items():
  1442. grouped_functions = grouped_functions_by_root_name.get(name, [])
  1443. declarations = list(
  1444. concatMap(
  1445. dest.RegisterDispatchKey(
  1446. backend_indices[dispatch_key],
  1447. Target.NAMESPACED_DECLARATION,
  1448. selector,
  1449. rocm=rocm,
  1450. cpp_namespace="at::native",
  1451. class_method_name=None,
  1452. skip_dispatcher_op_registration=False,
  1453. ),
  1454. grouped_functions,
  1455. )
  1456. )
  1457. if len(declarations) == 0:
  1458. continue
  1459. dispatch_names.append(name)
  1460. ops_fm.write_with_template(
  1461. f"{name}_{dispatch_namespace}_dispatch.h",
  1462. "DispatchKeyFunction.h",
  1463. lambda: {
  1464. "dispatch_namespace": dispatch_namespace,
  1465. "dispatch_namespaced_declarations": declarations,
  1466. },
  1467. )
  1468. fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
  1469. inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
  1470. fm.write_with_template(
  1471. f"{dispatch_key}Functions.h",
  1472. "DispatchKeyFunctions.h",
  1473. lambda: {
  1474. "dispatch_key": str(dispatch_key),
  1475. "inline_headers": inl_headers,
  1476. },
  1477. )
  1478. fm.write_with_template(
  1479. f"{dispatch_key}Functions_inl.h",
  1480. "DispatchKeyFunctions_inl.h",
  1481. lambda: {
  1482. "dispatch_namespace": dispatch_namespace,
  1483. "DispatchKeyFunctions_inl_includes": [
  1484. f"#include <ATen/ops/{name}_{dispatch_namespace}_dispatch.h>"
  1485. for name in sorted(dispatch_names)
  1486. ],
  1487. "dispatch_namespaced_declarations": [],
  1488. },
  1489. )
  1490. del fm
  1491. cpu_fm.write(
  1492. "MethodOperators.h",
  1493. lambda: {
  1494. "MethodOperators_includes": sorted(
  1495. f"#include <ATen/ops/{name}_ops.h>"
  1496. for name, functions in functions_by_root_name.items()
  1497. if any(Variant.method in fn.variants for fn in functions)
  1498. ),
  1499. "MethodOperators_declarations": [],
  1500. },
  1501. )
  1502. def gen_headers(
  1503. *,
  1504. native_functions: Sequence[NativeFunction],
  1505. grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
  1506. structured_native_functions: Sequence[NativeFunctionsGroup],
  1507. static_dispatch_idx: List[BackendIndex],
  1508. selector: SelectiveBuilder,
  1509. backend_indices: Dict[DispatchKey, BackendIndex],
  1510. core_fm: FileManager,
  1511. cpu_fm: FileManager,
  1512. cuda_fm: FileManager,
  1513. ops_fm: FileManager,
  1514. dispatch_keys: Sequence[DispatchKey],
  1515. functions_keys: Set[DispatchKey],
  1516. rocm: bool,
  1517. per_operator_headers: bool,
  1518. ) -> None:
  1519. if per_operator_headers:
  1520. gen_per_operator_headers(
  1521. native_functions=native_functions,
  1522. grouped_native_functions=grouped_native_functions,
  1523. static_dispatch_idx=static_dispatch_idx,
  1524. selector=selector,
  1525. backend_indices=backend_indices,
  1526. cpu_fm=cpu_fm,
  1527. cuda_fm=cuda_fm,
  1528. ops_fm=ops_fm,
  1529. dispatch_keys=dispatch_keys,
  1530. functions_keys=functions_keys,
  1531. rocm=rocm,
  1532. )
  1533. else:
  1534. gen_aggregated_headers(
  1535. native_functions=native_functions,
  1536. grouped_native_functions=grouped_native_functions,
  1537. structured_native_functions=structured_native_functions,
  1538. static_dispatch_idx=static_dispatch_idx,
  1539. selector=selector,
  1540. backend_indices=backend_indices,
  1541. cpu_fm=cpu_fm,
  1542. cuda_fm=cuda_fm,
  1543. dispatch_keys=dispatch_keys,
  1544. functions_keys=functions_keys,
  1545. rocm=rocm,
  1546. )
  1547. core_fm.write(
  1548. "TensorBody.h",
  1549. lambda: {
  1550. "tensor_method_declarations": list(
  1551. mapMaybe(
  1552. ComputeTensorMethod(
  1553. target=Target.DECLARATION,
  1554. static_dispatch_backend_indices=static_dispatch_idx,
  1555. ),
  1556. native_functions,
  1557. )
  1558. ),
  1559. "tensor_method_definitions": list(
  1560. mapMaybe(
  1561. ComputeTensorMethod(
  1562. target=Target.DEFINITION,
  1563. static_dispatch_backend_indices=static_dispatch_idx,
  1564. ),
  1565. native_functions,
  1566. )
  1567. ),
  1568. },
  1569. )
  1570. cpu_fm.write(
  1571. "RedispatchFunctions.h",
  1572. lambda: {
  1573. "function_redispatch_definitions": list(
  1574. mapMaybe(ComputeRedispatchFunction(), native_functions)
  1575. ),
  1576. },
  1577. )
  1578. cpu_fm.write(
  1579. "RegistrationDeclarations.h",
  1580. lambda: {
  1581. "registration_declarations": [
  1582. compute_registration_declarations(f, backend_indices)
  1583. for f in native_functions
  1584. ],
  1585. },
  1586. )
  1587. def gen_aten_interned_strings() -> Dict[str, str]:
  1588. attrs = set() # All function argument names
  1589. names = set() # All ATen function names
  1590. for func in native_functions:
  1591. names.add(str(func.func.name.name))
  1592. # Some operators don't have a functional variant but we still create a
  1593. # symbol without the underscore
  1594. names.add(func.func.name.name.base)
  1595. for arg in func.func.schema_order_arguments():
  1596. attrs.add(arg.name)
  1597. # These are keywords in C++, so aren't valid symbol names
  1598. # https://en.cppreference.com/w/cpp/language/operator_alternative
  1599. names -= set(
  1600. [
  1601. "and",
  1602. "and_eq",
  1603. "bitand",
  1604. "bitor",
  1605. "compl",
  1606. "not",
  1607. "not_eq",
  1608. "or",
  1609. "or_eq",
  1610. "xor",
  1611. "xor_eq",
  1612. ]
  1613. )
  1614. return {
  1615. "aten_symbols": " \\\n".join(
  1616. [f"_(aten, {name})" for name in sorted(names)]
  1617. ),
  1618. "attr_symbols": " \\\n".join(
  1619. [f"_(attr, {name})" for name in sorted(attrs)]
  1620. ),
  1621. }
  1622. core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
  1623. def gen_source_files(
  1624. *,
  1625. native_functions: Sequence[NativeFunction],
  1626. grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
  1627. structured_native_functions: Sequence[NativeFunctionsGroup],
  1628. view_groups: Sequence[NativeFunctionsViewGroup],
  1629. selector: SelectiveBuilder,
  1630. static_dispatch_idx: List[BackendIndex],
  1631. backend_indices: Dict[DispatchKey, BackendIndex],
  1632. core_fm: FileManager,
  1633. cpu_fm: FileManager,
  1634. cpu_vec_fm: FileManager,
  1635. cuda_fm: FileManager,
  1636. dispatch_keys: Sequence[DispatchKey],
  1637. functions_keys: Set[DispatchKey],
  1638. rocm: bool,
  1639. force_schema_registration: bool,
  1640. per_operator_headers: bool,
  1641. skip_dispatcher_op_registration: bool,
  1642. ) -> None:
  1643. extra_cuda_headers = """\
  1644. #include <c10/cuda/CUDAGuard.h>
  1645. #include <ATen/cuda/ATenCUDAGeneral.h>
  1646. #include <ATen/cuda/CUDADevice.h>
  1647. #include <ATen/cuda/CUDAContext.h>"""
  1648. if rocm:
  1649. extra_cuda_headers = """\
  1650. #include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
  1651. #include <ATen/hip/ATenHIPGeneral.h>
  1652. #include <ATen/hip/HIPDevice.h>
  1653. #include <ATen/hip/HIPContext.h>"""
  1654. for dispatch_key in dispatch_keys:
  1655. fm = cuda_fm if is_cuda_dispatch_key(dispatch_key) else cpu_fm
  1656. if per_operator_headers:
  1657. def operator_headers() -> List[str]:
  1658. headers = []
  1659. for g in grouped_native_functions:
  1660. is_registered = False
  1661. if backend_index.has_kernel(g):
  1662. is_registered = True
  1663. # The above has_kernel test on a group will only test for
  1664. # the existence of out dispatch, because that's how
  1665. # structured kernels work. But sometimes functions can be
  1666. # grouped but not be structured, and then you need to check
  1667. # each individual piece, as they may have manual dispatch
  1668. # entries.
  1669. elif isinstance(g, NativeFunctionsGroup) and any(
  1670. backend_index.has_kernel(fn) for fn in g.functions()
  1671. ):
  1672. is_registered = True
  1673. # TODO: this condition is a bit questionable
  1674. elif g.structured and dispatch_key in (
  1675. DispatchKey.Meta,
  1676. DispatchKey.CompositeExplicitAutograd,
  1677. ):
  1678. is_registered = True
  1679. if not is_registered:
  1680. continue
  1681. headers.append(f"#include <ATen/ops/{g.root_name}_native.h>")
  1682. if dispatch_key == DispatchKey.CompositeExplicitAutograd:
  1683. headers.append(f"#include <ATen/ops/{g.root_name}.h>")
  1684. if dispatch_key in functions_keys:
  1685. headers.append(
  1686. f"#include <ATen/ops/{g.root_name}_{dispatch_namespace}_dispatch.h>"
  1687. )
  1688. return sorted(set(headers))
  1689. else:
  1690. def operator_headers() -> List[str]:
  1691. headers = ["#include <ATen/NativeFunctions.h>"]
  1692. if dispatch_key == DispatchKey.CompositeExplicitAutograd:
  1693. headers.append("#include <ATen/Functions.h>")
  1694. if dispatch_key in functions_keys:
  1695. headers.append(f"#include <ATen/{dispatch_key!s}Functions.h>")
  1696. return headers
  1697. backend_index = backend_indices[dispatch_key]
  1698. dispatch_registrations_body = (
  1699. ""
  1700. if skip_dispatcher_op_registration
  1701. else "\n".join(
  1702. list(
  1703. concatMap(
  1704. dest.RegisterDispatchKey(
  1705. backend_index,
  1706. Target.REGISTRATION,
  1707. selector,
  1708. rocm=rocm,
  1709. cpp_namespace="at::native",
  1710. class_method_name=None,
  1711. skip_dispatcher_op_registration=skip_dispatcher_op_registration,
  1712. ),
  1713. grouped_native_functions,
  1714. )
  1715. )
  1716. )
  1717. )
  1718. static_template = CodeTemplate(
  1719. """\
  1720. TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
  1721. $dispatch_registrations_body
  1722. };"""
  1723. )
  1724. static_init_dispatch_registrations = static_template.substitute(
  1725. dispatch_key=dispatch_key,
  1726. dispatch_registrations_body=dispatch_registrations_body,
  1727. )
  1728. dispatch_namespace = str(dispatch_key).lower()
  1729. fm.write_with_template(
  1730. f"Register{dispatch_key}.cpp",
  1731. "RegisterDispatchKey.cpp",
  1732. lambda: {
  1733. "extra_cuda_headers": extra_cuda_headers
  1734. if is_cuda_dispatch_key(dispatch_key)
  1735. else "",
  1736. "external_backend_headers": "",
  1737. "dispatch_headers": dest.gen_registration_headers(
  1738. backend_index, per_operator_headers, rocm
  1739. ),
  1740. "ops_headers": operator_headers(),
  1741. "DispatchKey": dispatch_key,
  1742. "dispatch_namespace": dispatch_key.lower(),
  1743. "dispatch_helpers": dest.gen_registration_helpers(backend_index),
  1744. "dispatch_namespaced_definitions": list(
  1745. concatMap(
  1746. dest.RegisterDispatchKey(
  1747. backend_index,
  1748. Target.NAMESPACED_DEFINITION,
  1749. selector,
  1750. rocm=rocm,
  1751. cpp_namespace="at::native",
  1752. class_method_name=None,
  1753. skip_dispatcher_op_registration=skip_dispatcher_op_registration,
  1754. ),
  1755. grouped_native_functions,
  1756. )
  1757. ),
  1758. "dispatch_anonymous_definitions": list(
  1759. concatMap(
  1760. dest.RegisterDispatchKey(
  1761. backend_index,
  1762. Target.ANONYMOUS_DEFINITION,
  1763. selector,
  1764. rocm=rocm,
  1765. cpp_namespace="at::native",
  1766. class_method_name=None,
  1767. skip_dispatcher_op_registration=skip_dispatcher_op_registration,
  1768. ),
  1769. grouped_native_functions,
  1770. )
  1771. ),
  1772. "static_init_dispatch_registrations": static_init_dispatch_registrations,
  1773. "deferred_dispatch_registrations": "",
  1774. },
  1775. )
  1776. for g in structured_native_functions:
  1777. if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key):
  1778. continue
  1779. name = g.functional.func.name.name
  1780. if dispatch_key is DispatchKey.CPU:
  1781. assert fm is cpu_fm
  1782. fm.write_with_template(
  1783. f"UfuncCPU_{name}.cpp",
  1784. "UfuncCPU.cpp",
  1785. lambda: {
  1786. "meta_declaration": compute_meta_function_declaration(g),
  1787. "native_declaration": dest.compute_native_function_declaration(
  1788. g, backend_indices[dispatch_key]
  1789. ),
  1790. "native_definitions": dest.compute_ufunc_cpu(g),
  1791. },
  1792. )
  1793. cpu_vec_fm.write_with_template(
  1794. f"UfuncCPUKernel_{name}.cpp",
  1795. "UfuncCPUKernel.cpp",
  1796. lambda: {
  1797. "name": name,
  1798. "native_definitions": dest.compute_ufunc_cpu_kernel(g),
  1799. },
  1800. )
  1801. elif dispatch_key is DispatchKey.CUDA:
  1802. cuda_headers = "#include <ATen/native/cuda/Loops.cuh>"
  1803. if rocm:
  1804. cuda_headers = "#include <ATen/native/hip/Loops.cuh>"
  1805. fm.write_with_template(
  1806. f"UfuncCUDA_{name}.cu",
  1807. "UfuncCUDA.cu",
  1808. lambda: {
  1809. "name": name,
  1810. "cuda_headers": cuda_headers,
  1811. "meta_declaration": compute_meta_function_declaration(g),
  1812. "native_declaration": dest.compute_native_function_declaration(
  1813. g, backend_indices[dispatch_key]
  1814. ),
  1815. "native_definitions": dest.compute_ufunc_cuda(g),
  1816. },
  1817. )
  1818. else:
  1819. raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
  1820. del fm
  1821. # BackendSelect is generated specially
  1822. def gen_backend_select() -> Dict[str, List[str]]:
  1823. relevant_fns = [
  1824. fn for fn in native_functions if needs_backend_select(fn, selector)
  1825. ]
  1826. return {
  1827. "ops_headers": [
  1828. f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in relevant_fns
  1829. ],
  1830. "backend_select_method_definitions": list(
  1831. mapMaybe(
  1832. ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns
  1833. )
  1834. ),
  1835. "backend_select_function_registrations": list(
  1836. mapMaybe(
  1837. ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns
  1838. )
  1839. ),
  1840. }
  1841. cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select)
  1842. schema_selector = selector
  1843. if force_schema_registration:
  1844. schema_selector = SelectiveBuilder.get_nop_selector()
  1845. cpu_fm.write(
  1846. "RegisterSchema.cpp",
  1847. lambda: {
  1848. "schema_registrations": []
  1849. if skip_dispatcher_op_registration
  1850. else list(mapMaybe(RegisterSchema(schema_selector), native_functions)),
  1851. },
  1852. )
  1853. def key_func(
  1854. fn: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
  1855. ) -> str:
  1856. return fn.root_name
  1857. cpu_fm.write_sharded(
  1858. "Operators.cpp",
  1859. native_functions,
  1860. key_fn=key_func,
  1861. env_callable=lambda fn: {
  1862. "operator_headers": [f"#include <ATen/ops/{fn.root_name}.h>"],
  1863. "definitions": [
  1864. ComputeOperators(
  1865. Target.DEFINITION,
  1866. static_dispatch_backend_indices=static_dispatch_idx,
  1867. )(fn)
  1868. ],
  1869. },
  1870. base_env={
  1871. "static_dispatch_extra_headers": static_dispatch_extra_headers(
  1872. static_dispatch_idx
  1873. ),
  1874. },
  1875. num_shards=5,
  1876. sharded_keys={
  1877. "operator_headers",
  1878. "definitions",
  1879. "static_dispatch_extra_headers",
  1880. },
  1881. )
  1882. cpu_fm.write("Functions.cpp", lambda: {})
  1883. core_fm.write("TensorMethods.cpp", lambda: {})
  1884. core_fm.write(
  1885. "ATenOpList.cpp",
  1886. lambda: {
  1887. "aten_ops": list(mapMaybe(compute_aten_op, native_functions)),
  1888. },
  1889. )
  1890. def functionalization_env_callable(
  1891. g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
  1892. ) -> Dict[str, List[str]]:
  1893. def gen_op_headers(
  1894. g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
  1895. ) -> List[str]:
  1896. if isinstance(g, NativeFunctionsViewGroup):
  1897. # view ops always get a functionalization kernel
  1898. headers = [
  1899. f"#include <ATen/ops/{g.view.root_name}_native.h>",
  1900. f"#include <ATen/ops/{g.view.root_name}_ops.h>",
  1901. ]
  1902. if g.view_copy is not None:
  1903. headers += [
  1904. f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
  1905. f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
  1906. ]
  1907. return headers
  1908. elif isinstance(g, NativeFunctionsGroup):
  1909. headers = [
  1910. f"#include <ATen/ops/{g.functional.root_name}_native.h>",
  1911. f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
  1912. f"#include <ATen/ops/{g.out.root_name}_native.h>",
  1913. f"#include <ATen/ops/{g.out.root_name}_ops.h>",
  1914. ]
  1915. if g.inplace is not None:
  1916. headers += [
  1917. f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
  1918. f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
  1919. ]
  1920. if g.mutable is not None:
  1921. headers += [
  1922. f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
  1923. f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
  1924. ]
  1925. return headers
  1926. else:
  1927. return [
  1928. f"#include <ATen/ops/{g.root_name}_native.h>",
  1929. f"#include <ATen/ops/{g.root_name}_ops.h>",
  1930. ]
  1931. return {
  1932. "ops_headers": gen_op_headers(g),
  1933. "func_definitions": gen_functionalization_definition(
  1934. selector,
  1935. g,
  1936. ),
  1937. "func_registrations": gen_functionalization_registration(
  1938. selector,
  1939. g,
  1940. backend_indices[DispatchKey.CompositeImplicitAutograd],
  1941. ),
  1942. }
  1943. all_groups: List[
  1944. Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
  1945. ] = list(structured_native_functions) + list(
  1946. view_groups # type: ignore[assignment, arg-type, operator]
  1947. )
  1948. # Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly.
  1949. # The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because:
  1950. # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
  1951. # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
  1952. # Although this could go away long-term if we add a dedicated dispatch key for decompositions.
  1953. structured_map: Dict[OperatorName, NativeFunction] = {
  1954. f.func.name: f
  1955. for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
  1956. }
  1957. view_map: Dict[OperatorName, NativeFunction] = {
  1958. f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
  1959. }
  1960. for f in native_functions:
  1961. if f.func.name not in structured_map and f.func.name not in view_map:
  1962. all_groups.append(f)
  1963. cpu_fm.write_sharded(
  1964. "RegisterFunctionalization.cpp",
  1965. all_groups,
  1966. key_fn=key_func,
  1967. env_callable=functionalization_env_callable,
  1968. num_shards=4,
  1969. sharded_keys={
  1970. "ops_headers",
  1971. "func_definitions",
  1972. "func_registrations",
  1973. "func_add_back_views_definitions",
  1974. "func_add_back_views_registrations",
  1975. },
  1976. )
  1977. cpu_fm.write(
  1978. "FunctionalInverses.h",
  1979. lambda: {
  1980. "view_inverse_declarations": list(
  1981. mapMaybe(
  1982. lambda g: gen_functionalization_view_inverse_declaration(
  1983. selector, g
  1984. ),
  1985. view_groups,
  1986. )
  1987. )
  1988. },
  1989. )
  1990. # Note [view_copy NativeFunctions]
  1991. # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
  1992. # needs to have a corresponding non-aliasing {view}_copy variant.
  1993. # Backends that use functionalization and don't know how to handle aliasing ops
  1994. # are expected to implement kernels for these {view}_copy kernels instead.
  1995. # The code for {view}_copy operators in core is pretty boilerplate-heavy however,
  1996. # so we codegen the following:
  1997. # (1) A CompositeExplicitAutograd kernel for every {view}_copy operator.
  1998. # These are never explicitly invoked by the functionalization pass,
  1999. # but they could theoretically be called from user code (I added these kernels for completeness,
  2000. # since the ops are part of the public API).
  2001. # (2) A derivative formula for every {view}_copy operator
  2002. # {view}_copy operators can re-use the same derivative formulas as their {view} op counterparts,
  2003. # so rather than stamping all of the entries out in derivatives.yaml,
  2004. # we codegen them in.
  2005. # This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry.
  2006. cpu_fm.write(
  2007. "CompositeViewCopyKernels.cpp",
  2008. lambda: {
  2009. "ops_headers": [
  2010. "\n".join(
  2011. f"#include <ATen/ops/{f.root_name}_ops.h>"
  2012. for f in (
  2013. [g.view] if g.view_copy is None else [g.view, g.view_copy]
  2014. )
  2015. )
  2016. for g in view_groups
  2017. ]
  2018. + [
  2019. "\n".join(
  2020. f"#include <ATen/ops/{f.root_name}_ops.h>"
  2021. for f in [g.inplace, g.mutable]
  2022. if f is not None and "generated" not in f.tags
  2023. )
  2024. for g in structured_native_functions
  2025. ],
  2026. "CompositeViewCopyKernel_Definitions": list(
  2027. mapMaybe(gen_composite_view_copy_kernel, view_groups)
  2028. ),
  2029. "GeneratedCompositeFunctional_Definitions": list(
  2030. mapMaybe(
  2031. gen_composite_functional_kernel,
  2032. structured_native_functions,
  2033. )
  2034. ),
  2035. },
  2036. )
  2037. def gen_declarations_yaml(
  2038. cpu_fm: FileManager, native_functions: Sequence[NativeFunction]
  2039. ) -> None:
  2040. cpu_fm.write(
  2041. "Declarations.yaml",
  2042. lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]),
  2043. )
  2044. def get_torchgen_root() -> pathlib.Path:
  2045. """
  2046. If you're depending on torchgen out-of-tree, you can use the root to figure
  2047. out the path to native_functions.yaml
  2048. """
  2049. return pathlib.Path(__file__).parent.resolve()
  2050. def main() -> None:
  2051. parser = argparse.ArgumentParser(description="Generate ATen source files")
  2052. parser.add_argument(
  2053. "-s",
  2054. "--source-path",
  2055. help="path to source directory for ATen",
  2056. default="aten/src/ATen",
  2057. )
  2058. parser.add_argument(
  2059. "-o",
  2060. "--output-dependencies",
  2061. help="output a list of dependencies into the given file and exit",
  2062. )
  2063. parser.add_argument(
  2064. "--dry-run",
  2065. action="store_true",
  2066. help="run without writing any files (still updates outputs)",
  2067. )
  2068. parser.add_argument(
  2069. "--per-operator-headers",
  2070. action="store_true",
  2071. help="generate separate headers per operator in ATen/ops",
  2072. )
  2073. parser.add_argument(
  2074. "-d", "--install_dir", help="output directory", default="build/aten/src/ATen"
  2075. )
  2076. parser.add_argument(
  2077. "--rocm",
  2078. action="store_true",
  2079. help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
  2080. )
  2081. parser.add_argument(
  2082. "--mps",
  2083. action="store_true",
  2084. help="Generate MPS registration code when set",
  2085. )
  2086. # TODO: --op_registration_whitelist will be removed when all call-sites
  2087. # for gen.py are moved over to using the operator YAML file for mobile
  2088. # custom build.
  2089. parser.add_argument(
  2090. "--op_registration_whitelist",
  2091. nargs="*",
  2092. help="filter op registrations by the whitelist (if set); "
  2093. "each item is `namespace`::`operator name` without overload name; "
  2094. "e.g.: aten::empty aten::conv2d ...",
  2095. )
  2096. parser.add_argument(
  2097. "--op_selection_yaml_path",
  2098. help="Provide a path to the operator selection (for custom build) YAML "
  2099. "that contains the information about the set of selected operators "
  2100. "and their categories (training, ...). Each operator is either a "
  2101. "full operator name with overload or just a bare operator name. "
  2102. "The operator names also contain the namespace prefix (e.g. aten::)",
  2103. )
  2104. parser.add_argument(
  2105. "--backend_whitelist",
  2106. nargs="*",
  2107. help="filter dispatch backend by the whitelist (if set), "
  2108. "e.g.: CPU CUDA QuantizedCPU ...",
  2109. )
  2110. parser.add_argument(
  2111. "--static_dispatch_backend",
  2112. nargs="*",
  2113. help="generate static dispatch code for the specific backend (if set)",
  2114. )
  2115. parser.add_argument(
  2116. "--skip_dispatcher_op_registration",
  2117. action="store_true",
  2118. help="Avoid registering operators into the dispatcher.",
  2119. )
  2120. parser.add_argument(
  2121. "--force_schema_registration",
  2122. action="store_true",
  2123. help="force it to generate schema-only registrations for all ops, including"
  2124. "those that are not listed on --op_registration_whitelist",
  2125. )
  2126. parser.add_argument(
  2127. "--generate",
  2128. type=str,
  2129. nargs="*",
  2130. choices=["headers", "sources", "declarations_yaml"],
  2131. default=["headers", "sources", "declarations_yaml"],
  2132. help="Generate only a subset of files",
  2133. )
  2134. options = parser.parse_args()
  2135. selector = get_custom_build_selector(
  2136. options.op_registration_whitelist,
  2137. options.op_selection_yaml_path,
  2138. )
  2139. native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
  2140. tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
  2141. from torchgen.model import dispatch_keys
  2142. # TODO: stop generating CUDA kernels for non-CUDA builds
  2143. ignore_keys = set()
  2144. if not options.mps:
  2145. ignore_keys.add(DispatchKey.MPS)
  2146. if DispatchKey.MPS in dispatch_keys:
  2147. del dispatch_keys[dispatch_keys.index(DispatchKey.MPS)]
  2148. parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
  2149. native_functions, backend_indices = (
  2150. parsed_yaml.native_functions,
  2151. parsed_yaml.backend_indices,
  2152. )
  2153. grouped_native_functions = get_grouped_native_functions(native_functions)
  2154. structured_native_functions = [
  2155. g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
  2156. ]
  2157. native_functions_with_view_groups = get_grouped_by_view_native_functions(
  2158. native_functions
  2159. )
  2160. view_groups = [
  2161. g
  2162. for g in native_functions_with_view_groups
  2163. if isinstance(g, NativeFunctionsViewGroup)
  2164. ]
  2165. template_dir = os.path.join(options.source_path, "templates")
  2166. # NB: It is mandatory to NOT use os.path.join here, as the install directory
  2167. # will eventually be ingested by cmake, which does not respect Windows style
  2168. # path slashes. If you switch this to use os.path.join, you'll get an error
  2169. # like:
  2170. #
  2171. # Syntax error in cmake code when parsing string
  2172. #
  2173. # C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
  2174. #
  2175. # Invalid character escape '\c'.
  2176. core_install_dir = f"{options.install_dir}/core"
  2177. pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True)
  2178. ops_install_dir = f"{options.install_dir}/ops"
  2179. pathlib.Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
  2180. core_fm = make_file_manager(options=options, install_dir=core_install_dir)
  2181. cpu_fm = make_file_manager(options=options)
  2182. cpu_vec_fm = make_file_manager(options=options)
  2183. cuda_fm = make_file_manager(options=options)
  2184. ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
  2185. extra_cuda_headers = """\
  2186. #include <c10/cuda/CUDAGuard.h>
  2187. #include <ATen/cuda/ATenCUDAGeneral.h>
  2188. #include <ATen/cuda/CUDADevice.h>
  2189. #include <ATen/cuda/CUDAContext.h>"""
  2190. if options.rocm:
  2191. extra_cuda_headers = """\
  2192. #include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
  2193. #include <ATen/hip/ATenHIPGeneral.h>
  2194. #include <ATen/hip/HIPDevice.h>
  2195. #include <ATen/hip/HIPContext.h>"""
  2196. # Only a limited set of dispatch keys get CPUFunctions.h headers generated
  2197. # for them; this is the set
  2198. functions_keys = {
  2199. DispatchKey.CPU,
  2200. DispatchKey.CUDA,
  2201. DispatchKey.CompositeImplicitAutograd,
  2202. DispatchKey.CompositeExplicitAutograd,
  2203. DispatchKey.Meta,
  2204. }
  2205. if options.mps:
  2206. functions_keys.add(DispatchKey.MPS)
  2207. if options.backend_whitelist:
  2208. dispatch_keys = [
  2209. k
  2210. for k in dispatch_keys
  2211. if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
  2212. ]
  2213. static_dispatch_idx: List[BackendIndex] = []
  2214. if options.static_dispatch_backend:
  2215. static_dispatch_idx = [
  2216. backend_indices[DispatchKey.parse(key)]
  2217. for key in options.static_dispatch_backend
  2218. ]
  2219. for key in options.static_dispatch_backend:
  2220. dp_key = DispatchKey.parse(key)
  2221. if dp_key not in functions_keys:
  2222. functions_keys.add(dp_key)
  2223. if "sources" in options.generate:
  2224. gen_source_files(
  2225. native_functions=native_functions,
  2226. grouped_native_functions=grouped_native_functions,
  2227. structured_native_functions=structured_native_functions,
  2228. view_groups=view_groups,
  2229. selector=selector,
  2230. static_dispatch_idx=static_dispatch_idx,
  2231. backend_indices=backend_indices,
  2232. core_fm=core_fm,
  2233. cpu_fm=cpu_fm,
  2234. cpu_vec_fm=cpu_vec_fm,
  2235. cuda_fm=cuda_fm,
  2236. dispatch_keys=dispatch_keys,
  2237. functions_keys=functions_keys,
  2238. rocm=options.rocm,
  2239. force_schema_registration=options.force_schema_registration,
  2240. per_operator_headers=options.per_operator_headers,
  2241. skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
  2242. )
  2243. if "headers" in options.generate:
  2244. gen_headers(
  2245. native_functions=native_functions,
  2246. grouped_native_functions=grouped_native_functions,
  2247. structured_native_functions=structured_native_functions,
  2248. static_dispatch_idx=static_dispatch_idx,
  2249. selector=selector,
  2250. backend_indices=backend_indices,
  2251. core_fm=core_fm,
  2252. cpu_fm=cpu_fm,
  2253. cuda_fm=cuda_fm,
  2254. ops_fm=ops_fm,
  2255. dispatch_keys=dispatch_keys,
  2256. functions_keys=functions_keys,
  2257. rocm=options.rocm,
  2258. per_operator_headers=options.per_operator_headers,
  2259. )
  2260. if "declarations_yaml" in options.generate:
  2261. gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
  2262. if options.output_dependencies:
  2263. depfile_path = pathlib.Path(options.output_dependencies).resolve()
  2264. depfile_name = depfile_path.name
  2265. depfile_stem = depfile_path.stem
  2266. for fm, prefix in [
  2267. (cpu_fm, ""),
  2268. (cpu_vec_fm, "cpu_vec_"),
  2269. (core_fm, "core_"),
  2270. (cuda_fm, "cuda_"),
  2271. (ops_fm, "ops_"),
  2272. ]:
  2273. varname = prefix + depfile_stem
  2274. path = depfile_path.parent / (prefix + depfile_name)
  2275. fm.write_outputs(varname, str(path))
  2276. if __name__ == "__main__":
  2277. main()