symbolic_registry.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import importlib
  2. import inspect
  3. import itertools
  4. import warnings
  5. from typing import Any, Callable, Dict, Tuple, Union
  6. from torch import _C
  7. from torch.onnx import _constants
  8. _SymbolicFunction = Callable[..., Union[_C.Value, Tuple[_C.Value]]]
  9. """
  10. The symbolic registry "_registry" is a dictionary that maps operators
  11. (for a specific domain and opset version) to their symbolic functions.
  12. An operator is defined by its domain, opset version, and opname.
  13. The keys are tuples (domain, version), (where domain is a string, and version is an int),
  14. and the operator's name (string).
  15. The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
  16. """
  17. _registry: Dict[
  18. Tuple[str, int],
  19. Dict[str, _SymbolicFunction],
  20. ] = {}
  21. _symbolic_versions: Dict[Union[int, str], Any] = {}
  22. def _import_symbolic_opsets():
  23. for opset_version in itertools.chain(
  24. _constants.onnx_stable_opsets, [_constants.onnx_main_opset]
  25. ):
  26. module = importlib.import_module(
  27. "torch.onnx.symbolic_opset{}".format(opset_version)
  28. )
  29. global _symbolic_versions
  30. _symbolic_versions[opset_version] = module
  31. def register_version(domain: str, version: int):
  32. if not is_registered_version(domain, version):
  33. global _registry
  34. _registry[(domain, version)] = {}
  35. register_ops_in_version(domain, version)
  36. def register_ops_helper(domain: str, version: int, iter_version: int):
  37. for domain, op_name, op_func in get_ops_in_version(iter_version):
  38. if not is_registered_op(op_name, domain, version):
  39. register_op(op_name, op_func, domain, version)
  40. def register_ops_in_version(domain: str, version: int):
  41. # iterates through the symbolic functions of
  42. # the specified opset version, and the previous
  43. # opset versions for operators supported in
  44. # previous versions.
  45. # Opset 9 is the base version. It is selected as the base version because
  46. # 1. It is the first opset version supported by PyTorch export.
  47. # 2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
  48. # that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
  49. # we chose to handle them as special cases separately.
  50. # Backward support for opset versions beyond opset 7 is not in our roadmap.
  51. # For opset versions other than 9, by default they will inherit the symbolic functions defined in
  52. # symbolic_opset9.py.
  53. # To extend support for updated operators in different opset versions on top of opset 9,
  54. # simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
  55. # Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
  56. iter_version = version
  57. while iter_version != 9:
  58. register_ops_helper(domain, version, iter_version)
  59. if iter_version > 9:
  60. iter_version = iter_version - 1
  61. else:
  62. iter_version = iter_version + 1
  63. register_ops_helper(domain, version, 9)
  64. def get_ops_in_version(version: int):
  65. if not _symbolic_versions:
  66. _import_symbolic_opsets()
  67. members = inspect.getmembers(_symbolic_versions[version])
  68. domain_opname_ops = []
  69. for obj in members:
  70. if isinstance(obj[1], type) and hasattr(obj[1], "domain"):
  71. ops = inspect.getmembers(obj[1], predicate=inspect.isfunction)
  72. for op in ops:
  73. domain_opname_ops.append((obj[1].domain, op[0], op[1])) # type: ignore[attr-defined]
  74. elif inspect.isfunction(obj[1]):
  75. if obj[0] == "_len":
  76. obj = ("len", obj[1])
  77. if obj[0] == "_list":
  78. obj = ("list", obj[1])
  79. if obj[0] == "_any":
  80. obj = ("any", obj[1])
  81. if obj[0] == "_all":
  82. obj = ("all", obj[1])
  83. domain_opname_ops.append(("", obj[0], obj[1]))
  84. return domain_opname_ops
  85. def is_registered_version(domain: str, version: int):
  86. global _registry
  87. return (domain, version) in _registry
  88. def register_op(opname, op, domain, version):
  89. if domain is None or version is None:
  90. warnings.warn(
  91. "ONNX export failed. The ONNX domain and/or version to register are None."
  92. )
  93. global _registry
  94. if not is_registered_version(domain, version):
  95. _registry[(domain, version)] = {}
  96. _registry[(domain, version)][opname] = op
  97. def is_registered_op(opname: str, domain: str, version: int):
  98. if domain is None or version is None:
  99. warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
  100. global _registry
  101. return (domain, version) in _registry and opname in _registry[(domain, version)]
  102. def unregister_op(opname: str, domain: str, version: int):
  103. global _registry
  104. if is_registered_op(opname, domain, version):
  105. del _registry[(domain, version)][opname]
  106. if not _registry[(domain, version)]:
  107. del _registry[(domain, version)]
  108. else:
  109. warnings.warn("The opname " + opname + " is not registered.")
  110. def get_op_supported_version(opname: str, domain: str, version: int):
  111. iter_version = version
  112. while iter_version <= _constants.onnx_main_opset:
  113. ops = [(op[0], op[1]) for op in get_ops_in_version(iter_version)]
  114. if (domain, opname) in ops:
  115. return iter_version
  116. iter_version += 1
  117. return None
  118. def get_registered_op(opname: str, domain: str, version: int) -> _SymbolicFunction:
  119. if domain is None or version is None:
  120. warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
  121. global _registry
  122. if not is_registered_op(opname, domain, version):
  123. raise UnsupportedOperatorError(domain, opname, version)
  124. return _registry[(domain, version)][opname]
  125. class UnsupportedOperatorError(RuntimeError):
  126. def __init__(self, domain: str, opname: str, version: int):
  127. supported_version = get_op_supported_version(opname, domain, version)
  128. if domain in {"", "aten", "prim", "quantized"}:
  129. msg = f"Exporting the operator {domain}::{opname} to ONNX opset version {version} is not supported. "
  130. if supported_version is not None:
  131. msg += (
  132. f"Support for this operator was added in version {supported_version}, "
  133. "try exporting with this version."
  134. )
  135. else:
  136. msg += "Please feel free to request support or submit a pull request on PyTorch GitHub."
  137. else:
  138. msg = (
  139. f"ONNX export failed on an operator with unrecognized namespace {domain}::{opname}. "
  140. "If you are trying to export a custom operator, make sure you registered "
  141. "it with the right domain and version."
  142. )
  143. super().__init__(msg)