Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_x8s8s32x_1x1_conv_kernel.cpp
index 40ca5f0..011db24 100644 (file)
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *******************************************************************************/
-#include <float.h>
+
+#include <assert.h>
+
 #include "c_types_map.hpp"
+#include "memory_tracking.hpp"
 #include "nstl.hpp"
 #include "type_helpers.hpp"
-#include "mkldnn_thread.hpp"
 #include "utils.hpp"
+
 #include "cpu_memory.hpp"
 
 #include "jit_uni_1x1_conv_utils.hpp"
@@ -35,32 +38,6 @@ using namespace mkldnn::impl::utils;
 
 using namespace Xbyak;
 
-bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::maybe_relu(int position)
-{
-    using namespace primitive_kind;
-    const auto &p = attr_.post_ops_;
-
-    if (position == 0) {
-        /* relu before sum */
-        return false
-            || jcp.with_eltwise
-            || p.contain(eltwise, 0)
-            || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
-    } else if (position == 1) {
-        /* relu after sum */
-        const int sum_idx = p.contain(sum, 0)
-            ? 0 : (p.contain(sum, 1) ? 1 : -1);
-        if (sum_idx == -1)
-            return false;
-
-        return false
-            || p.contain(eltwise, sum_idx + 1)
-            || jcp.dst_dt == data_type::u8;
-    }
-
-    return false;
-}
-
 void jit_avx512_core_x8s8s32x_1x1_conv_kernel::bcast_loop(int load_loop_blk)
 {
     mov(aux1_reg_bcast_data, reg_bcast_data);
@@ -131,7 +108,7 @@ void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
     };
 
     auto vreg_accum = [=](int i_load, int i_ur) {
-        return Zmm(i_ur * load_loop_blk + i_load);
+        return Zmm(i_ur + i_load * ur);
     };
 
     auto zmm_bias_alpha = [=]() {
@@ -242,23 +219,60 @@ void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
 
                 zmm_t mask_zmm = mask_flag ? r | ktail_mask | T_z : r;
                 vmulps(mask_zmm, r, scale_ptr(i_load));
-                if (maybe_relu(0)) {
-                    vpxord(zmm_zero, zmm_zero, zmm_zero);
-                    vmaxps(r, zmm_zero, r);
-                }
-                if (p_sum_scale) { // post_op: sum
-                    vpxord(zmm_zero, zmm_zero, zmm_zero);
-                    auto zmm_prev_dst = zmm_zero;
+            }
+        }
+
+        int eltwise_inj_idx = 0;
+        int depthwise_inj_idx = 0;
+        for (int i = 0; i < p.len_; i++) {
+            auto& post_op = p.entry_[i];
+            if (post_op.is_eltwise()) {
+                eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur * load_loop_blk);
+
+                eltwise_inj_idx++;
+            } else if (post_op.is_depthwise()) {
+                mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
+                mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
+
+                add(reg_d_weights, reg_oc_off);
+                add(reg_d_bias, reg_oc_off);
 
-                    cvt2ps(jcp.dst_dt, zmm_prev_dst, output_ptr(i_load, i_ur),
-                        mask_flag);
+                for (int k = 0; k < load_loop_blk; k++) {
+                    depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
+                            k * ur, k * ur + ur, reg_d_weights, reg_d_bias);
 
-                    if (*p_sum_scale == 1.f)
-                        vaddps(r, zmm_prev_dst);
-                    else
-                        vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
+                    add(reg_d_weights, jcp.oc_block * sizeof(float));
+                    add(reg_d_bias, jcp.oc_block * sizeof(float));
                 }
