Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_x8s8s32x_dw_conv_kernel.cpp
index c02bd80..d7b3994 100644 (file)
@@ -183,32 +183,6 @@ void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::apply_filter_unrolled(int ur_ch_b
 }
 
 template <cpu_isa_t isa>
-bool jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::maybe_relu(int position) {
-    using namespace primitive_kind;
-    const auto &p = attr_.post_ops_;
-
-    if (position == 0) {
-        /* relu before sum */
-        return false
-               || jcp.with_eltwise
-               || p.contain(eltwise, 0)
-               || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
-    } else if (position == 1) {
-        /* relu after sum */
-        const int sum_idx = p.contain(sum, 0)
-                            ? 0 : (p.contain(sum, 1) ? 1 : -1);
-        if (sum_idx == -1)
-            return false;
-
-        return false
-               || p.contain(eltwise, sum_idx + 1)
-               || jcp.dst_dt == data_type::u8;
-    }
-
-    return false;
-}
-
-template <cpu_isa_t isa>
 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
     Ymm ymm_dst = Ymm(vmm_dst.getIdx());
     Xmm xmm_dst = Xmm(vmm_dst.getIdx());
@@ -229,7 +203,7 @@ void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &o
             if (isa != sse42 && !scalar_store)
                 vpermq(ymm_dst, ymm_dst, 0x08);
 
-            uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst);
+            uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
 
             if (scalar_store) {
                 movq(reg_tmp_64, xmm_dst);
@@ -247,7 +221,7 @@ void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &o
             if (isa != sse42 && !scalar_store)
                 vpermq(ymm_dst, ymm_dst, 0x08);
 
-            uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst);
+            uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
 
             if (scalar_store) {
                 movq(reg_tmp_64, xmm_dst);
@@ -306,37 +280,89 @@ template <cpu_isa_t isa>
 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(int ur_ch_blocks, int ch_step, int ur_w) {
     int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
 
+    pop(reg_oc_off);
     pop(reg_scales_base);
 
-    uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
+    mov(imm_addr64, l_table);
+
+    const auto &p = attr_.post_ops_;
+    const int sum_idx = p.find(primitive_kind::sum);
+    const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
+
+    bool is_scalar_store = ch_step < jcp.ch_block;
+
     for (int r = 0; r < repeats; r++) {
-        if (ch_step < jcp.ch_block) {
+        for (int ii = 0; ii < ur_ch_blocks; ii++) {
+            if (jcp.with_bias) {
+                int b_off = ii * jcp.ch_block + r * (jcp.ch_block / 2);
+                cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], is_scalar_store);
+            }
+
             for (int jj = 0; jj < ur_w; jj++) {
-                Vmm vmm_dst = get_acc_reg(r * ur_w * ur_ch_blocks + jj);
+                Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks * ur_w + ur_w * ii + jj);
                 uni_vcvtdq2ps(vmm_dst, vmm_dst);
 
-                if (jcp.with_bias) {
-                    int b_off = r * (jcp.ch_block / 2);
-                    cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], true);
+                if (jcp.with_bias)
                     uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
-                }
 
-                int s_off = jcp.is_oc_scale * (r * (jcp.ch_block / 2));
-                cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], true);
+                int s_off = jcp.is_oc_scale * (ii * jcp.ch_block + r * (jcp.ch_block / 2));
+                cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], is_scalar_store);
                 uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
+            }
+        }
 
-                int o_off = jj * jcp.oc + r * (jcp.ch_block / 2);
-                if (jcp.with_sum) {
-                    uni_vpxor(vmm_prev_dst, vmm_prev_dst, vmm_prev_dst);
-                    cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], true);
-                    uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
+        int eltwise_inj_idx = 0;
+        int depthwise_inj_idx = 0;
+        for (int i = 0; i < p.len_; i++) {
+            int start_idx = 4 + r * ur_ch_blocks*ur_w;
+
+            auto& post_op = p.entry_[i];
+            if (post_op.is_eltwise()) {
+                eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, start_idx + ur_ch_blocks * ur_w);
+                eltwise_inj_idx++;
+            } else if (post_op.is_depthwise()) {
+                mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
+                mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
+
+                add(reg_d_weights, reg_oc_off);
+                add(reg_d_bias, reg_oc_off);
+
+                if (r == 1) {
+                    add(reg_d_weights, (jcp.ch_block / 2) * sizeof(float));
+                    add(reg_d_bias, (jcp.ch_block / 2) * sizeof(float));
                 }
 
-                if (maybe_relu(0))
-                    uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
+                for (int ii = 0; ii < ur_ch_blocks; ii++) {
+                    depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
+                            start_idx + ur_w * ii, start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
+
+                    add(reg_d_weights, jcp.ch_block * sizeof(float));
+                    add(reg_d_bias, jcp.ch_block * sizeof(float));
+                }
+
+                depthwise_inj_idx++;
+            } else if (post_op.is_sum(false)) {
+                for (int ii = 0; ii < ur_ch_blocks; ii++) {
+                    for (int jj = 0; jj < ur_w; jj++) {
+                        Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks*ur_w + ur_w * ii + jj);
+                        int o_off = ii * jcp.ch_block + jj * jcp.oc + r * (jcp.ch_block / 2);
+
+                        cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], is_scalar_store);
+
+                        if (p_sum_scale == 1.f) {
+                            uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
+                        } else {
+                            uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 0 * vlen]);
+                        }
+                    }
+                }
+            }
+        }
 
