context.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from torchgen.utils import S, T, context
  2. from torchgen.model import (
  3. NativeFunction,
  4. NativeFunctionsGroup,
  5. NativeFunctionsViewGroup,
  6. BackendIndex,
  7. DispatchKey,
  8. )
  9. import torchgen.local as local
  10. import functools
  11. from typing import TypeVar, Union, Iterator, Callable, Dict, Optional
  12. import contextlib
  13. # Helper functions for defining generators on things in the model
  14. F = TypeVar(
  15. "F",
  16. NativeFunction,
  17. NativeFunctionsGroup,
  18. NativeFunctionsViewGroup,
  19. Union[NativeFunction, NativeFunctionsGroup],
  20. Union[NativeFunction, NativeFunctionsViewGroup],
  21. )
  22. F2 = TypeVar(
  23. "F2",
  24. NativeFunction,
  25. NativeFunctionsGroup,
  26. Optional[NativeFunction],
  27. bool,
  28. )
  29. @contextlib.contextmanager
  30. def native_function_manager(
  31. g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction]
  32. ) -> Iterator[None]:
  33. if isinstance(g, NativeFunctionsGroup):
  34. # By default, we associate all errors with structured native functions
  35. # with the out variant. In some cases, it might be better to have
  36. # a more specific place to hang things; if so, use
  37. # native_function_manager again on the inside
  38. f = g.out
  39. elif isinstance(g, NativeFunctionsViewGroup):
  40. # We associate errors with the view operator
  41. f = g.view
  42. else:
  43. f = g
  44. with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"):
  45. with local.parametrize(
  46. use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors
  47. ):
  48. yield
  49. # Given a function that operates on NativeFunction, wrap it into a new function
  50. # that sets some appropriate context managers for that native function.
  51. # YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound
  52. # (you will get an error if we try to access the local variables without having
  53. # set them).
  54. def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
  55. @functools.wraps(func)
  56. def wrapper(f: F) -> T:
  57. with native_function_manager(f):
  58. return func(f)
  59. return wrapper
  60. def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
  61. @functools.wraps(func)
  62. def wrapper(f: F, f2: F2) -> T:
  63. # The first native_function is assumed to be the one with the appropriate context.
  64. with native_function_manager(f):
  65. return func(f, f2)
  66. return wrapper
  67. def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
  68. @functools.wraps(func)
  69. def wrapper(slf: S, f: F) -> T:
  70. with native_function_manager(f):
  71. return func(slf, f)
  72. return wrapper
  73. # Convenience decorator for functions that explicitly take in a BackendIndex,
  74. # instead of indirectly taking one in as a closure
  75. def with_native_function_and_index(
  76. func: Callable[[F, BackendIndex], T]
  77. ) -> Callable[[F, BackendIndex], T]:
  78. @functools.wraps(func)
  79. def wrapper(f: F, backend_index: BackendIndex) -> T:
  80. with native_function_manager(f):
  81. return func(f, backend_index)
  82. return wrapper
  83. # Convenience decorator for functions that explicitly take in a Dict of BackendIndices
  84. def with_native_function_and_indices(
  85. func: Callable[[F, Dict[DispatchKey, BackendIndex]], T]
  86. ) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]:
  87. @functools.wraps(func)
  88. def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T:
  89. with native_function_manager(f):
  90. return func(f, backend_indices)
  91. return wrapper