-                if (maybe_relu(1)) {
+
+                depthwise_inj_idx++;
+            } else if (post_op.is_sum(false)) {
+                for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
+                    const bool mask_flag = mask_flag_in &&
+                                               i_load == load_loop_blk - 1;
+                    for (int i_ur = 0; i_ur < ur; ++i_ur) {
+                        vpxord(zmm_zero, zmm_zero, zmm_zero);
+                        auto zmm_prev_dst = zmm_zero;
+
+                        auto r = vreg_accum(i_load, i_ur);
+                        cvt2ps(jcp.dst_dt, zmm_prev_dst, output_ptr(i_load, i_ur),
+                            mask_flag);
+
+                        if (*p_sum_scale == 1.f)
+                            vaddps(r, zmm_prev_dst);
+                        else
+                            vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
+                    }
+                }
+            }
+        }
+
+        for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
+            const bool mask_flag = mask_flag_in &&
+                                       i_load == load_loop_blk - 1;
+            for (int i_ur = 0; i_ur < ur; ++i_ur) {
+                auto r = vreg_accum(i_load, i_ur);
+                if (jcp.dst_dt == data_type::u8) {
                     vpxord(zmm_zero, zmm_zero, zmm_zero);
                     vmaxps(r, zmm_zero, r);
                 }
@@ -274,6 +288,7 @@ void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
             for (int i_ur = 0; i_ur < ur; ++i_ur) {
                 auto r = vreg_accum(i_load, i_ur);
                 zmm_t r_zmm = mask_flag ? r | ktail_mask : r;
+
                 switch (jcp.dst_dt) {
                 case data_type::f32:
                 case data_type::s32:
@@ -335,6 +350,8 @@ void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
     Label reduce_loop;
     Label reduce_loop_tail;
 
+    push(reg_oc_off);
+
     mov(aux_reg_load_data, reg_load_data);
 
     mov(aux_reg_bcast_data, aux1_reg_bcast_data);
@@ -359,6 +376,8 @@ void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
         fma_block(false);
     }
 
+    pop(reg_oc_off);
+
     if (jcp.oc_without_padding != jcp.oc) {
         Label end_store, common_store;
         mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
@@ -388,6 +407,24 @@ void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
 
 void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate()
 {
+    const auto &p = attr_.post_ops_;
+    for (int i = 0; i < p.len_; i++) {
+        auto &post_op = p.entry_[i];
+        if (post_op.is_eltwise()) {
+            eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
+                    this,
+                    post_op.eltwise.alg,
+                    post_op.eltwise.alpha,
+                    post_op.eltwise.beta
+            ));
+        } else if (post_op.is_depthwise()) {
+            depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_common>(
+                    this,
+                    post_op.depthwise.alg
+            ));
+        }
+    }
+
     preamble();
 
     xor_(reg_scratch, reg_scratch);
@@ -423,7 +460,7 @@ void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate()
     mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work);
     mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
     mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
-
+    mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
 
     auto load_loop_body = [=](int load_loop_blk) {
         bcast_loop(load_loop_blk);
@@ -451,6 +488,7 @@ void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate()
         add(reg_output_data,
             load_loop_blk * jcp.load_block * jcp.typesize_out);
         sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
+        add(reg_oc_off, load_loop_blk * jcp.oc_block * sizeof(float));
     };
 
     const int simd_w = 16;
@@ -480,6 +518,12 @@ void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate()
                     cmp(reg_load_loop_work, 0);
                     je(load_loop_blk[num_ur_cases], T_NEAR);
                 }
+
+                for (int _i = 1; _i <= label_idx + 1; _i++) {
+                    prefetcht0(ptr [ reg_load_data + _i * jcp.ic * jcp.oc_block ]);
+                    prefetcht1(ptr [ reg_output_data + _i * jcp.oc_block ]);
+                }
+
                 load_loop_body(label_idx + 1);
                 if (label_idx - 1 > 0) {
                     cmp(reg_load_loop_work, 2 * label_idx * simd_w);
@@ -503,6 +547,9 @@ void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate()
     add(rsp, stack_space_needed);
 
     postamble();
+
+    for (auto& inj : eltwise_injectors)
+        inj->prepare_table();
 }
 
 bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::post_ops_ok(
@@ -510,27 +557,18 @@ bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::post_ops_ok(
     using namespace primitive_kind;
     const auto &p = attr.post_ops_;
 
-    auto is_relu = [&](int idx) {
-        return p.entry_[idx].kind == eltwise
-            && p.entry_[idx].eltwise.scale == 1.
-            && p.entry_[idx].eltwise.alg == alg_kind::eltwise_relu
-            && p.entry_[idx].eltwise.alpha == 0.;
-    };
+    auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+    auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
+    auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); };
+    auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
 
     switch (p.len_) {
-    case 0: return true;
-    case 1: return true
-                && IMPLICATION(jcp.with_eltwise, p.contain(sum, 0))
-                && IMPLICATION(!jcp.with_eltwise, is_relu(0) || p.contain(sum, 0));
-    case 2: return true
-                && IMPLICATION(jcp.with_eltwise, p.contain(sum, 0) && is_relu(1))
-                && IMPLICATION(!jcp.with_eltwise, false
-                        || (p.contain(sum, 0) && is_relu(1))
-                        || (p.contain(sum, 1) && is_relu(0)));
-    case 3: return true
-                && jcp.with_eltwise == false
-                && (is_relu(0) && p.contain(sum, 1) && is_relu(2));
-    default: return false;
+        case 0: return true;
+        case 1: return is_simple(0) || is_sum(0);
+        case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_sum(1)) ||
+                       (is_simple(0) && is_simple(1));
+        case 3: return (is_simple(0) && is_sum(1) && is_simple(2));
+        default: return false;
     }
 
     return false;
