return best_divisor;
}
+namespace {
+bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) {
+ /* Determines if current winograd implementation is faster than direct.
+ Following conditions are empirical and based on performance data */
+ unsigned int ncores_per_socket =
+ cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel);
+ unsigned int nthreads = mkldnn_get_max_threads();
+
+ if (jcp.prop_kind == prop_kind::forward_inference) {
+ return jcp.mb >= 4;
+ } else if (nthreads > ncores_per_socket) {
+ double src_dst_transforms_per_core = alpha * alpha
+ * (jcp.ic + jcp.oc)
+ * jcp.mb * ((jcp.oh + tile_size - 1) / tile_size)
+ * ((jcp.ow + tile_size - 1) / tile_size)
+ * sizeof(float) / 1024. / 1024. / nthreads;
+ double wei_transform = alpha * alpha
+ * jcp.ic * jcp.oc * sizeof(float) /1024. / 1024.;
+
+ if (jcp.prop_kind == prop_kind::backward_weights) {
+ if (src_dst_transforms_per_core < 0.3
+ || (src_dst_transforms_per_core <= 28 && wei_transform < 4))
+ return false;
+ else
+ return true;
+ } else {
+ if (src_dst_transforms_per_core < 2.0 || wei_transform < 0.02)
+ return false;
+ }
+ }
+
+ return jcp.mb > 8;
+}
+}
+
/* assumes 512 bits registers */
/* TODO: add support for strides */
/* TODO: handle the prefetch distance automatically */
vaddps(zmm_O, zmm_O, ptr[oreg_bias]);
}
if (with_relu) {
- Opmask kmask = Opmask(7);
- if (jcp.eltwise_alpha == 0) {
- zmm_relu_ns = zmm_zero;
+ if (jcp.eltwise.alpha == 0) {
+ vmaxps(zmm_O, zmm_O, zmm_zero);
} else {
- mov(imm_addr64, float2int(jcp.eltwise_alpha));
+ Opmask kmask = Opmask(7);
+ mov(imm_addr64, float2int(jcp.eltwise.alpha));
vmovq(xmm_relu_ns, imm_addr64);
vbroadcastss(zmm_relu_ns, xmm_relu_ns);
+ vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os);
+ vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns);
}
- vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os);
- vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns);
}
}
if (with_sum) {
if (!mayiuse(avx512_core)) {
return status::unimplemented;
}
+
+ jcp.nthr = mkldnn_get_max_threads();
+
jcp.ver = ver_avx512_core;
jcp.prop_kind = cd.prop_kind;
}
// Checking conditions not supported by these kernels
+ if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+ is_winograd_faster_than_direct(jcp)))
+ return status::unimplemented;
+
if (jcp.ngroups != 1)
return status::unimplemented;
if ((jcp.kh != 3) || (jcp.kw != 3))
jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
const auto &p = attr.post_ops_;
- auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
+ auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
switch (p.len_) {
- case 0:
- return true; // no post_ops
- case 1:
- return true // relu or sum
- && IMPLICATION(jcp.with_eltwise, is_sum(0))
- && IMPLICATION(!jcp.with_eltwise, is_eltwise(0) || is_sum(0));
- case 2:
- return true // sum->relu or relu->sum
- && IMPLICATION(jcp.with_eltwise, is_sum(0) && is_eltwise(1))
- && IMPLICATION(!jcp.with_eltwise, false
- || (is_sum(0) && is_eltwise(1))
- || (is_eltwise(0) && is_sum(1)));
- case 3:
- return true // relu->sum->relu
- && jcp.with_eltwise == false
- && (is_eltwise(0) && is_sum(1) && is_eltwise(2));
- default:
- return false;
+ case 0: return true; // no post_ops
+ case 1: return is_relu(0) || is_sum(0); // relu or sum
+ case 2: return (is_sum(0) && is_relu(1))
+ || (is_relu(0) && is_sum(1)); // sum->relu or relu->sum
+ case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu
+ default: return false;
}
return false;
status_t jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::init_conf(
jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
const cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
- const cpu_memory_t::pd_t &dst_pd, const primitive_attr_t &attr,
- bool with_relu, float relu_negative_slope) {
+ const cpu_memory_t::pd_t &dst_pd, const primitive_attr_t &attr) {
status_t st = init_conf_common(jcp, cd,
*src_pd.desc(), *weights_pd.desc(), *dst_pd.desc());
jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
jcp.with_bias = cd.bias_desc.format != memory_format::undef;
- jcp.with_eltwise = with_relu;
- jcp.eltwise_alpha = relu_negative_slope;
if (!post_ops_ok(jcp, attr))
return status::unimplemented;
const auto &p = attr.post_ops_;
- if (!jcp.with_eltwise) {
- /* PostOps ReLU before SUM is handled the same as ReLU primitive */
- jcp.with_eltwise = p.find(primitive_kind::eltwise, 0, 1) != -1;
- jcp.eltwise_alpha = 0.f;
- }
+ const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1);
+ jcp.with_eltwise = eltwise_ind != -1;
+ if (jcp.with_eltwise)
+ jcp.eltwise = p.entry_[eltwise_ind].eltwise;
+
jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1;
jcp.dimM_block = M_blk;
jcp.sched_policy = WSCHED_WEI_SDGtWo;
set_jcp_WEI_params(jcp);
+ jcp.nthr = nstl::min(mkldnn_get_max_threads(),
+ jcp.tile_block);
return status::success;
}
}
else
jcp.ver = ver_avx512_core;
+ jcp.nthr = mkldnn_get_max_threads();
+
+ jcp.prop_kind = cd.prop_kind;
const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
jcp.mb = src_d.dims()[0];
jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
// Winograd kernel works only for 3x3 convolution with stride 1
+ if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+ is_winograd_faster_than_direct(jcp)))
+ return status::unimplemented;
+
if (jcp.ngroups != 1)
return status::unimplemented;
if ((jcp.kh != 3) || (jcp.kw != 3))