Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_sse42_1x1_convolution.cpp
index 3b95a10..2fe6e8f 100644 (file)
@@ -34,36 +34,38 @@ namespace cpu {
 
 using namespace mkldnn::impl::status;
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
-template <bool with_relu>
-void _jit_sse42_1x1_convolution_fwd_t<with_relu>::execute_forward() {
+void jit_sse42_1x1_convolution_fwd_t::execute_forward() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
     auto dst = reinterpret_cast<data_t *>(this->memory());
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
 
     const int ndims = src_d.ndims();
     const auto &jcp = kernel_->jcp;
-    int MB = conf_.MB();
+    int MB = pd()->MB();
 
     const int work_amount = MB * jcp.ngroups * jcp.nb_bcast;
 
-    if (conf_.want_padded_bias()) {
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            padded_bias_[oc] = bias[oc];
-        bias = padded_bias_;
+    if (pd()->wants_padded_bias()) {
+        auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
+        utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+        utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+                jcp.oc - jcp.oc_without_padding);
+        bias = padded_bias;
     }
 
     parallel(0, [&](const int ithr, const int nthr) {
         // TODO (Roma): remove this restriction
         assert(jcp.stride_w == 1 && jcp.stride_h == 1);
 
-        jit_1x1_conv_call_s par_conv = {};
+        auto par_conv = jit_1x1_conv_call_s();
 
         const int nb_oc = jcp.nb_load;
         const int nb_ic = jcp.nb_reduce;
@@ -120,7 +122,7 @@ void _jit_sse42_1x1_convolution_fwd_t<with_relu>::execute_forward() {
                     const size_t src_off = data_blk_off(src_d, n, _icb, ih, iw);
                     par_conv.bcast_data = &src[src_off];
 
-                    par_conv.load_data = &weights[conf_.with_groups()
+                    par_conv.load_data = &weights[pd()->with_groups()
                         ? weights_d.blk_off(g, ocb, icb)
                         : weights_d.blk_off(ocb, icb)];
 
@@ -135,22 +137,25 @@ void _jit_sse42_1x1_convolution_fwd_t<with_relu>::execute_forward() {
             iwork += bcast_step;
         }
     });
+
+    if (pd()->wants_zero_pad_dst())
+        output_memory_primitive(0)->zero_pad();
 }
 
-template <bool with_relu>
-void _jit_sse42_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
+void jit_sse42_1x1_convolution_fwd_t::execute_forward_with_dw_conv() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
     auto dst = reinterpret_cast<data_t *>(this->memory());
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
 
-    auto &jcp = kernel_->jcp;
-    int MB = conf_.MB();
+    const auto &jcp = kernel_->jcp;
+    const auto &jcp_dw = kernel_dw_->jcp;
+    int MB = pd()->MB();
 
-    auto dw_bias = jcp.dw_conv_biases;
+    auto dw_bias = jcp_dw.conv_biases;
 
     int ocb_work = jcp.with_dw_conv ? utils::div_up(jcp.nb_load, jcp.nb_load_blocking) : 1;
     const int work_amount = MB * jcp.ngroups * ocb_work * jcp.nb_bcast;
@@ -173,8 +178,8 @@ void _jit_sse42_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
 
                 if ((oh + h) < 0 || (oh + h) >= jcp.ih) {
                     for (int chb = ocb; chb < ocb + load_step; chb++) {
-                        memset(ws_p + (((oh + h) + 1) % jcp.dw_conv_ker_h) * jcp.ow * jcp.oc_block +
-                               (chb - ocb) * jcp.dw_conv_ker_h * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
+                        memset(ws_p + (((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block +
+                               (chb - ocb) * jcp_dw.kh * jcp.ow * jcp.oc_block, 0, jcp.ow * jcp.oc_block * sizeof(float));
                     }
                 } else {
                     const int _ocb = g * jcp.nb_load + ocb;
@@ -182,7 +187,7 @@ void _jit_sse42_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
                     p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
                     p.load_dim = this_block_size(ocb * jcp.oc_block, jcp.oc, load_step * jcp.oc_block);
 
-                    p.output_data = &ws_p[(((oh + h) + 1) % jcp.dw_conv_ker_h) * jcp.ow * jcp.oc_block];
+                    p.output_data = &ws_p[(((oh + h) + 1) % jcp_dw.kh) * jcp.ow * jcp.oc_block];
 
                     p.bias_data = &bias[_ocb * jcp.oc_block];
 
@@ -194,7 +199,7 @@ void _jit_sse42_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
 
                         p.reduce_dim = this_block_size(icb * jcp.ic_block, jcp.ic,
                                                        jcp.nb_reduce_blocking * jcp.ic_block);
-                        p.load_data = &weights[conf_.with_groups()
+                        p.load_data = &weights[pd()->with_groups()
                                                ? weights_d.blk_off(g, ocb, icb)
                                                : weights_d.blk_off(ocb, icb)];
 
@@ -210,8 +215,6 @@ void _jit_sse42_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
         };
 
         auto compute_row_dw = [&](const float* ws_p, int n, int ocb, int load_step, int dst_idx) {
-            const auto &jcp_dw = kernel_dw_->jcp;
-
             for (int chb = ocb; chb < ocb + load_step; chb++) {
                 auto par_conv_dw = jit_conv_call_s();
 
@@ -226,9 +229,11 @@ void _jit_sse42_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
                                        dst_idx/jcp_dw.stride_h*jcp_dw.ow*jcp_dw.ch_block];
 
                 par_conv_dw.kh_padding = jcp_dw.kh;
-                par_conv_dw.filt = &jcp.dw_conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
+                par_conv_dw.filt = &jcp_dw.conv_weights[chb * jcp_dw.kh * jcp_dw.kw * jcp_dw.ch_block];
                 par_conv_dw.bias = &dw_bias[chb * jcp_dw.ch_block];
                 par_conv_dw.ur_w = (size_t)(jcp_dw.ow);
+                par_conv_dw.oc_work = nstl::min((chb + 1) * jcp_dw.ch_block, (int)jcp_dw.oc) - chb*jcp_dw.ch_block;
+                par_conv_dw.oc_off = chb * jcp_dw.ch_block * sizeof(float);
 
                 kernel_dw_->jit_ker(&par_conv_dw);
             }
@@ -239,11 +244,12 @@ void _jit_sse42_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
         int start{0}, end{0};
         balance211(work_amount, nthr, ithr, start, end);
 
-        auto pbuf = dw_conv_buffer_ + ithr * dw_conv_buffer_size_;
+        auto dw_conv_buffer = scratchpad().get<data_t>(key_dw_conv_buffer);
+        size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block);
+        auto pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
 
         const int os_block = jcp.iw;
 
-
         int iwork = start;
         while (iwork < end) {
             int n{0}, g{0}, ocbb{0}, osb{0};
@@ -272,7 +278,7 @@ void _jit_sse42_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
                 compute_block_1x1(pbuf, n, g, oh + 1, ow, ih, iw, os, os_block, bcast_step, ocb, load_step, bcast_step);
             }
 
-            if ((oh % jcp.dw_conv_str_h == 0)) {
+            if ((oh % jcp_dw.stride_h == 0)) {
                 compute_row_dw(pbuf, n, ocb, load_step, oh);
             }
 
@@ -280,23 +286,25 @@ void _jit_sse42_1x1_convolution_fwd_t<with_relu>::execute_forward_fusing() {
         }
     };
 
-    if (conf_.want_padded_bias()) {
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            padded_bias_[oc] = bias[oc];
-        bias = padded_bias_;
-
-        for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
-            dw_padded_bias_[oc] = dw_bias[oc];
-        dw_bias = dw_padded_bias_;
+    if (pd()->wants_padded_bias()) {
+        auto padded_bias = scratchpad().get<data_t>(key_conv_padded_bias);
+        utils::array_copy(padded_bias, bias, jcp.oc_without_padding);
+        utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+                jcp.oc - jcp.oc_without_padding);
+        bias = padded_bias;
+
+        auto dw_padded_bias = scratchpad().get<data_t>(key_dw_conv_padded_bias);
+        utils::array_copy(dw_padded_bias, dw_bias, jcp.oc_without_padding);
+        utils::array_set(dw_padded_bias + jcp.oc_without_padding, 0.f,
+                         jcp.oc - jcp.oc_without_padding);
+        dw_bias = dw_padded_bias;
     }
 
     parallel(0, ker);
-}
 
-template void _jit_sse42_1x1_convolution_fwd_t<true>::execute_forward();
-template void _jit_sse42_1x1_convolution_fwd_t<false>::execute_forward();
-template void _jit_sse42_1x1_convolution_fwd_t<true>::execute_forward_fusing();
-template void _jit_sse42_1x1_convolution_fwd_t<false>::execute_forward_fusing();
+    if (pd()->wants_zero_pad_dst())
+        output_memory_primitive(0)->zero_pad();
+}
 
 }
 }