#include <assert.h>
#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
#include "cpu_convolution_pd.hpp"
#include "cpu_engine.hpp"
#include "mkldnn_thread.hpp"
namespace cpu {
using namespace mkldnn::impl::memory_format;
+using namespace mkldnn::impl::memory_tracking::names;
using namespace mkldnn::impl::utils;
using namespace Xbyak;
return Opmask(3 + id);
}
- Reg64 reg_ptr_offset = r15;
Reg64 reg_ptr_src = r14;
Reg64 reg_ptr_dst = r13;
Reg64 reg_scratch_src_alpha = rdx;
Xmm xmm_src_alpha = Xmm(0);
Zmm zmm_src_alpha = Zmm(0);
+
+ Reg64 reg_shift = rax;
+ Xmm xmm_shift = Xmm(1);
+ Xmm xmm_zero = Xmm(0);
+
+ Reg64 reg_maskx = rbx;
+ Reg64 reg_masky = rsi;
+ Reg64 reg_nomask = reg_maskx;
};
void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
Label ic_block_label;
+ Label end_label;
+ Label mask_label;
+ Label nomask_label;
+
+ auto load_src = [=](bool mask) {
+ for (int y = 0; y < jcp.alpha; y++) {
+ if (mask)
+ kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]);
+ for (int x = 0; x < jcp.alpha; x++) {
+ Zmm zmm_i = zmm_inp(y * jcp.alpha + x);
+ Xmm vreg_i = vreg_inp(y * jcp.alpha + x);
+ int inp_offset = sizeof(uint8_t)
+ * ((-jcp.t_pad + y) * jcp.iw * jcp.ic
+ + (-jcp.l_pad + x) * jcp.ic);
+ if (mask) {
+ kandw(r_mask, y_mask, x_mask(x));
+ vmovdqu8(vreg_i | r_mask | T_z,
+ EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
+ } else {
+ vmovdqu8(vreg_i,
+ EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
+ }
+ vpmovzxbd(zmm_i, vreg_i); // to int32
+ vcvtdq2ps(zmm_i, zmm_i); // to fp32
+ vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha
+ vcvtps2dq(zmm_i | T_rn_sae, zmm_i); // to int32
+ vpmovusdb(vreg_i, zmm_i); // to u8
+ }
+ }
+ };
- int out_offset = 0, inp_offset = 0;
preamble();
# define READ_PARAM(reg, field) \
READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
# undef READ_PARAM
- xor_(eax, eax);
- mov(ax, (int8_t)-128);
+ mov(reg_maskx, ptr[reg_ptr_v_x_masks]);
+ mov(reg_masky, ptr[reg_ptr_v_y_masks]);
+ test(reg_maskx, reg_maskx);
+ jz(end_label, T_NEAR); // skip kernel if x mask is all 0's
+ test(reg_masky, reg_masky);
+ jz(end_label, T_NEAR); // skip kernel if y mask is all 0's
+ and_(reg_maskx, reg_masky);
+ mov(reg_nomask, reg_maskx);
+ not_(reg_nomask); // zero if x and y masks are all 1's
+
+ xor_(reg_shift, reg_shift);
+ mov(reg_shift.cvt8(), (int8_t)-128);
mov(reg_aux_ptr_src, reg_ptr_src);
mov(reg_aux_ptr_dst, reg_ptr_dst);
for (int i = 0; i < jcp.alpha; i++) {
- kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
+ kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
}
mov(reg_scratch_src_alpha, float2int(adj_src_scale));
vmovq(xmm_src_alpha, reg_scratch_src_alpha);
vbroadcastss(zmm_src_alpha, xmm_src_alpha);
- for(int y = 0; y < jcp.alpha; y++) {
- kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]);
- for(int x = 0; x < jcp.alpha; x++) {
- Zmm zmm_i = zmm_inp(y*jcp.alpha + x);
- Xmm vreg_i = vreg_inp(y*jcp.alpha + x);
- vpxord(vreg_i, vreg_i, vreg_i);
- kandw(r_mask, y_mask, x_mask(x));
- inp_offset = sizeof(uint8_t) *
- ((-jcp.t_pad + y) * jcp.iw * jcp.ic
- + (-jcp.l_pad + x) * jcp.ic);
- vmovdqu8(vreg_i | r_mask, EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
- vpmovzxbd(zmm_i, vreg_i); // to int32
- vcvtdq2ps(zmm_i, zmm_i); // to fp32
- vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha
- vcvtps2dq(zmm_i | T_rn_sae, zmm_i); // to int32
- vpmovusdb(vreg_i, zmm_i); // to u8
- }
- }
+ test(reg_nomask, reg_nomask);
+ jz(nomask_label, T_NEAR);
+ load_src(true);
+ jmp(mask_label, T_NEAR);
+ L(nomask_label);
+ load_src(false);
+ L(mask_label);
+
for(int y = 0; y < 4; y++) {
vpsubb(vreg_tmp(y*4+0), vreg_inp(y*4+0), vreg_inp(y*4+2));
vpaddb(vreg_tmp(y*4+1), vreg_inp(y*4+1), vreg_inp(y*4+2));
vpsubb(vreg_out(x+3*4), vreg_tmp(x+4*1), vreg_tmp(x+4*3));
}
- movd(Xmm(1), eax);
- pxor(Xmm(0), Xmm(0));
- pshufb(Xmm(1), Xmm(0));
+ vmovd(xmm_shift, reg_shift.cvt32());
+ vpxor(xmm_zero, xmm_zero, xmm_zero);
+ vpshufb(xmm_shift, xmm_shift, xmm_zero);
for (int i = 0; i < 16; i++) {
- out_offset = sizeof(uint8_t) * (jcp.inp_stride * i);
+ int out_offset = sizeof(uint8_t) * (jcp.inp_stride * i);
if (i != unsign_val_in_wino_domain)
vpsubb(vreg_out(i), vreg_out(i), Xmm(1));
vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), vreg_out(i));
dec(reg_ic_block);
jnz(ic_block_label, T_NEAR);
+ L(end_label);
postamble();
}
if (position == 0) {
/* relu before sum */
return false
- || jcp.with_relu
|| p.contain(eltwise, 0)
|| (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
} else if (position == 1) {
vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha
}
for(int y = 0; y < jcp.m; y++) {
- kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(int16_t) * y ]);
+ kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(uint16_t) * y ]);
for(int x = 0; x < jcp.m; x++) {
kandw(r_mask, y_mask, x_mask(x));
mov(reg_aux_ptr_dst, reg_ptr_dst);
vpxord(vreg_zero, vreg_zero, vreg_zero);
- for (int i = 0; i < jcp.alpha * jcp.alpha; i++)
- vpxord(vreg_inp(i), vreg_inp(i), vreg_inp(i));
- for (int i = 0; i < jcp.alpha; i++)
- kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
+ for (int i = 0; i < jcp.m; i++)
+ kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
int oc_blocks = jcp.oc / load_block;
mov(reg_oc_block, oc_blocks);
dec(reg_oc_block);
jnz(oc_block_label, T_NEAR);
- sub(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
- sub(reg_ptr_bias, oc_blocks * sizeof(jcp.typesize_bia) * load_block);
-
postamble();
}
jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
- const primitive_attr_t &attr,
- bool with_relu, float relu_negative_slope);
+ const primitive_attr_t &attr);
Zmm vreg_out(int n, int m) {
const int id_reg_out = n * jcp.m_block + m;
using namespace primitive_kind;
const auto &p = attr.post_ops_;
- auto is_relu = [&](int idx) {
- return p.entry_[idx].kind == eltwise
- && p.entry_[idx].eltwise.scale == 1.
- && p.entry_[idx].eltwise.alg == alg_kind::eltwise_relu
- && p.entry_[idx].eltwise.alpha == 0.;
- };
+ auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
- switch (p.len_) {
+ switch (p.len_) {
case 0: return true;
- case 1: return true
- && IMPLICATION(jcp.with_relu, p.contain(sum, 0))
- && IMPLICATION(!jcp.with_relu, is_relu(0) || p.contain(sum, 0));
- case 2: return true
- && IMPLICATION(jcp.with_relu, p.contain(sum, 0) && is_relu(1))
- && IMPLICATION(!jcp.with_relu, false
- || (p.contain(sum, 0) && is_relu(1))
- || (p.contain(sum, 1) && is_relu(0)));
- case 3: return true
- && jcp.with_relu == false
- && (is_relu(0) && p.contain(sum, 1) && is_relu(2));
+ case 1: return is_relu(0) || p.contain(sum, 0);
+ case 2: return (p.contain(sum, 0) && is_relu(1)) ||
+ (p.contain(sum, 1) && is_relu(0));
+ case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
default: return false;
}
postamble();
}
+namespace {
+bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
+ if (jcp.ver == ver_vnni) {
+ return (jcp.mb <= mkldnn_get_max_threads()
+ && (jcp.mb > 4
+ && jcp.ic > 64
+ && !(jcp.oc > 128 && jcp.ih < 14)))
+ || jcp.mb > mkldnn_get_max_threads();
+ }
+ return true;
+}
+}
status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
::init_conf(jit_conv_conf_2x3_wino_t &jcp,
const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
cpu_memory_t::pd_t &wei_pd, cpu_memory_t::pd_t &dst_pd,
- cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr,
- bool with_relu, float relu_negative_slope) {
+ cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr) {
const memory_desc_wrapper src_d(&src_pd);
const memory_desc_wrapper wei_d(&wei_pd);
const memory_desc_wrapper dst_d(&dst_pd);
const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
+ jcp.nthr = mkldnn_get_max_threads();
+
jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
jcp.mb = src_d.dims()[0];
jcp.oc = dst_d.dims()[1] / jcp.ngroups;
if (mayiuse(avx512_core_vnni))
jcp.ver = ver_vnni;
+ if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
+ is_winograd_faster_than_direct(jcp)))
+ return status::unimplemented;
+
// block sizes needed for GEMM kernel
jcp.ic_block = 4;
jcp.oc_block = 16;
jcp.src_fmt = src_d.format();
jcp.with_bias = cd.bias_desc.format != memory_format::undef;
- jcp.with_relu = with_relu;
- jcp.relu_negative_slope = relu_negative_slope;
- if (!IMPLICATION(with_relu, relu_negative_slope == 0.))
- return status::unimplemented;
+
if (!post_ops_ok(jcp, attr))
return status::unimplemented;
jcp.alpha = jcp.m + jcp.r - 1;
int aa = jcp.alpha * jcp.alpha;
- int nthr = mkldnn_get_max_threads();
int L1_cap = get_cache_size(1, true);
int L2_cap = get_cache_size(2, true);
// need 1 extra reg for bcast, and 2 tmp regs for non-vnni
float Y = (float)jcp.ic * jcp.oc;
if (small_mb == 0) { // outer par
int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix);
- thr_eff = (float)nblocks / rnd_up(nblocks, nthr);
+ thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
} else { // inner par
int tranw = iy * ix / jcp.alpha;
int gemmw = aa * (jcp.nb_oc / n2_b);
- int tranw_r = rnd_up(tranw, nthr);
- int gemmw_r = rnd_up(gemmw, nthr);
+ int tranw_r = rnd_up(tranw, jcp.nthr);
+ int gemmw_r = rnd_up(gemmw, jcp.nthr);
thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y);
}
return thr_eff;
req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N;
mem_eff = nstl::min(1.f, L2_cap / req_mem);
// memory used during wino transforms
- int M_per_thr = div_up(M, nthr);
+ int M_per_thr = div_up(M, jcp.nthr);
req_mem = (float)aa * M_per_thr
* (jcp.ic + jcp.typesize_acc * jcp.oc);
if (req_mem > L2_cap)
assert((jcp.m_block + 1) * jcp.n2_block <= free_regs);
assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0);
- jcp.inp_stride = jcp.yb * jcp.xb / 4 * jcp.ic;
- jcp.out_stride = jcp.yb * jcp.xb / 4 * jcp.oc;
- jcp.wei_stride = jcp.ic * jcp.oc;
- jcp.bia_stride = jcp.oc;
+ jcp.mb_block = 1;
+ if (jcp.small_mb) {
+ // For small mb harness, set mb_block as large as possible subject to
+ // the constraint that winograd activations fit into available L3 cache
+ int L3_cap = get_cache_size(3, true);
+ int M = jcp.xb * jcp.yb / 4;
+ int wino_src_size = 16 * M * jcp.ic * jcp.typesize_in;
+ int wino_dst_size = 16 * M * jcp.oc * jcp.typesize_acc;
+ int max_mb_block = nstl::min(
+ jcp.mb, jcp.nthr * L3_cap / (wino_src_size + wino_dst_size));
+ for (int i = max_mb_block; i > 1; i--) {
+ if (jcp.mb % i == 0) {
+ jcp.mb_block = i;
+ break;
+ }
+ }
+ }
+ jcp.nb_mb = jcp.mb / jcp.mb_block;
- jcp.M = jcp.xb * jcp.yb / 4;
+ jcp.M = jcp.mb_block * jcp.xb * jcp.yb / 4;
jcp.N = jcp.oc;
jcp.K = jcp.ic;
+ jcp.inp_stride = jcp.M * jcp.ic;
+ jcp.out_stride = jcp.M * jcp.oc;
+ jcp.wei_stride = jcp.ic * jcp.oc;
+ jcp.bia_stride = jcp.oc;
+
jcp.n_block = jcp.oc_block;
jcp.k_block = jcp.ic_block;
if (!wei_pd.is_equal(&new_weights_pd))
return status::unimplemented;
+ const int tilesize = jcp.alpha * jcp.alpha;
+ const int numtiles = jcp.M;
+ const int alltiles = numtiles * tilesize;
+
+ jcp.size_wino_src
+ = utils::rnd_up(jcp.typesize_in * alltiles * jcp.ic, PAGE_4K)
+ / jcp.typesize_in;
+ jcp.size_wino_wei = tilesize * jcp.oc * jcp.ic;
+ jcp.size_wino_dst = alltiles * jcp.oc;
+
return status::success;
}
////////////////////////////////////////////////////////////////////////////////
-template <bool with_relu, data_type_t dst_data_type>
-status_t _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
- dst_data_type>::pd_t::jit_conf() {
+template <data_type_t dst_data_type>
+status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+ pd_t::jit_conf() {
return jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf(
- jcp_, this->cdesc_(), this->src_pd_, this->weights_pd_,
- this->dst_pd_,this->bias_pd_, *this->attr(),
- with_relu, this->negative_slope());
+ jcp_, *this->desc(), this->src_pd_, this->weights_pd_,
+ this->dst_pd_,this->bias_pd_, *this->attr());
}
-template <bool with_relu, data_type_t dst_data_type>
-_jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu, dst_data_type>::
- _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *pd,
+template <data_type_t dst_data_type>
+void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::pd_t::
+init_scratchpad() {
+ auto scratchpad = this->scratchpad_registry().registrar();
+
+ int nthr_multiplier = jcp_.small_mb ? 1 : jcp_.nthr;
+ scratchpad.book(key_wino_V,
+ sizeof(src_data_t) * jcp_.size_wino_src * nthr_multiplier, PAGE_4K);
+ scratchpad.book(key_wino_M,
+ sizeof(acc_data_t) * jcp_.size_wino_dst * nthr_multiplier, PAGE_4K);
+
+ scratchpad.book(key_conv_adjusted_scales,
+ sizeof(float) * nstl::max(attr()->output_scales_.count_, 16));
+}
+
+template <data_type_t dst_data_type>
+jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+ jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd,
const input_vector &inputs, const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs)
- , conf_(*pd)
- , scratchpad_(nullptr) {
- const int nthreads = mkldnn_get_max_threads();
+ : cpu_primitive_t(apd, inputs, outputs, true)
+{
kernel_ = new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
- conf_.jcp_, *conf_.attr());
+ pd()->jcp_, *pd()->attr());
src_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_src_trans_t(
- conf_.jcp_, *conf_.attr());
+ pd()->jcp_, *pd()->attr());
dst_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
- conf_.jcp_, *conf_.attr());
-
- const int tilesize = conf_.jcp_.alpha * conf_.jcp_.alpha;
- const int numtiles = (conf_.jcp_.yb / 2) * (conf_.jcp_.xb / 2);
- const int alltiles = tilesize * numtiles;
- size_wino_wei_ = tilesize * conf_.jcp_.oc * conf_.jcp_.ic;
- size_wino_src_ = sizeof(src_data_t) * alltiles * conf_.jcp_.ic;
- size_wino_src_ = rnd_up(size_wino_src_, PAGE_4K);
- size_wino_src_ /= sizeof(src_data_t);
- size_wino_dst_ = alltiles * conf_.jcp_.oc;
-
- size_t workspace_size = (conf_.jcp_.small_mb ? 1 : nthreads)
- * (sizeof(src_data_t) * size_wino_src_
- + sizeof(acc_data_t) * size_wino_dst_);
-
- scratchpad_ = create_scratchpad(workspace_size);
- assert(scratchpad_); // TODO: add proper check and raise exception?
-
- wino_shift_ = (conf_.jcp_.small_mb ? 1 : nthreads) * sizeof(src_data_t)
- * size_wino_src_;
-
- updated_output_scales_ = conf_.attr()->output_scales_;
- updated_output_scales_.scale(1.f / (adj_src_scale * adj_wei_scale));
+ pd()->jcp_, *pd()->attr());
}
-template <bool with_relu, data_type_t dst_data_type>
-_jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
- dst_data_type>::~_jit_avx512_core_u8s8s32x_wino_convolution_fwd_t() {
+template <data_type_t dst_data_type>
+jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+ ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t() {
delete kernel_;
delete src_trans_;
delete dst_trans_;
- delete scratchpad_;
}
-template <bool with_relu, data_type_t dst_data_type>
-void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
- dst_data_type>::execute_forward() {
+template <data_type_t dst_data_type>
+const float *jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+adjust_oscales(const memory_tracking::grantor_t &scratchpad) const {
+ const float *oscales = pd()->attr()->output_scales_.scales_;
+ auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales);
+ size_t count = pd()->attr()->output_scales_.count_;
+ float factor = 1.f / (adj_src_scale * adj_wei_scale);
+ if (count == 1)
+ utils::array_set(loc_scales, oscales[0] * factor, 16);
+ else
+ for (size_t c = 0; c < count; c++) loc_scales[c] = oscales[c] * factor;
+ return loc_scales;
+}
+
+template <data_type_t dst_data_type>
+void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+execute_forward() const {
const auto &jcp = kernel_->jcp;
if (jcp.small_mb)
execute_forward_small_mb();
execute_forward_mbN();
}
-template <bool with_relu, data_type_t dst_data_type>
-void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
- dst_data_type>::execute_forward_mbN() {
+template <data_type_t dst_data_type>
+void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+execute_forward_mbN() const {
auto src = reinterpret_cast<const src_data_t *>(input_memory(0));
auto wei = reinterpret_cast<const wei_data_t *>(input_memory(1));
auto bia = reinterpret_cast<const char *>(input_memory(2));
auto dst = reinterpret_cast<dst_data_t *>(memory(0));
+ auto scratchpad = this->scratchpad();
+
const auto &jcp = kernel_->jcp;
- const auto &oscales = updated_output_scales_;
+ const float *oscales = adjust_oscales(scratchpad);
- auto wino_wei = wei;
- auto dst_bias = (const acc_data_t *)(wei + size_wino_wei_);
- auto wino_src_base = (src_data_t *)scratchpad_->get();
- auto wino_dst_base = (acc_data_t *)(scratchpad_->get() + wino_shift_);
+ auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
+ auto wino_src_base = scratchpad.template get<src_data_t>(key_wino_V);
+ auto wino_dst_base = scratchpad.template get<acc_data_t>(key_wino_M);
parallel_nd(jcp.mb, div_up(jcp.oh, jcp.yb), div_up(jcp.ow, jcp.xb),
[&](int mb, int tile_y_b, int tile_x_b) {
int tile_x = tile_x_b * jcp.xb;
int ithr = mkldnn_get_thread_num();
- auto wino_src = wino_src_base + size_wino_src_ * ithr;
- auto wino_dst = wino_dst_base + size_wino_dst_ * ithr;
+ auto wino_src = wino_src_base + jcp.size_wino_src * ithr;
+ auto wino_dst = wino_dst_base + jcp.size_wino_dst * ithr;
auto src_trans_p =
jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t();
/* transformation of input tensor to winograd domain */
for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) {
- unsigned short v_y_masks[4], v_x_masks[4];
+ uint16_t v_y_masks[4], v_x_masks[4];
int y = y_in_block + tile_y;
int x = x_in_block + tile_x;
#pragma unroll(4)
for (int i = 0; i < jcp.alpha; i++) {
- v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
- v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
+ v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff);
+ v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff);
}
auto local_s = src
+ mb * jcp.ih * jcp.iw * jcp.ic
int offset = (tile_ij + ithr) % 16;
gemm_p.src = wino_src + jcp.inp_stride * offset;
gemm_p.dst = wino_dst + jcp.out_stride * offset;
- gemm_p.wei = wino_wei + jcp.wei_stride * offset;
+ gemm_p.wei = wei + jcp.wei_stride * offset;
gemm_p.dst_b = dst_bias + jcp.bia_stride * offset;
kernel_->ker_(&gemm_p);
/* transformation from winograd domain to output tensor */
for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) {
- unsigned short v_y_masks[2], v_x_masks[2];
+ uint16_t v_y_masks[2], v_x_masks[2];
int y = y_in_block + tile_y;
int x = x_in_block + tile_x;
#pragma unroll(2)
for (int i = 0; i < jcp.m; i++) {
- v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
- v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
+ v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0);
+ v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0);
}
auto local_d = dst
+ mb * jcp.oh * jcp.ow * jcp.oc
+ y * jcp.ow * jcp.oc + x * jcp.oc;
auto local_w = wino_dst + m * jcp.oc;
- auto scales = oscales.scales_;
+ auto scales = oscales;
dst_trans_p.dst = local_d;
dst_trans_p.wino_dst = local_w;
dst_trans_p.v_y_masks = v_y_masks;
});
}
-template <bool with_relu, data_type_t dst_data_type>
-void _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<with_relu,
- dst_data_type>::execute_forward_small_mb() {
+template <data_type_t dst_data_type>
+void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
+execute_forward_small_mb() const {
auto src = reinterpret_cast<const src_data_t *>(input_memory(0));
auto wei = reinterpret_cast<const wei_data_t *>(input_memory(1));
auto bia = reinterpret_cast<const char *>(input_memory(2));
auto dst = reinterpret_cast<dst_data_t *>(memory(0));
+ auto scratchpad = this->scratchpad();
+
const auto &jcp = kernel_->jcp;
- const auto &oscales = updated_output_scales_;
+ const float *oscales = adjust_oscales(scratchpad);
- auto wino_wei = wei;
- auto dst_bias = (const acc_data_t *)(wei + size_wino_wei_);
- auto wino_src = (src_data_t *)scratchpad_->get();
- auto wino_dst = (acc_data_t *)(scratchpad_->get() + wino_shift_);
+ auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
+ auto wino_src = scratchpad.template get<src_data_t>(key_wino_V);
+ auto wino_dst = scratchpad.template get<acc_data_t>(key_wino_M);
- for (int mb = 0; mb < jcp.mb; mb++) {
+ for (int mbb = 0; mbb < jcp.nb_mb; mbb++) {
for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) {
for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
/* transformation of input tensor to winograd domain */
- parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
- [&](int y_in_block_b, int x_in_block_b) {
+ parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
+ [&](int y_in_block_b, int x_in_block_b, int mb) {
int y_in_block = y_in_block_b * 2;
int x_in_block = x_in_block_b * 2;
auto src_trans_p =
jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t();
- unsigned short v_y_masks[4], v_x_masks[4];
+ uint16_t v_y_masks[4], v_x_masks[4];
int y = y_in_block + tile_y;
int x = x_in_block + tile_x;
- int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
+ int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2)
+ + (x_in_block / 2);
int v_ys = nstl::max(0, jcp.t_pad - y);
int v_ye = nstl::min(
#pragma unroll(4)
for (int i = 0; i < jcp.alpha; i++) {
- v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
- v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
+ v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff);
+ v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff);
}
auto local_s = src
- + mb * jcp.ih * jcp.iw * jcp.ic
+ + (mbb * jcp.mb_block + mb) * jcp.ih * jcp.iw * jcp.ic
+ y * jcp.iw * jcp.ic + x * jcp.ic;
auto local_w = wino_src + m * jcp.ic;
gemm_p.src = wino_src + jcp.inp_stride * tile_ij;
gemm_p.dst = wino_dst + jcp.out_stride * tile_ij
+ nnb * jcp.n2_block * jcp.n_block;
- gemm_p.wei = wino_wei + jcp.wei_stride * tile_ij
+ gemm_p.wei = wei + jcp.wei_stride * tile_ij
+ nnb * jcp.n2_block * jcp.n_block * jcp.K;
gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij
+ nnb * jcp.n2_block * jcp.n_block;
});
/* transformation from winograd domain to output tensor */
- parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
- [&](int y_in_block_b, int x_in_block_b) {
+ parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
+ [&](int y_in_block_b, int x_in_block_b, int mb) {
int y_in_block = y_in_block_b * 2;
int x_in_block = x_in_block_b * 2;
auto dst_trans_p =
jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t();
- unsigned short v_y_masks[2], v_x_masks[2];
+ uint16_t v_y_masks[2], v_x_masks[2];
int y = y_in_block + tile_y;
int x = x_in_block + tile_x;
- int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
+ int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2)
+ + (x_in_block / 2);
#pragma unroll(2)
for (int i = 0; i < jcp.m; i++) {
- v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
- v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
+ v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0);
+ v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0);
}
auto local_d = dst
- + mb * jcp.oh * jcp.ow * jcp.oc
+ + (mbb * jcp.mb_block + mb) * jcp.oh * jcp.ow * jcp.oc
+ y * jcp.ow * jcp.oc + x * jcp.oc;
auto local_w = wino_dst + m * jcp.oc;
- auto scales = oscales.scales_;
+ auto scales = oscales;
dst_trans_p.dst = local_d;
dst_trans_p.wino_dst = local_w;
dst_trans_p.v_y_masks = v_y_masks;
}}}
}
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<true,
- data_type::s8>;
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<false,
- data_type::s8>;
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<true,
- data_type::u8>;
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<false,
- data_type::u8>;
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<true,
- data_type::s32>;
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<false,
- data_type::s32>;
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<true,
- data_type::f32>;
-template struct _jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<false,
- data_type::f32>;
+template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::s8>;
+template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::u8>;
+template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::s32>;
+template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::f32>;
} // namespace cpu
} // namespace impl