__init__.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960
  1. r"""
  2. The torch package contains data structures for multi-dimensional
  3. tensors and defines mathematical operations over these tensors.
  4. Additionally, it provides many utilities for efficient serializing of
  5. Tensors and arbitrary types, and other useful utilities.
  6. It has a CUDA counterpart, that enables you to run your tensor computations
  7. on an NVIDIA GPU with compute capability >= 3.0.
  8. """
  9. import os
  10. import sys
  11. import platform
  12. import textwrap
  13. import ctypes
  14. import warnings
  15. import inspect
  16. if sys.version_info < (3,):
  17. raise Exception("Python 2 has reached end-of-life and is no longer supported by PyTorch.")
  18. from ._utils import _import_dotted_name, classproperty
  19. from ._utils_internal import get_file_path, prepare_multiprocessing_environment, \
  20. USE_RTLD_GLOBAL_WITH_LIBTORCH, USE_GLOBAL_DEPS
  21. # TODO(torch_deploy) figure out how to freeze version.py in fbcode build
  22. if sys.executable == 'torch_deploy':
  23. __version__ = "torch-deploy-1.8"
  24. else:
  25. from .torch_version import __version__ as __version__
  26. from ._six import string_classes as _string_classes
  27. from typing import Set, Type, TYPE_CHECKING, Union, Callable
  28. import builtins
  29. __all__ = [
  30. 'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type',
  31. 'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed', 'seed',
  32. 'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul',
  33. 'no_grad', 'enable_grad', 'rand', 'randn', 'inference_mode',
  34. 'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage',
  35. 'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage',
  36. '_TypedStorage',
  37. 'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor',
  38. 'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor',
  39. 'lobpcg', 'use_deterministic_algorithms',
  40. 'are_deterministic_algorithms_enabled',
  41. 'is_deterministic_algorithms_warn_only_enabled',
  42. 'set_deterministic_debug_mode', 'get_deterministic_debug_mode',
  43. 'set_float32_matmul_precision', 'get_float32_matmul_precision',
  44. 'set_warn_always', 'is_warn_always_enabled',
  45. ]
  46. ################################################################################
  47. # Load the extension module
  48. ################################################################################
  49. if sys.platform == 'win32':
  50. pfiles_path = os.getenv('ProgramFiles', 'C:\\Program Files')
  51. py_dll_path = os.path.join(sys.exec_prefix, 'Library', 'bin')
  52. th_dll_path = os.path.join(os.path.dirname(__file__), 'lib')
  53. # When users create a virtualenv that inherits the base environment,
  54. # we will need to add the corresponding library directory into
  55. # DLL search directories. Otherwise, it will rely on `PATH` which
  56. # is dependent on user settings.
  57. if sys.exec_prefix != sys.base_exec_prefix:
  58. base_py_dll_path = os.path.join(sys.base_exec_prefix, 'Library', 'bin')
  59. else:
  60. base_py_dll_path = ''
  61. dll_paths = list(filter(os.path.exists, [th_dll_path, py_dll_path, base_py_dll_path]))
  62. if all([not os.path.exists(os.path.join(p, 'nvToolsExt64_1.dll')) for p in dll_paths]):
  63. nvtoolsext_dll_path = os.path.join(
  64. os.getenv('NVTOOLSEXT_PATH', os.path.join(pfiles_path, 'NVIDIA Corporation', 'NvToolsExt')), 'bin', 'x64')
  65. else:
  66. nvtoolsext_dll_path = ''
  67. from .version import cuda as cuda_version
  68. import glob
  69. if cuda_version and all([not glob.glob(os.path.join(p, 'cudart64*.dll')) for p in dll_paths]):
  70. cuda_version_1 = cuda_version.replace('.', '_')
  71. cuda_path_var = 'CUDA_PATH_V' + cuda_version_1
  72. default_path = os.path.join(pfiles_path, 'NVIDIA GPU Computing Toolkit', 'CUDA', 'v' + cuda_version)
  73. cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), 'bin')
  74. else:
  75. cuda_path = ''
  76. dll_paths.extend(filter(os.path.exists, [nvtoolsext_dll_path, cuda_path]))
  77. kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True)
  78. with_load_library_flags = hasattr(kernel32, 'AddDllDirectory')
  79. prev_error_mode = kernel32.SetErrorMode(0x0001)
  80. kernel32.LoadLibraryW.restype = ctypes.c_void_p
  81. if with_load_library_flags:
  82. kernel32.AddDllDirectory.restype = ctypes.c_void_p
  83. kernel32.LoadLibraryExW.restype = ctypes.c_void_p
  84. for dll_path in dll_paths:
  85. if sys.version_info >= (3, 8):
  86. os.add_dll_directory(dll_path)
  87. elif with_load_library_flags:
  88. res = kernel32.AddDllDirectory(dll_path)
  89. if res is None:
  90. err = ctypes.WinError(ctypes.get_last_error())
  91. err.strerror += f' Error adding "{dll_path}" to the DLL directories.'
  92. raise err
  93. try:
  94. ctypes.CDLL('vcruntime140.dll')
  95. ctypes.CDLL('msvcp140.dll')
  96. ctypes.CDLL('vcruntime140_1.dll')
  97. except OSError:
  98. print('''Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure.
  99. It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe''')
  100. dlls = glob.glob(os.path.join(th_dll_path, '*.dll'))
  101. path_patched = False
  102. for dll in dlls:
  103. is_loaded = False
  104. if with_load_library_flags:
  105. res = kernel32.LoadLibraryExW(dll, None, 0x00001100)
  106. last_error = ctypes.get_last_error()
  107. if res is None and last_error != 126:
  108. err = ctypes.WinError(last_error)
  109. err.strerror += f' Error loading "{dll}" or one of its dependencies.'
  110. raise err
  111. elif res is not None:
  112. is_loaded = True
  113. if not is_loaded:
  114. if not path_patched:
  115. os.environ['PATH'] = ';'.join(dll_paths + [os.environ['PATH']])
  116. path_patched = True
  117. res = kernel32.LoadLibraryW(dll)
  118. if res is None:
  119. err = ctypes.WinError(ctypes.get_last_error())
  120. err.strerror += f' Error loading "{dll}" or one of its dependencies.'
  121. raise err
  122. kernel32.SetErrorMode(prev_error_mode)
  123. # See Note [Global dependencies]
  124. def _load_global_deps():
  125. if platform.system() == 'Windows' or sys.executable == 'torch_deploy':
  126. return
  127. lib_name = 'libtorch_global_deps' + ('.dylib' if platform.system() == 'Darwin' else '.so')
  128. here = os.path.abspath(__file__)
  129. lib_path = os.path.join(os.path.dirname(here), 'lib', lib_name)
  130. ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
  131. if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv('TORCH_USE_RTLD_GLOBAL')) and \
  132. platform.system() != 'Windows':
  133. # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a
  134. # few circumstances:
  135. #
  136. # 1. You're in a build environment (e.g., fbcode) where
  137. # libtorch_global_deps is not available, but you still need
  138. # to get mkl to link in with RTLD_GLOBAL or it will just
  139. # not work.
  140. #
  141. # 2. You're trying to run PyTorch under UBSAN and you need
  142. # to ensure that only one copy of libtorch is loaded, so
  143. # vptr checks work properly
  144. #
  145. # If you're using this setting, you must verify that all the libraries
  146. # you load consistently use the same libstdc++, or you may have
  147. # mysterious segfaults.
  148. #
  149. import os as _dl_flags
  150. if not hasattr(_dl_flags, 'RTLD_GLOBAL') or not hasattr(_dl_flags, 'RTLD_LAZY'):
  151. try:
  152. # next try if DLFCN exists
  153. import DLFCN as _dl_flags # type: ignore[import, no-redef]
  154. except ImportError:
  155. # as a last attempt, use compile-time constants
  156. import torch._dl as _dl_flags # type: ignore[import, no-redef]
  157. old_flags = sys.getdlopenflags()
  158. sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_LAZY)
  159. from torch._C import * # noqa: F403
  160. sys.setdlopenflags(old_flags)
  161. del old_flags
  162. del _dl_flags
  163. else:
  164. # Easy way. You want this most of the time, because it will prevent
  165. # C++ symbols from libtorch clobbering C++ symbols from other
  166. # libraries, leading to mysterious segfaults.
  167. #
  168. # If building in an environment where libtorch_global_deps isn't available
  169. # like parts of fbsource, but where RTLD_GLOBAL causes segfaults, you will
  170. # want USE_RTLD_GLOBAL_WITH_LIBTORCH = False and USE_GLOBAL_DEPS = False
  171. #
  172. # See Note [Global dependencies]
  173. if USE_GLOBAL_DEPS:
  174. _load_global_deps()
  175. from torch._C import * # noqa: F403
  176. # Appease the type checker; ordinarily this binding is inserted by the
  177. # torch._C module initialization code in C
  178. if TYPE_CHECKING:
  179. import torch._C as _C
  180. # Check to see if we can load C extensions, and if not provide some guidance
  181. # on what the problem might be.
  182. try:
  183. # _initExtension is chosen (arbitrarily) as a sentinel.
  184. from torch._C import _initExtension
  185. except ImportError:
  186. import torch._C as _C_for_compiled_check
  187. # The __file__ check only works for Python 3.7 and above.
  188. if sys.version_info >= (3, 7) and _C_for_compiled_check.__file__ is None:
  189. raise ImportError(textwrap.dedent('''
  190. Failed to load PyTorch C extensions:
  191. It appears that PyTorch has loaded the `torch/_C` folder
  192. of the PyTorch repository rather than the C extensions which
  193. are expected in the `torch._C` namespace. This can occur when
  194. using the `install` workflow. e.g.
  195. $ python setup.py install && python -c "import torch"
  196. This error can generally be solved using the `develop` workflow
  197. $ python setup.py develop && python -c "import torch" # This should succeed
  198. or by running Python from a different directory.
  199. ''').strip()) from None
  200. raise # If __file__ is not None the cause is unknown, so just re-raise.
  201. for name in dir(_C):
  202. if name[0] != '_' and not name.endswith('Base'):
  203. __all__.append(name)
  204. obj = getattr(_C, name)
  205. if (isinstance(obj, Callable) or inspect.isclass(obj)): # type: ignore[arg-type]
  206. if (obj.__module__ != 'torch'):
  207. # TODO: fix their module from C++ side
  208. if name not in ['DisableTorchFunction', 'Generator']:
  209. obj.__module__ = 'torch'
  210. if not TYPE_CHECKING:
  211. # issue 38137 and python issue 43367. Submodules of a C extension are
  212. # non-standard, and attributes of those submodules cannot be pickled since
  213. # pickle expect to be able to import them as "from _C.sub import attr"
  214. # which fails with "_C is not a package
  215. for attr in dir(_C):
  216. candidate = getattr(_C, attr)
  217. if type(candidate) is type(_C):
  218. # submodule
  219. if f'torch._C.{attr}' not in sys.modules:
  220. sys.modules[f'torch._C.{attr}'] = candidate
  221. ################################################################################
  222. # Define basic utilities
  223. ################################################################################
  224. def typename(o):
  225. if isinstance(o, torch.Tensor):
  226. return o.type()
  227. module = ''
  228. class_name = ''
  229. if hasattr(o, '__module__') and o.__module__ != 'builtins' \
  230. and o.__module__ != '__builtin__' and o.__module__ is not None:
  231. module = o.__module__ + '.'
  232. if hasattr(o, '__qualname__'):
  233. class_name = o.__qualname__
  234. elif hasattr(o, '__name__'):
  235. class_name = o.__name__
  236. else:
  237. class_name = o.__class__.__name__
  238. return module + class_name
  239. def is_tensor(obj):
  240. r"""Returns True if `obj` is a PyTorch tensor.
  241. Note that this function is simply doing ``isinstance(obj, Tensor)``.
  242. Using that ``isinstance`` check is better for typechecking with mypy,
  243. and more explicit - so it's recommended to use that instead of
  244. ``is_tensor``.
  245. Args:
  246. obj (Object): Object to test
  247. Example::
  248. >>> x=torch.tensor([1,2,3])
  249. >>> torch.is_tensor(x)
  250. True
  251. """
  252. return isinstance(obj, torch.Tensor)
  253. def is_storage(obj):
  254. r"""Returns True if `obj` is a PyTorch storage object.
  255. Args:
  256. obj (Object): Object to test
  257. """
  258. return type(obj) in _storage_classes
  259. def set_default_tensor_type(t):
  260. r"""Sets the default ``torch.Tensor`` type to floating point tensor type
  261. ``t``. This type will also be used as default floating point type for
  262. type inference in :func:`torch.tensor`.
  263. The default floating point tensor type is initially ``torch.FloatTensor``.
  264. Args:
  265. t (type or string): the floating point tensor type or its name
  266. Example::
  267. >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
  268. torch.float32
  269. >>> torch.set_default_tensor_type(torch.DoubleTensor)
  270. >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
  271. torch.float64
  272. """
  273. if isinstance(t, _string_classes):
  274. t = _import_dotted_name(t)
  275. _C._set_default_tensor_type(t)
  276. def set_default_dtype(d):
  277. r"""
  278. Sets the default floating point dtype to :attr:`d`. Supports torch.float32
  279. and torch.float64 as inputs. Other dtypes may be accepted without complaint
  280. but are not supported and are unlikely to work as expected.
  281. When PyTorch is initialized its default floating point dtype is torch.float32,
  282. and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like
  283. type inference. The default floating point dtype is used to:
  284. 1. Implicitly determine the default complex dtype. When the default floating point
  285. type is float32 the default complex dtype is complex64, and when the default
  286. floating point type is float64 the default complex type is complex128.
  287. 2. Infer the dtype for tensors constructed using Python floats or complex Python
  288. numbers. See examples below.
  289. 3. Determine the result of type promotion between bool and integer tensors and
  290. Python floats and complex Python numbers.
  291. Args:
  292. d (:class:`torch.dtype`): the floating point dtype to make the default.
  293. Either torch.float32 or torch.float64.
  294. Example:
  295. >>> # initial default for floating point is torch.float32
  296. >>> # Python floats are interpreted as float32
  297. >>> torch.tensor([1.2, 3]).dtype
  298. torch.float32
  299. >>> # initial default for floating point is torch.complex64
  300. >>> # Complex Python numbers are interpreted as complex64
  301. >>> torch.tensor([1.2, 3j]).dtype
  302. torch.complex64
  303. >>> torch.set_default_dtype(torch.float64)
  304. >>> # Python floats are now interpreted as float64
  305. >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor
  306. torch.float64
  307. >>> # Complex Python numbers are now interpreted as complex128
  308. >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor
  309. torch.complex128
  310. """
  311. _C._set_default_dtype(d)
  312. def use_deterministic_algorithms(mode, *, warn_only=False):
  313. r""" Sets whether PyTorch operations must use "deterministic"
  314. algorithms. That is, algorithms which, given the same input, and when
  315. run on the same software and hardware, always produce the same output.
  316. When enabled, operations will use deterministic algorithms when available,
  317. and if only nondeterministic algorithms are available they will throw a
  318. :class:`RuntimeError` when called.
  319. .. note:: :func:`torch.set_deterministic_debug_mode` offers an alternative
  320. interface for this feature.
  321. The following normally-nondeterministic operations will act
  322. deterministically when ``mode=True``:
  323. * :class:`torch.nn.Conv1d` when called on CUDA tensor
  324. * :class:`torch.nn.Conv2d` when called on CUDA tensor
  325. * :class:`torch.nn.Conv3d` when called on CUDA tensor
  326. * :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor
  327. * :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor
  328. * :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor
  329. * :func:`torch.bmm` when called on sparse-dense CUDA tensors
  330. * :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor
  331. and the index is a list of tensors
  332. * :func:`torch.Tensor.index_put` with ``accumulate=False``
  333. * :func:`torch.Tensor.index_put` with ``accumulate=True`` when called on a CPU
  334. tensor
  335. * :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU
  336. tensor
  337. * :func:`torch.Tensor.scatter_add_` when ``input`` dimension is one and called
  338. on a CUDA tensor
  339. * :func:`torch.gather` when ``input`` dimension is one and called
  340. on a CUDA tensor that requires grad
  341. * :func:`torch.index_add` when called on CUDA tensor
  342. * :func:`torch.index_select` when attempting to differentiate a CUDA tensor
  343. * :func:`torch.repeat_interleave` when attempting to differentiate a CUDA tensor
  344. * :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor
  345. The following normally-nondeterministic operations will throw a
  346. :class:`RuntimeError` when ``mode=True``:
  347. * :class:`torch.nn.AvgPool3d` when attempting to differentiate a CUDA tensor
  348. * :class:`torch.nn.AdaptiveAvgPool2d` when attempting to differentiate a CUDA tensor
  349. * :class:`torch.nn.AdaptiveAvgPool3d` when attempting to differentiate a CUDA tensor
  350. * :class:`torch.nn.MaxPool3d` when attempting to differentiate a CUDA tensor
  351. * :class:`torch.nn.AdaptiveMaxPool2d` when attempting to differentiate a CUDA tensor
  352. * :class:`torch.nn.FractionalMaxPool2d` when attempting to differentiate a CUDA tensor
  353. * :class:`torch.nn.FractionalMaxPool3d` when attempting to differentiate a CUDA tensor
  354. * :func:`torch.nn.functional.interpolate` when attempting to differentiate a CUDA tensor
  355. and one of the following modes is used:
  356. - ``linear``
  357. - ``bilinear``
  358. - ``bicubic``
  359. - ``trilinear``
  360. * :class:`torch.nn.ReflectionPad1d` when attempting to differentiate a CUDA tensor
  361. * :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor
  362. * :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor
  363. * :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor
  364. * :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor
  365. * :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor
  366. * :class:`torch.nn.NLLLoss` when called on a CUDA tensor
  367. * :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor
  368. * :class:`torch.nn.EmbeddingBag` when attempting to differentiate a CUDA tensor when
  369. ``mode='max'``
  370. * :func:`torch.Tensor.scatter_add_` when ``input`` dimension is larger than one
  371. and called on a CUDA tensor
  372. * :func:`torch.gather` when ``input`` dimension is larger than one
  373. and called on a CUDA tensor that requires grad
  374. * :func:`torch.Tensor.put_` when ``accumulate=False``
  375. * :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor
  376. * :func:`torch.histc` when called on a CUDA tensor
  377. * :func:`torch.bincount` when called on a CUDA tensor
  378. * :func:`torch.kthvalue` with called on a CUDA tensor
  379. * :func:`torch.median` with indices output when called on a CUDA tensor
  380. * :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor
  381. A handful of CUDA operations are nondeterministic if the CUDA version is
  382. 10.2 or greater, unless the environment variable ``CUBLAS_WORKSPACE_CONFIG=:4096:8``
  383. or ``CUBLAS_WORKSPACE_CONFIG=:16:8`` is set. See the CUDA documentation for more
  384. details: `<https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility>`_
  385. If one of these environment variable configurations is not set, a :class:`RuntimeError`
  386. will be raised from these operations when called with CUDA tensors:
  387. * :func:`torch.mm`
  388. * :func:`torch.mv`
  389. * :func:`torch.bmm`
  390. Note that deterministic operations tend to have worse performance than
  391. nondeterministic operations.
  392. .. note::
  393. This flag does not detect or prevent nondeterministic behavior caused
  394. by calling an inplace operation on a tensor with an internal memory
  395. overlap or by giving such a tensor as the :attr:`out` argument for an
  396. operation. In these cases, multiple writes of different data may target
  397. a single memory location, and the order of writes is not guaranteed.
  398. Args:
  399. mode (:class:`bool`): If True, makes potentially nondeterministic
  400. operations switch to a deterministic algorithm or throw a runtime
  401. error. If False, allows nondeterministic operations.
  402. Keyword args:
  403. warn_only (:class:`bool`, optional): If True, operations that do not
  404. have a deterministic implementation will throw a warning instead of
  405. an error. Default: ``False``
  406. Example::
  407. >>> torch.use_deterministic_algorithms(True)
  408. # Forward mode nondeterministic error
  409. >>> torch.randn(10).index_copy(0, torch.tensor([0]), torch.randn(1))
  410. ...
  411. RuntimeError: index_copy does not have a deterministic implementation...
  412. # Backward mode nondeterministic error
  413. >>> torch.randn(10, requires_grad=True, device='cuda').index_select(0, torch.tensor([0], device='cuda')).backward()
  414. ...
  415. RuntimeError: index_add_cuda_ does not have a deterministic implementation...
  416. """
  417. _C._set_deterministic_algorithms(mode, warn_only=warn_only)
  418. def are_deterministic_algorithms_enabled():
  419. r"""Returns True if the global deterministic flag is turned on. Refer to
  420. :func:`torch.use_deterministic_algorithms` documentation for more details.
  421. """
  422. return _C._get_deterministic_algorithms()
  423. def is_deterministic_algorithms_warn_only_enabled():
  424. r"""Returns True if the global deterministic flag is set to warn only.
  425. Refer to :func:`torch.use_deterministic_algorithms` documentation for more
  426. details.
  427. """
  428. return _C._get_deterministic_algorithms_warn_only()
  429. def set_deterministic_debug_mode(debug_mode: Union[builtins.int, str]) -> None:
  430. r"""Sets the debug mode for deterministic operations.
  431. .. note:: This is an alternative interface for
  432. :func:`torch.use_deterministic_algorithms`. Refer to that function's
  433. documentation for details about affected operations.
  434. Args:
  435. debug_mode(str or int): If "default" or 0, don't error or warn on
  436. nondeterministic operations. If "warn" or 1, warn on
  437. nondeterministic operations. If "error" or 2, error on
  438. nondeterministic operations.
  439. """
  440. # NOTE: builtins.int is used here because int in this scope resolves
  441. # to torch.int
  442. if not isinstance(debug_mode, (builtins.int, str)):
  443. raise TypeError(f'debug_mode must be str or int, but got {type(debug_mode)}')
  444. if isinstance(debug_mode, str):
  445. if debug_mode == 'default':
  446. debug_mode = 0
  447. elif debug_mode == 'warn':
  448. debug_mode = 1
  449. elif debug_mode == 'error':
  450. debug_mode = 2
  451. else:
  452. raise RuntimeError(
  453. 'invalid value of debug_mode, expected one of `default`, '
  454. f'`warn`, `error`, but got {debug_mode}')
  455. if debug_mode == 0:
  456. _C._set_deterministic_algorithms(False)
  457. elif debug_mode == 1:
  458. _C._set_deterministic_algorithms(True, warn_only=True)
  459. elif debug_mode == 2:
  460. _C._set_deterministic_algorithms(True)
  461. else:
  462. raise RuntimeError(
  463. 'invalid value of debug_mode, expected 0, 1, or 2, '
  464. f'but got {debug_mode}')
  465. def get_deterministic_debug_mode() -> builtins.int:
  466. r"""Returns the current value of the debug mode for deterministic
  467. operations. Refer to :func:`torch.set_deterministic_debug_mode`
  468. documentation for more details.
  469. """
  470. if _C._get_deterministic_algorithms():
  471. if _C._get_deterministic_algorithms_warn_only():
  472. return 1
  473. else:
  474. return 2
  475. else:
  476. return 0
  477. def get_float32_matmul_precision() -> builtins.str:
  478. r"""Returns the current value of float32 matrix multiplication precision. Refer to
  479. :func:`torch.set_float32_matmul_precision` documentation for more details.
  480. """
  481. return _C._get_float32_matmul_precision()
  482. def set_float32_matmul_precision(precision):
  483. r"""Sets the internal precision of float32 matrix multiplications.
  484. Running float32 matrix multiplications in lower precision may significantly increase
  485. performance, and in some programs the loss of precision has a negligible impact.
  486. Supports three settings:
  487. * "highest", float32 matrix multiplications use the float32 datatype for
  488. internal computations.
  489. * "high", float32 matrix multiplications use the TensorFloat32 or bfloat16_3x
  490. datatypes for internal computations, if fast matrix multiplication algorithms
  491. using those datatypes internally are available. Otherwise float32
  492. matrix multiplications are computed as if the precision is "highest".
  493. * "medium", float32 matrix multiplications use the bfloat16 datatype for
  494. internal computations, if a fast matrix multiplication algorithm
  495. using that datatype internally is available. Otherwise float32
  496. matrix multiplications are computed as if the precision is "high".
  497. .. note::
  498. This does not change the output dtype of float32 matrix multiplications,
  499. it controls how the internal computation of the matrix multiplication is performed.
  500. .. note::
  501. This does not change the precision of convolution operations. Other flags,
  502. like `torch.backends.cudnn.allow_tf32`, may control the precision of convolution
  503. operations.
  504. .. note::
  505. This flag currently only affects one native device type: CUDA.
  506. If "high" or "medium" are set then the TensorFloat32 datatype will be used
  507. when computing float32 matrix multiplications, equivalent to setting
  508. `torch.backends.cuda.matmul.allow_tf32 = True`. When "highest" (the default)
  509. is set then the float32 datatype is used for internal computations, equivalent
  510. to setting `torch.backends.cuda.matmul.allow_tf32 = False`.
  511. Args:
  512. precision(str): can be set to "highest" (default), "high", or "medium" (see above).
  513. """
  514. _C._set_float32_matmul_precision(precision)
  515. def set_warn_always(b):
  516. r"""When this flag is False (default) then some PyTorch warnings may only
  517. appear once per process. This helps avoid excessive warning information.
  518. Setting it to True causes these warnings to always appear, which may be
  519. helpful when debugging.
  520. Args:
  521. b (:class:`bool`): If True, force warnings to always be emitted
  522. If False, set to the default behaviour
  523. """
  524. _C._set_warnAlways(b)
  525. def is_warn_always_enabled():
  526. r"""Returns True if the global warn_always flag is turned on. Refer to
  527. :func:`torch.set_warn_always` documentation for more details.
  528. """
  529. return _C._get_warnAlways()
  530. ################################################################################
  531. # Define numeric constants
  532. ################################################################################
  533. # For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and
  534. # NumPy consistency (https://numpy.org/devdocs/reference/constants.html)
  535. from math import e , nan , inf , pi
  536. __all__.extend(['e', 'pi', 'nan', 'inf'])
  537. ################################################################################
  538. # Define Storage and Tensor classes
  539. ################################################################################
  540. from ._tensor import Tensor
  541. from .storage import _StorageBase, _TypedStorage, _LegacyStorage, _UntypedStorage
  542. # NOTE: New <type>Storage classes should never be added. When adding a new
  543. # dtype, use torch.storage._TypedStorage directly.
  544. class ByteStorage(_LegacyStorage):
  545. @classproperty
  546. def dtype(self):
  547. return torch.uint8
  548. class DoubleStorage(_LegacyStorage):
  549. @classproperty
  550. def dtype(self):
  551. return torch.double
  552. class FloatStorage(_LegacyStorage):
  553. @classproperty
  554. def dtype(self):
  555. return torch.float
  556. class HalfStorage(_LegacyStorage):
  557. @classproperty
  558. def dtype(self):
  559. return torch.half
  560. class LongStorage(_LegacyStorage):
  561. @classproperty
  562. def dtype(self):
  563. return torch.long
  564. class IntStorage(_LegacyStorage):
  565. @classproperty
  566. def dtype(self):
  567. return torch.int
  568. class ShortStorage(_LegacyStorage):
  569. @classproperty
  570. def dtype(self):
  571. return torch.short
  572. class CharStorage(_LegacyStorage):
  573. @classproperty
  574. def dtype(self):
  575. return torch.int8
  576. class BoolStorage(_LegacyStorage):
  577. @classproperty
  578. def dtype(self):
  579. return torch.bool
  580. class BFloat16Storage(_LegacyStorage):
  581. @classproperty
  582. def dtype(self):
  583. return torch.bfloat16
  584. class ComplexDoubleStorage(_LegacyStorage):
  585. @classproperty
  586. def dtype(self):
  587. return torch.cdouble
  588. class ComplexFloatStorage(_LegacyStorage):
  589. @classproperty
  590. def dtype(self):
  591. return torch.cfloat
  592. class QUInt8Storage(_LegacyStorage):
  593. @classproperty
  594. def dtype(self):
  595. return torch.quint8
  596. class QInt8Storage(_LegacyStorage):
  597. @classproperty
  598. def dtype(self):
  599. return torch.qint8
  600. class QInt32Storage(_LegacyStorage):
  601. @classproperty
  602. def dtype(self):
  603. return torch.qint32
  604. class QUInt4x2Storage(_LegacyStorage):
  605. @classproperty
  606. def dtype(self):
  607. return torch.quint4x2
  608. class QUInt2x4Storage(_LegacyStorage):
  609. @classproperty
  610. def dtype(self):
  611. return torch.quint2x4
  612. _storage_classes = {
  613. _UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage,
  614. ShortStorage, CharStorage, ByteStorage, HalfStorage, BoolStorage,
  615. QUInt8Storage, QInt8Storage, QInt32Storage, BFloat16Storage,
  616. ComplexFloatStorage, ComplexDoubleStorage, QUInt4x2Storage, QUInt2x4Storage,
  617. _TypedStorage
  618. }
  619. # The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
  620. _tensor_classes: Set[Type] = set()
  621. # If you edit these imports, please update torch/__init__.py.in as well
  622. from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed
  623. from .serialization import save, load
  624. from ._tensor_str import set_printoptions
  625. ################################################################################
  626. # Initialize extension
  627. ################################################################################
  628. def manager_path():
  629. if platform.system() == 'Windows' or sys.executable == 'torch_deploy':
  630. return b""
  631. path = get_file_path('torch', 'bin', 'torch_shm_manager')
  632. prepare_multiprocessing_environment(get_file_path('torch'))
  633. if not os.path.exists(path):
  634. raise RuntimeError("Unable to find torch_shm_manager at " + path)
  635. return path.encode('utf-8')
  636. from torch.amp import autocast
  637. # Shared memory manager needs to know the exact location of manager executable
  638. _C._initExtension(manager_path())
  639. del manager_path
  640. # Appease the type checker: it can't deal with direct setting of globals().
  641. # Note that we will see "too many" functions when reexporting this way; there
  642. # is not a good way to fix this problem. Perhaps, try to redesign VariableFunctions
  643. # so that this import is good enough
  644. if TYPE_CHECKING:
  645. # Some type signatures pulled in from _VariableFunctions here clash with
  646. # signatures already imported. For now these clashes are ignored; see
  647. # PR #43339 for details.
  648. from torch._C._VariableFunctions import * # type: ignore[misc] # noqa: F403
  649. # Ops not to be exposed in `torch` namespace,
  650. # mostly helper ops.
  651. PRIVATE_OPS = (
  652. 'unique_dim',
  653. )
  654. for name in dir(_C._VariableFunctions):
  655. if name.startswith('__') or name in PRIVATE_OPS:
  656. continue
  657. obj = getattr(_C._VariableFunctions, name)
  658. obj.__module__ = 'torch'
  659. globals()[name] = obj
  660. if not name.startswith("_"):
  661. __all__.append(name)
  662. ################################################################################
  663. # Import interface functions defined in Python
  664. ################################################################################
  665. # needs to be after the above ATen bindings so we can overwrite from Python side
  666. from .functional import * # noqa: F403
  667. ################################################################################
  668. # Remove unnecessary members
  669. ################################################################################
  670. del _StorageBase
  671. del _LegacyStorage
  672. ################################################################################
  673. # Define _assert
  674. ################################################################################
  675. # needs to be before the submodule imports to avoid circular dependencies
  676. def _assert(condition, message):
  677. r"""A wrapper around Python's assert which is symbolically traceable.
  678. """
  679. from .overrides import has_torch_function, handle_torch_function
  680. if type(condition) is not torch.Tensor and has_torch_function((condition,)):
  681. return handle_torch_function(_assert, (condition,), condition, message)
  682. assert condition, message
  683. ################################################################################
  684. # Import most common subpackages
  685. ################################################################################
  686. # Use the redundant form so that type checkers know that these are a part of
  687. # the public API. The "regular" import lines are there solely for the runtime
  688. # side effect of adding to the imported module's members for other users.
  689. from torch import cuda as cuda
  690. from torch import cpu as cpu
  691. from torch import autograd as autograd
  692. from torch.autograd import (
  693. no_grad as no_grad,
  694. enable_grad as enable_grad,
  695. set_grad_enabled as set_grad_enabled,
  696. inference_mode as inference_mode,
  697. )
  698. from torch import fft as fft
  699. from torch import futures as futures
  700. from torch import nn as nn
  701. from torch import optim as optim
  702. import torch.optim._multi_tensor
  703. from torch import multiprocessing as multiprocessing
  704. from torch import sparse as sparse
  705. from torch import special as special
  706. import torch.utils.backcompat
  707. from torch import onnx as onnx
  708. from torch import jit as jit
  709. from torch import linalg as linalg
  710. from torch import hub as hub
  711. from torch import random as random
  712. from torch import distributions as distributions
  713. from torch import testing as testing
  714. import torch.backends.cuda
  715. import torch.backends.mps
  716. import torch.backends.cudnn
  717. import torch.backends.mkl
  718. import torch.backends.mkldnn
  719. import torch.backends.openmp
  720. import torch.backends.quantized
  721. import torch.utils.data
  722. from torch import __config__ as __config__
  723. from torch import __future__ as __future__
  724. from torch import profiler as profiler
  725. # Quantized, sparse, AO, etc. should be last to get imported, as nothing
  726. # is expected to depend on them.
  727. import torch.nn.intrinsic
  728. import torch.nn.quantizable
  729. import torch.nn.quantized
  730. # AO depends on nn, as well as quantized stuff -- so should be after those.
  731. from torch import ao as ao
  732. _C._init_names(list(torch._storage_classes))
  733. # attach docstrings to torch and tensor functions
  734. from . import _torch_docs, _tensor_docs, _storage_docs
  735. del _torch_docs, _tensor_docs, _storage_docs
  736. def compiled_with_cxx11_abi():
  737. r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1"""
  738. return _C._GLIBCXX_USE_CXX11_ABI
  739. # Import the ops "namespace"
  740. from torch._ops import ops
  741. from torch._classes import classes
  742. # quantization depends on torch.fx
  743. # Import quantization
  744. from torch import quantization as quantization
  745. # Import the quasi random sampler
  746. from torch import quasirandom as quasirandom
  747. # If you are seeing this, it means that this call site was not checked if
  748. # the memory format could be preserved, and it was switched to old default
  749. # behaviour of contiguous
  750. legacy_contiguous_format = contiguous_format
  751. # Register fork handler to initialize OpenMP in child processes (see gh-28389)
  752. from torch.multiprocessing._atfork import register_after_fork
  753. register_after_fork(torch.get_num_threads)
  754. del register_after_fork
  755. # Import tools that require fully imported torch (for applying
  756. # torch.jit.script as a decorator, for instance):
  757. from ._lobpcg import lobpcg as lobpcg
  758. # These were previously defined in native_functions.yaml and appeared on the
  759. # `torch` namespace, but we moved them to c10 dispatch to facilitate custom
  760. # class usage. We add these lines here to preserve backward compatibility.
  761. quantized_lstm = torch.ops.aten.quantized_lstm
  762. quantized_gru = torch.ops.aten.quantized_gru
  763. from torch.utils.dlpack import from_dlpack, to_dlpack
  764. # Import experimental masked operations support. See
  765. # [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more
  766. # information.
  767. from . import _masked
  768. # Import removed ops with error message about removal
  769. from ._linalg_utils import solve
  770. def _register_device_module(device_type, module):
  771. r"""Register an external runtime module of the specific :attr:`device_type`
  772. supported by torch.
  773. After the :attr:`module` is registered correctly, the user can refer
  774. the external runtime module as part of torch with attribute torch.xxx.
  775. """
  776. # Make sure the device_type represent a supported device type for torch.
  777. device_type = torch.device(device_type).type
  778. m = sys.modules[__name__]
  779. if hasattr(m, device_type):
  780. raise RuntimeError("The runtime module of '{}' has already "
  781. "been registered with '{}'".format(device_type, getattr(m, device_type)))
  782. setattr(m, device_type, module)
  783. # expose return_types
  784. from . import return_types
  785. if sys.executable != 'torch_deploy':
  786. from . import library
  787. from . import _meta_registrations