namespace cpu {
using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;
using namespace Xbyak;
if (position == 0) {
/* relu before sum */
return false
- || jcp.with_relu
|| p.contain(eltwise, 0);
} else if (position == 1) {
/* relu after sum */
cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
const primitive_attr_t &attr,
- bool with_relu, float relu_negative_slope,
memory_desc_t& expect_wei_md);
Zmm vreg_out(int n, int m) {
using namespace primitive_kind;
const auto &p = attr.post_ops_;
- auto is_relu = [&](int idx) {
- return p.entry_[idx].kind == eltwise
- && p.entry_[idx].eltwise.scale == 1.
- && p.entry_[idx].eltwise.alg == alg_kind::eltwise_relu
- && p.entry_[idx].eltwise.alpha == 0.;
- };
+ auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
- switch (p.len_) {
+ switch (p.len_) {
case 0: return true;
- case 1: return true
- && IMPLICATION(jcp.with_relu, p.contain(sum, 0))
- && IMPLICATION(!jcp.with_relu, is_relu(0) || p.contain(sum, 0));
- case 2: return true
- && IMPLICATION(jcp.with_relu, p.contain(sum, 0) && is_relu(1))
- && IMPLICATION(!jcp.with_relu, false
- || (p.contain(sum, 0) && is_relu(1))
- || (p.contain(sum, 1) && is_relu(0)));
- case 3: return true
- && jcp.with_relu == false
- && (is_relu(0) && p.contain(sum, 1) && is_relu(2));
+ case 1: return is_relu(0) || p.contain(sum, 0);
+ case 2: return (p.contain(sum, 0) && is_relu(1)) ||
+ (p.contain(sum, 1) && is_relu(0));
+ case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
default: return false;
}
postamble();
}
+namespace {
+bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
+ return jcp.mb >= 4;
+}
+}
+
status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf(
jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &wei_pd,
cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
- const primitive_attr_t &attr, bool with_relu, float relu_negative_slope,
- memory_desc_t &expect_wei_md) {
+ const primitive_attr_t &attr, memory_desc_t &expect_wei_md) {
const memory_desc_wrapper src_d(&src_pd);
const memory_desc_wrapper wei_d(&wei_pd);
const memory_desc_wrapper dst_d(&dst_pd);
const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
+ jcp.nthr = mkldnn_get_max_threads();
+
jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
jcp.mb = src_d.dims()[0];
jcp.oc = dst_d.dims()[1] / jcp.ngroups;
int simdw = 16;
jcp.src_fmt = src_d.format();
jcp.with_bias = cd.bias_desc.format != memory_format::undef;
- jcp.with_relu = with_relu;
- jcp.relu_negative_slope = relu_negative_slope;
- if (!IMPLICATION(with_relu, relu_negative_slope == 0.))
- return status::unimplemented;
+
if (!post_ops_ok(jcp, attr))
return status::unimplemented;
if (!(mayiuse(avx512_core)))
return status::unimplemented;
+ if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+ is_winograd_faster_than_direct(jcp)))
+ return status::unimplemented;
+
if (src_d.data_type() != data_type::f32)
return status::unimplemented;
if (wei_d.data_type() != data_type::f32)
auto wei_sz = (float)aa * ic * oc;
auto inp_sz = (float)mb * ih * iw * ic;
auto sp_sz = (float)mb * ih * iw;
- const int nthr = mkldnn_get_max_threads();
/* Heuristics here. Numbers '28','196' is an observation from data. */
if (wei_sz / inp_sz > 5)
else
jcp.small_mb = false;
- if (mb > nstl::min(nthr, 28)
+ if (mb > nstl::min(jcp.nthr, 28)
|| (!jcp.small_mb
&& (wei_sz >= 0.9f * L2_cap
- || inp_sz > L2_cap * nthr + L3_capacity))
+ || inp_sz > L2_cap * jcp.nthr + L3_capacity))
|| (jcp.small_mb && sp_sz > 196))
return unimplemented;
/* outer parallelization */
int nblocks = mb * div_up(ih, iy) * div_up(iw, ix);
- thr_eff = (float)nblocks / rnd_up(nblocks, nthr);
+ thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
mem_eff = 1.f;
req_mem = (((float)ix + 2) * (iy + 2) + aa * M) * Z + aa * Y;
/* inner parallelization */
int bsz = iy * ix / a;
int gemmw = aa * (nb_oc / n2_b);
- int bsz_r = rnd_up(bsz, nthr);
- int gemmw_r = rnd_up(gemmw, nthr);
+ int bsz_r = rnd_up(bsz, jcp.nthr);
+ int gemmw_r = rnd_up(gemmw, jcp.nthr);
thr_eff = ((float)Z * bsz / bsz_r + Y * gemmw / gemmw_r) / (Z + Y);
req_mem = (float)ix * iy * (ic + simdw * n2_b) + simdw * n2_b * ic;
mem_eff = nstl::min(1.f, L2_cap / req_mem);
- int M_per_thr = nstl::max(2, div_up(aa, nthr));
- int oc_per_thr = nstl::min(oc, div_up(aa * (nb_oc / n2_b), nthr));
+ int M_per_thr = nstl::max(2, div_up(aa, jcp.nthr));
+ int oc_per_thr =
+ nstl::min(oc, div_up(aa * (nb_oc / n2_b), jcp.nthr));
req_mem = (float)aa * oc_per_thr * ic + M_per_thr * M * Z;
if (req_mem > L2_cap)
mem_eff = 0.1f;
}
////////////////////////////////////////////////////////////////////////////////
-template <bool with_relu>
-status_t _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>
+status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_t
::pd_t::jit_conf(memory_desc_t& expect_wei_md) {
return jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::init_conf(
- jcp_, this->cdesc_(), this->src_pd_, this->weights_pd_,
- this->dst_pd_,this->bias_pd_, *this->attr(),
- with_relu, this->negative_slope(), expect_wei_md);
+ jcp_, *this->desc(), this->src_pd_, this->weights_pd_,
+ this->dst_pd_,this->bias_pd_, *this->attr(), expect_wei_md);
}
-template <bool with_relu>
-_jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>::
- _jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *pd,
+jit_avx512_core_fp32_wino_conv_2x3_fwd_t::
+ jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd,
const input_vector &inputs, const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs)
- , conf_(*pd), padded_bias_(nullptr) {
- const int nthreads = mkldnn_get_max_threads();
+ : cpu_primitive_t(apd, inputs, outputs)
+{
kernel_ = new jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t(
- conf_.jcp_, *conf_.attr());
+ pd()->jcp_, *pd()->attr());
src_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_src_trans_t(
- conf_.jcp_, *conf_.attr());
+ pd()->jcp_, *pd()->attr());
dst_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t(
- conf_.jcp_, *conf_.attr());
-
- int wino_size_offset
- = (conf_.jcp_.yb / 2) * (conf_.jcp_.xb / 2) + (conf_.jcp_.xb);
-
- size_wino_src = (conf_.jcp_.ic * 16) * (wino_size_offset);
- size_wino_dst = (conf_.jcp_.oc * 16) * (wino_size_offset);
-
- wino_src_ = (float *)malloc(sizeof(float) * nthreads * size_wino_src, 4096);
- wino_dst_ = (float *)malloc(sizeof(float) * nthreads * size_wino_dst, 4096);
- if (conf_.want_padded_bias()) {
- const auto &j = conf_.jcp_;
- assert(j.ngroups == 1);
- padded_bias_ = (float *)malloc(sizeof(float) * j.oc, 64);
- for (int oc = j.oc_without_padding; oc < j.oc; ++oc)
- padded_bias_[oc] = 0;
- }
-
-
+ pd()->jcp_, *pd()->attr());
}
-template <bool with_relu>
-_jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>
- ::~_jit_avx512_core_fp32_wino_conv_2x3_fwd_t() {
+jit_avx512_core_fp32_wino_conv_2x3_fwd_t
+ ::~jit_avx512_core_fp32_wino_conv_2x3_fwd_t() {
delete kernel_;
delete src_trans_;
delete dst_trans_;
-
- free(wino_src_);
- free(wino_dst_);
- free(padded_bias_);
}
-template <bool with_relu>
-void _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<
- with_relu>::execute_forward() {
+void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward() const {
const auto &jcp = kernel_->jcp;
if (jcp.small_mb)
execute_forward_mbN();
}
-template <bool with_relu>
-void _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>
-::execute_forward_mbN() {
+void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_mbN() const {
auto src = reinterpret_cast<const float *>(input_memory(0));
auto wei = reinterpret_cast<const float *>(input_memory(1));
auto bia = reinterpret_cast<const float *>(input_memory(2));
auto dst = reinterpret_cast<float *>(memory(0));
- const auto &jcp = kernel_->jcp;
- const auto &oscales = conf_.attr()->output_scales_;
-
- wino_wei_ = wei;
+ auto scratchpad = this->scratchpad();
- if (conf_.want_padded_bias()) {
- for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
- padded_bias_[oc] = bia[oc];
- bia = padded_bias_;
+ const auto &jcp = kernel_->jcp;
+ const auto &oscales = pd()->attr()->output_scales_;
+
+ const size_t wino_size_offset =
+ (size_t)(pd()->jcp_.yb / 2) * (pd()->jcp_.xb / 2) + (pd()->jcp_.xb);
+ const size_t size_wino_src = wino_size_offset * pd()->jcp_.ic * 16;
+ const size_t size_wino_dst = wino_size_offset * pd()->jcp_.oc * 16;
+
+ if (pd()->wants_padded_bias()) {
+ auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
+ utils::array_copy(padded_bias, bia, jcp.oc_without_padding);
+ utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+ jcp.oc - jcp.oc_without_padding);
+ bia = padded_bias;
}
+ auto ptr_V = scratchpad.get<float>(key_wino_V);
+ auto ptr_M = scratchpad.get<float>(key_wino_M);
+
parallel_nd(jcp.mb, div_up(jcp.oh,jcp.yb), div_up(jcp.ow, jcp.xb),
[&](int mb, int tile_y_b, int tile_x_b) {
int tile_y = tile_y_b * jcp.yb;
int tile_x = tile_x_b * jcp.xb;
int ithr = mkldnn_get_thread_num();
- auto wino_src = wino_src_ + size_wino_src * ithr;
- auto wino_dst = wino_dst_ + size_wino_dst * ithr;
+ auto wino_src = ptr_V + size_wino_src * ithr;
+ auto wino_dst = ptr_M + size_wino_dst * ithr;
auto src_trans_p =
jit_avx512_core_fp32_wino_conv_2x3_src_trans_t
int offset = (tile_ij + ithr) % 16;
gemm_p.src = wino_src + jcp.inp_stride * offset;
gemm_p.dst = wino_dst + jcp.out_stride * offset;
- gemm_p.wei = wino_wei_ + jcp.wei_stride * offset;
+ gemm_p.wei = wei + jcp.wei_stride * offset;
kernel_->ker_(&gemm_p);
}
});
}
-template <bool with_relu>
-void _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<with_relu>
- ::execute_forward_small_mb() {
+void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_small_mb() const
+{
auto src = reinterpret_cast<const float *>(input_memory(0));
auto wei = reinterpret_cast<const float *>(input_memory(1));
auto bia = reinterpret_cast<const float *>(input_memory(2));
auto dst = reinterpret_cast<float *>(memory(0));
- const auto &jcp = kernel_->jcp;
- const auto &oscales = conf_.attr()->output_scales_;
-
- wino_wei_ = wei;
+ auto scratchpad = this->scratchpad();
- if (conf_.want_padded_bias()) {
- for (int oc = 0; oc < jcp.oc_without_padding; ++oc)
- padded_bias_[oc] = bia[oc];
- bia = padded_bias_;
+ const auto &jcp = kernel_->jcp;
+ const auto &oscales = pd()->attr()->output_scales_;
+
+ if (pd()->wants_padded_bias()) {
+ auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
+ utils::array_copy(padded_bias, bia, jcp.oc_without_padding);
+ utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
+ jcp.oc - jcp.oc_without_padding);
+ bia = padded_bias;
}
+ auto ptr_V = scratchpad.get<float>(key_wino_V);
+ auto ptr_M = scratchpad.get<float>(key_wino_M);
+
for (int mb = 0; mb < jcp.mb; mb++) {
for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) {
for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
auto local_s = src
+ mb * jcp.nb_ic * jcp.ih * jcp.iw * jcp.ic_block
+ y * jcp.iw * jcp.ic_block + x * jcp.ic_block;
- auto local_w = wino_src_ + m * jcp.ic;
+ auto local_w = ptr_V + m * jcp.ic;
src_trans_p.src = local_s;
src_trans_p.wino_src = local_w;
auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::
call_params_t();
- gemm_p.src = wino_src_ + jcp.inp_stride * tile_ij;
- gemm_p.dst = wino_dst_ + jcp.out_stride * tile_ij
+ gemm_p.src = ptr_V + jcp.inp_stride * tile_ij;
+ gemm_p.dst = ptr_M + jcp.out_stride * tile_ij
+ nnb * jcp.n2_block * jcp.n_block;
- gemm_p.wei = wino_wei_ + jcp.wei_stride * tile_ij
+ gemm_p.wei = wei + jcp.wei_stride * tile_ij
+ nnb * jcp.n2_block * jcp.n_block * jcp.K;
kernel_->ker_(&gemm_p);
auto local_d = dst
+ mb * jcp.nb_oc * jcp.oh * jcp.ow * jcp.oc_block
+ y * jcp.ow * jcp.oc_block + x * jcp.oc_block;
- auto local_w = wino_dst_ + m * jcp.oc;
+ auto local_w = ptr_M + m * jcp.oc;
auto scales = oscales.scales_;
dst_trans_p.dst = local_d;
}}}
}
-template struct _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<true>;
-template struct _jit_avx512_core_fp32_wino_conv_2x3_fwd_t<false>;
-
} // namespace cpu
} // namespace impl
} // namespace mkldnn