#pragma once
#include <cstddef>
+#include <type_traits>
#define IE_THREAD_TBB 0
#define IE_THREAD_OMP 1
template <typename T0, typename R, typename F>
R parallel_sum(const T0& D0, const R& input, const F& func) {
#if (IE_THREAD == IE_THREAD_TBB || IE_THREAD == IE_THREAD_TBB_AUTO)
- return tbb::parallel_reduce(
+ return tbb::parallel_deterministic_reduce(
tbb::blocked_range<T0>(0, D0), input,
[&](const tbb::blocked_range<T0>& r, R init) -> R {
R sum = init;
template <typename T0, typename T1, typename R, typename F>
R parallel_sum2d(const T0& D0, const T1& D1, const R& input, const F& func) {
#if (IE_THREAD == IE_THREAD_TBB || IE_THREAD == IE_THREAD_TBB_AUTO)
- return tbb::parallel_reduce(
+ return tbb::parallel_deterministic_reduce(
tbb::blocked_range2d<T0, T1>(0, D0, 0, D1), input,
[&](const tbb::blocked_range2d<T0, T1>& r, R init) -> R {
R sum = init;
template <typename T0, typename T1, typename T2, typename R, typename F>
R parallel_sum3d(const T0& D0, const T1& D1, const T2& D2, const R& input, const F& func) {
#if (IE_THREAD == IE_THREAD_TBB || IE_THREAD == IE_THREAD_TBB_AUTO)
- return tbb::parallel_reduce(
+ return tbb::parallel_deterministic_reduce(
tbb::blocked_range3d<T0, T1, T2>(0, D0, 0, D1, 0, D2), input,
[&](const tbb::blocked_range3d<T0, T1, T2>& r, R init) -> R {
R sum = init;
n_end += n_start;
}
+namespace details {
+ template<typename T>
+ struct num_of_lambda_args : public num_of_lambda_args<decltype(&T::operator())> {
+ };
+
+ template<typename C, typename R, typename... Args>
+ struct num_of_lambda_args<R(C::*)(Args...) const> {
+ constexpr static int value = sizeof...(Args);
+ };
+
+ template<typename ACT, typename ...T, size_t N_ARGS = num_of_lambda_args<ACT>::value>
+ typename std::enable_if<N_ARGS == sizeof...(T) + 1, void>::type
+ call_with_args(ACT body, size_t g_id, T ...arg) {
+ body(g_id, arg...);
+ }
+
+ template<typename ACT, typename ...T, size_t N_ARGS = num_of_lambda_args<ACT>::value>
+ typename std::enable_if<N_ARGS == sizeof...(T), void>::type
+ call_with_args(ACT body, size_t g_id, T ...arg) {
+ body(arg...);
+ }
+} // namespace details
+
template <typename T0, typename F>
void for_1d(const int& ithr, const int& nthr, const T0& D0, const F& func) {
T0 d0 {0}, end {0};
splitter(D0, nthr, ithr, d0, end);
- for (; d0 < end; ++d0) func(d0);
+ for (; d0 < end; ++d0)
+ details::call_with_args(func, ithr, d0);
}
template <typename T0, typename F>
T1 d1 {0};
parallel_it_init(start, d0, D0, d1, D1);
for (size_t iwork = start; iwork < end; ++iwork) {
- func(d0, d1);
+ details::call_with_args(func, ithr, d0, d1);
parallel_it_step(d0, D0, d1, D1);
}
}
T2 d2 {0};
parallel_it_init(start, d0, D0, d1, D1, d2, D2);
for (size_t iwork = start; iwork < end; ++iwork) {
- func(d0, d1, d2);
+ details::call_with_args(func, ithr, d0, d1, d2);
parallel_it_step(d0, D0, d1, D1, d2, D2);
}
}
T3 d3 {0};
parallel_it_init(start, d0, D0, d1, D1, d2, D2, d3, D3);
for (size_t iwork = start; iwork < end; ++iwork) {
- func(d0, d1, d2, d3);
+ details::call_with_args(func, ithr, d0, d1, d2, d3);
parallel_it_step(d0, D0, d1, D1, d2, D2, d3, D3);
}
}
T4 d4 {0};
parallel_it_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
for (size_t iwork = start; iwork < end; ++iwork) {
- func(d0, d1, d2, d3, d4);
+ details::call_with_args(func, ithr, d0, d1, d2, d3, d4);
parallel_it_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
}
}
for (int i = 0; i < mean_buffer.size(); i++)
mean_buffer[i] = 0.f;
- parallel_for2d(D, H, [&](size_t d, size_t h) {
+ parallel_for2d(D, H, [&](size_t thr_idx, size_t d, size_t h) {
for (size_t cb = 0; cb < CB; cb++) {
size_t src_off = is_nhwc ? ccb + d * H * W * C + h * W * C + cb * blk_size
: ccb + d * H * W * blk_size + h * W * blk_size + cb * D * H * W * blk_size;
- auto thr_idx = mkldnn_get_thread_num();
auto mean_buffer_ptr = &mean_buffer[blk_size * cb + aux_buffer_size * thr_idx];
auto arg = jit_mvn_call_args();
for (int i = 0; i < variance_buffer.size(); i++)
variance_buffer[i] = 0.f;
- parallel_for2d(D, H, [&](size_t d, size_t h) {
+ parallel_for2d(D, H, [&](size_t thr_idx, size_t d, size_t h) {
for (size_t cb = 0; cb < CB; cb++) {
size_t src_off = is_nhwc ? ccb + d * H * W * C + h * W * C + cb * blk_size
: ccb + d * H * W * blk_size + h * W * blk_size + cb * D * H * W * blk_size;
- auto thr_idx = mkldnn_get_thread_num();
auto mean_buffer_ptr = &mean_buffer[blk_size * cb];
auto variance_buffer_ptr = &variance_buffer[blk_size * cb + aux_buffer_size * thr_idx];