context.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import string
  2. from typing import Callable, Sequence, Any, Dict
  3. from itertools import chain
  4. import torch
  5. from torch.fx.graph import Graph, Node
  6. import torch.overrides
  7. from torch._prims.utils import TensorMeta
  8. import torch._refs as refs
  9. # TODO: automap torch operations to references
  10. # (need to throw a good assertion if the mapping doesn't exist)
  11. _torch_to_reference_map = {
  12. torch.add: refs.add,
  13. # torch.div: refs.div,
  14. torch.mul: refs.mul,
  15. torch.ge: refs.ge,
  16. torch.gt: refs.gt,
  17. torch.le: refs.le,
  18. torch.lt: refs.lt,
  19. }
  20. class PrimContext(torch.overrides.TorchFunctionMode):
  21. """
  22. The prototype prim tracing context.
  23. Example usage:
  24. import torch._prims.utils as utils
  25. from torch._prims.context import PrimContext
  26. from torch._prims.executor import execute
  27. from torch.overrides import push_torch_function_mode
  28. a = torch.randn((2, 2))
  29. b = torch.randn((2, 2))
  30. with push_torch_function_mode(PrimContext):
  31. meta_a = ctx.placeholder(utils.TensorMeta(a))
  32. meta_b = ctx.placeholder(utils.TensorMeta(b))
  33. result = torch.add(meta_a, meta_b)
  34. ctx.output(result)
  35. exc_result = execute(ctx, a, b)
  36. Currently this only acquires a trace of prims, and
  37. it does not account for control flow. As such,
  38. execute must be called with tensors that have the
  39. same metadata (dtype, device, shape...) as
  40. the tensors used to trace the operations.
  41. The tracing context's FX graph can be acquired
  42. using its graph attribute.
  43. """
  44. def __init__(self):
  45. self.graph = Graph()
  46. # Private attributes for generating names
  47. self._tensor_name_counter = 0
  48. self._dim_name_counter = 0
  49. self._shape_name_counter = 0
  50. self._lowercase = tuple(string.ascii_lowercase)
  51. self._uppercase = tuple(string.ascii_uppercase)
  52. @staticmethod
  53. def _create_name(idx, chars):
  54. name = ""
  55. while idx >= len(chars):
  56. name = chars[idx % len(chars)] + name
  57. idx = idx - len(chars)
  58. name = chars[idx] + name
  59. return name
  60. def _tensor_name(self):
  61. idx = self._tensor_name_counter
  62. self._tensor_name_counter = self._tensor_name_counter + 1
  63. return self._create_name(idx, self._lowercase)
  64. def _add_user(self, tm: TensorMeta, node: Node) -> None:
  65. assert tm.node is not None
  66. tm.node.users[node] = None
  67. def placeholder(self, a: Any):
  68. name = self._tensor_name()
  69. node = self.graph.placeholder(name)
  70. if isinstance(a, TensorMeta):
  71. if a.node is not None:
  72. raise ValueError("Attempting to reuse a TensorMeta in a new trace!")
  73. a.tname = name
  74. a.node = node
  75. return a
  76. def output(self, tm: TensorMeta):
  77. # TODO: allow other output types
  78. assert isinstance(tm, TensorMeta)
  79. node = self.graph.output(tm)
  80. self._add_user(tm, node)
  81. def __torch_function__(
  82. self,
  83. func: Callable,
  84. types: Sequence,
  85. args: Sequence[Any] = (),
  86. kwargs: Dict = None,
  87. ):
  88. """
  89. Determines which function to call. The order of which
  90. function is called is determined by:
  91. - func's "meta" attribute, if it exists
  92. - if func is a torch operation, its corresponding reference
  93. - func
  94. """
  95. if kwargs is None:
  96. kwargs = {}
  97. if hasattr(func, "meta"):
  98. # TODO: add check that all args/kwargs are 'registered' properly
  99. # to this trace
  100. output = func.meta(*args, **kwargs) # type: ignore[attr-defined]
  101. # Updates graph
  102. # TODO: handle outputs with multiple tensors
  103. # TODO: handle non-tensor outputs
  104. assert isinstance(output, TensorMeta)
  105. output_name = self._tensor_name()
  106. node = self.graph.create_node(
  107. "call_function", func, name=output_name, args=args, kwargs=kwargs
  108. )
  109. output.tname = output_name
  110. output.node = node
  111. # Marks uses
  112. for x in (
  113. x for x in chain(args, kwargs.values()) if isinstance(x, TensorMeta)
  114. ):
  115. self._add_user(x, node)
  116. return output
  117. # Remaps torch operations to their references
  118. if func in _torch_to_reference_map:
  119. fn = _torch_to_reference_map[func]
  120. with torch.overrides.enable_torch_function_mode(self, replace=self.inner):
  121. return fn(*args, **kwargs) # type: ignore[operator]
  122. return func(*args, **kwargs)