Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / gemm_convolution.cpp
index c403e45..154b5c3 100644 (file)
@@ -14,7 +14,6 @@
 * limitations under the License.
 *******************************************************************************/
 
-#include <common/primitive_attr.hpp>
 #include "mkldnn_types.h"
 
 #include "c_types_map.hpp"
@@ -22,7 +21,6 @@
 #include "utils.hpp"
 #include "type_helpers.hpp"
 #include "mkldnn_thread.hpp"
-
 #include "ref_eltwise.hpp"
 
 namespace mkldnn {
@@ -31,20 +29,22 @@ namespace cpu {
 
 using namespace mkldnn::impl::status;
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
-template <bool with_relu>
-void _gemm_convolution_fwd_t<with_relu>::execute_forward() {
+void gemm_convolution_fwd_t::execute_forward() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto bias = reinterpret_cast<const data_t *>(this->input_memory(2));
     auto dst = reinterpret_cast<data_t*>(this->memory());
 
-    jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
-    const int MB = conf_.MB();
+    auto col = scratchpad().get<data_t>(key_conv_gemm_col);
+
+    const auto &jcp = this->pd()->jcp_;
+    const int MB = pd()->MB();
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
 
     const int M = jcp.os * jcp.od;
     const size_t src_step = (src_d.blk_off(1) - src_d.off_l(0)) / jcp.ngroups;
@@ -53,60 +53,68 @@ void _gemm_convolution_fwd_t<with_relu>::execute_forward() {
     src += src_d.off_l(0);
     dst += dst_d.off_l(0);
 
+    assert(IMPLICATION(
+            jcp.id != 1, jcp.oh_block == jcp.oh && jcp.ow_block == jcp.ow));
+    assert(IMPLICATION(jcp.ow_block != jcp.ow, jcp.oh_block == 1));
+
     const int K = jcp.ic * jcp.ks;
     const int N = jcp.oc;
-    const int m = jcp.os;
-    const int LDA = jcp.im2col_sz ? m : M;
-
-    const data_t one = 1.0;
-
-    data_t *col = (jcp.im2col_sz)
-        ? (data_t *)this->scratchpad_->get()
-        : nullptr;
 
-    parallel_nd(jcp.im2col_sz * jcp.nthr,
-            [&](ptrdiff_t i) { col[i] = (data_t)0; });
+    if (jcp.im2col_sz && jcp.id != 1)
+        parallel_nd(jcp.im2col_sz * jcp.nthr,
+                [&](ptrdiff_t i) { col[i] = (data_t)0; });
 
-    const size_t work_amount = jcp.ngroups * MB * jcp.od;
+    const int nb_oh = div_up(jcp.oh, jcp.oh_block);
+    const int nb_ow = div_up(jcp.ow, jcp.ow_block);
+    const size_t work_amount = jcp.ngroups * MB * jcp.od * nb_oh * nb_ow;
     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
         data_t *_col = col + (ptrdiff_t)ithr * jcp.im2col_sz;
 
-        int g{0}, n{0}, od{0};
+        int g{ 0 }, n{ 0 }, od{ 0 }, ohb{ 0 }, owb{ 0 };
         size_t start = 0, end = 0;
 
         balance211(work_amount, nthr, ithr, start, end);
-        nd_iterator_init(start, g, jcp.ngroups, n, MB, od, jcp.od);
-
+        nd_iterator_init(start, g, jcp.ngroups, n, MB, od, jcp.od, ohb,
+                nb_oh, owb, nb_ow);
         for (size_t iwork = start; iwork < end; ++iwork) {
+            int oh = ohb * jcp.oh_block;
+            int ow = owb * jcp.ow_block;
             const data_t *_src = src + (n * jcp.ngroups + g) * src_step;
             const data_t *_weights = weights + g * weights_g_size;
-            data_t *_dst = dst + (n * jcp.ngroups + g) * dst_step;
-
+            data_t *_dst_im = dst + (n * jcp.ngroups + g) * dst_step;
+            const int h_step = nstl::min(jcp.oh_block, jcp.oh - oh);
+            const int w_step = nstl::min(jcp.ow_block, jcp.ow - ow);
             if (jcp.im2col_sz) {
                 if (jcp.id == 1)
-                    jit_gemm_convolution_utils::im2col(jcp, _src, _col);
+                    jit_gemm_convolution_utils::im2col(
+                            jcp, _src, _col, oh, h_step, ow, w_step);
                 else
                     jit_gemm_convolution_utils::im2col_3d(jcp, _src, _col, od);
             }
 
             const data_t one = 1.0;
+
+            const int m = h_step * w_step;
+            const int LDA = jcp.im2col_sz ? m : M;
+            data_t *_dst = _dst_im + od * jcp.os + oh * jcp.ow + ow;
+
             extended_sgemm("N", "N", &m, &N, &K, &one,
                     jcp.im2col_sz ? _col : _src + od * m, &LDA, _weights, &K,
-                    &this->beta_, _dst + od * m, &M);
+                    &this->beta_, _dst, &M);
 
-            const auto &p = conf_.attr()->post_ops_;
+            data_t *d = _dst;
+            const auto &p = pd()->attr()->post_ops_;
             bool need_bias = jcp.with_bias;
             if (use_fast_relu) {
-                data_t *d = _dst + od * m;
-
-                for (int oc = 0; oc < jcp.oc; ++oc) {
+                parallel_nd(jcp.oc, [&](const int oc) {
                     data_t b = need_bias ? bias[g * jcp.oc + oc] : 0;
+                    data_t *d_ = d + oc * M;
+                    PRAGMA_OMP_SIMD()
                     for (int oS = 0; oS < m; ++oS) {
-                        d[oS] += b;
-                        if (d[oS] < 0) d[oS] *= fast_relu_ns;
+                        d_[oS] += b;
+                        if (d_[oS] < 0) d_[oS] *= fast_relu_ns;
                     }
-                    d += M;
-                }
+                });
 
                 need_bias = false;
             } else if (p.len_ > 0) {
@@ -114,17 +122,17 @@ void _gemm_convolution_fwd_t<with_relu>::execute_forward() {
                 int depthwise_inj_idx = 0;
 
                 for (int i = 0; i < p.len_; i++) {
-                    data_t *d = _dst + od * m;
                     auto& post_op = p.entry_[i];
                     if (post_op.is_eltwise()) {
-                        for (int oc = 0; oc < jcp.oc; ++oc) {
+                        parallel_nd(jcp.oc, [&](const int oc) {
                             data_t b = need_bias ? bias[g * jcp.oc + oc] : 0;
+                            data_t *d_ = d + oc * M;
+                            PRAGMA_OMP_SIMD()
                             for (int oS = 0; oS < m; ++oS) {
-                                d[oS] += b;
-                                d[oS] = eltwise_injectors[eltwise_inj_idx]->compute_scalar(d[oS]);
+                                d_[oS] += b;
+                                d_[oS] = eltwise_injectors[eltwise_inj_idx]->compute_scalar(d_[oS]);
                             }
-                            d += M;
-                        }
+                        });
 
                         eltwise_inj_idx++;
                         need_bias = false;
@@ -132,16 +140,17 @@ void _gemm_convolution_fwd_t<with_relu>::execute_forward() {
                         auto depthwise_weights = post_op.depthwise.weights_data;
                         auto depthwise_bias = post_op.depthwise.biases_data;
 
-                        for (int oc = 0; oc < jcp.oc; ++oc) {
+                        parallel_nd(jcp.oc, [&](const int oc) {
                             data_t b = need_bias ? bias[g * jcp.oc + oc] : 0;
+                            data_t *d_ = d + oc * M;
+                            PRAGMA_OMP_SIMD()
                             for (int oS = 0; oS < m; ++oS) {
-                                d[oS] += b;
-                                d[oS] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(d[oS],
+                                d_[oS] += b;
+                                d_[oS] = depthwise_injectors[depthwise_inj_idx]->compute_scalar(d_[oS],
                                                                   depthwise_weights + g * jcp.oc + oc,
                                                                   depthwise_bias + g * jcp.oc + oc);
                             }
-                            d += M;
-                        }
+                        });
 
                         depthwise_inj_idx++;
                         need_bias = false;
@@ -150,46 +159,53 @@ void _gemm_convolution_fwd_t<with_relu>::execute_forward() {
             }
 
             if (need_bias) {
-                data_t *d = _dst + od * m;
-
-                for (int oc = 0; oc < jcp.oc; ++oc) {
+                parallel_nd(jcp.oc, [&](const int oc) {
                     data_t b = bias[g * jcp.oc + oc];
+                    data_t *d_ = d + oc * M;
+                    PRAGMA_OMP_SIMD()
                     for (int oS = 0; oS < m; ++oS) {
-                        d[oS] += b;
+                        d_[oS] += b;
                     }
-                    d += M;
-                }
+                });
             }
 
-            nd_iterator_step(g, jcp.ngroups, n, MB, od, jcp.od);
+            nd_iterator_step(g, jcp.ngroups, n, MB, od, jcp.od, ohb, nb_oh,
+                    owb, nb_ow);
         }
     });
 }
 
-void gemm_convolution_bwd_data_t::execute_backward_data() {
+void gemm_convolution_bwd_data_t::execute_backward_data() const {
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto weights = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto diff_src = reinterpret_cast<data_t*>(this->memory());
 
-    jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
-    const int MB = conf_.MB();
+    auto col = scratchpad().get<data_t>(key_conv_gemm_col);
+
+    const auto &jcp = this->pd()->jcp_;
+    const int MB = pd()->MB();
 
     const int M = jcp.os * jcp.od;
-    const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
-    const size_t dst_step = jcp.oc * M;
+    const size_t src_step_to_clean = jcp.ic * jcp.ih * jcp.iw * jcp.id;
+    const memory_desc_wrapper diff_src_d(pd()->diff_src_pd());
+    const memory_desc_wrapper diff_dst_d(pd()->diff_dst_pd());
+    const size_t src_step = diff_src_d.blk_off(1) / jcp.ngroups;
+    const size_t dst_step = diff_dst_d.blk_off(1) / jcp.ngroups;
     const size_t weights_g_size = jcp.ic * jcp.oc * jcp.ks;
 
     const int m = jcp.os;
     const int K = jcp.oc;
     const int N = jcp.ic * jcp.ks;
     const int LDC = jcp.im2col_sz ? m : M;
-    data_t *col = jcp.im2col_sz ? (data_t *)this->scratchpad_->get() : nullptr;
 
     const size_t work_amount = (size_t)jcp.ngroups * MB;
 
     if (jcp.id > 1) {
-        const ptrdiff_t diff_src_sz = (ptrdiff_t)(work_amount * src_step);
-        parallel_nd(diff_src_sz, [&](ptrdiff_t i) { diff_src[i] = (data_t)0; });
+        for (size_t j = 0; j < work_amount; j++) {
+            int j_step = src_step * j;
+            const ptrdiff_t diff_src_sz = (ptrdiff_t)(src_step_to_clean);
+            parallel_nd(diff_src_sz, [&](ptrdiff_t i) { diff_src[j_step + i] = (data_t)0; });
+        }
     }
 
     parallel(jcp.nthr, [&](const int ithr, const int nthr) {
@@ -201,7 +217,7 @@ void gemm_convolution_bwd_data_t::execute_backward_data() {
         nd_iterator_init(start, g, jcp.ngroups, n, MB);
         for (size_t iwork = start; iwork < end; ++iwork) {
 
-            data_t *_diff_src = diff_src + (n * jcp.ngroups + g)*src_step;
+            data_t *_diff_src = diff_src + (n * jcp.ngroups + g) * src_step;
             const data_t *_weights = weights + g * weights_g_size;
             for (int od = 0; od < jcp.od; ++od) {
                 const data_t *_diff_dst = diff_dst + (n * jcp.ngroups + g)
@@ -226,13 +242,17 @@ void gemm_convolution_bwd_data_t::execute_backward_data() {
     });
 }
 
-void gemm_convolution_bwd_weights_t::execute_backward_weights() {
+void gemm_convolution_bwd_weights_t::execute_backward_weights() const {
     auto src = reinterpret_cast<const data_t *>(this->input_memory(0));
     auto diff_dst = reinterpret_cast<const data_t *>(this->input_memory(1));
     auto diff_weights = reinterpret_cast<data_t*>(this->memory(0));
     auto diff_bias = reinterpret_cast<data_t *>(this->memory(1));
 
-    jit_gemm_conv_conf_t &jcp = this->conf_.jcp_;
+    auto col = scratchpad().get<data_t>(key_conv_gemm_col);
+    auto wei_reduction = scratchpad().get<data_t>(key_conv_wei_reduction);
+
+    const jit_gemm_conv_conf_t &jcp = this->pd()->jcp_;
+
     const int K = jcp.os * jcp.od;
     const size_t src_step = jcp.ic * jcp.ih * jcp.iw * jcp.id;
     const size_t dst_step = jcp.oc * K;
@@ -243,15 +263,6 @@ void gemm_convolution_bwd_weights_t::execute_backward_weights() {
     const int M = jcp.ic * jcp.ks;
     const int LDA = jcp.im2col_sz ? k : K;
 
-    data_t *col = nullptr, *wei_reduction = nullptr;
-    ptrdiff_t wei_offset = 0;
-    if (jcp.im2col_sz) {
-        col = (data_t *)this->scratchpad_->get();
-        wei_offset = jcp.im2col_sz * jcp.nthr;
-    }
-    if (jcp.need_wei_reduction)
-        wei_reduction = (data_t *)this->scratchpad_->get() + wei_offset;
-
     parallel_nd(jcp.im2col_sz * jcp.nthr,
             [&](ptrdiff_t i) { col[i] = (data_t)0; });
 
@@ -289,7 +300,8 @@ void gemm_convolution_bwd_weights_t::execute_backward_weights() {
 
                     if (jcp.im2col_sz) {
                         if (jcp.id == 1)
-                            jit_gemm_convolution_utils::im2col(jcp, _src, _col);
+                            jit_gemm_convolution_utils::im2col(
+                                    jcp, _src, _col, 0, jcp.oh, 0, jcp.ow);
                         else
                             jit_gemm_convolution_utils::im2col_3d(jcp, _src,
                                 _col, od);
@@ -331,13 +343,10 @@ void gemm_convolution_bwd_weights_t::execute_backward_weights() {
                 }
             }
             diff_bias[g*jcp.oc+oc] = db;
-            nd_iterator_step(g, jcp.ngroups, oc, jcp.oc);
         });
     }
 }
 
-template struct _gemm_convolution_fwd_t<true>;
-template struct _gemm_convolution_fwd_t<false>;
 }
 }
 }