Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_common_conv_winograd_kernel_f32.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "nstl.hpp"
20 #include "type_helpers.hpp"
21 #include "utils.hpp"
22 #include "cpu_memory.hpp"
23
24 #include <math.h>
25
26 #include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
27
28 #ifndef KERNEL_SIZE_THRESHOLD
29 #define KERNEL_SIZE_THRESHOLD 16
30 #endif
31
32 #define MIN_REQUIRED_DIMN_REG_BLOCK 14
33
34 namespace mkldnn {
35 namespace impl {
36 namespace cpu {
37
38 namespace {
39
40 using namespace mkldnn::impl::utils;
41
42 unsigned int L1_cache_size = get_cache_size(1, true);
43 unsigned int L2_cache_size = get_cache_size(2, true);
44 unsigned int LLC_data_size = get_cache_size(3, false);
45
46 // the test funtion takes jcp, the candidate and the current best.
47 // it  returns true if the new candidate is better
48 int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number,
49         int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int))
50 {
51     int best_divisor = default_best;
52     auto test_num
53             = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) {
54                   if (test(jcp, num, best_divisor)) {
55                       best_divisor = num;
56                   }
57               };
58
59     for (int divisor = 1; divisor <= ::sqrt(number); divisor++) {
60         if (number % divisor == 0) {
61             test_num(jcp, divisor);
62             test_num(jcp, number / divisor);
63         }
64     }
65
66     return best_divisor;
67 }
68
69 namespace {
70 bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) {
71     if (jcp.ver == ver_4fma)
72         return jcp.mb >= 32;
73     else
74         return jcp.mb >= 16;
75 }
76 }
77
78 /* assumes 512 bits registers */
79 /* TODO: add support for strides */
80 /* TODO: handle the prefetch distance automatically */
81 typedef enum cache_t_ { L1, L2, L3 } cache_t;
82
83 template <typename data_t>
84 struct prefetcher_t {
85     prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr,
86             cache_t cache_type, size_t block_size, /* in number of elements*/
87             int nb_instructions_in_block, int fma_ipc)
88         : cg_(generator)
89         , reg_base_addr_(reg_base_addr)
90         , cache_type_(cache_type)
91         , cache_block_size_(block_size)
92     {
93         nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t));
94         prefetch_spread_
95                 = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_);
96         prefetch_blk_
97                 = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block);
98
99         /* assumption: when fetch in Li, data is already in L(i+1) */
100         int cache_latency;
101         switch (cache_type_) {
102         case L1: cache_latency = 14; break;
103         case L2:
104         case L3:
105         default: cache_latency = 250; break;
106         }
107
108         prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_);
109     }
110
111     void prefetch(int instruction_number)
112     {
113         if (instruction_number % prefetch_spread_ == 0) {
114             for (int i = 0; (i < prefetch_blk_)
115                     && (prefetches_issued_ < nb_cache_lines_to_prefetch_);
116                     i++, prefetches_issued_++) {
117                 prefetch_inst_(cg_->EVEX_compress_addr(
118                         reg_base_addr_, (cache_block_size_ * prefetch_distance_)
119                                         * sizeof(data_t)
120                                 + (prefetches_issued_ * 64)));
121             }
122         }
123     }
124
125 private:
126     void prefetch_inst_(const Xbyak::Address &addr)
127     {
128         switch (cache_type_) {
129         case L1: cg_->prefetcht0(addr); break;
130         case L2: cg_->prefetcht1(addr); break;
131         case L3: cg_->prefetcht2(addr); break;
132         default:
133             break; // TODO: raise an exception or put an assert
134         }
135     }
136
137     jit_generator *cg_;
138     Xbyak::Reg64 reg_base_addr_;
139     cache_t cache_type_;
140     int cache_block_size_ = 0;
141     int nb_cache_lines_to_prefetch_ = 0;
142     int prefetches_issued_ = 0;
143     int prefetch_spread_ = 0;
144     int prefetch_blk_ = 0;
145     int prefetch_distance_ = 0;
146 };
147
148 // utilities to support kernel parameter selection
149 bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block,
150         int dimM_block, int dimM_simd_block, float C)
151 {
152     float lhs = (dimM_block * dimN_reg_block * dimM_simd_block
153                         + dimM_block * dimK_block * dimK_reg_block
154                                 * dimM_simd_block
155                         + dimK_block * dimN_reg_block * dimK_reg_block)
156             * (float)sizeof(float);
157     float rhs = C * L1_cache_size;
158     return (lhs < rhs);
159 }
160
161 bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block,
162         int dimM_block, int dimM_simd_block, float C)
163 {
164     float lhs = (dimM_block * dimK_block * dimK_reg_block * dimM_simd_block
165                         + dimK_block * dimN_reg_block * dimK_reg_block)
166             * (float)sizeof(float);
167     float rhs = C * L1_cache_size;
168     return (lhs < rhs);
169 }
170
171 bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block,
172         int dimK_block, int dimK_reg_block, int dimM_block, int dimM_simd_block,
173         float C)
174 {
175     float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block * dimM_simd_block
176                       + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block
177                               * dimM_simd_block
178                       + nb_dimN_reg_block * dimK_nb_block * dimK_block
179                               * dimN_reg_block * dimK_reg_block)
180             * (float)sizeof(float);
181     float rhs = C * L2_cache_size;
182     return (lhs < rhs);
183 }
184 }
185
186 using namespace mkldnn::impl::memory_format;
187 using namespace mkldnn::impl::utils;
188 using namespace Xbyak;
189
190 void _jit_avx512_common_conv_winograd_data_kernel_f32::gemm_loop_generate(
191         bool is_beta_zero)
192 {
193     // const int dimK_simd_block = jcp.dimK_reg_block;
194
195     // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++)
196     //     for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++)
197     //         for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block;
198     //         dimK_reg_block++)
199     //                 for (int tile =0; tile < jcp.dimN_reg_block; tile++)
200     //                     C[dimM_block][tile] +=
201     //                     A[dimM_block][dimK_block][dimK_reg_block] *
202     //                     broadcast(B[dimK_block][tile][dimK_reg_block]);
203     // 1) We do register blocking on A[dimM_block][dimK_block][dimK_reg_block],
204     // so we load it before the loop on tile
205     // 2) the loop on tile must be fully unrolled. Don't know about the one on
206     // dimK_reg_block. I think it should be
207
208     auto inner_loops = [=]() {
209         Label dimM_block_loop, dimK_block_loop;
210         const int inc_dimK_reg_block = jcp.ver == ver_4fma ? 4 : 1;
211         const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2;
212
213         prefetcher_t<float> L1_pf(this, reg_srcB, L1,
214                 jcp.dimN_reg_block * jcp.dimK_reg_block,
215                 jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block,
216                 fma_ipc);
217         prefetcher_t<float> L2_pf(this, reg_srcB, L2,
218                 jcp.dimN_reg_block * jcp.dimK_reg_block,
219                 jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block,
220                 fma_ipc);
221
222         if (jcp.dimM_block > 1) {
223             mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
224             L(dimM_block_loop);
225         }
226         {
227             // First, we zero the accumulators if first nb_ic iteration,
228             // otherwise we load them
229             for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
230                 Zmm zmm(jcp.zmm_start + tile);
231                 if (is_beta_zero)
232                     vpxord(zmm, zmm, zmm);
233                 else
234                     vmovups(zmm, zword[reg_dstC + 64 * tile]);
235             }
236
237             if (jcp.dimK_block > 1) {
238                 mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
239                 L(dimK_block_loop);
240             }
241             {
242                 auto load_A = [=](int reg_idx, int offset) {
243                     for (int i = 0; i < inc_dimK_reg_block; i++)
244                         vmovups(Zmm(reg_idx + i),
245                                 zword[reg_srcA + 64 * (offset + i)]);
246                 };
247
248                 // Used when doing double buffering
249                 int next = 0;
250                 if (jcp.double_buffering) {
251                     load_A(next, 0);
252                 }
253                 for (int dimK_reg_block = 0;
254                         dimK_reg_block < jcp.dimK_reg_block;
255                         dimK_reg_block += inc_dimK_reg_block) {
256                     int current;
257                     /* Loading the next vector from A */
258                     current = next;
259                     if (jcp.double_buffering) {
260                         next = (dimK_reg_block + inc_dimK_reg_block)
261                                 % (2 * inc_dimK_reg_block);
262                         load_A(next, dimK_reg_block + inc_dimK_reg_block);
263                     } else {
264                         next = 0;
265                         load_A(next, dimK_reg_block);
266                     }
267                     /* Performing the fmas */
268                     for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
269                         Zmm zmm(jcp.zmm_start + tile);
270                         if (jcp.ver != ver_avx512_core)
271                             L1_pf.prefetch(
272                                     dimK_reg_block * jcp.dimN_reg_block + tile);
273                         if (jcp.ver == ver_4fma)
274                             v4fmaddps(zmm, Zmm(current),
275                                     EVEX_compress_addr(reg_srcB,
276                                               64 * tile + dimK_reg_block * 4));
277                         else
278                             vfmadd231ps(zmm, Zmm(current),
279                                     EVEX_compress_addr(reg_srcB,
280                                                 64 * tile + dimK_reg_block * 4,
281                                                 true));
282                         if (jcp.ver != ver_avx512_core)
283                             L2_pf.prefetch(
284                                     dimK_reg_block * jcp.dimN_reg_block + tile);
285                     }
286                 }
287
288                 add(reg_srcA, jcp.dimK_reg_block * 64);
289                 add(reg_srcB, jcp.dimN_reg_block * 64);
290                 if (jcp.dimK_block > 1) {
291                     sub(reg_dimK_block_loop_cnt, 1);
292                     jnz(dimK_block_loop);
293                 }
294             }
295
296
297             auto store_output = [=](bool output_is_aligned) {
298                 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
299                     Zmm zmm(jcp.zmm_start + tile);
300                     if (output_is_aligned
301                         && jcp.dimK_nb_block == 1
302                         && (jcp.dimN * jcp.dimM * alpha * alpha
303                             * sizeof(float) > 2 * LLC_data_size))
304                         vmovntps(zword[reg_dstC + 64 * tile], zmm);
305                     else
306                         vmovups(zword[reg_dstC + 64 * tile], zmm);
307                 }
308             };
309
310             Label unaligned_store, end_store;
311             test(reg_dstC, cpu_isa_traits<avx512_common>::vlen - 1);
312             jnz(unaligned_store, T_NEAR);
313             store_output(true);
314             jmp(end_store, T_NEAR);
315             L(unaligned_store); {
316                 store_output(false);
317             }
318             L(end_store);
319
320             if (jcp.dimM_block > 1) {
321                 sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64);
322                 add(reg_dstC, jcp.dimN_reg_block * 64);
323                 sub(reg_dimM_block_loop_cnt, 1);
324                 jnz(dimM_block_loop);
325             }
326         }
327     };
328
329     /* Preamble */
330     preamble();
331
332     /* kernel */
333     inner_loops();
334
335     /* Postamble */
336     postamble();
337     ret();
338 }
339
340 status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_common(
341         jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
342         const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
343         const memory_desc_wrapper &dst_d)
344 {
345
346     if (mayiuse(avx512_core))
347         return status::unimplemented;
348     else if (!mayiuse(avx512_common))
349         return status::unimplemented;
350     else if (mayiuse(avx512_mic_4ops))
351         jcp.ver = ver_4fma;
352     else
353         jcp.ver = ver_fma;
354
355     jcp.nthr = mkldnn_get_max_threads();
356
357     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
358
359     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
360     jcp.mb = src_d.dims()[0];
361     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
362     jcp.oc_without_padding = jcp.oc;
363     jcp.ic = src_d.dims()[1] / jcp.ngroups;
364     jcp.ih = src_d.dims()[2];
365     jcp.iw = src_d.dims()[3];
366     jcp.oh = dst_d.dims()[2];
367     jcp.ow = dst_d.dims()[3];
368     jcp.kh = weights_d.dims()[with_groups + 2];
369     jcp.kw = weights_d.dims()[with_groups + 3];
370     jcp.t_pad = cd.padding[0][0];
371     jcp.l_pad = cd.padding[0][1];
372     jcp.stride_h = cd.strides[0];
373     jcp.stride_w = cd.strides[1];
374     jcp.dilate_h = cd.dilates[0];
375     jcp.dilate_w = cd.dilates[1];
376     jcp.r_pad = nstl::max(
377             0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
378     jcp.b_pad = nstl::max(
379             0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
380     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
381     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
382     jcp.ohp = jcp.oh;
383     jcp.owp = jcp.ow;
384
385     bool ok_to_pad_channels = jcp.ngroups == 1;
386     if (ok_to_pad_channels) {
387         jcp.oc = rnd_up(jcp.oc, simd_w);
388         jcp.ic = rnd_up(jcp.ic, simd_w);
389     }
390
391     if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
392                 is_winograd_faster_than_direct(jcp)))
393         return status::unimplemented;
394
395     // Checking conditions not supported by these kernels
396     if (jcp.ngroups != 1)
397         return status::unimplemented;
398     if ((jcp.kh != 3) || (jcp.kw != 3))
399         return status::unimplemented;
400     if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
401         return status::unimplemented;
402     if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
403         return status::unimplemented;
404     if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
405         return status::unimplemented;
406
407     if (src_d.format() != nChw16c)
408         return status::unimplemented;
409     if (weights_d.format() != (with_groups ? gOIhw16i16o : OIhw16i16o))
410         return status::unimplemented;
411     if (dst_d.format() != nChw16c)
412         return status::unimplemented;
413
414     bool layout_consistency = true
415         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
416         && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
417         && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
418         && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
419     if (!layout_consistency) return status::unimplemented;
420
421     return status::success;
422 }
423
424
425 status_t set_wsched_DATA_W_S_G_D_avx512_common(jit_conv_winograd_conf_t &jcp) {
426
427     auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
428             int dimN_reg_block, int current_best) {
429         return (dimN_reg_block >= MIN_REQUIRED_DIMN_REG_BLOCK)
430                 && (dimN_reg_block < jcp.nb_reg)
431                 && (dimN_reg_block < current_best);
432     };
433     jcp.dimN_reg_block = get_divisor_satisfying_cond(
434             jcp, jcp.dimN, jcp.dimN, test_cond_dimN_reg_block);
435
436     if (jcp.dimN_reg_block >= jcp.nb_reg) {
437         auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
438                 int dimN_reg_block, int current_best) {
439             return (dimN_reg_block < jcp.nb_reg)
440                     && (dimN_reg_block > current_best);
441         };
442
443         jcp.dimN_reg_block = get_divisor_satisfying_cond(
444                 jcp, jcp.dimN, 1, test_cond_dimN_reg_block);
445     }
446
447     //********************* Choosing dimK_block **********************//
448     auto test_cond1_dimK_block = [](
449             jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
450         return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block,
451                        1, jcp.dimM_simd_block, .75f)
452                 && (dimK_block > current_best);
453     };
454
455     auto test_cond1_bis_dimK_block = [](
456             jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
457         return check_cond1_bis(jcp.dimN_reg_block, dimK_block,
458                        jcp.dimK_reg_block, 1, jcp.dimM_simd_block, .9f)
459                 && (dimK_block > current_best);
460     };
461
462     jcp.dimK_block = get_divisor_satisfying_cond(
463             jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block);
464     // If we are not able to use streams, we fall back to condition [1]
465     if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
466         jcp.dimK_block = get_divisor_satisfying_cond(
467                 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block);
468     jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block;
469
470     //********************* Choosing dimM_block **********************//
471     jcp.dimM_simd_block = 16;
472     /*XXX: Why C=0.5 here but C=0.75 for dimK_block?*/
473     auto test_cond1_dimM_block = [](
474             jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
475         return check_cond1(jcp.dimN_reg_block, jcp.dimK_block,
476                        jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .5f)
477                 && (dimM_block > current_best);
478     };
479
480     auto test_cond1_bis_dimM_block = [](
481             jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
482         return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block,
483                        jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .3f)
484                 && (dimM_block > current_best);
485     };
486
487     if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
488         jcp.dimM_block = get_divisor_satisfying_cond(
489                 jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block);
490     else
491         jcp.dimM_block = get_divisor_satisfying_cond(jcp,
492                 jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_bis_dimM_block);
493     jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block;
494
495     //******************* Choosing dimN_block *******************//
496     auto test_cond2_dimN_block = [](
497             jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
498         return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block,
499                        jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block,
500                        jcp.dimM_simd_block, .5f)
501                 && (dimN_block > current_best);
502     };
503
504     jcp.dimN_block = get_divisor_satisfying_cond(
505             jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
506     jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block);
507     jcp.sched_policy = WSCHED_DATA_W_S_G_D;
508     return status::success;
509 }
510
511 status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_kernel(
512         jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK)
513 {
514     jcp.dimK_reg_block = 16;
515     jcp.dimM_simd_block = 16;
516
517     // TODO: replace double buffering with nuple buffering to maximize register
518     // usage.
519     // the choice of the number of buffers will then come after choosing
520     // dimN_reg_block
521     jcp.double_buffering = true;
522     if (jcp.double_buffering)
523         jcp.zmm_start = 2 * ((jcp.ver == ver_4fma) ? 4 : 2);
524     else
525         jcp.zmm_start = 1;
526     jcp.nb_reg = 32 - jcp.zmm_start;
527
528     jcp.dimN = dimN;
529     jcp.dimK = dimK;
530     jcp.dimM = dimM;
531
532     jcp.sched_policy = WSCHED_INVALID;
533     set_wsched_DATA_W_S_G_D_avx512_common(jcp);
534
535     assert(jcp.sched_policy == WSCHED_DATA_W_S_G_D);
536     return status::success;
537 }
538
539 bool jit_avx512_common_conv_winograd_fwd_kernel_f32::post_ops_ok(
540         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
541     const auto &p = attr.post_ops_;
542
543     auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
544     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
545
546     switch (p.len_) {
547     case 0: return true; // no post_ops
548     case 1: return is_relu(0) || is_sum(0); // relu or sum
549     case 2: return (is_sum(0) && is_relu(1)) ||
550                        (is_relu(0) && is_sum(1)); // sum->relu or relu->sum
551     case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu
552     default: return false;
553     }
554
555     return false;
556 }
557
558 status_t jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf(
559         jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
560         const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
561         const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) {
562     status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d);
563
564     if (st != status::success)
565         return st;
566
567     // Winograd specific initialization
568     jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
569     jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
570     jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
571
572     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
573
574     if (!post_ops_ok(jcp, attr))
575         return status::unimplemented;
576
577     const auto &p = attr.post_ops_;
578     const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1);
579     jcp.with_eltwise = eltwise_ind != -1;
580     if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
581     jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
582
583     status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic);
584     jcp.ic_simd_block = jcp.dimK_reg_block;
585     jcp.ic_block = jcp.dimK_block;
586     jcp.nb_ic = jcp.dimK_nb_block;
587     jcp.oc_simd_block = jcp.dimM_simd_block;
588     jcp.oc_block = jcp.dimM_block;
589     jcp.nb_oc = jcp.dimM_nb_block;
590     jcp.tile_block_ur = jcp.dimN_reg_block;
591     jcp.nb_tile_block_ur = jcp.dimN_block;
592     jcp.tile_block = jcp.dimN_nb_block;
593     jcp.tile_4fma_padding = 0; // only relevant for backward weights
594
595     return res;
596 }
597
598 status_t jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf(
599         jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
600         const memory_desc_wrapper &diff_src_d,
601         const memory_desc_wrapper &weights_d,
602         const memory_desc_wrapper &diff_dst_d)
603 {
604     status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d);
605
606     if (st != status::success)
607         return st;
608
609     jcp.itiles = (jcp.iw + tile_size - 1) / tile_size;
610     jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size;
611     jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
612
613     status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc);
614     jcp.oc_simd_block = jcp.dimK_reg_block;
615     jcp.oc_block = jcp.dimK_block;
616     jcp.nb_oc = jcp.dimK_nb_block;
617     jcp.ic_simd_block = jcp.dimM_simd_block;
618     jcp.ic_block = jcp.dimM_block;
619     jcp.nb_ic = jcp.dimM_nb_block;
620     jcp.tile_block_ur = jcp.dimN_reg_block;
621     jcp.nb_tile_block_ur = jcp.dimN_block;
622     jcp.tile_block = jcp.dimN_nb_block;
623     jcp.tile_4fma_padding = 0; // only relevant for backward weights
624
625     return res;
626 }
627
628 void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::transpose_ker_generate()
629 {
630     auto load_B = [=](int reg_idx, int offset) {
631         for (int i = 0; i < 4; i++) {
632             vmovups(Zmm(reg_idx + i), zword[reg_origB + (offset + i) * jcp.dimN_reg_block * sizeof(float)]);
633         }
634     };
635
636     preamble();
637     int curr = 0;
638     for (int j = 0; j < alpha; j++) {
639         for (int i = 0; i < alpha; i++) {
640             int origB_offset = (j * alpha + i) * jcp.dimK_4fma;
641             size_t transB_offset = (size_t)(j * alpha + i) * jcp.dimK_nb_block *
642                 jcp.dimN_block * jcp.dimK_block * jcp.dimK_reg_block *
643                 jcp.dimK_4fma * jcp.dimN_reg_block * sizeof(float);
644             mov(reg_transB_idx, transB_offset);
645             for (int tb = 0; tb < jcp.dimK_4fma; tb+=4) {
646                 /*double buffering to hide load latencies*/
647                 int next = (curr + 4) % 8;
648                 if (i == 0 && tb == 0) {
649                     load_B(0, origB_offset);
650                 }
651                 if (tb + 4 < (jcp.dimK_4fma -1)) {
652                     load_B(next, origB_offset + 4);
653                 } else if (i < alpha - 1) {
654                     load_B(next, origB_offset + jcp.dimK_4fma);
655                 }
656
657                 vunpcklps(Zmm(8), Zmm(curr), Zmm(curr + 1));
658                 vunpcklps(Zmm(9), Zmm(curr + 2), Zmm(curr + 3));
659                 vunpckhps(Zmm(curr), Zmm(curr), Zmm(curr + 1));
660                 vunpckhps(Zmm(curr + 1), Zmm(curr + 2), Zmm(curr + 3));
661
662                 vunpcklpd(Zmm(curr + 2), Zmm(8), Zmm(9));
663                 vunpckhpd(Zmm(curr + 3), Zmm(8), Zmm(9));
664
665                 vunpcklpd(Zmm(8), Zmm(curr), Zmm(curr + 1));
666                 vunpckhpd(Zmm(9), Zmm(curr), Zmm(curr + 1));
667
668                 vmovntps(zword[reg_transB + reg_transB_idx
669                         + sizeof(float) * tb * jcp.dimN_reg_block],
670                         Zmm(curr+2));
671                 vmovntps(zword[reg_transB + reg_transB_idx
672                         + sizeof(float) * (tb + 1) * jcp.dimN_reg_block],
673                         Zmm(curr+3));
674                 vmovntps(zword[reg_transB + reg_transB_idx
675                         + sizeof(float) * (tb + 2) * jcp.dimN_reg_block],
676                         Zmm(8));
677                 vmovntps(zword[reg_transB + reg_transB_idx
678                         + sizeof(float) * (tb + 3) * jcp.dimN_reg_block],
679                         Zmm(9));
680                 curr = next;
681
682             }
683         }
684     }
685     postamble();
686     ret();
687 }
688 void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::gemm_loop_generate(
689         bool is_first_tile)
690 {
691     // for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++)
692     //     for (int ifm2 = 0; ifm2 < jcp.ic_block; ifm2++)
693     //             for (int nb_tile_block_ur = 0; nb_tile_block_ur <
694     //             jcp.nb_tile_block_ur; nb_tile_block_ur++)
695     //                 for (int tile_block_ur = 0; tile_block_ur <
696     //                 jcp.tile_block_ur; tile_block_ur++)
697     //                     for (int ifm3 = 0; ifm3 < jcp.ic_reg_block; ++ifm3)
698     //                         U[ofm2][ifm2][ofm3][ifm3][0:oc_simd_block] +=
699     //                             M[ofm2][ofm3][nb_tile_block_ur][tile_block_ur][0:oc_simd_block]
700     //                              *
701     //                              broadcast(V[ifm2][nb_tile_block_ur][ifm3][tile_block_ur])
702     auto inner_loops = [=]() {
703         int inc_fma = jcp.ver == ver_4fma ? 4 : 1;
704         const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2;
705         prefetcher_t<float> L1_pf(this, reg_srcB, L1,
706                 jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma,
707                 jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma
708                         / inc_fma,
709                 fma_ipc);
710         prefetcher_t<float> L2_pf(this, reg_srcB, L2,
711                 jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma,
712                 jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma
713                         / inc_fma,
714                 fma_ipc);
715
716         auto load_A = [=](int reg_idx, int offset) {
717             for (int i = 0; i < inc_fma; i++) {
718                 vmovups(Zmm(reg_idx + i),
719                         zword[reg_srcA +
720                         sizeof(float) * jcp.dimM_simd_block * (offset + i)]);
721             }
722         };
723
724         Label dimM_block_loop, dimK_block_loop, dimN_block_loop;
725         if (jcp.dimM_block > 1) {
726             mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
727             L(dimM_block_loop);
728         }
729         { /************* OC_block (M) loop ***********/
730             if (jcp.dimN_block > 1) {
731                 mov(reg_dimN_block_loop_cnt, jcp.dimN_block);
732                 L(dimN_block_loop);
733             }
734             { /*************** IC_block (N) loop *********/
735                 for (int dimN_reg_block = 0;
736                         dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) {
737                     Zmm zmm(jcp.zmm_start + dimN_reg_block);
738                     if (is_first_tile)
739                         vpxord(zmm, zmm, zmm);
740                     else
741                         vmovups(zmm, zword[reg_dstC +
742                                 dimN_reg_block * jcp.dimM_simd_block *
743                                 sizeof(float)]);
744                 }
745
746                 if (jcp.dimK_block > 1) {
747                     mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
748                     L(dimK_block_loop);
749                 }
750                 { /************* nb_tile_ur(K) loop ********/
751                     int next = 0;
752                     if (jcp.double_buffering) {
753                         load_A(next, 0);
754                     }
755                     for (int dimK_reg_block = 0;
756                             dimK_reg_block < jcp.dimK_reg_block;
757                             dimK_reg_block++) {
758                         int srcB_offset = dimK_reg_block * jcp.dimK_4fma
759                                 * jcp.dimN_reg_block;
760                         for (int dimK_4fma = 0; dimK_4fma < jcp.dimK_4fma;
761                                 dimK_4fma += inc_fma) {
762                             int current = next;
763                             if (jcp.double_buffering) {
764                                 next = (dimK_reg_block * jcp.dimK_4fma
765                                                + dimK_4fma + inc_fma)
766                                         % (2 * inc_fma);
767                                 load_A(next, dimK_reg_block * jcp.dimK_4fma
768                                                 + dimK_4fma + inc_fma);
769                             } else {
770                                 next = 0;
771                                 load_A(next, dimK_reg_block * jcp.dimK_4fma
772                                                 + dimK_4fma);
773                             }
774                             for (int dimN_reg_block = 0;
775                                     dimN_reg_block < jcp.dimN_reg_block;
776                                     ++dimN_reg_block) {
777                                 L1_pf.prefetch(srcB_offset / inc_fma
778                                         + dimK_4fma / inc_fma
779                                                 * jcp.dimN_reg_block
780                                         + dimN_reg_block);
781                                 L2_pf.prefetch(srcB_offset / inc_fma
782                                         + dimK_4fma / inc_fma
783                                                 * jcp.dimN_reg_block
784                                         + dimN_reg_block);
785                                 if (jcp.ver == ver_4fma) {
786                                     int srcB_trans_offset = (dimK_4fma / 4) * 64
787                                             + dimK_4fma % 4;
788                                     v4fmaddps(
789                                             Zmm(jcp.zmm_start + dimN_reg_block),
790                                             Zmm(current),
791                                             EVEX_compress_addr(reg_srcB,
792                                                     sizeof(float) * (
793                                                         srcB_offset +
794                                                         srcB_trans_offset +
795                                                         (dimN_reg_block % 4) * 16 +
796                                                         (dimN_reg_block / 4) * 4)));
797                                 } else {
798                                     vfmadd231ps(
799                                             Zmm(jcp.zmm_start + dimN_reg_block),
800                                             Zmm(current),
801                                             EVEX_compress_addr(reg_srcB,
802                                                 sizeof(float) * (srcB_offset + dimN_reg_block),
803                                                     true));
804                                 }
805                             }
806                         }
807                     }
808                 }
809
810                 add(reg_srcA, jcp.dimK_reg_block * jcp.dimK_4fma
811                                 * jcp.dimM_simd_block * sizeof(float));
812                 add(reg_srcB, jcp.dimK_reg_block * jcp.dimN_reg_block
813                                 * jcp.dimK_4fma * sizeof(float));
814                 if (jcp.dimK_block > 1) {
815                     sub(reg_dimK_block_loop_cnt, 1);
816                     jnz(dimK_block_loop);
817                 }
818
819                 /******** Write C back to memory *******/
820                 for (int dimN_reg_block = 0;
821                         dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) {
822                     Zmm zmm(jcp.zmm_start + dimN_reg_block);
823                     vmovups(zword[reg_dstC +
824                             dimN_reg_block * jcp.dimM_simd_block * sizeof(float)],
825                             zmm);
826                 }
827
828                 sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block *
829                         jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float));
830                 add(reg_dstC, jcp.dimN_reg_block * jcp.dimM_simd_block
831                         * sizeof(float));
832                 if (jcp.dimN_block > 1) {
833                     sub(reg_dimN_block_loop_cnt, 1);
834                     jnz(dimN_block_loop);
835                 }
836             }
837
838             if (jcp.dimM_block > 1) {
839                 sub(reg_srcB, jcp.dimN_block * jcp.dimK_block
840                                 * jcp.dimK_reg_block * jcp.dimN_reg_block
841                                 * jcp.dimK_4fma * sizeof(float));
842                 add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
843                                 * jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float));
844                 sub(reg_dimM_block_loop_cnt, 1);
845                 jnz(dimM_block_loop);
846             }
847         }
848     };
849
850     /* Preamble */
851     // register used to handle long fma encoding
852     preamble();
853     mov(reg_srcA, reg_srcA_const);
854     inner_loops();
855
856     /* Postamble */
857     postamble();
858     ret();
859 }
860
861 namespace {
862 bool check_cond1_wu(int dimM_block, int dimM_simdw, int dimK_block,
863         int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C)
864 {
865     float lhs = 1.0f * dimM_block * dimN_reg_block * dimM_simdw;
866     lhs += dimM_block * dimK_block * dimK_reg_block * dimK_4fma * dimM_simdw;
867     lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma;
868     lhs *= sizeof(float);
869     float rhs = C * L1_cache_size;
870     return (lhs <= rhs);
871 }
872
873 bool check_cond1bis_wu(int dimM_block, int dimM_simdw, int dimK_block,
874         int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C)
875 {
876     float lhs = 1.0f * dimM_block * dimK_block * dimK_reg_block * dimK_4fma
877             * dimM_simdw;
878     lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma;
879     lhs *= sizeof(float);
880     float rhs = C * L1_cache_size;
881     return (lhs <= rhs);
882 }
883
884 bool check_cond2bis_wu(int dimM_block, int dimM_simdw, int dimK_block,
885         int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block,
886         float C)
887 {
888     float lhs = 1.0f * dimM_block * dimM_simdw * dimK_block * dimK_reg_block
889             * dimK_4fma;
890     lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block
891             * dimN_reg_block;
892     lhs *= sizeof(float);
893     float rhs = C * L2_cache_size;
894     return (lhs <= rhs);
895 }
896
897 bool check_cond2_wu(int dimM_block, int dimM_simdw, int dimK_block,
898         int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block,
899         float C)
900 {
901     float lhs = 1.0f * dimM_block * dimM_simdw * dimN_block * dimN_reg_block;
902     lhs += dimM_block * dimM_simdw * dimK_block * dimK_reg_block * dimK_4fma;
903     lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block
904             * dimN_reg_block;
905     lhs *= sizeof(float);
906     float rhs = C * L2_cache_size;
907     return (lhs <= rhs);
908 }
909 } // namespace
910
911 status_t set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp)
912 {
913     /*************** Choose dimN_reg_block (ic_simd_block)
914      * *******************************/
915     jcp.dimN = jcp.ic;
916     /*Hardcoded to 16 because N = ic for bwd weights and
917      innermost dimension for ic is assumed 16 in src transforms. This
918      choice covers load latencies while maintaining simplicity of kernel
919      for POR topologies. FIXME in future??: Will not work for future topologies
920      when ic%16 != 0*/
921     jcp.dimN_reg_block = jcp.ic_simd_block;
922
923     /****************************** Choose dimK_block
924      * **************************/
925     // No freedom for choosing dimM_simd_block because ic_simd_block
926     // is determined by input data format
927     jcp.dimM_simd_block = jcp.oc_simd_block;
928
929     auto test_cond1bis_dimK_block = [](
930             jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
931         return check_cond1bis_wu(1, jcp.dimM_simd_block, dimK_block, 1,
932                        jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f)
933                 && (dimK_block > current_best);
934     };
935
936     auto test_cond1_dimK_block = [](
937             jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
938         return check_cond1_wu(1, jcp.dimM_simd_block, dimK_block, 1,
939                        jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f)
940                 && (dimK_block > current_best);
941     };
942
943     auto test_cond2bis_dimK_block = [](
944             jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
945         return check_cond2bis_wu(1, jcp.dimM_simd_block, dimK_block, 1,
946                        jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.5f)
947                 && (dimK_block > current_best);
948     };
949
950     auto test_cond2_dimK_block = [](
951             jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
952         return check_cond2_wu(1, jcp.dimM_simd_block, dimK_block, 1,
953                        jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.1f)
954                 && (dimK_block > current_best);
955     };
956
957     jcp.dimK_block = get_divisor_satisfying_cond(
958             jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2bis_dimK_block);
959     if (jcp.dimK_block < jcp.dimK / jcp.dimK_4fma)
960         jcp.dimK_block = get_divisor_satisfying_cond(
961                 jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2_dimK_block);
962
963     jcp.dimK_reg_block = get_divisor_satisfying_cond(
964             jcp, jcp.dimK_block, 1, test_cond1bis_dimK_block);
965     if (jcp.dimK_reg_block < jcp.dimK_block) {
966         jcp.dimK_reg_block = get_divisor_satisfying_cond(
967                 jcp, jcp.dimK_block, 1, test_cond1_dimK_block);
968     }
969     jcp.dimK_block /= jcp.dimK_reg_block;
970     jcp.dimK_nb_block
971             = jcp.dimK / jcp.dimK_4fma / jcp.dimK_reg_block / jcp.dimK_block;
972     jcp.tile_block_ur = jcp.dimK_reg_block;
973     jcp.nb_tile_block_ur = jcp.dimK_block;
974     jcp.tile_block = jcp.dimK_nb_block;
975
976     /***************************** Chose dimN block
977      * ****************************/
978     auto test_cond2_dimN_block = [](
979             jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
980         return check_cond2_wu(1, jcp.dimM_simd_block, jcp.dimK_block,
981                        jcp.dimK_reg_block, jcp.dimK_4fma, dimN_block,
982                        jcp.dimN_reg_block, 0.5f)
983                 && (dimN_block > current_best);
984     };
985
986     jcp.dimN_block = get_divisor_satisfying_cond(
987             jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
988     jcp.ic_block = jcp.dimN_block;
989     jcp.dimN_nb_block = jcp.dimN / jcp.dimN_reg_block / jcp.dimN_block;
990     jcp.nb_ic = jcp.dimN_nb_block;
991
992     /********************************* Choose dimM block
993      * ************************/
994     jcp.dimM = jcp.oc;
995
996     auto test_cond1_dimM_block = [](
997             jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
998         return check_cond1_wu(dimM_block, jcp.dimM_simd_block, 1,
999                        jcp.dimK_reg_block, jcp.dimK_4fma, jcp.dimN_reg_block,
1000                        1.0f)
1001                 && (dimM_block > current_best)
1002                 && (jcp.dimM / jcp.dimM_simd_block / dimM_block) >= 2;
1003     };
1004
1005     jcp.dimM_block = get_divisor_satisfying_cond(
1006             jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block);
1007     jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block;
1008
1009     jcp.sched_policy = WSCHED_WEI_S_D_G_W;
1010     return status::success;
1011 }
1012
1013 status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf(
1014         jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1015         const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d,
1016         const memory_desc_wrapper &diff_weights_d)
1017 {
1018     jcp.nthr = mkldnn_get_max_threads();
1019
1020     const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
1021
1022     jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
1023     jcp.mb = src_d.dims()[0];
1024     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1025     jcp.oc_without_padding = jcp.oc;
1026     jcp.ic = src_d.dims()[1] / jcp.ngroups;
1027     jcp.ih = src_d.dims()[2];
1028     jcp.iw = src_d.dims()[3];
1029     jcp.oh = diff_dst_d.dims()[2];
1030     jcp.ow = diff_dst_d.dims()[3];
1031     jcp.kh = diff_weights_d.dims()[with_groups + 2];
1032     jcp.kw = diff_weights_d.dims()[with_groups + 3];
1033     jcp.t_pad = cd.padding[0][0];
1034     jcp.l_pad = cd.padding[0][1];
1035     jcp.stride_h = cd.strides[0];
1036     jcp.stride_w = cd.strides[1];
1037     jcp.r_pad = nstl::max(
1038             0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
1039     jcp.b_pad = nstl::max(
1040             0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
1041     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1042     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1043     jcp.ohp = jcp.oh;
1044     jcp.owp = jcp.ow;
1045     jcp.with_bias = (cd.diff_bias_desc.format != memory_format::undef);
1046     jcp.dilate_h = cd.dilates[0];
1047     jcp.dilate_w = cd.dilates[1];
1048
1049     bool ok_to_pad_channels = jcp.ngroups == 1;
1050     if (ok_to_pad_channels) {
1051         jcp.oc = rnd_up(jcp.oc, simd_w);
1052         jcp.ic = rnd_up(jcp.ic, simd_w);
1053     }
1054
1055     if (mayiuse(avx512_core))
1056         return status::unimplemented;
1057     if (!mayiuse(avx512_common))
1058         return status::unimplemented;
1059     else if (mayiuse(avx512_mic_4ops))
1060         jcp.ver = ver_4fma;
1061     else
1062         jcp.ver = ver_fma;
1063
1064     if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
1065                 is_winograd_faster_than_direct(jcp)))
1066         return status::unimplemented;
1067     // Winograd specific initialization
1068     jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
1069     jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
1070     jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
1071
1072     // Winograd kernel works only for 3x3 convolution with stride 1
1073     if (jcp.ngroups != 1)
1074         return status::unimplemented;
1075     if ((jcp.kh != 3) || (jcp.kw != 3))
1076         return status::unimplemented;
1077     if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
1078         return status::unimplemented;
1079     if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
1080         return status::unimplemented;
1081     if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
1082         return status::unimplemented;
1083     if (src_d.format() != nChw16c)
1084         return status::unimplemented;
1085     if (diff_weights_d.format() != (with_groups ? gOIhw16i16o : OIhw16i16o))
1086         return status::unimplemented;
1087     if (diff_dst_d.format() != nChw16c)
1088         return status::unimplemented;
1089
1090     bool layout_consistency = true
1091         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
1092         && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
1093         && jcp.ic <= diff_weights_d.blocking_desc().padding_dims[with_groups + 1]
1094         && jcp.oc <= diff_weights_d.blocking_desc().padding_dims[with_groups + 0];
1095     if (!layout_consistency) return status::unimplemented;
1096
1097     /*************************** New Kernel Parameters
1098      * *****************************/
1099     jcp.ic_simd_block = simd_w;
1100     jcp.oc_simd_block = simd_w;
1101     jcp.dimK_4fma = 1;
1102     jcp.tile_4fma_padding = 0;
1103
1104 #define MAX_4FMA_UR 8
1105     if (jcp.ver == ver_4fma) {
1106         auto test_cond_4fma = [](jit_conv_winograd_conf_t &jcp, int dimK_4fma,
1107                                       int current_best) {
1108             return (dimK_4fma % 4 == 0) && (dimK_4fma <= MAX_4FMA_UR)
1109                     && (dimK_4fma > current_best);
1110         };
1111         jcp.dimK_4fma = get_divisor_satisfying_cond(
1112                 jcp, jcp.itiles * jcp.jtiles, 4, test_cond_4fma);
1113         if (jcp.dimK_4fma == 1)
1114             jcp.dimK_4fma = 4;
1115         if ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma != 0)
1116             jcp.tile_4fma_padding = jcp.dimK_4fma
1117                     - ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma);
1118     }
1119
1120     jcp.tile_4fma = jcp.dimK_4fma;
1121     /*NOTE: When (itiles * jtiles) % dimK_4fma != 0, transpose in diff_src
1122      * transform
1123      * will not work correctly, this is solved by applying padding.*/
1124     jcp.dimK = jcp.mb * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding);
1125     jcp.dimN = jcp.ic;
1126     jcp.dimM = jcp.oc;
1127
1128     jcp.double_buffering = true;
1129     if (jcp.double_buffering)
1130         jcp.zmm_start = jcp.ver == ver_4fma ? 8 : 2;
1131     else
1132         jcp.zmm_start = jcp.ver == ver_4fma ? 4 : 1;
1133     jcp.nb_reg = 32 - jcp.zmm_start;
1134
1135     jcp.sched_policy = WSCHED_INVALID;
1136     status_t res = set_wsched_WEI_S_D_G_W_avx512_common(jcp);
1137     assert(jcp.sched_policy == WSCHED_WEI_S_D_G_W);
1138
1139     jcp.tile_block_ur = jcp.dimK_reg_block;
1140     jcp.nb_tile_block_ur = jcp.dimK_block;
1141     jcp.tile_block = jcp.dimK_nb_block;
1142
1143     jcp.ic_block = jcp.dimN_block;
1144     jcp.nb_ic = jcp.dimN_nb_block;
1145
1146     jcp.oc_block = jcp.dimM_block;
1147     jcp.nb_oc = jcp.dimM_nb_block;
1148
1149     return res;
1150
1151 }
1152 }
1153 }
1154 }
1155
1156 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s