_globals.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. """Globals used internally by the ONNX exporter.
  2. Do not use this module outside of `torch.onnx` and its tests.
  3. Be very judicious when adding any new global variables. Do not create new global
  4. variables unless they are absolutely necessary.
  5. """
  6. from typing import Optional
  7. import torch._C._onnx as _C_onnx
  8. # This module should only depend on _constants and nothing else in torch.onnx to keep
  9. # dependency direction clean.
  10. from torch.onnx import _constants
  11. class _InternalGlobals:
  12. """Globals used internally by ONNX exporter.
  13. NOTE: Be very judicious when adding any new variables. Do not create new
  14. global variables unless they are absolutely necessary.
  15. """
  16. def __init__(self):
  17. self._export_onnx_opset_version = _constants.onnx_default_opset
  18. self.operator_export_type: Optional[_C_onnx.OperatorExportTypes] = None
  19. self.training_mode: Optional[_C_onnx.TrainingMode] = None
  20. self.onnx_shape_inference: bool = False
  21. @property
  22. def export_onnx_opset_version(self):
  23. return self._export_onnx_opset_version
  24. @export_onnx_opset_version.setter
  25. def export_onnx_opset_version(self, value: int):
  26. supported_versions = [_constants.onnx_main_opset]
  27. supported_versions.extend(_constants.onnx_stable_opsets)
  28. if value not in supported_versions:
  29. raise ValueError(f"Unsupported ONNX opset version: {value}")
  30. self._export_onnx_opset_version = value
  31. GLOBALS = _InternalGlobals()