@@ -540,9 +578,7 @@ status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf(
         jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd,
         const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
         const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d,
-        const primitive_attr_t &attr, bool with_relu, float relu_negative_slope,
-        int nthreads, bool reduce_src)
-{
+        const primitive_attr_t &attr, int nthreads, bool reduce_src) {
     if (!mayiuse(avx512_core)) return status::unimplemented;
 
     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
@@ -577,10 +613,6 @@ status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf(
     jcp.stride_w = cd.strides[1];
     jcp.src_fmt = src_d.format();
     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
-    jcp.with_eltwise = with_relu;
-    jcp.eltwise_alpha = relu_negative_slope;
-    if (!IMPLICATION(with_relu, relu_negative_slope == 0.))
-        return status::unimplemented;
 
     jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
 
@@ -646,25 +678,30 @@ status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf(
         max_regs = 8;
     jcp.expl_bcast = true;
 
-    const int spatial = jcp.oh;
-    jcp.ur = 1;
-    for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) {
-        if ((spatial >= size_treshold && spatial % ur_w == 0)
-                || (spatial < size_treshold && jcp.os % ur_w == 0)) {
-            jcp.ur = ur_w;
-            break;
-        }
-    }
-    if (jcp.ur == 1) {
+    if (jcp.mb == 1 && jcp.ic > 128
+        && (jcp.oh <= size_treshold && jcp.ow <= size_treshold)) {
         jcp.ur = nstl::min(max_regs, jcp.os);
-        int os_tail = jcp.os % max_regs;
-        for (int i = max_regs; i >= min_regs; i--) {
-            int i_tail = jcp.os % i;
-            if (i_tail > os_tail || i_tail == 0) {
-                jcp.ur = i;
-                os_tail = i_tail;
-                if (i_tail == 0)
-                    break;
+    } else {
+        const int spatial = jcp.oh;
+        jcp.ur = 1;
+        for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) {
+            if ((spatial >= size_treshold && spatial % ur_w == 0)
+                    || (spatial < size_treshold && jcp.os % ur_w == 0)) {
+                jcp.ur = ur_w;
+                break;
+            }
+        }
+        if (jcp.ur == 1) {
+            jcp.ur = nstl::min(max_regs, jcp.os);
+            int os_tail = jcp.os % max_regs;
+            for (int i = max_regs; i >= min_regs; i--) {
+                int i_tail = jcp.os % i;
+                if (i_tail > os_tail || i_tail == 0) {
+                    jcp.ur = i;
+                    os_tail = i_tail;
+                    if (i_tail == 0)
+                        break;
+                }
             }
         }
     }
@@ -786,6 +823,17 @@ status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf(
     return status::success;
 }
 
+void jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_scratchpad(
+        memory_tracking::registrar_t &scratchpad,
+        const jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
+    using namespace mkldnn::impl::memory_tracking::names;
+
+    if (jcp.signed_input && jcp.ver != ver_vnni) {
+        size_t count = nstl::max(attr.output_scales_.count_, 16);
+        scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
+    }
+}
+
 }
 }
 }