| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- import importlib
- import inspect
- import itertools
- import warnings
- from typing import Any, Callable, Dict, Tuple, Union
- from torch import _C
- from torch.onnx import _constants
- _SymbolicFunction = Callable[..., Union[_C.Value, Tuple[_C.Value]]]
- """
- The symbolic registry "_registry" is a dictionary that maps operators
- (for a specific domain and opset version) to their symbolic functions.
- An operator is defined by its domain, opset version, and opname.
- The keys are tuples (domain, version), (where domain is a string, and version is an int),
- and the operator's name (string).
- The map's entries are as follows : _registry[(domain, version)][op_name] = op_symbolic
- """
- _registry: Dict[
- Tuple[str, int],
- Dict[str, _SymbolicFunction],
- ] = {}
- _symbolic_versions: Dict[Union[int, str], Any] = {}
- def _import_symbolic_opsets():
- for opset_version in itertools.chain(
- _constants.onnx_stable_opsets, [_constants.onnx_main_opset]
- ):
- module = importlib.import_module(
- "torch.onnx.symbolic_opset{}".format(opset_version)
- )
- global _symbolic_versions
- _symbolic_versions[opset_version] = module
- def register_version(domain: str, version: int):
- if not is_registered_version(domain, version):
- global _registry
- _registry[(domain, version)] = {}
- register_ops_in_version(domain, version)
- def register_ops_helper(domain: str, version: int, iter_version: int):
- for domain, op_name, op_func in get_ops_in_version(iter_version):
- if not is_registered_op(op_name, domain, version):
- register_op(op_name, op_func, domain, version)
- def register_ops_in_version(domain: str, version: int):
- # iterates through the symbolic functions of
- # the specified opset version, and the previous
- # opset versions for operators supported in
- # previous versions.
- # Opset 9 is the base version. It is selected as the base version because
- # 1. It is the first opset version supported by PyTorch export.
- # 2. opset 9 is more robust than previous opset versions. Opset versions like 7/8 have limitations
- # that certain basic operators cannot be expressed in ONNX. Instead of basing on these limitations,
- # we chose to handle them as special cases separately.
- # Backward support for opset versions beyond opset 7 is not in our roadmap.
- # For opset versions other than 9, by default they will inherit the symbolic functions defined in
- # symbolic_opset9.py.
- # To extend support for updated operators in different opset versions on top of opset 9,
- # simply add the updated symbolic functions in the respective symbolic_opset{version}.py file.
- # Checkout topk in symbolic_opset10.py, and upsample_nearest2d in symbolic_opset8.py for example.
- iter_version = version
- while iter_version != 9:
- register_ops_helper(domain, version, iter_version)
- if iter_version > 9:
- iter_version = iter_version - 1
- else:
- iter_version = iter_version + 1
- register_ops_helper(domain, version, 9)
- def get_ops_in_version(version: int):
- if not _symbolic_versions:
- _import_symbolic_opsets()
- members = inspect.getmembers(_symbolic_versions[version])
- domain_opname_ops = []
- for obj in members:
- if isinstance(obj[1], type) and hasattr(obj[1], "domain"):
- ops = inspect.getmembers(obj[1], predicate=inspect.isfunction)
- for op in ops:
- domain_opname_ops.append((obj[1].domain, op[0], op[1])) # type: ignore[attr-defined]
- elif inspect.isfunction(obj[1]):
- if obj[0] == "_len":
- obj = ("len", obj[1])
- if obj[0] == "_list":
- obj = ("list", obj[1])
- if obj[0] == "_any":
- obj = ("any", obj[1])
- if obj[0] == "_all":
- obj = ("all", obj[1])
- domain_opname_ops.append(("", obj[0], obj[1]))
- return domain_opname_ops
- def is_registered_version(domain: str, version: int):
- global _registry
- return (domain, version) in _registry
- def register_op(opname, op, domain, version):
- if domain is None or version is None:
- warnings.warn(
- "ONNX export failed. The ONNX domain and/or version to register are None."
- )
- global _registry
- if not is_registered_version(domain, version):
- _registry[(domain, version)] = {}
- _registry[(domain, version)][opname] = op
- def is_registered_op(opname: str, domain: str, version: int):
- if domain is None or version is None:
- warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
- global _registry
- return (domain, version) in _registry and opname in _registry[(domain, version)]
- def unregister_op(opname: str, domain: str, version: int):
- global _registry
- if is_registered_op(opname, domain, version):
- del _registry[(domain, version)][opname]
- if not _registry[(domain, version)]:
- del _registry[(domain, version)]
- else:
- warnings.warn("The opname " + opname + " is not registered.")
- def get_op_supported_version(opname: str, domain: str, version: int):
- iter_version = version
- while iter_version <= _constants.onnx_main_opset:
- ops = [(op[0], op[1]) for op in get_ops_in_version(iter_version)]
- if (domain, opname) in ops:
- return iter_version
- iter_version += 1
- return None
- def get_registered_op(opname: str, domain: str, version: int) -> _SymbolicFunction:
- if domain is None or version is None:
- warnings.warn("ONNX export failed. The ONNX domain and/or version are None.")
- global _registry
- if not is_registered_op(opname, domain, version):
- raise UnsupportedOperatorError(domain, opname, version)
- return _registry[(domain, version)][opname]
- class UnsupportedOperatorError(RuntimeError):
- def __init__(self, domain: str, opname: str, version: int):
- supported_version = get_op_supported_version(opname, domain, version)
- if domain in {"", "aten", "prim", "quantized"}:
- msg = f"Exporting the operator {domain}::{opname} to ONNX opset version {version} is not supported. "
- if supported_version is not None:
- msg += (
- f"Support for this operator was added in version {supported_version}, "
- "try exporting with this version."
- )
- else:
- msg += "Please feel free to request support or submit a pull request on PyTorch GitHub."
- else:
- msg = (
- f"ONNX export failed on an operator with unrecognized namespace {domain}::{opname}. "
- "If you are trying to export a custom operator, make sure you registered "
- "it with the right domain and version."
- )
- super().__init__(msg)
|