native.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from torchgen.model import (
  2. Argument,
  3. FunctionSchema,
  4. Return,
  5. SelfArgument,
  6. TensorOptionsArguments,
  7. Type,
  8. )
  9. from torchgen.api.types import (
  10. ArgName,
  11. BaseCType,
  12. Binding,
  13. ConstRefCType,
  14. NamedCType,
  15. CType,
  16. MutRefCType,
  17. ListCType,
  18. OptionalCType,
  19. tensorT,
  20. scalarT,
  21. layoutT,
  22. deviceT,
  23. boolT,
  24. scalarTypeT,
  25. )
  26. from torchgen.api import cpp
  27. from torchgen import local
  28. from torchgen.utils import assert_never
  29. from typing import Union, Sequence, List, Optional
  30. # This file describes the translation of JIT schema to the native functions API.
  31. # This looks a lot like the C++ API (which makes historical sense, because the
  32. # idea was you wrote native functions to implement functions in the C++ API),
  33. # but over time we have evolved the C++ API without actually changing our
  34. # native:: kernels. The intention is to make native API and dispatcher API
  35. # line up as closely as possible, since this results in the least overhead
  36. # (no translation is needed from dispatcher API to native API).
  37. def name(func: FunctionSchema) -> str:
  38. name = str(func.name.name)
  39. # TODO: delete this!
  40. if func.is_out_fn():
  41. name += "_out"
  42. if func.name.overload_name:
  43. name += f"_{func.name.overload_name}"
  44. return name
  45. def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
  46. if str(t) == "Tensor?":
  47. tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
  48. if mutable and not local.use_const_ref_for_mutable_tensors():
  49. return NamedCType(binds, MutRefCType(tensor_type))
  50. else:
  51. return NamedCType(binds, ConstRefCType(tensor_type))
  52. elif str(t) == "Tensor?[]":
  53. return NamedCType(
  54. binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
  55. )
  56. elif str(t) == "Scalar":
  57. return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
  58. elif str(t) == "Scalar?":
  59. return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
  60. return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
  61. def returns_type(rs: Sequence[Return]) -> CType:
  62. return cpp.returns_type(rs)
  63. def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
  64. return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
  65. def argument(
  66. a: Union[Argument, SelfArgument, TensorOptionsArguments], *, is_out: bool
  67. ) -> List[Binding]:
  68. # Ideally, we NEVER default native functions. However, there are a number
  69. # of functions that call native:: directly and rely on the defaulting
  70. # existing. So for BC, we generate defaults for non-out variants (but not
  71. # for out variants, where it is impossible to generate an appropriate
  72. # default)
  73. should_default = not is_out
  74. if isinstance(a, Argument):
  75. default: Optional[str] = None
  76. if should_default and a.default is not None:
  77. default = cpp.default_expr(a.default, a.type)
  78. return [
  79. Binding(
  80. nctype=argument_type(a, binds=a.name),
  81. name=a.name,
  82. default=default,
  83. argument=a,
  84. )
  85. ]
  86. elif isinstance(a, SelfArgument):
  87. # Erase SelfArgument from the distinction
  88. return argument(a.argument, is_out=is_out)
  89. elif isinstance(a, TensorOptionsArguments):
  90. default = None
  91. if should_default:
  92. default = "{}"
  93. # TODO: Not sure why the arguments assigned here are for
  94. # TensorOptionsArguments and not the constituent pieces. It seems
  95. # to matter
  96. return [
  97. Binding(
  98. nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))),
  99. name="dtype",
  100. default=default,
  101. argument=a,
  102. ),
  103. Binding(
  104. nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))),
  105. name="layout",
  106. default=default,
  107. argument=a,
  108. ),
  109. Binding(
  110. nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))),
  111. name="device",
  112. default=default,
  113. argument=a,
  114. ),
  115. Binding(
  116. nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))),
  117. name="pin_memory",
  118. default=default,
  119. argument=a,
  120. ),
  121. ]
  122. else:
  123. assert_never(a)
  124. def arguments(func: FunctionSchema) -> List[Binding]:
  125. args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
  126. args.extend(func.arguments.non_out)
  127. args.extend(func.arguments.out)
  128. return [r for arg in args for r in argument(arg, is_out=func.is_out_fn())]