interpreter.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459
  1. from .graph_module import GraphModule
  2. from .graph import Graph
  3. from .node import Argument, Node, Target, map_arg, map_aggregate
  4. from .proxy import Proxy
  5. from ._symbolic_trace import Tracer
  6. from ._compatibility import compatibility
  7. from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
  8. import inspect
  9. @compatibility(is_backward_compatible=True)
  10. class Interpreter:
  11. """
  12. An Interpreter executes an FX graph Node-by-Node. This pattern
  13. can be useful for many things, including writing code
  14. transformations as well as analysis passes.
  15. Methods in the Interpreter class can be overridden to customize
  16. the behavior of execution. The map of overrideable methods
  17. in terms of call hierarchy::
  18. run()
  19. +-- run_node
  20. +-- placeholder()
  21. +-- get_attr()
  22. +-- call_function()
  23. +-- call_method()
  24. +-- call_module()
  25. +-- output()
  26. Example:
  27. Suppose we want to swap all instances of ``torch.neg`` with
  28. ``torch.sigmoid`` and vice versa (including their ``Tensor``
  29. method equivalents). We could subclass Interpreter like so::
  30. class NegSigmSwapInterpreter(Interpreter):
  31. def call_function(self, target : Target,
  32. args : Tuple, kwargs : Dict) -> Any:
  33. if target == torch.sigmoid:
  34. return torch.neg(*args, **kwargs)
  35. return super().call_function(n)
  36. def call_method(self, target : Target,
  37. args : Tuple, kwargs : Dict) -> Any:
  38. if target == 'neg':
  39. call_self, *args_tail = args
  40. return call_self.sigmoid(*args_tail, **kwargs)
  41. return super().call_method(n)
  42. def fn(x):
  43. return torch.sigmoid(x).neg()
  44. gm = torch.fx.symbolic_trace(fn)
  45. input = torch.randn(3, 4)
  46. result = NegSigmSwapInterpreter(gm).run(input)
  47. torch.testing.assert_allclose(result, torch.neg(input).sigmoid())
  48. Args:
  49. module (GraphModule): The module to be executed
  50. garbage_collect_values (bool): Whether to delete values after their last
  51. use within the Module's execution. This ensures optimal memory usage during
  52. execution. This can be disabled to, for example, examine all of the intermediate
  53. values in the execution by looking at the ``Interpreter.env`` attribute.
  54. """
  55. @compatibility(is_backward_compatible=True)
  56. def __init__(self, module : GraphModule, garbage_collect_values : bool = True):
  57. assert isinstance(module, GraphModule)
  58. self.module = module
  59. self.submodules = dict(self.module.named_modules())
  60. self.env : Dict[Node, Any] = {}
  61. self.garbage_collect_values = garbage_collect_values
  62. if self.garbage_collect_values:
  63. # Run through reverse nodes and record the first instance of a use
  64. # of a given node. This represents the *last* use of the node in the
  65. # execution order of the program, which we will use to free unused
  66. # values
  67. node_to_last_use : Dict[Node, Node] = {}
  68. self.user_to_last_uses : Dict[Node, List[Node]] = {}
  69. def register_last_uses(n : Node, user : Node):
  70. if n not in node_to_last_use:
  71. node_to_last_use[n] = user
  72. self.user_to_last_uses.setdefault(user, []).append(n)
  73. for node in reversed(self.module.graph.nodes):
  74. map_arg(node.args, lambda n: register_last_uses(n, node))
  75. map_arg(node.kwargs, lambda n: register_last_uses(n, node))
  76. @compatibility(is_backward_compatible=True)
  77. def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any:
  78. """
  79. Run `module` via interpretation and return the result.
  80. Args:
  81. *args: The arguments to the Module to run, in positional order
  82. initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution.
  83. This is a dict mapping `Node` to any value. This can be used, for example, to
  84. pre-populate results for certain `Nodes` so as to do only partial evaluation within
  85. the interpreter.
  86. enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and
  87. process_outputs function first before using them.
  88. Returns:
  89. Any: The value returned from executing the Module
  90. """
  91. self.env = initial_env if initial_env else {}
  92. # Positional function args are consumed left-to-right by
  93. # `placeholder` nodes. Use an iterator to keep track of
  94. # position and extract those values.
  95. if enable_io_processing:
  96. args = self.module.graph.process_inputs(*args)
  97. self.args_iter : Iterator[Any] = iter(args)
  98. for node in self.module.graph.nodes:
  99. if node in self.env:
  100. # Short circuit if we have this value. This could
  101. # be used, for example, for partial evaluation
  102. # where the caller has pre-populated `env` with
  103. # values for a subset of the program.
  104. continue
  105. self.env[node] = self.run_node(node)
  106. if self.garbage_collect_values:
  107. for to_delete in self.user_to_last_uses.get(node, []):
  108. del self.env[to_delete]
  109. if node.op == 'output':
  110. output_val = self.env[node]
  111. return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val
  112. @compatibility(is_backward_compatible=True)
  113. def run_node(self, n : Node) -> Any:
  114. """
  115. Run a specific node ``n`` and return the result.
  116. Calls into placeholder, get_attr, call_function,
  117. call_method, call_module, or output depending
  118. on ``node.op``
  119. Args:
  120. n (Node): The Node to execute
  121. Returns:
  122. Any: The result of executing ``n``
  123. """
  124. args, kwargs = self.fetch_args_kwargs_from_env(n)
  125. assert isinstance(args, tuple)
  126. assert isinstance(kwargs, dict)
  127. return getattr(self, n.op)(n.target, args, kwargs)
  128. # Main Node running APIs
  129. @compatibility(is_backward_compatible=True)
  130. def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  131. """
  132. Execute a ``placeholder`` node. Note that this is stateful:
  133. ``Interpreter`` maintains an internal iterator over
  134. arguments passed to ``run`` and this method returns
  135. next() on that iterator.
  136. Args:
  137. target (Target): The call target for this node. See
  138. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  139. details on semantics
  140. args (Tuple): Tuple of positional args for this invocation
  141. kwargs (Dict): Dict of keyword arguments for this invocation
  142. Returns:
  143. Any: The argument value that was retrieved.
  144. """
  145. assert isinstance(target, str)
  146. if target.startswith('*'):
  147. # For a starred parameter e.g. `*args`, retrieve all
  148. # remaining values from the args list.
  149. return list(self.args_iter)
  150. else:
  151. try:
  152. return next(self.args_iter)
  153. except StopIteration as si:
  154. if len(args) > 0:
  155. return args[0]
  156. else:
  157. raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!')
  158. @compatibility(is_backward_compatible=True)
  159. def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  160. """
  161. Execute a ``get_attr`` node. Will retrieve an attribute
  162. value from the ``Module`` hierarchy of ``self.module``.
  163. Args:
  164. target (Target): The call target for this node. See
  165. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  166. details on semantics
  167. args (Tuple): Tuple of positional args for this invocation
  168. kwargs (Dict): Dict of keyword arguments for this invocation
  169. Return:
  170. Any: The value of the attribute that was retrieved
  171. """
  172. assert isinstance(target, str)
  173. return self.fetch_attr(target)
  174. @compatibility(is_backward_compatible=True)
  175. def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  176. """
  177. Execute a ``call_function`` node and return the result.
  178. Args:
  179. target (Target): The call target for this node. See
  180. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  181. details on semantics
  182. args (Tuple): Tuple of positional args for this invocation
  183. kwargs (Dict): Dict of keyword arguments for this invocation
  184. Return
  185. Any: The value returned by the function invocation
  186. """
  187. assert not isinstance(target, str)
  188. # Execute the function and return the result
  189. return target(*args, **kwargs)
  190. @compatibility(is_backward_compatible=True)
  191. def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  192. """
  193. Execute a ``call_method`` node and return the result.
  194. Args:
  195. target (Target): The call target for this node. See
  196. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  197. details on semantics
  198. args (Tuple): Tuple of positional args for this invocation
  199. kwargs (Dict): Dict of keyword arguments for this invocation
  200. Return
  201. Any: The value returned by the method invocation
  202. """
  203. # args[0] is the `self` object for this method call
  204. self_obj, *args_tail = args
  205. # Execute the method and return the result
  206. assert isinstance(target, str)
  207. return getattr(self_obj, target)(*args_tail, **kwargs)
  208. @compatibility(is_backward_compatible=True)
  209. def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  210. """
  211. Execute a ``call_module`` node and return the result.
  212. Args:
  213. target (Target): The call target for this node. See
  214. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  215. details on semantics
  216. args (Tuple): Tuple of positional args for this invocation
  217. kwargs (Dict): Dict of keyword arguments for this invocation
  218. Return
  219. Any: The value returned by the module invocation
  220. """
  221. # Retrieve executed args and kwargs values from the environment
  222. # Execute the method and return the result
  223. assert isinstance(target, str)
  224. submod = self.fetch_attr(target)
  225. return submod(*args, **kwargs)
  226. @compatibility(is_backward_compatible=True)
  227. def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  228. """
  229. Execute an ``output`` node. This really just retrieves
  230. the value referenced by the ``output`` node and returns it.
  231. Args:
  232. target (Target): The call target for this node. See
  233. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  234. details on semantics
  235. args (Tuple): Tuple of positional args for this invocation
  236. kwargs (Dict): Dict of keyword arguments for this invocation
  237. Return:
  238. Any: The return value referenced by the output node
  239. """
  240. return args[0]
  241. # Helper methods
  242. @compatibility(is_backward_compatible=True)
  243. def fetch_attr(self, target : str):
  244. """
  245. Fetch an attribute from the ``Module`` hierarchy of ``self.module``.
  246. Args:
  247. target (str): The fully-qualfiied name of the attribute to fetch
  248. Return:
  249. Any: The value of the attribute.
  250. """
  251. target_atoms = target.split('.')
  252. attr_itr = self.module
  253. for i, atom in enumerate(target_atoms):
  254. if not hasattr(attr_itr, atom):
  255. raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}")
  256. attr_itr = getattr(attr_itr, atom)
  257. return attr_itr
  258. @compatibility(is_backward_compatible=True)
  259. def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]:
  260. """
  261. Fetch the concrete values of ``args`` and ``kwargs`` of node ``n``
  262. from the current execution environment.
  263. Args:
  264. n (Node): The node for which ``args`` and ``kwargs`` should be fetched.
  265. Return:
  266. Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``.
  267. """
  268. args = self.map_nodes_to_values(n.args, n)
  269. assert isinstance(args, tuple)
  270. kwargs = self.map_nodes_to_values(n.kwargs, n)
  271. assert isinstance(kwargs, dict)
  272. return args, kwargs
  273. @compatibility(is_backward_compatible=True)
  274. def map_nodes_to_values(self, args : Argument, n : Node) -> Argument:
  275. """
  276. Recursively descend through ``args`` and look up the concrete value
  277. for each ``Node`` in the current execution environment.
  278. Args:
  279. args (Argument): Data structure within which to look up concrete values
  280. n (Node): Node to which ``args`` belongs. This is only used for error reporting.
  281. """
  282. def load_arg(n_arg : Node) -> Any:
  283. if n_arg not in self.env:
  284. raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() '
  285. f'to diagnose such issues')
  286. return self.env[n_arg]
  287. return map_arg(args, load_arg)
  288. @compatibility(is_backward_compatible=True)
  289. class Transformer(Interpreter):
  290. """
  291. ``Transformer`` is a special type of interpreter that produces a
  292. new ``Module``. It exposes a ``transform()`` method that returns
  293. the transformed ``Module``. ``Transformer`` does not require
  294. arguments to run, as ``Interpreter`` does. ``Transformer`` works
  295. entirely symbolically.
  296. Example:
  297. Suppose we want to swap all instances of ``torch.neg`` with
  298. ``torch.sigmoid`` and vice versa (including their ``Tensor``
  299. method equivalents). We could subclass ``Transformer`` like so::
  300. class NegSigmSwapXformer(Transformer):
  301. def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  302. if target == torch.sigmoid:
  303. return torch.neg(*args, **kwargs)
  304. return super().call_function(n)
  305. def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  306. if target == 'neg':
  307. call_self, *args_tail = args
  308. return call_self.sigmoid(*args_tail, **kwargs)
  309. return super().call_method(n)
  310. def fn(x):
  311. return torch.sigmoid(x).neg()
  312. gm = torch.fx.symbolic_trace(fn)
  313. transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform()
  314. input = torch.randn(3, 4)
  315. torch.testing.assert_allclose(transformed(input), torch.neg(input).sigmoid())
  316. Args:
  317. module (GraphModule): The ``Module`` to be transformed.
  318. """
  319. @compatibility(is_backward_compatible=True)
  320. def __init__(self, module):
  321. super().__init__(module)
  322. self.new_graph = Graph()
  323. self.new_graph.set_codegen(module.graph._codegen)
  324. class TransformerTracer(Tracer):
  325. def __init__(self, graph: Graph):
  326. super().__init__()
  327. self.graph = graph
  328. def is_leaf_module(self, _, __) -> bool:
  329. return True
  330. self.tracer = TransformerTracer(self.new_graph)
  331. self.tracer.root = module
  332. @compatibility(is_backward_compatible=True)
  333. def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
  334. """
  335. Execute a ``placeholder`` node. In ``Transformer``, this is
  336. overridden to insert a new ``placeholder`` into the output
  337. graph.
  338. Args:
  339. target (Target): The call target for this node. See
  340. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  341. details on semantics
  342. args (Tuple): Tuple of positional args for this invocation
  343. kwargs (Dict): Dict of keyword arguments for this invocation
  344. """
  345. assert isinstance(target, str)
  346. default_value = next(iter(args)) if args else inspect.Signature.empty
  347. return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer)
  348. @compatibility(is_backward_compatible=True)
  349. def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy:
  350. """
  351. Execute a ``get_attr`` node. In ``Transformer``, this is
  352. overridden to insert a new ``get_attr`` node into the output
  353. graph.
  354. Args:
  355. target (Target): The call target for this node. See
  356. `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
  357. details on semantics
  358. args (Tuple): Tuple of positional args for this invocation
  359. kwargs (Dict): Dict of keyword arguments for this invocation
  360. """
  361. assert isinstance(target, str)
  362. return Proxy(self.new_graph.get_attr(target), self.tracer)
  363. @compatibility(is_backward_compatible=True)
  364. def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  365. # Override so that the leaf module policy from `self.tracer` is respected.
  366. assert isinstance(target, str)
  367. submod = self.fetch_attr(target)
  368. return self.tracer.call_module(submod, submod.forward, args, kwargs)
  369. @compatibility(is_backward_compatible=True)
  370. def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any:
  371. # Override so that functions that were wrapped are still wrapped.
  372. return self.tracer.create_proxy('call_function', target, args, kwargs)
  373. @compatibility(is_backward_compatible=True)
  374. def transform(self) -> GraphModule:
  375. """
  376. Transform ``self.module`` and return the transformed
  377. ``GraphModule``.
  378. """
  379. result = super().run(enable_io_processing=False)
  380. if result is not None:
  381. def strip_proxy(a : Union[Argument, Proxy]) -> Any:
  382. return a.node if isinstance(a, Proxy) else a
  383. self.new_graph.output(map_aggregate(result, strip_proxy))
  384. return GraphModule(self.module, self.new_graph)