lazy.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. from typing import List, Union, Tuple, Optional
  2. from torchgen.model import (
  3. Type,
  4. BaseTy,
  5. BaseType,
  6. OptionalType,
  7. ListType,
  8. OperatorName,
  9. FunctionSchema,
  10. Return,
  11. TensorOptionsArguments,
  12. Argument,
  13. )
  14. from torchgen.api.types import (
  15. CType,
  16. BaseCppType,
  17. BaseCType,
  18. OptionalCType,
  19. NamedCType,
  20. deviceT,
  21. layoutT,
  22. VectorCType,
  23. boolT,
  24. longT,
  25. doubleT,
  26. ListCType,
  27. stringT,
  28. scalarT,
  29. scalarTypeT,
  30. memoryFormatT,
  31. SymIntT,
  32. )
  33. _valueT = None
  34. def getValueT() -> BaseCppType:
  35. global _valueT
  36. if not _valueT:
  37. raise NotImplementedError(
  38. "The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
  39. )
  40. return _valueT
  41. def setValueT(val: BaseCppType) -> None:
  42. global _valueT
  43. _valueT = val
  44. # this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
  45. # making it easier to represent special properties of an arg.
  46. tensorListValueT = BaseCppType("torch::lazy", "Value")
  47. def process_ir_type(
  48. typ: Type,
  49. ) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
  50. """
  51. This function takes a type from NativeFunctions and converts it for use with
  52. lazy tensor codegen.
  53. Type conversion for lazy currently consists of
  54. (1) changing at::Tensors into lazy::Values
  55. (2) wrapping everything in a BaseCType
  56. (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
  57. (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
  58. There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'
  59. This is incomplete- there are assertions in places that it's expected to need to add
  60. more types as the codegen is used with more operators.
  61. """
  62. if isinstance(typ, BaseType):
  63. if typ.name == BaseTy.Tensor:
  64. return BaseCType(getValueT())
  65. elif typ.name == BaseTy.Scalar:
  66. # at::scalar has special handling,
  67. # and is wrapped in an lazy::Value just like at::tensor
  68. return BaseCType(getValueT())
  69. elif typ.name == BaseTy.ScalarType:
  70. return BaseCType(scalarTypeT)
  71. elif typ.name == BaseTy.int:
  72. return BaseCType(longT)
  73. elif typ.name == BaseTy.SymInt:
  74. return BaseCType(getValueT())
  75. elif typ.name == BaseTy.bool:
  76. return BaseCType(boolT)
  77. elif typ.name == BaseTy.float:
  78. return BaseCType(doubleT)
  79. elif typ.name == BaseTy.str:
  80. return BaseCType(stringT)
  81. elif typ.name == BaseTy.Device:
  82. return BaseCType(deviceT)
  83. elif typ.name == BaseTy.Layout:
  84. return BaseCType(layoutT)
  85. elif typ.name == BaseTy.MemoryFormat:
  86. return BaseCType(memoryFormatT)
  87. else:
  88. raise AssertionError(f"TODO add support for type {repr(typ)}")
  89. elif isinstance(typ, OptionalType):
  90. return OptionalCType(process_ir_type(typ.elem))
  91. elif isinstance(typ, ListType):
  92. if str(typ.elem) == "Tensor?":
  93. # TODO(whc) is this actually correct? or should it use a Vector like above
  94. return ListCType(OptionalCType(BaseCType(getValueT())))
  95. elif str(typ.elem) == "Tensor":
  96. # this is a TensorList which comes in from GetTensorList as a Value
  97. return BaseCType(tensorListValueT)
  98. else:
  99. return VectorCType(process_ir_type(typ.elem))
  100. else:
  101. raise AssertionError(f"unrecognized type {repr(typ)}")
  102. def isValueType(typ: CType) -> bool:
  103. """
  104. Given a type, determine if it is a Value-like type. This is equivalent to
  105. being Tensor-like, but assumes the type has already been transformed.
  106. """
  107. if isinstance(typ, BaseCType):
  108. # I am regretting my naming conventions, but now we are wrapping at::scalar in
  109. # lazy value, while preserving other 'scalar' types as scalars in the IR
  110. return typ.type == getValueT() or typ.type == scalarT or typ.type == SymIntT
  111. elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
  112. return isValueType(typ.elem)
  113. return False
  114. def isSymIntType(typ: Type) -> bool:
  115. return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
  116. def isWrappedScalarType(typ: Type) -> bool:
  117. """
  118. Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
  119. Since we literally change the type from scalarT to valueT, information is lost.
  120. This function helps build a list of wrapped scalars to save that information
  121. """
  122. if isinstance(typ, BaseType):
  123. # I am regretting my naming conventions, but now we are wrapping at::scalar in
  124. # lazy value, while preserving other 'scalar' types as scalars in the IR
  125. return typ.name == BaseTy.Scalar
  126. elif isinstance(typ, (OptionalType, ListType)):
  127. return isWrappedScalarType(typ.elem)
  128. return False
  129. def isGeneratorType(typ: Type) -> bool:
  130. if isinstance(typ, BaseType):
  131. return typ.name == BaseTy.Generator
  132. elif isinstance(typ, (OptionalType)):
  133. return isGeneratorType(typ.elem)
  134. return False
  135. class LazyArgument:
  136. name: str
  137. orig_type: Type
  138. lazy_type_: Optional[CType]
  139. is_wrapped_scalar: bool
  140. is_generator: bool
  141. is_symint_or_list: bool
  142. # true if this argument is or contains a lazy IR value
  143. is_lazy_value: bool
  144. def __init__(self, arg: Argument):
  145. self.name = arg.name
  146. self.orig_type = arg.type
  147. self.is_optional = isinstance(arg.type, OptionalType)
  148. self.is_generator = isGeneratorType(arg.type)
  149. if self.is_generator:
  150. assert (
  151. self.is_optional
  152. ), "We expect all generators are optional since currently they are"
  153. # there is no handling for generators in TorchScript IR (or XLA)
  154. # so we fall back to eager if the (optional)generator has value, and otherwise
  155. # its null and safe to exclude from lazy IR
  156. self.lazy_type_ = None
  157. else:
  158. self.lazy_type_ = process_ir_type(arg.type)
  159. self.is_wrapped_scalar = isWrappedScalarType(arg.type)
  160. self.is_symint_or_list = isSymIntType(arg.type)
  161. self.is_lazy_value = not self.is_generator and isValueType(self.lazy_type)
  162. @property
  163. def lazy_type(self) -> CType:
  164. assert (
  165. self.lazy_type_ is not None
  166. ), f"Attempted to access lazy_type for invalid argument {self.name}"
  167. return self.lazy_type_
  168. # Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
  169. # Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
  170. # but carries type information from a native FunctionSchema modified for use with IR nodes,
  171. # and preserving original argument names.
  172. class LazyIrSchema:
  173. # The name of the operator this function schema describes.
  174. name: "OperatorName"
  175. positional_args: Tuple[LazyArgument, ...]
  176. keyword_args: Tuple[LazyArgument, ...]
  177. # TODO: Need to handle collisions with argument names at some point
  178. returns: Tuple["Return", ...]
  179. # if this schema has a Generator arg, list its orig ctype/name but don't
  180. # build a LazyArgument since lazy IR doesn't support it
  181. generator_arg: Optional[NamedCType] = None
  182. def __init__(self, func: FunctionSchema):
  183. positional_args = []
  184. for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
  185. if arg_field == "self_arg" and func.arguments.self_arg is not None:
  186. arg = getattr(func.arguments, "self_arg").argument
  187. positional_args.append(LazyArgument(arg))
  188. elif getattr(func.arguments, arg_field) is not None:
  189. positional_args.extend(
  190. [LazyArgument(arg) for arg in getattr(func.arguments, arg_field)]
  191. )
  192. self.positional_args = tuple(positional_args)
  193. keyword_args = []
  194. for arg_field in [
  195. "pre_tensor_options_kwarg_only",
  196. "tensor_options",
  197. "post_tensor_options_kwarg_only",
  198. "out",
  199. ]:
  200. curr_args = getattr(func.arguments, arg_field)
  201. if curr_args is not None:
  202. if isinstance(curr_args, TensorOptionsArguments):
  203. curr_args = curr_args.all()
  204. for arg in curr_args:
  205. if isGeneratorType(arg.type):
  206. assert (
  207. self.generator_arg is None
  208. ), "We expect there is only one generator arg"
  209. self.generator_arg = NamedCType(arg.name, arg.type)
  210. keyword_args.extend([LazyArgument(arg) for arg in curr_args])
  211. self.keyword_args = tuple(keyword_args)
  212. self.name = func.name
  213. self.returns = func.returns
  214. @property
  215. def node_name(self) -> str:
  216. """
  217. Return camel-case version of op in node.
  218. Note: This function also appends any `overload_name` in the operation.
  219. For example, if the op is `bitwise_and.Tensor`, the returned name
  220. will be `BitwiseAndTensor`.
  221. """
  222. op_name = f"{self.name.name}_{self.name.overload_name}".lower()
  223. return "".join(word.capitalize() or "" for word in op_name.split("_"))
  224. @property
  225. def aten_name(self) -> str:
  226. return f"{self.name.name}"
  227. @property
  228. def base_name(self) -> str:
  229. return f"{self.name.name.base}"
  230. def filtered_args(
  231. self,
  232. positional: bool = True,
  233. keyword: bool = True,
  234. values: bool = True,
  235. scalars: bool = True,
  236. generator: bool = False,
  237. ) -> List[LazyArgument]:
  238. # This function maintains the sorted order of arguments but provides different filtered views.
  239. # Some parts of the code care about kwargs vs args (TS lowerings),
  240. # other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
  241. # Generators are special cased, as they are needed for fallback/shape-inference but not supported
  242. # in TS lowerings and therefore also omitted from lazy IR.
  243. args: List[LazyArgument] = []
  244. if positional:
  245. args.extend(self.positional_args)
  246. if keyword:
  247. args.extend(self.keyword_args)
  248. if values and scalars and generator:
  249. return args
  250. elif values and scalars:
  251. return [a for a in args if not a.is_generator]
  252. elif values:
  253. return [a for a in args if a.is_lazy_value]
  254. elif scalars:
  255. return [
  256. a
  257. for a in args
  258. if not a.is_lazy_value and (generator or not a.is_generator)
  259. ]
  260. return []
  261. @property
  262. def positional_values(self) -> List[LazyArgument]:
  263. return self.filtered_args(
  264. positional=True, keyword=False, values=True, scalars=False
  265. )
  266. @property
  267. def positional_scalars(self) -> List[LazyArgument]:
  268. return self.filtered_args(
  269. positional=True, keyword=False, values=False, scalars=True
  270. )
  271. @property
  272. def keyword_values(self) -> List[LazyArgument]:
  273. return self.filtered_args(
  274. positional=False, keyword=True, values=True, scalars=False
  275. )
  276. @property
  277. def keyword_scalars(self) -> List[LazyArgument]:
  278. return self.filtered_args(
  279. positional=False, keyword=True, values=False, scalars=True
  280. )