Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_dw_convolution.cpp
index 48c1961..82a7a9d 100644 (file)
 * limitations under the License.
 *******************************************************************************/
 
-#include "mkldnn_types.h"
-
 #include "c_types_map.hpp"
-#include "jit_uni_dw_convolution.hpp"
+#include "memory_tracking.hpp"
 #include "mkldnn_thread.hpp"
 
+#include "jit_uni_dw_convolution.hpp"
+
 namespace mkldnn {
 namespace impl {
 namespace cpu {
 
 using namespace mkldnn::impl::status;
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
-template <cpu_isa_t isa, bool with_relu>
-void _jit_uni_dw_convolution_fwd_t<isa, with_relu>::execute_forward() {
+template <cpu_isa_t isa>
+void _jit_uni_dw_convolution_fwd_t<isa>::execute_forward() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
     auto dst = reinterpret_cast<data_t *>(this->memory());
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
-    const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+    const memory_desc_wrapper bias_d(pd()->weights_pd(1));
 
     const auto &jcp = kernel_->jcp;
 
-    if (conf_.want_padded_bias()) {
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            padded_bias_[oc] = bias[oc];
-        bias = padded_bias_;
+    if (pd()->wants_padded_bias()) {
+        auto padded_bias = this->scratchpad().template get<data_t>(
+                key_conv_padded_bias);
+        utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+        utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+                jcp.oc - jcp.oc_without_padding);
+        bias = padded_bias;
     }
 
     int dil_h = jcp.dilate_h + 1;
@@ -85,7 +89,7 @@ void _jit_uni_dw_convolution_fwd_t<isa, with_relu>::execute_forward() {
         return par_conv;
     };
 
-    int MB = conf_.MB();
+    int MB = pd()->MB();
     const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
     parallel_nd(MB, chb_work, jcp.oh,
             [&](int n, int chb, int oh) {
@@ -134,31 +138,24 @@ void _jit_uni_dw_convolution_fwd_t<isa, with_relu>::execute_forward() {
             kernel_->jit_ker(&par_conv);
         }
     });
-}
 
-template void _jit_uni_dw_convolution_fwd_t<avx512_common, false>
-    ::execute_forward();
-template void _jit_uni_dw_convolution_fwd_t<avx2, false>
-    ::execute_forward();
-template void _jit_uni_dw_convolution_fwd_t<sse42, false>
-    ::execute_forward();
+    if (pd()->wants_zero_pad_dst())
+        output_memory_primitive(0)->zero_pad();
+}
 
-template void _jit_uni_dw_convolution_fwd_t<avx512_common, true>
-    ::execute_forward();
-template void _jit_uni_dw_convolution_fwd_t<avx2, true>
-    ::execute_forward();
-template void _jit_uni_dw_convolution_fwd_t<sse42, true>
-    ::execute_forward();
+template struct _jit_uni_dw_convolution_fwd_t<avx512_common>;
+template struct _jit_uni_dw_convolution_fwd_t<avx2>;
+template struct _jit_uni_dw_convolution_fwd_t<sse42>;
 
 template <cpu_isa_t isa>
-void _jit_uni_dw_convolution_bwd_data_t<isa>::execute_backward_data() {
+void _jit_uni_dw_convolution_bwd_data_t<isa>::execute_backward_data() const {
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto diff_src = reinterpret_cast<data_t *>(this->memory());
 
-    const memory_desc_wrapper diff_dst_d(conf_.diff_dst_pd());
-    const memory_desc_wrapper diff_src_d(conf_.diff_src_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
 
     const auto &jcp = kernel_->jcp;
 
@@ -192,7 +189,7 @@ void _jit_uni_dw_convolution_bwd_data_t<isa>::execute_backward_data() {
         return par_conv;
     };
 
-    int MB = conf_.MB();
+    int MB = pd()->MB();
     const int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
     parallel_nd(MB, chb_work, jcp.ih,
         [&](int n, int chb, int ih) {
@@ -247,264 +244,185 @@ void _jit_uni_dw_convolution_bwd_data_t<isa>::execute_backward_data() {
     });
 }
 
-template void _jit_uni_dw_convolution_bwd_data_t<avx512_common>
-    ::execute_backward_data();
-template void _jit_uni_dw_convolution_bwd_data_t<avx2>
-    ::execute_backward_data();
-template void _jit_uni_dw_convolution_bwd_data_t<sse42>
-    ::execute_backward_data();
+template struct _jit_uni_dw_convolution_bwd_data_t<avx512_common>;
+template struct _jit_uni_dw_convolution_bwd_data_t<avx2>;
+template struct _jit_uni_dw_convolution_bwd_data_t<sse42>;
 
 template <cpu_isa_t isa>
 _jit_uni_dw_convolution_bwd_weights_t<isa>::
-        _jit_uni_dw_convolution_bwd_weights_t(const pd_t *pd,
-                const input_vector &inputs, const output_vector &outputs)
-    : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {
-
-    const auto &jcp = conf_.jcp_;
-
-    kernel_ = new jit_uni_dw_conv_bwd_weights_kernel_f32<isa>(jcp);
-
-    const int max_threads
-            = (mkldnn_in_parallel()) ? 1 : mkldnn_get_max_threads();
-    nthr_ = max_threads;
-
-    nthr_g_ = nthr_mb_ = 1;
-
-    /* Basic-Heuristics for parallel strategy:
-     * 1) Tries to parallel on the number of Groups (g) where tasks are
-     * independent. Otherwise,
-     * 2) Tries to split the work across g and MiniBatch (mb).
-     * Parallelizing on mb requires computing a reduction for weights.
-     *
-     * NOTE: because of 'task partitioning' scheme, there will be unbalanced
-     * per-thread load when the number of threads is high (e.g. > 16).
-     */
-    nthr_g_ = nstl::min(jcp.nb_ch, nthr_);
-    nthr_mb_ = nstl::min(nstl::max(1, nthr_ / nthr_g_), jcp.mb);
-
-    nthr_ = nthr_g_ * nthr_mb_;
-
-    /* Notes: if splitting thread work on 'mb', then a reduction has to take
-     * place. Hence, allocate a per-thread, local weights-buffer for the
-     * reduction */
-    if (nthr_mb_ > 1) {
-        const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
-        ws_reduction_ = (data_t *)malloc(
-                (nthr_mb_ - 1) * wei_size * sizeof(data_t), 64);
-
-        if (jcp.with_bias) {
-            const size_t bias_size = jcp.ngroups;
-            bias_reduction_ = (data_t *)malloc(
-                    (nthr_mb_ - 1) * bias_size * sizeof(data_t), 64);
-        }
-
-        /* Used when executing a parallel reduction */
-        if(do_parallel_reduction()){
-            acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
-            simple_barrier::ctx_init(&reduction_bctx_);
-        }
-    }
+_jit_uni_dw_convolution_bwd_weights_t(const pd_t *apd,
+        const input_vector &inputs, const output_vector &outputs)
+    : cpu_primitive_t(apd, inputs, outputs)
+    , kernel_(nullptr), acc_ker_(nullptr)
+{
+    kernel_ = new jit_uni_dw_conv_bwd_weights_kernel_f32<isa>(pd()->jcp_);
+    if (pd()->jcp_.nthr_mb > 1 && do_parallel_reduction())
+        acc_ker_ = new cpu_accumulator_1d_t<data_type::f32>();
 }
+
 template <cpu_isa_t isa>
-void _jit_uni_dw_convolution_bwd_weights_t<isa>::execute_backward_weights() {
+void _jit_uni_dw_convolution_bwd_weights_t<isa>::execute_backward_weights() const {
+    auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
+    auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
+    auto diff_weights = reinterpret_cast<data_t *>(this->memory(0));
+    auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
+
+    auto diff_wei_reduction_buf =
+        scratchpad().template get<data_t>(key_conv_wei_reduction);
+    auto diff_bia_reduction_buf =
+        scratchpad().template get<data_t>(key_conv_bia_reduction);
 
-    auto src
-            = (data_t *)reinterpret_cast<const data_t *>(this->input_memory(0));
-    auto diff_dst
-            = (data_t *)reinterpret_cast<const data_t *>(this->input_memory(1));
     const auto &jcp = kernel_->jcp;
 
-    /* JIT-code skips the unnecessary computations within the padded region. */
-    const int SKIP_TOP_PADDING = 0;
+    /* Used when executing a parallel reduction */
+    simple_barrier::ctx_t reduction_bctx;
+    simple_barrier::ctx_init(&reduction_bctx);
 
     const size_t wei_size = jcp.ngroups * jcp.kh * jcp.kw;
     const size_t bias_size = jcp.with_bias ? jcp.ngroups : 0;
 
-    const int oh_blk_size = jcp.oh_blk_size;
-
-    //const int simd_w = jcp.ch_block;
     const int ch_block = jcp.ch_block;
 
     auto set_kernel_params = [&](jit_dw_conv_call_s *conv_params,
-            const int batch, const int group, const int oh_block,
-            const unsigned char table_idx, const int negative_padding_offset,
-            const unsigned char exec_flag) {
+            const int batch, const int group, const int oh_start,
+            const int work_size, const unsigned char exec_flag,
+            const size_t kh_padding, const size_t filter_off) {
+        const int tpad_underflow_off = jcp.t_pad - filter_off;
+
+        conv_params->exec_flags = exec_flag;
+        conv_params->kh_count = jcp.kh - kh_padding;
 
-        const int ih_block = oh_block * jcp.stride_h;
+        const int oh_s = oh_start;
+        const int oh_e = oh_start + work_size;
+        const int ih_s = oh_s * jcp.stride_h;
 
-        conv_params->table_idx = table_idx;
-        conv_params->exec_flag = exec_flag;
+        conv_params->filter_pad_off
+                = filter_off * jcp.kw * ch_block * sizeof(float);
+        conv_params->oh_index = oh_s;
+        conv_params->oh_count = oh_e;
 
         size_t diff_dst_off
-                = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh + oh_block)
+                = ((batch * (jcp.ngroups / ch_block) + group) * jcp.oh
+                          + oh_start)
                 * jcp.ow;
 
         size_t src_off = ((batch * (jcp.ngroups / ch_block) + group) * jcp.ih
-                              + ih_block - negative_padding_offset)
-                * jcp.iw;
+                + ih_s - tpad_underflow_off) * jcp.iw;
 
         conv_params->output = &diff_dst[diff_dst_off * ch_block];
         conv_params->input = &src[src_off * ch_block];
     };
 
-    parallel(nthr_, [&](const int ithr, const int nthr_) {
+    parallel(jcp.nthr, [&](const int ithr, const int nthr) {
+        assert(nthr == jcp.nthr);
+
         auto conv_params = jit_dw_conv_call_s();
+        const int h_block_size = 15;
 
         /* assign iteration space to thread */
-        const int ithr_g = ithr % nthr_g_;
-        const int ithr_mb = (ithr / nthr_g_) % nthr_mb_;
+        const int ithr_g = ithr % jcp.nthr_g;
+        const int ithr_mb = (ithr / jcp.nthr_g) % jcp.nthr_mb;
 
         /* split dimensions */
         int g_start{ 0 }, g_end{ 0 };
-        balance211(jcp.nb_ch, nthr_g_, ithr_g, g_start, g_end);
+        balance211(jcp.nb_ch, jcp.nthr_g, ithr_g, g_start, g_end);
 
         int mb_start{ 0 }, mb_end{ 0 };
-        balance211(jcp.mb, nthr_mb_, ithr_mb, mb_start, mb_end);
-
-        auto diff_wei = ithr_mb == 0 ?
-                (data_t *)reinterpret_cast<data_t *>(this->memory(0)) :
-                (data_t *)ws_reduction_ + (ithr_mb - 1) * wei_size;
+        balance211(jcp.mb, jcp.nthr_mb, ithr_mb, mb_start, mb_end);
 
-        auto diff_bias = ithr_mb == 0 ?
-                (data_t *)reinterpret_cast<const data_t *>(this->memory(1)) :
-                (data_t *)bias_reduction_ + (ithr_mb - 1) * bias_size;
+        auto diff_wei = ithr_mb == 0
+            ? diff_weights : diff_wei_reduction_buf + (ithr_mb - 1) * wei_size;
+        auto diff_bia = ithr_mb == 0
+            ? diff_bias : diff_bia_reduction_buf + (ithr_mb - 1) * bias_size;
 
         for (int g = g_start; g < g_end; ++g) {
-
-            /* This flag controls whether the kernel loads weights from memory
-             * or initializes the 'weight accummulator' registers to '0'. The
-             * latter happens at the beginning of each group/16 computation. */
-            unsigned char zero_filter_flag = ~FLAG_ZERO_FILTER;
-            unsigned char zero_bias_flag = jcp.with_bias ? ~FLAG_ZERO_BIAS : 0;
+            unsigned char zero_filter_flag = FLAG_ZERO_FILTER;
+            unsigned char zero_bias_flag = jcp.with_bias ? FLAG_ZERO_BIAS : 0;
 
             size_t diff_wei_off = g * jcp.kh * jcp.kw;
             conv_params.filter = &diff_wei[diff_wei_off * ch_block];
 
             if (jcp.with_bias)
-                conv_params.bias = &diff_bias[g * ch_block];
+                conv_params.bias = &diff_bia[g * ch_block];
 
             for (int mb = mb_start; mb < mb_end; ++mb) {
-
-                /* The 'table index' parameter controls the table entry for the
-                 * inner kernel execution. For more details see
-                 * jit_uni_dw_conv_kernel_f32. */
-                int table_idx = 0;
-
-                /* OH_BLOCK is unrolled to separate the computations according
-                 * to numerous condition-setting 'h' parameter. */
-                int oh_blk = 0;
-
-                /* Top-padding case - this case always executes. */
-                set_kernel_params(&conv_params, mb, g, oh_blk, table_idx,
-                        SKIP_TOP_PADDING, zero_filter_flag & zero_bias_flag);
-                kernel_->jit_ker(&conv_params);
-
-                zero_bias_flag |= FLAG_ZERO_BIAS;
-                zero_filter_flag |= FLAG_ZERO_FILTER;
-                oh_blk += oh_blk_size;
-
-                /* Middle OH_BLOCK cases. */
-                for (; oh_blk < (jcp.oh - oh_blk_size); oh_blk += oh_blk_size) {
-                    table_idx = 1;
-                    set_kernel_params(&conv_params, mb, g, oh_blk, table_idx,
-                            jcp.t_pad, zero_filter_flag & zero_bias_flag);
+                int oh = 0;
+                while (oh < jcp.oh) {
+                    const int h_work = nstl::min(h_block_size, jcp.oh - oh);
+                    auto kh_t_padding = nstl::max(0, jcp.t_pad - oh);
+                    auto kh_b_padding
+                            = (oh * jcp.stride_h + jcp.kh - 1 > jcp.ih) ?
+                            jcp.b_pad - (h_work - 1) :
+                            0;
+
+                    set_kernel_params(&conv_params, mb, g, oh, h_work,
+                            zero_filter_flag | zero_bias_flag,
+                            kh_t_padding + kh_b_padding, kh_t_padding);
                     kernel_->jit_ker(&conv_params);
-                }
-                table_idx++;
 
-                /* Bottom block */
-                if (oh_blk < jcp.oh) {
-                    set_kernel_params(&conv_params, mb, g, oh_blk, table_idx,
-                            jcp.t_pad, zero_filter_flag & zero_bias_flag);
-                    kernel_->jit_ker(&conv_params);
+                    zero_bias_flag &= ~FLAG_ZERO_BIAS;
+                    zero_filter_flag &= ~FLAG_ZERO_FILTER;
+                    oh += h_work;
                 }
             }
         }
-        if (do_parallel_reduction() && nthr_mb_ > 1) {
 
+        if (do_parallel_reduction() && jcp.nthr_mb > 1) {
             size_t reduct_start{ 0 }, reduct_end{ 0 };
-            balance211(wei_size, nthr_, ithr, reduct_start, reduct_end);
-
-            const size_t reduct_off = reduct_start;
-
-            auto *acc_data
-                    = (data_t *)reinterpret_cast<data_t *>(this->memory(0))
-                    + reduct_off;
+            balance211(wei_size, nthr, ithr, reduct_start, reduct_end);
 
             const int acc_size = reduct_end - reduct_start;
+            const size_t reduct_off = reduct_start;
+            auto *acc_data = diff_weights + reduct_off;
 
-            simple_barrier::barrier(&reduction_bctx_, nthr_);
-
-            for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
+            simple_barrier::barrier(&reduction_bctx, nthr);
 
-                auto *src_data = (data_t *)ws_reduction_
+            for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
+                auto *src_data = diff_wei_reduction_buf
                         + (thr_mb - 1) * wei_size + reduct_off;
-
                 acc_ker_->accumulate(acc_data, src_data, acc_size);
             }
         }
     });
 
-    /* Apply single-threaded 'mb' reduction */
-    if (nthr_mb_ > 1) {
-
-        auto diff_weights
-                = (data_t *)reinterpret_cast<data_t *>(this->memory(0));
-        auto diff_bias
-                = (data_t *)reinterpret_cast<const data_t *>(this->memory(1));
-
-        for (int thr_mb = 1; thr_mb < nthr_mb_; ++thr_mb) {
-
-            size_t mb_accum_offset = (thr_mb - 1) * wei_size;
-            size_t b_accum_offset = (thr_mb - 1) * bias_size;
+    if (jcp.nthr_mb <= 1) return;
 
-            for (int g = 0; g < jcp.nb_ch; ++g) {
-
-                /* Reduction on Bias */
-                if (jcp.with_bias) {
-                    PRAGMA_OMP_SIMD()
-                    for (int g_block = 0; g_block < ch_block; ++g_block) {
-                        size_t bias_offset = g * ch_block + g_block;
-                        diff_bias[bias_offset] += bias_reduction_[b_accum_offset
-                                + bias_offset];
-                    }
+    /* Apply single-threaded 'mb' reduction */
+    for (int thr_mb = 1; thr_mb < jcp.nthr_mb; ++thr_mb) {
+        size_t mb_accum_offset = (thr_mb - 1) * wei_size;
+        size_t b_accum_offset = (thr_mb - 1) * bias_size;
+
+        for (int g = 0; g < jcp.nb_ch; ++g) {
+            /* Reduction on Bias */
+            if (jcp.with_bias) {
+                PRAGMA_OMP_SIMD()
+                for (int g_block = 0; g_block < ch_block; ++g_block) {
+                    size_t bias_offset = g * ch_block + g_block;
+                    diff_bias[bias_offset] += diff_bia_reduction_buf[
+                        b_accum_offset + bias_offset];
                 }
-                if (!do_parallel_reduction()) {
-                    for (int kh = 0; kh < jcp.kh; ++kh) {
-                        for (int kw = 0; kw < jcp.kw; ++kw) {
-
-                            size_t wei_offset = (g * jcp.kh + kh) * jcp.kw + kw;
-                            PRAGMA_OMP_SIMD()
-                            for (int g_block = 0; g_block < ch_block; ++g_block) {
-                                diff_weights[wei_offset * ch_block + g_block]
-                                        += ws_reduction_[mb_accum_offset
-                                                + wei_offset * ch_block
-                                                + g_block];
-                            }
-                        }
-                    }
+            }
+
+            if (do_parallel_reduction()) continue;
+
+            for (int kh = 0; kh < jcp.kh; ++kh)
+            for (int kw = 0; kw < jcp.kw; ++kw)
+            {
+                size_t wei_offset = (g * jcp.kh + kh) * jcp.kw + kw;
+                PRAGMA_OMP_SIMD()
+                for (int g_block = 0; g_block < ch_block; ++g_block) {
+                    const size_t off = wei_offset * ch_block + g_block;
+                    diff_weights[off] +=
+                        diff_wei_reduction_buf[mb_accum_offset + off];
                 }
             }
         }
     }
 }
 
-template _jit_uni_dw_convolution_bwd_weights_t<avx512_common>::
-        _jit_uni_dw_convolution_bwd_weights_t(const pd_t *pd,
-                const input_vector &inputs, const output_vector &outputs);
-template _jit_uni_dw_convolution_bwd_weights_t<avx2>::
-        _jit_uni_dw_convolution_bwd_weights_t(const pd_t *pd,
-                const input_vector &inputs, const output_vector &outputs);
-template _jit_uni_dw_convolution_bwd_weights_t<sse42>::
-        _jit_uni_dw_convolution_bwd_weights_t(const pd_t *pd,
-                const input_vector &inputs, const output_vector &outputs);
-
-template void _jit_uni_dw_convolution_bwd_weights_t<avx512_common>::
-        execute_backward_weights();
-template void _jit_uni_dw_convolution_bwd_weights_t<avx2>::
-        execute_backward_weights();
-template void _jit_uni_dw_convolution_bwd_weights_t<sse42>::
-        execute_backward_weights();
+template struct _jit_uni_dw_convolution_bwd_weights_t<avx512_common>;
+template struct _jit_uni_dw_convolution_bwd_weights_t<avx2>;
+template struct _jit_uni_dw_convolution_bwd_weights_t<sse42>;
 
 }
 }