feature_extraction.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. import inspect
  2. import math
  3. import re
  4. import warnings
  5. from collections import OrderedDict
  6. from copy import deepcopy
  7. from itertools import chain
  8. from typing import Dict, Callable, List, Union, Optional, Tuple, Any
  9. import torch
  10. import torchvision
  11. from torch import fx
  12. from torch import nn
  13. from torch.fx.graph_module import _copy_attr
  14. __all__ = ["create_feature_extractor", "get_graph_node_names"]
  15. class LeafModuleAwareTracer(fx.Tracer):
  16. """
  17. An fx.Tracer that allows the user to specify a set of leaf modules, ie.
  18. modules that are not to be traced through. The resulting graph ends up
  19. having single nodes referencing calls to the leaf modules' forward methods.
  20. """
  21. def __init__(self, *args, **kwargs):
  22. self.leaf_modules = {}
  23. if "leaf_modules" in kwargs:
  24. leaf_modules = kwargs.pop("leaf_modules")
  25. self.leaf_modules = leaf_modules
  26. super().__init__(*args, **kwargs)
  27. def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool:
  28. if isinstance(m, tuple(self.leaf_modules)):
  29. return True
  30. return super().is_leaf_module(m, module_qualname)
  31. class NodePathTracer(LeafModuleAwareTracer):
  32. """
  33. NodePathTracer is an FX tracer that, for each operation, also records the
  34. name of the Node from which the operation originated. A node name here is
  35. a `.` separated path walking the hierarchy from top level module down to
  36. leaf operation or leaf module. The name of the top level module is not
  37. included as part of the node name. For example, if we trace a module whose
  38. forward method applies a ReLU module, the name for that node will simply
  39. be 'relu'.
  40. Some notes on the specifics:
  41. - Nodes are recorded to `self.node_to_qualname` which is a dictionary
  42. mapping a given Node object to its node name.
  43. - Nodes are recorded in the order which they are executed during
  44. tracing.
  45. - When a duplicate node name is encountered, a suffix of the form
  46. _{int} is added. The counter starts from 1.
  47. """
  48. def __init__(self, *args, **kwargs):
  49. super().__init__(*args, **kwargs)
  50. # Track the qualified name of the Node being traced
  51. self.current_module_qualname = ""
  52. # A map from FX Node to the qualified name\#
  53. # NOTE: This is loosely like the "qualified name" mentioned in the
  54. # torch.fx docs https://pytorch.org/docs/stable/fx.html but adapted
  55. # for the purposes of the torchvision feature extractor
  56. self.node_to_qualname = OrderedDict()
  57. def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs):
  58. """
  59. Override of `fx.Tracer.call_module`
  60. This override:
  61. 1) Stores away the qualified name of the caller for restoration later
  62. 2) Adds the qualified name of the caller to
  63. `current_module_qualname` for retrieval by `create_proxy`
  64. 3) Once a leaf module is reached, calls `create_proxy`
  65. 4) Restores the caller's qualified name into current_module_qualname
  66. """
  67. old_qualname = self.current_module_qualname
  68. try:
  69. module_qualname = self.path_of_module(m)
  70. self.current_module_qualname = module_qualname
  71. if not self.is_leaf_module(m, module_qualname):
  72. out = forward(*args, **kwargs)
  73. return out
  74. return self.create_proxy("call_module", module_qualname, args, kwargs)
  75. finally:
  76. self.current_module_qualname = old_qualname
  77. def create_proxy(
  78. self, kind: str, target: fx.node.Target, args, kwargs, name=None, type_expr=None, *_
  79. ) -> fx.proxy.Proxy:
  80. """
  81. Override of `Tracer.create_proxy`. This override intercepts the recording
  82. of every operation and stores away the current traced module's qualified
  83. name in `node_to_qualname`
  84. """
  85. proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr)
  86. self.node_to_qualname[proxy.node] = self._get_node_qualname(self.current_module_qualname, proxy.node)
  87. return proxy
  88. def _get_node_qualname(self, module_qualname: str, node: fx.node.Node) -> str:
  89. node_qualname = module_qualname
  90. if node.op != "call_module":
  91. # In this case module_qualname from torch.fx doesn't go all the
  92. # way to the leaf function/op so we need to append it
  93. if len(node_qualname) > 0:
  94. # Only append '.' if we are deeper than the top level module
  95. node_qualname += "."
  96. node_qualname += str(node)
  97. # Now we need to add an _{index} postfix on any repeated node names
  98. # For modules we do this from scratch
  99. # But for anything else, torch.fx already has a globally scoped
  100. # _{index} postfix. But we want it locally (relative to direct parent)
  101. # scoped. So first we need to undo the torch.fx postfix
  102. if re.match(r".+_[0-9]+$", node_qualname) is not None:
  103. node_qualname = node_qualname.rsplit("_", 1)[0]
  104. # ... and now we add on our own postfix
  105. for existing_qualname in reversed(self.node_to_qualname.values()):
  106. # Check to see if existing_qualname is of the form
  107. # {node_qualname} or {node_qualname}_{int}
  108. if re.match(rf"{node_qualname}(_[0-9]+)?$", existing_qualname) is not None:
  109. postfix = existing_qualname.replace(node_qualname, "")
  110. if len(postfix):
  111. # existing_qualname is of the form {node_qualname}_{int}
  112. next_index = int(postfix[1:]) + 1
  113. else:
  114. # existing_qualname is of the form {node_qualname}
  115. next_index = 1
  116. node_qualname += f"_{next_index}"
  117. break
  118. return node_qualname
  119. def _is_subseq(x, y):
  120. """Check if y is a subseqence of x
  121. https://stackoverflow.com/a/24017747/4391249
  122. """
  123. iter_x = iter(x)
  124. return all(any(x_item == y_item for x_item in iter_x) for y_item in y)
  125. def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathTracer):
  126. """
  127. Utility function for warning the user if there are differences between
  128. the train graph nodes and the eval graph nodes.
  129. """
  130. train_nodes = list(train_tracer.node_to_qualname.values())
  131. eval_nodes = list(eval_tracer.node_to_qualname.values())
  132. if len(train_nodes) == len(eval_nodes) and all(t == e for t, e in zip(train_nodes, eval_nodes)):
  133. return
  134. suggestion_msg = (
  135. "When choosing nodes for feature extraction, you may need to specify "
  136. "output nodes for train and eval mode separately."
  137. )
  138. if _is_subseq(train_nodes, eval_nodes):
  139. msg = (
  140. "NOTE: The nodes obtained by tracing the model in eval mode "
  141. "are a subsequence of those obtained in train mode. "
  142. )
  143. elif _is_subseq(eval_nodes, train_nodes):
  144. msg = (
  145. "NOTE: The nodes obtained by tracing the model in train mode "
  146. "are a subsequence of those obtained in eval mode. "
  147. )
  148. else:
  149. msg = "The nodes obtained by tracing the model in train mode are different to those obtained in eval mode. "
  150. warnings.warn(msg + suggestion_msg)
  151. def _get_leaf_modules_for_ops() -> List[type]:
  152. members = inspect.getmembers(torchvision.ops)
  153. result = []
  154. for _, obj in members:
  155. if inspect.isclass(obj) and issubclass(obj, torch.nn.Module):
  156. result.append(obj)
  157. return result
  158. def _set_default_tracer_kwargs(original_tr_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
  159. default_autowrap_modules = (math, torchvision.ops)
  160. default_leaf_modules = _get_leaf_modules_for_ops()
  161. result_tracer_kwargs = {} if original_tr_kwargs is None else original_tr_kwargs
  162. result_tracer_kwargs["autowrap_modules"] = (
  163. tuple(set(result_tracer_kwargs["autowrap_modules"] + default_autowrap_modules))
  164. if "autowrap_modules" in result_tracer_kwargs
  165. else default_autowrap_modules
  166. )
  167. result_tracer_kwargs["leaf_modules"] = (
  168. list(set(result_tracer_kwargs["leaf_modules"] + default_leaf_modules))
  169. if "leaf_modules" in result_tracer_kwargs
  170. else default_leaf_modules
  171. )
  172. return result_tracer_kwargs
  173. def get_graph_node_names(
  174. model: nn.Module,
  175. tracer_kwargs: Optional[Dict[str, Any]] = None,
  176. suppress_diff_warning: bool = False,
  177. ) -> Tuple[List[str], List[str]]:
  178. """
  179. Dev utility to return node names in order of execution. See note on node
  180. names under :func:`create_feature_extractor`. Useful for seeing which node
  181. names are available for feature extraction. There are two reasons that
  182. node names can't easily be read directly from the code for a model:
  183. 1. Not all submodules are traced through. Modules from ``torch.nn`` all
  184. fall within this category.
  185. 2. Nodes representing the repeated application of the same operation
  186. or leaf module get a ``_{counter}`` postfix.
  187. The model is traced twice: once in train mode, and once in eval mode. Both
  188. sets of node names are returned.
  189. For more details on the node naming conventions used here, please see the
  190. :ref:`relevant subheading <about-node-names>` in the
  191. `documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.
  192. Args:
  193. model (nn.Module): model for which we'd like to print node names
  194. tracer_kwargs (dict, optional): a dictionary of keyword arguments for
  195. ``NodePathTracer`` (they are eventually passed onto
  196. `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
  197. By default it will be set to wrap and make leaf nodes all torchvision ops:
  198. {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
  199. WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
  200. provided dictionary.
  201. suppress_diff_warning (bool, optional): whether to suppress a warning
  202. when there are discrepancies between the train and eval version of
  203. the graph. Defaults to False.
  204. Returns:
  205. tuple(list, list): a list of node names from tracing the model in
  206. train mode, and another from tracing the model in eval mode.
  207. Examples::
  208. >>> model = torchvision.models.resnet18()
  209. >>> train_nodes, eval_nodes = get_graph_node_names(model)
  210. """
  211. tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
  212. is_training = model.training
  213. train_tracer = NodePathTracer(**tracer_kwargs)
  214. train_tracer.trace(model.train())
  215. eval_tracer = NodePathTracer(**tracer_kwargs)
  216. eval_tracer.trace(model.eval())
  217. train_nodes = list(train_tracer.node_to_qualname.values())
  218. eval_nodes = list(eval_tracer.node_to_qualname.values())
  219. if not suppress_diff_warning:
  220. _warn_graph_differences(train_tracer, eval_tracer)
  221. # Restore training state
  222. model.train(is_training)
  223. return train_nodes, eval_nodes
  224. class DualGraphModule(fx.GraphModule):
  225. """
  226. A derivative of `fx.GraphModule`. Differs in the following ways:
  227. - Requires a train and eval version of the underlying graph
  228. - Copies submodules according to the nodes of both train and eval graphs.
  229. - Calling train(mode) switches between train graph and eval graph.
  230. """
  231. def __init__(
  232. self, root: torch.nn.Module, train_graph: fx.Graph, eval_graph: fx.Graph, class_name: str = "GraphModule"
  233. ):
  234. """
  235. Args:
  236. root (nn.Module): module from which the copied module hierarchy is
  237. built
  238. train_graph (fx.Graph): the graph that should be used in train mode
  239. eval_graph (fx.Graph): the graph that should be used in eval mode
  240. """
  241. super(fx.GraphModule, self).__init__()
  242. self.__class__.__name__ = class_name
  243. self.train_graph = train_graph
  244. self.eval_graph = eval_graph
  245. # Copy all get_attr and call_module ops (indicated by BOTH train and
  246. # eval graphs)
  247. for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)):
  248. if node.op in ["get_attr", "call_module"]:
  249. if not isinstance(node.target, str):
  250. raise TypeError(f"node.target should be of type str instead of {type(node.target)}")
  251. _copy_attr(root, self, node.target)
  252. # train mode by default
  253. self.train()
  254. self.graph = train_graph
  255. # (borrowed from fx.GraphModule):
  256. # Store the Tracer class responsible for creating a Graph separately as part of the
  257. # GraphModule state, except when the Tracer is defined in a local namespace.
  258. # Locally defined Tracers are not pickleable. This is needed because torch.package will
  259. # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
  260. # to re-create the Graph during deserialization.
  261. if self.eval_graph._tracer_cls != self.train_graph._tracer_cls:
  262. raise TypeError(
  263. f"Train mode and eval mode should use the same tracer class. Instead got {self.eval_graph._tracer_cls} for eval vs {self.train_graph._tracer_cls} for train"
  264. )
  265. self._tracer_cls = None
  266. if self.graph._tracer_cls and "<locals>" not in self.graph._tracer_cls.__qualname__:
  267. self._tracer_cls = self.graph._tracer_cls
  268. def train(self, mode=True):
  269. """
  270. Swap out the graph depending on the selected training mode.
  271. NOTE this should be safe when calling model.eval() because that just
  272. calls this with mode == False.
  273. """
  274. # NOTE: Only set self.graph if the current graph is not the desired
  275. # one. This saves us from recompiling the graph where not necessary.
  276. if mode and not self.training:
  277. self.graph = self.train_graph
  278. elif not mode and self.training:
  279. self.graph = self.eval_graph
  280. return super().train(mode=mode)
  281. def create_feature_extractor(
  282. model: nn.Module,
  283. return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
  284. train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
  285. eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None,
  286. tracer_kwargs: Optional[Dict[str, Any]] = None,
  287. suppress_diff_warning: bool = False,
  288. ) -> fx.GraphModule:
  289. """
  290. Creates a new graph module that returns intermediate nodes from a given
  291. model as dictionary with user specified keys as strings, and the requested
  292. outputs as values. This is achieved by re-writing the computation graph of
  293. the model via FX to return the desired nodes as outputs. All unused nodes
  294. are removed, together with their corresponding parameters.
  295. Desired output nodes must be specified as a ``.`` separated
  296. path walking the module hierarchy from top level module down to leaf
  297. operation or leaf module. For more details on the node naming conventions
  298. used here, please see the :ref:`relevant subheading <about-node-names>`
  299. in the `documentation <https://pytorch.org/vision/stable/feature_extraction.html>`_.
  300. Not all models will be FX traceable, although with some massaging they can
  301. be made to cooperate. Here's a (not exhaustive) list of tips:
  302. - If you don't need to trace through a particular, problematic
  303. sub-module, turn it into a "leaf module" by passing a list of
  304. ``leaf_modules`` as one of the ``tracer_kwargs`` (see example below).
  305. It will not be traced through, but rather, the resulting graph will
  306. hold a reference to that module's forward method.
  307. - Likewise, you may turn functions into leaf functions by passing a
  308. list of ``autowrap_functions`` as one of the ``tracer_kwargs`` (see
  309. example below).
  310. - Some inbuilt Python functions can be problematic. For instance,
  311. ``int`` will raise an error during tracing. You may wrap them in your
  312. own function and then pass that in ``autowrap_functions`` as one of
  313. the ``tracer_kwargs``.
  314. For further information on FX see the
  315. `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_.
  316. Args:
  317. model (nn.Module): model on which we will extract the features
  318. return_nodes (list or dict, optional): either a ``List`` or a ``Dict``
  319. containing the names (or partial names - see note above)
  320. of the nodes for which the activations will be returned. If it is
  321. a ``Dict``, the keys are the node names, and the values
  322. are the user-specified keys for the graph module's returned
  323. dictionary. If it is a ``List``, it is treated as a ``Dict`` mapping
  324. node specification strings directly to output names. In the case
  325. that ``train_return_nodes`` and ``eval_return_nodes`` are specified,
  326. this should not be specified.
  327. train_return_nodes (list or dict, optional): similar to
  328. ``return_nodes``. This can be used if the return nodes
  329. for train mode are different than those from eval mode.
  330. If this is specified, ``eval_return_nodes`` must also be specified,
  331. and ``return_nodes`` should not be specified.
  332. eval_return_nodes (list or dict, optional): similar to
  333. ``return_nodes``. This can be used if the return nodes
  334. for train mode are different than those from eval mode.
  335. If this is specified, ``train_return_nodes`` must also be specified,
  336. and `return_nodes` should not be specified.
  337. tracer_kwargs (dict, optional): a dictionary of keyword arguments for
  338. ``NodePathTracer`` (which passes them onto it's parent class
  339. `torch.fx.Tracer <https://pytorch.org/docs/stable/fx.html#torch.fx.Tracer>`_).
  340. By default it will be set to wrap and make leaf nodes all torchvision ops:
  341. {"autowrap_modules": (math, torchvision.ops,),"leaf_modules": _get_leaf_modules_for_ops(),}
  342. WARNING: In case the user provides tracer_kwargs, above default arguments will be appended to the user
  343. provided dictionary.
  344. suppress_diff_warning (bool, optional): whether to suppress a warning
  345. when there are discrepancies between the train and eval version of
  346. the graph. Defaults to False.
  347. Examples::
  348. >>> # Feature extraction with resnet
  349. >>> model = torchvision.models.resnet18()
  350. >>> # extract layer1 and layer3, giving as names `feat1` and feat2`
  351. >>> model = create_feature_extractor(
  352. >>> model, {'layer1': 'feat1', 'layer3': 'feat2'})
  353. >>> out = model(torch.rand(1, 3, 224, 224))
  354. >>> print([(k, v.shape) for k, v in out.items()])
  355. >>> [('feat1', torch.Size([1, 64, 56, 56])),
  356. >>> ('feat2', torch.Size([1, 256, 14, 14]))]
  357. >>> # Specifying leaf modules and leaf functions
  358. >>> def leaf_function(x):
  359. >>> # This would raise a TypeError if traced through
  360. >>> return int(x)
  361. >>>
  362. >>> class LeafModule(torch.nn.Module):
  363. >>> def forward(self, x):
  364. >>> # This would raise a TypeError if traced through
  365. >>> int(x.shape[0])
  366. >>> return torch.nn.functional.relu(x + 4)
  367. >>>
  368. >>> class MyModule(torch.nn.Module):
  369. >>> def __init__(self):
  370. >>> super().__init__()
  371. >>> self.conv = torch.nn.Conv2d(3, 1, 3)
  372. >>> self.leaf_module = LeafModule()
  373. >>>
  374. >>> def forward(self, x):
  375. >>> leaf_function(x.shape[0])
  376. >>> x = self.conv(x)
  377. >>> return self.leaf_module(x)
  378. >>>
  379. >>> model = create_feature_extractor(
  380. >>> MyModule(), return_nodes=['leaf_module'],
  381. >>> tracer_kwargs={'leaf_modules': [LeafModule],
  382. >>> 'autowrap_functions': [leaf_function]})
  383. """
  384. tracer_kwargs = _set_default_tracer_kwargs(tracer_kwargs)
  385. is_training = model.training
  386. if all(arg is None for arg in [return_nodes, train_return_nodes, eval_return_nodes]):
  387. raise ValueError(
  388. "Either `return_nodes` or `train_return_nodes` and `eval_return_nodes` together, should be specified"
  389. )
  390. if (train_return_nodes is None) ^ (eval_return_nodes is None):
  391. raise ValueError(
  392. "If any of `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified"
  393. )
  394. if not ((return_nodes is None) ^ (train_return_nodes is None)):
  395. raise ValueError("If `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified")
  396. # Put *_return_nodes into Dict[str, str] format
  397. def to_strdict(n) -> Dict[str, str]:
  398. if isinstance(n, list):
  399. return {str(i): str(i) for i in n}
  400. return {str(k): str(v) for k, v in n.items()}
  401. if train_return_nodes is None:
  402. return_nodes = to_strdict(return_nodes)
  403. train_return_nodes = deepcopy(return_nodes)
  404. eval_return_nodes = deepcopy(return_nodes)
  405. else:
  406. train_return_nodes = to_strdict(train_return_nodes)
  407. eval_return_nodes = to_strdict(eval_return_nodes)
  408. # Repeat the tracing and graph rewriting for train and eval mode
  409. tracers = {}
  410. graphs = {}
  411. mode_return_nodes: Dict[str, Dict[str, str]] = {"train": train_return_nodes, "eval": eval_return_nodes}
  412. for mode in ["train", "eval"]:
  413. if mode == "train":
  414. model.train()
  415. elif mode == "eval":
  416. model.eval()
  417. # Instantiate our NodePathTracer and use that to trace the model
  418. tracer = NodePathTracer(**tracer_kwargs)
  419. graph = tracer.trace(model)
  420. name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__
  421. graph_module = fx.GraphModule(tracer.root, graph, name)
  422. available_nodes = list(tracer.node_to_qualname.values())
  423. # FIXME We don't know if we should expect this to happen
  424. if len(set(available_nodes)) != len(available_nodes):
  425. raise ValueError(
  426. "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues"
  427. )
  428. # Check that all outputs in return_nodes are present in the model
  429. for query in mode_return_nodes[mode].keys():
  430. # To check if a query is available we need to check that at least
  431. # one of the available names starts with it up to a .
  432. if not any([re.match(rf"^{query}(\.|$)", n) is not None for n in available_nodes]):
  433. raise ValueError(
  434. f"node: '{query}' is not present in model. Hint: use "
  435. "`get_graph_node_names` to make sure the "
  436. "`return_nodes` you specified are present. It may even "
  437. "be that you need to specify `train_return_nodes` and "
  438. "`eval_return_nodes` separately."
  439. )
  440. # Remove existing output nodes (train mode)
  441. orig_output_nodes = []
  442. for n in reversed(graph_module.graph.nodes):
  443. if n.op == "output":
  444. orig_output_nodes.append(n)
  445. if not orig_output_nodes:
  446. raise ValueError("No output nodes found in graph_module.graph.nodes")
  447. for n in orig_output_nodes:
  448. graph_module.graph.erase_node(n)
  449. # Find nodes corresponding to return_nodes and make them into output_nodes
  450. nodes = [n for n in graph_module.graph.nodes]
  451. output_nodes = OrderedDict()
  452. for n in reversed(nodes):
  453. module_qualname = tracer.node_to_qualname.get(n)
  454. if module_qualname is None:
  455. # NOTE - Know cases where this happens:
  456. # - Node representing creation of a tensor constant - probably
  457. # not interesting as a return node
  458. # - When packing outputs into a named tuple like in InceptionV3
  459. continue
  460. for query in mode_return_nodes[mode]:
  461. depth = query.count(".")
  462. if ".".join(module_qualname.split(".")[: depth + 1]) == query:
  463. output_nodes[mode_return_nodes[mode][query]] = n
  464. mode_return_nodes[mode].pop(query)
  465. break
  466. output_nodes = OrderedDict(reversed(list(output_nodes.items())))
  467. # And add them in the end of the graph
  468. with graph_module.graph.inserting_after(nodes[-1]):
  469. graph_module.graph.output(output_nodes)
  470. # Remove unused modules / parameters
  471. graph_module.graph.eliminate_dead_code()
  472. graph_module.recompile()
  473. # Keep track of the tracer and graph so we can choose the main one
  474. tracers[mode] = tracer
  475. graphs[mode] = graph
  476. # Warn user if there are any discrepancies between the graphs of the
  477. # train and eval modes
  478. if not suppress_diff_warning:
  479. _warn_graph_differences(tracers["train"], tracers["eval"])
  480. # Build the final graph module
  481. graph_module = DualGraphModule(model, graphs["train"], graphs["eval"], class_name=name)
  482. # Restore original training mode
  483. model.train(is_training)
  484. graph_module.train(is_training)
  485. return graph_module