1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
19 #include "c_types_map.hpp"
20 #include "memory_tracking.hpp"
21 #include "cpu_convolution_pd.hpp"
22 #include "cpu_engine.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "type_helpers.hpp"
27 #include "jit_avx512_core_u8s8s32x_wino_convolution.hpp"
28 #include "jit_generator.hpp"
36 using namespace mkldnn::impl::memory_format;
37 using namespace mkldnn::impl::memory_tracking::names;
38 using namespace mkldnn::impl::utils;
39 using namespace Xbyak;
42 // Below scales are applied to source and weights data accordingly
43 // because this winograd implementation
44 // transforms source which may increase values up to 4x
45 // and transforms weights which may increase values up to 9/4x
46 const float adj_src_scale = 1.f / 4.f;
47 const float adj_wei_scale = 4.f / 9.f;
48 // Winograd transforms need ic and oc to be multiples of 16
49 const int load_block = 16;
52 /// SRC TRANSFORMS /////////////////////////////////////////////////////////////
53 struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator {
54 DECLARE_CPU_JIT_AUX_FUNCTIONS(
55 jit_avx512_core_u8s8s32x_wino_conv_src_trans_t)
57 jit_conv_conf_2x3_wino_t jcp;
58 const primitive_attr_t &attr_;
60 struct call_params_t {
63 const void *v_y_masks;
64 const void *v_x_masks;
66 void (*ker_)(const call_params_t *);
68 jit_avx512_core_u8s8s32x_wino_conv_src_trans_t(
69 jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
70 : jcp(ajcp), attr_(attr), unsign_val_in_wino_domain(5) {
72 ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
76 int reg_inp_ind(int i) {
77 assert(i < jcp.alpha * jcp.alpha);
82 return Xmm(reg_inp_ind(i));
86 return Zmm(reg_inp_ind(i));
90 assert(i < jcp.alpha * jcp.alpha);
94 assert(i < jcp.alpha * jcp.alpha);
98 Opmask y_mask = Opmask(1);
99 Opmask r_mask = Opmask(2);
100 Opmask x_mask(int id) {
102 return Opmask(3 + id);
105 Reg64 reg_ptr_src = r14;
106 Reg64 reg_ptr_dst = r13;
108 Reg64 reg_ptr_v_y_masks = r12;
109 Reg64 reg_ptr_v_x_masks = r11;
111 Reg64 reg_aux_ptr_src = r10;
112 Reg64 reg_aux_ptr_dst = r9;
114 Reg64 reg_ic_block = r8;
116 int unsign_val_in_wino_domain;
118 Reg64 reg_scratch_src_alpha = rdx;
119 Xmm xmm_src_alpha = Xmm(0);
120 Zmm zmm_src_alpha = Zmm(0);
122 Reg64 reg_shift = rax;
123 Xmm xmm_shift = Xmm(1);
124 Xmm xmm_zero = Xmm(0);
126 Reg64 reg_maskx = rbx;
127 Reg64 reg_masky = rsi;
128 Reg64 reg_nomask = reg_maskx;
131 void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
132 Label ic_block_label;
137 auto load_src = [=](bool mask) {
138 for (int y = 0; y < jcp.alpha; y++) {
140 kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]);
141 for (int x = 0; x < jcp.alpha; x++) {
142 Zmm zmm_i = zmm_inp(y * jcp.alpha + x);
143 Xmm vreg_i = vreg_inp(y * jcp.alpha + x);
144 int inp_offset = sizeof(uint8_t)
145 * ((-jcp.t_pad + y) * jcp.iw * jcp.ic
146 + (-jcp.l_pad + x) * jcp.ic);
148 kandw(r_mask, y_mask, x_mask(x));
149 vmovdqu8(vreg_i | r_mask | T_z,
150 EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
153 EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
155 vpmovzxbd(zmm_i, vreg_i); // to int32
156 vcvtdq2ps(zmm_i, zmm_i); // to fp32
157 vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha
158 vcvtps2dq(zmm_i | T_rn_sae, zmm_i); // to int32
159 vpmovusdb(vreg_i, zmm_i); // to u8
166 # define READ_PARAM(reg, field) \
167 mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
168 READ_PARAM(reg_ptr_src, src);
169 READ_PARAM(reg_ptr_dst, wino_src);
170 READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
171 READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
174 mov(reg_maskx, ptr[reg_ptr_v_x_masks]);
175 mov(reg_masky, ptr[reg_ptr_v_y_masks]);
176 test(reg_maskx, reg_maskx);
177 jz(end_label, T_NEAR); // skip kernel if x mask is all 0's
178 test(reg_masky, reg_masky);
179 jz(end_label, T_NEAR); // skip kernel if y mask is all 0's
180 and_(reg_maskx, reg_masky);
181 mov(reg_nomask, reg_maskx);
182 not_(reg_nomask); // zero if x and y masks are all 1's
184 xor_(reg_shift, reg_shift);
185 mov(reg_shift.cvt8(), (int8_t)-128);
187 mov(reg_aux_ptr_src, reg_ptr_src);
188 mov(reg_aux_ptr_dst, reg_ptr_dst);
190 for (int i = 0; i < jcp.alpha; i++) {
191 kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
194 mov(reg_scratch_src_alpha, float2int(adj_src_scale));
196 mov(reg_ic_block, jcp.ic / load_block);
199 vmovq(xmm_src_alpha, reg_scratch_src_alpha);
200 vbroadcastss(zmm_src_alpha, xmm_src_alpha);
202 test(reg_nomask, reg_nomask);
203 jz(nomask_label, T_NEAR);
205 jmp(mask_label, T_NEAR);
210 for(int y = 0; y < 4; y++) {
211 vpsubb(vreg_tmp(y*4+0), vreg_inp(y*4+0), vreg_inp(y*4+2));
212 vpaddb(vreg_tmp(y*4+1), vreg_inp(y*4+1), vreg_inp(y*4+2));
213 vpsubb(vreg_tmp(y*4+2), vreg_inp(y*4+2), vreg_inp(y*4+1));
214 vpsubb(vreg_tmp(y*4+3), vreg_inp(y*4+1), vreg_inp(y*4+3));
216 for(int x = 0;x < 4; x++) {
217 vpsubb(vreg_out(x+0*4), vreg_tmp(x+4*0), vreg_tmp(x+4*2));
218 vpaddb(vreg_out(x+1*4), vreg_tmp(x+4*1), vreg_tmp(x+4*2));
219 vpsubb(vreg_out(x+2*4), vreg_tmp(x+4*2), vreg_tmp(x+4*1));
220 vpsubb(vreg_out(x+3*4), vreg_tmp(x+4*1), vreg_tmp(x+4*3));
223 vmovd(xmm_shift, reg_shift.cvt32());
224 vpxor(xmm_zero, xmm_zero, xmm_zero);
225 vpshufb(xmm_shift, xmm_shift, xmm_zero);
227 for (int i = 0; i < 16; i++) {
228 int out_offset = sizeof(uint8_t) * (jcp.inp_stride * i);
229 if (i != unsign_val_in_wino_domain)
230 vpsubb(vreg_out(i), vreg_out(i), Xmm(1));
231 vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), vreg_out(i));
234 add(reg_aux_ptr_src, sizeof(uint8_t) * load_block);
235 add(reg_aux_ptr_dst, sizeof(uint8_t) * load_block);
238 jnz(ic_block_label, T_NEAR);
244 /// DST TRANSFORMS /////////////////////////////////////////////////////////////
245 struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator {
246 DECLARE_CPU_JIT_AUX_FUNCTIONS(
247 jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t)
249 jit_conv_conf_2x3_wino_t jcp;
250 const primitive_attr_t &attr_;
252 struct call_params_t {
253 const void *wino_dst;
255 const void *v_y_masks;
256 const void *v_x_masks;
261 void (*ker_)(const call_params_t *);
263 jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
264 jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
265 : jcp(ajcp), attr_(attr) {
267 ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
271 bool maybe_relu(int position);
273 Zmm vreg_inp(int i) { // 16
274 assert(i < jcp.alpha * jcp.alpha);
277 Zmm vreg_stg(int id) { // 8
278 const int id_reg_stg = jcp.alpha * jcp.alpha + id;
280 return Zmm(31 - id_reg_stg);
282 Zmm vreg_out(int id) { // 4
283 const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
285 return Zmm(31 - id_reg_out);
287 Xmm xmm_out(int id) { // 4
288 const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
290 return Xmm(31 - id_reg_out);
292 Zmm vreg_tmp(int id) { // 2
293 const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id;
295 return Zmm(31 - id_reg_tmp);
298 Zmm vreg_zero = Zmm(0);
299 Zmm vreg_bias = Zmm(1);
300 Zmm vreg_prev_dst = Zmm(2);
301 Zmm zmm_bias_alpha = Zmm(2);
302 Xmm xmm_bias_alpha = Xmm(2);
304 Opmask y_mask = Opmask(1);
305 Opmask r_mask = Opmask(2);
306 Opmask x_mask(int id) {
308 return Opmask(3 + id);
311 Reg64 reg_scratch_bias_alpha = r15;
313 Reg64 reg_ptr_src = r14;
314 Reg64 reg_ptr_dst = r13;
316 Reg64 reg_ptr_v_y_masks = r12;
317 Reg64 reg_ptr_v_x_masks = r11;
319 Reg64 reg_aux_ptr_src = r10;
320 Reg64 reg_aux_ptr_dst = r9;
322 Reg64 reg_oc_block = r8;
324 Reg64 reg_ptr_bias = rbx;
325 Reg64 reg_ptr_scales = abi_not_param1;
326 Reg64 reg_ptr_sum_scale = rdx;
329 bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) {
330 using namespace primitive_kind;
331 const auto &p = attr_.post_ops_;
334 /* relu before sum */
336 || p.contain(eltwise, 0)
337 || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
338 } else if (position == 1) {
340 const int sum_idx = p.contain(sum, 0)
341 ? 0 : (p.contain(sum, 1) ? 1 : -1);
346 || p.contain(eltwise, sum_idx + 1)
347 || jcp.dst_dt == data_type::u8;
353 void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() {
354 Label oc_block_label;
356 auto loop_body = [=]() {
357 const auto &p = attr_.post_ops_;
358 const int sum_idx = p.find(primitive_kind::sum);
359 const float *p_sum_scale = (sum_idx != -1)
360 ? &p.entry_[sum_idx].sum.scale
362 if (p_sum_scale && *p_sum_scale != 1.f)
363 mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
365 for(int i = 0; i < 16; i++) {
366 int internal_offset = sizeof(int32_t) * jcp.out_stride * i;
368 EVEX_compress_addr(reg_aux_ptr_src, internal_offset));
370 for(int y = 0; y < jcp.alpha; y++) {
371 vpaddd(vreg_tmp(0), vreg_inp(y*4 + 0), vreg_inp(y*4 + 1));
372 vpaddd(vreg_stg(y*2), vreg_tmp(0), vreg_inp(y*4 + 2));
374 vpsubd(vreg_tmp(1), vreg_inp(y*4 + 1), vreg_inp(y*4 + 2));
375 vpsubd(vreg_stg(y*2+1), vreg_tmp(1), vreg_inp(y*4 + 3));
377 for(int x = 0; x < jcp.m; x++) {
378 vpaddd(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2*1));
379 vpaddd(vreg_out(x), vreg_tmp(0), vreg_stg(x+2*2));
381 vpsubd(vreg_tmp(1), vreg_stg(x+2*1), vreg_stg(x+2*2));
382 vpsubd(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2*3));
387 vmovq(xmm_bias_alpha, reg_scratch_bias_alpha);
388 vbroadcastss(zmm_bias_alpha, xmm_bias_alpha);
390 auto bias_addr = ptr [ reg_ptr_bias ];
391 switch (jcp.bia_dt) {
393 case data_type::s32: vmovups(vreg_bias, bias_addr); break;
394 case data_type::s8: vpmovsxbd(vreg_bias, bias_addr); break;
395 case data_type::u8: vpmovzxbd(vreg_bias, bias_addr); break;
396 default: assert(!"unsupported dst data type");
398 if (jcp.bia_dt != data_type::f32)
399 vcvtdq2ps(vreg_bias, vreg_bias);
400 vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha
402 for(int y = 0; y < jcp.m; y++) {
403 kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(uint16_t) * y ]);
404 for(int x = 0; x < jcp.m; x++) {
405 kandw(r_mask, y_mask, x_mask(x));
407 int i = y * jcp.m + x;
408 int offset = jcp.typesize_out *
409 (y * jcp.ow * jcp.oc + x * jcp.oc);
410 Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset);
412 Zmm zmm = vreg_out(i);
413 Xmm xmm = xmm_out(i);
416 vaddps(zmm, zmm, vreg_bias);
417 vmulps(zmm, zmm, ptr [reg_ptr_scales]);
419 vmaxps(zmm, vreg_zero, zmm);
420 if (p_sum_scale) { // post_op: sum
421 vpxord(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst);
422 switch (jcp.dst_dt) {
425 vmovups(vreg_prev_dst | r_mask, addr); break;
427 vpmovsxbd(vreg_prev_dst | r_mask, addr); break;
429 vpmovzxbd(vreg_prev_dst | r_mask, addr); break;
430 default: assert(!"unknown dst_dt");
432 if (jcp.dst_dt != data_type::f32)
433 vcvtdq2ps(vreg_prev_dst, vreg_prev_dst);
434 if (*p_sum_scale == 1.f)
435 vaddps(zmm, vreg_prev_dst);
437 vfmadd231ps(zmm, vreg_prev_dst,
438 zword_b[reg_ptr_sum_scale]);
441 vmaxps(zmm, vreg_zero, zmm);
442 if (jcp.dst_dt != data_type::f32) {
443 if (attr_.round_mode_ == round_mode::nearest)
444 vcvtps2dq(zmm | T_rn_sae, zmm);
445 else if (attr_.round_mode_ == round_mode::down)
446 vcvtps2dq(zmm | T_rd_sae, zmm);
448 assert(!"unimplemented");
450 switch (jcp.dst_dt) {
453 vmovups(addr, zmm | r_mask); break;
455 vpmovsdb(xmm, zmm); vmovups(addr, xmm | r_mask); break;
457 vpmovusdb(xmm, zmm); vmovups(addr, xmm | r_mask); break;
458 default: assert(!"unknown dst_dt");
466 # define READ_PARAM(reg, field) \
467 mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
468 READ_PARAM(reg_ptr_src, wino_dst);
469 READ_PARAM(reg_ptr_dst, dst);
470 READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
471 READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
472 READ_PARAM(reg_ptr_bias, bias);
473 READ_PARAM(reg_ptr_scales, scales);
477 mov(reg_scratch_bias_alpha, float2int(adj_src_scale * adj_wei_scale));
479 mov(reg_aux_ptr_src, reg_ptr_src);
480 mov(reg_aux_ptr_dst, reg_ptr_dst);
482 vpxord(vreg_zero, vreg_zero, vreg_zero);
484 for (int i = 0; i < jcp.m; i++)
485 kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
487 int oc_blocks = jcp.oc / load_block;
488 mov(reg_oc_block, oc_blocks);
491 add(reg_aux_ptr_src, sizeof(int32_t) * load_block);
492 add(reg_aux_ptr_dst, jcp.typesize_out * load_block);
494 add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
495 add(reg_ptr_bias, sizeof(jcp.typesize_bia) * load_block);
498 jnz(oc_block_label, T_NEAR);
504 /// GEMM kernel ////////////////////////////////////////////////////////////////
505 struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t: public jit_generator {
506 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t)
507 jit_conv_conf_2x3_wino_t jcp;
508 const primitive_attr_t &attr_;
510 struct call_params_t {
516 void (*ker_)(const call_params_t *);
519 static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp,
520 const primitive_attr_t &attr);
522 jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
523 jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
524 : jcp(ajcp), attr_(attr)
527 ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
530 static status_t init_conf(
531 jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
532 cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
533 cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
534 const primitive_attr_t &attr);
536 Zmm vreg_out(int n, int m) {
537 const int id_reg_out = n * jcp.m_block + m;
538 assert(id_reg_out < jcp.n2_block * jcp.m_block);
539 return Zmm(31 - id_reg_out);
541 Zmm vreg_wei(int i) {
542 assert(31 - jcp.n2_block * jcp.m_block - i
543 > (jcp.ver == ver_vnni ? 0 : 2));
544 return Zmm(31 - jcp.n2_block * jcp.m_block - i);
547 Zmm vreg_src = Zmm(0);
548 Zmm vreg_one = Zmm(1);
549 Zmm vreg_tmp = Zmm(2);
551 Reg64 reg_ptr_src = r15;
553 Reg64 reg_aux_dst_b = r13;
554 Reg64 reg_aux_dst = r12;
555 Reg64 reg_aux_dst2 = r11;
556 Reg64 reg_aux_wei = r10;
557 Reg64 reg_aux_wei2 = r9;
558 Reg64 reg_aux_src = r8;
559 Reg64 reg_aux_src2 = rax;
561 Reg64 reg_nnb = abi_not_param1;
562 Reg64 reg_scratch = rdx;
566 bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok(
567 jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) {
568 using namespace primitive_kind;
569 const auto &p = attr.post_ops_;
571 auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
575 case 1: return is_relu(0) || p.contain(sum, 0);
576 case 2: return (p.contain(sum, 0) && is_relu(1)) ||
577 (p.contain(sum, 1) && is_relu(0));
578 case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
579 default: return false;
585 void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() {
586 Label nnb_loop_label, K_loop_label, mb_loop_label;
588 auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
589 if (jcp.ver == ver_vnni) {
590 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
592 vpmaddubsw(vreg_tmp, vreg_src, vreg_wei);
593 vpmaddwd(vreg_tmp, vreg_tmp, vreg_one);
594 vpaddd(vreg_acc, vreg_acc, vreg_tmp);
599 # define READ_PARAM(reg, field) \
600 mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
601 READ_PARAM(reg_ptr_src, src);
602 READ_PARAM(reg_aux_dst, dst);
603 READ_PARAM(reg_aux_wei, wei);
604 READ_PARAM(reg_aux_dst_b, dst_b);
607 if (jcp.ver != ver_vnni) {
608 xor_(reg_scratch, reg_scratch);
609 Reg16 _t = reg_scratch.cvt16();
611 vpbroadcastw(vreg_one, _t);
615 mov(reg_nnb, jcp.n_chunks);
618 mov(reg_aux_dst2, reg_aux_dst);
619 mov(reg_aux_src, reg_ptr_src);
620 mov(reg_mb, jcp.M / jcp.m_block);
623 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
624 for (int m = 0; m < jcp.m_block; m++) {
625 int offset = jcp.typesize_acc * nb2 * jcp.n_block;
626 vmovups(vreg_out(nb2, m),
627 EVEX_compress_addr(reg_aux_dst_b, offset));
630 mov(reg_aux_src2, reg_aux_src);
631 mov(reg_aux_wei2, reg_aux_wei);
632 mov(reg_K, jcp.k_chunks);
635 for (int k = 0; k < jcp.k2_block; k += 4) {
636 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
638 = jcp.typesize_in * (nb2 * jcp.n_block * jcp.K);
639 vmovups(vreg_wei(nb2),
640 EVEX_compress_addr(reg_aux_wei2, wei_offset));
642 for (int m = 0; m < jcp.m_block; m++) {
643 int inp_offset = jcp.typesize_in * m * jcp.K;
644 vpbroadcastd(vreg_src,
645 EVEX_compress_addr(reg_aux_src2, inp_offset));
646 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++)
647 compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src);
649 add(reg_aux_src2, jcp.typesize_in * 4);
650 add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block);
654 jnz(K_loop_label, T_NEAR);
656 for (int m = 0; m < jcp.m_block; m++) {
657 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
658 int offset = jcp.typesize_acc * (m * jcp.N + nb2 * jcp.n_block);
659 vmovups(EVEX_compress_addr(reg_aux_dst2, offset),
663 add(reg_aux_src, jcp.typesize_in * jcp.m_block * jcp.K);
664 add(reg_aux_dst2, jcp.typesize_acc * jcp.m_block * jcp.N);
667 jnz(mb_loop_label, T_NEAR);
670 add(reg_aux_dst, jcp.typesize_acc * jcp.n2_block * jcp.n_block);
671 add(reg_aux_dst_b, jcp.typesize_acc * jcp.n2_block * jcp.n_block);
672 add(reg_aux_wei, jcp.typesize_in * jcp.n2_block * jcp.n_block * jcp.K);
675 jnz(nnb_loop_label, T_NEAR);
681 bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
682 if (jcp.ver == ver_vnni) {
683 return (jcp.mb <= mkldnn_get_max_threads()
686 && !(jcp.oc > 128 && jcp.ih < 14)))
687 || jcp.mb > mkldnn_get_max_threads();
693 status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
694 ::init_conf(jit_conv_conf_2x3_wino_t &jcp,
695 const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
696 cpu_memory_t::pd_t &wei_pd, cpu_memory_t::pd_t &dst_pd,
697 cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr) {
698 const memory_desc_wrapper src_d(&src_pd);
699 const memory_desc_wrapper wei_d(&wei_pd);
700 const memory_desc_wrapper dst_d(&dst_pd);
701 const memory_desc_wrapper bias_d(&bias_pd);
703 const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
705 jcp.nthr = mkldnn_get_max_threads();
707 jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
708 jcp.mb = src_d.dims()[0];
709 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
710 jcp.ic = src_d.dims()[1] / jcp.ngroups;
711 jcp.ih = src_d.dims()[2];
712 jcp.iw = src_d.dims()[3];
713 jcp.oh = dst_d.dims()[2];
714 jcp.ow = dst_d.dims()[3];
715 jcp.kh = wei_d.dims()[with_groups + 2];
716 jcp.kw = wei_d.dims()[with_groups + 3];
717 jcp.t_pad = cd.padding[0][0];
718 jcp.b_pad = cd.padding[1][0];
719 jcp.l_pad = cd.padding[0][1];
720 jcp.r_pad = cd.padding[1][1];
721 jcp.stride_h = cd.strides[0];
722 jcp.stride_w = cd.strides[1];
723 jcp.dilate_h = cd.dilates[0];
724 jcp.dilate_w = cd.dilates[1];
726 jcp.ver = ver_avx512_core;
727 if (!(mayiuse(avx512_core) &&
728 src_d.data_type() == data_type::u8
729 && wei_d.data_type() == data_type::s8
730 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
731 data_type::s8, data_type::u8)))
732 return status::unimplemented;
733 if (mayiuse(avx512_core_vnni))
736 if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
737 is_winograd_faster_than_direct(jcp)))
738 return status::unimplemented;
740 // block sizes needed for GEMM kernel
746 && jcp.oc % load_block == 0 && jcp.ic % load_block == 0
747 && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0
748 && everyone_is(3, jcp.kh, jcp.kw)
749 && everyone_is(1, jcp.stride_h, jcp.stride_w)
750 && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
751 && jcp.t_pad == jcp.b_pad && jcp.l_pad == jcp.r_pad
752 && one_of(jcp.t_pad, 0, 1)
753 && one_of(jcp.l_pad, 0, 1);
754 if (!ok) return status::unimplemented;
756 jcp.src_fmt = src_d.format();
757 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
759 if (!post_ops_ok(jcp, attr))
760 return status::unimplemented;
762 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
763 jcp.dst_dt = cd.dst_desc.data_type;
765 jcp.typesize_in = types::data_type_size(src_d.data_type());
766 jcp.typesize_out = types::data_type_size(dst_d.data_type());
767 jcp.typesize_acc = sizeof(int32_t);
768 jcp.typesize_bia = jcp.with_bias
769 ? types::data_type_size(bias_d.data_type())
772 jcp.nb_oc = jcp.oc / jcp.oc_block;
773 jcp.nb_ic = jcp.ic / jcp.ic_block;
777 jcp.alpha = jcp.m + jcp.r - 1;
779 int aa = jcp.alpha * jcp.alpha;
780 int L1_cap = get_cache_size(1, true);
781 int L2_cap = get_cache_size(2, true);
782 // need 1 extra reg for bcast, and 2 tmp regs for non-vnni
783 int free_regs = jcp.ver == ver_vnni ? 31 : 29;
785 auto get_thr_eff = [&](int small_mb, int ix, int iy, int n2_b) {
787 float Z = (float)jcp.ic + jcp.oc;
788 float Y = (float)jcp.ic * jcp.oc;
789 if (small_mb == 0) { // outer par
790 int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix);
791 thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
792 } else { // inner par
793 int tranw = iy * ix / jcp.alpha;
794 int gemmw = aa * (jcp.nb_oc / n2_b);
795 int tranw_r = rnd_up(tranw, jcp.nthr);
796 int gemmw_r = rnd_up(gemmw, jcp.nthr);
797 thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y);
802 auto get_mem_eff = [&](int small_mb, int ix, int iy, int n2_b) {
803 float mem_eff, req_mem;
804 int M = ix * iy / jcp.alpha;
805 if (small_mb == 0) { // outer parallelization strategy
806 // memory for wino transforms (other memory has poor reuse)
807 req_mem = (float)aa * M * (jcp.ic + jcp.typesize_acc * jcp.oc);
808 mem_eff = req_mem < L1_cap ? 1.f : req_mem < L2_cap ? 0.5f : 0.f;
809 } else { // inner parallelization strategy
810 // memory used during gemm
811 int N = jcp.oc_block * n2_b;
812 req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N;
813 mem_eff = nstl::min(1.f, L2_cap / req_mem);
814 // memory used during wino transforms
815 int M_per_thr = div_up(M, jcp.nthr);
816 req_mem = (float)aa * M_per_thr
817 * (jcp.ic + jcp.typesize_acc * jcp.oc);
818 if (req_mem > L2_cap)
824 auto get_tot_eff = [&](int small_mb, float thr_eff, float work_eff,
825 float mem_eff, float reg_eff) {
826 // these coefficients are chosen empirically
827 float mem_fac = 0.1f, reg_fac = 0.2f;
828 // normalized overhead relative to memory and register components
829 float tot_eff = 1.f + mem_fac * mem_eff + reg_fac * reg_eff;
830 // thread and work components affect all others
831 tot_eff *= thr_eff * work_eff;
835 auto find_m_n2_blocks = [&](bool small_mb, int ix, int iy, float work_eff,
836 int &m_block, int &n2_block, float &tot_eff) {
837 int M = (ix * iy) / jcp.alpha;
838 int max_m_block = nstl::min(M, free_regs);
839 int max_n2_block = nstl::min(jcp.nb_oc, free_regs);
841 for (int im = max_m_block; im > 0; im--) {
844 for (int in2 = max_n2_block; in2 > 0; in2--) {
845 int used_regs = (im + 1) * in2;
846 float mem_eff = get_mem_eff(small_mb, ix, iy, in2);
847 float reg_eff = (float)(im * in2) / (im + in2);
848 float thr_eff = get_thr_eff(small_mb, ix, iy, in2);
849 float cur_tot_eff = get_tot_eff(
850 small_mb, thr_eff, work_eff, mem_eff, reg_eff);
851 if (jcp.nb_oc % in2 || used_regs > free_regs
852 || cur_tot_eff <= tot_eff)
854 tot_eff = cur_tot_eff;
861 /* Selecting xb and yb blocking */
864 int max_yb = nstl::max(min_yb, rnd_up(jcp.oh, 2));
865 int max_xb = nstl::max(min_xb, rnd_up(jcp.ow, 2));
866 float best_eff = 0.f;
867 for (int ix = min_xb; ix <= max_xb; ix += 2) {
868 assert(rnd_up(jcp.ow, ix) >= jcp.iw - 2);
869 for (int iy = max_yb; iy >= min_yb; iy -= 2) {
870 assert(rnd_up(jcp.oh, iy) >= jcp.ih - 2);
875 float inner_eff, outer_eff, work_eff;
877 int tiled_area = rnd_up(jcp.oh, iy) * rnd_up(jcp.ow, ix);
878 work_eff = (float)jcp.oh * jcp.ow / tiled_area;
879 if (best_eff > 0.f && work_eff < 4.f / 9.f)
880 continue; // no gain from Winograd transformation
882 /* outer parallelization */
883 find_m_n2_blocks(0, ix, iy, work_eff, m_b[0], n2_b[0], outer_eff);
885 /* inner parallelization */
886 find_m_n2_blocks(1, ix, iy, work_eff, m_b[1], n2_b[1], inner_eff);
888 small_mb = inner_eff > outer_eff;
889 float eff = small_mb ? inner_eff : outer_eff;
890 if (eff > best_eff) {
894 jcp.m_block = m_b[small_mb];
895 jcp.n2_block = n2_b[small_mb];
896 jcp.small_mb = small_mb;
901 assert((jcp.m_block + 1) * jcp.n2_block <= free_regs);
902 assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0);
906 // For small mb harness, set mb_block as large as possible subject to
907 // the constraint that winograd activations fit into available L3 cache
908 int L3_cap = get_cache_size(3, true);
909 int M = jcp.xb * jcp.yb / 4;
910 int wino_src_size = 16 * M * jcp.ic * jcp.typesize_in;
911 int wino_dst_size = 16 * M * jcp.oc * jcp.typesize_acc;
912 int max_mb_block = nstl::min(
913 jcp.mb, jcp.nthr * L3_cap / (wino_src_size + wino_dst_size));
914 for (int i = max_mb_block; i > 1; i--) {
915 if (jcp.mb % i == 0) {
921 jcp.nb_mb = jcp.mb / jcp.mb_block;
923 jcp.M = jcp.mb_block * jcp.xb * jcp.yb / 4;
927 jcp.inp_stride = jcp.M * jcp.ic;
928 jcp.out_stride = jcp.M * jcp.oc;
929 jcp.wei_stride = jcp.ic * jcp.oc;
930 jcp.bia_stride = jcp.oc;
932 jcp.n_block = jcp.oc_block;
933 jcp.k_block = jcp.ic_block;
935 jcp.n_chunks = (jcp.N / jcp.n_block) / jcp.n2_block;
937 // We need jcp.k2_block to be a multiple of jcp.k_block = jcp.ic_block = 4
938 // and jcp.K = jcp.ic to be a multiple of jcp.k2_block. Since jcp.ic is
939 // a multiple of load_block = 16, we just use that for now.
940 jcp.k2_block = load_block;
941 jcp.k_chunks = jcp.K / jcp.k2_block;
943 const auto &oscales = attr.output_scales_;
944 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
945 assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
947 /* re-create weights primitive descriptor
948 and set weights wino_blocking */
949 memory_desc_t expect_wei_md = *(wei_pd.desc());
951 expect_wei_md.format = mkldnn_wino_fmt;
952 expect_wei_md.data_type = data_type::s8;
953 mkldnn_wino_desc_t &wd = expect_wei_md.layout_desc.wino_desc;
954 wd.wino_format = mkldnn_wino_wei_aaOIoi;
956 wd.alpha = jcp.alpha;
959 wd.ic_block = jcp.ic_block;
960 wd.oc_block = jcp.oc_block;
961 wd.oc2_block = jcp.n2_block;
963 wd.adj_scale = adj_wei_scale;
965 size_t max_size = types::data_type_size(data_type::s8) *
966 jcp.alpha * jcp.alpha * jcp.ic * jcp.oc;
967 max_size += types::data_type_size(data_type::s32) *
968 jcp.alpha * jcp.alpha * jcp.oc;
971 cpu_memory_t::pd_t new_weights_pd(wei_pd.engine(), &expect_wei_md);
972 if (wei_pd.desc()->format == any)
973 wei_pd = new_weights_pd;
974 if (!wei_pd.is_equal(&new_weights_pd))
975 return status::unimplemented;
977 const int tilesize = jcp.alpha * jcp.alpha;
978 const int numtiles = jcp.M;
979 const int alltiles = numtiles * tilesize;
982 = utils::rnd_up(jcp.typesize_in * alltiles * jcp.ic, PAGE_4K)
984 jcp.size_wino_wei = tilesize * jcp.oc * jcp.ic;
985 jcp.size_wino_dst = alltiles * jcp.oc;
987 return status::success;
989 ////////////////////////////////////////////////////////////////////////////////
991 template <data_type_t dst_data_type>
992 status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
994 return jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf(
995 jcp_, *this->desc(), this->src_pd_, this->weights_pd_,
996 this->dst_pd_,this->bias_pd_, *this->attr());
999 template <data_type_t dst_data_type>
1000 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::pd_t::
1002 auto scratchpad = this->scratchpad_registry().registrar();
1004 int nthr_multiplier = jcp_.small_mb ? 1 : jcp_.nthr;
1005 scratchpad.book(key_wino_V,
1006 sizeof(src_data_t) * jcp_.size_wino_src * nthr_multiplier, PAGE_4K);
1007 scratchpad.book(key_wino_M,
1008 sizeof(acc_data_t) * jcp_.size_wino_dst * nthr_multiplier, PAGE_4K);
1010 scratchpad.book(key_conv_adjusted_scales,
1011 sizeof(float) * nstl::max(attr()->output_scales_.count_, 16));
1014 template <data_type_t dst_data_type>
1015 jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
1016 jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd,
1017 const input_vector &inputs, const output_vector &outputs)
1018 : cpu_primitive_t(apd, inputs, outputs, true)
1020 kernel_ = new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
1021 pd()->jcp_, *pd()->attr());
1022 src_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_src_trans_t(
1023 pd()->jcp_, *pd()->attr());
1024 dst_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
1025 pd()->jcp_, *pd()->attr());
1028 template <data_type_t dst_data_type>
1029 jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
1030 ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t() {
1036 template <data_type_t dst_data_type>
1037 const float *jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
1038 adjust_oscales(const memory_tracking::grantor_t &scratchpad) const {
1039 const float *oscales = pd()->attr()->output_scales_.scales_;
1040 auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales);
1041 size_t count = pd()->attr()->output_scales_.count_;
1042 float factor = 1.f / (adj_src_scale * adj_wei_scale);
1044 utils::array_set(loc_scales, oscales[0] * factor, 16);
1046 for (size_t c = 0; c < count; c++) loc_scales[c] = oscales[c] * factor;
1050 template <data_type_t dst_data_type>
1051 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
1052 execute_forward() const {
1053 const auto &jcp = kernel_->jcp;
1055 execute_forward_small_mb();
1057 execute_forward_mbN();
1060 template <data_type_t dst_data_type>
1061 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
1062 execute_forward_mbN() const {
1063 auto src = reinterpret_cast<const src_data_t *>(input_memory(0));
1064 auto wei = reinterpret_cast<const wei_data_t *>(input_memory(1));
1065 auto bia = reinterpret_cast<const char *>(input_memory(2));
1066 auto dst = reinterpret_cast<dst_data_t *>(memory(0));
1068 auto scratchpad = this->scratchpad();
1070 const auto &jcp = kernel_->jcp;
1071 const float *oscales = adjust_oscales(scratchpad);
1073 auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
1074 auto wino_src_base = scratchpad.template get<src_data_t>(key_wino_V);
1075 auto wino_dst_base = scratchpad.template get<acc_data_t>(key_wino_M);
1077 parallel_nd(jcp.mb, div_up(jcp.oh, jcp.yb), div_up(jcp.ow, jcp.xb),
1078 [&](int mb, int tile_y_b, int tile_x_b) {
1080 int tile_y = tile_y_b * jcp.yb;
1081 int tile_x = tile_x_b * jcp.xb;
1083 int ithr = mkldnn_get_thread_num();
1084 auto wino_src = wino_src_base + jcp.size_wino_src * ithr;
1085 auto wino_dst = wino_dst_base + jcp.size_wino_dst * ithr;
1088 jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t();
1090 jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t();
1092 jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::call_params_t();
1094 /* transformation of input tensor to winograd domain */
1095 for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
1096 for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) {
1097 uint16_t v_y_masks[4], v_x_masks[4];
1099 int y = y_in_block + tile_y;
1100 int x = x_in_block + tile_x;
1101 int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
1103 int v_ys = nstl::max(0, jcp.t_pad - y);
1104 int v_ye = nstl::min(jcp.alpha,
1105 nstl::max(0, jcp.ih + jcp.t_pad - y));
1107 int v_xs = nstl::max(0, jcp.l_pad - x);
1108 int v_xe = nstl::min(jcp.alpha,
1109 nstl::max(0, jcp.iw + jcp.l_pad - x));
1112 for (int i = 0; i < jcp.alpha; i++) {
1113 v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff);
1114 v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff);
1117 + mb * jcp.ih * jcp.iw * jcp.ic
1118 + y * jcp.iw * jcp.ic + x * jcp.ic;
1119 auto local_w = wino_src + m * jcp.ic;
1121 src_trans_p.src = local_s;
1122 src_trans_p.wino_src = local_w;
1123 src_trans_p.v_y_masks = v_y_masks;
1124 src_trans_p.v_x_masks = v_x_masks;
1126 src_trans_->ker_(&src_trans_p);
1130 for (int tile_ij = 0; tile_ij < 16; tile_ij++) {
1131 // start threads at different GEMMs to help bring weights into LLC
1132 int offset = (tile_ij + ithr) % 16;
1133 gemm_p.src = wino_src + jcp.inp_stride * offset;
1134 gemm_p.dst = wino_dst + jcp.out_stride * offset;
1135 gemm_p.wei = wei + jcp.wei_stride * offset;
1136 gemm_p.dst_b = dst_bias + jcp.bia_stride * offset;
1138 kernel_->ker_(&gemm_p);
1141 /* transformation from winograd domain to output tensor */
1142 for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
1143 for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) {
1144 uint16_t v_y_masks[2], v_x_masks[2];
1146 int y = y_in_block + tile_y;
1147 int x = x_in_block + tile_x;
1148 int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
1151 for (int i = 0; i < jcp.m; i++) {
1152 v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0);
1153 v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0);
1156 + mb * jcp.oh * jcp.ow * jcp.oc
1157 + y * jcp.ow * jcp.oc + x * jcp.oc;
1158 auto local_w = wino_dst + m * jcp.oc;
1160 auto scales = oscales;
1161 dst_trans_p.dst = local_d;
1162 dst_trans_p.wino_dst = local_w;
1163 dst_trans_p.v_y_masks = v_y_masks;
1164 dst_trans_p.v_x_masks = v_x_masks;
1166 dst_trans_p.scales = scales;
1167 dst_trans_p.bias = bia;
1169 dst_trans_->ker_(&dst_trans_p);
1175 template <data_type_t dst_data_type>
1176 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
1177 execute_forward_small_mb() const {
1178 auto src = reinterpret_cast<const src_data_t *>(input_memory(0));
1179 auto wei = reinterpret_cast<const wei_data_t *>(input_memory(1));
1180 auto bia = reinterpret_cast<const char *>(input_memory(2));
1181 auto dst = reinterpret_cast<dst_data_t *>(memory(0));
1183 auto scratchpad = this->scratchpad();
1185 const auto &jcp = kernel_->jcp;
1186 const float *oscales = adjust_oscales(scratchpad);
1188 auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
1189 auto wino_src = scratchpad.template get<src_data_t>(key_wino_V);
1190 auto wino_dst = scratchpad.template get<acc_data_t>(key_wino_M);
1192 for (int mbb = 0; mbb < jcp.nb_mb; mbb++) {
1193 for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) {
1194 for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
1195 /* transformation of input tensor to winograd domain */
1196 parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
1197 [&](int y_in_block_b, int x_in_block_b, int mb) {
1198 int y_in_block = y_in_block_b * 2;
1199 int x_in_block = x_in_block_b * 2;
1202 jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t();
1204 uint16_t v_y_masks[4], v_x_masks[4];
1206 int y = y_in_block + tile_y;
1207 int x = x_in_block + tile_x;
1208 int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2)
1211 int v_ys = nstl::max(0, jcp.t_pad - y);
1212 int v_ye = nstl::min(
1213 jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y));
1215 int v_xs = nstl::max(0, jcp.l_pad - x);
1216 int v_xe = nstl::min(
1217 jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x));
1220 for (int i = 0; i < jcp.alpha; i++) {
1221 v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff);
1222 v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff);
1225 + (mbb * jcp.mb_block + mb) * jcp.ih * jcp.iw * jcp.ic
1226 + y * jcp.iw * jcp.ic + x * jcp.ic;
1227 auto local_w = wino_src + m * jcp.ic;
1229 src_trans_p.src = local_s;
1230 src_trans_p.wino_src = local_w;
1231 src_trans_p.v_y_masks = v_y_masks;
1232 src_trans_p.v_x_masks = v_x_masks;
1234 src_trans_->ker_(&src_trans_p);
1238 parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) {
1239 auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::
1242 gemm_p.src = wino_src + jcp.inp_stride * tile_ij;
1243 gemm_p.dst = wino_dst + jcp.out_stride * tile_ij
1244 + nnb * jcp.n2_block * jcp.n_block;
1245 gemm_p.wei = wei + jcp.wei_stride * tile_ij
1246 + nnb * jcp.n2_block * jcp.n_block * jcp.K;
1247 gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij
1248 + nnb * jcp.n2_block * jcp.n_block;
1250 kernel_->ker_(&gemm_p);
1253 /* transformation from winograd domain to output tensor */
1254 parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
1255 [&](int y_in_block_b, int x_in_block_b, int mb) {
1256 int y_in_block = y_in_block_b * 2;
1257 int x_in_block = x_in_block_b * 2;
1260 jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t();
1262 uint16_t v_y_masks[2], v_x_masks[2];
1264 int y = y_in_block + tile_y;
1265 int x = x_in_block + tile_x;
1266 int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2)
1270 for (int i = 0; i < jcp.m; i++) {
1271 v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0);
1272 v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0);
1275 + (mbb * jcp.mb_block + mb) * jcp.oh * jcp.ow * jcp.oc
1276 + y * jcp.ow * jcp.oc + x * jcp.oc;
1277 auto local_w = wino_dst + m * jcp.oc;
1279 auto scales = oscales;
1280 dst_trans_p.dst = local_d;
1281 dst_trans_p.wino_dst = local_w;
1282 dst_trans_p.v_y_masks = v_y_masks;
1283 dst_trans_p.v_x_masks = v_x_masks;
1285 dst_trans_p.scales = scales;
1286 dst_trans_p.bias = bia;
1288 dst_trans_->ker_(&dst_trans_p);
1293 template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::s8>;
1294 template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::u8>;
1295 template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::s32>;
1296 template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::f32>;
1300 } // namespace mkldnn