return best_divisor;
}
+namespace {
+bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) {
+ if (jcp.ver == ver_4fma)
+ return jcp.mb >= 32;
+ else
+ return jcp.mb >= 16;
+}
+}
+
/* assumes 512 bits registers */
/* TODO: add support for strides */
/* TODO: handle the prefetch distance automatically */
};
// utilities to support kernel parameter selection
-bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp,
- int dimN_block, float C2_min, float C2_max) {
- /* V_L2_block + M_L2_block + W */
- float block_size = (alpha * alpha * (jcp.oc + jcp.ic)
- * dimN_block * jcp.dimN_reg_block
- + jcp.ic * jcp.oc) * (float)sizeof(float);
- float L2_lb = C2_min * L2_cache_size;
- float L2_ub = C2_max * L2_cache_size;
- return (block_size > L2_lb && block_size < L2_ub);
-}
-
-bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block,
- int dimM_block, float C1_min, float C1_max) {
- float gemm_block_size = (dimM_block * jcp.dimM_simd_block * dimK_block
- * jcp.dimK_reg_block
- + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block
- + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block)
- * (float)sizeof(float);
- float L1_lb = C1_min * L1_cache_size;
- float L1_ub = C1_max * L1_cache_size;
- return (gemm_block_size > L1_lb && gemm_block_size < L1_ub);
-}
-
bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block,
int dimM_block, int dimM_simd_block, float C)
{
auto store_output = [=](bool output_is_aligned) {
for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
Zmm zmm(jcp.zmm_start + tile);
- // In W_SGD, output will be reused.
if (output_is_aligned
&& jcp.dimK_nb_block == 1
- && jcp.sched_policy == WSCHED_DATA_W_S_G_D
&& (jcp.dimN * jcp.dimM * alpha * alpha
* sizeof(float) > 2 * LLC_data_size))
vmovntps(zword[reg_dstC + 64 * tile], zmm);
const memory_desc_wrapper &dst_d)
{
- if (!mayiuse(avx512_common))
+ if (mayiuse(avx512_core))
+ return status::unimplemented;
+ else if (!mayiuse(avx512_common))
return status::unimplemented;
- else if (mayiuse(avx512_core))
- jcp.ver = ver_avx512_core;
else if (mayiuse(avx512_mic_4ops))
jcp.ver = ver_4fma;
else
jcp.ver = ver_fma;
+ jcp.nthr = mkldnn_get_max_threads();
+
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
jcp.ic = rnd_up(jcp.ic, simd_w);
}
+ if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+ is_winograd_faster_than_direct(jcp)))
+ return status::unimplemented;
+
// Checking conditions not supported by these kernels
if (jcp.ngroups != 1)
return status::unimplemented;
return status::success;
}
-status_t set_wsched_DATA_W_SGD_avx512_common(jit_conv_winograd_conf_t &jcp) {
-
- if (jcp.ver != ver_avx512_core)
- return status::unimplemented;
-
- /* ----------- dimN reg block ---------------------*/
- auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
- int dimN_reg_block, int current_best) {
- return (dimN_reg_block >= MIN_REQUIRED_DIMN_REG_BLOCK)
- && (dimN_reg_block <= jcp.nb_reg)
- && (dimN_reg_block < current_best);
- };
-
- jcp.dimN_reg_block = get_divisor_satisfying_cond(
- jcp, jcp.dimN, jcp.dimN, test_cond_dimN_reg_block);
-
- if (jcp.dimN_reg_block >= jcp.nb_reg) {
- auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
- int dimN_reg_block, int current_best) {
- return (dimN_reg_block < jcp.nb_reg)
- && (dimN_reg_block > current_best);
- };
-
- jcp.dimN_reg_block = get_divisor_satisfying_cond(
- jcp, jcp.dimN, 1, test_cond_dimN_reg_block);
- }
-
- /*-------------- L2 blocking for dimN block ---------*/
-
- auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp,
- int dimN_block, int current_best) {
- return check_L2_block_per_thread(jcp, dimN_block, 0.1, 1.3)
- && (dimN_block > current_best)
- && ((jcp.dimN / dimN_block / jcp.dimN_reg_block) > 2 * mkldnn_get_max_threads());
- };
-
- jcp.dimN_block = get_divisor_satisfying_cond(
- jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block);
-
- if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 1.3)
- && jcp.dimN/ jcp.dimN_block/ jcp.dimN_reg_block > 2 * mkldnn_get_max_threads()) {
- jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
-
- /* ------------------- L1 blocking for GEMM --------------*/
- /* -------------------- Choose dimK block ----------------*/
- auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp,
- int dimK_block, int current_best) {
- return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.6)
- && (dimK_block > current_best);
- };
-
- jcp.dimK_block = get_divisor_satisfying_cond(
- jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block);
-
- if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 0.6)) {
- jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
-
- /* -------------- Choose dimM block -------------------*/
- auto test_cond_dimM_block = [](jit_conv_winograd_conf_t &jcp,
- int dimM_block, int current_best) {
- return check_L1_block_gemm(jcp, jcp.dimK_block, dimM_block, 0.1, 0.7)
- && (dimM_block > current_best);
- };
-
- jcp.dimM_block = get_divisor_satisfying_cond(
- jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond_dimM_block);
- jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_simd_block;
-
- jcp.sched_policy = WSCHED_DATA_W_SGD;
- return status::success;
- }
-
- }
- return status::unimplemented;
-
-}
-
status_t set_wsched_DATA_W_S_G_D_avx512_common(jit_conv_winograd_conf_t &jcp) {
jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block);
jcp.sched_policy = WSCHED_DATA_W_S_G_D;
return status::success;
- //return status::unimplemented;
}
status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_kernel(
jcp.dimM = dimM;
jcp.sched_policy = WSCHED_INVALID;
- if (!(set_wsched_DATA_W_SGD_avx512_common(jcp) == status::success))
- set_wsched_DATA_W_S_G_D_avx512_common(jcp);
+ set_wsched_DATA_W_S_G_D_avx512_common(jcp);
- assert(jcp.sched_policy != WSCHED_INVALID);
+ assert(jcp.sched_policy == WSCHED_DATA_W_S_G_D);
return status::success;
}
jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
const auto &p = attr.post_ops_;
- auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+ auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
switch (p.len_) {
- case 0:
- return true; // no post_ops
- case 1:
- return true // relu or sum
- && IMPLICATION(jcp.with_eltwise, is_sum(0))
- && IMPLICATION(!jcp.with_eltwise, is_eltwise(0) || is_sum(0));
- case 2:
- return true // sum->relu or relu->sum
- && IMPLICATION(jcp.with_eltwise, is_sum(0) && is_eltwise(1))
- && IMPLICATION(!jcp.with_eltwise, false
- || (is_sum(0) && is_eltwise(1))
- || (is_eltwise(0) && is_sum(1)));
- case 3:
- return true // relu->sum->relu
- && jcp.with_eltwise == false
- && (is_eltwise(0) && is_sum(1) && is_eltwise(2));
- default:
- return false;
+ case 0: return true; // no post_ops
+ case 1: return is_relu(0) || is_sum(0); // relu or sum
+ case 2: return (is_sum(0) && is_relu(1)) ||
+ (is_relu(0) && is_sum(1)); // sum->relu or relu->sum
+ case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu
+ default: return false;
}
return false;
status_t jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf(
jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
- const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
- bool with_relu, float relu_negative_slope) {
+ const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) {
status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d);
if (st != status::success)
jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
jcp.with_bias = cd.bias_desc.format != memory_format::undef;
- jcp.with_eltwise = with_relu;
- jcp.eltwise_alpha = relu_negative_slope;
if (!post_ops_ok(jcp, attr))
return status::unimplemented;
const auto &p = attr.post_ops_;
- if (!jcp.with_eltwise) {
- /* PostOps ReLU before SUM is handled the same as ReLU primitive */
- jcp.with_eltwise = p.find(primitive_kind::eltwise, 0, 1) != -1;
- jcp.eltwise_alpha = 0.f;
- }
+ const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic);
}
} // namespace
-bool set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp)
+status_t set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp)
{
/*************** Choose dimN_reg_block (ic_simd_block)
* *******************************/
jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block;
jcp.sched_policy = WSCHED_WEI_S_D_G_W;
- return true;
-}
-
-namespace {
-bool is_in_L1_range(int v, float C1, float C2)
-{
- return ((v > C1 * L1_cache_size) && (v < C2 * L1_cache_size));
-}
-
-bool is_in_L2_range(int v, float C1, float C2)
-{
- return ((v > C1 * L2_cache_size) && (v < C2 * L2_cache_size));
-}
-
-void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp, int tile_block_ur,
- int tile_block, int nb_ic, int nb_oc)
-{
- jcp.tile_block_ur = tile_block_ur;
- jcp.tile_block = tile_block;
- jcp.nb_ic = nb_ic;
- jcp.nb_oc = nb_oc;
-
- jcp.nb_tile_block_ur = jcp.ntiles / jcp.tile_block / jcp.tile_block_ur;
- jcp.ic_block = jcp.ic / jcp.ic_simd_block / jcp.nb_ic;
- jcp.oc_block = jcp.oc / jcp.oc_simd_block / jcp.nb_oc;
-
- jcp.dimK_reg_block = jcp.tile_block_ur;
- jcp.dimK_block = jcp.nb_tile_block_ur;
- jcp.dimK_nb_block = jcp.tile_block;
- jcp.dimN_reg_block = jcp.ic_simd_block;
- jcp.dimN_block = jcp.ic_block;
- jcp.dimN_nb_block = jcp.nb_ic;
- jcp.dimM_simd_block = jcp.oc_simd_block;
- jcp.dimM_block = jcp.oc_block;
- jcp.dimM_nb_block = jcp.nb_oc;
-}
-}
-
-bool set_wsched_WEI_SDGt_W_avx512_common(jit_conv_winograd_conf_t &jcp)
-{
- jcp.ic_simd_block = jcp.oc_simd_block = 16;
- int nb_ic_simd_block = jcp.ic / jcp.ic_simd_block;
- int nb_oc_simd_block = jcp.oc / jcp.oc_simd_block;
-
- int min_tile_block_ur = 8;
- int max_tile_block_ur = 64;
- int max_tile_block = jcp.ntiles / min_tile_block_ur;
-
- // Consider L2 + L3 together on SKX
- const float C1_min = .1, C1_0 = .4, C1_max = .5;
- const float C2_0 = .4, C2_max = .5;
- const float TC2_0 = .7, TC2_max = 1.2;
- const int T_min = 2, T0 = 20;
- float C1, C2, TC2;
- int T, tile_block, tile_block_ur, nb_oc, nb_ic;
-
- auto blocking_ok = [&]() -> bool {
- // V:tile_block + M:tile_block + U
- int thread_size = alpha * alpha * jcp.oc
- * (jcp.ntiles / tile_block) * sizeof(float)
- + alpha * alpha * jcp.ic * (jcp.ntiles / tile_block)
- * sizeof(float)
- + alpha * alpha * jcp.ic * jcp.oc * sizeof(float);
- // V:tile_block + M:tile_block
- int L2_reuse = alpha * alpha * jcp.oc
- * (jcp.ntiles / tile_block) * sizeof(float)
- + alpha * alpha * jcp.ic * (jcp.ntiles / tile_block)
- * sizeof(float);
- // V:nb_ic + M:nb_tile_block_ur
- // Use M:nb_oc + V:nb_ic as an superset estimation
- int L1_reuse
- = (jcp.ic / nb_ic) * (jcp.ntiles / tile_block) * sizeof(float)
- + (jcp.oc / nb_oc) * (jcp.ntiles / tile_block) * sizeof(float);
-
- return jcp.ntiles % tile_block == 0
- && (jcp.ntiles / tile_block) % tile_block_ur == 0
- && is_in_L2_range(thread_size, TC2, TC2_max)
- && is_in_L2_range(L2_reuse, C2, C2_max)
- && tile_block > T * mkldnn_get_max_threads()
- && nb_oc_simd_block % nb_oc == 0
- && nb_ic_simd_block % nb_ic == 0
- && is_in_L1_range(L1_reuse, C1, C1_max);
- };
-
- for (C1 = C1_0, C2 = C2_0, TC2 = TC2_0; C1 > C1_min;
- C1 -= .02, C2 -= .02, TC2 -= .04) {
- for (T = T0; T >= T_min; --T) {
- for (tile_block = 1; tile_block <= max_tile_block; ++tile_block) {
- for (tile_block_ur = max_tile_block_ur;
- tile_block_ur >= min_tile_block_ur; --tile_block_ur) {
- for (nb_oc = 1; nb_oc <= nb_oc_simd_block; ++nb_oc) {
- for (nb_ic = nb_ic_simd_block; nb_ic >= 1; --nb_ic) {
- if (blocking_ok()) {
- set_jcp_WEI_params(jcp, tile_block_ur,
- tile_block, nb_ic, nb_oc);
- jcp.sched_policy = WSCHED_WEI_SDGt_W;
- return true;
- }
- }
- }
- }
- }
- }
- }
-
- return false;
-}
-
-bool set_wsched_WEI_SDGtWo_avx512_common(jit_conv_winograd_conf_t &jcp)
-{
- jcp.ic_simd_block = jcp.oc_simd_block = 16;
- int nb_ic_simd_block = jcp.ic / jcp.ic_simd_block;
- int nb_oc_simd_block = jcp.oc / jcp.oc_simd_block;
-
- int min_tile_block_ur = 12;
- int max_tile_block_ur = 64;
- int max_tile_block = jcp.ntiles / min_tile_block_ur;
-
- const float C1_min = .1, C1_0 = .4, C1_max = .5;
- const float C2_0 = .4, C2_max = .6;
- const float TC2_0 = .7, TC2_max = 1.6;
-
- const int max_nb_oc = 2; // Limit the # of sequential execution
- const int T0 = 12, T_min = 8;
- float C1, C2, TC2;
- int T, tile_block, tile_block_ur, nb_oc, nb_ic;
-
- auto blocking_ok = [&]() -> bool {
- // M:tile_block:nb_oc + V:tile_block + U:nb_oc
- int thread_size = alpha * alpha * (jcp.oc / nb_oc)
- * (jcp.ntiles / tile_block) * sizeof(float)
- + alpha * alpha * jcp.ic * (jcp.ntiles / tile_block)
- * sizeof(float)
- + alpha * alpha * jcp.ic * (jcp.oc / nb_oc)
- * sizeof(float);
- // M:tile_block:nb_oc + V:tile_block
- int L2_reuse = alpha * alpha * (jcp.oc / nb_oc)
- * (jcp.ntiles / tile_block) * sizeof(float)
- + alpha * alpha * jcp.ic * (jcp.ntiles / tile_block)
- * sizeof(float);
- // V:nb_ic + M:nb_tile_block_ur
- // Use M:nb_oc + V:nb_ic as an superset estimation
- int L1_reuse
- = (jcp.ic / nb_ic) * (jcp.ntiles / tile_block) * sizeof(float)
- + (jcp.oc / nb_oc) * (jcp.ntiles / tile_block) * sizeof(float);
-
- return jcp.ntiles % tile_block == 0
- && (jcp.ntiles / tile_block) % tile_block_ur == 0
- && is_in_L2_range(thread_size, TC2, TC2_max)
- && is_in_L2_range(L2_reuse, C2, C2_max)
- && tile_block > T * mkldnn_get_max_threads()
- && nb_oc_simd_block % nb_oc == 0
- && nb_ic_simd_block % nb_ic == 0
- && is_in_L1_range(L1_reuse, C1, C1_max);
- };
-
- for (T = T0; T >= T_min; --T) {
- for (C1 = C1_0, C2 = C2_0, TC2 = TC2_0; C1 > C1_min;
- C1 -= .02, C2 -= .02, TC2 -= .04) {
- for (nb_oc = 1; nb_oc <= max_nb_oc; ++nb_oc) {
- for (tile_block = max_tile_block; tile_block >= 1;
- --tile_block) {
- for (tile_block_ur = min_tile_block_ur;
- tile_block_ur <= max_tile_block_ur;
- ++tile_block_ur) {
- for (nb_ic = 1; nb_ic <= nb_ic_simd_block; ++nb_ic) {
- if (blocking_ok()) {
- set_jcp_WEI_params(jcp, tile_block_ur,
- tile_block, nb_ic, nb_oc);
- jcp.sched_policy = WSCHED_WEI_SDGtWo;
- return true;
- }
- }
- }
- }
- }
- }
- }
-
- return false;
-}
-
-bool set_wsched_WEI_S_D_Giot_W_avx512_common(jit_conv_winograd_conf_t &jcp)
-{
- jcp.ic_simd_block = jcp.oc_simd_block = 16;
- int nb_ic_simd_block = jcp.ic / jcp.ic_simd_block;
-
- int min_tile_block_ur = 8;
- int max_tile_block_ur = 64;
- const float C1_min = .2, C1_0 = .4, C1_max = .9;
- const float C2_min = .1, C2_0 = .4, C2_max = .5;
- const int T0 = 16, T_min = 12;
- float C1, C2;
- int T, tile_block, tile_block_ur, nb_ic;
- int nb_oc = 1; // Keep nb_oc small to increase
- // oc_block, for better reuse of V in
- // L2
-
- auto blocking_ok = [&]() -> bool {
- // V[:ic_block][][][]
- int L2_reuse
- = (jcp.ic / nb_ic) * (jcp.ntiles / tile_block) * sizeof(float);
- // M[:nb_tile_block_ur][][] + V[:nb_tile_block_ur][][]
- int L1_reuse
- = (jcp.ntiles / tile_block) * jcp.oc_simd_block * sizeof(float);
-
- int work_amount = tile_block * nb_ic * nb_oc * alpha * alpha;
-
- return (jcp.ntiles / tile_block_ur) % tile_block == 0
- && jcp.ntiles % tile_block_ur == 0
- && nb_ic_simd_block % nb_ic == 0
- && is_in_L2_range(L2_reuse, C2, C2_max)
- && is_in_L1_range(L1_reuse, C1, C1_max)
- && work_amount > T * mkldnn_get_max_threads();
- };
-
- for (T = T0; T >= T_min; --T) {
- for (C1 = C1_0; C1 > C1_min; C1 -= .02) {
- for (C2 = C2_0; C2 > C2_min; C2 -= .02) {
- for (nb_ic = 1; nb_ic <= nb_ic_simd_block; ++nb_ic) {
- for (tile_block_ur = min_tile_block_ur;
- tile_block_ur <= max_tile_block_ur;
- ++tile_block_ur) {
- for (tile_block = 1;
- tile_block <= jcp.ntiles / min_tile_block_ur;
- ++tile_block) {
- if (blocking_ok()) {
- set_jcp_WEI_params(jcp, tile_block_ur,
- tile_block, nb_ic, nb_oc);
- jcp.sched_policy = WSCHED_WEI_S_D_Giot_W;
- return true;
- }
- }
- }
- }
- }
- }
- }
- return false;
+ return status::success;
}
status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf(
const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d,
const memory_desc_wrapper &diff_weights_d)
{
- if (!mayiuse(avx512_common))
- return status::unimplemented;
+ jcp.nthr = mkldnn_get_max_threads();
const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
jcp.ic = rnd_up(jcp.ic, simd_w);
}
+ if (mayiuse(avx512_core))
+ return status::unimplemented;
if (!mayiuse(avx512_common))
return status::unimplemented;
- else if (mayiuse(avx512_core))
- jcp.ver = ver_avx512_core;
else if (mayiuse(avx512_mic_4ops))
jcp.ver = ver_4fma;
else
jcp.ver = ver_fma;
+ if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+ is_winograd_faster_than_direct(jcp)))
+ return status::unimplemented;
// Winograd specific initialization
jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
jcp.zmm_start = jcp.ver == ver_4fma ? 4 : 1;
jcp.nb_reg = 32 - jcp.zmm_start;
- status_t res;
jcp.sched_policy = WSCHED_INVALID;
- if ((jcp.ver == ver_avx512_core &&
- (set_wsched_WEI_SDGt_W_avx512_common(jcp)
- || set_wsched_WEI_SDGtWo_avx512_common(jcp)
- || set_wsched_WEI_S_D_Giot_W_avx512_common(jcp)))
- || set_wsched_WEI_S_D_G_W_avx512_common(jcp))
- res = status::success;
- else
- return status::unimplemented;
+ status_t res = set_wsched_WEI_S_D_G_W_avx512_common(jcp);
+ assert(jcp.sched_policy == WSCHED_WEI_S_D_G_W);
jcp.tile_block_ur = jcp.dimK_reg_block;
jcp.nb_tile_block_ur = jcp.dimK_block;