Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / mkldnn_thread_parallel_nd.hpp
index 77bf53b..4a1f487 100644 (file)
@@ -56,9 +56,9 @@ void parallel(int nthr, F f) {
 
 template <typename T0, typename F>
 void for_nd(const int ithr, const int nthr, const T0 &D0, F f) {
-    T0 d0{0}, end{0};
-    balance211(D0, nthr, ithr, d0, end);
-    for (; d0 < end; ++d0) f(d0);
+    T0 start{0}, end{0};
+    balance211(D0, nthr, ithr, start, end);
+    for (T0 d0 = start; d0 < end; ++d0) f(d0);
 }
 
 template <typename T0, typename T1, typename F>
@@ -143,6 +143,13 @@ void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
     }
 }
 
+// Skip a lambda function in the parameter pack.
+template <typename T>
+constexpr size_t get_work_amount(const T &v) { return 1; }
+template <typename T, typename ...Args>
+constexpr size_t get_work_amount(const T &v, Args &&...args)
+{ return (size_t)v * get_work_amount(utils::forward<Args>(args)...); }
+
 /* parallel_nd and parallel_nd_in_omp section */
 
 #if MKLDNN_THR != MKLDNN_THR_TBB
@@ -151,9 +158,13 @@ void parallel_nd(Args &&...args) {
 #if MKLDNN_THR == MKLDNN_THR_SEQ
     for_nd(0, 1, utils::forward<Args>(args)...);
 #elif MKLDNN_THR == MKLDNN_THR_OMP
-#   pragma omp parallel
-    for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(),
-            utils::forward<Args>(args)...);
+    const bool do_parallel = get_work_amount(utils::forward<Args>(args)...) > 1;
+#   pragma omp parallel if (do_parallel)
+    {
+        const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads();
+        const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num();
+        for_nd(ithr, nthr, utils::forward<Args>(args)...);
+    }
 #endif
 }
 #else // MKLDNN_THR != MKLDNN_THR_TBB