ParallelOpenMP.h 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. #pragma once
  2. #include <atomic>
  3. #include <cstddef>
  4. #include <exception>
  5. #ifdef _OPENMP
  6. #define INTRA_OP_PARALLEL
  7. #include <omp.h>
  8. #endif
  9. namespace at {
  10. #ifdef _OPENMP
  11. namespace internal {
  12. template <typename F>
  13. inline void invoke_parallel(int64_t begin, int64_t end, int64_t grain_size, const F& f) {
  14. std::atomic_flag err_flag = ATOMIC_FLAG_INIT;
  15. std::exception_ptr eptr;
  16. #pragma omp parallel
  17. {
  18. // choose number of tasks based on grain size and number of threads
  19. // can't use num_threads clause due to bugs in GOMP's thread pool (See #32008)
  20. int64_t num_threads = omp_get_num_threads();
  21. if (grain_size > 0) {
  22. num_threads = std::min(num_threads, divup((end - begin), grain_size));
  23. }
  24. int64_t tid = omp_get_thread_num();
  25. int64_t chunk_size = divup((end - begin), num_threads);
  26. int64_t begin_tid = begin + tid * chunk_size;
  27. if (begin_tid < end) {
  28. try {
  29. internal::ThreadIdGuard tid_guard(tid);
  30. f(begin_tid, std::min(end, chunk_size + begin_tid));
  31. } catch (...) {
  32. if (!err_flag.test_and_set()) {
  33. eptr = std::current_exception();
  34. }
  35. }
  36. }
  37. }
  38. if (eptr) {
  39. std::rethrow_exception(eptr);
  40. }
  41. }
  42. } // namespace internal
  43. #endif // _OPENMP
  44. } // namespace at