| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327 |
- from typing import List, Union, Tuple, Optional
- from torchgen.model import (
- Type,
- BaseTy,
- BaseType,
- OptionalType,
- ListType,
- OperatorName,
- FunctionSchema,
- Return,
- TensorOptionsArguments,
- Argument,
- )
- from torchgen.api.types import (
- CType,
- BaseCppType,
- BaseCType,
- OptionalCType,
- NamedCType,
- deviceT,
- layoutT,
- VectorCType,
- boolT,
- longT,
- doubleT,
- ListCType,
- stringT,
- scalarT,
- scalarTypeT,
- memoryFormatT,
- SymIntT,
- )
- _valueT = None
- def getValueT() -> BaseCppType:
- global _valueT
- if not _valueT:
- raise NotImplementedError(
- "The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
- )
- return _valueT
- def setValueT(val: BaseCppType) -> None:
- global _valueT
- _valueT = val
- # this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
- # making it easier to represent special properties of an arg.
- tensorListValueT = BaseCppType("torch::lazy", "Value")
- def process_ir_type(
- typ: Type,
- ) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
- """
- This function takes a type from NativeFunctions and converts it for use with
- lazy tensor codegen.
- Type conversion for lazy currently consists of
- (1) changing at::Tensors into lazy::Values
- (2) wrapping everything in a BaseCType
- (3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
- (1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
- There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'
- This is incomplete- there are assertions in places that it's expected to need to add
- more types as the codegen is used with more operators.
- """
- if isinstance(typ, BaseType):
- if typ.name == BaseTy.Tensor:
- return BaseCType(getValueT())
- elif typ.name == BaseTy.Scalar:
- # at::scalar has special handling,
- # and is wrapped in an lazy::Value just like at::tensor
- return BaseCType(getValueT())
- elif typ.name == BaseTy.ScalarType:
- return BaseCType(scalarTypeT)
- elif typ.name == BaseTy.int:
- return BaseCType(longT)
- elif typ.name == BaseTy.SymInt:
- return BaseCType(getValueT())
- elif typ.name == BaseTy.bool:
- return BaseCType(boolT)
- elif typ.name == BaseTy.float:
- return BaseCType(doubleT)
- elif typ.name == BaseTy.str:
- return BaseCType(stringT)
- elif typ.name == BaseTy.Device:
- return BaseCType(deviceT)
- elif typ.name == BaseTy.Layout:
- return BaseCType(layoutT)
- elif typ.name == BaseTy.MemoryFormat:
- return BaseCType(memoryFormatT)
- else:
- raise AssertionError(f"TODO add support for type {repr(typ)}")
- elif isinstance(typ, OptionalType):
- return OptionalCType(process_ir_type(typ.elem))
- elif isinstance(typ, ListType):
- if str(typ.elem) == "Tensor?":
- # TODO(whc) is this actually correct? or should it use a Vector like above
- return ListCType(OptionalCType(BaseCType(getValueT())))
- elif str(typ.elem) == "Tensor":
- # this is a TensorList which comes in from GetTensorList as a Value
- return BaseCType(tensorListValueT)
- else:
- return VectorCType(process_ir_type(typ.elem))
- else:
- raise AssertionError(f"unrecognized type {repr(typ)}")
- def isValueType(typ: CType) -> bool:
- """
- Given a type, determine if it is a Value-like type. This is equivalent to
- being Tensor-like, but assumes the type has already been transformed.
- """
- if isinstance(typ, BaseCType):
- # I am regretting my naming conventions, but now we are wrapping at::scalar in
- # lazy value, while preserving other 'scalar' types as scalars in the IR
- return typ.type == getValueT() or typ.type == scalarT or typ.type == SymIntT
- elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
- return isValueType(typ.elem)
- return False
- def isSymIntType(typ: Type) -> bool:
- return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
- def isWrappedScalarType(typ: Type) -> bool:
- """
- Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
- Since we literally change the type from scalarT to valueT, information is lost.
- This function helps build a list of wrapped scalars to save that information
- """
- if isinstance(typ, BaseType):
- # I am regretting my naming conventions, but now we are wrapping at::scalar in
- # lazy value, while preserving other 'scalar' types as scalars in the IR
- return typ.name == BaseTy.Scalar
- elif isinstance(typ, (OptionalType, ListType)):
- return isWrappedScalarType(typ.elem)
- return False
- def isGeneratorType(typ: Type) -> bool:
- if isinstance(typ, BaseType):
- return typ.name == BaseTy.Generator
- elif isinstance(typ, (OptionalType)):
- return isGeneratorType(typ.elem)
- return False
- class LazyArgument:
- name: str
- orig_type: Type
- lazy_type_: Optional[CType]
- is_wrapped_scalar: bool
- is_generator: bool
- is_symint_or_list: bool
- # true if this argument is or contains a lazy IR value
- is_lazy_value: bool
- def __init__(self, arg: Argument):
- self.name = arg.name
- self.orig_type = arg.type
- self.is_optional = isinstance(arg.type, OptionalType)
- self.is_generator = isGeneratorType(arg.type)
- if self.is_generator:
- assert (
- self.is_optional
- ), "We expect all generators are optional since currently they are"
- # there is no handling for generators in TorchScript IR (or XLA)
- # so we fall back to eager if the (optional)generator has value, and otherwise
- # its null and safe to exclude from lazy IR
- self.lazy_type_ = None
- else:
- self.lazy_type_ = process_ir_type(arg.type)
- self.is_wrapped_scalar = isWrappedScalarType(arg.type)
- self.is_symint_or_list = isSymIntType(arg.type)
- self.is_lazy_value = not self.is_generator and isValueType(self.lazy_type)
- @property
- def lazy_type(self) -> CType:
- assert (
- self.lazy_type_ is not None
- ), f"Attempted to access lazy_type for invalid argument {self.name}"
- return self.lazy_type_
- # Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
- # Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
- # but carries type information from a native FunctionSchema modified for use with IR nodes,
- # and preserving original argument names.
- class LazyIrSchema:
- # The name of the operator this function schema describes.
- name: "OperatorName"
- positional_args: Tuple[LazyArgument, ...]
- keyword_args: Tuple[LazyArgument, ...]
- # TODO: Need to handle collisions with argument names at some point
- returns: Tuple["Return", ...]
- # if this schema has a Generator arg, list its orig ctype/name but don't
- # build a LazyArgument since lazy IR doesn't support it
- generator_arg: Optional[NamedCType] = None
- def __init__(self, func: FunctionSchema):
- positional_args = []
- for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
- if arg_field == "self_arg" and func.arguments.self_arg is not None:
- arg = getattr(func.arguments, "self_arg").argument
- positional_args.append(LazyArgument(arg))
- elif getattr(func.arguments, arg_field) is not None:
- positional_args.extend(
- [LazyArgument(arg) for arg in getattr(func.arguments, arg_field)]
- )
- self.positional_args = tuple(positional_args)
- keyword_args = []
- for arg_field in [
- "pre_tensor_options_kwarg_only",
- "tensor_options",
- "post_tensor_options_kwarg_only",
- "out",
- ]:
- curr_args = getattr(func.arguments, arg_field)
- if curr_args is not None:
- if isinstance(curr_args, TensorOptionsArguments):
- curr_args = curr_args.all()
- for arg in curr_args:
- if isGeneratorType(arg.type):
- assert (
- self.generator_arg is None
- ), "We expect there is only one generator arg"
- self.generator_arg = NamedCType(arg.name, arg.type)
- keyword_args.extend([LazyArgument(arg) for arg in curr_args])
- self.keyword_args = tuple(keyword_args)
- self.name = func.name
- self.returns = func.returns
- @property
- def node_name(self) -> str:
- """
- Return camel-case version of op in node.
- Note: This function also appends any `overload_name` in the operation.
- For example, if the op is `bitwise_and.Tensor`, the returned name
- will be `BitwiseAndTensor`.
- """
- op_name = f"{self.name.name}_{self.name.overload_name}".lower()
- return "".join(word.capitalize() or "" for word in op_name.split("_"))
- @property
- def aten_name(self) -> str:
- return f"{self.name.name}"
- @property
- def base_name(self) -> str:
- return f"{self.name.name.base}"
- def filtered_args(
- self,
- positional: bool = True,
- keyword: bool = True,
- values: bool = True,
- scalars: bool = True,
- generator: bool = False,
- ) -> List[LazyArgument]:
- # This function maintains the sorted order of arguments but provides different filtered views.
- # Some parts of the code care about kwargs vs args (TS lowerings),
- # other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
- # Generators are special cased, as they are needed for fallback/shape-inference but not supported
- # in TS lowerings and therefore also omitted from lazy IR.
- args: List[LazyArgument] = []
- if positional:
- args.extend(self.positional_args)
- if keyword:
- args.extend(self.keyword_args)
- if values and scalars and generator:
- return args
- elif values and scalars:
- return [a for a in args if not a.is_generator]
- elif values:
- return [a for a in args if a.is_lazy_value]
- elif scalars:
- return [
- a
- for a in args
- if not a.is_lazy_value and (generator or not a.is_generator)
- ]
- return []
- @property
- def positional_values(self) -> List[LazyArgument]:
- return self.filtered_args(
- positional=True, keyword=False, values=True, scalars=False
- )
- @property
- def positional_scalars(self) -> List[LazyArgument]:
- return self.filtered_args(
- positional=True, keyword=False, values=False, scalars=True
- )
- @property
- def keyword_values(self) -> List[LazyArgument]:
- return self.filtered_args(
- positional=False, keyword=True, values=True, scalars=False
- )
- @property
- def keyword_scalars(self) -> List[LazyArgument]:
- return self.filtered_args(
- positional=False, keyword=True, values=False, scalars=True
- )
|