importer.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. import importlib
  2. from abc import ABC, abstractmethod
  3. from pickle import _getattribute, _Pickler # type: ignore[attr-defined]
  4. from pickle import whichmodule as _pickle_whichmodule # type: ignore[attr-defined]
  5. from types import ModuleType
  6. from typing import Any, List, Optional, Tuple, Dict
  7. from ._mangling import demangle, get_mangle_prefix, is_mangled
  8. class ObjNotFoundError(Exception):
  9. """Raised when an importer cannot find an object by searching for its name."""
  10. pass
  11. class ObjMismatchError(Exception):
  12. """Raised when an importer found a different object with the same name as the user-provided one."""
  13. pass
  14. class Importer(ABC):
  15. """Represents an environment to import modules from.
  16. By default, you can figure out what module an object belongs by checking
  17. __module__ and importing the result using __import__ or importlib.import_module.
  18. torch.package introduces module importers other than the default one.
  19. Each PackageImporter introduces a new namespace. Potentially a single
  20. name (e.g. 'foo.bar') is present in multiple namespaces.
  21. It supports two main operations:
  22. import_module: module_name -> module object
  23. get_name: object -> (parent module name, name of obj within module)
  24. The guarantee is that following round-trip will succeed or throw an ObjNotFoundError/ObjMisMatchError.
  25. module_name, obj_name = env.get_name(obj)
  26. module = env.import_module(module_name)
  27. obj2 = getattr(module, obj_name)
  28. assert obj1 is obj2
  29. """
  30. modules: Dict[str, ModuleType]
  31. @abstractmethod
  32. def import_module(self, module_name: str) -> ModuleType:
  33. """Import `module_name` from this environment.
  34. The contract is the same as for importlib.import_module.
  35. """
  36. pass
  37. def get_name(self, obj: Any, name: Optional[str] = None) -> Tuple[str, str]:
  38. """Given an object, return a name that can be used to retrieve the
  39. object from this environment.
  40. Args:
  41. obj: An object to get the the module-environment-relative name for.
  42. name: If set, use this name instead of looking up __name__ or __qualname__ on `obj`.
  43. This is only here to match how Pickler handles __reduce__ functions that return a string,
  44. don't use otherwise.
  45. Returns:
  46. A tuple (parent_module_name, attr_name) that can be used to retrieve `obj` from this environment.
  47. Use it like:
  48. mod = importer.import_module(parent_module_name)
  49. obj = getattr(mod, attr_name)
  50. Raises:
  51. ObjNotFoundError: we couldn't retrieve `obj by name.
  52. ObjMisMatchError: we found a different object with the same name as `obj`.
  53. """
  54. if name is None and obj and _Pickler.dispatch.get(type(obj)) is None:
  55. # Honor the string return variant of __reduce__, which will give us
  56. # a global name to search for in this environment.
  57. # TODO: I guess we should do copyreg too?
  58. reduce = getattr(obj, "__reduce__", None)
  59. if reduce is not None:
  60. try:
  61. rv = reduce()
  62. if isinstance(rv, str):
  63. name = rv
  64. except Exception:
  65. pass
  66. if name is None:
  67. name = getattr(obj, "__qualname__", None)
  68. if name is None:
  69. name = obj.__name__
  70. orig_module_name = self.whichmodule(obj, name)
  71. # Demangle the module name before importing. If this obj came out of a
  72. # PackageImporter, `__module__` will be mangled. See mangling.md for
  73. # details.
  74. module_name = demangle(orig_module_name)
  75. # Check that this name will indeed return the correct object
  76. try:
  77. module = self.import_module(module_name)
  78. obj2, _ = _getattribute(module, name)
  79. except (ImportError, KeyError, AttributeError):
  80. raise ObjNotFoundError(
  81. f"{obj} was not found as {module_name}.{name}"
  82. ) from None
  83. if obj is obj2:
  84. return module_name, name
  85. def get_obj_info(obj):
  86. assert name is not None
  87. module_name = self.whichmodule(obj, name)
  88. is_mangled_ = is_mangled(module_name)
  89. location = (
  90. get_mangle_prefix(module_name)
  91. if is_mangled_
  92. else "the current Python environment"
  93. )
  94. importer_name = (
  95. f"the importer for {get_mangle_prefix(module_name)}"
  96. if is_mangled_
  97. else "'sys_importer'"
  98. )
  99. return module_name, location, importer_name
  100. obj_module_name, obj_location, obj_importer_name = get_obj_info(obj)
  101. obj2_module_name, obj2_location, obj2_importer_name = get_obj_info(obj2)
  102. msg = (
  103. f"\n\nThe object provided is from '{obj_module_name}', "
  104. f"which is coming from {obj_location}."
  105. f"\nHowever, when we import '{obj2_module_name}', it's coming from {obj2_location}."
  106. "\nTo fix this, make sure this 'PackageExporter's importer lists "
  107. f"{obj_importer_name} before {obj2_importer_name}."
  108. )
  109. raise ObjMismatchError(msg)
  110. def whichmodule(self, obj: Any, name: str) -> str:
  111. """Find the module name an object belongs to.
  112. This should be considered internal for end-users, but developers of
  113. an importer can override it to customize the behavior.
  114. Taken from pickle.py, but modified to exclude the search into sys.modules
  115. """
  116. module_name = getattr(obj, "__module__", None)
  117. if module_name is not None:
  118. return module_name
  119. # Protect the iteration by using a list copy of self.modules against dynamic
  120. # modules that trigger imports of other modules upon calls to getattr.
  121. for module_name, module in self.modules.copy().items():
  122. if (
  123. module_name == "__main__"
  124. or module_name == "__mp_main__" # bpo-42406
  125. or module is None
  126. ):
  127. continue
  128. try:
  129. if _getattribute(module, name)[0] is obj:
  130. return module_name
  131. except AttributeError:
  132. pass
  133. return "__main__"
  134. class _SysImporter(Importer):
  135. """An importer that implements the default behavior of Python."""
  136. def import_module(self, module_name: str):
  137. return importlib.import_module(module_name)
  138. def whichmodule(self, obj: Any, name: str) -> str:
  139. return _pickle_whichmodule(obj, name)
  140. sys_importer = _SysImporter()
  141. class OrderedImporter(Importer):
  142. """A compound importer that takes a list of importers and tries them one at a time.
  143. The first importer in the list that returns a result "wins".
  144. """
  145. def __init__(self, *args):
  146. self._importers: List[Importer] = list(args)
  147. def _is_torchpackage_dummy(self, module):
  148. """Returns true iff this module is an empty PackageNode in a torch.package.
  149. If you intern `a.b` but never use `a` in your code, then `a` will be an
  150. empty module with no source. This can break cases where we are trying to
  151. re-package an object after adding a real dependency on `a`, since
  152. OrderedImportere will resolve `a` to the dummy package and stop there.
  153. See: https://github.com/pytorch/pytorch/pull/71520#issuecomment-1029603769
  154. """
  155. if not getattr(module, "__torch_package__", False):
  156. return False
  157. if not hasattr(module, "__path__"):
  158. return False
  159. if not hasattr(module, "__file__"):
  160. return True
  161. return module.__file__ is None
  162. def import_module(self, module_name: str) -> ModuleType:
  163. last_err = None
  164. for importer in self._importers:
  165. if not isinstance(importer, Importer):
  166. raise TypeError(
  167. f"{importer} is not a Importer. "
  168. "All importers in OrderedImporter must inherit from Importer."
  169. )
  170. try:
  171. module = importer.import_module(module_name)
  172. if self._is_torchpackage_dummy(module):
  173. continue
  174. return module
  175. except ModuleNotFoundError as err:
  176. last_err = err
  177. if last_err is not None:
  178. raise last_err
  179. else:
  180. raise ModuleNotFoundError(module_name)
  181. def whichmodule(self, obj: Any, name: str) -> str:
  182. for importer in self._importers:
  183. module_name = importer.whichmodule(obj, name)
  184. if module_name != "__main__":
  185. return module_name
  186. return "__main__"