Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_sse42_1x1_conv_kernel_f32.cpp
index cbce262..3ba4715 100644 (file)
@@ -139,7 +139,7 @@ void jit_sse42_1x1_conv_kernel_f32::generate_reduce_loop(
         default:
             if (jcp.with_dw_conv)
                 return ptr[aux_reg_output_data +
-                           (i * jcp.dw_conv_ker_h * jcp.ow + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)];
+                           (i * jcp_dw.kh * jcp.ow + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)];
             else
                 return ptr[aux_reg_output_data +
                     (i * jcp.os + j) * jcp.oc_block * sizeof(float) + n*4*sizeof(float)];
@@ -185,7 +185,6 @@ void jit_sse42_1x1_conv_kernel_f32::generate_reduce_loop(
     }; // init()
 
     auto store = [=]() {
-        Label store_done;
         Label store_noadd;
 
         if (!jcp.with_sum) {
@@ -203,16 +202,13 @@ void jit_sse42_1x1_conv_kernel_f32::generate_reduce_loop(
 
         L(store_noadd);
 
-        Label store_norelu;
+        Label store_no_postops;
         test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
-        jz(store_norelu, T_NEAR);
+        jz(store_no_postops, T_NEAR);
 
         int eltwise_inj_idx = 0;
         int depthwise_inj_idx = 0;
         const auto &p = attr_.post_ops_;
-        if (p.len_ == 0 && eltwise_injectors.size() == 1) {
-            eltwise_injectors[0]->compute_vector_range(1, 2 * ur * load_loop_blk + 1);
-        }
 
         int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
         for (int i = 0; i < end_idx; i++) {
@@ -244,15 +240,13 @@ void jit_sse42_1x1_conv_kernel_f32::generate_reduce_loop(
             }
         }
 
-        L(store_norelu);
+        L(store_no_postops);
 
         for (int j = 0; j < ur; ++j)
             for (int i = 0; i < load_loop_blk; ++i) {
                 movups(output_ptr(i, j, 0), reg_accum(i, j, 0));
                 movups(output_ptr(i, j, 1), reg_accum(i, j, 1));
             }
-
-        L(store_done);
     };
 
     auto fma_block = [=](bool last_block) {
@@ -375,12 +369,6 @@ void jit_sse42_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk)
 
 void jit_sse42_1x1_conv_kernel_f32::generate()
 {
-    if (jcp.with_eltwise) {
-        eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<sse42>(
-                this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
-        ));
-    }
-
     const auto &p = attr_.post_ops_;
     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
     for (int i = 0; i < end_idx; i++) {
@@ -513,24 +501,15 @@ bool jit_sse42_1x1_conv_kernel_f32::post_ops_ok(
     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
 
     switch (p.len_) {
-        case 0: return true; // no post_ops
-        case 1:
-            return true // sum OR eltwise OR dw_conv
-                   && !jcp.with_eltwise && (is_simple(0) || is_sum(0) || is_dw_conv(0));
-        case 2:
-            return true // sum->eltwise OR dw_conv->eltwise OR eltwise->dw_conv OR dw_conv->sum OR sum->depthwise OR
-                   // eltwise->depthwise OR depthwise->depthwise
-                   && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
-                                            (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
-                                            (is_simple(0) && is_simple(1)));
-        case 3:
-            return true // eltwise->dw_conv->eltwise OR dw_conv->sum->eltwise OR sum->eltwise->depthwise OR
-                   // sum->depthwise->eltwise OR sum->depthwise->depthwise
-                   && !jcp.with_eltwise && ((is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
-                                            (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
-                                            (is_sum(0) && is_simple(1) && is_simple(2)));
-        case 4: return true // eltwise->dw_conv->sum->eltwise
-                       && !jcp.with_eltwise && (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
+        case 0: return true;
+        case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
+        case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
+                       (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
+                       (is_simple(0) && is_simple(1));
+        case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
+                       (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
+                       (is_sum(0) && is_simple(1) && is_simple(2));
+        case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
         default: return false;
     }
 
@@ -540,7 +519,7 @@ bool jit_sse42_1x1_conv_kernel_f32::post_ops_ok(
 status_t jit_sse42_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
-        const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
+        const primitive_attr_t &attr)
 {
     if (!mayiuse(sse42))
         return status::unimplemented;
@@ -576,47 +555,25 @@ status_t jit_sse42_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
     jcp.src_fmt = src_d.format();
     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
 
-    jcp.with_eltwise = with_relu;
-    jcp.eltwise_alg = mkldnn_eltwise_relu;
-    jcp.eltwise_alpha = relu_negative_slope;
-
     if (!post_ops_ok(jcp, attr))
         return status::unimplemented;
 
     const auto &p = attr.post_ops_;
-    jcp.with_dw_conv = false;
-    int dw_conv_ind = p.find(primitive_kind::convolution);
-    if (dw_conv_ind != -1) {
-        jcp.with_dw_conv = true;
-        jcp.dw_conv_in_h = p.entry_[dw_conv_ind].dw_conv.in_h;
-        jcp.dw_conv_in_w = p.entry_[dw_conv_ind].dw_conv.in_w;
-        jcp.dw_conv_ker_h = p.entry_[dw_conv_ind].dw_conv.ker_h;
-        jcp.dw_conv_ker_w = p.entry_[dw_conv_ind].dw_conv.ker_w;
-        jcp.dw_conv_str_h = p.entry_[dw_conv_ind].dw_conv.str_h;
-        jcp.dw_conv_str_w = p.entry_[dw_conv_ind].dw_conv.str_w;
-        jcp.dw_conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
-        jcp.dw_conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
-    }
 
+    int dw_conv_ind = p.find(primitive_kind::convolution);
+    jcp.with_dw_conv = dw_conv_ind != -1;
     if (jcp.with_dw_conv) {
-        int dw_conv_eltwise_ind = p.find(primitive_kind::eltwise, dw_conv_ind);
-        if (dw_conv_eltwise_ind != -1) {
-            jcp.dw_conv_with_eltwise = true;
-            jcp.dw_conv_eltwise_alg = p.entry_[dw_conv_eltwise_ind].eltwise.alg;
-            jcp.dw_conv_eltwise_alpha = p.entry_[dw_conv_eltwise_ind].eltwise.alpha;
-            jcp.dw_conv_eltwise_beta = p.entry_[dw_conv_eltwise_ind].eltwise.beta;
-        }
+        jcp.dw_conv_oh = jcp.oh;
+        jcp.dw_conv_ow = jcp.ow;
+        jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
+        jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
     }
 
     jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
-    if (jcp.with_dw_conv) {
-        jcp.dw_conv_with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
-    }
 
-    if (jcp.with_dw_conv) {
-        jcp.oh = jcp.dw_conv_in_h;
-        jcp.ow = jcp.dw_conv_in_w;
-    }
+    jcp.src_dt = cd.src_desc.data_type;
+    jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
+    jcp.dst_dt = cd.dst_desc.data_type;
 
     jcp.os = jcp.oh * jcp.ow;
     jcp.is = jcp.ih * jcp.iw;
@@ -791,6 +748,24 @@ status_t jit_sse42_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
     return status::success;
 }
 
+void jit_sse42_1x1_conv_kernel_f32::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad,
+        const jit_1x1_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
+    using namespace mkldnn::impl::memory_tracking::names;
+
+    if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding)
+        scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+
+    if (jcp.with_dw_conv) {
+        const int nthreads = mkldnn_get_max_threads();
+        size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block);
+        scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
+
+        if (jcp.oc != jcp.oc_without_padding)
+            scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);
+    }
+}
+
 }
 }
 }