-                if (maybe_relu(1))
-                    uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
+        for (int ii = 0; ii < ur_ch_blocks; ii++) {
+            for (int jj = 0; jj < ur_w; jj++) {
+                Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks * ur_w + ur_w * ii + jj);
+                int o_off = ii * jcp.ch_block + jj * jcp.oc + r * (jcp.ch_block / 2);
 
                 if (jcp.dst_dt != data_type::f32) {
                     if (attr_.round_mode_ == round_mode::nearest)
@@ -348,55 +374,13 @@ void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(int ur_ch_blocks, int c
                         assert(!"unimplemented");
                 }
 
-                store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, true);
-            }
-        } else {
-            for (int ii = 0; ii < ur_ch_blocks; ii++) {
-                if (jcp.with_bias) {
-                    int b_off = ii * jcp.ch_block + r * (jcp.ch_block / 2);
-                    cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], false);
-                }
-
-                for (int jj = 0; jj < ur_w; jj++) {
-                    Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks*ur_w + ur_w * ii + jj);
-                    uni_vcvtdq2ps(vmm_dst, vmm_dst);
-
-                    if (jcp.with_bias)
-                        uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
-
-                    int s_off = jcp.is_oc_scale * (ii * jcp.ch_block + r * (jcp.ch_block / 2));
-                    cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], false);
-                    uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
-
-                    int o_off = ii * jcp.ch_block + jj * jcp.oc + r * (jcp.ch_block / 2);
-                    if (jcp.with_sum) {
-                        cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], false);
-                        uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
-                    }
-
-                    if (maybe_relu(0))
-                        uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
-
-                    if (maybe_relu(1))
-                        uni_vmaxps(vmm_dst, vmm_dst, vmm_zero);
-
-                    if (jcp.dst_dt != data_type::f32) {
-                        if (attr_.round_mode_ == round_mode::nearest)
-                            uni_vcvtps2dq(vmm_dst, vmm_dst);
-                        else if (attr_.round_mode_ == round_mode::down) {
-                            uni_vroundps(vmm_dst, vmm_dst, 1);
-                            uni_vcvtps2dq(vmm_dst, vmm_dst);
-                        } else
-                            assert(!"unimplemented");
-                    }
-
-                    store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
-                }
+                store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, is_scalar_store);
             }
         }
     }
 
     push(reg_scales_base);
+    push(reg_oc_off);
 }
 
 template <cpu_isa_t isa>
