proxy.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. import dis
  2. import torch
  3. import inspect
  4. import operator
  5. import traceback
  6. from .graph import magic_methods, reflectable_magic_methods, Graph
  7. from typing import Tuple, Dict, Optional, Iterable, Any, Iterator, Callable
  8. from .node import Target, Node, Argument, base_types, map_aggregate
  9. from ._compatibility import compatibility
  10. from .operator_schemas import check_for_mutable_operation
  11. @compatibility(is_backward_compatible=True)
  12. class TracerBase:
  13. graph: Graph
  14. record_stack_traces : bool = False
  15. # Feature flag for mutable schema checking
  16. # Enableby default in 1.12
  17. check_mutable_operations : bool = False
  18. # Feature flag for assert tracing
  19. trace_asserts : bool = False
  20. # Feature flag for proxying accesses to buffer values
  21. proxy_buffer_attributes : bool = False
  22. # Name of the function to be traced. It will only be used when
  23. # ``root`` is an instance of ``nn.Module``
  24. traced_func_name: str = "forward"
  25. @compatibility(is_backward_compatible=True)
  26. def create_node(self, kind : str, target : Target,
  27. args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
  28. type_expr : Optional[Any] = None) -> Node:
  29. """
  30. Inserts a graph node given target, args, kwargs, and name.
  31. This method can be overridden to do extra checking, validation, or
  32. modification of values used in node creation. For example, one might
  33. want to disallow in-place operations from being recorded.
  34. """
  35. if kind == 'call_function' and self.check_mutable_operations:
  36. check_for_mutable_operation(target, args, kwargs)
  37. return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
  38. @compatibility(is_backward_compatible=True)
  39. def proxy(self, node: Node) -> 'Proxy':
  40. return Proxy(node, self)
  41. @compatibility(is_backward_compatible=True)
  42. def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
  43. name: Optional[str] = None, type_expr : Optional[Any] = None,
  44. proxy_factory_fn: Callable[[Node], 'Proxy'] = None):
  45. '''
  46. Create a Node from the given arguments, then return the Node
  47. wrapped in a Proxy object.
  48. If kind = 'placeholder', then we're creating a Node that
  49. represents the parameter of a function. If we need to encode
  50. a default parameter, we use the ``args`` tuple. ``args`` is
  51. otherwise empty for ``placeholder`` Nodes.
  52. '''
  53. args_ = self.create_arg(args)
  54. kwargs_ = self.create_arg(kwargs)
  55. assert isinstance(args_, tuple)
  56. assert isinstance(kwargs_, dict)
  57. node = self.create_node(kind, target, args_, kwargs_, name, type_expr)
  58. if not proxy_factory_fn:
  59. proxy = self.proxy(node)
  60. else:
  61. proxy = proxy_factory_fn(node)
  62. # Optionally set stack trace on the created Node for debugging purposes
  63. if self.record_stack_traces:
  64. user_frame = self._find_user_frame()
  65. if user_frame:
  66. walk_stack_gen = traceback.walk_stack(user_frame)
  67. summary = traceback.StackSummary.extract(walk_stack_gen) # type: ignore[arg-type]
  68. tb_lines = summary.format()
  69. proxy.node.stack_trace = ''.join(tb_lines)
  70. return proxy
  71. def _find_user_frame(self):
  72. """
  73. Find the Python stack frame executing the user code during
  74. symbolic tracing.
  75. """
  76. # We have to do a little dance here. Basically, walk up the callstack and
  77. # record the first frame not in the FX source. This is the frame executing
  78. # the user code during tracing.
  79. frame = inspect.currentframe()
  80. fx_files = ['torch/fx/proxy.py', 'torch/fx/symbolic_trace.py']
  81. while frame:
  82. frame = frame.f_back
  83. if frame and all(not frame.f_code.co_filename.endswith(file) for file in fx_files):
  84. break
  85. if not frame:
  86. return None
  87. return frame
  88. @compatibility(is_backward_compatible=True)
  89. def create_arg(self, a: Any) -> Argument:
  90. """
  91. A method that lowers the objects seen as arguments during symbolic evaluation
  92. into Argument types that can be stored in IR.
  93. Can be override to support more trace-specific types.
  94. """
  95. if not isinstance(a, Proxy) and hasattr(a, '__fx_create_arg__'):
  96. return a.__fx_create_arg__(self)
  97. # aggregates
  98. elif isinstance(a, tuple) and hasattr(a, '_fields'):
  99. # NamedTuple constructors don't seem to like getting a generator
  100. # expression as an argument to their constructor, so build this
  101. # intermediate tuple and unpack it into the NamedTuple constructor
  102. args = tuple(self.create_arg(elem) for elem in a)
  103. return type(a)(*args) # type: ignore[arg-type]
  104. elif isinstance(a, (tuple, list)):
  105. return type(a)(self.create_arg(elem) for elem in a)
  106. elif isinstance(a, dict):
  107. r = {}
  108. for k, v in a.items():
  109. # Check for invalid dict keys. We do not want a Proxy to appear
  110. # anywhere within the key. Since keys can be collection types,
  111. # we iterate through the key with map_aggregate
  112. k = self.create_arg(k)
  113. def no_node(arg):
  114. if isinstance(arg, Node):
  115. raise RuntimeError("Keys for dictionaries used as an argument cannot contain a "
  116. "Node. Got key: {k}")
  117. map_aggregate(k, no_node)
  118. r[k] = self.create_arg(v)
  119. return r
  120. elif isinstance(a, slice):
  121. return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
  122. if isinstance(a, Proxy):
  123. # base case: we unwrap the Proxy object
  124. return a.node
  125. elif isinstance(a, base_types) or a is None or a is ...:
  126. return a
  127. raise NotImplementedError(f"argument of type: {type(a)}")
  128. @compatibility(is_backward_compatible=True)
  129. def to_bool(self, obj: 'Proxy') -> bool:
  130. """Called when a proxy object is being converted to a boolean, such as
  131. when used in control flow. Normally we don't know what to do because
  132. we don't know the value of the proxy, but a custom tracer can attach more
  133. information to the graph node using create_node and can choose to return a value.
  134. """
  135. raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
  136. @compatibility(is_backward_compatible=True)
  137. def iter(self, obj: 'Proxy') -> Iterator:
  138. """Called when a proxy object is being iterated over, such as
  139. when used in control flow. Normally we don't know what to do because
  140. we don't know the value of the proxy, but a custom tracer can attach more
  141. information to the graph node using create_node and can choose to return an iterator.
  142. """
  143. raise TraceError('Proxy object cannot be iterated. This can be '
  144. 'attempted when the Proxy is used in a loop or'
  145. ' as a *args or **kwargs function argument. '
  146. 'See the torch.fx docs on pytorch.org for a '
  147. 'more detailed explanation of what types of '
  148. 'control flow can be traced, and check out the'
  149. ' Proxy docstring for help troubleshooting '
  150. 'Proxy iteration errors')
  151. @compatibility(is_backward_compatible=True)
  152. def keys(self, obj: 'Proxy') -> Any:
  153. """Called when a proxy object is has the keys() method called.
  154. This is what happens when ** is called on a proxy. This should return an
  155. iterator it ** is suppose to work in your custom tracer.
  156. """
  157. return Attribute(obj, 'keys')()
  158. # used in Proxy object when just appending to the graph while not tracing.
  159. @compatibility(is_backward_compatible=True)
  160. class GraphAppendingTracer(TracerBase):
  161. def __init__(self, graph: Graph):
  162. super().__init__()
  163. self.graph = graph
  164. @compatibility(is_backward_compatible=False)
  165. def assert_fn(x):
  166. assert x
  167. @compatibility(is_backward_compatible=True)
  168. class TraceError(ValueError):
  169. pass
  170. @compatibility(is_backward_compatible=True)
  171. class Proxy:
  172. """
  173. ``Proxy`` objects are ``Node`` wrappers that flow through the
  174. program during symbolic tracing and record all the operations
  175. (``torch`` function calls, method calls, operators) that they touch
  176. into the growing FX Graph.
  177. If you're doing graph transforms, you can wrap your own ``Proxy``
  178. method around a raw ``Node`` so that you can use the overloaded
  179. operators to add additional things to a ``Graph``.
  180. ``Proxy`` objects cannot be iterated. In other words, the symbolic
  181. tracer will throw an error if a ``Proxy`` is used in a loop or as
  182. an ``*args``/``**kwargs`` function argument.
  183. There are two main ways around this:
  184. 1. Factor out the untraceable logic into a top-level function and
  185. use ``fx.wrap`` on it.
  186. 2. If the control flow is static (i.e. the loop trip count is
  187. based on some hyperparameter), the code can be kept in its original
  188. position and refactored into something like::
  189. for i in range(self.some_hyperparameter):
  190. indexed_item = proxied_value[i]
  191. For a more detailed description into the Proxy internals, check out
  192. the "Proxy" section in `torch/fx/OVERVIEW.md`
  193. """
  194. @compatibility(is_backward_compatible=True)
  195. def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None):
  196. if tracer is None:
  197. # This allows you to create a Proxy object around a raw Node
  198. tracer = GraphAppendingTracer(node.graph)
  199. self.tracer = tracer
  200. self.node = node
  201. def __repr__(self) -> str:
  202. return f'Proxy({self.node.name})'
  203. def __getattr__(self, k) -> 'Attribute':
  204. # note: not added to the graph yet, if this is a method call
  205. # we peephole optimize to the method invocation
  206. return Attribute(self, k)
  207. def __call__(self, *args, **kwargs) -> 'Proxy':
  208. return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs)
  209. def __iter__(self) -> Iterable['Proxy']:
  210. frame = inspect.currentframe()
  211. assert frame is not None
  212. calling_frame = frame.f_back
  213. assert calling_frame is not None
  214. inst = list(dis.get_instructions(calling_frame.f_code))[calling_frame.f_lasti // 2]
  215. if inst.opname == 'UNPACK_SEQUENCE':
  216. return (self[i] for i in range(inst.argval)) # type: ignore[index]
  217. return self.tracer.iter(self)
  218. def __bool__(self) -> bool:
  219. if self.tracer.trace_asserts:
  220. # check if this boolean is used in an assertion, bytecode pattern for assertions
  221. # is pretty stable for Python 3.7--3.9
  222. frame = inspect.currentframe()
  223. assert frame is not None
  224. calling_frame = frame.f_back
  225. assert calling_frame is not None
  226. insts = list(dis.get_instructions(calling_frame.f_code))
  227. cur = calling_frame.f_lasti // 2
  228. inst = insts[cur]
  229. if inst.opname == 'POP_JUMP_IF_TRUE':
  230. first = insts[cur + 1]
  231. assert inst.arg is not None
  232. last = insts[inst.arg // 2 - 1]
  233. starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError'
  234. or first.opname == 'LOAD_ASSERTION_ERROR')
  235. if starts_with_assert and last.opname == 'RAISE_VARARGS':
  236. self.tracer.create_proxy('call_function', assert_fn, (self,), {})
  237. return True
  238. return self.tracer.to_bool(self)
  239. @compatibility(is_backward_compatible=True)
  240. def keys(self):
  241. return self.tracer.keys(self)
  242. def __len__(self):
  243. raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
  244. "this call to be recorded, please call torch.fx.wrap('len') at "
  245. "module scope")
  246. @classmethod
  247. def __torch_function__(cls, orig_method, types, args=None, kwargs=None):
  248. args = args if args else ()
  249. kwargs = kwargs if kwargs else {}
  250. tracers : Dict[Any, None] = {}
  251. def find_tracer(a):
  252. if isinstance(a, cls):
  253. tracers[a.tracer] = None
  254. torch.fx.node.map_aggregate(args, find_tracer)
  255. torch.fx.node.map_aggregate(kwargs, find_tracer)
  256. if len(tracers) > 1:
  257. raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while '
  258. f'trying to trace operations {orig_method}')
  259. tracer = next(iter(tracers.keys()))
  260. if isinstance(orig_method, torch._C.ScriptMethod):
  261. args = (orig_method.owner,) + args
  262. return tracer.create_proxy('call_method', orig_method.name, args, kwargs)
  263. if torch.overrides.is_tensor_method_or_property(orig_method):
  264. return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs)
  265. else:
  266. return tracer.create_proxy('call_function', orig_method, args, kwargs,
  267. name=tracer.graph._target_to_str(orig_method.__name__))
  268. @compatibility(is_backward_compatible=True)
  269. class Attribute(Proxy):
  270. @compatibility(is_backward_compatible=True)
  271. def __init__(self, root: Proxy, attr: str):
  272. self.root = root
  273. self.attr = attr
  274. self.tracer = root.tracer
  275. self._node: Optional[Node] = None
  276. @property
  277. def node(self):
  278. # the node for attributes is added lazily, since most will just be method calls
  279. # which do not rely on the getitem call
  280. if self._node is None:
  281. self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
  282. return self._node
  283. def __call__(self, *args, **kwargs):
  284. return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
  285. @compatibility(is_backward_compatible=False)
  286. class ParameterProxy(Proxy):
  287. """
  288. A special proxy which lets "shape", "size", "dim", and a few other
  289. attribute accesses pass through to the underlying module parameter object,
  290. so that conditional tests on these attributes will not throw exception during tracing
  291. """
  292. def __init__(self, tracer: TracerBase, node: Node, name, param):
  293. super().__init__(node, tracer)
  294. assert(isinstance(param, torch.nn.Parameter))
  295. self.param = param
  296. self.name = name
  297. def __repr__(self) -> str:
  298. return f'ParameterProxy({self.name})'
  299. @property
  300. def shape(self):
  301. return self.param.shape
  302. def size(self):
  303. return self.param.size()
  304. def dim(self):
  305. return self.param.dim()
  306. @property
  307. def ndim(self):
  308. return self.param.ndim
  309. def numel(self):
  310. return self.param.numel()
  311. def nelement(self):
  312. return self.param.nelement()
  313. for method in magic_methods:
  314. def _scope(method):
  315. def impl(*args, **kwargs):
  316. tracer = args[0].tracer
  317. target = getattr(operator, method)
  318. return tracer.create_proxy('call_function', target, args, kwargs)
  319. impl.__name__ = method
  320. as_magic = f'__{method.strip("_")}__'
  321. setattr(Proxy, as_magic, impl)
  322. _scope(method)
  323. def _define_reflectable(orig_method_name):
  324. method_name = f'__r{orig_method_name.strip("_")}__'
  325. def impl(self, rhs):
  326. target = getattr(operator, orig_method_name)
  327. return self.tracer.create_proxy('call_function', target, (rhs, self), {})
  328. impl.__name__ = method_name
  329. impl.__qualname__ = method_name
  330. setattr(Proxy, method_name, impl)
  331. for orig_method_name in reflectable_magic_methods:
  332. _define_reflectable(orig_method_name)