Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_conv_winograd_kernel_f32.cpp
index 0405eee..63cd074 100644 (file)
@@ -66,6 +66,15 @@ int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number,
     return best_divisor;
 }
 
+namespace {
+bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) {
+    if (jcp.ver == ver_4fma)
+        return jcp.mb >= 32;
+    else
+        return jcp.mb >= 16;
+}
+}
+
 /* assumes 512 bits registers */
 /* TODO: add support for strides */
 /* TODO: handle the prefetch distance automatically */
@@ -137,29 +146,6 @@ private:
 };
 
 // utilities to support kernel parameter selection
-bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp,
-        int dimN_block, float C2_min, float C2_max) {
-    /* V_L2_block + M_L2_block + W */
-    float block_size = (alpha * alpha * (jcp.oc + jcp.ic)
-                     * dimN_block * jcp.dimN_reg_block
-                     + jcp.ic * jcp.oc) * (float)sizeof(float);
-    float L2_lb = C2_min * L2_cache_size;
-    float L2_ub =  C2_max * L2_cache_size;
-    return (block_size > L2_lb && block_size < L2_ub);
-}
-
-bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block,
-        int dimM_block, float C1_min, float C1_max) {
-    float gemm_block_size = (dimM_block * jcp.dimM_simd_block * dimK_block
-                             * jcp.dimK_reg_block
-                     + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block
-                     + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block)
-                     * (float)sizeof(float);
-    float L1_lb = C1_min * L1_cache_size;
-    float L1_ub = C1_max * L1_cache_size;
-    return (gemm_block_size > L1_lb && gemm_block_size < L1_ub);
-}
-
 bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block,
         int dimM_block, int dimM_simd_block, float C)
 {
@@ -311,10 +297,8 @@ void _jit_avx512_common_conv_winograd_data_kernel_f32::gemm_loop_generate(
             auto store_output = [=](bool output_is_aligned) {
                 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
                     Zmm zmm(jcp.zmm_start + tile);
-                    // In W_SGD, output will be reused.
                     if (output_is_aligned
                         && jcp.dimK_nb_block == 1
-                        && jcp.sched_policy == WSCHED_DATA_W_S_G_D
                         && (jcp.dimN * jcp.dimM * alpha * alpha
                             * sizeof(float) > 2 * LLC_data_size))
                         vmovntps(zword[reg_dstC + 64 * tile], zmm);
@@ -359,15 +343,17 @@ status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_common(
         const memory_desc_wrapper &dst_d)
 {
 
-    if (!mayiuse(avx512_common))
+    if (mayiuse(avx512_core))
+        return status::unimplemented;
+    else if (!mayiuse(avx512_common))
         return status::unimplemented;
-    else if (mayiuse(avx512_core))
-        jcp.ver = ver_avx512_core;
     else if (mayiuse(avx512_mic_4ops))
         jcp.ver = ver_4fma;
     else
         jcp.ver = ver_fma;
 
+    jcp.nthr = mkldnn_get_max_threads();
+
     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
 
     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
@@ -402,6 +388,10 @@ status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_common(
         jcp.ic = rnd_up(jcp.ic, simd_w);
     }
 
+    if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+                is_winograd_faster_than_direct(jcp)))
+        return status::unimplemented;
+
     // Checking conditions not supported by these kernels
     if (jcp.ngroups != 1)
         return status::unimplemented;
@@ -431,83 +421,6 @@ status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_common(
     return status::success;
 }
 
-status_t set_wsched_DATA_W_SGD_avx512_common(jit_conv_winograd_conf_t &jcp) {
-
-    if (jcp.ver != ver_avx512_core)
-        return status::unimplemented;
-
-    /* ----------- dimN reg block ---------------------*/
-    auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
-            int dimN_reg_block, int current_best) {
-        return (dimN_reg_block >= MIN_REQUIRED_DIMN_REG_BLOCK)
-            && (dimN_reg_block <= jcp.nb_reg)
-            && (dimN_reg_block < current_best);
-    };
-
-    jcp.dimN_reg_block = get_divisor_satisfying_cond(
-            jcp, jcp.dimN, jcp.dimN, test_cond_dimN_reg_block);
-
-    if (jcp.dimN_reg_block >= jcp.nb_reg) {
-        auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
-                int dimN_reg_block, int current_best) {
-            return (dimN_reg_block < jcp.nb_reg)
-                    && (dimN_reg_block > current_best);
-        };
-
-        jcp.dimN_reg_block = get_divisor_satisfying_cond(
-                jcp, jcp.dimN, 1, test_cond_dimN_reg_block);
-    }
-
-    /*-------------- L2 blocking for dimN block ---------*/
-
-    auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp,
-            int dimN_block, int current_best) {
-        return check_L2_block_per_thread(jcp, dimN_block, 0.1, 1.3)
-            && (dimN_block > current_best)
-            && ((jcp.dimN / dimN_block / jcp.dimN_reg_block) > 2 * mkldnn_get_max_threads());
-    };
-
-    jcp.dimN_block = get_divisor_satisfying_cond(
-            jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block);
-
-    if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 1.3)
-        && jcp.dimN/ jcp.dimN_block/ jcp.dimN_reg_block > 2 * mkldnn_get_max_threads()) {
-        jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
-
-        /* ------------------- L1 blocking for GEMM --------------*/
-        /* -------------------- Choose dimK block ----------------*/
-        auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp,
-                int dimK_block, int current_best) {
-            return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.6)
-                && (dimK_block > current_best);
-        };
-
-        jcp.dimK_block = get_divisor_satisfying_cond(
-                jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block);
-
-        if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 0.6)) {
-            jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
-
-            /* -------------- Choose dimM block -------------------*/
-            auto test_cond_dimM_block = [](jit_conv_winograd_conf_t &jcp,
-                    int dimM_block, int current_best) {
-                return check_L1_block_gemm(jcp, jcp.dimK_block, dimM_block, 0.1, 0.7)
-                    && (dimM_block > current_best);
-            };
-
-            jcp.dimM_block = get_divisor_satisfying_cond(
-                    jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond_dimM_block);
-            jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_simd_block;
-
-            jcp.sched_policy = WSCHED_DATA_W_SGD;
-            return status::success;
-        }
-
-    }
-    return status::unimplemented;
-
-}
-
 
 status_t set_wsched_DATA_W_S_G_D_avx512_common(jit_conv_winograd_conf_t &jcp) {
 
@@ -593,7 +506,6 @@ status_t set_wsched_DATA_W_S_G_D_avx512_common(jit_conv_winograd_conf_t &jcp) {
     jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block);
     jcp.sched_policy = WSCHED_DATA_W_S_G_D;
     return status::success;
-    //return status::unimplemented;
 }
 
 status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_kernel(
@@ -618,10 +530,9 @@ status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_kernel(
     jcp.dimM = dimM;
 
     jcp.sched_policy = WSCHED_INVALID;
-    if (!(set_wsched_DATA_W_SGD_avx512_common(jcp) == status::success))
-        set_wsched_DATA_W_S_G_D_avx512_common(jcp);
+    set_wsched_DATA_W_S_G_D_avx512_common(jcp);
 
-    assert(jcp.sched_policy != WSCHED_INVALID);
+    assert(jcp.sched_policy == WSCHED_DATA_W_S_G_D);
     return status::success;
 }
 
@@ -629,28 +540,16 @@ bool jit_avx512_common_conv_winograd_fwd_kernel_f32::post_ops_ok(
         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
     const auto &p = attr.post_ops_;
 
-    auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+    auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
 
     switch (p.len_) {
-    case 0:
-        return true; // no post_ops
-    case 1:
-        return true // relu or sum
-                && IMPLICATION(jcp.with_eltwise, is_sum(0))
-                && IMPLICATION(!jcp.with_eltwise, is_eltwise(0) || is_sum(0));
-    case 2:
-        return true // sum->relu or relu->sum
-                && IMPLICATION(jcp.with_eltwise, is_sum(0) && is_eltwise(1))
-                && IMPLICATION(!jcp.with_eltwise, false
-                                   || (is_sum(0) && is_eltwise(1))
-                                   || (is_eltwise(0) && is_sum(1)));
-    case 3:
-        return true // relu->sum->relu
-                && jcp.with_eltwise == false
-                && (is_eltwise(0) && is_sum(1) && is_eltwise(2));
-    default:
-        return false;
+    case 0: return true; // no post_ops
+    case 1: return is_relu(0) || is_sum(0); // relu or sum
+    case 2: return (is_sum(0) && is_relu(1)) ||
+                       (is_relu(0) && is_sum(1)); // sum->relu or relu->sum
+    case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu
+    default: return false;
     }
 
     return false;
@@ -659,8 +558,7 @@ bool jit_avx512_common_conv_winograd_fwd_kernel_f32::post_ops_ok(
 status_t jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf(
         jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
         const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
-        const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
-        bool with_relu, float relu_negative_slope) {
+        const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) {
     status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d);
 
     if (st != status::success)
@@ -672,18 +570,14 @@ status_t jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf(
     jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
 
     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
-    jcp.with_eltwise = with_relu;
-    jcp.eltwise_alpha = relu_negative_slope;
 
     if (!post_ops_ok(jcp, attr))
         return status::unimplemented;
 
     const auto &p = attr.post_ops_;
-    if (!jcp.with_eltwise) {
-        /* PostOps ReLU before SUM is handled the same as ReLU primitive */
-        jcp.with_eltwise = p.find(primitive_kind::eltwise, 0, 1) != -1;
-        jcp.eltwise_alpha = 0.f;
-    }
+    const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1);
+    jcp.with_eltwise = eltwise_ind != -1;
+    if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
     jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
 
     status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic);
@@ -1014,7 +908,7 @@ bool check_cond2_wu(int dimM_block, int dimM_simdw, int dimK_block,
 }
 } // namespace
 
-bool set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp)
+status_t set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp)
 {
     /*************** Choose dimN_reg_block (ic_simd_block)
      * *******************************/
@@ -1113,245 +1007,7 @@ bool set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp)
     jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block;
 
     jcp.sched_policy = WSCHED_WEI_S_D_G_W;
