ThreadLocalState.h 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. #pragma once
  2. #include <stack>
  3. #include <c10/core/InferenceMode.h>
  4. #include <c10/core/impl/LocalDispatchKeySet.h>
  5. #include <c10/util/Exception.h>
  6. #include <c10/util/ThreadLocalDebugInfo.h>
  7. #include <ATen/record_function.h>
  8. #include <ATen/FuncTorchTLS.h>
  9. #include <ATen/core/TorchDispatchModeTLS.h>
  10. #include <ATen/PythonTorchFunctionTLS.h>
  11. namespace at {
  12. // Thread local state contains values that are preserved across
  13. // thread boundaries (e.g. at::launch/JIT fork, autograd).
  14. // Note at::parallel_for doesn't preserve TLS across thread boundaries.
  15. class TORCH_API ThreadLocalState {
  16. public:
  17. // Saves the thread local variables' values and
  18. // returns them as a ThreadLocalState
  19. ThreadLocalState();
  20. // set_grad_mode - force the value of the grad mode TLS in
  21. // the current state object. This is used for example in the
  22. // autograd engine.
  23. void set_grad_mode(bool enabled);
  24. // Sets thread local variables in the current thread,
  25. // according to the thread boundary specified
  26. static void setThreadLocalState(const ThreadLocalState& state);
  27. private:
  28. c10::impl::LocalDispatchKeySet dispatch_key_;
  29. // ThreadLocalDebugInfo does not change after being created
  30. // with DebugInfoGuard
  31. std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;
  32. // RecordFunction TLS
  33. RecordFunctionTLS rf_tls_;
  34. // TLS for out-of-tree functorch
  35. // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a
  36. // pointer (spoiler alert: it's due to the indirection)
  37. // This needs to be a shared_ptr instead of a unique_ptr because
  38. // ThreadLocalState is copy-able and does indeed get copied. Maybe we can
  39. // consider adding an explicit copy constructor for ThreadLocalState in the
  40. // future but I didn't want to add one just for this.
  41. std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_;
  42. // TLS for AutogradModes
  43. AutogradState autograd_tls_;
  44. // TLS for enable_torch_dispatch_mode
  45. std::shared_ptr<SafePyObject> torch_dispatch_mode_state_;
  46. // TLS for __torch_function__ (mode and disable_torch_function)
  47. at::impl::PythonTorchFunctionTLS python_torch_function_state_;
  48. // TLS for saved tensors default hooks
  49. std::stack<std::pair<PyObject*, PyObject*>> saved_tensors_default_hooks_;
  50. friend class ThreadLocalStateGuard;
  51. };
  52. // Guard to set and reset the thread local state
  53. class TORCH_API ThreadLocalStateGuard {
  54. public:
  55. explicit ThreadLocalStateGuard(const ThreadLocalState& state)
  56. : prev_state_(ThreadLocalState()) {
  57. // set the given state across the thread boundary
  58. ThreadLocalState::setThreadLocalState(state);
  59. }
  60. ~ThreadLocalStateGuard() {
  61. // restore previously set variables
  62. ThreadLocalState::setThreadLocalState(prev_state_);
  63. }
  64. private:
  65. const ThreadLocalState prev_state_;
  66. };
  67. template <typename T>
  68. auto wrapPropagateTLSState(T callback) {
  69. return [tls_state = ThreadLocalState(),
  70. callback = std::move(callback)](auto&&... args) {
  71. ThreadLocalStateGuard g(tls_state);
  72. // Propagate value returned by callback().
  73. return callback(std::forward<decltype(args)>(args)...);
  74. };
  75. }
  76. } // namespace at