Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_fp32_wino_conv_2x3.cpp
1
2 /*******************************************************************************
3  * Copyright 2018 Intel Corporation
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  *******************************************************************************/
17
18 #include <assert.h>
19
20 #include "c_types_map.hpp"
21 #include "cpu_convolution_pd.hpp"
22 #include "cpu_engine.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "type_helpers.hpp"
25 #include "utils.hpp"
26
27 #include "jit_avx512_core_fp32_wino_conv_2x3.hpp"
28 #include "jit_generator.hpp"
29
30 namespace mkldnn {
31 namespace impl {
32 namespace cpu {
33
34 using namespace mkldnn::impl::memory_format;
35 using namespace mkldnn::impl::memory_tracking::names;
36 using namespace mkldnn::impl::utils;
37 using namespace Xbyak;
38
39 /// SRC TRANSFORMS /////////////////////////////////////////////////////////////
40 struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t: public jit_generator {
41     DECLARE_CPU_JIT_AUX_FUNCTIONS(
42             jit_avx512_core_fp32_wino_conv_2x3_src_trans_t)
43
44     jit_conv_conf_2x3_wino_t jcp;
45
46     struct call_params_t {
47         const void *src;
48         const void *wino_src;
49         const void *v_y_masks;
50         const void *v_x_masks;
51     };
52     void (*ker_)(const call_params_t *);
53
54     jit_avx512_core_fp32_wino_conv_2x3_src_trans_t(
55         jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
56         : jcp(ajcp) {
57         generate();
58         ker_ =
59             reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
60     }
61
62     void generate();
63
64     Zmm vreg_inp(int i) {
65         assert(i < jcp.alpha * jcp.alpha);
66         return Zmm(31 - i);
67     }
68
69     Zmm vreg_tmp(int i) {
70         assert(i < jcp.alpha * jcp.alpha);
71         return Zmm(15 - i);
72     }
73
74     Zmm vreg_out(int i) {
75         assert(i < jcp.alpha * jcp.alpha);
76         return Zmm(31 - i);
77     }
78
79     Opmask y_mask = Opmask(1);
80     Opmask r_mask = Opmask(2);
81     Opmask x_mask(int id) {
82         assert (id < 4);
83         return Opmask(3 + id);
84     }
85
86     Reg64 reg_ptr_v_y_masks = r12;
87     Reg64 reg_ptr_v_x_masks = r11;
88
89     Reg64 reg_aux_ptr_src = r10;
90     Reg64 reg_aux_ptr_dst = r9;
91
92     Reg64 reg_ic_block = r8;
93
94 };
95
96 void jit_avx512_core_fp32_wino_conv_2x3_src_trans_t::generate() {
97     Label ic_block_label;
98
99     const int load_block = 16;
100     int out_offset = 0, inp_offset = 0;
101     preamble();
102
103 #define READ_PARAM(reg, field) \
104         mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
105     READ_PARAM(reg_aux_ptr_src, src);
106     READ_PARAM(reg_aux_ptr_dst, wino_src);
107     READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
108     READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
109 #undef READ_PARAM
110
111     for (int i = 0; i < jcp.alpha; i++) {
112         kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
113     }
114     mov(reg_ic_block, jcp.ic / load_block);
115     L(ic_block_label);
116     {
117         for (int y = 0; y < jcp.alpha; y++) {
118             kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]);
119             for (int x = 0; x < jcp.alpha; x++) {
120                 Zmm zmm = vreg_inp(y * jcp.alpha + x);
121
122                 vxorps(zmm, zmm, zmm);
123                 kandw(r_mask, y_mask, x_mask(x));
124                 inp_offset = sizeof(float)
125                         * ((-jcp.t_pad + y) * jcp.iw * load_block
126                                   + (-jcp.l_pad + x) * load_block);
127                 vmovups(zmm | r_mask,
128                         EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
129             }
130         }
131         for (int y = 0; y < jcp.alpha; y++) {
132             vsubps(vreg_tmp(y * jcp.alpha + 0), vreg_inp(y * jcp.alpha + 0),
133                     vreg_inp(y * jcp.alpha + 2));
134             vaddps(vreg_tmp(y * jcp.alpha + 1), vreg_inp(y * jcp.alpha + 1),
135                     vreg_inp(y * jcp.alpha + 2));
136             vsubps(vreg_tmp(y * jcp.alpha + 2), vreg_inp(y * jcp.alpha + 2),
137                     vreg_inp(y * jcp.alpha + 1));
138             vsubps(vreg_tmp(y * jcp.alpha + 3), vreg_inp(y * jcp.alpha + 1),
139                     vreg_inp(y * jcp.alpha + 3));
140         }
141         for (int x = 0; x < jcp.alpha; x++) {
142             vsubps(vreg_out(x + 0 * jcp.alpha), vreg_tmp(x + jcp.alpha * 0),
143                     vreg_tmp(x + jcp.alpha * 2));
144             vaddps(vreg_out(x + 1 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1),
145                     vreg_tmp(x + jcp.alpha * 2));
146             vsubps(vreg_out(x + 2 * jcp.alpha), vreg_tmp(x + jcp.alpha * 2),
147                     vreg_tmp(x + jcp.alpha * 1));
148             vsubps(vreg_out(x + 3 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1),
149                     vreg_tmp(x + jcp.alpha * 3));
150         }
151
152         for (int i = 0; i < 16; i++) {
153             out_offset = sizeof(float) * (jcp.inp_stride * i);
154             vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset),
155                     vreg_out(i));
156         }
157
158         add(reg_aux_ptr_src, sizeof(float) * jcp.ih * jcp.iw * load_block);
159         add(reg_aux_ptr_dst, sizeof(float) * load_block);
160     }
161     dec(reg_ic_block);
162     cmp(reg_ic_block, 0);
163     jg(ic_block_label, T_NEAR);
164     postamble();
165 }
166
167 /// DST TRANSFORMS /////////////////////////////////////////////////////////////
168 struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t: public jit_generator {
169     DECLARE_CPU_JIT_AUX_FUNCTIONS(
170             jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t)
171
172     jit_conv_conf_2x3_wino_t jcp;
173     const primitive_attr_t &attr_;
174
175     struct call_params_t {
176         const void *wino_dst;
177         const void *dst;
178         const void *v_y_masks;
179         const void *v_x_masks;
180
181         const void *bias;
182         const void *scales;
183     };
184     void (*ker_)(const call_params_t *);
185
186     jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t(
187             jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
188         : jcp(ajcp), attr_(attr) {
189         generate();
190         ker_ = reinterpret_cast<decltype(ker_)>(
191                 const_cast<uint8_t *>(getCode()));
192     }
193
194     void generate();
195     bool maybe_relu(int position);
196
197     Zmm vreg_inp(int i) { // 16
198         assert(i < jcp.alpha * jcp.alpha);
199         return Zmm(31 - i);
200     }
201
202     Zmm vreg_stg(int id) { // 8
203         const int id_reg_stg = jcp.alpha * jcp.alpha + id;
204         assert(id_reg_stg < jcp.alpha * jcp.alpha + 8);
205         return Zmm(31 - id_reg_stg);
206     }
207
208     Zmm vreg_out(int id) { // 4
209         const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
210         assert(id_reg_out < jcp.alpha * jcp.alpha + 12);
211         return Zmm(31 - id_reg_out);
212     }
213
214     Zmm vreg_tmp(int id) { // 2
215         const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id;
216         assert(id_reg_tmp < jcp.alpha * jcp.alpha + 14);
217         return Zmm(31 - id_reg_tmp);
218     }
219
220     Zmm vreg_zero = Zmm(0);
221     Zmm vreg_prev_dst = Zmm(0);
222     Zmm vreg_bias = Zmm(2);
223
224     Opmask y_mask = Opmask(1);
225     Opmask r_mask = Opmask(2);
226     Opmask x_mask(int id) {
227         assert (id < 4);
228         return Opmask(3 + id);
229     }
230
231     Reg64 reg_ptr_v_y_masks = r12;
232     Reg64 reg_ptr_v_x_masks = r11;
233
234     Reg64 reg_aux_ptr_src = r10;
235     Reg64 reg_aux_ptr_dst = r9;
236
237     Reg64 reg_oc_block = r8;
238
239     Reg64 reg_ptr_bias = rbx;
240     Reg64 reg_ptr_scales = abi_not_param1;
241     Reg64 reg_ptr_sum_scale = rdx;
242 };
243
244 bool jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::maybe_relu(int position) {
245     using namespace primitive_kind;
246     const auto &p = attr_.post_ops_;
247
248     if (position == 0) {
249         /* relu before sum */
250         return false
251             || p.contain(eltwise, 0);
252     } else if (position == 1) {
253         /* relu after sum */
254         const int sum_idx = p.contain(sum, 0)
255             ? 0 : (p.contain(sum, 1) ? 1 : -1);
256         if (sum_idx == -1)
257             return false;
258
259         return false
260             || p.contain(eltwise, sum_idx + 1);
261     }
262
263     return false;
264 }
265
266 void jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::generate() {
267     Label oc_block_label;
268
269     const int load_block = 16;
270
271     auto loop_body = [=]() {
272         const auto &p = attr_.post_ops_;
273         const int sum_idx = p.find(primitive_kind::sum);
274         const float *p_sum_scale = (sum_idx != -1)
275                 ? &p.entry_[sum_idx].sum.scale
276                 : nullptr;
277         if (p_sum_scale && *p_sum_scale != 1.f)
278             mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
279
280         for (int i = 0; i < 16; i++) {
281             int internal_offset = sizeof(float) * jcp.out_stride * i;
282             vmovups(vreg_inp(i),
283                 EVEX_compress_addr(reg_aux_ptr_src, internal_offset));
284         }
285         for (int y = 0; y < jcp.alpha; y++) {
286             vaddps(vreg_tmp(0), vreg_inp(y * 4 + 0), vreg_inp(y * 4 + 1));
287             vaddps(vreg_stg(y * 2), vreg_tmp(0), vreg_inp(y * 4 + 2));
288
289             vsubps(vreg_tmp(1), vreg_inp(y * 4 + 1), vreg_inp(y * 4 + 2));
290             vsubps(vreg_stg(y * 2+1), vreg_tmp(1), vreg_inp(y * 4 + 3));
291         }
292         for (int x = 0; x < jcp.m; x++) {
293             vaddps(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2 * 1));
294             vaddps(vreg_out(x), vreg_tmp(0), vreg_stg(x+2 * 2));
295
296             vsubps(vreg_tmp(1), vreg_stg(x+2 * 1), vreg_stg(x+2 * 2));
297             vsubps(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2 * 3));
298         }
299
300
301         if (jcp.with_bias) {
302             auto bias_addr = ptr [ reg_ptr_bias ];
303             vmovups(vreg_bias, bias_addr);
304         }
305         for (int y = 0; y < jcp.m; y++) {
306             kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(int16_t) * y ]);
307             for (int x = 0; x < jcp.m; x++) {
308                 kandw(r_mask, y_mask, x_mask(x));
309
310                 int i = y * jcp.m + x;
311                 int offset = sizeof(float) *
312                     (y * jcp.ow * jcp.oc_block + x * jcp.oc_block);
313                 Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset);
314
315                 Zmm zmm = vreg_out(i);
316                 if (jcp.with_bias)
317                     vaddps(zmm, zmm, vreg_bias);
318                 vmulps(zmm, zmm, ptr [reg_ptr_scales]);
319
320                 if (maybe_relu(0)) {
321                     vxorps(vreg_zero, vreg_zero, vreg_zero);
322                     vmaxps(zmm, vreg_zero, zmm);
323                 }
324                 if (p_sum_scale) { // post_op: sum
325                     vxorps(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst);
326                     vmovups(vreg_prev_dst | r_mask, addr);
327                     if (*p_sum_scale == 1.f)
328                         vaddps(zmm, vreg_prev_dst);
329                     else
330                         vfmadd231ps(zmm, vreg_prev_dst,
331                             zword_b[reg_ptr_sum_scale]);
332                 }
333                 if (maybe_relu(1)) {
334                     vxorps(vreg_zero, vreg_zero, vreg_zero);
335                     vmaxps(zmm, vreg_zero, zmm);
336                 }
337
338                 vmovups(addr, zmm | r_mask);
339             }
340         }
341     };
342
343     preamble();
344
345 #define READ_PARAM(reg, field) \
346         mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
347     READ_PARAM(reg_aux_ptr_src, wino_dst);
348     READ_PARAM(reg_aux_ptr_dst, dst);
349     READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
350     READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
351     READ_PARAM(reg_ptr_bias, bias);
352     READ_PARAM(reg_ptr_scales, scales);
353 #undef READ_PARAM
354
355     for (int i = 0; i < jcp.alpha * jcp.alpha; i++)
356         vxorps(vreg_inp(i), vreg_inp(i), vreg_inp(i));
357
358     for (int i = 0; i < jcp.alpha; i++)
359         kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
360
361     int oc_blocks = 1;
362     oc_blocks = jcp.oc / load_block;
363     mov(reg_oc_block, oc_blocks);
364     L(oc_block_label);
365     {
366         loop_body();
367         add(reg_aux_ptr_src, sizeof(float) * load_block);
368         add(reg_aux_ptr_dst, sizeof(float) * jcp.oh * jcp.ow * load_block);
369
370         add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
371         add(reg_ptr_bias, jcp.typesize_bia * load_block);
372     }
373     dec(reg_oc_block);
374     cmp(reg_oc_block, 0);
375     jg(oc_block_label, T_NEAR);
376
377     sub(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
378     sub(reg_ptr_bias, oc_blocks * jcp.typesize_bia * load_block);
379
380     postamble();
381
382 }
383
384 /// GEMM kernel ////////////////////////////////////////////////////////////////
385 struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t: public jit_generator {
386     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t)
387     jit_conv_conf_2x3_wino_t jcp;
388
389     struct call_params_t {
390         const void *src;
391         const void *dst;
392         const void *wei;
393         const void *dst_b;
394     };
395     void (*ker_)(const call_params_t *);
396
397     void generate();
398     static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp,
399                             const primitive_attr_t &attr);
400
401     jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t(
402             jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
403         : jcp(ajcp) {
404         generate();
405         ker_ = reinterpret_cast<decltype(ker_)>(
406                 const_cast<uint8_t *>(getCode()));
407     }
408
409     static status_t init_conf(
410             jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
411             cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
412             cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
413             const primitive_attr_t &attr,
414             memory_desc_t& expect_wei_md);
415
416     Zmm vreg_out(int n, int m) {
417         const int id_reg_out = n * jcp.m_block + m;
418         assert(id_reg_out < jcp.n2_block * jcp.m_block);
419         return Zmm(31 - id_reg_out);
420     }
421     Zmm vreg_wei(int i) {
422         assert (31 - jcp.n2_block * jcp.m_block - i > 1);
423         return Zmm(31 - jcp.n2_block * jcp.m_block - i);
424     }
425
426     Zmm vreg_src = Zmm(0);
427     Zmm vreg_one = Zmm(1);
428     Zmm vreg_tmp = Zmm(2);
429
430     Reg64 reg_ptr_src = r15;
431
432     Reg64 reg_aux_dst = r12;
433     Reg64 reg_aux_dst2 = r11;
434     Reg64 reg_aux_wei = r10;
435     Reg64 reg_aux_wei2 = r9;
436     Reg64 reg_aux_src = r8;
437     Reg64 reg_aux_src2 = rax;
438
439     Reg64 reg_mb = rbx;
440     Reg64 reg_nnb = rdx;
441     Reg64 reg_K = rsi;
442
443 };
444
445 bool jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::post_ops_ok(
446         jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) {
447     using namespace primitive_kind;
448     const auto &p = attr.post_ops_;
449
450     auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
451
452     switch (p.len_) {
453     case 0: return true;
454     case 1: return is_relu(0) || p.contain(sum, 0);
455     case 2: return (p.contain(sum, 0) && is_relu(1)) ||
456                        (p.contain(sum, 1) && is_relu(0));
457     case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
458     default: return false;
459     }
460
461     return false;
462 }
463
464 void jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::generate() {
465     Label nnb_loop_label, K_loop_label, mb_loop_label;
466
467     preamble();
468 #define READ_PARAM(reg, field) \
469     mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
470     READ_PARAM(reg_ptr_src, src);
471     READ_PARAM(reg_aux_dst, dst);
472     READ_PARAM(reg_aux_wei, wei);
473 #undef READ_PARAM
474
475     if (!jcp.small_mb) {
476         mov(reg_nnb, jcp.n_chunks);
477         L(nnb_loop_label);
478     }
479     mov(reg_aux_dst2, reg_aux_dst);
480     mov(reg_aux_src, reg_ptr_src);
481     mov(reg_mb, jcp.M / jcp.m_block);
482     L(mb_loop_label);
483     {
484         int nb2 = 0;
485         for (nb2 = 0; nb2 < jcp.n2_block; nb2++) {
486             for (int m = 0; m < jcp.m_block; m++) {
487                 vxorps(vreg_out(nb2, m), vreg_out(nb2, m), vreg_out(nb2, m));
488             }
489         }
490         mov(reg_aux_src2, reg_aux_src);
491         mov(reg_aux_wei2, reg_aux_wei);
492
493         mov(reg_K, jcp.k_chunks);
494         L(K_loop_label); {
495             int wei_offset = 0;
496             for (int _i = 0; _i < jcp.k2_block; _i++) {
497                 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
498                     if (jcp.small_mb) {
499                         int wei_offset = sizeof(float)
500                                 * ((nb2 * jcp.nb_ic * jcp.ic_block
501                                            * jcp.oc_block)
502                                           + _i * jcp.oc_block);
503                         vmovups(vreg_wei(nb2),
504                                 EVEX_compress_addr(reg_aux_wei2, wei_offset));
505                     } else {
506                         vmovups(vreg_wei(nb2),
507                                 EVEX_compress_addr(reg_aux_wei2,
508                                         sizeof(float) * wei_offset));
509                         wei_offset += jcp.oc_block;
510                     }
511                 }
512                 for (int m = 0; m < jcp.m_block; m++) {
513                     int inp_offset = sizeof(float) * (m * jcp.K + _i);
514                     if (jcp.n2_block > 1) {
515                         vbroadcastss(vreg_src,
516                             EVEX_compress_addr(reg_aux_src2, inp_offset));
517                         for (int nb2 = 0; nb2 < jcp.n2_block; nb2++)
518                             vfmadd231ps(vreg_out(nb2, m), vreg_wei(nb2),
519                                 vreg_src);
520                     } else {
521                         vfmadd231ps(vreg_out(0, m), vreg_wei(0),
522                             EVEX_compress_addr(reg_aux_src2, inp_offset, true));
523                     }
524                 }
525             }
526             add(reg_aux_src2, sizeof(float) * jcp.ic_block);
527             if (jcp.small_mb)
528                 add(reg_aux_wei2, sizeof(float) * jcp.oc_block * jcp.ic_block);
529             else
530                 add(reg_aux_wei2,
531                         sizeof(float) * jcp.k2_block * jcp.n2_block
532                                 * jcp.oc_block);
533         }
534         dec(reg_K);
535         cmp(reg_K, 0);
536         jg(K_loop_label, T_NEAR);
537
538         for (int m = 0; m < jcp.m_block; m++) {
539             int nb2 = 0;
540             for (nb2 = 0; nb2 < jcp.n2_block; nb2++) {
541                 int offset = sizeof(float) *
542                     (m * jcp.N + nb2 * jcp.oc_block);
543                 vmovups(EVEX_compress_addr(reg_aux_dst2,offset),
544                             vreg_out(nb2, m));
545             }
546         }
547         add(reg_aux_src, sizeof(float) * jcp.m_block * jcp.K);
548         add(reg_aux_dst2, sizeof(float) * jcp.m_block * jcp.N);
549     }
550     dec(reg_mb);
551     cmp(reg_mb, 0);
552     jg(mb_loop_label, T_NEAR);
553
554     if (!jcp.small_mb) {
555         add(reg_aux_dst, sizeof(float) * jcp.n2_block * jcp.oc_block);
556         add(reg_aux_wei,
557                 sizeof(float) * jcp.k_chunks * jcp.ic_block * jcp.n2_block
558                         * jcp.oc_block);
559
560         dec(reg_nnb);
561         cmp(reg_nnb, 0);
562         jg(nnb_loop_label, T_NEAR);
563     }
564     postamble();
565 }
566
567 namespace {
568 bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
569     return jcp.mb >= 4;
570 }
571 }
572
573 status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf(
574         jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
575         cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &wei_pd,
576         cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
577         const primitive_attr_t &attr, memory_desc_t &expect_wei_md) {
578     const memory_desc_wrapper src_d(&src_pd);
579     const memory_desc_wrapper wei_d(&wei_pd);
580     const memory_desc_wrapper dst_d(&dst_pd);
581     const memory_desc_wrapper bias_d(&bias_pd);
582
583     const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
584
585     jcp.nthr = mkldnn_get_max_threads();
586
587     jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
588     jcp.mb = src_d.dims()[0];
589     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
590     jcp.oc_without_padding = jcp.oc;
591     jcp.ic = src_d.dims()[1] / jcp.ngroups;
592     jcp.ih = src_d.dims()[2];
593     jcp.iw = src_d.dims()[3];
594     jcp.oh = dst_d.dims()[2];
595     jcp.ow = dst_d.dims()[3];
596     jcp.kh = wei_d.dims()[with_groups + 2];
597     jcp.kw = wei_d.dims()[with_groups + 3];
598     jcp.t_pad = cd.padding[0][0];
599     jcp.b_pad = cd.padding[1][0];
600     jcp.l_pad = cd.padding[0][1];
601     jcp.r_pad = cd.padding[1][1];
602     jcp.stride_h = cd.strides[0];
603     jcp.stride_w = cd.strides[1];
604     jcp.dilate_h = cd.dilates[0];
605     jcp.dilate_w = cd.dilates[1];
606
607     jcp.m = 2;
608     jcp.r = 3;
609     jcp.alpha = jcp.m + jcp.r - 1;
610     int simdw = 16;
611     jcp.src_fmt = src_d.format();
612     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
613
614     if (!post_ops_ok(jcp, attr))
615         return status::unimplemented;
616
617     bool ok_to_pad_channels = jcp.ngroups == 1;
618     if (ok_to_pad_channels) {
619         jcp.oc = rnd_up(jcp.oc, simdw);
620         jcp.ic = rnd_up(jcp.ic, simdw);
621     }
622
623     if (src_d.format() != nChw16c
624             || dst_d.format() != nChw16c
625             || !IMPLICATION(jcp.with_bias,
626                 bias_d.format() == x))
627         return status::unimplemented;
628
629     jcp.ver = ver_avx512_core;
630     if (!(mayiuse(avx512_core)))
631         return status::unimplemented;
632
633     if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
634                is_winograd_faster_than_direct(jcp)))
635         return status::unimplemented;
636
637     if (src_d.data_type() != data_type::f32)
638         return status::unimplemented;
639     if (wei_d.data_type() != data_type::f32)
640         return status::unimplemented;
641     if (dst_d.data_type() != data_type::f32)
642         return status::unimplemented;
643
644     if (mayiuse(avx512_core_vnni))
645         jcp.ver = ver_vnni;
646
647     jcp.ic_block = simdw;
648     jcp.oc_block = simdw;
649
650     bool ok = true && jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1
651             && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0
652             && jcp.stride_h == 1 && jcp.stride_w == 1 && jcp.dilate_h == 0
653             && jcp.dilate_w == 0 && jcp.t_pad == jcp.b_pad
654             && jcp.l_pad == jcp.r_pad && jcp.t_pad < 2 && jcp.t_pad >= 0
655             && jcp.l_pad < 2 && jcp.l_pad >= 0;
656     if (!ok)
657         return status::unimplemented;
658
659     const int L2_cap = get_cache_size(2, true) / sizeof(float);
660     const int L3_capacity = get_cache_size(3, false) / sizeof(float);
661     int a = jcp.alpha;
662     int aa = a * a;
663     int mb = jcp.mb;
664     int ic = jcp.ic;
665     int oc = jcp.oc;
666     int ih = jcp.ih;
667     int iw = jcp.iw;
668     auto wei_sz = (float)aa * ic * oc;
669     auto inp_sz = (float)mb * ih * iw * ic;
670     auto sp_sz = (float)mb * ih * iw;
671
672     /* Heuristics here. Numbers '28','196' is an observation from data. */
673     if (wei_sz / inp_sz > 5)
674         jcp.small_mb = true;
675     else
676         jcp.small_mb = false;
677
678     if (mb > nstl::min(jcp.nthr, 28)
679         || (!jcp.small_mb
680             && (wei_sz >= 0.9f * L2_cap
681                 || inp_sz > L2_cap * jcp.nthr + L3_capacity))
682         || (jcp.small_mb && sp_sz > 196))
683         return unimplemented;
684
685     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
686     jcp.dst_dt = cd.dst_desc.data_type;
687
688     jcp.typesize_bia
689             = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
690
691     jcp.nb_oc = jcp.oc / jcp.oc_block;
692     jcp.nb_ic = jcp.ic / jcp.ic_block;
693
694     const int skx_free_regs = 30;
695
696     auto find_m_n2_blocks = [=](int xb, int yb, int &M, int &m_block,
697                                     int &n2_block, float &reg_eff) {
698         M = (xb * yb) / jcp.alpha;
699         int max_m_block = m_block = nstl::min(M, skx_free_regs);
700         int max_n2_block = n2_block = nstl::min(jcp.nb_oc, skx_free_regs);
701         reg_eff = 0;
702         for (int im = max_m_block; im > 0; im--) {
703             for (int in2 = max_n2_block; in2 > 0; in2--) {
704                 int used_regs = in2 * im + in2;
705                 float cur_reg_eff = ((float)in2 * im) / (im + in2) / 2.5f;
706                 if (M % im || jcp.nb_oc % in2 || used_regs > skx_free_regs
707                         || cur_reg_eff <= reg_eff)
708                     continue;
709                 reg_eff = cur_reg_eff;
710                 m_block = im;
711                 n2_block = in2;
712             }
713         }
714     };
715
716     int oh = jcp.oh;
717     int ow = jcp.ow;
718     int nb_oc = jcp.nb_oc;
719     int Z = ic + oc;
720     int Y = ic * oc;
721     const int L3_cap_per_core = get_cache_size(3, true) / sizeof(float);
722
723     /* Selecting xb and yb blocking */
724     int min_yb = jcp.alpha;
725     int min_xb = jcp.alpha;
726     int max_yb = nstl::max(min_yb, rnd_up(ih, 2));
727     int max_xb = nstl::max(min_xb, rnd_up(iw, 2));
728     float best_eff = 0.f;
729     for (int ix = max_xb; ix >= min_xb; ix -= 2) {
730         if (rnd_up(ow, ix) < iw - 2)
731             continue;
732         for (int iy = max_yb; iy >= min_yb; iy -= 2) {
733             if (rnd_up(oh, iy) < ih - 2)
734                 continue;
735             int ex_y = rnd_up(oh, iy);
736             int ex_x = rnd_up(ow, ix);
737             float work_eff = (float)(ih * iw) / (ex_y * ex_x);
738
739             int M, m_block, n2_b;
740             float reg_eff, thr_eff, par_eff, mem_eff, req_mem;
741
742             find_m_n2_blocks(ix, iy, M, m_block, n2_b, reg_eff);
743
744             /* outer parallelization */
745             int nblocks = mb * div_up(ih, iy) * div_up(iw, ix);
746             thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
747
748             mem_eff = 1.f;
749             req_mem = (((float)ix + 2) * (iy + 2) + aa * M) * Z + aa * Y;
750             if (req_mem > L2_cap / 2) {
751                 if (req_mem > ((L2_cap + L3_cap_per_core) * 4) / 7)
752                     mem_eff /= (n2_b + 1) / 2.f;
753                 else
754                     mem_eff /= (n2_b + 1) / 3.f;
755             }
756
757             float outer_eff = thr_eff + work_eff + reg_eff + mem_eff;
758
759             /* inner parallelization */
760             int bsz = iy * ix / a;
761             int gemmw = aa * (nb_oc / n2_b);
762             int bsz_r = rnd_up(bsz, jcp.nthr);
763             int gemmw_r = rnd_up(gemmw, jcp.nthr);
764             thr_eff = ((float)Z * bsz / bsz_r + Y * gemmw / gemmw_r) / (Z + Y);
765
766             req_mem = (float)ix * iy * (ic + simdw * n2_b) + simdw * n2_b * ic;
767             mem_eff = nstl::min(1.f, L2_cap / req_mem);
768             int M_per_thr = nstl::max(2, div_up(aa, jcp.nthr));
769             int oc_per_thr =
770                 nstl::min(oc, div_up(aa * (nb_oc / n2_b), jcp.nthr));
771             req_mem = (float)aa * oc_per_thr * ic + M_per_thr * M * Z;
772             if (req_mem > L2_cap)
773                 mem_eff = 0.1f;
774             par_eff = 1 / (2.f * nblocks);
775
776             float inner_eff = thr_eff + work_eff + mem_eff + par_eff;
777
778             float eff = jcp.small_mb ? inner_eff : outer_eff;
779             if (eff > best_eff) {
780                 best_eff = eff;
781                 jcp.yb = iy;
782                 jcp.xb = ix;
783                 jcp.M = M;
784                 jcp.m_block = m_block;
785                 jcp.n2_block = n2_b;
786             }
787         }
788     }
789
790     assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0);
791
792     jcp.inp_stride = jcp.M * jcp.ic;
793     jcp.out_stride = jcp.M * jcp.oc;
794     jcp.wei_stride = jcp.ic * jcp.oc;
795     jcp.bia_stride = jcp.oc;
796
797     jcp.N = jcp.oc;
798     jcp.K = jcp.ic;
799
800     jcp.n_block = jcp.oc_block;
801     jcp.k_block = jcp.ic_block;
802
803     assert(jcp.M % jcp.m_block == 0);
804     assert(jcp.nb_oc % jcp.n2_block == 0);
805
806     jcp.n_chunks = jcp.nb_oc / jcp.n2_block;
807     jcp.k2_block = jcp.ic_block;
808     jcp.k_chunks = jcp.K / jcp.k2_block;
809
810     const auto &oscales = attr.output_scales_;
811     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
812     assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
813
814     /* re-create weights primitive descriptor
815                                     and set weights wino_blocking */
816     expect_wei_md.format = mkldnn_wino_fmt;
817     expect_wei_md.data_type = data_type::f32;
818     mkldnn_wino_desc_t &wd = expect_wei_md.layout_desc.wino_desc;
819     wd.wino_format
820             = jcp.small_mb ? mkldnn_wino_wei_aaOio : mkldnn_wino_wei_aaOBiOo;
821     wd.r = jcp.r;
822     wd.alpha = jcp.alpha;
823     wd.ic = jcp.ic;
824     wd.oc = jcp.oc;
825     wd.ic_block = jcp.ic_block;
826     wd.oc_block = jcp.oc_block;
827     wd.oc2_block = jcp.n2_block;
828     wd.ic2_block = 1;
829     wd.adj_scale = 1.f;
830     size_t max_size = sizeof(float) * jcp.alpha * jcp.alpha * jcp.ic * jcp.oc;
831     wd.size = max_size;
832
833     return status::success;
834 }
835 ////////////////////////////////////////////////////////////////////////////////
836
837 status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_t
838     ::pd_t::jit_conf(memory_desc_t& expect_wei_md) {
839     return jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::init_conf(
840             jcp_, *this->desc(), this->src_pd_, this->weights_pd_,
841             this->dst_pd_,this->bias_pd_, *this->attr(), expect_wei_md);
842 }
843
844 jit_avx512_core_fp32_wino_conv_2x3_fwd_t::
845         jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd,
846                 const input_vector &inputs, const output_vector &outputs)
847     : cpu_primitive_t(apd, inputs, outputs)
848 {
849     kernel_ = new jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t(
850             pd()->jcp_, *pd()->attr());
851     src_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_src_trans_t(
852             pd()->jcp_, *pd()->attr());
853     dst_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t(
854             pd()->jcp_, *pd()->attr());
855 }
856
857 jit_avx512_core_fp32_wino_conv_2x3_fwd_t
858     ::~jit_avx512_core_fp32_wino_conv_2x3_fwd_t() {
859     delete kernel_;
860     delete src_trans_;
861     delete dst_trans_;
862 }
863
864 void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward() const {
865     const auto &jcp = kernel_->jcp;
866
867     if (jcp.small_mb)
868         execute_forward_small_mb();
869     else
870         execute_forward_mbN();
871 }
872
873 void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_mbN() const {
874     auto src = reinterpret_cast<const float *>(input_memory(0));
875     auto wei = reinterpret_cast<const float *>(input_memory(1));
876     auto bia = reinterpret_cast<const float *>(input_memory(2));
877     auto dst = reinterpret_cast<float *>(memory(0));
878
879     auto scratchpad = this->scratchpad();
880
881     const auto &jcp = kernel_->jcp;
882     const auto &oscales = pd()->attr()->output_scales_;
883
884     const size_t wino_size_offset =
885         (size_t)(pd()->jcp_.yb / 2) * (pd()->jcp_.xb / 2) + (pd()->jcp_.xb);
886     const size_t size_wino_src = wino_size_offset * pd()->jcp_.ic * 16;
887     const size_t size_wino_dst = wino_size_offset * pd()->jcp_.oc * 16;
888
889     if (pd()->wants_padded_bias()) {
890         auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
891         utils::array_copy(padded_bias, bia, jcp.oc_without_padding);
892         utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
893                 jcp.oc - jcp.oc_without_padding);
894         bia = padded_bias;
895     }
896
897     auto ptr_V = scratchpad.get<float>(key_wino_V);
898     auto ptr_M = scratchpad.get<float>(key_wino_M);
899
900     parallel_nd(jcp.mb, div_up(jcp.oh,jcp.yb), div_up(jcp.ow, jcp.xb),
901         [&](int mb, int tile_y_b, int tile_x_b) {
902         int tile_y = tile_y_b * jcp.yb;
903         int tile_x = tile_x_b * jcp.xb;
904
905         int ithr = mkldnn_get_thread_num();
906         auto wino_src = ptr_V + size_wino_src * ithr;
907         auto wino_dst = ptr_M + size_wino_dst * ithr;
908
909         auto src_trans_p =
910             jit_avx512_core_fp32_wino_conv_2x3_src_trans_t
911                 ::call_params_t();
912         auto dst_trans_p =
913             jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t
914                 ::call_params_t();
915         auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::
916                 call_params_t();
917
918         /* transformation of input tensor to winograd domain */
919         for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
920             for (int x_in_block = 0; x_in_block < jcp.xb;
921                     x_in_block += 2) {
922
923                 unsigned short v_y_masks[4], v_x_masks[4];
924
925                 int y = y_in_block + tile_y;
926                 int x = x_in_block + tile_x;
927                 int m = (y_in_block / 2) * (jcp.xb / 2)
928                         + (x_in_block / 2);
929
930                 int v_ys = nstl::max(0, jcp.t_pad - y);
931                 int v_ye = nstl::min(jcp.alpha,
932                         nstl::max(0, jcp.ih + jcp.t_pad - y));
933
934                 int v_xs = nstl::max(0, jcp.l_pad - x);
935                 int v_xe = nstl::min(jcp.alpha,
936                         nstl::max(0, jcp.iw + jcp.l_pad - x));
937
938 #pragma unroll(4)
939                 for (int i = 0; i < jcp.alpha; i++) {
940                     v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
941                     v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
942                 }
943                 auto local_s = src
944                         + mb * jcp.nb_ic * jcp.ih * jcp.iw
945                                 * jcp.ic_block
946                         + y * jcp.iw * jcp.ic_block + x * jcp.ic_block;
947                 auto local_w = wino_src + m * jcp.ic;
948
949                 src_trans_p.src = local_s;
950                 src_trans_p.wino_src = local_w;
951                 src_trans_p.v_y_masks = v_y_masks;
952                 src_trans_p.v_x_masks = v_x_masks;
953
954                 src_trans_->ker_(&src_trans_p);
955             }
956         }
957         /* gemms */
958         for (int tile_ij = 0; tile_ij < 16; tile_ij++) {
959             int offset = (tile_ij + ithr) % 16;
960             gemm_p.src = wino_src + jcp.inp_stride * offset;
961             gemm_p.dst = wino_dst + jcp.out_stride * offset;
962             gemm_p.wei = wei + jcp.wei_stride * offset;
963
964             kernel_->ker_(&gemm_p);
965         }
966
967         /* transformation from winograd domain to output tensor */
968         for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
969             for (int x_in_block = 0; x_in_block < jcp.xb;
970                     x_in_block += 2) {
971                 unsigned short v_y_masks[2], v_x_masks[2];
972
973                 int y = y_in_block + tile_y;
974                 int x = x_in_block + tile_x;
975                 int m = (y_in_block / 2) * (jcp.xb / 2)
976                         + (x_in_block / 2);
977
978 #pragma unroll(2)
979                 for (int i = 0; i < jcp.m; i++) {
980                     v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
981                     v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
982                 }
983                 auto local_d = dst
984                         + mb * jcp.nb_oc * jcp.oh * jcp.ow
985                                 * jcp.oc_block
986                         + y * jcp.ow * jcp.oc_block + x * jcp.oc_block;
987                 auto local_w = wino_dst + m * jcp.oc;
988
989                 auto scales = oscales.scales_;
990                 dst_trans_p.dst = local_d;
991                 dst_trans_p.wino_dst = local_w;
992                 dst_trans_p.v_y_masks = v_y_masks;
993                 dst_trans_p.v_x_masks = v_x_masks;
994
995                 dst_trans_p.scales = scales;
996                 dst_trans_p.bias = bia;
997
998                 dst_trans_->ker_(&dst_trans_p);
999             }
1000         }
1001     });
1002 }
1003
1004 void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_small_mb() const
1005 {
1006     auto src = reinterpret_cast<const float *>(input_memory(0));
1007     auto wei = reinterpret_cast<const float *>(input_memory(1));
1008     auto bia = reinterpret_cast<const float *>(input_memory(2));
1009     auto dst = reinterpret_cast<float *>(memory(0));
1010
1011     auto scratchpad = this->scratchpad();
1012
1013     const auto &jcp = kernel_->jcp;
1014     const auto &oscales = pd()->attr()->output_scales_;
1015
1016     if (pd()->wants_padded_bias()) {
1017         auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
1018         utils::array_copy(padded_bias, bia, jcp.oc_without_padding);
1019         utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
1020                 jcp.oc - jcp.oc_without_padding);
1021         bia = padded_bias;
1022     }
1023
1024     auto ptr_V = scratchpad.get<float>(key_wino_V);
1025     auto ptr_M = scratchpad.get<float>(key_wino_M);
1026
1027     for (int mb = 0; mb < jcp.mb; mb++) {
1028     for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) {
1029     for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
1030         /* transformation of input tensor to winograd domain */
1031         parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
1032             [&](int y_in_block_b, int x_in_block_b) {
1033             int y_in_block = y_in_block_b * 2;
1034             int x_in_block = x_in_block_b * 2;
1035
1036             auto src_trans_p = jit_avx512_core_fp32_wino_conv_2x3_src_trans_t ::
1037                     call_params_t();
1038
1039             unsigned short v_y_masks[4], v_x_masks[4];
1040
1041             int y = y_in_block + tile_y;
1042             int x = x_in_block + tile_x;
1043             int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
1044
1045             int v_ys = nstl::max(0, jcp.t_pad - y);
1046             int v_ye = nstl::min(
1047                     jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y));
1048
1049             int v_xs = nstl::max(0, jcp.l_pad - x);
1050             int v_xe = nstl::min(
1051                     jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x));
1052
1053 #pragma unroll(4)
1054             for (int i = 0; i < jcp.alpha; i++) {
1055                 v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
1056                 v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
1057             }
1058             auto local_s = src
1059                     + mb * jcp.nb_ic * jcp.ih * jcp.iw * jcp.ic_block
1060                     + y * jcp.iw * jcp.ic_block + x * jcp.ic_block;
1061             auto local_w = ptr_V + m * jcp.ic;
1062
1063             src_trans_p.src = local_s;
1064             src_trans_p.wino_src = local_w;
1065             src_trans_p.v_y_masks = v_y_masks;
1066             src_trans_p.v_x_masks = v_x_masks;
1067
1068             src_trans_->ker_(&src_trans_p);
1069         });
1070
1071         /* gemms */
1072         parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) {
1073             auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::
1074                     call_params_t();
1075
1076             gemm_p.src = ptr_V + jcp.inp_stride * tile_ij;
1077             gemm_p.dst = ptr_M + jcp.out_stride * tile_ij
1078                     + nnb * jcp.n2_block * jcp.n_block;
1079             gemm_p.wei = wei + jcp.wei_stride * tile_ij
1080                     + nnb * jcp.n2_block * jcp.n_block * jcp.K;
1081
1082             kernel_->ker_(&gemm_p);
1083         });
1084
1085         /* transformation from winograd domain to output tensor */
1086
1087         parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
1088             [&](int y_in_block_b, int x_in_block_b) {
1089             int y_in_block = y_in_block_b * 2;
1090             int x_in_block = x_in_block_b * 2;
1091
1092             auto dst_trans_p = jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t ::
1093                     call_params_t();
1094
1095             unsigned short v_y_masks[2], v_x_masks[2];
1096
1097             int y = y_in_block + tile_y;
1098             int x = x_in_block + tile_x;
1099             int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
1100
1101 #pragma unroll(2)
1102             for (int i = 0; i < jcp.m; i++) {
1103                 v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
1104                 v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
1105             }
1106             auto local_d = dst
1107                     + mb * jcp.nb_oc * jcp.oh * jcp.ow * jcp.oc_block
1108                     + y * jcp.ow * jcp.oc_block + x * jcp.oc_block;
1109             auto local_w = ptr_M + m * jcp.oc;
1110
1111             auto scales = oscales.scales_;
1112             dst_trans_p.dst = local_d;
1113             dst_trans_p.wino_dst = local_w;
1114             dst_trans_p.v_y_masks = v_y_masks;
1115             dst_trans_p.v_x_masks = v_x_masks;
1116
1117             dst_trans_p.scales = scales;
1118             dst_trans_p.bias = bia;
1119
1120             dst_trans_->ker_(&dst_trans_p);
1121         });
1122     }}}
1123 }
1124
1125 } // namespace cpu
1126 } // namespace impl
1127 } // namespace mkldnn