Context.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. #pragma once
  2. #include <ATen/core/ATenGeneral.h>
  3. #include <ATen/core/Generator.h>
  4. #include <ATen/CPUGeneratorImpl.h>
  5. #include <ATen/LinalgBackend.h>
  6. #include <ATen/core/LegacyTypeDispatch.h>
  7. #include <ATen/core/DeprecatedTypeProperties.h>
  8. #include <ATen/detail/CUDAHooksInterface.h>
  9. #include <ATen/detail/HIPHooksInterface.h>
  10. #include <ATen/detail/ORTHooksInterface.h>
  11. #include <c10/util/Exception.h>
  12. #include <c10/core/impl/DeviceGuardImplInterface.h>
  13. #include <c10/core/QEngine.h>
  14. #include <c10/util/irange.h>
  15. #include <memory>
  16. #include <mutex>
  17. #include <cstdint>
  18. namespace at {
  19. class Tensor;
  20. enum class TORCH_API Float32MatmulPrecision {HIGHEST, HIGH, MEDIUM};
  21. class TORCH_API Context {
  22. public:
  23. Context();
  24. const Generator& defaultGenerator(Device device) {
  25. DeviceType device_type = device.type();
  26. initCUDAIfNeeded(device_type);
  27. initHIPIfNeeded(device_type);
  28. if (device_type == at::kCPU) {
  29. return at::detail::getDefaultCPUGenerator();
  30. } else if (device_type == at::kCUDA) {
  31. return at::detail::getCUDAHooks().getDefaultCUDAGenerator(device.index());
  32. } else {
  33. AT_ERROR(DeviceTypeName(device_type), " device type not enabled.");
  34. }
  35. }
  36. Device getDeviceFromPtr(void* data, DeviceType device_type) {
  37. initCUDAIfNeeded(device_type);
  38. initHIPIfNeeded(device_type);
  39. if (device_type == at::kCPU) {
  40. return DeviceType::CPU;
  41. } else if (device_type == at::kCUDA) {
  42. return at::detail::getCUDAHooks().getDeviceFromPtr(data);
  43. } else {
  44. AT_ERROR(DeviceTypeName(device_type), " device type not enabled.");
  45. }
  46. }
  47. static bool isPinnedPtr(void* data) {
  48. return detail::getCUDAHooks().isPinnedPtr(data);
  49. }
  50. static bool hasOpenMP() ;
  51. static bool hasMKL() ;
  52. static bool hasLAPACK() ;
  53. static bool hasMKLDNN() ;
  54. static bool hasMAGMA() {
  55. return detail::getCUDAHooks().hasMAGMA();
  56. }
  57. static bool hasCUDA() {
  58. return detail::getCUDAHooks().hasCUDA();
  59. }
  60. static bool hasCUDART() {
  61. return detail::getCUDAHooks().hasCUDART();
  62. }
  63. static long versionCUDART() {
  64. return detail::getCUDAHooks().versionCUDART();
  65. }
  66. static bool hasCuDNN() {
  67. return detail::getCUDAHooks().hasCuDNN();
  68. }
  69. static long versionCuDNN() {
  70. return detail::getCUDAHooks().versionCuDNN();
  71. }
  72. static bool hasCuSOLVER() {
  73. return detail::getCUDAHooks().hasCuSOLVER();
  74. }
  75. static bool hasHIP() {
  76. return detail::getHIPHooks().hasHIP();
  77. }
  78. static bool hasIPU() {
  79. return c10::impl::hasDeviceGuardImpl(at::DeviceType::IPU);
  80. }
  81. static bool hasXLA() {
  82. return c10::impl::hasDeviceGuardImpl(at::DeviceType::XLA);
  83. }
  84. static bool hasLazy() {
  85. return c10::impl::hasDeviceGuardImpl(at::DeviceType::Lazy);
  86. }
  87. static bool hasMPS();
  88. static bool hasORT() {
  89. return c10::impl::hasDeviceGuardImpl(at::DeviceType::ORT);
  90. }
  91. // defined in header so that getNonVariableType has ability to inline
  92. // call_once check. getNonVariableType is called fairly frequently
  93. void lazyInitCUDA() {
  94. std::call_once(thc_init,[&] {
  95. detail::getCUDAHooks().initCUDA();
  96. });
  97. }
  98. void lazyInitHIP() {
  99. std::call_once(thh_init,[&] {
  100. detail::getHIPHooks().initHIP();
  101. });
  102. }
  103. static const at::cuda::NVRTC& getNVRTC() {
  104. return detail::getCUDAHooks().nvrtc();
  105. }
  106. static bool setFlushDenormal(bool on);
  107. // NB: This method is *purely* whether or not a user requested
  108. // that CuDNN was enabled, it doesn't actually say anything about
  109. // whether or not CuDNN is actually usable. Use cudnn_is_acceptable
  110. // to test this instead
  111. bool userEnabledCuDNN() const;
  112. void setUserEnabledCuDNN(bool e);
  113. bool userEnabledMkldnn() const;
  114. void setUserEnabledMkldnn(bool e);
  115. bool benchmarkCuDNN() const;
  116. void setBenchmarkCuDNN(bool);
  117. bool deterministicCuDNN() const;
  118. void setDeterministicCuDNN(bool);
  119. at::LinalgBackend linalgPreferredBackend() const;
  120. void setLinalgPreferredBackend(at::LinalgBackend);
  121. // Note [Enabling Deterministic Operations]
  122. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  123. // Operations in PyTorch that normally act nondeterministically, but have an alternate
  124. // deterministic implementation, should satisfy the following requirements:
  125. //
  126. // * Include this comment: "See Note [Enabling Deterministic Operations]"
  127. //
  128. // * Check the value of `at::globalContext().deterministicAlgorithms()` to toggle
  129. // between nondeterministic and deterministic implementations.
  130. //
  131. // * Have an entry in the list of PyTorch operations that toggle between nondeterministic
  132. // and deterministic implementations, in the docstring of `use_deterministic_algorithms()`
  133. // in torch/__init__.py
  134. //
  135. // `example_func()` below shows an example of toggling between nondeterministic and
  136. // deterministic implementations:
  137. //
  138. // void example_func() {
  139. // // See Note [Enabling Deterministic Operations]
  140. // if (at::globalContext().deterministicAlgorithms()) {
  141. // example_func_deterministic();
  142. // } else {
  143. // example_func_nondeterministic();
  144. // }
  145. // }
  146. bool deterministicAlgorithms() const;
  147. bool deterministicAlgorithmsWarnOnly() const;
  148. void setDeterministicAlgorithms(bool, bool);
  149. // Note [Writing Nondeterministic Operations]
  150. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  151. // Operations in PyTorch that act nondeterministically and do not have an alternate
  152. // deterministic implementation should satisfy the following requirements:
  153. //
  154. // * Include this comment: "See Note [Writing Nondeterministic Operations]"
  155. //
  156. // * Include a comment explaining why the operation is nondeterministic.
  157. //
  158. // * Throw an error when `Context::deterministicAlgorithms()` is true. Most
  159. // of the time, this should be accomplished by calling
  160. // `at::globalContext().alertNotDeterminstic()`. However, if the
  161. // nondeterministic behavior is caused by the CuBLAS workspace
  162. // configuration in CUDA >= 10.2,
  163. // `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
  164. // called instead (in this case, a comment explaining why the operation is
  165. // nondeterministic is not necessary). See below for details on these
  166. // methods.
  167. //
  168. // * Have an entry in the list of nondeterministic PyTorch operations in the
  169. // docstring of `use_deterministic_algorithms()` in torch/__init__.py
  170. //
  171. // * Have a test function in `test/test_torch.py` whose name begins with
  172. // `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace
  173. // configuration is the reason for nondeterminism, the operation should be
  174. // included in the `test_cublas_config_nondeterministic_alert` test. Any new
  175. // tests should ideally follow a pattern similar to the existing ones.
  176. //
  177. // `example_func()` below shows an example of the comments and error-throwing code
  178. // for a nondeterministic operation:
  179. //
  180. // void example_func() {
  181. // // See Note [Writing Nondeterministic Operations]
  182. // // Nondeterministic because <reason>
  183. // at::globalContext().alertNondeterministic("example_func");
  184. // ...
  185. // }
  186. // Throws an error if `Context::deterministicAlgorithms()` is true
  187. static void alertNotDeterministic(c10::string_view const& caller);
  188. // Throws an error if `Context::deterministicAlgorithms()` is true, CUDA >= 10.2, and
  189. // CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or ":4096:8". For more details:
  190. // https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
  191. void alertCuBLASConfigNotDeterministic() const;
  192. void setFloat32MatmulPrecision(const std::string & s);
  193. bool allowTF32CuDNN() const;
  194. void setAllowTF32CuDNN(bool);
  195. bool allowTF32CuBLAS() const;
  196. void setAllowTF32CuBLAS(bool);
  197. Float32MatmulPrecision float32MatmulPrecision() const;
  198. void setFloat32MatmulPrecision(Float32MatmulPrecision p);
  199. bool allowFP16ReductionCuBLAS() const;
  200. void setAllowFP16ReductionCuBLAS(bool);
  201. at::QEngine qEngine() const;
  202. void setQEngine(at::QEngine e);
  203. static const std::vector<at::QEngine>& supportedQEngines() ;
  204. static bool isXNNPACKAvailable() ;
  205. // This method is used to release the original weight after pre-packing.
  206. // It should be called once before loading/running the model.
  207. // NB: By default it is set to true for mobile builds.
  208. void setReleaseWeightsWhenPrepacking(bool e);
  209. bool releaseWeightsWhenPrepacking() const;
  210. void setDisplayVmapFallbackWarnings(bool enabled);
  211. bool areVmapFallbackWarningsEnabled() const;
  212. void setDefaultMobileCPUAllocator();
  213. void unsetDefaultMobileCPUAllocator();
  214. private:
  215. void initCUDAIfNeeded(DeviceType p) {
  216. if (p == DeviceType::CUDA) {
  217. lazyInitCUDA();
  218. }
  219. }
  220. void initHIPIfNeeded(DeviceType p) {
  221. if (p == DeviceType::HIP) {
  222. lazyInitHIP();
  223. }
  224. }
  225. static bool checkCuBLASConfigDeterministic();
  226. std::once_flag thc_init;
  227. std::once_flag thh_init;
  228. bool enabled_cudnn = true;
  229. bool deterministic_cudnn = false;
  230. bool _deterministic_algorithms = false;
  231. bool _deterministic_algorithms_warn_only = false;
  232. bool benchmark_cudnn = false;
  233. Float32MatmulPrecision float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
  234. bool allow_tf32_cudnn = true;
  235. bool allow_fp16_reduction_cublas = true;
  236. bool enabled_mkldnn = true;
  237. at::LinalgBackend linalg_preferred_backend = at::LinalgBackend::Default;
  238. #ifdef C10_MOBILE
  239. bool release_original_weights = true;
  240. #else
  241. bool release_original_weights = false;
  242. #endif
  243. bool display_vmap_fallback_warnings_ = false;
  244. c10::optional<at::QEngine> quantized_engine = c10::nullopt;
  245. Allocator* prev_allocator_ptr_{nullptr};
  246. };
  247. TORCH_API Context& globalContext();
  248. static inline void init() {
  249. globalContext();
  250. }
  251. TORCH_API Allocator* getCPUAllocator();
  252. static inline DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s) {
  253. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  254. p, s);
  255. }
  256. static inline DeprecatedTypeProperties& CPU(ScalarType s) {
  257. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  258. Backend::CPU, s);
  259. }
  260. static inline DeprecatedTypeProperties& CUDA(ScalarType s) {
  261. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  262. Backend::CUDA, s);
  263. }
  264. static inline DeprecatedTypeProperties& HIP(ScalarType s) {
  265. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  266. Backend::HIP, s);
  267. }
  268. static inline DeprecatedTypeProperties& MPS(ScalarType s) {
  269. return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
  270. Backend::MPS, s);
  271. }
  272. static inline bool hasCUDA() {
  273. return globalContext().hasCUDA();
  274. }
  275. static inline bool hasHIP() {
  276. return globalContext().hasHIP();
  277. }
  278. static inline bool hasIPU() {
  279. return globalContext().hasIPU();
  280. }
  281. static inline bool hasXLA() {
  282. return globalContext().hasXLA();
  283. }
  284. static inline bool hasMPS() {
  285. return globalContext().hasMPS();
  286. }
  287. static inline bool hasORT() {
  288. return globalContext().hasORT();
  289. }
  290. // Despite its name, this function returns the number of *CUDA* GPUs.
  291. static inline size_t getNumGPUs() {
  292. // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
  293. // FUNCTION. If you are interested in interrogating the number of
  294. // devices for a specific device type, add that function to the
  295. // relevant library (e.g., similar to at::cuda::device_count())
  296. if (hasCUDA() && hasHIP()) {
  297. throw std::runtime_error(
  298. "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
  299. "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
  300. "means HIP. Rebuild PyTorch with one or the other disabled.");
  301. } else if (hasCUDA()) {
  302. return detail::getCUDAHooks().getNumGPUs();
  303. } else if (hasHIP()) {
  304. return detail::getHIPHooks().getNumGPUs();
  305. } else {
  306. return 0;
  307. }
  308. }
  309. static inline bool hasOpenMP() {
  310. return globalContext().hasOpenMP();
  311. }
  312. static inline bool hasMKL() {
  313. return globalContext().hasMKL();
  314. }
  315. static inline bool hasLAPACK() {
  316. return globalContext().hasLAPACK();
  317. }
  318. static inline bool hasMAGMA() {
  319. return globalContext().hasMAGMA();
  320. }
  321. static inline bool hasMKLDNN() {
  322. return globalContext().hasMKLDNN();
  323. }
  324. static inline void manual_seed(uint64_t seed) {
  325. auto gen = globalContext().defaultGenerator(DeviceType::CPU);
  326. {
  327. // See Note [Acquire lock when using random generators]
  328. std::lock_guard<std::mutex> lock(gen.mutex());
  329. gen.set_current_seed(seed);
  330. }
  331. // NB: Sometimes we build with CUDA, but we don't have any GPUs
  332. // available. In that case, we must not seed CUDA; it will fail!
  333. const auto num_gpus = detail::getCUDAHooks().getNumGPUs();
  334. if (hasCUDA() && num_gpus > 0) {
  335. for (const auto i : c10::irange(num_gpus)) {
  336. auto cuda_gen = globalContext().defaultGenerator(
  337. Device(at::kCUDA, static_cast<c10::DeviceIndex>(i))
  338. );
  339. {
  340. // See Note [Acquire lock when using random generators]
  341. std::lock_guard<std::mutex> lock(cuda_gen.mutex());
  342. cuda_gen.set_current_seed(seed);
  343. }
  344. }
  345. }
  346. }
  347. // When the global flag `allow_tf32` is set to true, cuBLAS handles are
  348. // automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
  349. // For some operators, such as addmv, TF32 offers no performance improvement
  350. // but causes precision loss. To help this case, this class implements
  351. // a RAII guard that can be used to quickly disable TF32 within its scope.
  352. //
  353. // Usage:
  354. // NoTF32Guard disable_tf32;
  355. struct TORCH_API NoTF32Guard {
  356. NoTF32Guard();
  357. ~NoTF32Guard();
  358. static bool should_disable_tf32();
  359. private:
  360. bool changed = false;
  361. };
  362. #ifdef USE_ROCM
  363. struct TORCH_API ROCmBackwardPassGuard {
  364. ROCmBackwardPassGuard();
  365. ~ROCmBackwardPassGuard();
  366. static bool is_backward_pass();
  367. private:
  368. static thread_local bool is_backward_pass_;
  369. };
  370. #endif
  371. } // namespace at