Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx2_1x1_conv_kernel_f32.cpp
index 9ef2558..73f01f5 100644 (file)
 * limitations under the License.
 *******************************************************************************/
 
+#include <assert.h>
+
 #include "c_types_map.hpp"
+#include "memory_tracking.hpp"
 #include "nstl.hpp"
 #include "type_helpers.hpp"
 #include "utils.hpp"
+
 #include "cpu_memory.hpp"
 
 #include "jit_avx2_1x1_conv_kernel_f32.hpp"
@@ -140,7 +144,7 @@ void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop(
         default:
             if (jcp.with_dw_conv) {
                 return ptr[aux_reg_output_data +
-                           (i * jcp.dw_conv_ker_h * jcp.ow + j) * jcp.oc_block * sizeof(float)];
+                           (i * jcp_dw.kh * jcp.ow + j) * jcp.oc_block * sizeof(float)];
             } else {
                 return ptr[aux_reg_output_data +
                            (i * jcp.os + j) * jcp.oc_block * sizeof(float)];
@@ -176,7 +180,7 @@ void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop(
     };
 
     auto store = [=]() {
-        Label store_done, store_noadd;
+        Label store_noadd;
 
         if (!jcp.with_sum) {
             test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
@@ -198,9 +202,6 @@ void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop(
         int eltwise_inj_idx = 0;
         int depthwise_inj_idx = 0;
         const auto &p = attr_.post_ops_;
-        if (p.len_ == 0 && eltwise_injectors.size() == 1) {
-            eltwise_injectors[0]->compute_vector_range(0, ur * load_loop_blk);
-        }
 
         int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
         for (int i = 0; i < end_idx; i++) {
@@ -236,8 +237,6 @@ void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop(
             for (int i = 0; i < load_loop_blk; ++i) {
                 vmovups(output_ptr(i, j), vreg_accum(i, j));
             }
-
-        L(store_done);
     };
 
     auto fma_block = [=](bool last_block) {
@@ -247,9 +246,8 @@ void jit_avx2_1x1_conv_kernel_f32::generate_reduce_loop(
                     if (mayiuse(avx2))
                         vfmadd231ps(vreg_accum(i, j), vreg_load(i), vreg_bcast);
                     else { // Intel(R) Advanced Vector Extensions (Intel(R) AVX) support
-                        auto tmp = vmask;
-                        vmulps(tmp, vreg_bcast, vreg_load(i));
-                        vaddps(vreg_accum(i, j), vreg_accum(i, j), tmp);
+                        vmulps(vtmp, vreg_bcast, vreg_load(i));
+                        vaddps(vreg_accum(i, j), vreg_accum(i, j), vtmp);
                     }
                     if (j == ur - 1 && !(last_block
                                 && u == jcp.reduce_loop_unroll - 1))
@@ -347,12 +345,6 @@ void jit_avx2_1x1_conv_kernel_f32::generate_diff_bias_loop(int load_loop_blk)
 
 void jit_avx2_1x1_conv_kernel_f32::generate()
 {
-    if (jcp.with_eltwise) {
-        eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx2>(
-                this, jcp.eltwise_alg, jcp.eltwise_alpha, 0
-        ));
-    }
-
     const auto &p = attr_.post_ops_;
     int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
     for (int i = 0; i < end_idx; i++) {
@@ -485,24 +477,15 @@ bool jit_avx2_1x1_conv_kernel_f32::post_ops_ok(
     auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
 
     switch (p.len_) {
-        case 0: return true; // no post_ops
-        case 1:
-            return true // sum OR eltwise OR dw_conv
-                   && !jcp.with_eltwise && (is_simple(0) || is_sum(0) || is_dw_conv(0));
-        case 2:
-            return true // sum->eltwise OR dw_conv->eltwise OR eltwise->dw_conv OR dw_conv->sum OR sum->depthwise OR
-                   // eltwise->depthwise OR depthwise->depthwise
-                   && !jcp.with_eltwise && ((is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
-                                            (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
-                                            (is_simple(0) && is_simple(1)));
-        case 3:
-            return true // eltwise->dw_conv->eltwise OR dw_conv->sum->eltwise OR sum->eltwise->depthwise OR
-                   // sum->depthwise->eltwise OR sum->depthwise->depthwise
-                   && !jcp.with_eltwise && ((is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
-                                            (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
-                                            (is_sum(0) && is_simple(1) && is_simple(2)));
-        case 4: return true // eltwise->dw_conv->sum->eltwise
-                       && !jcp.with_eltwise && (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
+        case 0: return true;
+        case 1: return is_simple(0) || is_sum(0) || is_dw_conv(0);
+        case 2: return (is_sum(0) && is_simple(1)) || (is_dw_conv(0) && is_eltwise(1)) ||
+                       (is_eltwise(0) && is_dw_conv(1)) || (is_dw_conv(0) && is_sum(1)) ||
+                       (is_simple(0) && is_simple(1));
+        case 3: return (is_eltwise(0) && is_dw_conv(1) && is_eltwise(2)) ||
+                       (is_dw_conv(0) && is_sum(1) && is_eltwise(2)) ||
+                       (is_sum(0) && is_simple(1) && is_simple(2));
+        case 4: return (is_eltwise(0) && is_dw_conv(1) && is_sum(2) && is_eltwise(3));
         default: return false;
     }
 
@@ -512,7 +495,7 @@ bool jit_avx2_1x1_conv_kernel_f32::post_ops_ok(
 status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
-        const primitive_attr_t &attr, bool with_relu, float relu_negative_slope)
+        const primitive_attr_t &attr)
 {
     if (!mayiuse(avx)) return status::unimplemented;
 
@@ -547,51 +530,41 @@ status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
     jcp.src_fmt = src_d.format();
     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
 
-    jcp.with_eltwise = with_relu;
-    jcp.eltwise_alg = mkldnn_eltwise_relu;
-    jcp.eltwise_alpha = relu_negative_slope;
-
-    if (!post_ops_ok(jcp, attr)) {
+    if (!post_ops_ok(jcp, attr))
         return status::unimplemented;
-    }
 
     const auto &p = attr.post_ops_;
-    jcp.with_dw_conv = false;
+
     int dw_conv_ind = p.find(primitive_kind::convolution);
-    if (dw_conv_ind != -1) {
-        jcp.with_dw_conv = true;
-        jcp.dw_conv_in_h = p.entry_[dw_conv_ind].dw_conv.in_h;
-        jcp.dw_conv_in_w = p.entry_[dw_conv_ind].dw_conv.in_w;
-        jcp.dw_conv_ker_h = p.entry_[dw_conv_ind].dw_conv.ker_h;
-        jcp.dw_conv_ker_w = p.entry_[dw_conv_ind].dw_conv.ker_w;
-        jcp.dw_conv_str_h = p.entry_[dw_conv_ind].dw_conv.str_h;
-        jcp.dw_conv_str_w = p.entry_[dw_conv_ind].dw_conv.str_w;
-        jcp.dw_conv_weights = p.entry_[dw_conv_ind].dw_conv.weights_data;
-        jcp.dw_conv_biases = p.entry_[dw_conv_ind].dw_conv.biases_data;
+    jcp.with_dw_conv = dw_conv_ind != -1;
+    jcp.with_dw_conv = dw_conv_ind != -1;
+    if (jcp.with_dw_conv) {
+        jcp.dw_conv_oh = jcp.oh;
+        jcp.dw_conv_ow = jcp.ow;
+        jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
+        jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
     }
 
     if (jcp.with_dw_conv && !mayiuse(avx2))
         return status::unimplemented;
 
-    if (jcp.with_dw_conv) {
-        int dw_conv_eltwise_ind = p.find(primitive_kind::eltwise, dw_conv_ind);
-        if (dw_conv_eltwise_ind != -1) {
-            jcp.dw_conv_with_eltwise = true;
-            jcp.dw_conv_eltwise_alg = p.entry_[dw_conv_eltwise_ind].eltwise.alg;
-            jcp.dw_conv_eltwise_alpha = p.entry_[dw_conv_eltwise_ind].eltwise.alpha;
-            jcp.dw_conv_eltwise_beta = p.entry_[dw_conv_eltwise_ind].eltwise.beta;
+    if (!mayiuse(avx2)) {
+        for (int i = 0; i < p.len_; i++) {
+            auto &post_op = p.entry_[i];
+            if (post_op.is_eltwise()) {
+                if (post_op.eltwise.alg != alg_kind::eltwise_relu)
+                    return status::unimplemented;
+            } else if (post_op.is_depthwise()) {
+                return status::unimplemented;
+            }
         }
     }
 
     jcp.with_sum = p.find(primitive_kind::sum, 0, dw_conv_ind) != -1;
-    if (jcp.with_dw_conv) {
-        jcp.dw_conv_with_sum = p.find(primitive_kind::sum, dw_conv_ind) != -1;
-    }
 
-    if (jcp.with_dw_conv) {
-        jcp.oh = jcp.dw_conv_in_h;
-        jcp.ow = jcp.dw_conv_in_w;
-    }
+    jcp.src_dt = cd.src_desc.data_type;
+    jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
+    jcp.dst_dt = cd.dst_desc.data_type;
 
     jcp.os = jcp.oh * jcp.ow;
     jcp.is = jcp.ih * jcp.iw;
@@ -770,6 +743,24 @@ status_t jit_avx2_1x1_conv_kernel_f32::init_conf(jit_1x1_conv_conf_t &jcp,
     return status::success;
 }
 
+void jit_avx2_1x1_conv_kernel_f32::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad,
+        const jit_1x1_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw) {
+    using namespace mkldnn::impl::memory_tracking::names;
+
+    if (jcp.prop_kind != backward_data && jcp.oc != jcp.oc_without_padding)
+        scratchpad.book(key_conv_padded_bias, sizeof(float) * jcp.oc);
+
+    if (jcp.with_dw_conv) {
+        const int nthreads = mkldnn_get_max_threads();
+        size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * (jcp.oc / jcp.oc_block);
+        scratchpad.book(key_dw_conv_buffer, sizeof(float) * dw_conv_buffer_size_ * nthreads);
+
+        if (jcp.oc != jcp.oc_without_padding)
+            scratchpad.book(key_dw_conv_padded_bias, sizeof(float) * jcp.oc);
+    }
+}
+
 }
 }
 }