Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp
index 831f182..164bbe0 100644 (file)
@@ -62,6 +62,41 @@ int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number,
     return best_divisor;
 }
 
+namespace {
+bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) {
+    /* Determines if current winograd implementation is faster than direct.
+       Following conditions are empirical and based on performance data */
+    unsigned int ncores_per_socket =
+        cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel);
+    unsigned int nthreads = mkldnn_get_max_threads();
+
+    if (jcp.prop_kind == prop_kind::forward_inference) {
+        return jcp.mb >= 4;
+    } else if (nthreads > ncores_per_socket) {
+        double src_dst_transforms_per_core = alpha * alpha
+            * (jcp.ic + jcp.oc)
+            * jcp.mb * ((jcp.oh + tile_size - 1) / tile_size)
+            * ((jcp.ow + tile_size - 1) / tile_size)
+            * sizeof(float) / 1024. / 1024. / nthreads;
+        double wei_transform = alpha * alpha
+            * jcp.ic * jcp.oc * sizeof(float) /1024. / 1024.;
+
+        if (jcp.prop_kind == prop_kind::backward_weights) {
+            if (src_dst_transforms_per_core < 0.3
+                    || (src_dst_transforms_per_core <= 28 && wei_transform < 4))
+                return false;
+            else
+                return true;
+        } else {
+            if (src_dst_transforms_per_core < 2.0 || wei_transform < 0.02)
+                return false;
+        }
+    }
+
+    return jcp.mb > 8;
+}
+}
+
 /* assumes 512 bits registers */
 /* TODO: add support for strides */
 /* TODO: handle the prefetch distance automatically */
@@ -730,16 +765,16 @@ void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
                     vaddps(zmm_O, zmm_O, ptr[oreg_bias]);
                 }
                 if (with_relu) {
-                    Opmask kmask = Opmask(7);
-                    if (jcp.eltwise_alpha == 0) {
-                        zmm_relu_ns = zmm_zero;
+                    if (jcp.eltwise.alpha == 0) {
+                        vmaxps(zmm_O, zmm_O, zmm_zero);
                     } else {
-                        mov(imm_addr64, float2int(jcp.eltwise_alpha));
+                        Opmask kmask = Opmask(7);
+                        mov(imm_addr64, float2int(jcp.eltwise.alpha));
                         vmovq(xmm_relu_ns, imm_addr64);
                         vbroadcastss(zmm_relu_ns, xmm_relu_ns);
+                        vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os);
+                        vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns);
                     }
-                    vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os);
-                    vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns);
                 }
             }
             if (with_sum) {
@@ -1095,6 +1130,9 @@ status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_common(
     if (!mayiuse(avx512_core)) {
         return status::unimplemented;
     }
+
+    jcp.nthr = mkldnn_get_max_threads();
+
     jcp.ver = ver_avx512_core;
     jcp.prop_kind = cd.prop_kind;
 
@@ -1133,6 +1171,10 @@ status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_common(
     }
 
     // Checking conditions not supported by these kernels
+    if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+               is_winograd_faster_than_direct(jcp)))
+        return status::unimplemented;
+
     if (jcp.ngroups != 1)
         return status::unimplemented;
     if ((jcp.kh != 3) || (jcp.kw != 3))
@@ -1366,28 +1408,16 @@ bool jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::post_ops_ok(
         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
     const auto &p = attr.post_ops_;
 
-    auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+    auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
 
     switch (p.len_) {
-    case 0:
-        return true; // no post_ops
-    case 1:
-        return true // relu or sum
-                && IMPLICATION(jcp.with_eltwise, is_sum(0))
-                && IMPLICATION(!jcp.with_eltwise, is_eltwise(0) || is_sum(0));
-    case 2:
-        return true // sum->relu or relu->sum
-                && IMPLICATION(jcp.with_eltwise, is_sum(0) && is_eltwise(1))
-                && IMPLICATION(!jcp.with_eltwise, false
-                                   || (is_sum(0) && is_eltwise(1))
-                                   || (is_eltwise(0) && is_sum(1)));
-    case 3:
-        return true // relu->sum->relu
-                && jcp.with_eltwise == false
-                && (is_eltwise(0) && is_sum(1) && is_eltwise(2));
-    default:
-        return false;
+    case 0: return true; // no post_ops
+    case 1: return is_relu(0) || is_sum(0); // relu or sum
+    case 2: return (is_sum(0) && is_relu(1))
+                      || (is_relu(0) && is_sum(1)); // sum->relu or relu->sum
+    case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu
+    default: return false;
     }
 
     return false;
@@ -1396,8 +1426,7 @@ bool jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::post_ops_ok(
 status_t jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::init_conf(
         jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
         const cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
-        const cpu_memory_t::pd_t &dst_pd, const primitive_attr_t &attr,
-        bool with_relu, float relu_negative_slope) {
+        const cpu_memory_t::pd_t &dst_pd, const primitive_attr_t &attr) {
 
     status_t st = init_conf_common(jcp, cd,
                         *src_pd.desc(), *weights_pd.desc(), *dst_pd.desc());
@@ -1411,18 +1440,16 @@ status_t jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::init_conf(
     jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
 
     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
-    jcp.with_eltwise = with_relu;
-    jcp.eltwise_alpha = relu_negative_slope;
 
     if (!post_ops_ok(jcp, attr))
         return status::unimplemented;
 
     const auto &p = attr.post_ops_;
-    if (!jcp.with_eltwise) {
-        /* PostOps ReLU before SUM is handled the same as ReLU primitive */
-        jcp.with_eltwise = p.find(primitive_kind::eltwise, 0, 1) != -1;
-        jcp.eltwise_alpha = 0.f;
-    }
+    const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1);
+    jcp.with_eltwise = eltwise_ind != -1;
+    if (jcp.with_eltwise)
+        jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+
     jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
     jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1;
 
@@ -2376,6 +2403,8 @@ status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) {
                                 jcp.dimM_block = M_blk;
                                 jcp.sched_policy = WSCHED_WEI_SDGtWo;
                                 set_jcp_WEI_params(jcp);
+                                jcp.nthr = nstl::min(mkldnn_get_max_threads(),
+                                        jcp.tile_block);
                                 return status::success;
                             }
                         }
@@ -2467,6 +2496,9 @@ status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::init_conf(
     else
         jcp.ver = ver_avx512_core;
 
+    jcp.nthr = mkldnn_get_max_threads();
+
+    jcp.prop_kind = cd.prop_kind;
     const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
     jcp.mb = src_d.dims()[0];
     jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
@@ -2507,6 +2539,10 @@ status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::init_conf(
     jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
 
     // Winograd kernel works only for 3x3 convolution with stride 1
+    if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+               is_winograd_faster_than_direct(jcp)))
+        return status::unimplemented;
+
     if (jcp.ngroups != 1)
         return status::unimplemented;
     if ((jcp.kh != 3) || (jcp.kw != 3))