python.py 54 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399
  1. from dataclasses import dataclass
  2. from typing import Optional, Union, Sequence, Set, List, Dict, Tuple
  3. from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
  4. from torchgen.api import cpp
  5. from torchgen.gen import pythonify_default
  6. from torchgen.model import (
  7. Argument,
  8. BaseTy,
  9. BaseType,
  10. ListType,
  11. NativeFunction,
  12. OptionalType,
  13. Return,
  14. Type,
  15. Variant,
  16. )
  17. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  18. #
  19. # Data Models
  20. #
  21. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  22. #
  23. # [Notes] python binding codegen
  24. #
  25. # The Python binding codegen produces code that takes the input list of
  26. # PyObjects, finds the matching ATen C++ function using PythonArgParser,
  27. # converts the PyObjects into C++ types and calls the ATen C++ function:
  28. #
  29. # +--------+ parsing +------------------------+ binding +-----------------------+
  30. # | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch |
  31. # +--------+ +------------------------+ +-----------------------+
  32. #
  33. # The following examples demonstrate the data models the Python binding
  34. # codegen needs to deal with and the tasks it needs to accomplish. It
  35. # helps understand the purpose of the new data types we introduced below.
  36. #
  37. # - Function Schema (source of truth)
  38. #
  39. # aten::empty.names(int[] size, *, Dimname[]? names,
  40. # ScalarType? dtype=None, Layout? layout=None,
  41. # Device? device=None, bool? pin_memory=None,
  42. # MemoryFormat? memory_format=None) -> Tensor
  43. #
  44. # - Python Signature
  45. #
  46. # It's used to generate input schema string for PythonArgParser.
  47. # Note: TensorOptions fields are reordered and the additional
  48. # 'requires_grad' field is added:
  49. #
  50. # empty(IntArrayRef size, *, DimnameList? names,
  51. # MemoryFormat? memory_format=None, ScalarType dtype=None,
  52. # Layout layout=torch.strided, Device device=None,
  53. # bool pin_memory=False, bool requires_grad=False)
  54. #
  55. # - C++ Signature
  56. #
  57. # It's used to generate C++ lambda formals & dispatch call.
  58. # Note: the scattered TensorOptions fields are packed into 'options'.
  59. #
  60. # auto dispatch_empty =
  61. # [](IntArrayRef size, c10::optional<DimnameList> names,
  62. # const TensorOptions & options,
  63. # c10::optional<MemoryFormat> memory_format) -> Tensor {
  64. # pybind11::gil_scoped_release no_gil;
  65. # return torch::empty(size, names, options, memory_format);
  66. # };
  67. #
  68. # - Binding between Python Arguments and C++ Arguments
  69. #
  70. # Given a set of Python Arguments in scope, we need produce the
  71. # binding expressions that translate the Python API into C++ API:
  72. #
  73. # Python Args Cpp Args Binding Exprs
  74. # -----------------------------------------------------------------
  75. # 0: size size '_r.intlist(0)'
  76. # 1: names names 'names' [special init]
  77. # 2: memory_format -------+
  78. # 3: dtype -----+-|--> options 'options' [special packing]
  79. # 4: layout / |
  80. # 5: device / +--> memory_format '_r.memoryformatOptional(2)'
  81. # 6: pin_memory /
  82. # 7: requires_grad -+
  83. #
  84. # So the full dispatch expression would look like:
  85. #
  86. # dispatch_empty(_r.intlist(0), names, options,
  87. # _r.memoryformatOptional(2))
  88. #
  89. # Where does 'names' come from? It involves special local init:
  90. #
  91. # auto __names = _r.toDimnameListOptional(1);
  92. # c10::optional<DimnameList> names =
  93. # __names ? c10::make_optional(DimnameList(__names.value()))
  94. # : c10::nullopt;
  95. #
  96. # Where does 'options' come from? It involves special local init
  97. # for TensorOptions. Note that Python side has the additional
  98. # 'requires_grad' field:
  99. #
  100. # const auto options = TensorOptions()
  101. # .dtype(_r.scalartype(3))
  102. # .device(_r.device(5))
  103. # .layout(_r.layoutOptional(4))
  104. # .requires_grad(_r.toBool(7))
  105. # .pinned_memory(_r.toBool(6));
  106. #
  107. # In some other cases one Python Argument can map to multiple C++
  108. # Arguments. For example:
  109. #
  110. # aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False)
  111. # -> (Tensor values, Tensor indices)
  112. #
  113. # Python Args Cpp Args Binding Exprs
  114. # ---------------------------------------------------------------------
  115. # +----> max 'out[0]'
  116. # /-----> max_values 'out[1]
  117. # 0: input / self '_r.tensor(0)'
  118. # 1: dim / dim '_r.dimname(1)'
  119. # 2: keepdim / keepdim '_r.toBool(2)'
  120. # 3: out -----+ [local init] out '_r.tensorlist_n<2>(3)'
  121. #
  122. # As demonstrated above, the binding can involve reordering,
  123. # packing, unpacking and special local inits.
  124. #
  125. #
  126. # Let's look at a concrete example:
  127. #
  128. # static PythonArgParser parser({
  129. # "abs(Tensor input, *, Tensor out=None)",
  130. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  131. # ^
  132. # +--- Python Schema, represented by PythonSignature and PythonArgument
  133. #
  134. # }, /*traceable=*/true);
  135. #
  136. # ParsedArgs<2> parsed_args;
  137. # auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
  138. #
  139. # ...
  140. #
  141. # if (_r.isNone(1)) {
  142. # ~~~~~~~~~~~~ <--- Scattered PythonArgParser output (arg name = 'out')
  143. # represented by PythonArgParserOutputExpr
  144. #
  145. # // aten::abs(Tensor self) -> Tensor
  146. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  147. # ^
  148. # +--- NativeFunction schema, base version
  149. #
  150. # auto dispatch_abs = [](const Tensor & self) -> Tensor {
  151. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  152. # ^
  153. # +--- dispatch_lambda_args / dispatch_lambda_return_str
  154. # generated from NativeFunction / CppSignature
  155. # (deprecated PythonSignature is special)
  156. # arguments are represented by DispatchLambdaArgument
  157. #
  158. # pybind11::gil_scoped_release no_gil;
  159. # return self.abs();
  160. # ~~~~~~~~~~~ <--- cpp_dispatch_target / cpp_dispatch_exprs
  161. # generated from NativeFunction / CppSignature
  162. # };
  163. # return wrap(dispatch_abs(_r.tensor(0)));
  164. # ~~~~~~~~~~~~~
  165. # ^
  166. # +--- dispatch_lambda_exprs
  167. # binding PythonArgParserOutputExpr (python args)
  168. # and DispatchLambdaArgument (c++ args)
  169. #
  170. # } else {
  171. # // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
  172. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  173. # ^
  174. # +--- NativeFunction schema, out-variant
  175. #
  176. # auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor {
  177. # pybind11::gil_scoped_release no_gil;
  178. # return at::abs_out(out, self);
  179. # };
  180. # return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0)));
  181. # }
  182. #
  183. #
  184. # [Notes] python interface codegen
  185. # The python dataclasses below are used used to generate both python binding code
  186. # and pyi type hint signatures.
  187. # In theory these two should look very similar, but there are number of differences
  188. # in how pyi signatures vs. python_arg_parser signatures are generated.
  189. # These differences have been encapsulated in signature_str() vs. signature_str_pyi()
  190. # to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments.
  191. # For examples, only pyi signatures include return types.
  192. @dataclass(frozen=True)
  193. class PythonReturns:
  194. returns: Tuple[Return, ...]
  195. @dataclass(frozen=True)
  196. class PythonArgument:
  197. name: str
  198. type: Type
  199. default: Optional[str]
  200. # Used to generate the default init expr for some PythonArgParser outputs, e.g.:
  201. #
  202. # _r.layoutWithDefault(3, layout_from_backend(self.options().backend())))
  203. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  204. # ^
  205. # +--- default_init str
  206. default_init: Optional[str]
  207. # Compute argument formal for python argument parsing.
  208. # Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
  209. def argument_str(self, *, method: bool = False) -> str:
  210. type_str = argument_type_str(self.type).replace("const ", "").replace(" &", "")
  211. name = self.name
  212. # s/self/input/ outside method bindings
  213. # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
  214. # for the parse string
  215. if name == "self" and type_str == "Tensor" and not method:
  216. name = "input"
  217. # add default
  218. if self.default is not None:
  219. default = {
  220. "nullptr": "None",
  221. "c10::nullopt": "None",
  222. "{}": "None",
  223. }.get(self.default, self.default)
  224. return f"{type_str} {name}={default}"
  225. else:
  226. return f"{type_str} {name}"
  227. def argument_str_pyi(
  228. self, *, method: bool = False, deprecated: bool = False
  229. ) -> str:
  230. type_str = argument_type_str_pyi(self.type)
  231. name = self.name
  232. # s/self/input/ outside method bindings
  233. # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
  234. # for the parse string
  235. if name == "self" and type_str == "Tensor" and not method and not deprecated:
  236. name = "input"
  237. if name == "from": # from is a Python keyword...
  238. name += "_"
  239. # pyi merges the _out and functional variants into the same signature, with an optional out arg
  240. if name == "out" and type_str == "Tensor" and not deprecated:
  241. type_str = "Optional[" + type_str + "]"
  242. # pyi deprecated signatures don't get defaults for their out arg
  243. treat_as_no_default = (
  244. deprecated
  245. and isinstance(self, PythonOutArgument)
  246. and self.default == "None"
  247. )
  248. # add default
  249. if self.default is not None and not treat_as_no_default:
  250. if (
  251. isinstance(self.type, ListType)
  252. and self.type.elem == BaseType(BaseTy.int)
  253. and self.default.startswith("{")
  254. and self.default.endswith("}")
  255. ):
  256. default = "(" + self.default[1:-1] + ")"
  257. else:
  258. default = {
  259. "nullptr": "None",
  260. "c10::nullopt": "None",
  261. "{}": "None",
  262. "MemoryFormat::Contiguous": "contiguous_format",
  263. "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
  264. }.get(self.default, self.default)
  265. return f"{name}: {type_str}={default}"
  266. else:
  267. return f"{name}: {type_str}"
  268. @dataclass(frozen=True)
  269. class PythonOutArgument(PythonArgument):
  270. # In Python signature multiple output fields are packed into one 'out' argument.
  271. # When binding to C++, it's first binded to a local 'out' variable:
  272. # 'auto out = _r.tensorlist_n<2>(2);',
  273. # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
  274. # TODO: maybe don't need keep scattered out fields for python signature?
  275. outputs: Tuple[PythonArgument, ...]
  276. @staticmethod
  277. def from_outputs(
  278. outputs: Tuple[PythonArgument, ...]
  279. ) -> Optional["PythonOutArgument"]:
  280. if not outputs:
  281. return None
  282. size = len(outputs)
  283. if size == 1:
  284. return PythonOutArgument(
  285. name=outputs[0].name,
  286. type=outputs[0].type,
  287. default="None",
  288. default_init=None,
  289. outputs=outputs,
  290. )
  291. elif size > 1:
  292. if any(map(lambda a: not a.type.is_tensor_like(), outputs)):
  293. raise RuntimeError(f"Unsupported output type: {outputs}")
  294. return PythonOutArgument(
  295. name="out",
  296. # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None?
  297. type=ListType(BaseType(BaseTy.Tensor), size),
  298. default="None",
  299. default_init=None,
  300. outputs=outputs,
  301. )
  302. raise AssertionError(r"Unexpected PythonOutArgument size")
  303. @dataclass(frozen=True)
  304. class PythonSignature:
  305. # Base operator name, without inplace/outplace suffix.
  306. name: str
  307. # Positional arguments.
  308. # TODO: create a dedicated SelfArgument type for 'self'?
  309. input_args: Tuple[PythonArgument, ...]
  310. # Keyword arguments excluding the 'out' argument and scattered kwargs belonging
  311. # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
  312. input_kwargs: Tuple[PythonArgument, ...]
  313. output_args: Optional[PythonOutArgument]
  314. # Return types, which are only used by pyi
  315. returns: PythonReturns
  316. # These are scattered kwargs arguments belonging to TensorOptions.
  317. # When binding to C++, they are packed into a TensorOptions object 'options'.
  318. # It's possible that the C++ signature doesn't take TensorOptions object (e.g.
  319. # for out variant), in which case they will be used as scattered fields without
  320. # being packed into 'options'.
  321. # TODO: maybe create a PythonTensorOptionsArgument?
  322. tensor_options_args: Tuple[PythonArgument, ...]
  323. # method or function signature?
  324. method: bool
  325. @property
  326. def deprecated(self) -> bool:
  327. return False
  328. def arguments(
  329. self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
  330. ) -> Tuple[Union[PythonArgument, PythonOutArgument], ...]:
  331. result: List[Union[PythonArgument, PythonOutArgument]] = []
  332. result.extend(self.input_args)
  333. result.extend(self.input_kwargs)
  334. if self.output_args is not None and not skip_outputs:
  335. result.append(self.output_args)
  336. if not skip_tensor_options:
  337. result.extend(self.tensor_options_args)
  338. return tuple(result)
  339. def arguments_count(self) -> int:
  340. return len(self.arguments())
  341. def output_idx(self) -> int:
  342. return len(self.input_args) + len(self.input_kwargs)
  343. # [old codegen] Compute the Python function signature for argument parsing,
  344. # as specified in torch/csrc/utils/python_arg_parser.h. WARNING:
  345. # this is NOT the same type signature as specified by PEP 484
  346. # as understood by mypy; our format was independently developed
  347. # and has some quirks to make it more suitable specifically
  348. # for error parsing.
  349. #
  350. # For a translation to mypy-valid type signatures, see
  351. # signature_str_pyi().
  352. def signature_str(self, *, skip_outputs: bool = False) -> str:
  353. args = self.arguments(skip_outputs=skip_outputs)
  354. schema_formals: List[str] = list(
  355. map(lambda a: a.argument_str(method=self.method), args)
  356. )
  357. positional_argc = len(self.input_args)
  358. if len(schema_formals) > positional_argc:
  359. schema_formals.insert(positional_argc, "*")
  360. return f'{self.name}({", ".join(schema_formals)})'
  361. def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
  362. args = self.arguments(skip_outputs=skip_outputs)
  363. schema_formals: List[str] = list(
  364. map(lambda a: a.argument_str_pyi(method=self.method), args)
  365. )
  366. positional_argc = len(self.input_args)
  367. if len(schema_formals) > positional_argc:
  368. schema_formals.insert(positional_argc, "*")
  369. # only pyi signatures include returns
  370. returns_str = returns_str_pyi(self)
  371. # pyi also includes self (with no typing/defaults) for methods
  372. if self.method:
  373. schema_formals.insert(0, "self")
  374. return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
  375. def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
  376. # only pyi uses vararg signatures
  377. args = self.arguments(skip_outputs=skip_outputs)
  378. schema_formals: List[str] = list(
  379. map(lambda a: a.argument_str_pyi(method=self.method), args)
  380. )
  381. # vararg only applies to pyi signatures. vararg variants are not generated for all signatures
  382. num_args = self.arguments_count()
  383. num_positionalargs = len(self.input_args)
  384. have_vararg_version = False
  385. if num_args > 0:
  386. vararg_type = args[0].type
  387. if (
  388. isinstance(vararg_type, ListType)
  389. and str(vararg_type.elem) == "int"
  390. and num_positionalargs == 1
  391. ):
  392. have_vararg_version = True
  393. if not have_vararg_version:
  394. return None
  395. # Below are the major changes in vararg vs. regular pyi signatures
  396. # vararg signatures also omit the asterix
  397. schema_formals[0] = "*" + args[0].name + ": _int"
  398. returns_str = returns_str_pyi(self)
  399. # pyi also includes self (with no typing/defaults) for methods
  400. if self.method:
  401. schema_formals.insert(0, "self")
  402. return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
  403. # The deprecated python signature involves some special logic, so create a
  404. # dedicated data model to store these extra properties.
  405. @dataclass(frozen=True)
  406. class PythonSignatureDeprecated(PythonSignature):
  407. # We need keep the order of arguments in deprecated signature.
  408. # Particularly, method signature might have 'self' not at the beginning, e.g.:
  409. # addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2)
  410. # When generating lambda function signature we need follow the exact order (even for method=True):
  411. # [](Scalar beta, const Tensor & self, const Tensor & mat1, const Tensor & mat2) -> Tensor
  412. deprecated_args_names: Tuple[str, ...]
  413. # The deprecated signature might miss some arguments that the corresponding
  414. # C++ signature expects. We need store the constant default values to pass in.
  415. # For example:
  416. # [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2)
  417. # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
  418. # [func call]: self.addmm(mat1, mat2, beta, 1)
  419. # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
  420. deprecated_args_exprs: Tuple[str, ...]
  421. @property
  422. def deprecated(self) -> bool:
  423. return True
  424. def signature_str(self, *, skip_outputs: bool = False) -> str:
  425. return (
  426. PythonSignature.signature_str(self, skip_outputs=skip_outputs)
  427. + "|deprecated"
  428. )
  429. def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
  430. args = self.arguments(skip_outputs=skip_outputs)
  431. schema_formals: List[str] = list(
  432. map(lambda a: a.argument_str_pyi(method=self.method, deprecated=True), args)
  433. )
  434. positional_argc = len(self.input_args)
  435. if len(schema_formals) > positional_argc:
  436. schema_formals.insert(positional_argc, "*")
  437. returns_str = returns_str_pyi(self)
  438. return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
  439. def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
  440. # the codegen doesn't include vararg variants for deprecated signatures
  441. return None
  442. # This struct is used to hold the PythonSignature and its corresponding
  443. # NativeFunction BEFORE grouping base and out-variant functions.
  444. # Why not store NativeFunction in PythonSignature or construct PythonSignature
  445. # from NativeFunction? Because they are not 1-1 mapped.
  446. # One native function could have both deprecated and non-deprecated python
  447. # signatures - NativeFunction doesn't contain information to construct the
  448. # deprecated python signature.
  449. # One python signature is used to handle both the base and the out-variant
  450. # function - see 'PythonSignatureGroup'.
  451. @dataclass(frozen=True)
  452. class PythonSignatureNativeFunctionPair:
  453. signature: PythonSignature
  454. function: NativeFunction
  455. # We merge pairs of functions with signatures that are equivalent mod
  456. # output arguments, and use a single entry in the python_arg_parser sig
  457. # list for both (output arguments become optional).
  458. @dataclass(frozen=True)
  459. class PythonSignatureGroup:
  460. # The signature used for Python argument parsing. The outplace signature
  461. # is preferred if exists, because it can be used to parse inputs for both
  462. # the out-place variant and the base version (with output omitted).
  463. signature: PythonSignature
  464. # The regular ATen declaration (e.g. conv2d)
  465. base: NativeFunction
  466. # The out variant (e.g. conv2d_out)
  467. outplace: Optional[NativeFunction]
  468. # C++ function dispatch is wrapped in a lambda function. The lambda function
  469. # has almost the same signature as the C++ function, only with some small
  470. # variants - see details below.
  471. # This data model is used to represent arguments of the lambda function
  472. # signature.
  473. @dataclass(frozen=True)
  474. class DispatchLambdaArgument:
  475. name: str
  476. type_str: str
  477. is_out_arg: bool
  478. # To pass PyObjects arguments to C++ function (via the lambda wrapper),
  479. # we need first convert PyObjects into simple C++ objects. This work
  480. # is done by PythonArgParser.
  481. # This data model is used to represent the output of PythonArgParser.
  482. # It has 1-1 mapping with PythonArgument in PythonSignature.
  483. @dataclass(frozen=True)
  484. class PythonArgParserOutputExpr:
  485. # argument name
  486. name: str
  487. # RHS expression to reference PythonArgParser output.
  488. expr: str
  489. # In some special cases we need create different expr, e.g.:
  490. # '_r.isNone(1)' instead of '_r.tensor(1)'.
  491. index: int
  492. # The python argument it maps to.
  493. argument: PythonArgument
  494. @property
  495. def is_none_expr(self) -> str:
  496. return f"_r.isNone({self.index})"
  497. # To pass PythonArgParser output to the lambda wrapper, we need bind
  498. # PythonArgParserOutputExpr to DispatchLambdaArgument.
  499. # They are not always 1-1 mapped, e.g. scattered TensorOptions fields
  500. # need be packed into a TensorOptions object, which is the argument
  501. # that the lambda function wrapper takes.
  502. @dataclass(frozen=True)
  503. class DispatchLambdaArgumentExprs:
  504. # The exprs that provide the binding for lambda arguments, e.g.:
  505. #
  506. # 'self' -> '_r.tensor(0)'
  507. # 'min' -> 'out[0]' / 'min_indices' -> 'out[1]'
  508. # 'options' -> 'options'
  509. #
  510. # It has 1-1 mapping with DispatchLambdaArgument.
  511. exprs: Sequence[str]
  512. # Special local inits, which might introduce new variables that
  513. # the 'exprs' above reference, e.g.:
  514. #
  515. # 'auto out = _r.tensorlist_n<2>(2);'
  516. #
  517. inits: Sequence[str]
  518. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  519. #
  520. # Helper Functions
  521. #
  522. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  523. def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature:
  524. return CppSignatureGroup.from_native_function(f, method=method).signature
  525. def has_tensor_options(f: NativeFunction) -> bool:
  526. return f.func.arguments.tensor_options is not None
  527. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  528. #
  529. # Python Signature
  530. #
  531. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  532. # 'simple_type' was introduced by the old codegen, which is slightly
  533. # different from the python schema type, e.g.: doesn't have '?' suffix
  534. # for optional Tensor/TensorList; doesn't have '[size]' suffix for list type.
  535. def argument_type_str(t: Type, *, simple_type: bool = False) -> str:
  536. if isinstance(t, BaseType):
  537. if t.name == BaseTy.Tensor:
  538. return "Tensor"
  539. elif t.name == BaseTy.int:
  540. return "int64_t"
  541. elif t.name == BaseTy.float:
  542. return "double"
  543. elif t.name == BaseTy.str:
  544. return "c10::string_view"
  545. elif t.name in [
  546. BaseTy.bool,
  547. BaseTy.QScheme,
  548. BaseTy.Scalar,
  549. BaseTy.ScalarType,
  550. BaseTy.Generator,
  551. BaseTy.Storage,
  552. BaseTy.Layout,
  553. BaseTy.Device,
  554. BaseTy.MemoryFormat,
  555. BaseTy.Dimname,
  556. BaseTy.Stream,
  557. BaseTy.ConstQuantizerPtr,
  558. BaseTy.SymInt,
  559. ]:
  560. # These python schema type names line up with their function schema names
  561. return t.name.name
  562. elif isinstance(t, OptionalType):
  563. if str(t.elem) == "Tensor":
  564. # Is it desired to keep '?' for simple_type with new style dispatcher?
  565. return "Tensor?"
  566. elem = argument_type_str(t.elem, simple_type=simple_type)
  567. if elem == "Layout":
  568. # TODO: fix this special case in PythonArgParser?
  569. return "Layout"
  570. else:
  571. return f"{elem}?"
  572. elif isinstance(t, ListType):
  573. size = t.size if not simple_type else None
  574. if str(t.elem) == "bool":
  575. assert t.size is not None
  576. return f"::std::array<bool,{t.size}>"
  577. elif str(t.elem) == "int":
  578. return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
  579. elif str(t.elem) == "SymInt":
  580. return f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
  581. elif str(t.elem) == "Tensor":
  582. return f"TensorList[{size}]" if size is not None else "TensorList"
  583. elif str(t.elem) == "Scalar":
  584. return f"ScalarList[{size}]" if size is not None else "ScalarList"
  585. elif str(t.elem) == "Tensor?":
  586. if simple_type:
  587. return "c10::List<c10::optional<Tensor>>"
  588. else:
  589. return "const c10::List<c10::optional<Tensor>> &"
  590. elif str(t.elem) == "Dimname":
  591. return f"DimnameList[{size}]" if size is not None else "DimnameList"
  592. elem = argument_type_str(t.elem, simple_type=simple_type)
  593. return f"ArrayRef<{elem}>"
  594. raise RuntimeError(f"unrecognized type {repr(t)}")
  595. def argument_type_size(t: Type) -> Optional[int]:
  596. l = t.is_list_like()
  597. if l is not None and str(l.elem) != "bool":
  598. return l.size
  599. else:
  600. return None
  601. def argument(a: Argument) -> PythonArgument:
  602. return PythonArgument(
  603. name=a.name,
  604. type=a.type,
  605. # TODO: directly translate a.default to python default
  606. default=str(pythonify_default(cpp.default_expr(a.default, a.type)))
  607. if a.default is not None
  608. else None,
  609. default_init=None,
  610. )
  611. # Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen
  612. def signature(
  613. f: NativeFunction, *, method: bool = False, pyi: bool = False
  614. ) -> PythonSignature:
  615. args: List[Argument] = []
  616. args.extend(f.func.arguments.pre_self_positional)
  617. # Skip SelfArgument if this is method.
  618. if not method and f.func.arguments.self_arg is not None:
  619. args.append(f.func.arguments.self_arg.argument)
  620. args.extend(f.func.arguments.post_self_positional)
  621. args.extend(f.func.arguments.pre_tensor_options_kwarg_only)
  622. # Skip TensorOptionsArguments. Python side TensorOptions
  623. # arguments are created based on different rules - see below.
  624. args.extend(f.func.arguments.post_tensor_options_kwarg_only)
  625. args.extend(f.func.arguments.out)
  626. input_arg_set = set(a.name for a in f.func.arguments.flat_positional)
  627. kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only)
  628. out_arg_set = set(a.name for a in f.func.arguments.out)
  629. input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
  630. input_kwargs = tuple(
  631. map(argument, filter(lambda a: a.name in kwarg_only_set, args))
  632. )
  633. outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args)))
  634. # Reintroduce the scattered fields of TensorOptions for Python.
  635. # Compared to the cpp counterpart, the python arguments have new property
  636. # (default_init) and a new argument 'requires_grad', which require some
  637. # special handlings.
  638. # [old codegen] TODO: because these aren't guaranteed to be 100% faithful
  639. # to the original versions in the yaml, this recreation is a potential
  640. # source of drift between eager and JIT. Pull this logic out to a shared place.
  641. has_tensor_input_arg = any(
  642. a.type.is_tensor_like() for a in f.func.arguments.flat_non_out
  643. )
  644. if any(a.name == "requires_grad" for a in f.func.schema_order_arguments()):
  645. raise ValueError(
  646. "argument named requires_grad is reserved, should not explicitly add it in the schema"
  647. )
  648. # [old codegen] this probably won't work if one of the returns is not a tensor,
  649. # but it will produce a compile-time error that is obvious.
  650. has_tensor_return = any(r.type.is_tensor_like() for r in f.func.returns)
  651. name: str = cpp.name(f.func)
  652. is_factory_function = f.category_override == "factory" or (
  653. has_tensor_return and not has_tensor_input_arg
  654. )
  655. is_like_or_new_function = (
  656. f.category_override in ("new", "like")
  657. or name.startswith("new_")
  658. or name.endswith("_like")
  659. )
  660. tensor_options_args: List[PythonArgument] = []
  661. if is_factory_function or is_like_or_new_function:
  662. tensor_options_args.append(
  663. PythonArgument(
  664. name="dtype",
  665. type=BaseType(BaseTy.ScalarType),
  666. default="None" if pyi else _dtype_default_type_hack(name),
  667. default_init="self.scalar_type()" if is_like_or_new_function else None,
  668. )
  669. )
  670. tensor_options_args.append(
  671. PythonArgument(
  672. name="layout",
  673. type=OptionalType(BaseType(BaseTy.Layout)),
  674. default="strided" if pyi else "torch.strided",
  675. default_init="self.layout()" if is_like_or_new_function else None,
  676. )
  677. )
  678. tensor_options_args.append(
  679. PythonArgument(
  680. name="device",
  681. type=BaseType(BaseTy.Device),
  682. default="None",
  683. default_init="self.device()" if is_like_or_new_function else None,
  684. )
  685. )
  686. tensor_options_args.append(
  687. PythonArgument(
  688. name="pin_memory",
  689. type=BaseType(BaseTy.bool),
  690. default="False",
  691. default_init=None,
  692. )
  693. )
  694. tensor_options_args.append(
  695. PythonArgument(
  696. name="requires_grad",
  697. type=BaseType(BaseTy.bool),
  698. default="False",
  699. default_init=None,
  700. )
  701. )
  702. returns = PythonReturns(returns=f.func.returns)
  703. return PythonSignature(
  704. name=str(f.func.name.name),
  705. input_args=input_args,
  706. input_kwargs=input_kwargs,
  707. output_args=PythonOutArgument.from_outputs(outputs),
  708. tensor_options_args=tuple(tensor_options_args),
  709. returns=returns,
  710. method=method,
  711. )
  712. # TODO blowtorch
  713. # note: removing this will be BC-breaking. A quick test shows that
  714. # randperm will otherwise default its dtype to torch.float64
  715. def _dtype_default_type_hack(name: str) -> str:
  716. if name.startswith("randperm") or name == "tril_indices" or name == "triu_indices":
  717. return "torch.int64"
  718. else:
  719. return "None"
  720. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  721. #
  722. # Python Interface
  723. #
  724. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  725. def namedtuple_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
  726. if len(returns) <= 1 or all(map(lambda r: r.name is None, returns)):
  727. return []
  728. else:
  729. if any(map(lambda r: r.name is None, returns)):
  730. # When building on Windows, `PyStructSequence_UnnamedField` could not be
  731. # resolved by the linker for some reason, which cause error in building:
  732. #
  733. # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
  734. # PyStructSequence_UnnamedField
  735. #
  736. # Thus, at this point in time, we do not support unnamed
  737. # fields in namedtuple; you must either name all fields,
  738. # or none of them.
  739. raise ValueError("Unnamed field is not supported by codegen")
  740. return list(map(lambda r: str(r.name), returns))
  741. def argument_type_str_pyi(t: Type) -> str:
  742. add_optional = False
  743. if isinstance(t, OptionalType):
  744. t = t.elem
  745. add_optional = True
  746. if isinstance(t, BaseType):
  747. if t.name == BaseTy.int:
  748. ret = "_int"
  749. if t.name == BaseTy.SymInt:
  750. ret = "SymInt"
  751. elif t.name == BaseTy.float:
  752. ret = "_float"
  753. elif t.name == BaseTy.str:
  754. ret = "str"
  755. elif t.name == BaseTy.Scalar:
  756. ret = "Number"
  757. elif t.name == BaseTy.ScalarType:
  758. ret = "_dtype"
  759. elif t.name == BaseTy.bool:
  760. ret = "_bool"
  761. elif t.name == BaseTy.QScheme:
  762. ret = "_qscheme"
  763. elif t.name == BaseTy.Layout:
  764. ret = "_layout"
  765. elif t.name == BaseTy.Device:
  766. ret = "Union[_device, str, None]"
  767. elif t.name == BaseTy.MemoryFormat:
  768. ret = "memory_format"
  769. elif t.name == BaseTy.Dimname:
  770. ret = "Union[str, ellipsis, None]"
  771. elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Storage, BaseTy.Stream]:
  772. # These python schema type names line up with their function schema names
  773. ret = t.name.name
  774. elif isinstance(t, ListType):
  775. if str(t.elem) == "int":
  776. ret = "Union[_int, _size]" if t.size is not None else "_size"
  777. elif t.is_tensor_like():
  778. # TODO: this doesn't seem right...
  779. # Tensor?[] currently translates to Optional[Union[Tuple[Tensor, ...], List[Tensor]]]
  780. # It should probably translate to Union[Tuple[Optional[Tensor], ...], List[Optional[Tensor]]]
  781. if isinstance(t.elem, OptionalType):
  782. add_optional = True
  783. ret = (
  784. "Union[Tensor, Tuple[Tensor, ...], List[Tensor]]"
  785. if t.size is not None
  786. else "Union[Tuple[Tensor, ...], List[Tensor]]"
  787. )
  788. elif str(t.elem) == "float":
  789. ret = "Sequence[_float]"
  790. else:
  791. elem = argument_type_str_pyi(t.elem)
  792. ret = f"Sequence[{elem}]"
  793. if add_optional:
  794. ret = "Optional[" + ret + "]"
  795. return ret
  796. raise RuntimeError(f"unrecognized type {repr(t)}")
  797. def return_type_str_pyi(t: Type) -> str:
  798. # Where arguments are open to accepting Union, return types should return
  799. # concrete types
  800. if isinstance(t, OptionalType):
  801. inner = return_type_str_pyi(t.elem)
  802. return f"Optional[{inner}]"
  803. if isinstance(t, BaseType):
  804. if t.name == BaseTy.Device:
  805. return "_device"
  806. elif t.name == BaseTy.Dimname:
  807. ret = "Optional[str]"
  808. else:
  809. return argument_type_str_pyi(t)
  810. if isinstance(t, ListType):
  811. inner = return_type_str_pyi(t.elem)
  812. return f"List[{inner}]"
  813. return argument_type_str_pyi(t)
  814. def returns_named_tuple_pyi(signature: PythonSignature) -> Optional[Tuple[str, str]]:
  815. python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
  816. namedtuple_name = signature.name
  817. field_names = namedtuple_fieldnames(signature.returns.returns)
  818. if field_names:
  819. tuple_args = [
  820. f'("{name}", {typ})' for name, typ in zip(field_names, python_returns)
  821. ]
  822. namedtuple_def = f'NamedTuple("{namedtuple_name}", [{", ".join(tuple_args)}])'
  823. return namedtuple_name, namedtuple_def
  824. return None
  825. def returns_str_pyi(signature: PythonSignature) -> str:
  826. field_names = namedtuple_fieldnames(signature.returns.returns)
  827. if field_names:
  828. return f"torch.return_types.{signature.name}"
  829. python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
  830. if len(python_returns) > 1:
  831. return "Tuple[" + ", ".join(python_returns) + "]"
  832. if len(python_returns) == 1:
  833. return python_returns[0]
  834. return "None"
  835. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  836. #
  837. # C++ Function Dispatch
  838. #
  839. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  840. # This section provides APIs to generate the code that does C++ function
  841. # dispatch. The C++ function call is wrapped by a lambda function.
  842. # For example:
  843. #
  844. # // aten::selu_(Tensor(a!) self) -> Tensor(a!)
  845. # auto dispatch_selu_ = [](Tensor self) -> Tensor {
  846. # pybind11::gil_scoped_release no_gil;
  847. # return at::selu_(self);
  848. # };
  849. #
  850. # The lambda function's signature follows the C++ signature in common
  851. # cases, e.g.:
  852. #
  853. # // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
  854. # [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
  855. #
  856. # For out variant the 'out' argument's type is changed from 'Tensor &'
  857. # to 'Tensor'. It's because when calling the lambda it passes in the
  858. # PythonArgParser output '_r.tensor(3)', which is stack allocated object
  859. # and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'.
  860. #
  861. # // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
  862. # [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
  863. #
  864. # For multi-output case it can keep using reference type because the
  865. # PythonArgParser output has been unpacked to local variables, e.g.:
  866. #
  867. # // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *,
  868. # // Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
  869. # [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple<Tensor,Tensor>
  870. #
  871. # For deprecated python signature, it should follow deprecated python arg order.
  872. # TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary?
  873. def dispatch_lambda_args(
  874. ps: PythonSignature, f: NativeFunction
  875. ) -> Tuple[DispatchLambdaArgument, ...]:
  876. # Start with cpp arguments - dispatch lambda signature always include 'self'
  877. cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
  878. # Special reorder logic for deprecated python signature
  879. if isinstance(ps, PythonSignatureDeprecated):
  880. m: Dict[str, Binding] = dict((a.name, a) for a in cpp_args)
  881. # reorder according to the deprecated signature
  882. # ignore 'out' argument when binding to non-output function.
  883. ordered_args = filter(
  884. lambda n: n != "out" or f.func.is_out_fn(), ps.deprecated_args_names
  885. )
  886. cpp_args = list(map(lambda n: m[n], ordered_args))
  887. out_args: Set[str] = set(a.name for a in f.func.arguments.out)
  888. # Convert from cpp argument to lambda argument
  889. def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
  890. type_str = cpp_arg.type
  891. is_out_arg = cpp_arg.name in out_args
  892. if ps.method and cpp_arg.name == "self":
  893. # For method's 'self', we can use 'const Tensor &' and simply ignore mutability!
  894. type_str = "const at::Tensor &"
  895. else:
  896. # For other cases we need prevent dangling refs to temps (unless it's
  897. # unpacked scattered output)
  898. # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'.
  899. # TODO: avoid this special handling?
  900. ensure_temp_safe = len(out_args) <= 1 or not is_out_arg
  901. if ensure_temp_safe:
  902. type_str = {
  903. "at::Tensor &": "at::Tensor",
  904. }.get(type_str, type_str)
  905. return DispatchLambdaArgument(
  906. name=cpp_arg.name,
  907. type_str=type_str,
  908. is_out_arg=is_out_arg,
  909. )
  910. return tuple(map(dispatch_lambda_arg, cpp_args))
  911. # [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean
  912. # it's enough to just extend the list here. Before you do this, make sure
  913. # to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
  914. SUPPORTED_RETURN_TYPES = {
  915. "at::Tensor",
  916. "::std::tuple<at::Tensor,at::Tensor>",
  917. "::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
  918. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
  919. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
  920. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
  921. "::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
  922. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",
  923. "::std::tuple<at::Tensor,at::Tensor,double,at::Tensor,int64_t>",
  924. "::std::tuple<double,int64_t>",
  925. "::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
  926. "::std::vector<at::Tensor>",
  927. "at::Scalar",
  928. "bool",
  929. "int64_t",
  930. "void*",
  931. "void",
  932. "at::QScheme",
  933. "double",
  934. "at::IntArrayRef",
  935. "at::ScalarType",
  936. }
  937. def dispatch_lambda_return_str(f: NativeFunction) -> str:
  938. # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &')
  939. # because the dispatch lambdas take mutable arguments *by value*, not
  940. # by reference. If you then return a reference to such an argument, you
  941. # will now have a pointer to a dangling stack entry. Not good.
  942. #
  943. # You want:
  944. #
  945. # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); };
  946. # ^^^^^^
  947. #
  948. # *not*
  949. #
  950. # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); };
  951. # ^^^^^^^
  952. #
  953. # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
  954. # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a
  955. # mutable reference to temporary. Maybe we could assign it to a
  956. # variable itself.)
  957. returns_without_annotation = tuple(
  958. map(lambda r: Return(r.name, r.type, None), f.func.returns)
  959. )
  960. return_str = cpp.returns_type(returns_without_annotation).cpp_type()
  961. if return_str not in SUPPORTED_RETURN_TYPES:
  962. raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
  963. return return_str
  964. def cpp_dispatch_target(f: NativeFunction) -> str:
  965. name = cpp.name(f.func)
  966. if Variant.method in f.variants:
  967. return f"self.{name}"
  968. if Variant.function in f.variants:
  969. if has_tensor_options(f) or f.func.name.name.base.endswith("_like"):
  970. namespace = "torch"
  971. else:
  972. namespace = "at"
  973. return f"{namespace}::{name}"
  974. raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}")
  975. def cpp_dispatch_exprs(
  976. f: NativeFunction,
  977. *,
  978. python_signature: Optional[PythonSignature] = None,
  979. ) -> Tuple[str, ...]:
  980. cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
  981. exprs: Tuple[str, ...] = tuple()
  982. if not isinstance(python_signature, PythonSignatureDeprecated):
  983. # By default the exprs are consistent with the C++ signature.
  984. exprs = tuple(map(lambda a: a.name, cpp_args))
  985. else:
  986. # For deprecated python signature we may need fill in some constants.
  987. exprs = tuple(
  988. filter(
  989. lambda n: n != "out" or f.func.is_out_fn(),
  990. python_signature.deprecated_args_exprs,
  991. )
  992. )
  993. if Variant.method in f.variants:
  994. exprs = tuple(filter("self".__ne__, exprs))
  995. return exprs
  996. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  997. #
  998. # Python / C++ Args Binding
  999. #
  1000. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  1001. # We explicitly enumerate the PythonArgParser unpacking methods for all
  1002. # supported types. This might be more verbose than necessary, partially
  1003. # because of the irregularity of unpacking method naming, partially
  1004. # because we want to mimic the old codegen behavior - to reject
  1005. # unexpected and/or unsupported cases which the old codegen rejects.
  1006. # For certain cases it is intentionally more restrictive than necessary,
  1007. # e.g.: it doesn't accepts doublelist with definite size.
  1008. def arg_parser_unpack_method(t: Type, has_default: bool) -> str:
  1009. if has_default and str(t) not in ("ScalarType", "Device", "Layout?"):
  1010. raise RuntimeError(f"type '{t}' does not supported unpacking with default")
  1011. if isinstance(t, BaseType):
  1012. if t.name in [
  1013. BaseTy.Tensor,
  1014. BaseTy.Stream,
  1015. BaseTy.Storage,
  1016. BaseTy.Scalar,
  1017. BaseTy.Dimname,
  1018. ]:
  1019. # These unpack methods line up with their schema names
  1020. return t.name.name.lower()
  1021. elif t.name == BaseTy.ScalarType:
  1022. return "scalartypeWithDefault" if has_default else "scalartype"
  1023. elif t.name == BaseTy.Device:
  1024. return "deviceWithDefault" if has_default else "device"
  1025. elif t.name == BaseTy.int:
  1026. return "toInt64"
  1027. elif t.name == BaseTy.SymInt:
  1028. return "toSymInt"
  1029. elif t.name == BaseTy.bool:
  1030. return "toBool"
  1031. elif t.name == BaseTy.float:
  1032. return "toDouble"
  1033. elif t.name == BaseTy.str:
  1034. return "stringView"
  1035. elif t.name == BaseTy.Layout:
  1036. return "layout"
  1037. elif isinstance(t, OptionalType):
  1038. if str(t.elem) == "Tensor":
  1039. return "optionalTensor"
  1040. elif isinstance(t.elem, BaseType):
  1041. if t.elem.name in [
  1042. BaseTy.ScalarType,
  1043. BaseTy.Scalar,
  1044. BaseTy.int,
  1045. BaseTy.bool,
  1046. BaseTy.float,
  1047. BaseTy.str,
  1048. ]:
  1049. # Regular cases: append 'Optional' to elem's unpacking method
  1050. return arg_parser_unpack_method(t.elem, False) + "Optional"
  1051. elif t.elem.name == BaseTy.MemoryFormat:
  1052. return "memoryformatOptional"
  1053. elif t.elem.name == BaseTy.Generator:
  1054. return "generator"
  1055. elif t.elem.name == BaseTy.Layout:
  1056. return "layoutWithDefault" if has_default else "layoutOptional"
  1057. elif t.elem.name == BaseTy.Device:
  1058. return "deviceWithDefault" if has_default else "deviceOptional"
  1059. elif isinstance(t.elem, ListType):
  1060. if str(t.elem.elem) == "int":
  1061. # accept definite size
  1062. return "intlistOptional"
  1063. elif str(t.elem) == "float[]":
  1064. return "doublelistOptional"
  1065. elif str(t.elem) == "Dimname[]":
  1066. return "toDimnameListOptional"
  1067. elif isinstance(t, ListType):
  1068. if str(t.elem) == "Tensor":
  1069. # accept and use definite size
  1070. if t.size is not None:
  1071. return f"tensorlist_n<{t.size}>"
  1072. else:
  1073. return "tensorlist"
  1074. elif str(t.elem) == "Tensor?":
  1075. return "list_of_optional_tensors"
  1076. elif str(t.elem) == "Dimname":
  1077. # accept definite size
  1078. return "dimnamelist"
  1079. elif str(t.elem) == "int":
  1080. # accept definite size
  1081. return "intlist"
  1082. elif str(t) == "float[]":
  1083. return "doublelist"
  1084. elif str(t.elem) == "SymInt":
  1085. # accept definite size
  1086. return "symintlist"
  1087. elif str(t) == "Scalar[]":
  1088. return "scalarlist"
  1089. raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
  1090. # Return RHS expression for python argument using PythonArgParser output.
  1091. # e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
  1092. def arg_parser_output_expr(
  1093. arg_index: int, a: PythonArgument
  1094. ) -> PythonArgParserOutputExpr:
  1095. has_default = a.default_init is not None
  1096. unpack_method = arg_parser_unpack_method(a.type, has_default)
  1097. default = f", {a.default_init}" if has_default else ""
  1098. expr = f"_r.{unpack_method}({arg_index}{default})"
  1099. return PythonArgParserOutputExpr(
  1100. name=a.name,
  1101. expr=expr,
  1102. index=arg_index,
  1103. argument=a,
  1104. )
  1105. # Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
  1106. def arg_parser_output_exprs(
  1107. ps: PythonSignature, f: NativeFunction
  1108. ) -> Dict[str, PythonArgParserOutputExpr]:
  1109. return {
  1110. e.name: e
  1111. for i, a in enumerate(ps.arguments())
  1112. for e in (arg_parser_output_expr(i, a),)
  1113. }
  1114. # argument name to type for scattered tensor options fields
  1115. TENSOR_OPTIONS_FIELDS = {
  1116. "dtype": "ScalarType",
  1117. "device": "Device",
  1118. "layout": "Layout?",
  1119. "pin_memory": "bool",
  1120. "requires_grad": "bool",
  1121. }
  1122. # bind arg parser outputs (python args) with dispatch lambda arguments (c++ args).
  1123. def dispatch_lambda_exprs(
  1124. ps: PythonSignature, f: NativeFunction
  1125. ) -> DispatchLambdaArgumentExprs:
  1126. # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing
  1127. # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser
  1128. # outputs.
  1129. arg_parser_outputs = arg_parser_output_exprs(ps, f)
  1130. lambda_args = dispatch_lambda_args(ps, f)
  1131. inits: List[str] = []
  1132. lambda_args_exprs: Dict[str, str] = dict()
  1133. has_toptions = has_tensor_options(f)
  1134. # 1. special inits/unpacking to provide binding exprs for lambda arguments.
  1135. for a in ps.arguments(skip_tensor_options=True):
  1136. name = a.name
  1137. arg_parser_expr = arg_parser_outputs[a.name].expr
  1138. if has_toptions and name == "self":
  1139. # TODO: why this needs to be special case?
  1140. inits.extend(
  1141. [
  1142. f"auto self = {arg_parser_expr};",
  1143. ]
  1144. )
  1145. lambda_args_exprs[name] = name
  1146. elif (
  1147. isinstance(a, PythonOutArgument)
  1148. and len(a.outputs) > 1
  1149. and f.func.is_out_fn()
  1150. ):
  1151. inits.extend(
  1152. [
  1153. f"auto out = {arg_parser_expr};",
  1154. ]
  1155. )
  1156. for i, out_arg in enumerate(a.outputs):
  1157. lambda_args_exprs[out_arg.name] = f"out[{i}]"
  1158. elif str(a.type) == "Dimname[]?":
  1159. # [old codegen]
  1160. # TODO: make this part of something more general, or get rid of it.
  1161. # optional<ArrayRef<T>> are special. The PythonArgParser returns an
  1162. # optional<vector<T>>, which cannot be implicitly converted to
  1163. # optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap.
  1164. inits.extend(
  1165. [
  1166. f"auto __{name} = {arg_parser_expr};",
  1167. f"c10::optional<DimnameList> {name} = __{name} ? c10::make_optional(DimnameList(__{name}.value())) : c10::nullopt;", # noqa: B950
  1168. ]
  1169. )
  1170. lambda_args_exprs[name] = name
  1171. else:
  1172. # default case - directly using PythonArgParser output expr
  1173. lambda_args_exprs[name] = arg_parser_expr
  1174. # method's self is passed directly to python binding, rather than parsed
  1175. if ps.method:
  1176. lambda_args_exprs["self"] = "self"
  1177. # 2. special packing/checking for TensorOptions.
  1178. tensor_options_args_names = list(map(lambda a: a.name, ps.tensor_options_args))
  1179. if has_toptions:
  1180. if f.func.is_out_fn():
  1181. raise RuntimeError(f"{f.func}: tensor options with output arg")
  1182. for a in ps.tensor_options_args:
  1183. if a.name not in TENSOR_OPTIONS_FIELDS:
  1184. raise RuntimeError(
  1185. f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments"
  1186. )
  1187. if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name):
  1188. raise RuntimeError(
  1189. f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
  1190. )
  1191. if not all(
  1192. map(lambda a: a in tensor_options_args_names, TENSOR_OPTIONS_FIELDS.keys())
  1193. ):
  1194. raise RuntimeError(
  1195. f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
  1196. )
  1197. inits.append(
  1198. f"""\
  1199. const auto options = TensorOptions()
  1200. .dtype({arg_parser_outputs['dtype'].expr})
  1201. .device({arg_parser_outputs['device'].expr})
  1202. .layout({arg_parser_outputs['layout'].expr})
  1203. .requires_grad({arg_parser_outputs['requires_grad'].expr})
  1204. .pinned_memory({arg_parser_outputs['pin_memory'].expr});
  1205. torch::utils::maybe_initialize_cuda(options);
  1206. """
  1207. )
  1208. lambda_args_exprs["options"] = "options"
  1209. # 3. special case - access scattered TensorOptions fields without packing
  1210. # TODO: maybe move to the generator side as it's not related to binding.
  1211. if not has_toptions and tensor_options_args_names:
  1212. if "dtype" in tensor_options_args_names:
  1213. # we're an output-arg variant, check these args against output tensor
  1214. if not f.func.is_out_fn():
  1215. raise RuntimeError(
  1216. f"{f.func}: dtype in tensor_options_args without output arg"
  1217. )
  1218. if not all(
  1219. map(lambda a: a in tensor_options_args_names, ("layout", "device"))
  1220. ):
  1221. raise RuntimeError(
  1222. f"{f.func}: incomplete tensor options for output check"
  1223. )
  1224. inits.append(
  1225. f"""\
  1226. check_out_type_matches({arg_parser_outputs['out'].expr}, {arg_parser_outputs['dtype'].expr},
  1227. {arg_parser_outputs['dtype'].is_none_expr}, {arg_parser_outputs['layout'].expr},
  1228. {arg_parser_outputs['device'].expr}, {arg_parser_outputs['device'].is_none_expr});
  1229. """
  1230. )
  1231. # we'll set requires_grad on outgoing tensor
  1232. if "requires_grad" not in tensor_options_args_names:
  1233. raise RuntimeError(
  1234. f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]'
  1235. )
  1236. return DispatchLambdaArgumentExprs(
  1237. exprs=tuple(map(lambda a: lambda_args_exprs[a.name], lambda_args)),
  1238. inits=inits,
  1239. )