template <typename T0, typename F>
void for_nd(const int ithr, const int nthr, const T0 &D0, F f) {
- T0 d0{0}, end{0};
- balance211(D0, nthr, ithr, d0, end);
- for (; d0 < end; ++d0) f(d0);
+ T0 start{0}, end{0};
+ balance211(D0, nthr, ithr, start, end);
+ for (T0 d0 = start; d0 < end; ++d0) f(d0);
}
template <typename T0, typename T1, typename F>
}
}
+// Skip a lambda function in the parameter pack.
+template <typename T>
+constexpr size_t get_work_amount(const T &v) { return 1; }
+template <typename T, typename ...Args>
+constexpr size_t get_work_amount(const T &v, Args &&...args)
+{ return (size_t)v * get_work_amount(utils::forward<Args>(args)...); }
+
/* parallel_nd and parallel_nd_in_omp section */
#if MKLDNN_THR != MKLDNN_THR_TBB
#if MKLDNN_THR == MKLDNN_THR_SEQ
for_nd(0, 1, utils::forward<Args>(args)...);
#elif MKLDNN_THR == MKLDNN_THR_OMP
-# pragma omp parallel
- for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(),
- utils::forward<Args>(args)...);
+ const bool do_parallel = get_work_amount(utils::forward<Args>(args)...) > 1;
+# pragma omp parallel if (do_parallel)
+ {
+ const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads();
+ const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num();
+ for_nd(ithr, nthr, utils::forward<Args>(args)...);
+ }
#endif
}
#else // MKLDNN_THR != MKLDNN_THR_TBB