__init__.py 4.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import os
  2. import sys
  3. import warnings
  4. try:
  5. from caffe2.proto import caffe2_pb2
  6. except ImportError:
  7. warnings.warn('Caffe2 support is not enabled in this PyTorch build. '
  8. 'Please enable Caffe2 by building PyTorch from source with `BUILD_CAFFE2=1` flag.')
  9. raise
  10. # TODO: refactor & remove the following alias
  11. caffe2_pb2.CPU = caffe2_pb2.PROTO_CPU
  12. caffe2_pb2.CUDA = caffe2_pb2.PROTO_CUDA
  13. caffe2_pb2.MKLDNN = caffe2_pb2.PROTO_MKLDNN
  14. caffe2_pb2.OPENGL = caffe2_pb2.PROTO_OPENGL
  15. caffe2_pb2.OPENCL = caffe2_pb2.PROTO_OPENCL
  16. caffe2_pb2.IDEEP = caffe2_pb2.PROTO_IDEEP
  17. caffe2_pb2.HIP = caffe2_pb2.PROTO_HIP
  18. caffe2_pb2.COMPILE_TIME_MAX_DEVICE_TYPES = caffe2_pb2.PROTO_COMPILE_TIME_MAX_DEVICE_TYPES
  19. if sys.platform == "win32":
  20. is_conda = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
  21. py_dll_path = os.path.join(os.path.dirname(sys.executable), 'Library', 'bin')
  22. th_root = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'torch')
  23. th_dll_path = os.path.join(th_root, 'lib')
  24. if not os.path.exists(os.path.join(th_dll_path, 'nvToolsExt64_1.dll')) and \
  25. not os.path.exists(os.path.join(py_dll_path, 'nvToolsExt64_1.dll')):
  26. nvtoolsext_dll_path = os.path.join(
  27. os.getenv('NVTOOLSEXT_PATH', 'C:\\Program Files\\NVIDIA Corporation\\NvToolsExt'), 'bin', 'x64')
  28. else:
  29. nvtoolsext_dll_path = ''
  30. import importlib.util
  31. import glob
  32. spec = importlib.util.spec_from_file_location('torch_version', os.path.join(th_root, 'version.py'))
  33. torch_version = importlib.util.module_from_spec(spec)
  34. spec.loader.exec_module(torch_version)
  35. if torch_version.cuda and len(glob.glob(os.path.join(th_dll_path, 'cudart64*.dll'))) == 0 and \
  36. len(glob.glob(os.path.join(py_dll_path, 'cudart64*.dll'))) == 0:
  37. cuda_version = torch_version.cuda
  38. cuda_version_1 = cuda_version.replace('.', '_')
  39. cuda_path_var = 'CUDA_PATH_V' + cuda_version_1
  40. default_path = 'C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v' + cuda_version
  41. cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), 'bin')
  42. else:
  43. cuda_path = ''
  44. import ctypes
  45. kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True)
  46. dll_paths = list(filter(os.path.exists, [th_dll_path, py_dll_path, nvtoolsext_dll_path, cuda_path]))
  47. with_load_library_flags = hasattr(kernel32, 'AddDllDirectory')
  48. prev_error_mode = kernel32.SetErrorMode(0x0001)
  49. kernel32.LoadLibraryW.restype = ctypes.c_void_p
  50. if with_load_library_flags:
  51. kernel32.AddDllDirectory.restype = ctypes.c_void_p
  52. kernel32.LoadLibraryExW.restype = ctypes.c_void_p
  53. for dll_path in dll_paths:
  54. if sys.version_info >= (3, 8):
  55. os.add_dll_directory(dll_path)
  56. elif with_load_library_flags:
  57. res = kernel32.AddDllDirectory(dll_path)
  58. if res is None:
  59. err = ctypes.WinError(ctypes.get_last_error())
  60. err.strerror += ' Error adding "{}" to the DLL directories.'.format(dll_path)
  61. raise err
  62. dlls = glob.glob(os.path.join(th_dll_path, '*.dll'))
  63. path_patched = False
  64. for dll in dlls:
  65. is_loaded = False
  66. if with_load_library_flags:
  67. res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
  68. last_error = ctypes.get_last_error()
  69. if res is None and last_error != 126:
  70. err = ctypes.WinError(last_error)
  71. err.strerror += ' Error loading "{}" or one of its dependencies.'.format(dll)
  72. raise err
  73. elif res is not None:
  74. is_loaded = True
  75. if not is_loaded:
  76. if not path_patched:
  77. os.environ['PATH'] = ';'.join(dll_paths + [os.environ['PATH']])
  78. path_patched = True
  79. res = kernel32.LoadLibraryW(dll)
  80. if res is None:
  81. err = ctypes.WinError(ctypes.get_last_error())
  82. err.strerror += ' Error loading "{}" or one of its dependencies.'.format(dll)
  83. raise err
  84. kernel32.SetErrorMode(prev_error_mode)