cpp.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. from torchgen.model import (
  2. Argument,
  3. Arguments,
  4. BaseTy,
  5. BaseType,
  6. FunctionSchema,
  7. ListType,
  8. NativeFunction,
  9. OptionalType,
  10. Return,
  11. SelfArgument,
  12. TensorOptionsArguments,
  13. Type,
  14. )
  15. from torchgen.api.types import (
  16. ArgName,
  17. BaseCType,
  18. Binding,
  19. ConstRefCType,
  20. NamedCType,
  21. CType,
  22. MutRefCType,
  23. ArrayCType,
  24. ListCType,
  25. VectorCType,
  26. ArrayRefCType,
  27. OptionalCType,
  28. TupleCType,
  29. SpecialArgName,
  30. boolT,
  31. scalarT,
  32. tensorListT,
  33. dimnameListT,
  34. tensorT,
  35. voidT,
  36. longT,
  37. BaseTypeToCppMapping,
  38. intArrayRefT,
  39. optionalIntArrayRefT,
  40. tensorOptionsT,
  41. symIntArrayRefT,
  42. )
  43. from torchgen import local
  44. from torchgen.utils import assert_never
  45. from typing import Optional, Sequence, Union, List, Set
  46. # This file describes the translation of JIT schema to the public C++
  47. # API, which is what people use when they call functions like at::add.
  48. #
  49. # Prominent characteristics of the C++ API:
  50. #
  51. # - dtype, layout, device and pin_memory are collected into
  52. # a single C++ type TensorOptions (the native functions API
  53. # also has this, but tensor options is really most relevant
  54. # for the C++ API; it makes calling kwarg factory functions
  55. # pleasant)
  56. #
  57. # - defaulting lives here (in fact, the dispatcher is completely
  58. # oblivious of defaults!)
  59. #
  60. # BTW: policy on name collisions: we try not to have types with
  61. # collisions, but functions are fair game to collide
  62. def name(func: FunctionSchema, *, faithful_name_for_out_overloads: bool = False) -> str:
  63. name = str(func.name.name)
  64. if func.is_functional_fn():
  65. name += "_functional"
  66. elif func.is_out_fn():
  67. if faithful_name_for_out_overloads:
  68. name += "_outf"
  69. else:
  70. name += "_out"
  71. return name
  72. # Translation of "value types" in JIT schema to C++ API type. Value
  73. # types look the same no matter if they are argument types or return
  74. # types. Returns None if the type in question is not a value type.
  75. def valuetype_type(
  76. t: Type, *, binds: ArgName, remove_non_owning_ref_types: bool = False
  77. ) -> Optional[NamedCType]:
  78. if isinstance(t, BaseType):
  79. if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
  80. return None
  81. if remove_non_owning_ref_types:
  82. if t.name == BaseTy.str:
  83. raise AssertionError(
  84. "string ref->value conversion: not implemented yet"
  85. )
  86. # All other BaseType currently map directly to BaseCppTypes.
  87. return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
  88. elif isinstance(t, OptionalType):
  89. elem = valuetype_type(t.elem, binds=binds)
  90. if elem is None:
  91. return None
  92. return NamedCType(binds, OptionalCType(elem.type))
  93. elif isinstance(t, ListType):
  94. if str(t.elem) == "bool":
  95. assert t.size is not None
  96. return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
  97. else:
  98. return None
  99. else:
  100. raise AssertionError(f"unrecognized type {repr(t)}")
  101. # Translation of types occuring in JIT arguments to a C++ argument type.
  102. # If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
  103. # For example, we'll return std::vector<int> instead of IntArrayRef.
  104. # See Note [translation from C++ reference to value types]
  105. def argumenttype_type(
  106. t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False
  107. ) -> NamedCType:
  108. # If it's a value type, do the value type translation
  109. r = valuetype_type(
  110. t, binds=binds, remove_non_owning_ref_types=remove_non_owning_ref_types
  111. )
  112. if r is not None:
  113. return r
  114. if isinstance(t, BaseType):
  115. if t.name == BaseTy.Tensor:
  116. if mutable and not local.use_const_ref_for_mutable_tensors():
  117. return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
  118. else:
  119. return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
  120. elif t.name == BaseTy.Scalar:
  121. return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
  122. else:
  123. raise AssertionError(f"base type should have been value type {t}")
  124. elif isinstance(t, OptionalType):
  125. if str(t.elem) == "Tensor":
  126. if mutable and not local.use_const_ref_for_mutable_tensors():
  127. return NamedCType(
  128. binds, MutRefCType(BaseCType(tensorT))
  129. ) # TODO: fix this discrepancy
  130. else:
  131. return NamedCType(
  132. binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
  133. )
  134. elif str(t.elem) == "Scalar":
  135. return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
  136. elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
  137. return NamedCType(binds, BaseCType(optionalIntArrayRefT))
  138. elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
  139. return NamedCType(binds, OptionalCType(elem.type))
  140. elif isinstance(t, ListType):
  141. # TODO: remove these special cases, ArrayRef fallthrough works fine
  142. if str(t.elem) == "int":
  143. if remove_non_owning_ref_types:
  144. return NamedCType(binds, VectorCType(BaseCType(longT)))
  145. else:
  146. return NamedCType(binds, BaseCType(intArrayRefT))
  147. elif str(t.elem) == "Tensor":
  148. return NamedCType(binds, BaseCType(tensorListT))
  149. elif str(t.elem) == "Scalar":
  150. return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
  151. elif str(t.elem) == "SymInt":
  152. return NamedCType(binds, BaseCType(symIntArrayRefT))
  153. elif str(t.elem) == "Dimname":
  154. return NamedCType(binds, BaseCType(dimnameListT))
  155. elif str(t.elem) == "Tensor?":
  156. return NamedCType(
  157. binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
  158. )
  159. elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
  160. return NamedCType(binds, ArrayRefCType(elem.type))
  161. else:
  162. raise AssertionError(f"unrecognized type {repr(t)}")
  163. # Translate a JIT argument into its C++ type
  164. def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
  165. return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
  166. # Translation of a (non-multi) return type from JIT to C++
  167. # N.B: returntype_type returns a CType, not a NamedCType.
  168. # This is mostly because of the mismatch between return types and return names.
  169. # e.g. a function with a return type of 'void' has 0 return names,
  170. # and a function with a return type of 'std::tuple' has >1 return name.
  171. def returntype_type(t: Type, *, mutable: bool) -> CType:
  172. # placeholder is ignored
  173. r = valuetype_type(t, binds="__placeholder__")
  174. if r is not None:
  175. return r.type
  176. if isinstance(t, BaseType):
  177. if t.name == BaseTy.Tensor:
  178. if mutable:
  179. if local.use_const_ref_for_mutable_tensors():
  180. return ConstRefCType(BaseCType(tensorT))
  181. else:
  182. return MutRefCType(BaseCType(tensorT))
  183. else:
  184. # Note [Tensor Copy Returns]
  185. # Currently, we use "Argument.is_write" to determine
  186. # whether or not Tensor return types should be copies or references.
  187. # If that ever changes, take a look at other locations of this note!
  188. return BaseCType(tensorT)
  189. elif t.name == BaseTy.Scalar:
  190. return BaseCType(scalarT)
  191. elif isinstance(t, ListType):
  192. assert (
  193. not mutable
  194. ), "Native functions should never return a mutable tensor list. They should return void."
  195. elem = returntype_type(t.elem, mutable=False)
  196. assert t.size is None, f"fixed size list returns not supported: {t}"
  197. return VectorCType(elem)
  198. raise AssertionError(f"unrecognized return type {t}")
  199. # Translation of a single return to its C++ type
  200. def return_type(r: Return) -> CType:
  201. return returntype_type(r.type, mutable=r.is_write)
  202. # Translation of a full (possibly multi) return from JIT to its C++ type
  203. def returns_type(rs: Sequence[Return]) -> CType:
  204. if len(rs) == 0:
  205. return BaseCType(voidT)
  206. elif len(rs) == 1:
  207. return return_type(rs[0])
  208. else:
  209. return TupleCType([return_type(r) for r in rs])
  210. def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
  211. returns: List[str] = []
  212. for i, r in enumerate(f.func.returns):
  213. # If we have an inplace function, the return argument is
  214. # implicitly named self.
  215. # TODO: Consider incorporating this into the data model
  216. if f.func.name.name.inplace:
  217. assert i == 0, "illegal inplace function with multiple returns"
  218. name = "self"
  219. # If we are out function, the name is the name of the
  220. # corresponding output function (r.name will get recorded
  221. # in field_name later.)
  222. elif f.func.is_out_fn():
  223. name = f.func.arguments.out[i].name
  224. # If the return argument is explicitly named...
  225. elif r.name:
  226. name_conflict = any(
  227. r.name == a.name for a in f.func.schema_order_arguments()
  228. )
  229. if name_conflict and not f.func.is_out_fn():
  230. name = f"{r.name}_return"
  231. else:
  232. name = r.name
  233. # If there is no explicit name and no fallback name was passed in, we just name the output result,
  234. # unless it's a multi-return, in which case it's result0,
  235. # result1, etc (zero-indexed)
  236. else:
  237. name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
  238. returns.append(name)
  239. return returns
  240. JIT_TO_CPP_DEFAULT = {
  241. "False": "false",
  242. "True": "true",
  243. "None": "c10::nullopt", # UGH this one is type directed
  244. "Mean": "at::Reduction::Mean",
  245. "[]": "{}",
  246. "contiguous_format": "MemoryFormat::Contiguous",
  247. "long": "at::kLong",
  248. }
  249. # Convert a JIT default into C++ expression representing the default
  250. def default_expr(d: str, t: Type) -> str:
  251. if d == "None" and str(t) == "Tensor?":
  252. return "{}"
  253. if isinstance(t, BaseType) and t.name is BaseTy.str:
  254. # Schema allows single quotes but C++ needs double
  255. if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
  256. s = ""
  257. i = 1
  258. while i + 1 < len(d):
  259. if d[i] != "\\":
  260. if d[i] == '"':
  261. s += '\\"'
  262. else:
  263. s += d[i]
  264. i += 1
  265. else:
  266. if d[i + 1] == "'":
  267. s += "'"
  268. else:
  269. s += d[i : i + 2]
  270. i += 2
  271. return f'"{s}"'
  272. if isinstance(t, OptionalType):
  273. if d == "None":
  274. return "c10::nullopt"
  275. return default_expr(d, t.elem)
  276. if isinstance(t, ListType):
  277. if d.startswith("[") and d.endswith("]"):
  278. return "{" + d[1:-1] + "}"
  279. elif t.size is None:
  280. # NOTE: Sized lists can have scalar defaults
  281. raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
  282. return JIT_TO_CPP_DEFAULT.get(d, d)
  283. # Convert an argument into its C++ API form
  284. def argument(
  285. a: Union[Argument, TensorOptionsArguments, SelfArgument],
  286. *,
  287. cpp_no_default_args: Set[str],
  288. method: bool,
  289. faithful: bool,
  290. has_tensor_options: bool,
  291. ) -> List[Binding]:
  292. def sub_argument(
  293. a: Union[Argument, TensorOptionsArguments, SelfArgument]
  294. ) -> List[Binding]:
  295. return argument(
  296. a,
  297. cpp_no_default_args=cpp_no_default_args,
  298. method=method,
  299. faithful=faithful,
  300. has_tensor_options=has_tensor_options,
  301. )
  302. if isinstance(a, Argument):
  303. binds: ArgName
  304. if a.name == "memory_format" and has_tensor_options:
  305. binds = SpecialArgName.possibly_redundant_memory_format
  306. else:
  307. binds = a.name
  308. default: Optional[str] = None
  309. if a.name not in cpp_no_default_args and a.default is not None:
  310. default = default_expr(a.default, a.type)
  311. return [
  312. Binding(
  313. nctype=argument_type(a, binds=binds),
  314. name=a.name,
  315. default=default,
  316. argument=a,
  317. )
  318. ]
  319. elif isinstance(a, TensorOptionsArguments):
  320. if faithful:
  321. return (
  322. sub_argument(a.dtype)
  323. + sub_argument(a.layout)
  324. + sub_argument(a.device)
  325. + sub_argument(a.pin_memory)
  326. )
  327. else:
  328. default = None
  329. # Enforced by NativeFunction.__post_init__
  330. assert "options" not in cpp_no_default_args
  331. if all(x.default == "None" for x in a.all()):
  332. default = "{}"
  333. elif a.dtype.default == "long":
  334. default = "at::kLong" # TODO: this is wrong
  335. return [
  336. Binding(
  337. nctype=NamedCType("options", BaseCType(tensorOptionsT)),
  338. name="options",
  339. default=default,
  340. argument=a,
  341. )
  342. ]
  343. elif isinstance(a, SelfArgument):
  344. if method:
  345. # Caller is responsible for installing implicit this in context!
  346. return []
  347. else:
  348. return sub_argument(a.argument)
  349. else:
  350. assert_never(a)
  351. def arguments(
  352. arguments: Arguments, *, faithful: bool, method: bool, cpp_no_default_args: Set[str]
  353. ) -> List[Binding]:
  354. args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
  355. if faithful:
  356. args.extend(arguments.non_out)
  357. args.extend(arguments.out)
  358. else:
  359. args.extend(arguments.out)
  360. args.extend(arguments.non_out)
  361. return [
  362. r.no_default() if faithful else r
  363. for a in args
  364. for r in argument(
  365. a,
  366. faithful=faithful,
  367. method=method,
  368. has_tensor_options=arguments.tensor_options is not None,
  369. cpp_no_default_args=cpp_no_default_args,
  370. )
  371. ]