PTThreadPool.h 394 B

12345678910111213141516171819
  1. #pragma once
  2. #include <ATen/Parallel.h>
  3. #include <c10/core/thread_pool.h>
  4. namespace at {
  5. class TORCH_API PTThreadPool : public c10::ThreadPool {
  6. public:
  7. explicit PTThreadPool(
  8. int pool_size,
  9. int numa_node_id = -1)
  10. : c10::ThreadPool(pool_size, numa_node_id, [](){
  11. c10::setThreadName("PTThreadPool");
  12. at::init_num_threads();
  13. }) {}
  14. };
  15. } // namespace at