-    return true;
-}
-
-namespace {
-bool is_in_L1_range(int v, float C1, float C2)
-{
-    return ((v > C1 * L1_cache_size) && (v < C2 * L1_cache_size));
-}
-
-bool is_in_L2_range(int v, float C1, float C2)
-{
-    return ((v > C1 * L2_cache_size) && (v < C2 * L2_cache_size));
-}
-
-void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp, int tile_block_ur,
-        int tile_block, int nb_ic, int nb_oc)
-{
-    jcp.tile_block_ur = tile_block_ur;
-    jcp.tile_block = tile_block;
-    jcp.nb_ic = nb_ic;
-    jcp.nb_oc = nb_oc;
-
-    jcp.nb_tile_block_ur = jcp.ntiles / jcp.tile_block / jcp.tile_block_ur;
-    jcp.ic_block = jcp.ic / jcp.ic_simd_block / jcp.nb_ic;
-    jcp.oc_block = jcp.oc / jcp.oc_simd_block / jcp.nb_oc;
-
-    jcp.dimK_reg_block = jcp.tile_block_ur;
-    jcp.dimK_block = jcp.nb_tile_block_ur;
-    jcp.dimK_nb_block = jcp.tile_block;
-    jcp.dimN_reg_block = jcp.ic_simd_block;
-    jcp.dimN_block = jcp.ic_block;
-    jcp.dimN_nb_block = jcp.nb_ic;
-    jcp.dimM_simd_block = jcp.oc_simd_block;
-    jcp.dimM_block = jcp.oc_block;
-    jcp.dimM_nb_block = jcp.nb_oc;
-}
-}
-
-bool set_wsched_WEI_SDGt_W_avx512_common(jit_conv_winograd_conf_t &jcp)
-{
-    jcp.ic_simd_block = jcp.oc_simd_block = 16;
-    int nb_ic_simd_block = jcp.ic / jcp.ic_simd_block;
-    int nb_oc_simd_block = jcp.oc / jcp.oc_simd_block;
-
-    int min_tile_block_ur = 8;
-    int max_tile_block_ur = 64;
-    int max_tile_block = jcp.ntiles / min_tile_block_ur;
-
-    // Consider L2 + L3 together on SKX
-    const float C1_min = .1, C1_0 = .4, C1_max = .5;
-    const float C2_0 = .4, C2_max = .5;
-    const float TC2_0 = .7, TC2_max = 1.2;
-    const int T_min = 2, T0 = 20;
-    float C1, C2, TC2;
-    int T, tile_block, tile_block_ur, nb_oc, nb_ic;
-
-    auto blocking_ok = [&]() -> bool {
-        // V:tile_block + M:tile_block + U
-        int thread_size = alpha * alpha * jcp.oc
-                        * (jcp.ntiles / tile_block) * sizeof(float)
-                + alpha * alpha * jcp.ic * (jcp.ntiles / tile_block)
-                        * sizeof(float)
-                + alpha * alpha * jcp.ic * jcp.oc * sizeof(float);
-        // V:tile_block + M:tile_block
-        int L2_reuse = alpha * alpha * jcp.oc
-                        * (jcp.ntiles / tile_block) * sizeof(float)
-                + alpha * alpha * jcp.ic * (jcp.ntiles / tile_block)
-                        * sizeof(float);
-        // V:nb_ic + M:nb_tile_block_ur
-        // Use M:nb_oc + V:nb_ic as an superset estimation
-        int L1_reuse
-                = (jcp.ic / nb_ic) * (jcp.ntiles / tile_block) * sizeof(float)
-                + (jcp.oc / nb_oc) * (jcp.ntiles / tile_block) * sizeof(float);
-
-        return jcp.ntiles % tile_block == 0
-                && (jcp.ntiles / tile_block) % tile_block_ur == 0
-                && is_in_L2_range(thread_size, TC2, TC2_max)
-                && is_in_L2_range(L2_reuse, C2, C2_max)
-                && tile_block > T * mkldnn_get_max_threads()
-                && nb_oc_simd_block % nb_oc == 0
-                && nb_ic_simd_block % nb_ic == 0
-                && is_in_L1_range(L1_reuse, C1, C1_max);
-    };
-
-    for (C1 = C1_0, C2 = C2_0, TC2 = TC2_0; C1 > C1_min;
-            C1 -= .02, C2 -= .02, TC2 -= .04) {
-        for (T = T0; T >= T_min; --T) {
-            for (tile_block = 1; tile_block <= max_tile_block; ++tile_block) {
-                for (tile_block_ur = max_tile_block_ur;
-                        tile_block_ur >= min_tile_block_ur; --tile_block_ur) {
-                    for (nb_oc = 1; nb_oc <= nb_oc_simd_block; ++nb_oc) {
-                        for (nb_ic = nb_ic_simd_block; nb_ic >= 1; --nb_ic) {
-                            if (blocking_ok()) {
-                                set_jcp_WEI_params(jcp, tile_block_ur,
-                                        tile_block, nb_ic, nb_oc);
-                                jcp.sched_policy = WSCHED_WEI_SDGt_W;
-                                return true;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-
-    return false;
-}
-
-bool set_wsched_WEI_SDGtWo_avx512_common(jit_conv_winograd_conf_t &jcp)
-{
-    jcp.ic_simd_block = jcp.oc_simd_block = 16;
-    int nb_ic_simd_block = jcp.ic / jcp.ic_simd_block;
-    int nb_oc_simd_block = jcp.oc / jcp.oc_simd_block;
-
-    int min_tile_block_ur = 12;
-    int max_tile_block_ur = 64;
-    int max_tile_block = jcp.ntiles / min_tile_block_ur;
-
-    const float C1_min = .1, C1_0 = .4, C1_max = .5;
-    const float C2_0 = .4, C2_max = .6;
-    const float TC2_0 = .7, TC2_max = 1.6;
-
-    const int max_nb_oc = 2; // Limit the # of sequential execution
-    const int T0 = 12, T_min = 8;
-    float C1, C2, TC2;
-    int T, tile_block, tile_block_ur, nb_oc, nb_ic;
-
-    auto blocking_ok = [&]() -> bool {
-        // M:tile_block:nb_oc + V:tile_block + U:nb_oc
-        int thread_size = alpha * alpha * (jcp.oc / nb_oc)
-                        * (jcp.ntiles / tile_block) * sizeof(float)
-                + alpha * alpha * jcp.ic * (jcp.ntiles / tile_block)
-                        * sizeof(float)
-                + alpha * alpha * jcp.ic * (jcp.oc / nb_oc)
-                        * sizeof(float);
-        // M:tile_block:nb_oc + V:tile_block
-        int L2_reuse = alpha * alpha * (jcp.oc / nb_oc)
-                        * (jcp.ntiles / tile_block) * sizeof(float)
-                + alpha * alpha * jcp.ic * (jcp.ntiles / tile_block)
-                        * sizeof(float);
-        // V:nb_ic + M:nb_tile_block_ur
-        // Use M:nb_oc + V:nb_ic as an superset estimation
-        int L1_reuse
-                = (jcp.ic / nb_ic) * (jcp.ntiles / tile_block) * sizeof(float)
-                + (jcp.oc / nb_oc) * (jcp.ntiles / tile_block) * sizeof(float);
-
-        return jcp.ntiles % tile_block == 0
-                && (jcp.ntiles / tile_block) % tile_block_ur == 0
-                && is_in_L2_range(thread_size, TC2, TC2_max)
-                && is_in_L2_range(L2_reuse, C2, C2_max)
-                && tile_block > T * mkldnn_get_max_threads()
-                && nb_oc_simd_block % nb_oc == 0
-                && nb_ic_simd_block % nb_ic == 0
-                && is_in_L1_range(L1_reuse, C1, C1_max);
-    };
-
-    for (T = T0; T >= T_min; --T) {
-        for (C1 = C1_0, C2 = C2_0, TC2 = TC2_0; C1 > C1_min;
-                C1 -= .02, C2 -= .02, TC2 -= .04) {
-            for (nb_oc = 1; nb_oc <= max_nb_oc; ++nb_oc) {
-                for (tile_block = max_tile_block; tile_block >= 1;
-                        --tile_block) {
-                    for (tile_block_ur = min_tile_block_ur;
-                            tile_block_ur <= max_tile_block_ur;
-                            ++tile_block_ur) {
-                        for (nb_ic = 1; nb_ic <= nb_ic_simd_block; ++nb_ic) {
-                            if (blocking_ok()) {
-                                set_jcp_WEI_params(jcp, tile_block_ur,
-                                        tile_block, nb_ic, nb_oc);
-                                jcp.sched_policy = WSCHED_WEI_SDGtWo;
-                                return true;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-
-    return false;
-}
-
-bool set_wsched_WEI_S_D_Giot_W_avx512_common(jit_conv_winograd_conf_t &jcp)
-{
-    jcp.ic_simd_block = jcp.oc_simd_block = 16;
-    int nb_ic_simd_block = jcp.ic / jcp.ic_simd_block;
-
-    int min_tile_block_ur = 8;
-    int max_tile_block_ur = 64;
-    const float C1_min = .2, C1_0 = .4, C1_max = .9;
-    const float C2_min = .1, C2_0 = .4, C2_max = .5;
-    const int T0 = 16, T_min = 12;
-    float C1, C2;
-    int T, tile_block, tile_block_ur, nb_ic;
-    int nb_oc = 1; // Keep nb_oc small to increase
-                   // oc_block, for better reuse of V in
-                   // L2
-
-    auto blocking_ok = [&]() -> bool {
-        // V[:ic_block][][][]
-        int L2_reuse
-                = (jcp.ic / nb_ic) * (jcp.ntiles / tile_block) * sizeof(float);
-        // M[:nb_tile_block_ur][][] + V[:nb_tile_block_ur][][]
-        int L1_reuse
-                = (jcp.ntiles / tile_block) * jcp.oc_simd_block * sizeof(float);
-
-        int work_amount = tile_block * nb_ic * nb_oc * alpha * alpha;
-
-        return (jcp.ntiles / tile_block_ur) % tile_block == 0
-                && jcp.ntiles % tile_block_ur == 0
-                && nb_ic_simd_block % nb_ic == 0
-                && is_in_L2_range(L2_reuse, C2, C2_max)
-                && is_in_L1_range(L1_reuse, C1, C1_max)
-                && work_amount > T * mkldnn_get_max_threads();
-    };
-
-    for (T = T0; T >= T_min; --T) {
-        for (C1 = C1_0; C1 > C1_min; C1 -= .02) {
-            for (C2 = C2_0; C2 > C2_min; C2 -= .02) {
-                for (nb_ic = 1; nb_ic <= nb_ic_simd_block; ++nb_ic) {
-                    for (tile_block_ur = min_tile_block_ur;
-                            tile_block_ur <= max_tile_block_ur;
-                            ++tile_block_ur) {
-                        for (tile_block = 1;
-                                tile_block <= jcp.ntiles / min_tile_block_ur;
-                                ++tile_block) {
-                            if (blocking_ok()) {
-                                set_jcp_WEI_params(jcp, tile_block_ur,
-                                        tile_block, nb_ic, nb_oc);
-                                jcp.sched_policy = WSCHED_WEI_S_D_Giot_W;
-                                return true;
-                            }
-                        }
-                    }
-                }
-            }
-        }
-    }
-    return false;
+    return status::success;
 }
 
 status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf(
@@ -1359,8 +1015,7 @@ status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf(
         const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d,
         const memory_desc_wrapper &diff_weights_d)
 {
-    if (!mayiuse(avx512_common))
-        return status::unimplemented;
+    jcp.nthr = mkldnn_get_max_threads();
 
     const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
 
@@ -1397,15 +1052,18 @@ status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf(
         jcp.ic = rnd_up(jcp.ic, simd_w);
     }
 
+    if (mayiuse(avx512_core))
+        return status::unimplemented;
     if (!mayiuse(avx512_common))
         return status::unimplemented;
-    else if (mayiuse(avx512_core))
-        jcp.ver = ver_avx512_core;
     else if (mayiuse(avx512_mic_4ops))
         jcp.ver = ver_4fma;
     else
         jcp.ver = ver_fma;
 
+    if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+                is_winograd_faster_than_direct(jcp)))
+        return status::unimplemented;
     // Winograd specific initialization
     jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
     jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
@@ -1474,16 +1132,9 @@ status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf(
         jcp.zmm_start = jcp.ver == ver_4fma ? 4 : 1;
     jcp.nb_reg = 32 - jcp.zmm_start;
 
-    status_t res;
     jcp.sched_policy = WSCHED_INVALID;
-    if ((jcp.ver == ver_avx512_core &&
-            (set_wsched_WEI_SDGt_W_avx512_common(jcp)
-            || set_wsched_WEI_SDGtWo_avx512_common(jcp)
-            || set_wsched_WEI_S_D_Giot_W_avx512_common(jcp)))
-        || set_wsched_WEI_S_D_G_W_avx512_common(jcp))
-        res = status::success;
-    else
-        return status::unimplemented;
+    status_t res = set_wsched_WEI_S_D_G_W_avx512_common(jcp);
+    assert(jcp.sched_policy == WSCHED_WEI_S_D_G_W);
 
     jcp.tile_block_ur = jcp.dimK_reg_block;
     jcp.nb_tile_block_ur = jcp.dimK_block;