updated readme file due to moving CMake scripts to the root folder
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_u8s8s32x_wino_convolution.cpp
1 /*******************************************************************************
2  * Copyright 2018 Intel Corporation
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "c_types_map.hpp"
20 #include "memory_tracking.hpp"
21 #include "cpu_convolution_pd.hpp"
22 #include "cpu_engine.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "type_helpers.hpp"
25 #include "utils.hpp"
26
27 #include "jit_avx512_core_u8s8s32x_wino_convolution.hpp"
28 #include "jit_generator.hpp"
29
30 #include <string.h>
31
32 namespace mkldnn {
33 namespace impl {
34 namespace cpu {
35
36 using namespace mkldnn::impl::memory_format;
37 using namespace mkldnn::impl::memory_tracking::names;
38 using namespace mkldnn::impl::utils;
39 using namespace Xbyak;
40
41 namespace {
42     // Below scales are applied to source and weights data accordingly
43     // because this winograd implementation
44     // transforms source which may increase values up to 4x
45     // and transforms weights which may increase values up to 9/4x
46     const float adj_src_scale = 1.f / 4.f;
47     const float adj_wei_scale = 4.f / 9.f;
48     // Winograd transforms need ic and oc to be multiples of 16
49     const int load_block = 16;
50 }
51
52 /// SRC TRANSFORMS /////////////////////////////////////////////////////////////
53 struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t: public jit_generator {
54     DECLARE_CPU_JIT_AUX_FUNCTIONS(
55             jit_avx512_core_u8s8s32x_wino_conv_src_trans_t)
56
57     jit_conv_conf_2x3_wino_t jcp;
58     const primitive_attr_t &attr_;
59
60     struct call_params_t {
61         const void *src;
62         const void *wino_src;
63         const void *v_y_masks;
64         const void *v_x_masks;
65     };
66     void (*ker_)(const call_params_t *);
67
68     jit_avx512_core_u8s8s32x_wino_conv_src_trans_t(
69         jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
70         : jcp(ajcp), attr_(attr), unsign_val_in_wino_domain(5) {
71         generate();
72         ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
73     }
74     void generate();
75
76     int reg_inp_ind(int i) {
77         assert(i < jcp.alpha * jcp.alpha);
78         return (31 - i);
79     }
80
81     Xmm vreg_inp(int i) {
82         return Xmm(reg_inp_ind(i));
83     }
84
85     Zmm zmm_inp(int i) {
86         return Zmm(reg_inp_ind(i));
87     }
88
89     Xmm vreg_tmp(int i) {
90         assert(i < jcp.alpha * jcp.alpha);
91         return Xmm(15 - i);
92     }
93     Xmm vreg_out(int i) {
94         assert(i < jcp.alpha * jcp.alpha);
95         return Xmm(31 - i);
96     }
97
98     Opmask y_mask = Opmask(1);
99     Opmask r_mask = Opmask(2);
100     Opmask x_mask(int id) {
101         assert(id < 4);
102         return Opmask(3 + id);
103     }
104
105     Reg64 reg_ptr_src = r14;
106     Reg64 reg_ptr_dst = r13;
107
108     Reg64 reg_ptr_v_y_masks = r12;
109     Reg64 reg_ptr_v_x_masks = r11;
110
111     Reg64 reg_aux_ptr_src = r10;
112     Reg64 reg_aux_ptr_dst = r9;
113
114     Reg64 reg_ic_block = r8;
115
116     int unsign_val_in_wino_domain;
117
118     Reg64 reg_scratch_src_alpha = rdx;
119     Xmm xmm_src_alpha = Xmm(0);
120     Zmm zmm_src_alpha = Zmm(0);
121
122     Reg64 reg_shift = rax;
123     Xmm xmm_shift = Xmm(1);
124     Xmm xmm_zero = Xmm(0);
125
126     Reg64 reg_maskx = rbx;
127     Reg64 reg_masky = rsi;
128     Reg64 reg_nomask = reg_maskx;
129 };
130
131 void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
132     Label ic_block_label;
133     Label end_label;
134     Label mask_label;
135     Label nomask_label;
136
137     auto load_src = [=](bool mask) {
138         for (int y = 0; y < jcp.alpha; y++) {
139             if (mask)
140                 kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]);
141             for (int x = 0; x < jcp.alpha; x++) {
142                 Zmm zmm_i = zmm_inp(y * jcp.alpha + x);
143                 Xmm vreg_i = vreg_inp(y * jcp.alpha + x);
144                 int inp_offset = sizeof(uint8_t)
145                         * ((-jcp.t_pad + y) * jcp.iw * jcp.ic
146                                 + (-jcp.l_pad + x) * jcp.ic);
147                 if (mask) {
148                     kandw(r_mask, y_mask, x_mask(x));
149                     vmovdqu8(vreg_i | r_mask | T_z,
150                             EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
151                 } else {
152                     vmovdqu8(vreg_i,
153                             EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
154                 }
155                 vpmovzxbd(zmm_i, vreg_i); // to int32
156                 vcvtdq2ps(zmm_i, zmm_i); // to fp32
157                 vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha
158                 vcvtps2dq(zmm_i | T_rn_sae, zmm_i); // to int32
159                 vpmovusdb(vreg_i, zmm_i); // to u8
160             }
161         }
162     };
163
164     preamble();
165
166 #   define READ_PARAM(reg, field) \
167         mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
168     READ_PARAM(reg_ptr_src, src);
169     READ_PARAM(reg_ptr_dst, wino_src);
170     READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
171     READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
172 #   undef READ_PARAM
173
174     mov(reg_maskx, ptr[reg_ptr_v_x_masks]);
175     mov(reg_masky, ptr[reg_ptr_v_y_masks]);
176     test(reg_maskx, reg_maskx);
177     jz(end_label, T_NEAR); // skip kernel if x mask is all 0's
178     test(reg_masky, reg_masky);
179     jz(end_label, T_NEAR); // skip kernel if y mask is all 0's
180     and_(reg_maskx, reg_masky);
181     mov(reg_nomask, reg_maskx);
182     not_(reg_nomask); // zero if x and y masks are all 1's
183
184     xor_(reg_shift, reg_shift);
185     mov(reg_shift.cvt8(), (int8_t)-128);
186
187     mov(reg_aux_ptr_src, reg_ptr_src);
188     mov(reg_aux_ptr_dst, reg_ptr_dst);
189
190     for (int i = 0; i < jcp.alpha; i++) {
191         kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
192     }
193
194     mov(reg_scratch_src_alpha, float2int(adj_src_scale));
195
196     mov(reg_ic_block, jcp.ic / load_block);
197     L(ic_block_label);
198     {
199         vmovq(xmm_src_alpha, reg_scratch_src_alpha);
200         vbroadcastss(zmm_src_alpha, xmm_src_alpha);
201
202         test(reg_nomask, reg_nomask);
203         jz(nomask_label, T_NEAR);
204         load_src(true);
205         jmp(mask_label, T_NEAR);
206         L(nomask_label);
207         load_src(false);
208         L(mask_label);
209
210         for(int y = 0; y < 4; y++) {
211             vpsubb(vreg_tmp(y*4+0), vreg_inp(y*4+0), vreg_inp(y*4+2));
212             vpaddb(vreg_tmp(y*4+1), vreg_inp(y*4+1), vreg_inp(y*4+2));
213             vpsubb(vreg_tmp(y*4+2), vreg_inp(y*4+2), vreg_inp(y*4+1));
214             vpsubb(vreg_tmp(y*4+3), vreg_inp(y*4+1), vreg_inp(y*4+3));
215         }
216         for(int x = 0;x < 4; x++) {
217             vpsubb(vreg_out(x+0*4), vreg_tmp(x+4*0), vreg_tmp(x+4*2));
218             vpaddb(vreg_out(x+1*4), vreg_tmp(x+4*1), vreg_tmp(x+4*2));
219             vpsubb(vreg_out(x+2*4), vreg_tmp(x+4*2), vreg_tmp(x+4*1));
220             vpsubb(vreg_out(x+3*4), vreg_tmp(x+4*1), vreg_tmp(x+4*3));
221         }
222
223         vmovd(xmm_shift, reg_shift.cvt32());
224         vpxor(xmm_zero, xmm_zero, xmm_zero);
225         vpshufb(xmm_shift, xmm_shift, xmm_zero);
226
227         for (int i = 0; i < 16; i++) {
228             int out_offset = sizeof(uint8_t) * (jcp.inp_stride * i);
229             if (i != unsign_val_in_wino_domain)
230                 vpsubb(vreg_out(i), vreg_out(i), Xmm(1));
231             vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset), vreg_out(i));
232         }
233
234         add(reg_aux_ptr_src, sizeof(uint8_t) * load_block);
235         add(reg_aux_ptr_dst, sizeof(uint8_t) * load_block);
236     }
237     dec(reg_ic_block);
238     jnz(ic_block_label, T_NEAR);
239
240     L(end_label);
241     postamble();
242 }
243
244 /// DST TRANSFORMS /////////////////////////////////////////////////////////////
245 struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t: public jit_generator {
246     DECLARE_CPU_JIT_AUX_FUNCTIONS(
247             jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t)
248
249     jit_conv_conf_2x3_wino_t jcp;
250     const primitive_attr_t &attr_;
251
252     struct call_params_t {
253         const void *wino_dst;
254         const void *dst;
255         const void *v_y_masks;
256         const void *v_x_masks;
257
258         const void *bias;
259         const void *scales;
260     };
261     void (*ker_)(const call_params_t *);
262
263     jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
264         jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
265         : jcp(ajcp), attr_(attr) {
266         generate();
267         ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
268     }
269
270     void generate();
271     bool maybe_relu(int position);
272
273     Zmm vreg_inp(int i) { // 16
274         assert(i < jcp.alpha * jcp.alpha);
275         return Zmm(31 - i);
276     }
277     Zmm vreg_stg(int id) { // 8
278         const int id_reg_stg = jcp.alpha * jcp.alpha + id;
279         assert(id < 8);
280         return Zmm(31 - id_reg_stg);
281     }
282     Zmm vreg_out(int id) { // 4
283         const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
284         assert(id < 4);
285         return Zmm(31 - id_reg_out);
286     }
287     Xmm xmm_out(int id) { // 4
288         const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
289         assert(id < 4);
290         return Xmm(31 - id_reg_out);
291     }
292     Zmm vreg_tmp(int id) { // 2
293         const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id;
294         assert(id < 2);
295         return Zmm(31 - id_reg_tmp);
296     }
297
298     Zmm vreg_zero = Zmm(0);
299     Zmm vreg_bias = Zmm(1);
300     Zmm vreg_prev_dst = Zmm(2);
301     Zmm zmm_bias_alpha = Zmm(2);
302     Xmm xmm_bias_alpha = Xmm(2);
303
304     Opmask y_mask = Opmask(1);
305     Opmask r_mask = Opmask(2);
306     Opmask x_mask(int id) {
307         assert(id < 4);
308         return Opmask(3 + id);
309     }
310
311     Reg64 reg_scratch_bias_alpha = r15;
312
313     Reg64 reg_ptr_src = r14;
314     Reg64 reg_ptr_dst = r13;
315
316     Reg64 reg_ptr_v_y_masks = r12;
317     Reg64 reg_ptr_v_x_masks = r11;
318
319     Reg64 reg_aux_ptr_src = r10;
320     Reg64 reg_aux_ptr_dst = r9;
321
322     Reg64 reg_oc_block = r8;
323
324     Reg64 reg_ptr_bias = rbx;
325     Reg64 reg_ptr_scales = abi_not_param1;
326     Reg64 reg_ptr_sum_scale = rdx;
327 };
328
329 bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) {
330     using namespace primitive_kind;
331     const auto &p = attr_.post_ops_;
332
333     if (position == 0) {
334         /* relu before sum */
335         return false
336             || p.contain(eltwise, 0)
337             || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
338     } else if (position == 1) {
339         /* relu after sum */
340         const int sum_idx = p.contain(sum, 0)
341             ? 0 : (p.contain(sum, 1) ? 1 : -1);
342         if (sum_idx == -1)
343             return false;
344
345         return false
346             || p.contain(eltwise, sum_idx + 1)
347             || jcp.dst_dt == data_type::u8;
348     }
349
350     return false;
351 }
352
353 void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() {
354     Label oc_block_label;
355
356     auto loop_body = [=]() {
357         const auto &p = attr_.post_ops_;
358         const int sum_idx = p.find(primitive_kind::sum);
359         const float *p_sum_scale = (sum_idx != -1)
360                 ? &p.entry_[sum_idx].sum.scale
361                 : nullptr;
362         if (p_sum_scale && *p_sum_scale != 1.f)
363             mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
364
365         for(int i = 0; i < 16; i++) {
366             int internal_offset = sizeof(int32_t) * jcp.out_stride * i;
367             vmovups(vreg_inp(i),
368                 EVEX_compress_addr(reg_aux_ptr_src, internal_offset));
369         }
370         for(int y = 0; y < jcp.alpha; y++) {
371             vpaddd(vreg_tmp(0), vreg_inp(y*4 + 0), vreg_inp(y*4 + 1));
372             vpaddd(vreg_stg(y*2), vreg_tmp(0), vreg_inp(y*4 + 2));
373
374             vpsubd(vreg_tmp(1), vreg_inp(y*4 + 1), vreg_inp(y*4 + 2));
375             vpsubd(vreg_stg(y*2+1), vreg_tmp(1), vreg_inp(y*4 + 3));
376         }
377         for(int x = 0; x < jcp.m; x++) {
378             vpaddd(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2*1));
379             vpaddd(vreg_out(x), vreg_tmp(0), vreg_stg(x+2*2));
380
381             vpsubd(vreg_tmp(1), vreg_stg(x+2*1), vreg_stg(x+2*2));
382             vpsubd(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2*3));
383         }
384
385
386         if (jcp.with_bias) {
387             vmovq(xmm_bias_alpha, reg_scratch_bias_alpha);
388             vbroadcastss(zmm_bias_alpha, xmm_bias_alpha);
389
390             auto bias_addr = ptr [ reg_ptr_bias ];
391             switch (jcp.bia_dt) {
392             case data_type::f32:
393             case data_type::s32: vmovups(vreg_bias, bias_addr); break;
394             case data_type::s8: vpmovsxbd(vreg_bias, bias_addr); break;
395             case data_type::u8: vpmovzxbd(vreg_bias, bias_addr); break;
396             default: assert(!"unsupported dst data type");
397             }
398             if (jcp.bia_dt != data_type::f32)
399                 vcvtdq2ps(vreg_bias, vreg_bias);
400             vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha
401         }
402         for(int y = 0; y < jcp.m; y++) {
403             kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(uint16_t) * y ]);
404             for(int x = 0; x < jcp.m; x++) {
405                 kandw(r_mask, y_mask, x_mask(x));
406
407                 int i = y * jcp.m + x;
408                 int offset = jcp.typesize_out *
409                     (y * jcp.ow * jcp.oc + x * jcp.oc);
410                 Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset);
411
412                 Zmm zmm = vreg_out(i);
413                 Xmm xmm = xmm_out(i);
414                 vcvtdq2ps(zmm, zmm);
415                 if (jcp.with_bias)
416                     vaddps(zmm, zmm, vreg_bias);
417                 vmulps(zmm, zmm, ptr [reg_ptr_scales]);
418                 if (maybe_relu(0))
419                     vmaxps(zmm, vreg_zero, zmm);
420                 if (p_sum_scale) { // post_op: sum
421                     vpxord(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst);
422                     switch (jcp.dst_dt) {
423                     case data_type::f32:
424                     case data_type::s32:
425                         vmovups(vreg_prev_dst | r_mask, addr); break;
426                     case data_type::s8:
427                         vpmovsxbd(vreg_prev_dst | r_mask, addr); break;
428                     case data_type::u8:
429                         vpmovzxbd(vreg_prev_dst | r_mask, addr); break;
430                     default: assert(!"unknown dst_dt");
431                     }
432                     if (jcp.dst_dt != data_type::f32)
433                         vcvtdq2ps(vreg_prev_dst, vreg_prev_dst);
434                     if (*p_sum_scale == 1.f)
435                         vaddps(zmm, vreg_prev_dst);
436                     else
437                         vfmadd231ps(zmm, vreg_prev_dst,
438                             zword_b[reg_ptr_sum_scale]);
439                 }
440                 if (maybe_relu(1))
441                     vmaxps(zmm, vreg_zero, zmm);
442                 if (jcp.dst_dt != data_type::f32) {
443                     if (attr_.round_mode_ == round_mode::nearest)
444                         vcvtps2dq(zmm | T_rn_sae, zmm);
445                     else if (attr_.round_mode_ == round_mode::down)
446                         vcvtps2dq(zmm | T_rd_sae, zmm);
447                     else
448                         assert(!"unimplemented");
449                 }
450                 switch (jcp.dst_dt) {
451                 case data_type::f32:
452                 case data_type::s32:
453                     vmovups(addr,  zmm | r_mask); break;
454                 case data_type::s8:
455                     vpmovsdb(xmm, zmm); vmovups(addr, xmm | r_mask); break;
456                 case data_type::u8:
457                     vpmovusdb(xmm, zmm); vmovups(addr, xmm | r_mask); break;
458                 default: assert(!"unknown dst_dt");
459                 }
460             }
461         }
462     };
463
464     preamble();
465
466 #   define READ_PARAM(reg, field) \
467         mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
468     READ_PARAM(reg_ptr_src, wino_dst);
469     READ_PARAM(reg_ptr_dst, dst);
470     READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
471     READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
472     READ_PARAM(reg_ptr_bias, bias);
473     READ_PARAM(reg_ptr_scales, scales);
474 #   undef READ_PARAM
475
476     if (jcp.with_bias)
477         mov(reg_scratch_bias_alpha, float2int(adj_src_scale * adj_wei_scale));
478
479     mov(reg_aux_ptr_src, reg_ptr_src);
480     mov(reg_aux_ptr_dst, reg_ptr_dst);
481
482     vpxord(vreg_zero, vreg_zero, vreg_zero);
483
484     for (int i = 0; i < jcp.m; i++)
485         kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
486
487     int oc_blocks = jcp.oc / load_block;
488     mov(reg_oc_block, oc_blocks);
489     L(oc_block_label); {
490         loop_body();
491         add(reg_aux_ptr_src, sizeof(int32_t) * load_block);
492         add(reg_aux_ptr_dst, jcp.typesize_out * load_block);
493
494         add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
495         add(reg_ptr_bias, sizeof(jcp.typesize_bia) * load_block);
496     }
497     dec(reg_oc_block);
498     jnz(oc_block_label, T_NEAR);
499
500     postamble();
501
502 }
503
504 /// GEMM kernel ////////////////////////////////////////////////////////////////
505 struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t: public jit_generator {
506     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t)
507     jit_conv_conf_2x3_wino_t jcp;
508     const primitive_attr_t &attr_;
509
510     struct call_params_t {
511         const void *src;
512         const void *dst;
513         const void *wei;
514         const void *dst_b;
515     };
516     void (*ker_)(const call_params_t *);
517
518     void generate();
519     static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp,
520                             const primitive_attr_t &attr);
521
522     jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
523         jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
524         : jcp(ajcp), attr_(attr)
525     {
526         generate();
527         ker_ = reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
528     }
529
530     static status_t init_conf(
531             jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
532             cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
533             cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
534             const primitive_attr_t &attr);
535
536     Zmm vreg_out(int n, int m) {
537         const int id_reg_out = n * jcp.m_block + m;
538         assert(id_reg_out < jcp.n2_block * jcp.m_block);
539         return Zmm(31 - id_reg_out);
540     }
541     Zmm vreg_wei(int i) {
542         assert(31 - jcp.n2_block * jcp.m_block - i
543                 > (jcp.ver == ver_vnni ? 0 : 2));
544         return Zmm(31 - jcp.n2_block * jcp.m_block - i);
545     }
546
547     Zmm vreg_src = Zmm(0);
548     Zmm vreg_one = Zmm(1);
549     Zmm vreg_tmp = Zmm(2);
550
551     Reg64 reg_ptr_src = r15;
552
553     Reg64 reg_aux_dst_b = r13;
554     Reg64 reg_aux_dst = r12;
555     Reg64 reg_aux_dst2 = r11;
556     Reg64 reg_aux_wei = r10;
557     Reg64 reg_aux_wei2 = r9;
558     Reg64 reg_aux_src = r8;
559     Reg64 reg_aux_src2 = rax;
560     Reg64 reg_mb = rbx;
561     Reg64 reg_nnb = abi_not_param1;
562     Reg64 reg_scratch = rdx;
563     Reg64 reg_K = rsi;
564 };
565
566 bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok(
567         jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) {
568     using namespace primitive_kind;
569     const auto &p = attr.post_ops_;
570
571     auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
572
573     switch (p.len_) {
574     case 0: return true;
575     case 1: return is_relu(0) || p.contain(sum, 0);
576     case 2: return (p.contain(sum, 0) && is_relu(1)) ||
577                        (p.contain(sum, 1) && is_relu(0));
578     case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
579     default: return false;
580     }
581
582     return false;
583 }
584
585 void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() {
586     Label nnb_loop_label, K_loop_label, mb_loop_label;
587
588     auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
589         if (jcp.ver == ver_vnni) {
590             vpdpbusd(vreg_acc, vreg_src, vreg_wei);
591         } else {
592             vpmaddubsw(vreg_tmp, vreg_src, vreg_wei);
593             vpmaddwd(vreg_tmp, vreg_tmp, vreg_one);
594             vpaddd(vreg_acc, vreg_acc, vreg_tmp);
595         }
596     };
597
598     preamble();
599 #   define READ_PARAM(reg, field) \
600         mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
601     READ_PARAM(reg_ptr_src, src);
602     READ_PARAM(reg_aux_dst, dst);
603     READ_PARAM(reg_aux_wei, wei);
604     READ_PARAM(reg_aux_dst_b, dst_b);
605 #   undef READ_PARAM
606
607     if (jcp.ver != ver_vnni) {
608         xor_(reg_scratch, reg_scratch);
609         Reg16 _t = reg_scratch.cvt16();
610         mov(_t, 0x1);
611         vpbroadcastw(vreg_one, _t);
612     }
613
614     if (!jcp.small_mb) {
615         mov(reg_nnb, jcp.n_chunks);
616         L(nnb_loop_label);
617     }
618     mov(reg_aux_dst2, reg_aux_dst);
619     mov(reg_aux_src, reg_ptr_src);
620     mov(reg_mb, jcp.M / jcp.m_block);
621     L(mb_loop_label);
622     {
623         for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
624             for (int m = 0; m < jcp.m_block; m++) {
625                 int offset = jcp.typesize_acc * nb2 * jcp.n_block;
626                 vmovups(vreg_out(nb2, m),
627                         EVEX_compress_addr(reg_aux_dst_b, offset));
628             }
629         }
630         mov(reg_aux_src2, reg_aux_src);
631         mov(reg_aux_wei2, reg_aux_wei);
632         mov(reg_K, jcp.k_chunks);
633         L(K_loop_label);
634         {
635             for (int k = 0; k < jcp.k2_block; k += 4) {
636                 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
637                     int wei_offset
638                             = jcp.typesize_in * (nb2 * jcp.n_block * jcp.K);
639                     vmovups(vreg_wei(nb2),
640                             EVEX_compress_addr(reg_aux_wei2, wei_offset));
641                 }
642                 for (int m = 0; m < jcp.m_block; m++) {
643                     int inp_offset = jcp.typesize_in * m * jcp.K;
644                     vpbroadcastd(vreg_src,
645                             EVEX_compress_addr(reg_aux_src2, inp_offset));
646                     for (int nb2 = 0; nb2 < jcp.n2_block; nb2++)
647                         compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src);
648                 }
649                 add(reg_aux_src2, jcp.typesize_in * 4);
650                 add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block);
651             }
652         }
653         dec(reg_K);
654         jnz(K_loop_label, T_NEAR);
655
656         for (int m = 0; m < jcp.m_block; m++) {
657             for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
658                 int offset = jcp.typesize_acc * (m * jcp.N + nb2 * jcp.n_block);
659                 vmovups(EVEX_compress_addr(reg_aux_dst2, offset),
660                         vreg_out(nb2, m));
661             }
662         }
663         add(reg_aux_src, jcp.typesize_in * jcp.m_block * jcp.K);
664         add(reg_aux_dst2, jcp.typesize_acc * jcp.m_block * jcp.N);
665     }
666     dec(reg_mb);
667     jnz(mb_loop_label, T_NEAR);
668
669     if (!jcp.small_mb) {
670         add(reg_aux_dst, jcp.typesize_acc * jcp.n2_block * jcp.n_block);
671         add(reg_aux_dst_b, jcp.typesize_acc * jcp.n2_block * jcp.n_block);
672         add(reg_aux_wei, jcp.typesize_in * jcp.n2_block * jcp.n_block * jcp.K);
673
674         dec(reg_nnb);
675         jnz(nnb_loop_label, T_NEAR);
676     }
677
678     postamble();
679 }
680 namespace {
681 bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
682     if (jcp.ver == ver_vnni) {
683         return (jcp.mb <= mkldnn_get_max_threads()
684             && (jcp.mb > 4
685                 && jcp.ic > 64
686                 && !(jcp.oc > 128 && jcp.ih < 14)))
687             || jcp.mb > mkldnn_get_max_threads();
688     }
689     return true;
690 }
691 }
692
693 status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t
694 ::init_conf(jit_conv_conf_2x3_wino_t &jcp,
695             const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
696             cpu_memory_t::pd_t &wei_pd, cpu_memory_t::pd_t &dst_pd,
697             cpu_memory_t::pd_t &bias_pd, const primitive_attr_t &attr) {
698     const memory_desc_wrapper src_d(&src_pd);
699     const memory_desc_wrapper wei_d(&wei_pd);
700     const memory_desc_wrapper dst_d(&dst_pd);
701     const memory_desc_wrapper bias_d(&bias_pd);
702
703     const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
704
705     jcp.nthr = mkldnn_get_max_threads();
706
707     jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
708     jcp.mb = src_d.dims()[0];
709     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
710     jcp.ic = src_d.dims()[1] / jcp.ngroups;
711     jcp.ih = src_d.dims()[2];
712     jcp.iw = src_d.dims()[3];
713     jcp.oh = dst_d.dims()[2];
714     jcp.ow = dst_d.dims()[3];
715     jcp.kh = wei_d.dims()[with_groups + 2];
716     jcp.kw = wei_d.dims()[with_groups + 3];
717     jcp.t_pad = cd.padding[0][0];
718     jcp.b_pad = cd.padding[1][0];
719     jcp.l_pad = cd.padding[0][1];
720     jcp.r_pad = cd.padding[1][1];
721     jcp.stride_h = cd.strides[0];
722     jcp.stride_w = cd.strides[1];
723     jcp.dilate_h = cd.dilates[0];
724     jcp.dilate_w = cd.dilates[1];
725
726     jcp.ver = ver_avx512_core;
727     if (!(mayiuse(avx512_core) &&
728             src_d.data_type() == data_type::u8
729          && wei_d.data_type() == data_type::s8
730          && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
731             data_type::s8, data_type::u8)))
732         return status::unimplemented;
733     if (mayiuse(avx512_core_vnni))
734         jcp.ver = ver_vnni;
735
736     if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
737                is_winograd_faster_than_direct(jcp)))
738         return status::unimplemented;
739
740     // block sizes needed for GEMM kernel
741     jcp.ic_block = 4;
742     jcp.oc_block = 16;
743
744     bool ok = true
745         && jcp.ngroups == 1
746         && jcp.oc % load_block == 0 && jcp.ic % load_block == 0
747         && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0
748         && everyone_is(3, jcp.kh, jcp.kw)
749         && everyone_is(1, jcp.stride_h, jcp.stride_w)
750         && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
751         && jcp.t_pad == jcp.b_pad && jcp.l_pad == jcp.r_pad
752         && one_of(jcp.t_pad, 0, 1)
753         && one_of(jcp.l_pad, 0, 1);
754     if (!ok) return status::unimplemented;
755
756     jcp.src_fmt = src_d.format();
757     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
758
759     if (!post_ops_ok(jcp, attr))
760         return status::unimplemented;
761
762     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
763     jcp.dst_dt = cd.dst_desc.data_type;
764
765     jcp.typesize_in = types::data_type_size(src_d.data_type());
766     jcp.typesize_out = types::data_type_size(dst_d.data_type());
767     jcp.typesize_acc = sizeof(int32_t);
768     jcp.typesize_bia = jcp.with_bias
769         ? types::data_type_size(bias_d.data_type())
770         : 0;
771
772     jcp.nb_oc = jcp.oc / jcp.oc_block;
773     jcp.nb_ic = jcp.ic / jcp.ic_block;
774
775     jcp.m = 2;
776     jcp.r = 3;
777     jcp.alpha = jcp.m + jcp.r - 1;
778
779     int aa = jcp.alpha * jcp.alpha;
780     int L1_cap = get_cache_size(1, true);
781     int L2_cap = get_cache_size(2, true);
782     // need 1 extra reg for bcast, and 2 tmp regs for non-vnni
783     int free_regs = jcp.ver == ver_vnni ? 31 : 29;
784
785     auto get_thr_eff = [&](int small_mb, int ix, int iy, int n2_b) {
786         float thr_eff;
787         float Z = (float)jcp.ic + jcp.oc;
788         float Y = (float)jcp.ic * jcp.oc;
789         if (small_mb == 0) { // outer par
790             int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix);
791             thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
792         } else { // inner par
793             int tranw = iy * ix / jcp.alpha;
794             int gemmw = aa * (jcp.nb_oc / n2_b);
795             int tranw_r = rnd_up(tranw, jcp.nthr);
796             int gemmw_r = rnd_up(gemmw, jcp.nthr);
797             thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y);
798         }
799         return thr_eff;
800     };
801
802     auto get_mem_eff = [&](int small_mb, int ix, int iy, int n2_b) {
803         float mem_eff, req_mem;
804         int M = ix * iy / jcp.alpha;
805         if (small_mb == 0) { // outer parallelization strategy
806             // memory for wino transforms (other memory has poor reuse)
807             req_mem = (float)aa * M * (jcp.ic + jcp.typesize_acc * jcp.oc);
808             mem_eff = req_mem < L1_cap ? 1.f : req_mem < L2_cap ? 0.5f : 0.f;
809         } else { // inner parallelization strategy
810             // memory used during gemm
811             int N = jcp.oc_block * n2_b;
812             req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N;
813             mem_eff = nstl::min(1.f, L2_cap / req_mem);
814             // memory used during wino transforms
815             int M_per_thr = div_up(M, jcp.nthr);
816             req_mem = (float)aa * M_per_thr
817                     * (jcp.ic + jcp.typesize_acc * jcp.oc);
818             if (req_mem > L2_cap)
819                 mem_eff = 0.1f;
820         }
821         return mem_eff;
822     };
823
824     auto get_tot_eff = [&](int small_mb, float thr_eff, float work_eff,
825             float mem_eff, float reg_eff) {
826         // these coefficients are chosen empirically
827         float mem_fac = 0.1f, reg_fac = 0.2f;
828         // normalized overhead relative to memory and register components
829         float tot_eff = 1.f + mem_fac * mem_eff + reg_fac * reg_eff;
830         // thread and work components affect all others
831         tot_eff *= thr_eff * work_eff;
832         return tot_eff;
833     };
834
835     auto find_m_n2_blocks = [&](bool small_mb, int ix, int iy, float work_eff,
836             int &m_block, int &n2_block, float &tot_eff) {
837         int M = (ix * iy) / jcp.alpha;
838         int max_m_block = nstl::min(M, free_regs);
839         int max_n2_block = nstl::min(jcp.nb_oc, free_regs);
840         tot_eff = 0.f;
841         for (int im = max_m_block; im > 0; im--) {
842             if (M % im)
843                 continue;
844             for (int in2 = max_n2_block; in2 > 0; in2--) {
845                 int used_regs = (im + 1) * in2;
846                 float mem_eff = get_mem_eff(small_mb, ix, iy, in2);
847                 float reg_eff = (float)(im * in2) / (im + in2);
848                 float thr_eff = get_thr_eff(small_mb, ix, iy, in2);
849                 float cur_tot_eff = get_tot_eff(
850                         small_mb, thr_eff, work_eff, mem_eff, reg_eff);
851                 if (jcp.nb_oc % in2 || used_regs > free_regs
852                         || cur_tot_eff <= tot_eff)
853                     continue;
854                 tot_eff = cur_tot_eff;
855                 m_block = im;
856                 n2_block = in2;
857             }
858         }
859     };
860
861     /* Selecting xb and yb blocking */
862     int min_yb = jcp.m;
863     int min_xb = jcp.m;
864     int max_yb = nstl::max(min_yb, rnd_up(jcp.oh, 2));
865     int max_xb = nstl::max(min_xb, rnd_up(jcp.ow, 2));
866     float best_eff = 0.f;
867     for (int ix = min_xb; ix <= max_xb; ix += 2) {
868         assert(rnd_up(jcp.ow, ix) >= jcp.iw - 2);
869         for (int iy = max_yb; iy >= min_yb; iy -= 2) {
870             assert(rnd_up(jcp.oh, iy) >= jcp.ih - 2);
871
872             int m_b[2];
873             int n2_b[2];
874             bool small_mb;
875             float inner_eff, outer_eff, work_eff;
876
877             int tiled_area = rnd_up(jcp.oh, iy) * rnd_up(jcp.ow, ix);
878             work_eff = (float)jcp.oh * jcp.ow / tiled_area;
879             if (best_eff > 0.f && work_eff < 4.f / 9.f)
880                 continue; // no gain from Winograd transformation
881
882             /* outer parallelization */
883             find_m_n2_blocks(0, ix, iy, work_eff, m_b[0], n2_b[0], outer_eff);
884
885             /* inner parallelization */
886             find_m_n2_blocks(1, ix, iy, work_eff, m_b[1], n2_b[1], inner_eff);
887
888             small_mb = inner_eff > outer_eff;
889             float eff = small_mb ? inner_eff : outer_eff;
890             if (eff > best_eff) {
891                 best_eff = eff;
892                 jcp.yb = iy;
893                 jcp.xb = ix;
894                 jcp.m_block = m_b[small_mb];
895                 jcp.n2_block = n2_b[small_mb];
896                 jcp.small_mb = small_mb;
897             }
898         }
899     }
900
901     assert((jcp.m_block + 1) * jcp.n2_block <= free_regs);
902     assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0);
903
904     jcp.mb_block = 1;
905     if (jcp.small_mb) {
906         // For small mb harness, set mb_block as large as possible subject to
907         // the constraint that winograd activations fit into available L3 cache
908         int L3_cap = get_cache_size(3, true);
909         int M = jcp.xb * jcp.yb / 4;
910         int wino_src_size = 16 * M * jcp.ic * jcp.typesize_in;
911         int wino_dst_size = 16 * M * jcp.oc * jcp.typesize_acc;
912         int max_mb_block = nstl::min(
913                 jcp.mb, jcp.nthr * L3_cap / (wino_src_size + wino_dst_size));
914         for (int i = max_mb_block; i > 1; i--) {
915             if (jcp.mb % i == 0) {
916                 jcp.mb_block = i;
917                 break;
918             }
919         }
920     }
921     jcp.nb_mb = jcp.mb / jcp.mb_block;
922
923     jcp.M = jcp.mb_block * jcp.xb * jcp.yb / 4;
924     jcp.N = jcp.oc;
925     jcp.K = jcp.ic;
926
927     jcp.inp_stride = jcp.M * jcp.ic;
928     jcp.out_stride = jcp.M * jcp.oc;
929     jcp.wei_stride = jcp.ic * jcp.oc;
930     jcp.bia_stride = jcp.oc;
931
932     jcp.n_block = jcp.oc_block;
933     jcp.k_block = jcp.ic_block;
934
935     jcp.n_chunks = (jcp.N / jcp.n_block) / jcp.n2_block;
936
937     // We need jcp.k2_block to be a multiple of jcp.k_block = jcp.ic_block = 4
938     // and jcp.K = jcp.ic to be a multiple of jcp.k2_block. Since jcp.ic is
939     // a multiple of load_block = 16, we just use that for now.
940     jcp.k2_block = load_block;
941     jcp.k_chunks = jcp.K / jcp.k2_block;
942
943     const auto &oscales = attr.output_scales_;
944     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
945     assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
946
947     /* re-create weights primitive descriptor
948                                     and set weights wino_blocking */
949     memory_desc_t expect_wei_md = *(wei_pd.desc());
950
951     expect_wei_md.format = mkldnn_wino_fmt;
952     expect_wei_md.data_type = data_type::s8;
953     mkldnn_wino_desc_t &wd = expect_wei_md.layout_desc.wino_desc;
954     wd.wino_format = mkldnn_wino_wei_aaOIoi;
955     wd.r = jcp.r;
956     wd.alpha = jcp.alpha;
957     wd.ic = jcp.ic;
958     wd.oc = jcp.oc;
959     wd.ic_block = jcp.ic_block;
960     wd.oc_block = jcp.oc_block;
961     wd.oc2_block = jcp.n2_block;
962     wd.ic2_block = 1;
963     wd.adj_scale = adj_wei_scale;
964
965     size_t max_size = types::data_type_size(data_type::s8) *
966                         jcp.alpha * jcp.alpha * jcp.ic * jcp.oc;
967     max_size += types::data_type_size(data_type::s32) *
968                                 jcp.alpha * jcp.alpha * jcp.oc;
969     wd.size = max_size;
970
971     cpu_memory_t::pd_t new_weights_pd(wei_pd.engine(), &expect_wei_md);
972     if (wei_pd.desc()->format == any)
973         wei_pd = new_weights_pd;
974     if (!wei_pd.is_equal(&new_weights_pd))
975         return status::unimplemented;
976
977     const int tilesize = jcp.alpha * jcp.alpha;
978     const int numtiles = jcp.M;
979     const int alltiles = numtiles * tilesize;
980
981     jcp.size_wino_src
982         = utils::rnd_up(jcp.typesize_in * alltiles * jcp.ic, PAGE_4K)
983         / jcp.typesize_in;
984     jcp.size_wino_wei = tilesize * jcp.oc * jcp.ic;
985     jcp.size_wino_dst = alltiles * jcp.oc;
986
987     return status::success;
988 }
989 ////////////////////////////////////////////////////////////////////////////////
990
991 template <data_type_t dst_data_type>
992 status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
993         pd_t::jit_conf() {
994     return jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf(
995             jcp_, *this->desc(), this->src_pd_, this->weights_pd_,
996             this->dst_pd_,this->bias_pd_, *this->attr());
997 }
998
999 template <data_type_t dst_data_type>
1000 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::pd_t::
1001 init_scratchpad() {
1002     auto scratchpad = this->scratchpad_registry().registrar();
1003
1004     int nthr_multiplier = jcp_.small_mb ? 1 : jcp_.nthr;
1005     scratchpad.book(key_wino_V,
1006             sizeof(src_data_t) * jcp_.size_wino_src * nthr_multiplier, PAGE_4K);
1007     scratchpad.book(key_wino_M,
1008             sizeof(acc_data_t) * jcp_.size_wino_dst * nthr_multiplier, PAGE_4K);
1009
1010     scratchpad.book(key_conv_adjusted_scales,
1011             sizeof(float) * nstl::max(attr()->output_scales_.count_, 16));
1012 }
1013
1014 template <data_type_t dst_data_type>
1015 jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
1016         jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd,
1017                 const input_vector &inputs, const output_vector &outputs)
1018     : cpu_primitive_t(apd, inputs, outputs, true)
1019 {
1020     kernel_ = new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
1021             pd()->jcp_, *pd()->attr());
1022     src_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_src_trans_t(
1023             pd()->jcp_, *pd()->attr());
1024     dst_trans_ = new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
1025             pd()->jcp_, *pd()->attr());
1026 }
1027
1028 template <data_type_t dst_data_type>
1029 jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
1030         ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t() {
1031     delete kernel_;
1032     delete src_trans_;
1033     delete dst_trans_;
1034 }
1035
1036 template <data_type_t dst_data_type>
1037 const float *jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
1038 adjust_oscales(const memory_tracking::grantor_t &scratchpad) const {
1039     const float *oscales = pd()->attr()->output_scales_.scales_;
1040     auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales);
1041     size_t count = pd()->attr()->output_scales_.count_;
1042     float factor = 1.f / (adj_src_scale * adj_wei_scale);
1043     if (count == 1)
1044         utils::array_set(loc_scales, oscales[0] * factor, 16);
1045     else
1046         for (size_t c = 0; c < count; c++) loc_scales[c] = oscales[c] * factor;
1047     return loc_scales;
1048 }
1049
1050 template <data_type_t dst_data_type>
1051 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
1052 execute_forward() const {
1053     const auto &jcp = kernel_->jcp;
1054     if (jcp.small_mb)
1055         execute_forward_small_mb();
1056     else
1057         execute_forward_mbN();
1058 }
1059
1060 template <data_type_t dst_data_type>
1061 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
1062 execute_forward_mbN() const {
1063     auto src = reinterpret_cast<const src_data_t *>(input_memory(0));
1064     auto wei = reinterpret_cast<const wei_data_t *>(input_memory(1));
1065     auto bia = reinterpret_cast<const char *>(input_memory(2));
1066     auto dst = reinterpret_cast<dst_data_t *>(memory(0));
1067
1068     auto scratchpad = this->scratchpad();
1069
1070     const auto &jcp = kernel_->jcp;
1071     const float *oscales = adjust_oscales(scratchpad);
1072
1073     auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
1074     auto wino_src_base = scratchpad.template get<src_data_t>(key_wino_V);
1075     auto wino_dst_base = scratchpad.template get<acc_data_t>(key_wino_M);
1076
1077     parallel_nd(jcp.mb, div_up(jcp.oh, jcp.yb), div_up(jcp.ow, jcp.xb),
1078             [&](int mb, int tile_y_b, int tile_x_b) {
1079
1080         int tile_y = tile_y_b * jcp.yb;
1081         int tile_x = tile_x_b * jcp.xb;
1082
1083         int ithr = mkldnn_get_thread_num();
1084         auto wino_src = wino_src_base + jcp.size_wino_src * ithr;
1085         auto wino_dst = wino_dst_base + jcp.size_wino_dst * ithr;
1086
1087         auto src_trans_p =
1088             jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t();
1089         auto dst_trans_p =
1090             jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t();
1091         auto gemm_p =
1092             jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::call_params_t();
1093
1094         /* transformation of input tensor to winograd domain */
1095         for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
1096             for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) {
1097                 uint16_t v_y_masks[4], v_x_masks[4];
1098
1099                 int y = y_in_block + tile_y;
1100                 int x = x_in_block + tile_x;
1101                 int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
1102
1103                 int v_ys = nstl::max(0, jcp.t_pad - y);
1104                 int v_ye = nstl::min(jcp.alpha,
1105                         nstl::max(0, jcp.ih + jcp.t_pad - y));
1106
1107                 int v_xs = nstl::max(0, jcp.l_pad - x);
1108                 int v_xe = nstl::min(jcp.alpha,
1109                         nstl::max(0, jcp.iw + jcp.l_pad - x));
1110
1111 #pragma unroll(4)
1112                 for (int i = 0; i < jcp.alpha; i++) {
1113                     v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff);
1114                     v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff);
1115                 }
1116                 auto local_s = src
1117                         + mb * jcp.ih * jcp.iw * jcp.ic
1118                         + y * jcp.iw * jcp.ic + x * jcp.ic;
1119                 auto local_w = wino_src + m * jcp.ic;
1120
1121                 src_trans_p.src = local_s;
1122                 src_trans_p.wino_src = local_w;
1123                 src_trans_p.v_y_masks = v_y_masks;
1124                 src_trans_p.v_x_masks = v_x_masks;
1125
1126                 src_trans_->ker_(&src_trans_p);
1127             }
1128         }
1129         /* gemms */
1130         for (int tile_ij = 0; tile_ij < 16; tile_ij++) {
1131             // start threads at different GEMMs to help bring weights into LLC
1132             int offset = (tile_ij + ithr) % 16;
1133             gemm_p.src = wino_src + jcp.inp_stride * offset;
1134             gemm_p.dst = wino_dst + jcp.out_stride * offset;
1135             gemm_p.wei = wei + jcp.wei_stride * offset;
1136             gemm_p.dst_b = dst_bias + jcp.bia_stride * offset;
1137
1138             kernel_->ker_(&gemm_p);
1139         }
1140
1141         /* transformation from winograd domain to output tensor */
1142         for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
1143             for (int x_in_block = 0; x_in_block < jcp.xb; x_in_block += 2) {
1144                 uint16_t v_y_masks[2], v_x_masks[2];
1145
1146                 int y = y_in_block + tile_y;
1147                 int x = x_in_block + tile_x;
1148                 int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
1149
1150 #pragma unroll(2)
1151                 for (int i = 0; i < jcp.m; i++) {
1152                     v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0);
1153                     v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0);
1154                 }
1155                 auto local_d = dst
1156                         + mb * jcp.oh * jcp.ow * jcp.oc
1157                         + y * jcp.ow * jcp.oc + x * jcp.oc;
1158                 auto local_w = wino_dst + m * jcp.oc;
1159
1160                 auto scales = oscales;
1161                 dst_trans_p.dst = local_d;
1162                 dst_trans_p.wino_dst = local_w;
1163                 dst_trans_p.v_y_masks = v_y_masks;
1164                 dst_trans_p.v_x_masks = v_x_masks;
1165
1166                 dst_trans_p.scales = scales;
1167                 dst_trans_p.bias = bia;
1168
1169                 dst_trans_->ker_(&dst_trans_p);
1170             }
1171         }
1172     });
1173 }
1174
1175 template <data_type_t dst_data_type>
1176 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<dst_data_type>::
1177 execute_forward_small_mb() const {
1178     auto src = reinterpret_cast<const src_data_t *>(input_memory(0));
1179     auto wei = reinterpret_cast<const wei_data_t *>(input_memory(1));
1180     auto bia = reinterpret_cast<const char *>(input_memory(2));
1181     auto dst = reinterpret_cast<dst_data_t *>(memory(0));
1182
1183     auto scratchpad = this->scratchpad();
1184
1185     const auto &jcp = kernel_->jcp;
1186     const float *oscales = adjust_oscales(scratchpad);
1187
1188     auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
1189     auto wino_src = scratchpad.template get<src_data_t>(key_wino_V);
1190     auto wino_dst = scratchpad.template get<acc_data_t>(key_wino_M);
1191
1192     for (int mbb = 0; mbb < jcp.nb_mb; mbb++) {
1193     for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) {
1194     for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
1195         /* transformation of input tensor to winograd domain */
1196         parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
1197             [&](int y_in_block_b, int x_in_block_b, int mb) {
1198             int y_in_block = y_in_block_b * 2;
1199             int x_in_block = x_in_block_b * 2;
1200
1201             auto src_trans_p =
1202                 jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::call_params_t();
1203
1204             uint16_t v_y_masks[4], v_x_masks[4];
1205
1206             int y = y_in_block + tile_y;
1207             int x = x_in_block + tile_x;
1208             int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2)
1209                     + (x_in_block / 2);
1210
1211             int v_ys = nstl::max(0, jcp.t_pad - y);
1212             int v_ye = nstl::min(
1213                     jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y));
1214
1215             int v_xs = nstl::max(0, jcp.l_pad - x);
1216             int v_xe = nstl::min(
1217                     jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x));
1218
1219 #pragma unroll(4)
1220             for (int i = 0; i < jcp.alpha; i++) {
1221                 v_y_masks[i] = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff);
1222                 v_x_masks[i] = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff);
1223             }
1224             auto local_s = src
1225                     + (mbb * jcp.mb_block + mb) * jcp.ih * jcp.iw * jcp.ic
1226                     + y * jcp.iw * jcp.ic + x * jcp.ic;
1227             auto local_w = wino_src + m * jcp.ic;
1228
1229             src_trans_p.src = local_s;
1230             src_trans_p.wino_src = local_w;
1231             src_trans_p.v_y_masks = v_y_masks;
1232             src_trans_p.v_x_masks = v_x_masks;
1233
1234             src_trans_->ker_(&src_trans_p);
1235         });
1236
1237         /* gemms */
1238         parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) {
1239             auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::
1240                     call_params_t();
1241
1242             gemm_p.src = wino_src + jcp.inp_stride * tile_ij;
1243             gemm_p.dst = wino_dst + jcp.out_stride * tile_ij
1244                     + nnb * jcp.n2_block * jcp.n_block;
1245             gemm_p.wei = wei + jcp.wei_stride * tile_ij
1246                     + nnb * jcp.n2_block * jcp.n_block * jcp.K;
1247             gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij
1248                     + nnb * jcp.n2_block * jcp.n_block;
1249
1250             kernel_->ker_(&gemm_p);
1251         });
1252
1253         /* transformation from winograd domain to output tensor */
1254         parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
1255             [&](int y_in_block_b, int x_in_block_b, int mb) {
1256             int y_in_block = y_in_block_b * 2;
1257             int x_in_block = x_in_block_b * 2;
1258
1259             auto dst_trans_p =
1260                 jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::call_params_t();
1261
1262             uint16_t v_y_masks[2], v_x_masks[2];
1263
1264             int y = y_in_block + tile_y;
1265             int x = x_in_block + tile_x;
1266             int m = (mb * (jcp.yb / 2) + (y_in_block / 2)) * (jcp.xb / 2)
1267                     + (x_in_block / 2);
1268
1269 #pragma unroll(2)
1270             for (int i = 0; i < jcp.m; i++) {
1271                 v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0);
1272                 v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0);
1273             }
1274             auto local_d = dst
1275                     + (mbb * jcp.mb_block + mb) * jcp.oh * jcp.ow * jcp.oc
1276                     + y * jcp.ow * jcp.oc + x * jcp.oc;
1277             auto local_w = wino_dst + m * jcp.oc;
1278
1279             auto scales = oscales;
1280             dst_trans_p.dst = local_d;
1281             dst_trans_p.wino_dst = local_w;
1282             dst_trans_p.v_y_masks = v_y_masks;
1283             dst_trans_p.v_x_masks = v_x_masks;
1284
1285             dst_trans_p.scales = scales;
1286             dst_trans_p.bias = bia;
1287
1288             dst_trans_->ker_(&dst_trans_p);
1289         });
1290     }}}
1291 }
1292
1293 template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::s8>;
1294 template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::u8>;
1295 template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::s32>;
1296 template struct jit_avx512_core_u8s8s32x_wino_convolution_fwd_t<data_type::f32>;
1297
1298 } // namespace cpu
1299 } // namespace impl
1300 } // namespace mkldnn