| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428 |
- from .graph_module import GraphModule
- from .graph import Graph
- from .node import Node
- from ._symbolic_trace import symbolic_trace
- from ._compatibility import compatibility
- import copy
- from typing import Callable, Dict, List, NamedTuple, Optional, Set
- import torch
- @compatibility(is_backward_compatible=True)
- class Match(NamedTuple):
- # Node from which the match was found
- anchor: Node
- # Maps nodes in the pattern subgraph to nodes in the larger graph
- nodes_map: Dict[Node, Node]
- class _SubgraphMatcher:
- def __init__(self, pattern: Graph) -> None:
- self.pattern = pattern
- if len(pattern.nodes) == 0:
- raise ValueError("_SubgraphMatcher cannot be initialized with an "
- "empty pattern")
- # `self.pattern_anchor` is the output Node in `pattern`
- self.pattern_anchor = next(iter(reversed(pattern.nodes)))
- # Ensure that there is only a single output value in the pattern
- # since we don't support multiple outputs
- assert len(self.pattern_anchor.all_input_nodes) == 1, \
- "Pattern matching on multiple outputs is not supported"
- # Maps nodes in the pattern subgraph to nodes in the larger graph
- self.nodes_map: Dict[Node, Node] = {}
- def matches_subgraph_from_anchor(self, anchor: Node) -> bool:
- """
- Checks if the whole pattern can be matched starting from
- ``anchor`` in the larger graph.
- Pattern matching is done by recursively comparing the pattern
- node's use-def relationships against the graph node's.
- """
- self.nodes_map = {}
- return self._match_nodes(self.pattern_anchor, anchor)
- # Compare the pattern node `pn` against the graph node `gn`
- def _match_nodes(self, pn: Node, gn: Node) -> bool:
- # Check if we've already matched these nodes in the current
- # traversal
- if pn in self.nodes_map:
- return self.nodes_map[pn] == gn
- def attributes_are_equal(pn: Node, gn: Node) -> bool:
- # Use placeholder and output nodes as wildcards. The
- # only exception is that an output node can't match
- # a placeholder
- if (pn.op == "placeholder"
- or (pn.op == "output" and gn.op != "placeholder")):
- return True
- return pn.op == gn.op and pn.target == gn.target
- # Terminate early if the node attributes are not equal
- if not attributes_are_equal(pn, gn):
- return False
- # Optimistically mark `pn` as a match for `gn`
- self.nodes_map[pn] = gn
- # Traverse the use-def relationships to ensure that `pn` is a true
- # match for `gn`
- if pn.op == "placeholder":
- return True
- if (pn.op != "output"
- and len(pn.all_input_nodes) != len(gn.all_input_nodes)):
- return False
- if pn.op == "output":
- match_found = any(self._match_nodes(pn.all_input_nodes[0], gn_)
- for gn_ in gn.all_input_nodes)
- else:
- match_found = (len(pn.all_input_nodes) == len(gn.all_input_nodes)
- and all(self._match_nodes(pn_, gn_) for pn_, gn_
- in zip(pn.all_input_nodes, gn.all_input_nodes)))
- if not match_found:
- self.nodes_map.pop(pn)
- return False
- return True
- def _replace_submodules(gm: GraphModule, replacement: torch.nn.Module) -> None:
- gm.delete_all_unused_submodules()
- if isinstance(replacement, GraphModule):
- replacement.graph.lint()
- def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Module]:
- try:
- mod_match = mod.get_submodule(target)
- return mod_match
- except AttributeError:
- return None
- for node in gm.graph.nodes:
- if node.op == "call_module" or node.op == "get_attr":
- gm_submod = try_get_submodule(gm, node.target)
- replacement_submod = try_get_submodule(replacement, node.target)
- # CASE 1: This target already exists as a submodule in our
- # result GraphModule. Whether or not it exists in
- # `replacement`, the existing submodule takes precedence.
- if gm_submod is not None:
- continue
- # CASE 2: The target exists as a submodule in `replacement`
- # only, so we need to copy it over.
- elif replacement_submod is not None:
- new_submod = copy.deepcopy(getattr(replacement, node.target))
- gm.add_submodule(node.target, new_submod)
- # CASE 3: The target doesn't exist as a submodule in `gm`
- # or `replacement`
- else:
- raise RuntimeError("Attempted to create a \"", node.op,
- "\" node during subgraph rewriting "
- f"with target {node.target}, but "
- "the referenced submodule does not "
- "exist in either the original "
- "GraphModule `gm` or the replacement"
- " GraphModule `replacement`")
- gm.graph.lint()
- @compatibility(is_backward_compatible=True)
- def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]:
- """
- Matches all possible non-overlapping sets of operators and their
- data dependencies (``pattern``) in the Graph of a GraphModule
- (``gm``), then replaces each of these matched subgraphs with another
- subgraph (``replacement``).
- Args:
- ``gm``: The GraphModule that wraps the Graph to operate on
- ``pattern``: The subgraph to match in ``gm`` for replacement
- ``replacement``: The subgraph to replace ``pattern`` with
- Returns:
- List[Match]: A list of ``Match`` objects representing the places
- in the original graph that ``pattern`` was matched to. The list
- is empty if there are no matches. ``Match`` is defined as:
- .. code-block:: python
- class Match(NamedTuple):
- # Node from which the match was found
- anchor: Node
- # Maps nodes in the pattern subgraph to nodes in the larger graph
- nodes_map: Dict[Node, Node]
- Examples:
- .. code-block:: python
- import torch
- from torch.fx import symbolic_trace, subgraph_rewriter
- class M(torch.nn.Module):
- def __init__(self):
- super().__init__()
- def forward(self, x, w1, w2):
- m1 = torch.cat([w1, w2]).sum()
- m2 = torch.cat([w1, w2]).sum()
- return x + torch.max(m1) + torch.max(m2)
- def pattern(w1, w2):
- return torch.cat([w1, w2]).sum()
- def replacement(w1, w2):
- return torch.stack([w1, w2])
- traced_module = symbolic_trace(M())
- subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
- The above code will first match ``pattern`` in the ``forward``
- method of ``traced_module``. Pattern-matching is done based on
- use-def relationships, not node names. For example, if you had
- ``p = torch.cat([a, b])`` in ``pattern``, you could match
- ``m = torch.cat([a, b])`` in the original ``forward`` function,
- despite the variable names being different (``p`` vs ``m``).
- The ``return`` statement in ``pattern`` is matched based on its
- value only; it may or may not match to the ``return`` statement in
- the larger graph. In other words, the pattern doesn't have to extend
- to the end of the larger graph.
- When the pattern is matched, it will be removed from the larger
- function and replaced by ``replacement``. If there are multiple
- matches for ``pattern`` in the larger function, each non-overlapping
- match will be replaced. In the case of a match overlap, the first
- found match in the set of overlapping matches will be replaced.
- ("First" here being defined as the first in a topological ordering
- of the Nodes' use-def relationships. In most cases, the first Node
- is the parameter that appears directly after ``self``, while the
- last Node is whatever the function returns.)
- One important thing to note is that the parameters of the
- ``pattern`` Callable must be used in the Callable itself,
- and the parameters of the ``replacement`` Callable must match
- the pattern. The first rule is why, in the above code block, the
- ``forward`` function has parameters ``x, w1, w2``, but the
- ``pattern`` function only has parameters ``w1, w2``. ``pattern``
- doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
- As an example of the second rule, consider replacing
- .. code-block:: python
- def pattern(x, y):
- return torch.neg(x) + torch.relu(y)
- with
- .. code-block:: python
- def replacement(x, y):
- return torch.relu(x)
- In this case, ``replacement`` needs the same number of parameters
- as ``pattern`` (both ``x`` and ``y``), even though the parameter
- ``y`` isn't used in ``replacement``.
- After calling ``subgraph_rewriter.replace_pattern``, the generated
- Python code looks like this:
- .. code-block:: python
- def forward(self, x, w1, w2):
- stack_1 = torch.stack([w1, w2])
- sum_1 = stack_1.sum()
- stack_2 = torch.stack([w1, w2])
- sum_2 = stack_2.sum()
- max_1 = torch.max(sum_1)
- add_1 = x + max_1
- max_2 = torch.max(sum_2)
- add_2 = add_1 + max_2
- return add_2
- """
- # Get the graphs for `gm`, `pattern`, `replacement`
- original_graph = gm.graph
- pattern_graph = symbolic_trace(pattern).graph
- replacement_graph = symbolic_trace(replacement).graph
- # Find all possible pattern matches in original_graph. Note that
- # pattern matches may overlap with each other.
- matcher = _SubgraphMatcher(pattern_graph)
- matches: List[Match] = []
- # Consider each node as an "anchor" (deepest matching graph node)
- for anchor in original_graph.nodes:
- if matcher.matches_subgraph_from_anchor(anchor):
- def pattern_is_contained(nodes_map: Dict[Node, Node]) -> bool:
- # `lookup` represents all the nodes in `original_graph`
- # that are part of `pattern`
- lookup: Dict[Node, Node] = {v: k for k, v in nodes_map.items()}
- for n in lookup.keys():
- # Nodes that can "leak"...
- # Placeholders (by definition)
- if n.op == "placeholder":
- continue
- # Pattern output (acts as a container)
- if lookup[n].op == "output":
- continue
- # Result contained by pattern output (what we'll
- # hook in to the new Graph, thus what we'll
- # potentially use in other areas of the Graph as
- # an input Node)
- if (len(lookup[n].users) == 1
- and list(lookup[n].users.keys())[0].op == "output"):
- continue
- for user in n.users:
- # If this node has users that were not in
- # `lookup`, then it must leak out of the
- # pattern subgraph
- if user not in lookup:
- return False
- return True
- # It's not a match if the pattern leaks out into the rest
- # of the graph
- if pattern_is_contained(matcher.nodes_map):
- # Shallow copy nodes_map
- matches.append(Match(anchor=anchor,
- nodes_map=copy.copy({
- key: value
- for key, value in matcher.nodes_map.items()
- })))
- # The set of all nodes in `original_graph` that we've seen thus far
- # as part of a pattern match
- replaced_nodes: Set[Node] = set()
- # As we progressively replace nodes, we'll need to keep track of how the match results should change
- match_changed_node: Dict[Node, Node] = dict()
- # Return True if one of the nodes in the current match has already
- # been used as part of another match
- def overlaps_with_prev_match(match: Match) -> bool:
- for pn, gn in match.nodes_map.items():
- if pn.op in ["placeholder", "output"]:
- continue
- if gn in replaced_nodes and gn.op != "placeholder":
- return True
- return False
- for match in matches:
- # Skip overlapping matches
- if overlaps_with_prev_match(match):
- continue
- # Map replacement graph nodes to their copy in `original_graph`
- val_map: Dict[Node, Node] = {}
- pattern_placeholders = [n for n in pattern_graph.nodes
- if n.op == "placeholder"]
- assert len(pattern_placeholders) > 0
- replacement_placeholders = [n for n in replacement_graph.nodes
- if n.op == "placeholder"]
- assert len(pattern_placeholders) == len(replacement_placeholders)
- placeholder_map = {r: p for r, p
- in zip(replacement_placeholders, pattern_placeholders)}
- # node from `original_graph` that matched with the output node
- # in `pattern`
- subgraph_output: Node = match.anchor
- def mark_node_as_replaced(n: Node) -> None:
- if n not in match.nodes_map.values():
- return
- for n_ in n.all_input_nodes:
- mark_node_as_replaced(n_)
- replaced_nodes.add(n)
- for input_node in subgraph_output.all_input_nodes:
- mark_node_as_replaced(input_node)
- # Initialize `val_map` with mappings from placeholder nodes in
- # `replacement` to their corresponding node in `original_graph`
- for replacement_node in replacement_placeholders:
- # Get the `original_graph` placeholder node
- # corresponding to the current `replacement_node`
- pattern_node = placeholder_map[replacement_node]
- original_graph_node = match_changed_node.get(match.nodes_map[pattern_node], match.nodes_map[pattern_node])
- # Populate `val_map`
- val_map[replacement_node] = original_graph_node
- # Copy the replacement graph over
- with original_graph.inserting_before(subgraph_output):
- copied_output = original_graph.graph_copy(replacement_graph,
- val_map)
- # Hook the output Node of the replacement subgraph in to the
- # original Graph at the correct location
- # CASE 1: We need to hook the replacement subgraph in somewhere
- # in the middle of the graph. We replace the Node in the
- # original graph that corresponds to the end of the pattern
- # subgraph
- if subgraph_output.op != "output":
- pattern_outputs = [n for n in pattern_graph.nodes
- if n.op == "output"]
- assert len(pattern_outputs) > 0
- replacement_outputs = [n for n in replacement_graph.nodes
- if n.op == "output"]
- assert len(replacement_outputs) == len(pattern_outputs)
- outputs_map = {p: r for r, p
- in zip(replacement_outputs, pattern_outputs)}
- for pn, gn in match.nodes_map.items():
- if gn.op == "placeholder":
- continue
- # Search for the node corresponding to the output of the pattern
- if pn.op != "output":
- continue
- assert subgraph_output == gn
- # Update all anchor inputs to the new nodes
- rn = outputs_map[pn]
- for pn_input, rn_input in zip(pn.all_input_nodes, rn.all_input_nodes):
- gn_input = match.nodes_map[pn_input]
- rn_input_in_original_graph = val_map[rn_input]
- gn_input.replace_all_uses_with(rn_input_in_original_graph)
- # We store the updated node point in case other nodes want to use it
- match_changed_node[gn_input] = rn_input_in_original_graph
- assert subgraph_output.op != "output"
- # CASE 2: The pattern subgraph match extends to the end of the
- # original graph, so we need to change the current graph's
- # output Node to reflect the insertion of the replacement graph.
- # We'll keep the current output Node, but update its args and
- # `_input_nodes` as necessary
- else:
- subgraph_output.args = ((copied_output,))
- if isinstance(copied_output, Node):
- subgraph_output._input_nodes = {copied_output: None}
- assert isinstance(copied_output, Node)
- # Erase the `pattern` nodes
- for node in reversed(original_graph.nodes):
- if len(node.users) == 0 and node.op != "output":
- original_graph.erase_node(node)
- # Update the passed-in GraphModule to reflect the new state of
- # `original_graph`
- gm.recompile()
- # If `replacement` was an nn.Module, we'll need to make sure that
- # all the submodules have been copied over correctly
- if isinstance(replacement, torch.nn.Module):
- _replace_submodules(gm, replacement)
- return matches
|