extension.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import ctypes
  2. import os
  3. import sys
  4. from warnings import warn
  5. import torch
  6. from ._internally_replaced_utils import _get_extension_path
  7. _HAS_OPS = False
  8. def _has_ops():
  9. return False
  10. try:
  11. lib_path = _get_extension_path("_C")
  12. torch.ops.load_library(lib_path)
  13. _HAS_OPS = True
  14. def _has_ops(): # noqa: F811
  15. return True
  16. except (ImportError, OSError):
  17. pass
  18. def _assert_has_ops():
  19. if not _has_ops():
  20. raise RuntimeError(
  21. "Couldn't load custom C++ ops. This can happen if your PyTorch and "
  22. "torchvision versions are incompatible, or if you had errors while compiling "
  23. "torchvision from source. For further information on the compatible versions, check "
  24. "https://github.com/pytorch/vision#installation for the compatibility matrix. "
  25. "Please check your PyTorch version with torch.__version__ and your torchvision "
  26. "version with torchvision.__version__ and verify if they are compatible, and if not "
  27. "please reinstall torchvision so that it matches your PyTorch install."
  28. )
  29. def _check_cuda_version():
  30. """
  31. Make sure that CUDA versions match between the pytorch install and torchvision install
  32. """
  33. if not _HAS_OPS:
  34. return -1
  35. import torch
  36. _version = torch.ops.torchvision._cuda_version()
  37. if _version != -1 and torch.version.cuda is not None:
  38. tv_version = str(_version)
  39. if int(tv_version) < 10000:
  40. tv_major = int(tv_version[0])
  41. tv_minor = int(tv_version[2])
  42. else:
  43. tv_major = int(tv_version[0:2])
  44. tv_minor = int(tv_version[3])
  45. t_version = torch.version.cuda
  46. t_version = t_version.split(".")
  47. t_major = int(t_version[0])
  48. t_minor = int(t_version[1])
  49. if t_major != tv_major or t_minor != tv_minor:
  50. raise RuntimeError(
  51. "Detected that PyTorch and torchvision were compiled with different CUDA versions. "
  52. f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has "
  53. f"CUDA Version={tv_major}.{tv_minor}. "
  54. "Please reinstall the torchvision that matches your PyTorch install."
  55. )
  56. return _version
  57. def _load_library(lib_name):
  58. lib_path = _get_extension_path(lib_name)
  59. # On Windows Python-3.8+ has `os.add_dll_directory` call,
  60. # which is called from _get_extension_path to configure dll search path
  61. # Condition below adds a workaround for older versions by
  62. # explicitly calling `LoadLibraryExW` with the following flags:
  63. # - LOAD_LIBRARY_SEARCH_DEFAULT_DIRS (0x1000)
  64. # - LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR (0x100)
  65. if os.name == "nt" and sys.version_info < (3, 8):
  66. _kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
  67. if hasattr(_kernel32, "LoadLibraryExW"):
  68. _kernel32.LoadLibraryExW(lib_path, None, 0x00001100)
  69. else:
  70. warn("LoadLibraryExW is missing in kernel32.dll")
  71. torch.ops.load_library(lib_path)
  72. _check_cuda_version()