* limitations under the License.
*******************************************************************************/
-#include "mkldnn_types.h"
-
#include "c_types_map.hpp"
-#include "jit_uni_dw_convolution.hpp"
+#include "memory_tracking.hpp"
#include "mkldnn_thread.hpp"
+#include "jit_uni_dw_convolution.hpp"
+
namespace mkldnn {
namespace impl {
namespace cpu {
using namespace mkldnn::impl::status;
using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;
-template <cpu_isa_t isa, bool with_relu>
-void _jit_uni_dw_convolution_fwd_t<isa, with_relu>::execute_forward() {
+template <cpu_isa_t isa>
+void _jit_uni_dw_convolution_fwd_t<isa>::execute_forward() const {
auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
auto dst = reinterpret_cast<data_t *>(this->memory());
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
- const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+ const memory_desc_wrapper bias_d(pd()->weights_pd(1));
const auto &jcp = kernel_->jcp;
- if (conf_.want_padded_bias()) {
- for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
- padded_bias_[oc] = bias[oc];
- bias = padded_bias_;
+ if (pd()->wants_padded_bias()) {
+ auto padded_bias = this->scratchpad().template get<data_t>(
+ key_conv_padded_bias);
+ utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+ utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+ jcp.oc - jcp.oc_without_padding);
+ bias = padded_bias;
}
int dil_h = jcp.dilate_h + 1;
return par_conv;
};
- int MB = conf_.MB();
+ int MB = pd()->MB();
const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
parallel_nd(MB, chb_work, jcp.oh,
[&](int n, int chb, int oh) {
kernel_->jit_ker(&par_conv);
}
});
-}
-template void _jit_uni_dw_convolution_fwd_t<avx512_common, false>
- ::execute_forward();
-template void _jit_uni_dw_convolution_fwd_t<avx2, false>
- ::execute_forward();
-template void _jit_uni_dw_convolution_fwd_t<sse42, false>
- ::execute_forward();
+ if (pd()->wants_zero_pad_dst())
+ output_memory_primitive(0)->zero_pad();
+}
-template void _jit_uni_dw_convolution_fwd_t<avx512_common, true>
- ::execute_forward();
-template void _jit_uni_dw_convolution_fwd_t<avx2, true>
- ::execute_forward();
-template void _jit_uni_dw_convolution_fwd_t<sse42, true>
- ::execute_forward();
+template struct _jit_uni_dw_convolution_fwd_t<avx512_common>;
+template struct _jit_uni_dw_convolution_fwd_t<avx2>;
+template struct _jit_uni_dw_convolution_fwd_t<sse42>;
template <cpu_isa_t isa>
-void _jit_uni_dw_convolution_bwd_data_t<isa>::execute_backward_data() {
+void _jit_uni_dw_convolution_bwd_data_t<isa>::execute_backward_data() const {
auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
auto diff_src = reinterpret_cast<data_t *>(this->memory());
- const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
- const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+ const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+ const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
const auto &jcp = kernel_->jcp;
return par_conv;
};
- int MB = conf_.MB();
+ int MB = pd()->MB();
const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
parallel_nd(MB, chb_work, jcp.ih,
[&](int n, int chb, int ih) {
});
}
-template void _jit_uni_dw_convolution_bwd_data_t<avx512_common>
- ::execute_backward_data();
-template void _jit_uni_dw_convolution_bwd_data_t<avx2>
- ::execute_backward_data();
-template void _jit_uni_dw_convolution_bwd_data_t<sse42>
- ::execute_backward_data();
+template struct _jit_uni_dw_convolution_bwd_data_t<avx512_common>;
+template struct _jit_uni_dw_convolution_bwd_data_t<avx2>;
+template struct _jit_uni_dw_convolution_bwd_data_t<sse42>;
template <cpu_isa_t isa>
_jit_uni_dw_convolution_bwd_weights_t<isa>::
- _jit_uni_dw_convolution_bwd_weights_t(const pd_t *pd,
- const input_vector &inputs, const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {
-
- const auto &jcp = conf_.jcp_;
-
- kernel_ = new jit_uni_dw_conv_bwd_weights_kernel_f32<isa>(jcp);
-
- const int max_threads
- = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
- nthr_ = max_threads;
-
- nthr_g_ = nthr_mb_ = 1;
-
- /* Basic-Heuristics for parallel strategy:
- * 1) Tries to parallel on the number of Groups (g) where tasks are
- * independent. Otherwise,
- * 2) Tries to split the work across g and MiniBatch (mb).
- * Parallelizing on mb requires computing a reduction for weights.
- *
- * NOTE: because of 'task partitioning' scheme, there will be unbalanced
- * per-thread load when the number of threads is high (e.g. > 16).
- */
- nthr_g_ = nstl::min(jcp.nb_ch, nthr_);
- nthr_mb_ = nstl::min(nstl::max(1, nthr_ / nthr_g_), jcp.mb);
-
- nthr_ = nthr_g_ * nthr_mb_;
-
- /* Notes: if splitting thread work on 'mb', then a reduction has to take
- * place. Hence, allocate a per-thread, local weights-buffer for the
- * reduction */
- if (nthr_mb_ > 1) {
- const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
- ws_reduction_ = (data_t *)malloc(
- (nthr_mb_ - 1) * wei_size * sizeof(data_t), 64);
-
- if (jcp.with_bias) {
- const size_t bias_size = jcp.ngroups;
- bias_reduction_ = (data_t *)malloc(
- (nthr_mb_ - 1) * bias_size * sizeof(data_t), 64);
- }
-
- /* Used when executing a parallel reduction */
- if(do_parallel_reduction()){
- acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
- simple_barrier::ctx_init(&reduction_bctx_);
- }
- }
+_jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd,
+ const input_vector &inputs, const output_vector &outputs)
+ : cpu_primitive_t(apd, inputs, outputs)
+ , kernel_(nullptr), acc_ker_(nullptr)
+{
+ kernel_ = new jit_uni_dw_conv_bwd_weights_kernel_f32<isa>(pd()->jcp_);
+ if (pd()->jcp_.nthr_mb > 1 && do_parallel_reduction())
+ acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
}
+
template <cpu_isa_t isa>
-void _jit_uni_dw_convolution_bwd_weights_t<isa>::execute_backward_weights() {
+void _jit_uni_dw_convolution_bwd_weights_t<isa>::execute_backward_weights() const {
+ auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
+ auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
+ auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
+ auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
+
+ auto diff_wei_reduction_buf =
+ scratchpad().template get<data_t>(key_conv_wei_reduction);
+ auto diff_bia_reduction_buf =
+ scratchpad().template get<data_t>(key_conv_bia_reduction);
- auto src
- = (data_t *)reinterpret_cast<const data_t *>(this->input_memory(0));
- auto diff_dst
- = (data_t *)reinterpret_cast<const data_t *>(this->input_memory(1));
const auto &jcp = kernel_->jcp;
- /* JIT-code skips the unnecessary computations within the padded region. */
- const int SKIP_TOP_PADDING = 0;
+ /* Used when executing a parallel reduction */
+ simple_barrier::ctx_t reduction_bctx;
+ simple_barrier::ctx_init(&reduction_bctx);
const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0;
- const int oh_blk_size = jcp.oh_blk_size;
-
- //const int simd_w = jcp.ch_block;
const int ch_block = jcp.ch_block;
auto set_kernel_params = [&](jit_dw_conv_call_s *conv_params,
- const int batch, const int group, const int oh_block,
- const unsigned char table_idx, const int negative_padding_offset,
- const unsigned char exec_flag) {
+ const int batch, const int group, const int oh_start,
+ const int work_size, const unsigned char exec_flag,
+ const size_t kh_padding, const size_t filter_off) {
+ const int tpad_underflow_off = jcp.t_pad - filter_off;
+
+ conv_params->exec_flags = exec_flag;
+ conv_params->kh_count = jcp.kh - kh_padding;
- const int ih_block = oh_block * jcp.stride_h;
+ const int oh_s = oh_start;
+ const int oh_e = oh_start + work_size;
+ const int ih_s = oh_s * jcp.stride_h;
- conv_params->table_idx = table_idx;
- conv_params->exec_flag = exec_flag;
+ conv_params->filter_pad_off
+ = filter_off * jcp.kw * ch_block * sizeof(float);
+ conv_params->oh_index = oh_s;
+ conv_params->oh_count = oh_e;
size_t diff_dst_off
- = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh + oh_block)
+ = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh
+ + oh_start)
* jcp.ow;
size_t src_off = ((batch * (jcp.ngroups / ch_block) + group) * jcp.ih
- + ih_block - negative_padding_offset)
- * jcp.iw;
+ + ih_s - tpad_underflow_off) * jcp.iw;
conv_params->output = &diff_dst[diff_dst_off * ch_block];
conv_params->input = &src[src_off * ch_block];
};
- parallel(nthr_, [&](const int ithr, const int nthr_) {
+ parallel(jcp.nthr, [&](const int ithr, const int nthr) {
+ assert(nthr == jcp.nthr);
+
auto conv_params = jit_dw_conv_call_s();
+ const int h_block_size = 15;
/* assign iteration space to thread */
- const int ithr_g = ithr % nthr_g_;
- const int ithr_mb = (ithr / nthr_g_) % nthr_mb_;
+ const int ithr_g = ithr % jcp.nthr_g;
+ const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb;
/* split dimensions */
int g_start{ 0 }, g_end{ 0 };
- balance211(jcp.nb_ch, nthr_g_, ithr_g, g_start, g_end);
+ balance211(jcp.nb_ch, jcp.nthr_g, ithr_g, g_start, g_end);
int mb_start{ 0 }, mb_end{ 0 };
- balance211(jcp.mb, nthr_mb_, ithr_mb, mb_start, mb_end);
-
- auto diff_wei = ithr_mb == 0 ?
- (data_t *)reinterpret_cast<data_t *>(this->memory(0)) :
- (data_t *)ws_reduction_ + (ithr_mb - 1) * wei_size;
+ balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end);
- auto diff_bias = ithr_mb == 0 ?
- (data_t *)reinterpret_cast<const data_t *>(this->memory(1)) :
- (data_t *)bias_reduction_ + (ithr_mb - 1) * bias_size;
+ auto diff_wei = ithr_mb == 0
+ ? diff_weights : diff_wei_reduction_buf + (ithr_mb - 1) * wei_size;
+ auto diff_bia = ithr_mb == 0
+ ? diff_bias : diff_bia_reduction_buf + (ithr_mb - 1) * bias_size;
for (int g = g_start; g < g_end; ++g) {
-
- /* This flag controls whether the kernel loads weights from memory
- * or initializes the 'weight accummulator' registers to '0'. The
- * latter happens at the beginning of each group/16 computation. */
- unsigned char zero_filter_flag = ~FLAG_ZERO_FILTER;
- unsigned char zero_bias_flag = jcp.with_bias ? ~FLAG_ZERO_BIAS : 0;
+ unsigned char zero_filter_flag = FLAG_ZERO_FILTER;
+ unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0;
size_t diff_wei_off = g * jcp.kh * jcp.kw;
conv_params.filter = &diff_wei[diff_wei_off * ch_block];
if (jcp.with_bias)
- conv_params.bias = &diff_bias[g * ch_block];
+ conv_params.bias = &diff_bia[g * ch_block];
for (int mb = mb_start; mb < mb_end; ++mb) {
-
- /* The 'table index' parameter controls the table entry for the
- * inner kernel execution. For more details see
- * jit_uni_dw_conv_kernel_f32. */
- int table_idx = 0;
-
- /* OH_BLOCK is unrolled to separate the computations according
- * to numerous condition-setting 'h' parameter. */
- int oh_blk = 0;
-
- /* Top-padding case - this case always executes. */
- set_kernel_params(&conv_params, mb, g, oh_blk, table_idx,
- SKIP_TOP_PADDING, zero_filter_flag & zero_bias_flag);
- kernel_->jit_ker(&conv_params);
-
- zero_bias_flag |= FLAG_ZERO_BIAS;
- zero_filter_flag |= FLAG_ZERO_FILTER;
- oh_blk += oh_blk_size;
-
- /* Middle OH_BLOCK cases. */
- for (; oh_blk < (jcp.oh - oh_blk_size); oh_blk += oh_blk_size) {
- table_idx = 1;
- set_kernel_params(&conv_params, mb, g, oh_blk, table_idx,
- jcp.t_pad, zero_filter_flag & zero_bias_flag);
+ int oh = 0;
+ while (oh < jcp.oh) {
+ const int h_work = nstl::min(h_block_size, jcp.oh - oh);
+ auto kh_t_padding = nstl::max(0, jcp.t_pad - oh);
+ auto kh_b_padding
+ = (oh * jcp.stride_h + jcp.kh - 1 > jcp.ih) ?
+ jcp.b_pad - (h_work - 1) :
+ 0;
+
+ set_kernel_params(&conv_params, mb, g, oh, h_work,
+ zero_filter_flag | zero_bias_flag,
+ kh_t_padding + kh_b_padding, kh_t_padding);
kernel_->jit_ker(&conv_params);
- }
- table_idx++;
- /* Bottom block */
- if (oh_blk < jcp.oh) {
- set_kernel_params(&conv_params, mb, g, oh_blk, table_idx,
- jcp.t_pad, zero_filter_flag & zero_bias_flag);
- kernel_->jit_ker(&conv_params);
+ zero_bias_flag &= ~FLAG_ZERO_BIAS;
+ zero_filter_flag &= ~FLAG_ZERO_FILTER;
+ oh += h_work;
}
}
}
- if (do_parallel_reduction() && nthr_mb_ > 1) {
+ if (do_parallel_reduction() && jcp.nthr_mb > 1) {
size_t reduct_start{ 0 }, reduct_end{ 0 };
- balance211(wei_size, nthr_, ithr, reduct_start, reduct_end);
-
- const size_t reduct_off = reduct_start;
-
- auto *acc_data
- = (data_t *)reinterpret_cast<data_t *>(this->memory(0))
- + reduct_off;
+ balance211(wei_size, nthr, ithr, reduct_start, reduct_end);
const int acc_size = reduct_end - reduct_start;
+ const size_t reduct_off = reduct_start;
+ auto *acc_data = diff_weights + reduct_off;
- simple_barrier::barrier(&reduction_bctx_, nthr_);
-
- for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
+ simple_barrier::barrier(&reduction_bctx, nthr);
- auto *src_data = (data_t *)ws_reduction_
+ for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
+ auto *src_data = diff_wei_reduction_buf
+ (thr_mb - 1) * wei_size + reduct_off;
-
acc_ker_->accumulate(acc_data, src_data, acc_size);
}
}
});
- /* Apply single-threaded 'mb' reduction */
- if (nthr_mb_ > 1) {
-
- auto diff_weights
- = (data_t *)reinterpret_cast<data_t *>(this->memory(0));
- auto diff_bias
- = (data_t *)reinterpret_cast<const data_t *>(this->memory(1));
-
- for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
-
- size_t mb_accum_offset = (thr_mb - 1) * wei_size;
- size_t b_accum_offset = (thr_mb - 1) * bias_size;
+ if (jcp.nthr_mb <= 1) return;
- for (int g = 0; g < jcp.nb_ch; ++g) {
-
- /* Reduction on Bias */
- if (jcp.with_bias) {
- PRAGMA_OMP_SIMD()
- for (int g_block = 0; g_block < ch_block; ++g_block) {
- size_t bias_offset = g * ch_block + g_block;
- diff_bias[bias_offset] += bias_reduction_[b_accum_offset
- + bias_offset];
- }
+ /* Apply single-threaded 'mb' reduction */
+ for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
+ size_t mb_accum_offset = (thr_mb - 1) * wei_size;
+ size_t b_accum_offset = (thr_mb - 1) * bias_size;
+
+ for (int g = 0; g < jcp.nb_ch; ++g) {
+ /* Reduction on Bias */
+ if (jcp.with_bias) {
+ PRAGMA_OMP_SIMD()
+ for (int g_block = 0; g_block < ch_block; ++g_block) {
+ size_t bias_offset = g * ch_block + g_block;
+ diff_bias[bias_offset] += diff_bia_reduction_buf[
+ b_accum_offset + bias_offset];
}
- if (!do_parallel_reduction()) {
- for (int kh = 0; kh < jcp.kh; ++kh) {
- for (int kw = 0; kw < jcp.kw; ++kw) {
-
- size_t wei_offset = (g * jcp.kh + kh) * jcp.kw + kw;
- PRAGMA_OMP_SIMD()
- for (int g_block = 0; g_block < ch_block; ++g_block) {
- diff_weights[wei_offset * ch_block + g_block]
- += ws_reduction_[mb_accum_offset
- + wei_offset * ch_block
- + g_block];
- }
- }
- }
+ }
+
+ if (do_parallel_reduction()) continue;
+
+ for (int kh = 0; kh < jcp.kh; ++kh)
+ for (int kw = 0; kw < jcp.kw; ++kw)
+ {
+ size_t wei_offset = (g * jcp.kh + kh) * jcp.kw + kw;
+ PRAGMA_OMP_SIMD()
+ for (int g_block = 0; g_block < ch_block; ++g_block) {
+ const size_t off = wei_offset * ch_block + g_block;
+ diff_weights[off] +=
+ diff_wei_reduction_buf[mb_accum_offset + off];
}
}
}
}
}
-template _jit_uni_dw_convolution_bwd_weights_t<avx512_common>::
- _jit_uni_dw_convolution_bwd_weights_t(const pd_t *pd,
- const input_vector &inputs, const output_vector &outputs);
-template _jit_uni_dw_convolution_bwd_weights_t<avx2>::
- _jit_uni_dw_convolution_bwd_weights_t(const pd_t *pd,
- const input_vector &inputs, const output_vector &outputs);
-template _jit_uni_dw_convolution_bwd_weights_t<sse42>::
- _jit_uni_dw_convolution_bwd_weights_t(const pd_t *pd,
- const input_vector &inputs, const output_vector &outputs);
-
-template void _jit_uni_dw_convolution_bwd_weights_t<avx512_common>::
- execute_backward_weights();
-template void _jit_uni_dw_convolution_bwd_weights_t<avx2>::
- execute_backward_weights();
-template void _jit_uni_dw_convolution_bwd_weights_t<sse42>::
- execute_backward_weights();
+template struct _jit_uni_dw_convolution_bwd_weights_t<avx512_common>;
+template struct _jit_uni_dw_convolution_bwd_weights_t<avx2>;
+template struct _jit_uni_dw_convolution_bwd_weights_t<sse42>;
}
}