dispatcher.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. from torchgen.model import (
  2. Argument,
  3. FunctionSchema,
  4. Return,
  5. SelfArgument,
  6. TensorOptionsArguments,
  7. Type,
  8. )
  9. from torchgen.api.types import ArgName, Binding, NamedCType, CType
  10. from torchgen.api import cpp
  11. from torchgen.utils import concatMap, assert_never
  12. import itertools
  13. from typing import Sequence, List, Union
  14. # This file describes the translation of JIT schema to the dispatcher
  15. # API, the *unboxed* calling convention by which invocations through
  16. # the dispatcher are made. Historically, the dispatcher API matched
  17. # the C++ API, but with the establishment of the boxed API, we've
  18. # made changes to the dispatcher API to so that the unboxed API
  19. # better aligns with the boxed API. The dispatcher API hooks heavily
  20. # into our template based boxing/unboxing machinery, so changes
  21. # to this convention will usually need template updates too.
  22. #
  23. # Prominent characteristics of the dispatcher API:
  24. #
  25. # - dtype, layout, device and pin_memory are represented as separate
  26. # arguments.
  27. #
  28. def name(func: FunctionSchema) -> str:
  29. return cpp.name(func)
  30. def argumenttype_type(
  31. t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False
  32. ) -> NamedCType:
  33. # This is a faux amis. If it makes sense in the future to add
  34. # more special cases here, or invert things so cpp.argument_type
  35. # calls this, or just completely inline the function, please do
  36. # it.
  37. return cpp.argumenttype_type(
  38. t,
  39. mutable=mutable,
  40. binds=binds,
  41. remove_non_owning_ref_types=remove_non_owning_ref_types,
  42. )
  43. def argument_type(
  44. a: Argument, *, binds: ArgName, remove_non_owning_ref_types: bool = False
  45. ) -> NamedCType:
  46. return argumenttype_type(
  47. a.type,
  48. mutable=a.is_write,
  49. binds=binds,
  50. remove_non_owning_ref_types=remove_non_owning_ref_types,
  51. )
  52. def returns_type(rs: Sequence[Return]) -> CType:
  53. # At present, there is no difference. But there could be!
  54. return cpp.returns_type(rs)
  55. def jit_arguments(func: FunctionSchema) -> List[Argument]:
  56. def to_argument(
  57. a: Union[Argument, TensorOptionsArguments, SelfArgument]
  58. ) -> List[Argument]:
  59. if isinstance(a, Argument):
  60. return [a]
  61. elif isinstance(a, SelfArgument):
  62. return [a.argument]
  63. elif isinstance(a, TensorOptionsArguments):
  64. return [a.dtype, a.layout, a.device, a.pin_memory]
  65. else:
  66. assert_never(a)
  67. return list(
  68. concatMap(
  69. to_argument,
  70. itertools.chain(
  71. func.arguments.positional, func.arguments.kwarg_only, func.arguments.out
  72. ),
  73. )
  74. )
  75. def argument(a: Argument, *, remove_non_owning_ref_types: bool = False) -> Binding:
  76. return Binding(
  77. nctype=argument_type(
  78. a, binds=a.name, remove_non_owning_ref_types=remove_non_owning_ref_types
  79. ),
  80. name=a.name,
  81. argument=a,
  82. )
  83. def arguments(func: FunctionSchema) -> List[Binding]:
  84. return [argument(a) for a in jit_arguments(func)]