@@ -415,6 +399,7 @@ void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::loop_body(int ur_ch_blocks, int c
     push(reg_kernel_base);
     push(reg_ch_work);
     push(reg_scales_base);
+    push(reg_oc_off);
 
     L(unrolled_w_label); {
         int ur_w = jcp.ur_w;
@@ -458,6 +443,7 @@ void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::loop_body(int ur_ch_blocks, int c
 
     L(exit_label);
 
+    pop(reg_oc_off);
     pop(reg_scales_base);
     pop(reg_ch_work);
     pop(reg_kernel_base);
@@ -467,6 +453,24 @@ void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::loop_body(int ur_ch_blocks, int c
 
 template <cpu_isa_t isa>
 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::generate() {
+    const auto &p = attr_.post_ops_;
+    for (int i = 0; i < p.len_; i++) {
+        auto &post_op = p.entry_[i];
+        if (post_op.is_eltwise()) {
+            eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
+                    this,
+                    post_op.eltwise.alg,
+                    post_op.eltwise.alpha,
+                    post_op.eltwise.beta
+            ));
+        } else if (post_op.is_depthwise()) {
+            depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
+                    this,
+                    post_op.depthwise.alg
+            ));
+        }
+    }
+
     this->preamble();
 
     mov(reg_input_base, ptr[this->param1 + GET_OFF(src)]);
@@ -478,6 +482,7 @@ void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::generate() {
     mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
     mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
     mov(reg_ch_work, ptr[this->param1 + GET_OFF(ch_work)]);
+    mov(reg_oc_off, ptr[this->param1 + GET_OFF(oc_off)]);
 
     Label main_loop_label;
     Label tail_loop_label;
@@ -504,6 +509,7 @@ void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::generate() {
         add(reg_kernel_base, jcp.ch_block * jcp.kh * jcp.kw * jcp.typesize_in);
         add(reg_bias_base, jcp.ch_block * jcp.typesize_bia);
         add(reg_scales_base, jcp.is_oc_scale * jcp.ch_block * sizeof(float));
+        add(reg_oc_off, jcp.ch_block * sizeof(float));
 
         jmp(main_loop_label, T_NEAR);
     }
@@ -520,6 +526,7 @@ void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::generate() {
         add(reg_kernel_base, 1 * jcp.typesize_in);
         add(reg_bias_base, 1 * jcp.typesize_bia);
         add(reg_scales_base, jcp.is_oc_scale * 1 * sizeof(float));
+        add(reg_oc_off, 1 * sizeof(float));
 
         jmp(tail_loop_label, T_NEAR);
     }
@@ -527,6 +534,30 @@ void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::generate() {
     L(exit_label);
 
     this->postamble();
+
+    prepare_table();
+
+    for (auto& inj : eltwise_injectors)
+        inj->prepare_table();
+}
+
+template <cpu_isa_t isa>
+void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::prepare_table() {
+    const auto &p = attr_.post_ops_;
+    const int sum_idx = p.find(primitive_kind::sum);
+    const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
+
+    const int32_t cvals_sum_scale[] = {
+        float2int(p_sum_scale)
+    };
+
+    align(64);
+    L(l_table);
+    for (size_t i = 0; i < sizeof(cvals_sum_scale) / sizeof(cvals_sum_scale[0]); ++i) {
+        for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
+            dd(cvals_sum_scale[i]);
+        }
+    }
 }
 
 template <cpu_isa_t isa>
@@ -534,14 +565,18 @@ bool jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::post_ops_ok(
         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
     const auto &p = attr.post_ops_;
 
-    auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
-    auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
+    auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+    auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
+    auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); };
+    auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
 
     switch (p.len_) {
-    case 0: return true; // no post_ops
-    case 1: return !jcp.with_eltwise && (is_relu(0) || is_sum(0)); // sum OR relu
-    case 2: return !jcp.with_eltwise && (is_sum(0) && is_relu(1)); // sum->relu
-    default: return false;
+        case 0: return true;
+        case 1: return is_simple(0) || is_sum(0);
+        case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_sum(1)) ||
+                       (is_simple(0) && is_simple(1));
+        case 3: return (is_simple(0) && is_sum(1) && is_simple(2));
+        default: return false;
     }
 
     return false;
@@ -551,8 +586,7 @@ template <cpu_isa_t isa>
 status_t jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
         const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
         const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
-        const memory_desc_wrapper &bias_pd, const primitive_attr_t &attr,
-        bool with_relu, float relu_negative_slope)
+        const memory_desc_wrapper &bias_pd, const primitive_attr_t &attr)
 {
     if (!mayiuse(isa)) return status::unimplemented;
 
@@ -593,8 +627,6 @@ status_t jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jc
 
     jcp.src_fmt = src_d.format();
     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
-    jcp.with_eltwise = with_relu;
-    jcp.eltwise_alpha = relu_negative_slope;
 
     jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
 
@@ -610,13 +642,10 @@ status_t jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jc
 
     const auto &p = attr.post_ops_;
     jcp.with_sum = p.find(primitive_kind::sum) != -1;
-    if (!jcp.with_eltwise) {
-        int eltwise_ind = p.find(primitive_kind::eltwise);
-        if (eltwise_ind != -1) {
-            jcp.with_eltwise  = true;
-            jcp.eltwise_alpha = p.entry_[eltwise_ind].eltwise.alpha;
-        }
-    }
+    const int eltwise_ind = p.find(primitive_kind::eltwise);
+    jcp.with_eltwise = eltwise_ind != -1;
+    if (jcp.with_eltwise)
+        jcp.eltwise = p.entry_[eltwise_ind].eltwise;
 
     auto desired_act_fmt = nhwc;
     auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;