torch_version.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from typing import Any, Iterable
  2. from .version import __version__ as internal_version
  3. class _LazyImport:
  4. """Wraps around classes lazy imported from packaging.version
  5. Output of the function v in following snippets are identical:
  6. from packaging.version import Version
  7. def v():
  8. return Version('1.2.3')
  9. and
  10. Version = _LazyImport('Version')
  11. def v():
  12. return Version('1.2.3')
  13. The difference here is that in later example imports
  14. do not happen until v is called
  15. """
  16. def __init__(self, cls_name: str) -> None:
  17. self._cls_name = cls_name
  18. def get_cls(self):
  19. try:
  20. import packaging.version # type: ignore[import]
  21. except ImportError:
  22. # If packaging isn't installed, try and use the vendored copy
  23. # in pkg_resources
  24. from pkg_resources import packaging # type: ignore[attr-defined]
  25. return getattr(packaging.version, self._cls_name)
  26. def __call__(self, *args, **kwargs):
  27. return self.get_cls()(*args, **kwargs)
  28. def __instancecheck__(self, obj):
  29. return isinstance(obj, self.get_cls())
  30. Version = _LazyImport("Version")
  31. InvalidVersion = _LazyImport("InvalidVersion")
  32. class TorchVersion(str):
  33. """A string with magic powers to compare to both Version and iterables!
  34. Prior to 1.10.0 torch.__version__ was stored as a str and so many did
  35. comparisons against torch.__version__ as if it were a str. In order to not
  36. break them we have TorchVersion which masquerades as a str while also
  37. having the ability to compare against both packaging.version.Version as
  38. well as tuples of values, eg. (1, 2, 1)
  39. Examples:
  40. Comparing a TorchVersion object to a Version object
  41. TorchVersion('1.10.0a') > Version('1.10.0a')
  42. Comparing a TorchVersion object to a Tuple object
  43. TorchVersion('1.10.0a') > (1, 2) # 1.2
  44. TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1
  45. Comparing a TorchVersion object against a string
  46. TorchVersion('1.10.0a') > '1.2'
  47. TorchVersion('1.10.0a') > '1.2.1'
  48. """
  49. # fully qualified type names here to appease mypy
  50. def _convert_to_version(self, inp: Any) -> Any:
  51. if isinstance(inp, Version.get_cls()):
  52. return inp
  53. elif isinstance(inp, str):
  54. return Version(inp)
  55. elif isinstance(inp, Iterable):
  56. # Ideally this should work for most cases by attempting to group
  57. # the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH)
  58. # Examples:
  59. # * (1) -> Version("1")
  60. # * (1, 20) -> Version("1.20")
  61. # * (1, 20, 1) -> Version("1.20.1")
  62. return Version('.'.join((str(item) for item in inp)))
  63. else:
  64. raise InvalidVersion(inp)
  65. def _cmp_wrapper(self, cmp: Any, method: str) -> bool:
  66. try:
  67. return getattr(Version(self), method)(self._convert_to_version(cmp))
  68. except BaseException as e:
  69. if not isinstance(e, InvalidVersion.get_cls()):
  70. raise
  71. # Fall back to regular string comparison if dealing with an invalid
  72. # version like 'parrot'
  73. return getattr(super(), method)(cmp)
  74. for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]:
  75. setattr(TorchVersion, cmp_method, lambda x, y, method=cmp_method: x._cmp_wrapper(y, method))
  76. __version__ = TorchVersion(internal_version)