graph_module.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728
  1. import torch
  2. import torch.nn as nn
  3. import torch.overrides
  4. from torch.nn.modules.module import _addindent
  5. from torch.package import PackageImporter, PackageExporter
  6. import linecache
  7. from typing import Type, Dict, List, Any, Union, Optional, Set
  8. from .graph import Graph, _PyTreeCodeGen, _is_from_torch, _custom_builtins, PythonCode
  9. from ._compatibility import compatibility
  10. from torch.package import Importer, sys_importer
  11. import copy
  12. import itertools
  13. import sys
  14. import traceback
  15. from pathlib import Path
  16. import os
  17. import warnings
  18. # Normal exec loses the source code, however we can work with
  19. # the linecache module to recover it.
  20. # Using _exec_with_source will add it to our local cache
  21. # and then tools like TorchScript will be able to get source info.
  22. class _EvalCacheLoader(object):
  23. def __init__(self):
  24. self.eval_cache = {}
  25. self.next_id = 0
  26. def cache(self, src: str, globals: Dict[str, Any]):
  27. """Store the source in a private cache, and add a lazy entry in linecache
  28. that allows the source to be retrieved by 'filename'.
  29. Args:
  30. src (str): The module source to cache
  31. globals (dict): The module globals
  32. Returns:
  33. str: The cache key (and dummy filename) generated for src.
  34. """
  35. key = self._get_key()
  36. self.eval_cache[key] = src
  37. # Don't mutate globals so that this loader is only used
  38. # to populate linecache, and doesn't interact with other modules
  39. # that might check `__loader__`
  40. globals_copy = globals.copy()
  41. globals_copy['__file__'] = key
  42. globals_copy['__name__'] = key
  43. globals_copy['__loader__'] = self
  44. linecache.lazycache(key, globals_copy)
  45. return key
  46. # Part of the loader protocol (PEP 302)
  47. # linecache will use this method when trying to find source code
  48. def get_source(self, module_name) -> Optional[str]:
  49. if module_name in self.eval_cache:
  50. return self.eval_cache[module_name]
  51. return None
  52. def _get_key(self):
  53. key = f'<eval_with_key>.{self.next_id}'
  54. self.next_id += 1
  55. return key
  56. _loader = _EvalCacheLoader()
  57. def _exec_with_source(src: str, globals: Dict[str, Any]):
  58. key = _loader.cache(src, globals)
  59. exec(compile(src, key, 'exec'), globals)
  60. def _forward_from_src(src: str, globals: Dict[str, Any]):
  61. # avoid mutating the passed in dict
  62. globals_copy = globals.copy()
  63. _exec_with_source(src, globals_copy)
  64. forward_fn = globals_copy['forward']
  65. del globals_copy['forward']
  66. return forward_fn
  67. def _format_import_statement(name: str, obj: Any, importer: Importer) -> str:
  68. if name in _custom_builtins:
  69. return _custom_builtins[name].import_str
  70. if _is_from_torch(name):
  71. return 'import torch'
  72. module_name, attr_name = importer.get_name(obj)
  73. return f'from {module_name} import {attr_name} as {name}'
  74. def _format_import_block(globals: Dict[str, Any], importer: Importer):
  75. import_strs: Set[str] = set()
  76. for name, obj in globals.items():
  77. import_strs.add(_format_import_statement(name, obj, importer))
  78. return '\n'.join(import_strs)
  79. @compatibility(is_backward_compatible=True)
  80. def reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.Module:
  81. # BC: attribute name was changed from `code` to `_code` to facilitate
  82. # making `code` into a property and adding a docstring to it
  83. fn_src = body.get('_code') or body['code']
  84. forward = _forward_from_src(import_block + fn_src, {})
  85. return _deserialize_graph_module(forward, body)
  86. @compatibility(is_backward_compatible=True)
  87. def reduce_package_graph_module(
  88. importer: PackageImporter, body: Dict[Any, Any], generated_module_name: str
  89. ) -> torch.nn.Module:
  90. forward = importer.import_module(generated_module_name).forward
  91. return _deserialize_graph_module(forward, body)
  92. @compatibility(is_backward_compatible=True)
  93. def reduce_deploy_graph_module(
  94. importer: PackageImporter, body: Dict[Any, Any], import_block: str
  95. ) -> torch.nn.Module:
  96. ns = dict()
  97. ns["__builtins__"] = importer.patched_builtins
  98. fn_src = body.get('_code')
  99. assert fn_src is not None
  100. forward = _forward_from_src(import_block + fn_src, ns)
  101. return _deserialize_graph_module(forward, body)
  102. def _deserialize_graph_module(forward, body: Dict[Any, Any]) -> torch.nn.Module:
  103. """
  104. Deserialize a GraphModule given the dictionary of the original module,
  105. using the code to reconstruct the graph. We delete the actual graph before
  106. saving the dictionary so that changes to the in-memory graph format do not
  107. get serialized.
  108. """
  109. # We create a dummy class here because symbolic_trace pulls the forward()
  110. # function off of the class, rather than the instance
  111. class CodeOnlyModule(torch.nn.Module):
  112. def __init__(self, body):
  113. super().__init__()
  114. self.__dict__ = body
  115. # Try to retrieve the forward source in a backward-compatible way
  116. CodeOnlyModule.forward = forward
  117. tracer_cls = body.get('_tracer_cls')
  118. if tracer_cls is None:
  119. from ._symbolic_trace import Tracer
  120. tracer_cls = Tracer
  121. graphmodule_cls_name = body.get('_graphmodule_cls_name', 'GraphModule')
  122. # This is a workaround for a mypy linter issue related to
  123. # passing base class as an argument - https://github.com/python/mypy/issues/5865.
  124. cls_tracer : Any = tracer_cls
  125. class KeepModules(cls_tracer):
  126. # we shouldn't trace into any of the submodules,
  127. # because they were not traced in the original GraphModule
  128. def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool:
  129. return True
  130. com = CodeOnlyModule(body)
  131. tracer_extras = body.get('_tracer_extras', {})
  132. graph = KeepModules().trace(com, **tracer_extras)
  133. # Manually set Tracer class on the reconstructed Graph, to avoid
  134. # referencing the private local subclass KeepModules.
  135. graph._tracer_cls = tracer_cls
  136. gm = GraphModule(com, graph, class_name=graphmodule_cls_name)
  137. # The GraphModule constructor only retains attributes referenced by the graph.
  138. # In this case, our goal is return a GraphModule as close to identical as the one
  139. # put into the package. If any additional attributes were present in body,
  140. # we should keep them.
  141. for k, v in body.items():
  142. if not hasattr(gm, k):
  143. setattr(gm, k, v)
  144. return gm
  145. # copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
  146. # This installs empty Modules where none exist yet if they are subpaths of target
  147. def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str):
  148. *prefix, field = target.split('.')
  149. for item in prefix:
  150. f = getattr(from_module, item)
  151. t = getattr(to_module, item, None)
  152. if f is t:
  153. # we have already installed one of its parents
  154. # (e.g. target = root.linear.weight, but we have already installed root.linear)
  155. # once we install a parent, we no longer need to copy the children
  156. # since all the needed properties will already be present
  157. return
  158. if t is None:
  159. t = torch.nn.Module()
  160. setattr(to_module, item, t)
  161. from_module, to_module = f, t
  162. orig = getattr(from_module, field)
  163. # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
  164. # So, we register it as a named buffer in the target module.
  165. if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter):
  166. to_module.register_buffer(field, orig)
  167. else:
  168. setattr(to_module, field, orig)
  169. # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module
  170. # This installs empty Modules where none exist yet if they are subpaths of target
  171. def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
  172. *prefix, field = target.split('.')
  173. for item in prefix:
  174. t = getattr(to_module, item, None)
  175. if t is None:
  176. t = torch.nn.Module()
  177. setattr(to_module, item, t)
  178. to_module = t
  179. # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
  180. # So, we register it as a named buffer in the target module.
  181. if isinstance(from_obj, torch.Tensor) and not isinstance(from_obj, torch.nn.Parameter):
  182. to_module.register_buffer(field, from_obj)
  183. else:
  184. setattr(to_module, field, from_obj)
  185. class _WrappedCall:
  186. def __init__(self, cls, cls_call):
  187. self.cls = cls
  188. self.cls_call = cls_call
  189. # Previously, if an error occurred when valid
  190. # symbolically-traced code was run with an invalid input, the
  191. # user would see the source of the error as coming from
  192. # `File "<eval_with_key_N">`, where N is some number. We use
  193. # this function to generate a more informative error message. We
  194. # return the traceback itself, a message explaining that the
  195. # error occurred in a traced Module's generated forward
  196. # function, and five lines of context surrounding the faulty
  197. # line
  198. @staticmethod
  199. def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
  200. # auxiliary variables (for readability)
  201. err_lineno = frame_summary.lineno
  202. assert err_lineno is not None
  203. line = frame_summary.line
  204. assert line is not None
  205. err_line_len = len(line)
  206. all_src_lines = linecache.getlines(frame_summary.filename)
  207. # constituent substrings of the error message
  208. tb_repr = traceback.format_exc()
  209. custom_msg = ("Call using an FX-traced Module, "
  210. f"line {err_lineno} of the traced Module's "
  211. "generated forward function:")
  212. before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
  213. marker = "~" * err_line_len + "~~~ <--- HERE"
  214. err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
  215. # joined message
  216. return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
  217. def __call__(self, obj, *args, **kwargs):
  218. try:
  219. if self.cls_call is not None:
  220. return self.cls_call(obj, *args, **kwargs)
  221. else:
  222. return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
  223. except Exception as e:
  224. assert e.__traceback__
  225. topmost_framesummary: traceback.FrameSummary = \
  226. traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
  227. if "eval_with_key" in topmost_framesummary.filename:
  228. print(_WrappedCall._generate_error_message(topmost_framesummary),
  229. file=sys.stderr)
  230. raise e.with_traceback(None)
  231. else:
  232. raise e
  233. @compatibility(is_backward_compatible=True)
  234. class GraphModule(torch.nn.Module):
  235. """
  236. GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a
  237. ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated
  238. from that ``graph``.
  239. .. warning::
  240. When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically
  241. regenerated. However, if you edit the contents of the ``graph`` without reassigning
  242. the ``graph`` attribute itself, you must call ``recompile()`` to update the generated
  243. code.
  244. """
  245. def __new__(cls: 'Type[GraphModule]', *args, **kwargs):
  246. # each instance of a graph module needs its own forward method
  247. # so create a new singleton class for each instance.
  248. # it is a subclass of the user-defined class, the only difference
  249. # is an extra layer to install the forward method
  250. # address issue described at https://github.com/pytorch/pytorch/issues/63883
  251. # in other words, traverse class hierarchy to fix the redundant class definition problem
  252. for t in cls.__mro__:
  253. c = t.__qualname__.split('.')[-1]
  254. if c != 'GraphModuleImpl':
  255. cls = t
  256. break
  257. class GraphModuleImpl(cls): # type: ignore[misc, valid-type]
  258. pass
  259. return super().__new__(GraphModuleImpl)
  260. @compatibility(is_backward_compatible=True)
  261. def __init__(self,
  262. root: Union[torch.nn.Module, Dict[str, Any]],
  263. graph: Graph,
  264. class_name: str = 'GraphModule'):
  265. """
  266. Construct a GraphModule.
  267. Args:
  268. root (Union[torch.nn.Module, Dict[str, Any]):
  269. ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type.
  270. In the case that ``root`` is a Module, any references to Module-based objects (via qualified
  271. name) in the Graph's Nodes' ``target`` field will be copied over from the respective place
  272. within ``root``'s Module hierarchy into the GraphModule's module hierarchy.
  273. In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be
  274. looked up directly in the dict's keys. The object mapped to by the Dict will be copied
  275. over into the appropriate place within the GraphModule's module hierarchy.
  276. graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation
  277. class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all
  278. error messages will report as originating from ``GraphModule``. It may be helpful to set this
  279. to ``root``'s original name or a name that makes sense within the context of your transform.
  280. """
  281. super().__init__()
  282. self.__class__.__name__ = class_name
  283. if isinstance(root, torch.nn.Module):
  284. if hasattr(root, 'training'):
  285. self.training = root.training
  286. for node in graph.nodes:
  287. if node.op in ['get_attr', 'call_module']:
  288. assert isinstance(node.target, str)
  289. _copy_attr(root, self, node.target)
  290. elif isinstance(root, dict):
  291. targets_to_copy = []
  292. for node in graph.nodes:
  293. if node.op in ['get_attr', 'call_module']:
  294. assert isinstance(node.target, str)
  295. if node.target not in root:
  296. raise RuntimeError('Node ' + str(node) + ' referenced target ' + node.target +
  297. ' but that target was not provided in ``root``!')
  298. targets_to_copy.append(node.target)
  299. # Sort targets in ascending order of the # of atoms.
  300. # This will ensure that less deeply nested attributes are assigned
  301. # before more deeply nested attributes. For example, foo.bar
  302. # will be assigned before foo.bar.baz. Otherwise, we might assign
  303. # the user-provided ``foo.bar`` and wipe out the previously-assigned
  304. # ``foo.bar.baz``
  305. targets_to_copy.sort(key=lambda t: t.count('.'))
  306. for target_to_copy in targets_to_copy:
  307. _assign_attr(root[target_to_copy], self, target_to_copy)
  308. else:
  309. raise RuntimeError('Unsupported type ' + str(root) + ' passed for root!')
  310. self.graph = graph
  311. # Store the Tracer class responsible for creating a Graph separately as part of the
  312. # GraphModule state, except when the Tracer is defined in a local namespace.
  313. # Locally defined Tracers are not pickleable. This is needed because torch.package will
  314. # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer
  315. # to re-create the Graph during deserialization.
  316. self._tracer_cls = None
  317. if self.graph._tracer_cls and '<locals>' not in self.graph._tracer_cls.__qualname__:
  318. self._tracer_cls = self.graph._tracer_cls
  319. self._tracer_extras = {}
  320. if self.graph._tracer_extras:
  321. self._tracer_extras = self.graph._tracer_extras
  322. # TorchScript breaks trying to compile the graph setter because of the
  323. # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
  324. #
  325. # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
  326. __jit_unused_properties__ = ['graph']
  327. @property
  328. def graph(self) -> Graph:
  329. """
  330. Return the ``Graph`` underlying this ``GraphModule``
  331. """
  332. return self._graph
  333. @graph.setter
  334. def graph(self, g : Graph) -> None:
  335. """
  336. Set the underlying ``Graph`` for this ``GraphModule``. This will internally
  337. recompile the ``GraphModule`` so that the generated ``forward()`` function
  338. corresponds to ``g``
  339. """
  340. assert isinstance(g, Graph), f'Expected a Graph instance, but got {type(g)}'
  341. self._graph = g
  342. g.owning_module = self
  343. self.recompile()
  344. @compatibility(is_backward_compatible=False)
  345. def to_folder(self, folder: Union[str, os.PathLike], module_name : str = "FxModule"):
  346. """Dumps out module to ``folder`` with ``module_name`` so that it can be
  347. imported with ``from <folder> import <module_name>``
  348. Args:
  349. folder (Union[str, os.PathLike]): The folder to write the code out to
  350. module_name (str): Top-level name to use for the ``Module`` while
  351. writing out the code
  352. """
  353. folder = Path(folder)
  354. Path(folder).mkdir(exist_ok=True)
  355. torch.save(self.state_dict(), folder / 'state_dict.pt')
  356. tab = " " * 4
  357. model_str = f"""
  358. import torch
  359. from torch.nn import *
  360. class {module_name}(torch.nn.Module):
  361. def __init__(self):
  362. super().__init__()
  363. """
  364. def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
  365. safe_reprs = [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]
  366. if type(module) in safe_reprs:
  367. return f"{module.__repr__()}"
  368. else:
  369. return None
  370. blobified_modules = []
  371. for module_name, module in self.named_children():
  372. module_str = _gen_model_repr(module_name, module)
  373. if module_str is None:
  374. module_file = folder / f'{module_name}.pt'
  375. torch.save(module, module_file)
  376. blobified_modules.append(module_name)
  377. module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
  378. module_str = f"torch.load(r'{module_file}') # {module_repr}"
  379. model_str += f"{tab*2}self.{module_name} = {module_str}\n"
  380. for buffer_name, buffer in self._buffers.items():
  381. if buffer is None:
  382. continue
  383. model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n"
  384. for param_name, param in self._parameters.items():
  385. if param is None:
  386. continue
  387. model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n"
  388. model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
  389. model_str += f"{_addindent(self.code, 4)}\n"
  390. module_file = folder / 'module.py'
  391. module_file.write_text(model_str)
  392. init_file = folder / '__init__.py'
  393. init_file.write_text('from .module import *')
  394. if len(blobified_modules) > 0:
  395. warnings.warn("Was not able to save the following children modules as reprs -"
  396. f"saved as pickled files instead: {blobified_modules}")
  397. @compatibility(is_backward_compatible=True)
  398. def add_submodule(self, target: str, m: torch.nn.Module) -> bool:
  399. """
  400. Adds the given submodule to ``self``.
  401. This installs empty Modules where none exist yet if they are
  402. subpaths of ``target``.
  403. Args:
  404. target: The fully-qualified string name of the new submodule
  405. (See example in ``nn.Module.get_submodule`` for how to
  406. specify a fully-qualified string.)
  407. m: The submodule itself; the actual object we want to
  408. install in the current Module
  409. Return:
  410. bool: Whether or not the submodule could be inserted. For
  411. this method to return True, each object in the chain
  412. denoted by ``target`` must either a) not exist yet,
  413. or b) reference an ``nn.Module`` (not a parameter or
  414. other attribute)
  415. """
  416. *prefix, field = target.split('.')
  417. mod: torch.nn.Module = self
  418. for item in prefix:
  419. submod = getattr(mod, item, None)
  420. if submod is None:
  421. submod = torch.nn.Module()
  422. setattr(mod, item, submod)
  423. if not isinstance(submod, torch.nn.Module):
  424. return False
  425. mod = submod
  426. mod.add_module(field, m)
  427. return True
  428. @compatibility(is_backward_compatible=True)
  429. def delete_submodule(self, target: str) -> bool:
  430. """
  431. Deletes the given submodule from ``self``.
  432. The module will not be deleted if ``target`` is not a valid
  433. target.
  434. Args:
  435. target: The fully-qualified string name of the new submodule
  436. (See example in ``nn.Module.get_submodule`` for how to
  437. specify a fully-qualified string.)
  438. Returns:
  439. bool: Whether or not the target string referenced a
  440. submodule we want to delete. A return value of ``False``
  441. means that the ``target`` was not a valid reference to
  442. a submodule.
  443. """
  444. atoms = target.split(".")
  445. path, target_submod = atoms[:-1], atoms[-1]
  446. mod: torch.nn.Module = self
  447. # Get the parent module
  448. for item in path:
  449. if not hasattr(mod, item):
  450. return False
  451. mod = getattr(mod, item)
  452. if not isinstance(mod, torch.nn.Module):
  453. return False
  454. if not hasattr(mod, target_submod):
  455. return False
  456. if not isinstance(getattr(mod, target_submod), torch.nn.Module):
  457. return False
  458. delattr(mod, target_submod)
  459. return True
  460. @compatibility(is_backward_compatible=True)
  461. def delete_all_unused_submodules(self) -> None:
  462. """
  463. Deletes all unused submodules from ``self``.
  464. A Module is considered "used" if any one of the following is
  465. true:
  466. 1. It has children that are used
  467. 2. Its forward is called directly via a ``call_module`` node
  468. 3. It has a non-Module attribute that is used from a
  469. ``get_attr`` node
  470. This method can be called to clean up an ``nn.Module`` without
  471. manually calling ``delete_submodule`` on each unused submodule.
  472. """
  473. used: List[str] = []
  474. for node in self.graph.nodes:
  475. if node.op == "call_module" or node.op == "get_attr":
  476. # A list of strings representing the different parts
  477. # of the path. For exmaple, `foo.bar.baz` gives us
  478. # ["foo", "bar", "baz"]
  479. fullpath = node.target.split(".")
  480. # If we're looking at multiple parts of a path, join
  481. # join them with a dot. Otherwise, return that single
  482. # element without doing anything to it.
  483. def join_fn(x: str, y: str) -> str:
  484. return '.'.join([x, y] if y else [x])
  485. # Progressively collect all the names of intermediate
  486. # modules. For example, if we have the target
  487. # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and
  488. # `foo.bar.baz` to the list.
  489. for path in itertools.accumulate(fullpath, join_fn):
  490. used.append(path)
  491. # For a `call_module` node, also register all recursive submodules
  492. # as used
  493. if node.op == "call_module":
  494. try:
  495. submod = self.get_submodule(node.target)
  496. for submod_name, _ in submod.named_modules():
  497. if submod_name != '':
  498. used.append('.'.join([node.target, submod_name]))
  499. except AttributeError:
  500. # Node referenced nonexistent submodule, don't need to
  501. # worry about GCing anything
  502. pass
  503. to_delete = [name for name, _ in self.named_modules()
  504. if name not in used]
  505. for name in to_delete:
  506. self.delete_submodule(name)
  507. @property
  508. def code(self) -> str:
  509. """
  510. Return the Python code generated from the ``Graph`` underlying this
  511. ``GraphModule``.
  512. """
  513. if not hasattr(self, '_code'):
  514. raise RuntimeError('Code has not been generated! Please report a bug to PyTorch')
  515. return self._code
  516. @compatibility(is_backward_compatible=True)
  517. def recompile(self) -> PythonCode:
  518. """
  519. Recompile this GraphModule from its ``graph`` attribute. This should be
  520. called after editing the contained ``graph``, otherwise the generated
  521. code of this ``GraphModule`` will be out of date.
  522. """
  523. if isinstance(self._graph._codegen, _PyTreeCodeGen):
  524. self._in_spec = self._graph._codegen.pytree_info.in_spec
  525. self._out_spec = self._graph._codegen.pytree_info.out_spec
  526. python_code = self._graph.python_code(root_module='self')
  527. self._code = python_code.src
  528. cls = type(self)
  529. cls.forward = _forward_from_src(self._code, python_code.globals)
  530. # Determine whether this class explicitly defines a __call__ implementation
  531. # to wrap. If it does, save it in order to have wrapped_call invoke it.
  532. # If it does not, wrapped_call can use a dynamic call to super() instead.
  533. # In most cases, super().__call__ should be torch.nn.Module.__call__.
  534. # We do not want to hold a reference to Module.__call__ here; doing so will
  535. # bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
  536. cls_call = cls.__call__ if "__call__" in vars(cls) else None
  537. if '_wrapped_call' not in vars(cls):
  538. cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
  539. def call_wrapped(self, *args, **kwargs):
  540. return self._wrapped_call(self, *args, **kwargs)
  541. cls.__call__ = call_wrapped
  542. return python_code
  543. # Passing Tracer as argument allows subclasses extending fx.GraphModule
  544. # define their own Tracer (extending fx.Tracer).
  545. def __reduce_deploy__(self, importer: Importer):
  546. dict_without_graph = self.__dict__.copy()
  547. dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
  548. del dict_without_graph['_graph']
  549. python_code = self.recompile()
  550. import_block = _format_import_block(python_code.globals, importer)
  551. return (reduce_deploy_graph_module, (dict_without_graph, import_block))
  552. def __reduce_package__(self, exporter: PackageExporter):
  553. dict_without_graph = self.__dict__.copy()
  554. dict_without_graph['_graphmodule_cls_name'] = self.__class__.__name__
  555. del dict_without_graph['_graph']
  556. generated_module_name = f'fx-generated._{exporter.get_unique_id()}'
  557. python_code = self.recompile()
  558. import_block = _format_import_block(python_code.globals, exporter.importer)
  559. module_code = import_block + self.code
  560. exporter.save_source_string(generated_module_name, module_code)
  561. return (reduce_package_graph_module, (dict_without_graph, generated_module_name))
  562. def __reduce__(self):
  563. """
  564. Serialization of GraphModule. We serialize only the generated code, not
  565. the underlying ``Graph``. This is because ``Graph`` does not have on-disk
  566. backward-compatibility guarantees, whereas Python source code does.
  567. On the deserialization side, we symbolically trace through the generated
  568. code to regenerate the underlying ``Graph``
  569. """
  570. dict_without_graph = self.__dict__.copy()
  571. python_code = self.recompile()
  572. import_block = _format_import_block(python_code.globals, sys_importer)
  573. del dict_without_graph['_graph']
  574. return (reduce_graph_module, (dict_without_graph, import_block))
  575. # because __reduce__ is defined for serialization,
  576. # we need to define deepcopy otherwise it will call __reduce__
  577. # and cause symbolic tracing to occur every time we try to copy the object
  578. def __deepcopy__(self, memo):
  579. fake_mod = torch.nn.Module()
  580. fake_mod.__dict__ = copy.deepcopy(self.__dict__)
  581. return GraphModule(fake_mod, fake_mod.__dict__['_graph'])
  582. def __copy__(self):
  583. return GraphModule(self, self.graph)
  584. def __str__(self) -> str:
  585. orig_str = super().__str__()
  586. return '\n'.join([orig_str, self._code])
  587. def _replicate_for_data_parallel(self):
  588. new_gm = self.__copy__()
  589. new_gm._is_replica = True
  590. return new_gm
  591. # workarounds for issues in __torch_function__
  592. # WAR for __torch_function__ not handling tensor lists,
  593. # fix is in https://github.com/pytorch/pytorch/pull/34725
  594. # orig_cat = torch.cat
  595. # def patched_cat(*args, **kwargs):
  596. # tensors = args[0]
  597. # for t in tensors:
  598. # if isinstance(t, Proxy):
  599. # return t.__torch_function__(patched_cat, (), args, kwargs)
  600. # return orig_cat(*args, **kwargs)
  601. # patched_cat.__module__ = 'torch'
  602. # patched_cat.__name__ = 'cat'
  603. # torch.cat = patched_cat