Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_sse42_conv_kernel_f32.cpp
index 32f1903..c192504 100644 (file)
@@ -29,6 +29,7 @@ namespace cpu {
 
 using namespace mkldnn::impl::prop_kind;
 using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
 using namespace mkldnn::impl::utils;
 
 using namespace Xbyak;
@@ -170,7 +171,7 @@ void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w,
         for (int jj = 0; jj < ur_w; jj++) {
             int o_off;
             if (jcp.with_dw_conv)
-                o_off = (ii * jcp.dw_conv_ker_h * ow + jj) * oc_blk;
+                o_off = (ii * jcp_dw.kh * ow + jj) * oc_blk;
             else
                 o_off = (ii * oh * ow + jj) * oc_blk;
 
@@ -206,7 +207,8 @@ void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w,
 
     Label skip_kh_loop;
     mov(kj, reg_kh);
-    if ((jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
+    if ((jcp.dilate_h >= jcp.ih)
+            || (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad)) {
         cmp(kj, 0);
         je(skip_kh_loop, T_NEAR);
     }
@@ -240,10 +242,6 @@ void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w,
     int depthwise_inj_idx = 0;
     const auto &p = attr_.post_ops_;
 
-    if (p.len_ == 0 && eltwise_injectors.size() == 1) {
-        eltwise_injectors[0]->compute_vector_range(1, oc_blocks * ur_w + 1);
-    }
-
     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
     for (int i = 0; i < end_idx; i++) {
         auto& post_op = p.entry_[i];
@@ -275,7 +273,7 @@ void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w,
         for (int jj = 0; jj < ur_w; jj++) {
             int o_off;
             if (jcp.with_dw_conv)
-                o_off = (ii * jcp.dw_conv_ker_h * ow + jj) * oc_blk;
+                o_off = (ii * jcp_dw.kh * ow + jj) * oc_blk;
             else
                 o_off = (ii * oh * ow + jj) * oc_blk;
 
@@ -284,8 +282,6 @@ void jit_sse42_conv_fwd_kernel_f32::width_blk_step(int ur_w,
         }
     }
 
-    L(done);
-
     mov(aux_reg_kernel, reg_kernel);
     mov(aux_reg_input, reg_input);
     add(aux_reg_kernel, sizeof(float) * 4);
@@ -359,12 +355,6 @@ inline void jit_sse42_conv_fwd_kernel_f32::solve_common(int oc_blocks)
 
 void jit_sse42_conv_fwd_kernel_f32::generate()
 {
-    if (jcp.with_eltwise) {
-        eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<sse42>(
-                this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
-        ));
-    }
-
     const auto &p = attr_.post_ops_;
     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
     for (int i = 0; i < end_idx; i++) {
@@ -431,24 +421,15 @@ bool jit_sse42_conv_fwd_kernel_f32::post_ops_ok(
     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
 
     switch (p.len_) {
-        case 0: return true; // no post_ops
-        case 1:
-            return true // sum OR eltwise OR dw_conv
-                   && !jcp.with_eltwise && (is_simple(0) || is_sum(0) || is_dw_conv(0));
-        case 2:
-            return true // sum->eltwise OR dw_conv->eltwise OR eltwise->dw_conv OR dw_conv->sum OR sum->depthwise OR
-                   // eltwise->depthwise OR depthwise->depthwise
-                   && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
-                                            (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
-                                            (is_simple(0) && is_simple(1)));
-        case 3:
-            return true // eltwise->dw_conv->eltwise OR dw_conv->sum->eltwise OR sum->eltwise->depthwise OR
-                   // sum->depthwise->eltwise OR sum->depthwise->depthwise
-                   && !jcp.with_eltwise && ((is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
-                                            (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
-                                            (is_sum(0) && is_simple(1) && is_simple(2)));
-        case 4: return true // eltwise->dw_conv->sum->eltwise
-                       && !jcp.with_eltwise && (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
+        case 0: return true;
+        case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
+        case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
+                       (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
+                       (is_simple(0) && is_simple(1));
+        case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
+                       (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
+                       (is_sum(0) && is_simple(1) && is_simple(2));
+        case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
         default: return false;
     }
 
@@ -458,7 +439,7 @@ bool jit_sse42_conv_fwd_kernel_f32::post_ops_ok(
 status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
-        const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
+        const primitive_attr_t &attr)
 {
     if (!mayiuse(sse42)) return status::unimplemented;
 
@@ -496,47 +477,26 @@ status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
 
     jcp.src_fmt = src_d.format();
     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
-    jcp.with_eltwise = with_relu;
-    jcp.eltwise_alg = mkldnn_eltwise_relu;
-    jcp.eltwise_alpha = relu_negative_slope;
 
     if (!post_ops_ok(jcp, attr))
         return status::unimplemented;
 
     const auto &p = attr.post_ops_;
-    jcp.with_dw_conv = false;
-    int dw_conv_ind = p.find(primitive_kind::convolution);
-    if (dw_conv_ind != -1) {
-        jcp.with_dw_conv = true;
-        jcp.dw_conv_in_h = p.entry_[dw_conv_ind].dw_conv.in_h;
-        jcp.dw_conv_in_w = p.entry_[dw_conv_ind].dw_conv.in_w;
-        jcp.dw_conv_ker_h = p.entry_[dw_conv_ind].dw_conv.ker_h;
-        jcp.dw_conv_ker_w = p.entry_[dw_conv_ind].dw_conv.ker_w;
-        jcp.dw_conv_str_h = p.entry_[dw_conv_ind].dw_conv.str_h;
-        jcp.dw_conv_str_w = p.entry_[dw_conv_ind].dw_conv.str_w;
-        jcp.dw_conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
-        jcp.dw_conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
-    }
 
+    int dw_conv_ind = p.find(primitive_kind::convolution);
+    jcp.with_dw_conv = dw_conv_ind != -1;
     if (jcp.with_dw_conv) {
-        int dw_conv_eltwise_ind = p.find(primitive_kind::eltwise, dw_conv_ind);
-        if (dw_conv_eltwise_ind != -1) {
-            jcp.dw_conv_with_eltwise = true;
-            jcp.dw_conv_eltwise_alg = p.entry_[dw_conv_eltwise_ind].eltwise.alg;
-            jcp.dw_conv_eltwise_alpha = p.entry_[dw_conv_eltwise_ind].eltwise.alpha;
-            jcp.dw_conv_eltwise_beta = p.entry_[dw_conv_eltwise_ind].eltwise.beta;
-        }
+        jcp.dw_conv_oh = jcp.oh;
+        jcp.dw_conv_ow = jcp.ow;
+        jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
+        jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
     }
 
     jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
-    if (jcp.with_dw_conv) {
-        jcp.dw_conv_with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
-    }
 
-    if (jcp.with_dw_conv) {
-        jcp.oh = jcp.dw_conv_in_h;
-        jcp.ow = jcp.dw_conv_in_w;
-    }
+    jcp.src_dt = cd.src_desc.data_type;
+    jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
+    jcp.dst_dt = cd.dst_desc.data_type;
 
     const bool flat = jcp.ic == 3 || jcp.ic == 1;
     const bool mimo = !flat;
@@ -613,6 +573,21 @@ status_t jit_sse42_conv_fwd_kernel_f32::init_conf(jit_conv_conf_t &jcp,
     return status::success;
 }
 
+void jit_sse42_conv_fwd_kernel_f32::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
+    if (jcp.with_bias && jcp.oc != jcp.oc_without_padding)
+        scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+
+    if (jcp.with_dw_conv) {
+        const int nthreads = mkldnn_get_max_threads();
+        size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
+        scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
+
+        if (jcp.oc != jcp.oc_without_padding)
+            scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);
+    }
+}
+
 }
 }
 }