lazy_ir.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. from abc import ABC
  2. from typing import List, Optional, Union
  3. from dataclasses import dataclass
  4. from torchgen.context import method_with_native_function
  5. from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
  6. from torchgen.api.types import (
  7. BaseCType,
  8. OptionalCType,
  9. VectorCType,
  10. kernel_signature,
  11. deviceT,
  12. )
  13. import torchgen.api.dispatcher as dispatcher
  14. from torchgen.api.lazy import (
  15. LazyIrSchema,
  16. LazyArgument,
  17. getValueT,
  18. isValueType,
  19. tensorListValueT,
  20. )
  21. from torchgen.dest.lazy_ts_lowering import ts_lowering_body
  22. def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
  23. """
  24. Given a LazyArgument,
  25. generate a c++ string for materializing an rvalue of that arg for passing into
  26. a lazy Node constructor.
  27. """
  28. if isValueType(arg.lazy_type):
  29. if isinstance(arg.lazy_type, BaseCType):
  30. if arg.is_wrapped_scalar:
  31. return f"node_{arg.name}"
  32. elif arg.lazy_type.type is tensorListValueT:
  33. return f"lazy_{arg.name}_tensorlist"
  34. elif arg.is_symint_or_list:
  35. cpp_type = arg.lazy_type.cpp_type()
  36. return (
  37. f"{cpp_type}(std::dynamic_pointer_cast<torch::lazy::SymbolicIntNode>"
  38. f"({arg.name}.toSymbolicIntNode())->node_, 0)"
  39. )
  40. return f"lazy_{arg.name}->GetIrValue()"
  41. elif isinstance(arg.lazy_type, OptionalCType):
  42. if arg.is_wrapped_scalar:
  43. return f"node_{arg.name}"
  44. return (
  45. f"lazy_{arg.name} ? "
  46. f"c10::make_optional(lazy_{arg.name}->GetIrValue()) : "
  47. "c10::nullopt"
  48. )
  49. else:
  50. raise AssertionError(
  51. f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
  52. )
  53. else:
  54. if isinstance(arg.lazy_type, VectorCType) and isinstance(
  55. arg.lazy_type.elem, BaseCType
  56. ):
  57. return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
  58. elif (
  59. isinstance(arg.lazy_type, OptionalCType)
  60. and isinstance(arg.lazy_type.elem, VectorCType)
  61. and isinstance(arg.lazy_type.elem.elem, BaseCType)
  62. ):
  63. return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})"
  64. else:
  65. return f"{arg.name}"
  66. def node_ctor_inputs(schema: LazyIrSchema) -> str:
  67. """
  68. Produce a formatted string with the arguments as passed into the constructor of a node class.
  69. """
  70. node_ctor_values = [
  71. node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args()
  72. ]
  73. return ", ".join(node_ctor_values)
  74. def gen_fallback_code(schema: LazyIrSchema, overload_name: str) -> str:
  75. """
  76. Generate code that falls back to eager conditioned on a predicate
  77. """
  78. fallback_args = ",\n ".join(
  79. [str(arg.name) for arg in schema.filtered_args(generator=True)]
  80. )
  81. if len(overload_name):
  82. aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
  83. else:
  84. aten_op_str = f"ATEN_OP({schema.aten_name})"
  85. or_has_generator = ""
  86. if schema.generator_arg:
  87. # generators are always optional and there is never more than one, at least currently
  88. or_has_generator = f" || ({schema.generator_arg.name}.has_value() && {schema.generator_arg.name}->defined())"
  89. return f"""
  90. if (force_eager_fallback({aten_symbol(schema)}){or_has_generator}) {{
  91. return at::native::call_fallback_fn<&ltc_eager_fallback, {aten_op_str}>::call(
  92. {fallback_args}
  93. );
  94. }}
  95. """
  96. def aten_symbol(schema: LazyIrSchema) -> str:
  97. missing_interned_strings = {
  98. "sigmoid_backward",
  99. }
  100. if schema.aten_name in missing_interned_strings:
  101. return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
  102. return f"at::aten::{schema.aten_name}"
  103. @dataclass(frozen=True)
  104. class GenLazyIR(ABC):
  105. backend_index: BackendIndex
  106. node_base: str
  107. @method_with_native_function
  108. def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
  109. func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
  110. return self.gen(f)
  111. # there is no lowering functionality generated unless this IR base class is subclassed and
  112. # implemented as a backend-specific node
  113. def lowering_function(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
  114. return ""
  115. def can_be_reused_function(
  116. self, f: Union[NativeFunctionsGroup, NativeFunction], node_ctor_args: str
  117. ) -> str:
  118. return f"""bool CanBeReused({node_ctor_args}) const {{
  119. return false;
  120. }}"""
  121. def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
  122. # backends can customize the way the node base class constructor is called,
  123. # as long as all of its arguments can be generated from information available from the schema
  124. base_ctor_value_args_list = []
  125. for arg in schema.filtered_args(values=True, scalars=False):
  126. if isinstance(arg.lazy_type, BaseCType) or isinstance(
  127. arg.lazy_type, VectorCType
  128. ):
  129. base_ctor_value_args_list.append(f"{arg.name}")
  130. elif isinstance(arg.lazy_type, OptionalCType):
  131. base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
  132. else:
  133. raise AssertionError(
  134. f"Unsupported type ({arg.lazy_type}) - add support if necessary"
  135. )
  136. base_ctor_value_args = ", ".join(base_ctor_value_args_list)
  137. scalar_args = schema.filtered_args(values=False, scalars=True)
  138. scalar_hashes = ", ".join([f"{a.name}" for a in scalar_args])
  139. return f"""{self.node_base}(torch::lazy::OpKind({aten_symbol(schema)}),
  140. {{{base_ctor_value_args}}}, std::move(shapes),
  141. /* num_outputs */ {len(schema.returns)},
  142. torch::lazy::MHash({scalar_hashes}))"""
  143. def gen(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
  144. # for now, we just want one IR class decl and soon after also the method defs
  145. # and we use the functional version not out/inplace.
  146. func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
  147. schema = LazyIrSchema(func)
  148. all_args = schema.filtered_args()
  149. value_args = schema.filtered_args(values=True, scalars=False)
  150. scalar_args = schema.filtered_args(values=False, scalars=True)
  151. node_ctor_args = ", ".join(
  152. [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
  153. )
  154. scalar_initializers = ",\n ".join(
  155. [f"{a.name}({a.name})" for a in scalar_args]
  156. )
  157. comma_if_scalar_initializers = ",\n" if len(scalar_initializers) else ""
  158. scalar_decls = "\n ".join(
  159. [
  160. f"std::string {a.name};"
  161. if a.lazy_type.cpp_type() == "c10::string_view"
  162. else f"{a.lazy_type.cpp_type()} {a.name};"
  163. for a in scalar_args
  164. ]
  165. )
  166. optional_values = [
  167. arg.name
  168. for arg in schema.filtered_args(values=True, scalars=False)
  169. if isinstance(arg.lazy_type, OptionalCType)
  170. ]
  171. has_optional_decls = "\n ".join(
  172. [f"bool has_{value}: 1;" for value in optional_values]
  173. )
  174. has_optional_defs = "\n ".join(
  175. [f"has_{value} = !!{value};" for value in optional_values]
  176. )
  177. members_to_string = []
  178. for arg in scalar_args:
  179. if isinstance(arg.lazy_type, OptionalCType):
  180. members_to_string.append(
  181. f"""if ({arg.name}.has_value()) {{
  182. ss << ", {arg.name}=" << {arg.name}.value();
  183. }} else {{
  184. ss << ", {arg.name}=null";
  185. }}"""
  186. )
  187. else:
  188. members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};')
  189. members_to_string_str = "\n ".join(members_to_string)
  190. return [
  191. f"""\
  192. class {schema.node_name} : public {self.node_base} {{
  193. public:
  194. static torch::lazy::OpKind ClassOpKind() {{
  195. return torch::lazy::OpKind({aten_symbol(schema)});
  196. }}
  197. {schema.node_name}({node_ctor_args}, std::vector<torch::lazy::Shape>&& shapes)
  198. : {self.node_base_ctor_call(schema)}{comma_if_scalar_initializers}
  199. {scalar_initializers}
  200. {{
  201. {has_optional_defs}
  202. }}
  203. std::string ToString() const override {{
  204. std::stringstream ss;
  205. ss << {self.node_base}::ToString();
  206. {members_to_string_str}
  207. return ss.str();
  208. }}
  209. {self.can_be_reused_function(f, node_ctor_args)}
  210. {self.lowering_function(f)}
  211. {scalar_decls}
  212. {has_optional_decls}
  213. }};
  214. """,
  215. ]
  216. @dataclass(frozen=True)
  217. class GenTSLazyIR(GenLazyIR):
  218. def lowering_function(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
  219. return f"""torch::lazy::TSOpVector Lower(std::shared_ptr<torch::jit::GraphFunction> function,
  220. torch::lazy::TSLoweringContext* loctx) const override {{
  221. {ts_lowering_body(f)}
  222. }}"""
  223. def can_be_reused_function(
  224. self, f: Union[NativeFunctionsGroup, NativeFunction], node_ctor_args: str
  225. ) -> str:
  226. func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
  227. schema = LazyIrSchema(func)
  228. value_comparsion = []
  229. for arg in schema.positional_values:
  230. if isinstance(arg.lazy_type, OptionalCType):
  231. value_comparsion.append(
  232. f"operand(i++) == {arg.name}.value_or(kNullValue)"
  233. )
  234. else:
  235. value_comparsion.append(f"operand(i++) == {arg.name}")
  236. for arg in schema.positional_scalars:
  237. value_comparsion.append(f"this->{arg.name} == {arg.name}")
  238. for arg in schema.keyword_values:
  239. value_comparsion.append(f"operand(i++) == {arg.name}")
  240. for arg in schema.keyword_scalars:
  241. value_comparsion.append(f"this->{arg.name} == {arg.name}")
  242. value_comparsion_str = " &&\n ".join(value_comparsion)
  243. return f"""bool CanBeReused({node_ctor_args}) const {{
  244. size_t i = 0;
  245. return ({value_comparsion_str});
  246. }}"""
  247. @dataclass(frozen=True)
  248. class GenLazyNativeFuncDefinition:
  249. class_method_name: str
  250. backend_index: BackendIndex
  251. tensor_class: str
  252. gen_forced_fallback_code: bool
  253. backend_namespace: str
  254. get_tensorlist: str
  255. get_tensor_or_wrap_number: str
  256. try_get_tensor: str
  257. metrics_counter: str
  258. create_tensor: str
  259. create_from_first_tensor: bool
  260. create_aten_from_ltc_tensor: str
  261. tuple_aten_from_ltc_tensors: str
  262. lazy_tensor_ptr: str
  263. get_device_fn: str
  264. def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
  265. value_args = schema.filtered_args(values=True, scalars=False)
  266. # Generates lazy_{name} variables for LazyTensors wrapping input tensors
  267. lazy_tensor_decls: List[str] = []
  268. for arg in value_args:
  269. if arg.is_wrapped_scalar:
  270. if isinstance(arg.lazy_type, OptionalCType):
  271. lazy_tensor_decls.append(
  272. f"""auto node_{arg.name} = {arg.name} ?
  273. c10::make_optional(torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(*{arg.name})):
  274. c10::nullopt;"""
  275. )
  276. else:
  277. lazy_tensor_decls.append(
  278. f"""auto node_{arg.name} =
  279. torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen({arg.name});"""
  280. )
  281. elif arg.is_symint_or_list:
  282. continue # values are extracted in isValueType
  283. elif isinstance(arg.lazy_type, BaseCType):
  284. if arg.lazy_type.type is tensorListValueT:
  285. lazy_tensor_decls.append(
  286. f"auto lazy_{arg.name}_tensorlist = "
  287. f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});"
  288. )
  289. else:
  290. lazy_tensor_decls.append(
  291. f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
  292. f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
  293. )
  294. elif isinstance(arg.lazy_type, OptionalCType):
  295. # TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it
  296. # until we encounter a real world example.
  297. lazy_tensor_decls.append(
  298. f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
  299. f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
  300. )
  301. else:
  302. raise AssertionError(
  303. f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
  304. )
  305. return ("\n ").join(lazy_tensor_decls)
  306. def force_eager_fallback(self, func: NativeFunction, schema: LazyIrSchema) -> str:
  307. if self.gen_forced_fallback_code:
  308. return gen_fallback_code(schema, overload_name=func.func.name.overload_name)
  309. return ""
  310. def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:
  311. return f"{self.metrics_counter};"
  312. def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str:
  313. value_args = schema.filtered_args(values=True, scalars=False)
  314. scalar_args = schema.filtered_args(values=False, scalars=True)
  315. value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
  316. optional_device = OptionalCType(BaseCType(deviceT))
  317. optional_devices = [
  318. a.name for a in scalar_args if a.lazy_type == optional_device
  319. ]
  320. assert (
  321. len(value_types_names) > 0 or len(optional_devices) > 0
  322. ), "Expected at least one Value or Device type"
  323. get_device_str = (
  324. f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
  325. )
  326. return f"""auto common_device = {get_device_str};
  327. TORCH_INTERNAL_ASSERT(common_device);
  328. """
  329. def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str:
  330. metadata = self.backend_index.get_kernel(func)
  331. assert metadata is not None
  332. all_args = schema.filtered_args()
  333. returns_length = len(schema.returns)
  334. # call the meta kernel if it exists, to compute output shape/dtype for our IR
  335. if func.structured or func.structured_delegate is not None:
  336. meta_out = """std::vector<torch::lazy::Shape> shapes{
  337. torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
  338. if returns_length > 1:
  339. def this_shape(i: int) -> str:
  340. return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())"
  341. shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
  342. meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};"
  343. shape_str = f"""auto out_meta = at::meta::{schema.aten_name}({', '.join(str(a.name) for a in all_args)});
  344. {meta_out}"""
  345. else:
  346. shape_sig = ComputeShapeSignature(metadata.kernel, func)
  347. shape_str = f"""
  348. auto shapes = {shape_sig.shape_call};"""
  349. shape_str += f"""
  350. TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});"""
  351. # Calculating which dimensions are symbolic
  352. func_schema_str = "aten::" + str(func.func)
  353. shape_str += f"""
  354. if(torch::lazy::symbolicShapeEnabled()){{
  355. std::vector<torch::jit::IValue> inputs = {{ {', '.join(str(a.name) for a in all_args)} }};
  356. char* schema_str = "{func_schema_str}";
  357. applySymbolicShapesOnLT(schema_str, inputs, shapes);
  358. }}
  359. """
  360. return shape_str
  361. def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str:
  362. node_ctor_input_str = node_ctor_inputs(schema)
  363. return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str});
  364. if (!node) {{
  365. {self.shape_inference(func, schema)}
  366. node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));
  367. CacheNode(node);
  368. }}
  369. """
  370. def create_lazy_tensor(self, first_tensor_name: Optional[str] = None) -> str:
  371. # xla uses an instance method for tensor creation, for the time being
  372. if self.create_from_first_tensor:
  373. # TODO(whc) remove this if XLA switches to using static method for creation
  374. assert (
  375. first_tensor_name is not None
  376. ), "Requires first tensor to create lazy tensor"
  377. return f"{first_tensor_name}.{self.create_tensor}"
  378. return f"{self.backend_namespace}::{self.create_tensor}"
  379. def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str:
  380. returns_length = len(schema.returns)
  381. value_args = schema.filtered_args(values=True, scalars=False)
  382. value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
  383. first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None
  384. bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}(
  385. {self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
  386. if returns_length > 1:
  387. assert (
  388. len(value_types_names) > 0
  389. ), "Code below assumes there is at least one tensor arg"
  390. bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
  391. for (int i = 0; i < {returns_length}; i++) {{
  392. lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
  393. }}
  394. auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""
  395. if schema.name.name.inplace or func.func.is_out_fn():
  396. assert returns_length == 1, (
  397. "We assumed there was no such case where an op is an in-place variant "
  398. f"and has tuple outputs, but got tuple of len {returns_length}."
  399. )
  400. bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node);
  401. auto& result = {first_tensor_name};"""
  402. bridge_str += """
  403. return result;"""
  404. return bridge_str
  405. @method_with_native_function
  406. def __call__(self, func: NativeFunction) -> List[str]:
  407. sig = kernel_signature(func, self.backend_index)
  408. metadata = self.backend_index.get_kernel(func)
  409. assert metadata is not None
  410. schema = LazyIrSchema(func.func)
  411. return [
  412. f"""\
  413. {sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
  414. {self.force_eager_fallback(func, schema)}
  415. {self.metrics(func, schema)}
  416. {self.get_device(func, schema)}
  417. {self.lazy_tensor_decls(func, schema)}
  418. {self.build_ir_node(func, schema)}
  419. {self.return_aten_tensor(func, schema)}
  420. }};\n
  421. """
  422. ]
  423. class ComputeShapeSignature:
  424. """
  425. Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
  426. """
  427. def __init__(self, kernel_name: str, f: NativeFunction):
  428. self.__schema = LazyIrSchema(f.func)
  429. self.__dispatch_args = ", ".join(
  430. [a.decl() for a in dispatcher.arguments(f.func)]
  431. )
  432. self.__call_args = ", ".join(
  433. [f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)]
  434. )
  435. self.__kernel_name = kernel_name
  436. def __decl_suffix(self) -> str:
  437. return f"{self.__kernel_name}({self.__dispatch_args})"
  438. def __call_suffix(self) -> str:
  439. return f"{self.__kernel_name}({self.__call_args})"
  440. @property
  441. def shape_decl(self) -> str:
  442. return f"TORCH_API std::vector<torch::lazy::Shape> compute_shape_{self.__decl_suffix()}"
  443. @property
  444. def shape_call(self) -> str:
  445. return f"torch::lazy::compute_shape_{self.__call_suffix()}"
  446. @dataclass(frozen=True)
  447. class GenLazyShapeInferenceDefinition:
  448. backend_index: BackendIndex
  449. tensor_class: str
  450. @method_with_native_function
  451. def __call__(self, f: NativeFunction) -> List[str]:
  452. sig = kernel_signature(f, self.backend_index)
  453. metadata = self.backend_index.get_kernel(f)
  454. assert metadata is not None
  455. # Only generate shape/dtype fn for non-structured kernels,
  456. # since we just use the meta function for structured kernels
  457. if not f.structured and f.structured_delegate is None:
  458. shape_sig = ComputeShapeSignature(metadata.kernel, f)
  459. return ["\n".join([f"{shape_sig.shape_decl};"])]
  460. else:
  461. return []