subgraph_rewriter.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. from .graph_module import GraphModule
  2. from .graph import Graph
  3. from .node import Node
  4. from ._symbolic_trace import symbolic_trace
  5. from ._compatibility import compatibility
  6. import copy
  7. from typing import Callable, Dict, List, NamedTuple, Optional, Set
  8. import torch
  9. @compatibility(is_backward_compatible=True)
  10. class Match(NamedTuple):
  11. # Node from which the match was found
  12. anchor: Node
  13. # Maps nodes in the pattern subgraph to nodes in the larger graph
  14. nodes_map: Dict[Node, Node]
  15. class _SubgraphMatcher:
  16. def __init__(self, pattern: Graph) -> None:
  17. self.pattern = pattern
  18. if len(pattern.nodes) == 0:
  19. raise ValueError("_SubgraphMatcher cannot be initialized with an "
  20. "empty pattern")
  21. # `self.pattern_anchor` is the output Node in `pattern`
  22. self.pattern_anchor = next(iter(reversed(pattern.nodes)))
  23. # Ensure that there is only a single output value in the pattern
  24. # since we don't support multiple outputs
  25. assert len(self.pattern_anchor.all_input_nodes) == 1, \
  26. "Pattern matching on multiple outputs is not supported"
  27. # Maps nodes in the pattern subgraph to nodes in the larger graph
  28. self.nodes_map: Dict[Node, Node] = {}
  29. def matches_subgraph_from_anchor(self, anchor: Node) -> bool:
  30. """
  31. Checks if the whole pattern can be matched starting from
  32. ``anchor`` in the larger graph.
  33. Pattern matching is done by recursively comparing the pattern
  34. node's use-def relationships against the graph node's.
  35. """
  36. self.nodes_map = {}
  37. return self._match_nodes(self.pattern_anchor, anchor)
  38. # Compare the pattern node `pn` against the graph node `gn`
  39. def _match_nodes(self, pn: Node, gn: Node) -> bool:
  40. # Check if we've already matched these nodes in the current
  41. # traversal
  42. if pn in self.nodes_map:
  43. return self.nodes_map[pn] == gn
  44. def attributes_are_equal(pn: Node, gn: Node) -> bool:
  45. # Use placeholder and output nodes as wildcards. The
  46. # only exception is that an output node can't match
  47. # a placeholder
  48. if (pn.op == "placeholder"
  49. or (pn.op == "output" and gn.op != "placeholder")):
  50. return True
  51. return pn.op == gn.op and pn.target == gn.target
  52. # Terminate early if the node attributes are not equal
  53. if not attributes_are_equal(pn, gn):
  54. return False
  55. # Optimistically mark `pn` as a match for `gn`
  56. self.nodes_map[pn] = gn
  57. # Traverse the use-def relationships to ensure that `pn` is a true
  58. # match for `gn`
  59. if pn.op == "placeholder":
  60. return True
  61. if (pn.op != "output"
  62. and len(pn.all_input_nodes) != len(gn.all_input_nodes)):
  63. return False
  64. if pn.op == "output":
  65. match_found = any(self._match_nodes(pn.all_input_nodes[0], gn_)
  66. for gn_ in gn.all_input_nodes)
  67. else:
  68. match_found = (len(pn.all_input_nodes) == len(gn.all_input_nodes)
  69. and all(self._match_nodes(pn_, gn_) for pn_, gn_
  70. in zip(pn.all_input_nodes, gn.all_input_nodes)))
  71. if not match_found:
  72. self.nodes_map.pop(pn)
  73. return False
  74. return True
  75. def _replace_submodules(gm: GraphModule, replacement: torch.nn.Module) -> None:
  76. gm.delete_all_unused_submodules()
  77. if isinstance(replacement, GraphModule):
  78. replacement.graph.lint()
  79. def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Module]:
  80. try:
  81. mod_match = mod.get_submodule(target)
  82. return mod_match
  83. except AttributeError:
  84. return None
  85. for node in gm.graph.nodes:
  86. if node.op == "call_module" or node.op == "get_attr":
  87. gm_submod = try_get_submodule(gm, node.target)
  88. replacement_submod = try_get_submodule(replacement, node.target)
  89. # CASE 1: This target already exists as a submodule in our
  90. # result GraphModule. Whether or not it exists in
  91. # `replacement`, the existing submodule takes precedence.
  92. if gm_submod is not None:
  93. continue
  94. # CASE 2: The target exists as a submodule in `replacement`
  95. # only, so we need to copy it over.
  96. elif replacement_submod is not None:
  97. new_submod = copy.deepcopy(getattr(replacement, node.target))
  98. gm.add_submodule(node.target, new_submod)
  99. # CASE 3: The target doesn't exist as a submodule in `gm`
  100. # or `replacement`
  101. else:
  102. raise RuntimeError("Attempted to create a \"", node.op,
  103. "\" node during subgraph rewriting "
  104. f"with target {node.target}, but "
  105. "the referenced submodule does not "
  106. "exist in either the original "
  107. "GraphModule `gm` or the replacement"
  108. " GraphModule `replacement`")
  109. gm.graph.lint()
  110. @compatibility(is_backward_compatible=True)
  111. def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]:
  112. """
  113. Matches all possible non-overlapping sets of operators and their
  114. data dependencies (``pattern``) in the Graph of a GraphModule
  115. (``gm``), then replaces each of these matched subgraphs with another
  116. subgraph (``replacement``).
  117. Args:
  118. ``gm``: The GraphModule that wraps the Graph to operate on
  119. ``pattern``: The subgraph to match in ``gm`` for replacement
  120. ``replacement``: The subgraph to replace ``pattern`` with
  121. Returns:
  122. List[Match]: A list of ``Match`` objects representing the places
  123. in the original graph that ``pattern`` was matched to. The list
  124. is empty if there are no matches. ``Match`` is defined as:
  125. .. code-block:: python
  126. class Match(NamedTuple):
  127. # Node from which the match was found
  128. anchor: Node
  129. # Maps nodes in the pattern subgraph to nodes in the larger graph
  130. nodes_map: Dict[Node, Node]
  131. Examples:
  132. .. code-block:: python
  133. import torch
  134. from torch.fx import symbolic_trace, subgraph_rewriter
  135. class M(torch.nn.Module):
  136. def __init__(self):
  137. super().__init__()
  138. def forward(self, x, w1, w2):
  139. m1 = torch.cat([w1, w2]).sum()
  140. m2 = torch.cat([w1, w2]).sum()
  141. return x + torch.max(m1) + torch.max(m2)
  142. def pattern(w1, w2):
  143. return torch.cat([w1, w2]).sum()
  144. def replacement(w1, w2):
  145. return torch.stack([w1, w2])
  146. traced_module = symbolic_trace(M())
  147. subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
  148. The above code will first match ``pattern`` in the ``forward``
  149. method of ``traced_module``. Pattern-matching is done based on
  150. use-def relationships, not node names. For example, if you had
  151. ``p = torch.cat([a, b])`` in ``pattern``, you could match
  152. ``m = torch.cat([a, b])`` in the original ``forward`` function,
  153. despite the variable names being different (``p`` vs ``m``).
  154. The ``return`` statement in ``pattern`` is matched based on its
  155. value only; it may or may not match to the ``return`` statement in
  156. the larger graph. In other words, the pattern doesn't have to extend
  157. to the end of the larger graph.
  158. When the pattern is matched, it will be removed from the larger
  159. function and replaced by ``replacement``. If there are multiple
  160. matches for ``pattern`` in the larger function, each non-overlapping
  161. match will be replaced. In the case of a match overlap, the first
  162. found match in the set of overlapping matches will be replaced.
  163. ("First" here being defined as the first in a topological ordering
  164. of the Nodes' use-def relationships. In most cases, the first Node
  165. is the parameter that appears directly after ``self``, while the
  166. last Node is whatever the function returns.)
  167. One important thing to note is that the parameters of the
  168. ``pattern`` Callable must be used in the Callable itself,
  169. and the parameters of the ``replacement`` Callable must match
  170. the pattern. The first rule is why, in the above code block, the
  171. ``forward`` function has parameters ``x, w1, w2``, but the
  172. ``pattern`` function only has parameters ``w1, w2``. ``pattern``
  173. doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
  174. As an example of the second rule, consider replacing
  175. .. code-block:: python
  176. def pattern(x, y):
  177. return torch.neg(x) + torch.relu(y)
  178. with
  179. .. code-block:: python
  180. def replacement(x, y):
  181. return torch.relu(x)
  182. In this case, ``replacement`` needs the same number of parameters
  183. as ``pattern`` (both ``x`` and ``y``), even though the parameter
  184. ``y`` isn't used in ``replacement``.
  185. After calling ``subgraph_rewriter.replace_pattern``, the generated
  186. Python code looks like this:
  187. .. code-block:: python
  188. def forward(self, x, w1, w2):
  189. stack_1 = torch.stack([w1, w2])
  190. sum_1 = stack_1.sum()
  191. stack_2 = torch.stack([w1, w2])
  192. sum_2 = stack_2.sum()
  193. max_1 = torch.max(sum_1)
  194. add_1 = x + max_1
  195. max_2 = torch.max(sum_2)
  196. add_2 = add_1 + max_2
  197. return add_2
  198. """
  199. # Get the graphs for `gm`, `pattern`, `replacement`
  200. original_graph = gm.graph
  201. pattern_graph = symbolic_trace(pattern).graph
  202. replacement_graph = symbolic_trace(replacement).graph
  203. # Find all possible pattern matches in original_graph. Note that
  204. # pattern matches may overlap with each other.
  205. matcher = _SubgraphMatcher(pattern_graph)
  206. matches: List[Match] = []
  207. # Consider each node as an "anchor" (deepest matching graph node)
  208. for anchor in original_graph.nodes:
  209. if matcher.matches_subgraph_from_anchor(anchor):
  210. def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool:
  211. # `lookup` represents all the nodes in `original_graph`
  212. # that are part of `pattern`
  213. lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()}
  214. for n in lookup.keys():
  215. # Nodes that can "leak"...
  216. # Placeholders (by definition)
  217. if n.op == "placeholder":
  218. continue
  219. # Pattern output (acts as a container)
  220. if lookup[n].op == "output":
  221. continue
  222. # Result contained by pattern output (what we'll
  223. # hook in to the new Graph, thus what we'll
  224. # potentially use in other areas of the Graph as
  225. # an input Node)
  226. if (len(lookup[n].users) == 1
  227. and list(lookup[n].users.keys())[0].op == "output"):
  228. continue
  229. for user in n.users:
  230. # If this node has users that were not in
  231. # `lookup`, then it must leak out of the
  232. # pattern subgraph
  233. if user not in lookup:
  234. return False
  235. return True
  236. # It's not a match if the pattern leaks out into the rest
  237. # of the graph
  238. if pattern_is_contained(matcher.nodes_map):
  239. # Shallow copy nodes_map
  240. matches.append(Match(anchor=anchor,
  241. nodes_map=copy.copy({
  242. key: value
  243. for key, value in matcher.nodes_map.items()
  244. })))
  245. # The set of all nodes in `original_graph` that we've seen thus far
  246. # as part of a pattern match
  247. replaced_nodes: Set[Node] = set()
  248. # As we progressively replace nodes, we'll need to keep track of how the match results should change
  249. match_changed_node: Dict[Node, Node] = dict()
  250. # Return True if one of the nodes in the current match has already
  251. # been used as part of another match
  252. def overlaps_with_prev_match(match: Match) -> bool:
  253. for pn, gn in match.nodes_map.items():
  254. if pn.op in ["placeholder", "output"]:
  255. continue
  256. if gn in replaced_nodes and gn.op != "placeholder":
  257. return True
  258. return False
  259. for match in matches:
  260. # Skip overlapping matches
  261. if overlaps_with_prev_match(match):
  262. continue
  263. # Map replacement graph nodes to their copy in `original_graph`
  264. val_map: Dict[Node, Node] = {}
  265. pattern_placeholders = [n for n in pattern_graph.nodes
  266. if n.op == "placeholder"]
  267. assert len(pattern_placeholders) > 0
  268. replacement_placeholders = [n for n in replacement_graph.nodes
  269. if n.op == "placeholder"]
  270. assert len(pattern_placeholders) == len(replacement_placeholders)
  271. placeholder_map = {r: p for r, p
  272. in zip(replacement_placeholders, pattern_placeholders)}
  273. # node from `original_graph` that matched with the output node
  274. # in `pattern`
  275. subgraph_output: Node = match.anchor
  276. def mark_node_as_replaced(n: Node) -> None:
  277. if n not in match.nodes_map.values():
  278. return
  279. for n_ in n.all_input_nodes:
  280. mark_node_as_replaced(n_)
  281. replaced_nodes.add(n)
  282. for input_node in subgraph_output.all_input_nodes:
  283. mark_node_as_replaced(input_node)
  284. # Initialize `val_map` with mappings from placeholder nodes in
  285. # `replacement` to their corresponding node in `original_graph`
  286. for replacement_node in replacement_placeholders:
  287. # Get the `original_graph` placeholder node
  288. # corresponding to the current `replacement_node`
  289. pattern_node = placeholder_map[replacement_node]
  290. original_graph_node = match_changed_node.get(match.nodes_map[pattern_node], match.nodes_map[pattern_node])
  291. # Populate `val_map`
  292. val_map[replacement_node] = original_graph_node
  293. # Copy the replacement graph over
  294. with original_graph.inserting_before(subgraph_output):
  295. copied_output = original_graph.graph_copy(replacement_graph,
  296. val_map)
  297. # Hook the output Node of the replacement subgraph in to the
  298. # original Graph at the correct location
  299. # CASE 1: We need to hook the replacement subgraph in somewhere
  300. # in the middle of the graph. We replace the Node in the
  301. # original graph that corresponds to the end of the pattern
  302. # subgraph
  303. if subgraph_output.op != "output":
  304. pattern_outputs = [n for n in pattern_graph.nodes
  305. if n.op == "output"]
  306. assert len(pattern_outputs) > 0
  307. replacement_outputs = [n for n in replacement_graph.nodes
  308. if n.op == "output"]
  309. assert len(replacement_outputs) == len(pattern_outputs)
  310. outputs_map = {p: r for r, p
  311. in zip(replacement_outputs, pattern_outputs)}
  312. for pn, gn in match.nodes_map.items():
  313. if gn.op == "placeholder":
  314. continue
  315. # Search for the node corresponding to the output of the pattern
  316. if pn.op != "output":
  317. continue
  318. assert subgraph_output == gn
  319. # Update all anchor inputs to the new nodes
  320. rn = outputs_map[pn]
  321. for pn_input, rn_input in zip(pn.all_input_nodes, rn.all_input_nodes):
  322. gn_input = match.nodes_map[pn_input]
  323. rn_input_in_original_graph = val_map[rn_input]
  324. gn_input.replace_all_uses_with(rn_input_in_original_graph)
  325. # We store the updated node point in case other nodes want to use it
  326. match_changed_node[gn_input] = rn_input_in_original_graph
  327. assert subgraph_output.op != "output"
  328. # CASE 2: The pattern subgraph match extends to the end of the
  329. # original graph, so we need to change the current graph's
  330. # output Node to reflect the insertion of the replacement graph.
  331. # We'll keep the current output Node, but update its args and
  332. # `_input_nodes` as necessary
  333. else:
  334. subgraph_output.args = ((copied_output,))
  335. if isinstance(copied_output, Node):
  336. subgraph_output._input_nodes = {copied_output: None}
  337. assert isinstance(copied_output, Node)
  338. # Erase the `pattern` nodes
  339. for node in reversed(original_graph.nodes):
  340. if len(node.users) == 0 and node.op != "output":
  341. original_graph.erase_node(node)
  342. # Update the passed-in GraphModule to reflect the new state of
  343. # `original_graph`
  344. gm.recompile()
  345. # If `replacement` was an nn.Module, we'll need to make sure that
  346. # all the submodules have been copied over correctly
  347. if isinstance(replacement, torch.nn.Module):
  348. _replace_submodules(gm, replacement)
  349. return matches