Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "nstl.hpp"
20 #include "type_helpers.hpp"
21 #include "utils.hpp"
22 #include "cpu_memory.hpp"
23
24 #include <math.h>
25
26 #include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp"
27
28 #define GET_OFF(field) offsetof(jit_wino_transform_call_s, field)
29
30 namespace mkldnn {
31 namespace impl {
32 namespace cpu {
33
34 namespace {
35
36 using namespace mkldnn::impl::utils;
37
38 unsigned int L1_cache_size = get_cache_size(1, true);
39 unsigned int L2_cache_size = get_cache_size(2, true);
40 unsigned int LLC_data_size = get_cache_size(3, false);
41
42 // the test funtion takes jcp, the candidate and the current best.
43 // it  returns true if the new candidate is better
44 int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number,
45         int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int))
46 {
47     int best_divisor = default_best;
48     auto test_num
49             = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) {
50                   if (test(jcp, num, best_divisor)) {
51                       best_divisor = num;
52                   }
53               };
54
55     for (int divisor = 1; divisor <= ::sqrt(number); divisor++) {
56         if (number % divisor == 0) {
57             test_num(jcp, divisor);
58             test_num(jcp, number / divisor);
59         }
60     }
61
62     return best_divisor;
63 }
64
65 namespace {
66 bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) {
67     /* Determines if current winograd implementation is faster than direct.
68        Following conditions are empirical and based on performance data */
69     unsigned int ncores_per_socket =
70         cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel);
71     unsigned int nthreads = mkldnn_get_max_threads();
72
73     if (jcp.prop_kind == prop_kind::forward_inference) {
74         return jcp.mb >= 4;
75     } else if (nthreads > ncores_per_socket) {
76         double src_dst_transforms_per_core = alpha * alpha
77             * (jcp.ic + jcp.oc)
78             * jcp.mb * ((jcp.oh + tile_size - 1) / tile_size)
79             * ((jcp.ow + tile_size - 1) / tile_size)
80             * sizeof(float) / 1024. / 1024. / nthreads;
81         double wei_transform = alpha * alpha
82             * jcp.ic * jcp.oc * sizeof(float) /1024. / 1024.;
83
84         if (jcp.prop_kind == prop_kind::backward_weights) {
85             if (src_dst_transforms_per_core < 0.3
86                     || (src_dst_transforms_per_core <= 28 && wei_transform < 4))
87                 return false;
88             else
89                 return true;
90         } else {
91             if (src_dst_transforms_per_core < 2.0 || wei_transform < 0.02)
92                 return false;
93         }
94     }
95
96     return jcp.mb > 8;
97 }
98 }
99
100 /* assumes 512 bits registers */
101 /* TODO: add support for strides */
102 /* TODO: handle the prefetch distance automatically */
103 typedef enum cache_t_ { L1, L2, L3 } cache_t;
104
105 template <typename data_t>
106 struct prefetcher_t {
107     prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr,
108             cache_t cache_type, size_t block_size, /* in number of elements*/
109             int nb_instructions_in_block, int fma_ipc)
110         : cg_(generator)
111         , reg_base_addr_(reg_base_addr)
112         , cache_type_(cache_type)
113         , cache_block_size_(block_size)
114     {
115         nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t));
116         prefetch_spread_
117                 = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_);
118         prefetch_blk_
119                 = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block);
120
121         /* assumption: when fetch in Li, data is already in L(i+1) */
122         int cache_latency;
123         switch (cache_type_) {
124         case L1: cache_latency = 14; break;
125         case L2: cache_latency = 250; break;
126         case L3: cache_latency = 250; break;
127         }
128
129         prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_);
130     }
131
132     void prefetch(int instruction_number)
133     {
134         if (instruction_number % prefetch_spread_ == 0) {
135             for (int i = 0; (i < prefetch_blk_)
136                     && (prefetches_issued_ < nb_cache_lines_to_prefetch_);
137                     i++, prefetches_issued_++) {
138                 prefetch_inst_(cg_->EVEX_compress_addr(
139                         reg_base_addr_, (cache_block_size_ * prefetch_distance_)
140                                         * sizeof(data_t)
141                                 + (prefetches_issued_ * 64)));
142             }
143         }
144     }
145
146 private:
147     void prefetch_inst_(const Xbyak::Address &addr)
148     {
149         switch (cache_type_) {
150         case L1: cg_->prefetcht0(addr); break;
151         case L2: cg_->prefetcht1(addr); break;
152         case L3: cg_->prefetcht2(addr); break;
153         default:
154             break; // TODO: raise an exception or put an assert
155         }
156     }
157
158     jit_generator *cg_;
159     Xbyak::Reg64 reg_base_addr_;
160     cache_t cache_type_;
161     int cache_block_size_ = 0;
162     int nb_cache_lines_to_prefetch_ = 0;
163     int prefetches_issued_ = 0;
164     int prefetch_spread_ = 0;
165     int prefetch_blk_ = 0;
166     int prefetch_distance_ = 0;
167 };
168
169 // utilities to support kernel parameter selection
170 bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp,
171         int dimN_block, float C2_min, float C2_max) {
172     float block_size = alpha * alpha * (2*(jcp.oc + jcp.ic)
173         * dimN_block * jcp.dimN_reg_block
174         + div_up(jcp.ic * jcp.oc,mkldnn_get_max_threads())) * (float)sizeof(float);
175     float L2_lb = C2_min * L2_cache_size;
176     float L2_ub = C2_max * L2_cache_size;
177     return (block_size > L2_lb && block_size < L2_ub);
178 }
179
180 bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block,
181         int dimM_block, float C1_min, float C1_max) {
182     float gemm_block_size = (dimM_block * jcp.dimM_simd_block * dimK_block
183                              * jcp.dimK_reg_block * jcp.dimM_reg_block
184                      + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block
185                      + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block)
186                      * (float)sizeof(float);
187     float L1_lb = C1_min * L1_cache_size;
188     float L1_ub = C1_max * L1_cache_size;
189     return (gemm_block_size > L1_lb && gemm_block_size < L1_ub);
190 }
191 bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block,
192         int dimM_block, int dimM_reg_block, int dimM_simd_block, float C)
193 {
194     float lhs = (dimM_block * dimN_reg_block * dimM_simd_block * dimM_reg_block
195                         + dimM_block * dimK_block * dimK_reg_block
196                                 * dimM_simd_block * dimM_reg_block
197                         + dimK_block * dimN_reg_block * dimK_reg_block)
198             * (float)sizeof(float);
199     float rhs = C * L1_cache_size;
200     return (lhs < rhs);
201 }
202 bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block,
203         int dimM_block, int dimM_reg_block, int dimM_simd_block, float C)
204 {
205     float lhs = (dimM_block * dimM_reg_block * dimK_block * dimK_reg_block
206             * dimM_simd_block + dimK_block * dimN_reg_block * dimK_reg_block)
207             * (float)sizeof(float);
208     float rhs = C * L1_cache_size;
209     return (lhs < rhs);
210 }
211 bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block,
212         int dimK_block, int dimK_reg_block, int dimM_block, int dimM_reg_block,
213         int dimM_simd_block, float C)
214 {
215     float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block
216                               * dimM_simd_block * dimM_reg_block
217                       + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block
218                               * dimM_simd_block * dimM_reg_block
219                       + nb_dimN_reg_block * dimK_nb_block * dimK_block
220                               * dimN_reg_block * dimK_reg_block)
221             * (float)sizeof(float);
222     float rhs = C * L2_cache_size;
223     return (lhs < rhs);
224 }
225
226 bool check_kernel_cond(int dimM_block, int dimM_reg_block, int dimM_simd_block,
227         int dimN_block, int dimN_reg_block, int dimK, float C1, float C2)
228 {
229     float A_size = dimM_block * dimM_reg_block * dimM_simd_block * dimK
230         * (float)sizeof(float);
231     float B_size = dimN_block * dimN_reg_block * dimK
232         * (float)sizeof(float);
233     return (A_size > C1 * L2_cache_size && B_size > C2 * L2_cache_size);
234 }
235 }
236
237 using namespace mkldnn::impl::memory_format;
238 using namespace mkldnn::impl::utils;
239 using namespace Xbyak;
240
241 void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::gemm_loop_generate()
242 {
243     // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++)
244     // for (int dimM_reg_block =0; dimM_reg_block < jcp.dimM_reg_block;
245     //      dimM_reg_block++) // unrolled
246     //     for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++)
247     //         for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block;
248     //              dimK_reg_block++) // unrolled
249     //             for (int tile =0; tile < jcp.dimN_reg_block; tile++)
250     //                 C[dimM_block][dimM_reg_block][tile] +=
251     //                 A[dimM_block][dimM_reg_block][dimK_block][dimK_reg_block]
252     //                 * broadcast(B[dimK_block][tile][dimK_reg_block]);
253     // Notes:
254     // jcp.kernel_kind defines embedded or explicit broadcast
255     // dimM_reg_block=1 for embedded bcast kernel
256
257     auto zmm_srcA = [=]() {
258         return Xbyak::Zmm(0);
259     };
260     auto zmm_srcB = [=](int tile) {
261         int idx = 1 + tile;
262         assert(idx < 1 + jcp.dimN_reg_block);
263         return Xbyak::Zmm(idx);
264     };
265     auto zmm_dstC = [=](int dimM_reg_block, int tile) {
266         int idx{0};
267         if (jcp.kernel_kind == embd_bcast)
268             idx = 1 + tile;
269         else
270             idx = 1 + jcp.dimN_reg_block
271                   + dimM_reg_block * jcp.dimN_reg_block + tile;
272         assert(idx < 32);
273         return Xbyak::Zmm(idx);
274     };
275
276     auto prepare_output = [=]() {
277         for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
278               dimM_reg_block++) {
279             for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
280                 Zmm zmm = zmm_dstC(dimM_reg_block, tile);
281                 vpxord(zmm, zmm, zmm);
282             }
283         }
284     };
285     auto store_output = [=](bool output_is_aligned) {
286         Label save;
287         cmp(reg_is_beta_zero, 0);
288         je(save, T_NEAR);
289
290         for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
291               dimM_reg_block++) {
292             for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
293                 Zmm zmm = zmm_dstC(dimM_reg_block,tile);
294                 int output_offset
295                     = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64;
296                 vaddps(zmm, zmm, EVEX_compress_addr(reg_dstC, output_offset));
297             }
298         }
299
300         L(save);
301         for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
302               dimM_reg_block++) {
303             for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
304                 Zmm zmm = zmm_dstC(dimM_reg_block,tile);
305                 int output_offset
306                     = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64;
307
308                 // In W_SGD, output will be reused.
309                 if (output_is_aligned
310                     && jcp.dimK_nb_block == 1
311                     && jcp.sched_policy == WSCHED_DATA_W_S_G_D
312                     && (jcp.dimN * jcp.dimM * alpha * alpha
313                         * sizeof(float) > 2 * LLC_data_size))
314                     vmovntps(EVEX_compress_addr(reg_dstC, output_offset), zmm);
315                 else vmovups(EVEX_compress_addr(reg_dstC, output_offset), zmm);
316             }
317         }
318     };
319
320     auto inner_loops = [=]() {
321         Label dimM_block_loop, dimK_block_loop;
322
323         if (jcp.dimM_block > 1) {
324             mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
325             L(dimM_block_loop);
326         }
327
328         prepare_output();
329
330         if (jcp.dimK_block > 1) {
331             mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
332             L(dimK_block_loop);
333         }
334
335         for (int dimK_reg_block = 0;
336                 dimK_reg_block < jcp.dimK_reg_block;
337                 dimK_reg_block ++) {
338
339             if (jcp.kernel_kind == expl_bcast) {
340                 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
341                     vbroadcastss(zmm_srcB(tile),
342                         ptr[reg_srcB + 64 * tile + dimK_reg_block * 4]);
343                 }
344             }
345
346             /* Performing the fmas */
347
348             for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
349                 dimM_reg_block++) {
350
351                 vmovups(zmm_srcA(),
352                     zword[reg_srcA
353                             + jcp.dimK_reg_block * jcp.dimK_block * 64
354                               * dimM_reg_block
355                             + dimK_reg_block * 64]
356                     );
357
358                 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
359                     if (jcp.kernel_kind == expl_bcast)
360                         vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(),
361                             zmm_srcB(tile));
362                     else
363                         vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(),
364                             EVEX_compress_addr(reg_srcB,
365                                 64 * tile + dimK_reg_block * 4, true));
366                 }
367             }
368         }
369         add(reg_srcA, jcp.dimK_reg_block * 64);
370         add(reg_srcB, jcp.dimN_reg_block * 64);
371         if (jcp.dimK_block > 1) {
372             sub(reg_dimK_block_loop_cnt, 1);
373             jnz(dimK_block_loop);
374         }
375
376         Label unaligned_store, end_store;
377         test(reg_dstC, cpu_isa_traits<avx512_core>::vlen - 1);
378         jnz(unaligned_store, T_NEAR);
379         store_output(true);
380         jmp(end_store, T_NEAR);
381         L(unaligned_store); {
382             store_output(false);
383         }
384         L(end_store);
385
386         if (jcp.dimM_block > 1) {
387             sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64);
388             add(reg_dstC, jcp.dimM_reg_block * jcp.dimN_reg_block * 64);
389             if (jcp.kernel_kind == expl_bcast) {
390                 add(reg_srcA,
391                      (jcp.dimM_reg_block-1) * jcp.dimK_reg_block * 64
392                       * jcp.dimK_block);
393             }
394             sub(reg_dimM_block_loop_cnt, 1);
395             jnz(dimM_block_loop);
396         }
397     };
398
399     /* Preamble */
400     preamble();
401
402     /* kernel */
403     inner_loops();
404
405     /* Postamble */
406     postamble();
407     ret();
408 }
409
410 void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
411     ::weights_transform_data_ker_generate()
412 {
413     bool is_fwd = one_of(jcp.prop_kind,
414         mkldnn_forward_training, mkldnn_forward_inference);
415     int kh = jcp.kh;
416     int kw = jcp.kw;
417
418     auto zmm_temp = Xbyak::Zmm(31);
419     auto zmm_zero = Xbyak::Zmm(30);
420
421     auto zmm_M = [=](int i) {
422         return Xbyak::Zmm(i);
423     };
424     auto zmm_MT = [=](int i) {
425         return Xbyak::Zmm(i + simd_w);
426     };
427
428     auto zmm_G = [=](int i) {
429         return Xbyak::Zmm(i);
430     };
431     auto zmm_F = [=](int i) {
432         return Xbyak::Zmm(alpha + i);
433     };
434     auto zmm_T = [=](int i) {
435         return Xbyak::Zmm(alpha + 3 + i);
436     };
437     auto zmm_t = [=](int i) {
438         return Xbyak::Zmm(2 * alpha + 3 + i);
439     };
440
441     auto zmm_load = [=](int i) {
442         return Xbyak::Zmm(i);
443     };
444
445     auto init_G = [=]() {
446         mov(wreg_temp, ptr[param1 + GET_OFF(G)]);
447         for (int i = 0; i < alpha; i++) {
448             vbroadcastss(zmm_G(i), ptr[wreg_temp + i * typesize]);
449         }
450         vpxord(zmm_zero, zmm_zero, zmm_zero);
451     };
452
453     auto trans16x16 = [=]() {
454         for (int i = 0; i < simd_w; i+=2 ) {
455             vmovups(zmm_M(i), ptr[wreg_M + i * simd_w * 4]);
456             vmovups(zmm_M(i+1), ptr[wreg_M + (i + 1) * simd_w * 4]);
457             vunpcklps(zmm_MT(i), zmm_M(i), zmm_M(i+1));
458             vunpckhps(zmm_MT(i+1), zmm_M(i), zmm_M(i+1));
459         }
460         for (int i = 0; i < simd_w; i+=4 ) {
461             vunpcklpd(zmm_M(i), zmm_MT(i), zmm_MT(i+2));
462             vunpckhpd(zmm_M(i+1), zmm_MT(i), zmm_MT(i+2));
463             vunpcklpd(zmm_M(i+2), zmm_MT(i+1), zmm_MT(i+3));
464             vunpckhpd(zmm_M(i+3), zmm_MT(i+1), zmm_MT(i+3));
465         }
466         for (int i = 0; i < simd_w; i += 8) {
467             vshuff32x4(zmm_MT(i), zmm_M(i), zmm_M(i + 4), 0x88);
468             vshuff32x4(zmm_MT(i+1), zmm_M(i+1), zmm_M(i + 5), 0x88);
469             vshuff32x4(zmm_MT(i+2), zmm_M(i+2), zmm_M(i + 6), 0x88);
470             vshuff32x4(zmm_MT(i+3), zmm_M(i+3), zmm_M(i + 7), 0x88);
471             vshuff32x4(zmm_MT(i+4), zmm_M(i), zmm_M(i + 4), 0xdd);
472             vshuff32x4(zmm_MT(i+5), zmm_M(i+1), zmm_M(i + 5), 0xdd);
473             vshuff32x4(zmm_MT(i+6), zmm_M(i+2), zmm_M(i + 6), 0xdd);
474             vshuff32x4(zmm_MT(i+7), zmm_M(i+3), zmm_M(i + 7), 0xdd);
475         }
476         {
477             int i = 0;
478             int mask = 0x88;
479             vshuff32x4(zmm_M(0), zmm_MT(i), zmm_MT(i + 8), mask);
480             vmovups(ptr[wreg_MT + 0 * 16 * 4], zmm_M(0));
481             vshuff32x4(zmm_M(1), zmm_MT(i + 1), zmm_MT(i + 9), mask);
482             vmovups(ptr[wreg_MT + 1 * 16 * 4], zmm_M(1));
483             vshuff32x4(zmm_M(2), zmm_MT(i + 2), zmm_MT(i + 10), mask);
484             vmovups(ptr[wreg_MT + 2 * 16 * 4], zmm_M(2));
485             vshuff32x4(zmm_M(3), zmm_MT(i + 3), zmm_MT(i + 11), mask);
486             vmovups(ptr[wreg_MT + 3 * 16 * 4], zmm_M(3));
487             vshuff32x4(zmm_M(4), zmm_MT(i + 4), zmm_MT(i + 12), mask);
488             vmovups(ptr[wreg_MT + 4 * 16 * 4], zmm_M(4));
489             vshuff32x4(zmm_M(5), zmm_MT(i + 5), zmm_MT(i + 13), mask);
490             vmovups(ptr[wreg_MT + 5 * 16 * 4], zmm_M(5));
491             vshuff32x4(zmm_M(6), zmm_MT(i + 6), zmm_MT(i + 14), mask);
492             vmovups(ptr[wreg_MT + 6 * 16 * 4], zmm_M(6));
493             vshuff32x4(zmm_M(7), zmm_MT(i + 7), zmm_MT(i + 15), mask);
494             vmovups(ptr[wreg_MT + 7 * 16 * 4], zmm_M(7));
495             mask = 0xdd;
496             vshuff32x4(zmm_M(8), zmm_MT(i), zmm_MT(i + 8), mask);
497             vmovups(ptr[wreg_MT + 8 * 16 * 4], zmm_M(8));
498             vshuff32x4(zmm_M(9), zmm_MT(i + 1), zmm_MT(i + 9), mask);
499             vmovups(ptr[wreg_MT + 9 * 16 * 4], zmm_M(9));
500             vshuff32x4(zmm_M(10), zmm_MT(i + 2), zmm_MT(i + 10), mask);
501             vmovups(ptr[wreg_MT + 10 * 16 * 4], zmm_M(10));
502             vshuff32x4(zmm_M(11), zmm_MT(i + 3), zmm_MT(i + 11), mask);
503             vmovups(ptr[wreg_MT + 11 * 16 * 4], zmm_M(11));
504             vshuff32x4(zmm_M(12), zmm_MT(i + 4), zmm_MT(i + 12), mask);
505             vmovups(ptr[wreg_MT + 12 * 16 * 4], zmm_M(12));
506             vshuff32x4(zmm_M(13), zmm_MT(i + 5), zmm_MT(i + 13), mask);
507             vmovups(ptr[wreg_MT + 13 * 16 * 4], zmm_M(13));
508             vshuff32x4(zmm_M(14), zmm_MT(i + 6), zmm_MT(i + 14), mask);
509             vmovups(ptr[wreg_MT + 14 * 16 * 4], zmm_M(14));
510             vshuff32x4(zmm_M(15), zmm_MT(i + 7), zmm_MT(i + 15), mask);
511             vmovups(ptr[wreg_MT + 15 * 16 * 4], zmm_M(15));
512         }
513     };
514
515     auto load_src = [=]() {
516         mov(wreg_src, ptr[param1 + GET_OFF(src)]);
517         mov(wreg_F, ptr[param1 + GET_OFF(M)]);
518         for (int j = 0; j < kh; j++) {
519             for (int i = 0; i < kw; i++) {
520                 if (is_fwd) {
521                     for (int v1 = 0; v1 < simd_w; v1++) {
522                         int offset_src = (j * kw * simd_w * simd_w
523                             + i * simd_w * simd_w + v1 * simd_w) * typesize;
524                         int offset_F = (j * kw * simd_w * simd_w
525                             + i * simd_w * simd_w  + v1 * simd_w) * typesize;
526                         vmovups(zmm_temp, ptr[wreg_src + offset_src]);
527                         vmovups(ptr[wreg_F + offset_F], zmm_temp);
528                     }
529                 } else {
530                     int offset_src = ((2 - j) * kw * simd_w * simd_w
531                         + (2 - i) * simd_w * simd_w) * typesize;
532                     int offset_F = (j * kw * simd_w * simd_w
533                         + i * simd_w * simd_w) * typesize;
534                     lea(wreg_M, ptr[wreg_src + offset_src]);
535                     lea(wreg_MT, ptr[wreg_F + offset_F]);
536                     trans16x16();
537                 }
538             }
539         }
540     };
541
542     auto store_dst = [=]() {
543         mov(wreg_dst, ptr[param1 + GET_OFF(dst)]);
544         mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]);
545
546         Label Loop_j;
547         mov(wreg_cnt_j, 0);
548         mov(wreg_dst_aux, wreg_dst);
549         mov(wreg_Fw_aux, wreg_Fw);
550
551         int dim5 = jcp.dimK_nb_block * (jcp.dimM_block * jcp.dimM_reg_block)
552             * jcp.dimK_block * simd_w * simd_w;
553
554         L(Loop_j);
555         {
556             for (int i = 0; i < alpha; i++) {
557                 // touch pages
558                 vmovups(zmm_load(0), ptr[wreg_Fw_aux
559                     + (i * simd_w * simd_w) * typesize]);
560                 mov(wreg_dst_idx, i * dim5 * typesize);
561                 vmovntps(ptr[wreg_dst_aux + wreg_dst_idx], zmm_load(0));
562             }
563             for (int i = 0; i < alpha; i++) {
564                 for (int v1 = 1; v1 < simd_w; v1++) {
565                     int offset_Fw = (i * simd_w * simd_w  + v1 * simd_w)
566                         * typesize;
567                     vmovups(zmm_load(v1), ptr[wreg_Fw_aux + offset_Fw]);
568                 }
569                 mov(wreg_dst_idx, i * dim5 * typesize);
570                 for (int v1 = 1; v1 < simd_w; v1++) {
571                     int offset_dst = v1 * simd_w * typesize;
572                     vmovntps(ptr[wreg_dst_aux + wreg_dst_idx + offset_dst],
573                         zmm_load(v1));
574                 }
575             }
576             add(wreg_Fw_aux, alpha * simd_w * simd_w * typesize);
577             add(wreg_dst_aux, alpha * dim5 * typesize);
578             add(wreg_cnt_j, 1);
579             cmp(wreg_cnt_j, alpha);
580             jl(Loop_j, T_NEAR);
581         }
582     };
583
584     auto trans_W_4x4_3x3 = [=]() {
585         auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
586             vmovups(dst, a);
587             vfmadd231ps(dst, b, c);
588         };
589         auto fms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
590             vmulps(zmm_temp, b, c);
591             vsubps(dst, a, zmm_temp);
592         };
593         auto fnms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
594             vsubps(dst, zmm_zero, a);
595             vfnmadd231ps(dst, b, c);
596         };
597
598         mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]);
599         mov(wreg_F, ptr[param1 + GET_OFF(M)]);
600         mov(wreg_T, ptr[param1 + GET_OFF(T)]);
601
602         Label Loop_j;
603         mov(wreg_cnt_j, 0);
604         L(Loop_j);
605             mov(wreg_F_aux, wreg_F);
606             mov(wreg_Fw_aux, wreg_Fw);
607             mov(wreg_temp, wreg_cnt_j);
608             shl(wreg_temp, 4 + 2);
609             lea(wreg_F_aux, ptr[wreg_F + wreg_temp]);
610             lea(wreg_Fw_aux, ptr[wreg_Fw + wreg_temp]);
611
612             for (int i = 0; i < 3; i++) {
613                 for (int idx = 0; idx < 3; idx ++) {
614                     vmovups(zmm_F(idx), ptr[wreg_F_aux + (idx * 3 * simd_w
615                         * simd_w + i * simd_w * simd_w) * typesize]);
616                 }
617                 vmulps(zmm_t(0), zmm_G(0), zmm_F(2));
618                 fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_F(0));
619                 fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_F(0));
620
621                 vmulps(zmm_T(0), zmm_G(3), zmm_F(0));
622                 fms4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_F(1));
623                 fma4(zmm_T(2), zmm_t(1), zmm_G(4), zmm_F(1));
624                 fma4(zmm_T(3), zmm_t(2), zmm_G(5), zmm_F(1));
625                 fms4(zmm_T(4), zmm_t(2), zmm_G(5), zmm_F(1));
626                 vmovaps(zmm_T(5), zmm_F(2));
627
628                 for (int idx = 0; idx < 6; idx ++) {
629                     vmovups(ptr[wreg_T + (idx * 3 * simd_w + i * simd_w)
630                         * typesize], zmm_T(idx));
631                 }
632             }
633             for (int i = 0; i < 6; i++) {
634
635                 for (int idx = 0; idx < 3; idx ++) {
636                     vmovups(zmm_T(idx), ptr[wreg_T
637                         + (i * 3 * simd_w + idx * simd_w) * typesize]);
638                 }
639                 vmulps(zmm_t(0), zmm_G(0), zmm_T(2));
640                 fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_T(0));
641                 fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_T(0));
642
643                 vmulps(zmm_F(0), zmm_G(3), zmm_T(0));
644                 fms4(zmm_F(1), zmm_t(1), zmm_G(4), zmm_T(1));
645                 fma4(zmm_F(2), zmm_t(1), zmm_G(4), zmm_T(1));
646                 fma4(zmm_F(3), zmm_t(2), zmm_G(5), zmm_T(1));
647                 fms4(zmm_F(4), zmm_t(2), zmm_G(5), zmm_T(1));
648                 vmovaps(zmm_F(5), zmm_T(2));
649
650                 for (int l = 0; l < 6; l++) {
651                     vmovups(ptr[wreg_Fw_aux + (i * 6 * simd_w * simd_w
652                         + l * simd_w * simd_w) * typesize], zmm_F(l));
653                 }
654             }
655         add(wreg_cnt_j, 1);
656         cmp(wreg_cnt_j, 16);
657         jl(Loop_j, T_NEAR);
658     };
659
660     auto inner_loops = [=]() {
661         load_src();
662         init_G();
663         trans_W_4x4_3x3();
664         store_dst();
665     };
666
667     preamble();
668     inner_loops();
669     postamble();
670 }
671
672 void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
673     ::output_transform_data_ker_generate()
674 {
675     bool is_fwd = one_of(jcp.prop_kind,
676         mkldnn_forward_training, mkldnn_forward_inference);
677     int outw = is_fwd ? jcp.ow : jcp.iw;
678     int outh = is_fwd ? jcp.oh : jcp.ih;
679     bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D;
680     bool with_bias = jcp.with_bias;
681     bool with_relu = jcp.with_eltwise;
682     bool with_relu_postsum = jcp.with_relu_postsum;
683     bool with_sum = jcp.with_sum;
684
685     auto zmm_zero = Xbyak::Zmm(0);
686     auto zmm_temp = Xbyak::Zmm(31);
687     auto zmm_G = [=](int i) {
688         return Xbyak::Zmm(1 + i);
689     };
690     auto zmm_O = [=](int i) {
691         return Xbyak::Zmm(1 + alpha + i);
692     };
693     auto zmm_T = [=](int i) {
694         return Xbyak::Zmm(1 + 2 * alpha + i);
695     };
696     auto zmm_t = [=](int i) {
697         return Xbyak::Zmm(1 + 3 * alpha + i);
698     };
699
700     auto init_G = [=]() {
701         mov(oreg_temp, ptr[param1 + GET_OFF(G)]);
702         for (int i = 0; i < 6; i++) {
703             vbroadcastss(zmm_G(i), ptr[oreg_temp + i * typesize]);
704         }
705     };
706
707     auto load_src = [=]() {
708         mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]);
709         mov(oreg_src, ptr[param1 + GET_OFF(src)]);
710
711         mov(oreg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]);
712         imul(oreg_nb_tile_block_ur, oreg_nb_tile_block_ur,
713             (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block
714             * jcp.dimM_simd_block * typesize);
715         add(oreg_src, oreg_nb_tile_block_ur);
716
717         mov(oreg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]);
718         imul(oreg_tile_block_ur, oreg_tile_block_ur,
719             jcp.dimM_simd_block * typesize);
720         add(oreg_src, oreg_tile_block_ur);
721
722         if (not_tiled) {
723             mov(oreg_tile_block, ptr[param1 + GET_OFF(tile_block)]);
724             imul(oreg_tile_block, oreg_tile_block,
725                 jcp.dimM_nb_block * alpha * alpha * jcp.dimN_block
726                 * (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block
727                 * jcp.dimM_simd_block * typesize);
728             add(oreg_src, oreg_tile_block);
729         }
730
731         int last4dim = jcp.dimN_block * (jcp.dimM_block * jcp.dimM_reg_block)
732             * jcp.dimN_reg_block * jcp.dimM_simd_block * typesize;
733         for (int j = 0; j < alpha; j++) {
734             for (int i = 0; i < alpha; i++) {
735                 int j_base_offset = j * alpha * last4dim;
736                 int i_base_offset = i * last4dim;
737                 vmovups(zmm_temp, ptr[oreg_src + j_base_offset + i_base_offset]);
738                 vmovups(ptr[oreg_Ow + (j * alpha * simd_w + i * simd_w)
739                     * typesize], zmm_temp);
740             }
741         }
742     };
743
744     auto store_dst = [=]() {
745         vpxord(zmm_zero, zmm_zero, zmm_zero);
746         mov(oreg_dst, ptr[param1 + GET_OFF(dst)]);
747         mov(oreg_O, ptr[param1 + GET_OFF(M)]);
748         mov(oreg_ydim, ptr[param1 + GET_OFF(tj)]);
749         shl(oreg_ydim, 2); // tj * tile_size (==4)
750         mov(oreg_xdim, ptr[param1 + GET_OFF(ti)]);
751         shl(oreg_xdim, 2); // ti * tilesize (==4)
752
753         if (with_bias)
754             mov(oreg_bias, ptr[param1 + GET_OFF(bias)]);
755
756         auto store_one = [=](int j, int i, bool is_aligned) {
757             auto zmm_O = Xbyak::Zmm(31);
758             auto zmm_relu_ns = Xbyak::Zmm(30);
759             auto xmm_relu_ns = Xbyak::Xmm(30);
760             int offset = (j * tile_size * simd_w + i * simd_w) * typesize;
761
762             vmovups(zmm_O, ptr[oreg_O + offset]);
763             if (is_fwd) {
764                 if (with_bias) {
765                     vaddps(zmm_O, zmm_O, ptr[oreg_bias]);
766                 }
767                 if (with_relu) {
768                     if (jcp.eltwise.alpha == 0) {
769                         vmaxps(zmm_O, zmm_O, zmm_zero);
770                     } else {
771                         Opmask kmask = Opmask(7);
772                         mov(imm_addr64, float2int(jcp.eltwise.alpha));
773                         vmovq(xmm_relu_ns, imm_addr64);
774                         vbroadcastss(zmm_relu_ns, xmm_relu_ns);
775                         vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os);
776                         vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns);
777                     }
778                 }
779             }
780             if (with_sum) {
781                 vaddps(zmm_O, zmm_O, ptr[oreg_out_j + oreg_temp]);
782                 if (with_relu_postsum) // orig: with_relu_postsum
783                     vmaxps(zmm_O, zmm_O, zmm_zero);
784             }
785             if (is_aligned)
786                 vmovntps(ptr[oreg_out_j + oreg_temp], zmm_O);
787             else
788                 vmovups(ptr[oreg_out_j + oreg_temp], zmm_O);
789         };
790
791         auto i_loop = [=](int j, bool is_aligned) {
792             for (int i = 0; i < tile_size; i++) {
793                 Label next;
794                 mov(oreg_temp, oreg_xdim);
795                 add(oreg_temp, i);
796                 cmp(oreg_temp, outw);
797                 jge(next, T_NEAR);
798                 shl(oreg_temp, 4 + 2); // * 16 * 4
799
800                 store_one(j, i, is_aligned);
801
802                 L(next);
803             }
804         };
805
806
807         for (int j = 0; j < tile_size; j++) {
808             Label next, unaligned;
809             mov(oreg_temp, oreg_ydim);
810             add(oreg_temp, j);
811             cmp(oreg_temp, outh);
812             jge(next, T_NEAR);
813
814             mov(oreg_out_j, oreg_dst);
815             imul(oreg_temp, oreg_temp, outw * simd_w * typesize);
816             add(oreg_out_j, oreg_temp);
817
818             test(oreg_dst, 63);
819             jnz(unaligned, T_NEAR);
820
821             i_loop(j, true);
822             jmp(next, T_NEAR);
823
824             L(unaligned);
825             i_loop(j, false);
826
827             L(next);
828         }
829     };
830
831     auto trans_O_4x4_3x3 = [=]() {
832         auto fma2 = [=](Zmm dst, Zmm v1, Zmm u1, Zmm v2, Zmm u2){
833             vmulps(dst, v1, u1);
834             vfmadd231ps(dst, v2, u2);
835         };
836         mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]);
837         mov(oreg_T, ptr[param1 + GET_OFF(T)]);
838         mov(oreg_O, ptr[param1 + GET_OFF(M)]);
839
840         for (int i = 0; i < alpha; i++) {
841             for (int j = 0; j < alpha; j++) {
842                 vmovups(zmm_O(j), ptr[oreg_Ow + (j * alpha * simd_w
843                     + i * simd_w) * typesize]);
844             }
845
846             vaddps(zmm_t(0), zmm_O(1), zmm_O(2));
847             vaddps(zmm_t(1), zmm_O(3), zmm_O(4));
848             vsubps(zmm_t(2), zmm_O(1), zmm_O(2));
849             vsubps(zmm_t(3), zmm_O(3), zmm_O(4));
850
851             vaddps(zmm_T(0), zmm_t(0), zmm_t(1));
852             vaddps(zmm_T(0), zmm_T(0), zmm_O(0));
853             fma2(zmm_T(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1));
854             fma2(zmm_T(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3));
855             fma2(zmm_T(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5));
856             vaddps(zmm_T(3), zmm_T(3), zmm_O(5));
857
858             for (int j = 0; j < tile_size; j++) {
859                 vmovups(ptr[oreg_T + (j * alpha * simd_w
860                     + i * simd_w) * typesize], zmm_T(j));
861             }
862         }
863         for (int j = 0; j < tile_size; j++) {
864             for (int i = 0; i < alpha; i++) {
865                 vmovups(zmm_T(i), ptr[oreg_T + (j * alpha * simd_w
866                     + i * simd_w) * typesize]);
867             }
868             vaddps(zmm_t(0), zmm_T(1), zmm_T(2));
869             vaddps(zmm_t(1), zmm_T(3), zmm_T(4));
870             vsubps(zmm_t(2), zmm_T(1), zmm_T(2));
871             vsubps(zmm_t(3), zmm_T(3), zmm_T(4));
872
873             vaddps(zmm_O(0), zmm_t(0), zmm_t(1));
874             vaddps(zmm_O(0), zmm_O(0), zmm_T(0));
875             fma2(zmm_O(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1));
876             fma2(zmm_O(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3));
877             fma2(zmm_O(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5));
878             vaddps(zmm_O(3), zmm_O(3), zmm_T(5));
879
880             for (int i = 0; i < tile_size; i++) {
881                 vmovups(ptr[oreg_O + (j * tile_size * simd_w
882                     + i * simd_w) * typesize], zmm_O(i));
883             }
884         }
885     };
886
887     auto inner_loops = [=]() {
888         init_G();
889         load_src();
890         trans_O_4x4_3x3();
891         store_dst();
892     };
893
894     preamble();
895     inner_loops();
896     postamble();
897 }
898
899 void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
900     ::input_transform_data_ker_generate()
901 {
902     bool is_fwd = one_of(jcp.prop_kind,
903         mkldnn_forward_training, mkldnn_forward_inference);
904     int inpw = is_fwd ? jcp.iw : jcp.ow;
905     int inph = is_fwd ? jcp.ih : jcp.oh;
906     int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow;
907     int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh;
908     int wp_max = inpw + l_pad;
909     int hp_max = inph + t_pad;
910     bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D;
911     int G_size = 9;
912
913     auto zmm_zero = Xbyak::Zmm(0);
914     auto zmm_temp = Xbyak::Zmm(31);
915     auto zmm_G = [=](int i) {
916         return Xbyak::Zmm(1 + i);
917     };
918     auto zmm_I = [=](int i) {
919         return Xbyak::Zmm(1 + G_size + i);
920     };
921     auto zmm_T = [=](int i) {
922         return Xbyak::Zmm(1 + G_size + alpha + i);
923     };
924     auto zmm_t = [=](int i) {
925         return Xbyak::Zmm(1 + G_size + 2 * alpha + i);
926     };
927
928     auto init_G = [=]() {
929         mov(ireg_temp, ptr[param1 + GET_OFF(G)]);
930         for (int i = 0; i < G_size; i++) {
931             vbroadcastss(zmm_G(i), ptr[ireg_temp + i * typesize]);
932         }
933     };
934
935     auto load_src = [=]() {
936         mov(ireg_src, ptr[param1 + GET_OFF(src)]); // base addr of inp
937         mov(ireg_I, ptr[param1 + GET_OFF(M)]);
938
939         xor_(ireg_zero,  ireg_zero);
940         vpxord(zmm_zero, zmm_zero, zmm_zero);
941
942         mov(ireg_ydim, ptr[param1 + GET_OFF(tj)]);
943         shl(ireg_ydim, 2); // tj * tile_size (==4)
944         mov(ireg_xdim, ptr[param1 + GET_OFF(ti)]);
945         shl(ireg_xdim, 2); // ti * tilesize (==4)
946
947         for (int j = 0; j < alpha; j++) {
948             mov(ireg_temp, ireg_ydim);
949             add(ireg_temp, j);
950
951             mov(ireg_mask_j, 0xffff);
952             cmp(ireg_temp, t_pad);
953             cmovl(ireg_mask_j, ireg_zero);
954             cmp(ireg_temp, hp_max);
955             cmovge(ireg_mask_j, ireg_zero);
956
957             sub(ireg_temp, t_pad);
958             imul(ireg_temp, ireg_temp, inpw * simd_w * typesize);
959             mov(ireg_inp_j, ireg_src);
960             add(ireg_inp_j, ireg_temp);
961
962             for (int i = 0; i < alpha; i++) {
963
964                 mov(ireg_temp, ireg_xdim);
965                 add(ireg_temp, i);
966
967                 mov(ireg_mask, 0xffff);
968                 cmp(ireg_temp, l_pad);
969                 cmovl(ireg_mask, ireg_zero);
970                 cmp(ireg_temp, wp_max);
971                 cmovge(ireg_mask, ireg_zero);
972                 and_(ireg_mask, ireg_mask_j);
973
974                 sub(ireg_temp, l_pad);
975                 shl(ireg_temp, 4 + 2);
976
977                 vpxord(zmm_temp, zmm_temp, zmm_temp);
978                 Opmask kmask = Opmask(7);
979                 kmovw(kmask, ireg_mask_32);
980                 vmovups(zmm_temp | kmask, ptr[ireg_inp_j + ireg_temp]);
981                 vmovups(ptr[ireg_I + (j * alpha * simd_w + i * simd_w)
982                     * typesize], zmm_temp);
983             }
984         }
985     };
986
987     auto store_Iw = [=]() {
988
989         mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]);
990         mov(ireg_output, ptr[param1 + GET_OFF(dst)]);
991
992        bool streamout
993           = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float)
994             > 2 * LLC_data_size
995             ? true : false;
996
997         if (not_tiled) {
998             mov(ireg_tile_block, ptr[param1 + GET_OFF(tile_block)]);
999             imul(ireg_tile_block, ireg_tile_block,
1000                 alpha * alpha * jcp.dimN_block * jcp.dimK_nb_block
1001                 * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
1002                 * typesize);
1003         }
1004
1005         mov(ireg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]);
1006         imul(ireg_nb_tile_block_ur, ireg_nb_tile_block_ur,
1007             jcp.dimK_nb_block * jcp.dimK_block * jcp.dimN_reg_block
1008             * jcp.dimK_reg_block * typesize);
1009
1010         mov(ireg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]);
1011         imul(ireg_tile_block_ur, ireg_tile_block_ur,
1012             jcp.dimK_reg_block * typesize);
1013
1014         add(ireg_output, ireg_nb_tile_block_ur);
1015         add(ireg_output, ireg_tile_block_ur);
1016         if (not_tiled)
1017             add(ireg_output, ireg_tile_block);
1018
1019         for (int j = 0; j < alpha; j++) {
1020             for (int i = 0; i < alpha; i++) {
1021                 vmovups(zmm_temp,ptr[ireg_Iw + (j * alpha * simd_w
1022                     + i * simd_w) * typesize]);
1023
1024                 int j_base_offset =
1025                     j * alpha * jcp.dimN_block * jcp.dimK_nb_block
1026                     * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
1027                     * typesize;
1028                 int i_base_offset =
1029                     i * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block
1030                     * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize;
1031
1032                 if (not_tiled && streamout)
1033                     vmovntps(ptr[ireg_output + j_base_offset + i_base_offset],
1034                         zmm_temp);
1035                 else
1036                     vmovups(ptr[ireg_output + j_base_offset + i_base_offset],
1037                         zmm_temp);
1038             }
1039         }
1040     };
1041
1042     auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
1043         vmulps(zmm_temp, a, b);
1044         vaddps(dst, zmm_temp, c);
1045     };
1046
1047     auto trans_I_4x4_3x3 = [=]() {
1048         mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]);
1049         mov(ireg_T, ptr[param1 + GET_OFF(T)]);
1050         mov(ireg_I, ptr[param1 + GET_OFF(M)]);
1051
1052         mov(ireg_output, ptr[param1 + GET_OFF(dst)]); // for prefetch
1053         for (int i = 0; i < alpha; i++) {
1054             for (int idx = 0; idx < alpha; idx++) {
1055                 vmovups(zmm_I(idx), ptr[ireg_I + (idx * alpha * simd_w
1056                     + i * simd_w) * typesize]);
1057                 int j_base_offset =
1058                     i * alpha * jcp.dimN_block * jcp.dimK_nb_block
1059                     * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
1060                     * typesize;
1061                 int idx_base_offset =
1062                     idx * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block
1063                     * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize;
1064                 prefetcht0(ptr[ireg_output + j_base_offset + idx_base_offset]);
1065             }
1066
1067             fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4));
1068             fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3));
1069             fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4));
1070             fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3));
1071             fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4));
1072             fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5));
1073
1074             fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4));
1075             fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0));
1076             fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0));
1077             fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2));
1078             fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2));
1079             fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5));
1080
1081             for (int idx = 0; idx < alpha; idx++) {
1082                 vmovups(ptr[ireg_T + (idx * alpha * simd_w + i * simd_w)
1083                     * typesize],zmm_T(idx));
1084             }
1085         }
1086         for (int i = 0; i < alpha; i++) {
1087             for (int idx = 0; idx < alpha; idx++) {
1088                 vmovups(zmm_T(idx), ptr[ireg_T + (i * alpha * simd_w + idx
1089                     * simd_w) * typesize]);
1090             }
1091
1092             fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4));
1093             fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3));
1094             fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4));
1095             fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3));
1096             fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4));
1097             fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5));
1098
1099             fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4));
1100             fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0));
1101             fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0));
1102             fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2));
1103             fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2));
1104             fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5));
1105
1106             for (int idx = 0; idx < alpha; idx++) {
1107                 vmovups(ptr[ireg_Iw + (i * alpha * simd_w + idx * simd_w)
1108                     * typesize],zmm_I(idx));
1109             }
1110         }
1111     };
1112
1113     auto inner_loops = [=]() {
1114         init_G();
1115         load_src();
1116         trans_I_4x4_3x3();
1117         store_Iw();
1118     };
1119
1120     preamble();
1121     inner_loops();
1122     postamble();
1123 }
1124
1125 status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_common(
1126         jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1127         const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
1128         const memory_desc_wrapper &dst_d)
1129 {
1130     if (!mayiuse(avx512_core)) {
1131         return status::unimplemented;
1132     }
1133
1134     jcp.nthr = mkldnn_get_max_threads();
1135
1136     jcp.ver = ver_avx512_core;
1137     jcp.prop_kind = cd.prop_kind;
1138
1139     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1140
1141     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1142     jcp.mb = src_d.dims()[0];
1143     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
1144     jcp.oc_without_padding = jcp.oc;
1145     jcp.ic = src_d.dims()[1] / jcp.ngroups;
1146     jcp.ih = src_d.dims()[2];
1147     jcp.iw = src_d.dims()[3];
1148     jcp.oh = dst_d.dims()[2];
1149     jcp.ow = dst_d.dims()[3];
1150     jcp.kh = weights_d.dims()[with_groups + 2];
1151     jcp.kw = weights_d.dims()[with_groups + 3];
1152     jcp.t_pad = cd.padding[0][0];
1153     jcp.l_pad = cd.padding[0][1];
1154     jcp.stride_h = cd.strides[0];
1155     jcp.stride_w = cd.strides[1];
1156     jcp.dilate_h = cd.dilates[0];
1157     jcp.dilate_w = cd.dilates[1];
1158     jcp.r_pad = nstl::max(
1159             0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
1160     jcp.b_pad = nstl::max(
1161             0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
1162     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1163     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1164     jcp.ohp = jcp.oh;
1165     jcp.owp = jcp.ow;
1166
1167     bool ok_to_pad_channels = jcp.ngroups == 1;
1168     if (ok_to_pad_channels) {
1169         jcp.oc = rnd_up(jcp.oc, simd_w);
1170         jcp.ic = rnd_up(jcp.ic, simd_w);
1171     }
1172
1173     // Checking conditions not supported by these kernels
1174     if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
1175                is_winograd_faster_than_direct(jcp)))
1176         return status::unimplemented;
1177
1178     if (jcp.ngroups != 1)
1179         return status::unimplemented;
1180     if ((jcp.kh != 3) || (jcp.kw != 3))
1181         return status::unimplemented;
1182     if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
1183         return status::unimplemented;
1184     if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
1185         return status::unimplemented;
1186     if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
1187         return status::unimplemented;
1188
1189     if (src_d.format() != nChw16c)
1190         return status::unimplemented;
1191     if (!one_of(weights_d.format(), any,
1192                 with_groups ? gOIhw16i16o : OIhw16i16o, wino_fmt))
1193         return status::unimplemented;
1194     if (dst_d.format() != nChw16c)
1195         return status::unimplemented;
1196
1197     bool layout_consistency = true
1198             && jcp.ic <= src_d.blocking_desc().padding_dims[1]
1199             && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
1200             && (weights_d.format() == any || weights_d.format() == wino_fmt
1201                     || (jcp.ic <= weights_d.blocking_desc()
1202                                             .padding_dims[with_groups + 1]
1203                             && jcp.oc <= weights_d.blocking_desc()
1204                                             .padding_dims[with_groups + 0]));
1205     if (!layout_consistency)
1206         return status::unimplemented;
1207
1208     return status::success;
1209 }
1210
1211 void set_kernel_dims_reg_block(jit_conv_winograd_conf_t &jcp) {
1212
1213     /* ----------- dimM reg block ---------------------*/
1214     auto test_cond_dimM_reg_block = [](jit_conv_winograd_conf_t &jcp,
1215             int dimM_reg_block, int current_best) {
1216         int max_dimM_reg_block = jcp.kernel_kind == embd_bcast ? 1 : 4;
1217         return (dimM_reg_block >= 1)
1218                 && (dimM_reg_block <= max_dimM_reg_block )
1219                 && (dimM_reg_block > current_best);
1220     };
1221     jcp.dimM_reg_block = get_divisor_satisfying_cond(jcp,
1222         jcp.dimM/jcp.dimM_simd_block, 1, test_cond_dimM_reg_block);
1223
1224     /* ----------- dimN reg block ---------------------*/
1225
1226     auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
1227             int dimN_reg_block, int current_best) {
1228         return jcp.kernel_kind == embd_bcast
1229             ? dimN_reg_block < jcp.nb_reg && dimN_reg_block > current_best
1230             : dimN_reg_block >= 1
1231               && (dimN_reg_block * jcp.dimM_reg_block + dimN_reg_block)
1232                  < jcp.nb_reg
1233               && dimN_reg_block > current_best;
1234     };
1235     jcp.dimN_reg_block = get_divisor_satisfying_cond(jcp,
1236         jcp.dimN, 1, test_cond_dimN_reg_block);
1237 }
1238
1239 status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) {
1240     if (jcp.ver != ver_avx512_core)
1241         return status::unimplemented;
1242
1243     jcp.kernel_kind = embd_bcast;
1244
1245     set_kernel_dims_reg_block(jcp);
1246
1247     /*-------------- L2 blocking for dimN block ---------*/
1248
1249     auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp,
1250         int dimN_block, int current_best) {
1251         return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0)
1252             && (dimN_block > current_best)
1253             && ((jcp.dimN / dimN_block / jcp.dimN_reg_block)
1254             >= 1.5 * mkldnn_get_max_threads());
1255     };
1256
1257     jcp.dimN_block = get_divisor_satisfying_cond(
1258             jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block);
1259     jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
1260
1261     if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2)
1262         && (jcp.dimN_nb_block >= 1.5 * mkldnn_get_max_threads())) {
1263
1264         /* ------------------- L1 blocking for GEMM --------------*/
1265         /* -------------------- Choose dimK block ----------------*/
1266
1267         auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp,
1268                 int dimK_block, int current_best) {
1269             return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.5)
1270                 && (dimK_block > current_best);
1271         };
1272
1273         jcp.dimK_block = get_divisor_satisfying_cond(
1274                 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block);
1275
1276         if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 1.0)) {
1277             jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
1278
1279             /* -------------- Choose dimM block -------------------*/
1280             auto test_cond_dimM_block = [](jit_conv_winograd_conf_t &jcp,
1281                     int dimM_block, int current_best) {
1282                 return check_L1_block_gemm(jcp, jcp.dimK_block, dimM_block,
1283                     0.2, 0.5) && (dimM_block > current_best);
1284             };
1285
1286             jcp.dimM_block = get_divisor_satisfying_cond(jcp,
1287                 jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1,
1288                 test_cond_dimM_block);
1289             jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block
1290                 / jcp.dimM_simd_block;
1291
1292             jcp.sched_policy = WSCHED_DATA_W_SGD;
1293             return status::success;
1294         }
1295
1296     }
1297     return status::unimplemented;
1298 }
1299
1300 void set_kernel_blocking_DATA_W_S_G_D(jit_conv_winograd_conf_t &jcp) {
1301
1302     set_kernel_dims_reg_block(jcp);
1303
1304     //********************* Choosing dimK_block **********************//
1305     auto test_cond1_dimK_block = [](
1306             jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
1307         return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block,
1308                        1, jcp.dimM_reg_block, jcp.dimM_simd_block, .75f)
1309                 && (dimK_block > current_best);
1310     };
1311
1312     auto test_cond1_bis_dimK_block = [](
1313             jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
1314         return check_cond1_bis(jcp.dimN_reg_block, dimK_block,
1315                    jcp.dimK_reg_block, 1, jcp.dimM_reg_block,
1316                    jcp.dimM_simd_block, .9f)
1317                 && (dimK_block > current_best);
1318     };
1319
1320     jcp.dimK_block = get_divisor_satisfying_cond(
1321             jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block);
1322     // If we are not able to use streams, we fall back to condition [1]
1323     if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
1324         jcp.dimK_block = get_divisor_satisfying_cond(
1325                 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block);
1326     jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block;
1327
1328     //********************* Choosing dimM_block **********************//
1329     auto test_cond1_dimM_block = [](
1330             jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
1331         return check_cond1(jcp.dimN_reg_block, jcp.dimK_block,
1332                    jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block,
1333                    jcp.dimM_simd_block, .5f)
1334                 && (dimM_block > current_best);
1335     };
1336
1337     auto test_cond1_bis_dimM_block = [](
1338             jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
1339         return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block,
1340                    jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block,
1341                    jcp.dimM_simd_block, .3f)
1342                 && (dimM_block > current_best);
1343     };
1344
1345     if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
1346         jcp.dimM_block = get_divisor_satisfying_cond(
1347                 jcp, jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1,
1348                 test_cond1_dimM_block);
1349     else
1350         jcp.dimM_block = get_divisor_satisfying_cond(jcp,
1351                 jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1,
1352                 test_cond1_bis_dimM_block);
1353     jcp.dimM_nb_block = jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_block
1354                         * jcp.dimM_reg_block);
1355
1356     //******************* Choosing dimN_block *******************//
1357     auto test_cond2_dimN_block = [](
1358             jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
1359         return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block,
1360                        jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block,
1361                        jcp.dimM_reg_block, jcp.dimM_simd_block, .9f)
1362                 && (dimN_block > current_best);
1363     };
1364
1365     jcp.dimN_block = get_divisor_satisfying_cond(
1366             jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
1367     jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block);
1368 }
1369
1370 status_t set_wsched_DATA_W_S_G_D_avx512_core(jit_conv_winograd_conf_t &jcp) {
1371
1372     jcp.kernel_kind = expl_bcast;
1373     set_kernel_blocking_DATA_W_S_G_D(jcp);
1374     if (!(check_kernel_cond(jcp.dimM_block, jcp.dimM_reg_block,
1375         jcp.dimM_simd_block, jcp.dimN_block, jcp.dimN_reg_block, jcp.dimK,
1376         .1f, .35f))) {
1377         jcp.kernel_kind = embd_bcast;
1378         set_kernel_blocking_DATA_W_S_G_D(jcp);
1379     }
1380     jcp.sched_policy = WSCHED_DATA_W_S_G_D;
1381     return status::success;
1382 }
1383
1384 status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_kernel(
1385         jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK)
1386 {
1387     jcp.nb_reg = 32;
1388     jcp.dimN = dimN;
1389     jcp.dimK = dimK;
1390     jcp.dimM = dimM;
1391     jcp.sched_policy = WSCHED_INVALID;
1392
1393     jcp.dimK_reg_block = 16;
1394     jcp.dimM_simd_block = 16;
1395
1396     if (jcp.kernel_kind == embd_bcast) {
1397         jcp.dimM_reg_block = 1;
1398     }
1399
1400     if (!(set_wsched_DATA_W_SGD_avx512_core(jcp) == status::success))
1401         set_wsched_DATA_W_S_G_D_avx512_core(jcp);
1402
1403     assert(jcp.sched_policy != WSCHED_INVALID);
1404     return status::success;
1405 }
1406
1407 bool jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::post_ops_ok(
1408         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
1409     const auto &p = attr.post_ops_;
1410
1411     auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
1412     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
1413
1414     switch (p.len_) {
1415     case 0: return true; // no post_ops
1416     case 1: return is_relu(0) || is_sum(0); // relu or sum
1417     case 2: return (is_sum(0) && is_relu(1))
1418                       || (is_relu(0) && is_sum(1)); // sum->relu or relu->sum
1419     case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu
1420     default: return false;
1421     }
1422
1423     return false;
1424 }
1425
1426 status_t jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::init_conf(
1427         jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1428         const cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
1429         const cpu_memory_t::pd_t &dst_pd, const primitive_attr_t &attr) {
1430
1431     status_t st = init_conf_common(jcp, cd,
1432                         *src_pd.desc(), *weights_pd.desc(), *dst_pd.desc());
1433
1434     if (st != status::success)
1435         return st;
1436
1437     // Winograd specific initialization
1438     jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
1439     jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
1440     jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
1441
1442     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
1443
1444     if (!post_ops_ok(jcp, attr))
1445         return status::unimplemented;
1446
1447     const auto &p = attr.post_ops_;
1448     const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1);
1449     jcp.with_eltwise = eltwise_ind != -1;
1450     if (jcp.with_eltwise)
1451         jcp.eltwise = p.entry_[eltwise_ind].eltwise;
1452
1453     jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
1454     jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1;
1455
1456     status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic);
1457
1458     jcp.ic_simd_block = jcp.dimK_reg_block;
1459     jcp.ic_block = jcp.dimK_block;
1460     jcp.nb_ic = jcp.dimK_nb_block;
1461     jcp.oc_simd_block = jcp.dimM_simd_block;
1462     jcp.oc_block = jcp.dimM_block;
1463     jcp.oc_reg_block = jcp.dimM_reg_block;
1464     jcp.ic_reg_block = 1;
1465     jcp.nb_oc = jcp.dimM_nb_block;
1466     jcp.tile_block_ur = jcp.dimN_reg_block;
1467     jcp.nb_tile_block_ur = jcp.dimN_block;
1468     jcp.tile_block = jcp.dimN_nb_block;
1469
1470     /* re-create weights primitive descriptor
1471     and set weights wino_blocking */
1472     if (cd.prop_kind == mkldnn_forward_inference) {
1473         memory_desc_t expect_wei_md = *weights_pd.desc();
1474
1475         expect_wei_md.format = mkldnn_wino_fmt;
1476         expect_wei_md.data_type = data_type::f32;
1477         mkldnn_wino_desc_t &wd = expect_wei_md.layout_desc.wino_desc;
1478         wd.wino_format = mkldnn_wino_wei_OBaaIBOIio;
1479         wd.r = 3;
1480         wd.alpha = 6;
1481
1482         wd.ic = jcp.ic;
1483         wd.oc = jcp.oc;
1484         wd.ic_block = jcp.dimK_reg_block;
1485         wd.oc_block = jcp.dimM_simd_block;
1486         wd.ic2_block = jcp.dimK_block;
1487         wd.oc2_block = jcp.dimM_block * jcp.dimM_reg_block;
1488         size_t max_size = sizeof(float) * wd.alpha * wd.alpha * jcp.ic * jcp.oc;
1489         wd.size = max_size;
1490         wd.adj_scale = 1.f;
1491
1492         cpu_memory_t::pd_t new_weights_pd(
1493             weights_pd.engine(), &expect_wei_md);
1494         if (weights_pd.desc()->format == memory_format::any)
1495             weights_pd = new_weights_pd;
1496         if (!weights_pd.is_equal(&new_weights_pd))
1497             return status::unimplemented;
1498     }
1499
1500     return res;
1501 }
1502
1503 status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel::init_conf(
1504         jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1505         const memory_desc_wrapper &diff_src_d,
1506         const memory_desc_wrapper &weights_d,
1507         const memory_desc_wrapper &diff_dst_d)
1508 {
1509     status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d);
1510
1511     if (st != status::success)
1512         return st;
1513
1514     jcp.itiles = (jcp.iw + tile_size - 1) / tile_size;
1515     jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size;
1516     jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
1517
1518     status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc);
1519
1520     jcp.oc_simd_block = jcp.dimK_reg_block;
1521     jcp.oc_block = jcp.dimK_block;
1522     jcp.nb_oc = jcp.dimK_nb_block;
1523     jcp.ic_simd_block = jcp.dimM_simd_block;
1524     jcp.ic_block = jcp.dimM_block;
1525     jcp.ic_reg_block = jcp.dimM_reg_block;
1526     jcp.oc_reg_block = 1;
1527     jcp.nb_ic = jcp.dimM_nb_block;
1528     jcp.tile_block_ur = jcp.dimN_reg_block;
1529     jcp.nb_tile_block_ur = jcp.dimN_block;
1530     jcp.tile_block = jcp.dimN_nb_block;
1531
1532     return res;
1533 }
1534
1535 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
1536 src_transform_generate() {
1537     constexpr int G_size = 9;
1538     const size_t ifwp = jcp.iw + jcp.l_pad;
1539     const size_t ifhp = jcp.ih + jcp.t_pad;
1540
1541     auto zmm_G = [=](int i) {
1542         return Xbyak::Zmm(i);
1543     };
1544     auto zmm_I = [=](int i) {
1545         return Xbyak::Zmm(G_size + i);
1546     };
1547     auto zmm_T = [=](int i) {
1548         return Xbyak::Zmm(G_size + alpha + i);
1549     };
1550     auto zmm_t = [=](int i) {
1551         return Xbyak::Zmm(G_size + 2 * alpha + i);
1552     };
1553
1554     auto init_G = [=]() {
1555         mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1556         for (int i = 0; i < G_size; i++) {
1557             vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]);
1558         }
1559     };
1560
1561     auto load_src = [=]() {
1562         mov(reg_I, ptr[reg_transp + GET_OFF(M)]);
1563         xor_(reg_zero, reg_zero);
1564
1565         mov(reg_ydim, reg_tj);
1566         shl(reg_ydim, 2); //tj * tile_size(=4)
1567
1568         for (int j = 0; j < alpha; j++) {
1569             /* check if tile index is within physical spatial boundaries*/
1570             mov(reg_maskj, 0xffff);
1571             cmp(reg_ydim, jcp.t_pad);
1572             cmovl(reg_maskj, reg_zero);
1573             cmp(reg_ydim, ifhp);
1574             cmovge(reg_maskj, reg_zero);
1575
1576             /*address offset for tile in src*/
1577             mov(reg_src_offset, reg_ydim);
1578             sub(reg_src_offset, jcp.t_pad); // tj*tile_size - t_pad
1579             imul(reg_src_offset, reg_src_offset, jcp.iw);
1580
1581             mov(reg_xdim, reg_ti);
1582             shl(reg_xdim, 2); // xdim = ti * tile_size
1583
1584             add(reg_src_offset, reg_xdim);
1585             sub(reg_src_offset, jcp.l_pad);
1586             imul(reg_src_offset, reg_src_offset, simd_w * typesize);
1587             for (int i = 0; i < alpha; i++) {
1588                 /* check if tile index is within physical spatial boundaries*/
1589                 mov(reg_maski, 0xffff);
1590                 cmp(reg_xdim, jcp.l_pad);
1591                 cmovl(reg_maski, reg_zero);
1592                 cmp(reg_xdim, ifwp);
1593                 cmovge(reg_maski, reg_zero);
1594                 and_(reg_maski, reg_maskj);
1595
1596                 Opmask kmask_src = Xbyak::Opmask(7);
1597                 auto zmm_src = Xbyak::Zmm(31);
1598                 kmovw(kmask_src, reg_maski_32);
1599                 vpxord(zmm_src, zmm_src, zmm_src);
1600                 vmovups(zmm_src | kmask_src, ptr[reg_src + reg_src_offset]);
1601                 vmovups(ptr[reg_I], zmm_src);
1602
1603                 add(reg_xdim, 1); //xdim = ti * tile_size + i
1604                 add(reg_src_offset, simd_w * typesize);
1605                 add(reg_I, simd_w * typesize);
1606             }
1607             add(reg_ydim, 1);
1608         }
1609     };
1610
1611     auto fma4 = [=](Xbyak::Zmm dst, Xbyak::Zmm a, Xbyak::Zmm b, Xbyak::Zmm c) {
1612         vmovups(dst, c);
1613         vfmadd231ps(dst, a, b);
1614     };
1615
1616     auto trans_I_3x3_4x4 = [=]() {
1617         //Use 24 registers
1618         mov(reg_I, ptr[reg_transp + GET_OFF(M)]);
1619         mov(reg_T, ptr[reg_transp + GET_OFF(T)]);
1620         for (int i = 0; i < alpha; i++) {
1621             for (int j = 0; j < alpha; j++) {
1622                 size_t I_off = (j * alpha + i) * simd_w * typesize;
1623                 vmovups(zmm_I(j), ptr[reg_I + I_off]);
1624             }
1625
1626             fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4));
1627             fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3));
1628             fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4));
1629             fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3));
1630             fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4));
1631             fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5));
1632
1633             fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4));
1634             fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0));
1635             fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0));
1636             fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2));
1637             fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2));
1638             fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5));
1639
1640             for (int j = 0; j < alpha; j++) {
1641                 vmovups(ptr[reg_T + (j * alpha + i) * simd_w * typesize],
1642                         zmm_T(j));
1643             }
1644
1645         }
1646
1647         for (int j = 0; j < alpha; j++) {
1648             for (int i = 0; i < alpha; i++) {
1649                 vmovups(zmm_T(i), ptr[reg_T + (j * alpha + i) * simd_w * typesize]);
1650             }
1651
1652             fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4));
1653             fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3));
1654             fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4));
1655             fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3));
1656             fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4));
1657             fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5));
1658
1659             fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4));
1660             fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0));
1661             fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0));
1662             fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2));
1663             fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2));
1664             fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5));
1665
1666             for (int i = 0; i < alpha; i++) {
1667                 size_t dst_off = (j * alpha * jcp.ic_block
1668                     * jcp.nb_tile_block_ur * jcp.tile_block_ur
1669                     + i * jcp.ic_block * jcp.nb_tile_block_ur * jcp.tile_block_ur)
1670                     * simd_w * typesize;
1671                 vmovups(ptr[reg_dst + dst_off], zmm_I(i));
1672             }
1673         }
1674     };
1675
1676     auto compute_transform_SDGtWo = [=]() {
1677         mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]);
1678         mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]);
1679         mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1680         mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1681         xor_(reg_tile_count, reg_tile_count);
1682         Label loop_mb, loop_jtiles, loop_itiles, done;
1683         L(loop_mb);
1684         {
1685             L(loop_jtiles);
1686             {
1687                 L(loop_itiles);
1688                 {
1689                     load_src();
1690
1691                     trans_I_3x3_4x4();
1692
1693                     add(reg_tile_count, 1);
1694                     cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1695                     jge(done);
1696
1697                     add(reg_dst, simd_w * typesize);
1698                     add(reg_ti, 1);
1699                     cmp(reg_ti, jcp.itiles);
1700                     jl(loop_itiles);
1701                 }
1702                 xor_(reg_ti, reg_ti);
1703                 add(reg_tj, 1);
1704                 cmp(reg_tj, jcp.jtiles);
1705                 jl(loop_jtiles);
1706             }
1707             xor_(reg_tj, reg_tj);
1708             add(reg_src, jcp.ic * jcp.iw * jcp.ih * typesize);
1709             jmp(loop_mb);
1710         }
1711         L(done);
1712     };
1713
1714     auto compute_transform = [=]() {
1715         mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1716         xor_(reg_ti, reg_ti);
1717         xor_(reg_tj, reg_tj);
1718
1719         mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1720         mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
1721         imul(reg_temp, reg_tile_count, simd_w * typesize);
1722         add(reg_dst, reg_temp);
1723
1724         Label loop_jtiles, loop_itiles, next_tile_block, next_tile;
1725         L(loop_jtiles);
1726
1727         {
1728             L(loop_itiles);
1729             {
1730                 load_src();
1731
1732                 trans_I_3x3_4x4();
1733
1734                 add(reg_tile_count, 1);
1735                 cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1736                 jge(next_tile_block);
1737                 add(reg_dst, simd_w * typesize);
1738                 jmp(next_tile);
1739
1740                 L(next_tile_block);
1741                 sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1)
1742                         * simd_w * typesize);
1743                 size_t tblk_off = alpha * alpha * jcp.ic_block
1744                     * jcp.nb_tile_block_ur * jcp.tile_block_ur
1745                     * simd_w * typesize;
1746                 add(reg_dst, tblk_off);
1747                 xor_(reg_tile_count, reg_tile_count);
1748
1749                 L(next_tile);
1750                 add(reg_ti, 1);
1751                 cmp(reg_ti, jcp.itiles);
1752                 jl(loop_itiles);
1753             }
1754             xor_(reg_ti, reg_ti);
1755             add(reg_tj, 1);
1756             cmp(reg_tj, jcp.jtiles);
1757             jl(loop_jtiles);
1758         }
1759     };
1760
1761     preamble();
1762     init_G();
1763     if (jcp.sched_policy == WSCHED_WEI_SDGtWo)
1764         compute_transform_SDGtWo();
1765     else
1766         compute_transform();
1767     postamble();
1768 }
1769
1770 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
1771 diff_dst_transform_generate(bool with_bias) {
1772
1773     constexpr int G_size = 8;
1774     auto zmm_G = [](int i) {
1775         return Xbyak::Zmm(31);
1776     };
1777
1778     auto zmm_src = [=](int j, int i) {
1779         return Xbyak::Zmm(G_size + j * 4 + i);
1780     };
1781
1782     auto zmm_bias = Xbyak::Zmm(31);
1783
1784     auto load_src = [=]() {
1785         if (with_bias) vmovups(zmm_bias, ptr[reg_bias]);
1786         mov(reg_ydim, reg_tj);
1787         shl(reg_ydim, 2); //tj * tile_size(=4)
1788         for (int j = 0; j < tile_size; j++) {
1789             /* check if tile index is within physical spatial boundaries*/
1790             mov(reg_maskj, 0xffff);
1791             cmp(reg_ydim, jcp.oh);
1792             cmovge(reg_maskj, reg_zero);
1793
1794             /*address offset for tile in src*/
1795             mov(reg_src_offset, reg_ydim);
1796             imul(reg_src_offset, reg_src_offset, jcp.ow);
1797
1798             mov(reg_xdim, reg_ti);
1799             shl(reg_xdim, 2); // xdim = ti * tile_size
1800
1801             add(reg_src_offset, reg_xdim);
1802             imul(reg_src_offset, reg_src_offset, simd_w * typesize);
1803             for (int i = 0; i < tile_size; i++) {
1804                 /* check if tile index is within physical spatial boundaries*/
1805                 mov(reg_maski, 0xffff);
1806                 cmp(reg_xdim, jcp.ow);
1807                 cmovge(reg_maski, reg_zero);
1808                 and_(reg_maski, reg_maskj);
1809
1810                 Opmask kmask_src = Xbyak::Opmask(7);
1811                 kmovw(kmask_src, reg_maski_32);
1812                 vpxord(zmm_src(j, i), zmm_src(j, i), zmm_src(j, i));
1813                 vmovups(zmm_src(j, i) | kmask_src, ptr[reg_src + reg_src_offset]);
1814                 if (with_bias) vaddps(zmm_bias | kmask_src, zmm_bias,
1815                         ptr[reg_src + reg_src_offset]);
1816
1817                 add(reg_xdim, 1); //xdim = ti * tile_size + i
1818                 add(reg_src_offset, simd_w * typesize);
1819             }
1820             add(reg_ydim, 1);
1821         }
1822         if(with_bias) vmovups(ptr[reg_bias], zmm_bias);
1823     };
1824
1825     auto zmm_t = [=](int i) {
1826         return Xbyak::Zmm(G_size + 16 + i);
1827     };
1828
1829     auto zmm_T = [=](int j, int i) {
1830         return Xbyak::Zmm(j * 4 + i);
1831     };
1832
1833     auto movps = [=](Xbyak::Reg64 reg_dst, size_t dst_off, Xbyak::Zmm a) {
1834         if (jcp.sched_policy == WSCHED_WEI_SDGtWo)
1835             vmovups(ptr[reg_dst + dst_off], a);
1836         else
1837             vmovntps(ptr[reg_dst + dst_off], a);
1838     };
1839
1840     auto trans_W_3x3_4x4 = [=]() {
1841         mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1842         for (int i = 0; i < tile_size; i++) {
1843             vbroadcastss(zmm_G(0), ptr[reg_G]);
1844             vmulps(zmm_t(0), zmm_src(2, i), zmm_G(0));
1845
1846             vbroadcastss(zmm_G(1), ptr[reg_G + typesize]);
1847             vmovups(zmm_t(1), zmm_t(0));
1848             vfmsub231ps(zmm_t(1), zmm_src(0, i), zmm_G(1));
1849
1850             vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]);
1851             vmovups(zmm_t(2), zmm_t(0));
1852             vfmadd231ps(zmm_t(2), zmm_src(0, i), zmm_G(2));
1853
1854             vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]);
1855             vmulps(zmm_t(3), zmm_src(1, i), zmm_G(3));
1856
1857             vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]);
1858             vfmadd231ps(zmm_t(3), zmm_src(3, i), zmm_G(4));
1859
1860             vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]);
1861             vmulps(zmm_t(4), zmm_src(1, i), zmm_G(5));
1862
1863             vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]);
1864             vfmadd231ps(zmm_t(4), zmm_src(3, i), zmm_G(6));
1865
1866             vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]);
1867             vmulps(zmm_T(0, i), zmm_src(0, i), zmm_G(7));
1868             vsubps(zmm_T(1, i), zmm_t(1), zmm_t(3));
1869             vaddps(zmm_T(2, i), zmm_t(1), zmm_t(3));
1870             vaddps(zmm_T(3, i), zmm_t(2), zmm_t(4));
1871             vsubps(zmm_T(4, i), zmm_t(2), zmm_t(4));
1872             vmovups(zmm_T(5, i), zmm_src(3, i));
1873         }
1874
1875         for (int j = 0; j < alpha; j++) {
1876             vbroadcastss(zmm_G(0), ptr[reg_G]);
1877             vmulps(zmm_t(0), zmm_T(j, 2), zmm_G(0));
1878
1879             vbroadcastss(zmm_G(1), ptr[reg_G + typesize]);
1880             vmovups(zmm_t(1), zmm_t(0));
1881             vfmsub231ps(zmm_t(1), zmm_T(j, 0), zmm_G(1));
1882
1883             vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]);
1884             vmovups(zmm_t(2), zmm_t(0));
1885             vfmadd231ps(zmm_t(2), zmm_T(j, 0), zmm_G(2));
1886
1887             vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]);
1888             vmulps(zmm_t(3), zmm_T(j, 1), zmm_G(3));
1889
1890             vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]);
1891             vfmadd231ps(zmm_t(3), zmm_T(j, 3), zmm_G(4));
1892
1893             vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]);
1894             vmulps(zmm_t(4), zmm_T(j, 1), zmm_G(5));
1895
1896             vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]);
1897             vfmadd231ps(zmm_t(4), zmm_T(j, 3), zmm_G(6));
1898
1899             vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]);
1900             vmulps(zmm_t(0), zmm_T(j, 0), zmm_G(7));
1901             vsubps(zmm_t(5), zmm_t(1), zmm_t(3));
1902             vaddps(zmm_t(1), zmm_t(1), zmm_t(3));
1903             vaddps(zmm_t(6), zmm_t(2), zmm_t(4));
1904             vsubps(zmm_t(2), zmm_t(2), zmm_t(4));
1905             vmovups(zmm_t(3), zmm_T(j, 3));
1906
1907             int alpha_offset = (jcp.oc / jcp.nb_oc)
1908                 * (jcp.ntiles / jcp.tile_block) * typesize;
1909             int dst_off = j * alpha * alpha_offset;
1910             movps(reg_dst, dst_off, zmm_t(0));
1911             dst_off += alpha_offset;
1912             movps(reg_dst, dst_off, zmm_t(5));
1913             dst_off += alpha_offset;
1914             movps(reg_dst, dst_off, zmm_t(1));
1915             dst_off += alpha_offset;
1916             movps(reg_dst, dst_off, zmm_t(6));
1917             dst_off += alpha_offset;
1918             movps(reg_dst, dst_off, zmm_t(2));
1919             dst_off += alpha_offset;
1920             movps(reg_dst, dst_off, zmm_t(3));
1921         }
1922
1923     };
1924     auto compute_transform_SDGtWo = [=]() {
1925         mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1926         mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1927         if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]);
1928
1929         xor_(reg_zero, reg_zero);
1930         xor_(reg_oc_ur, reg_oc_ur);
1931         Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, tiles_done;
1932
1933         L(loop_oc_ur);
1934         {
1935             mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]);
1936             mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]);
1937             xor_(reg_tile_count, reg_tile_count);
1938             L(loop_mb);
1939             {
1940                 L(loop_jtiles);
1941                 {
1942                     L(loop_itiles);
1943                     {
1944                         load_src();
1945
1946                         trans_W_3x3_4x4();
1947
1948                         add(reg_tile_count, 1);
1949                         cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1950                         jge(tiles_done);
1951
1952                         add(reg_dst, jcp.oc_reg_block * simd_w * typesize);
1953                         add(reg_ti, 1);
1954                         cmp(reg_ti, jcp.itiles);
1955                         jl(loop_itiles);
1956                     }
1957                     xor_(reg_ti, reg_ti);
1958                     add(reg_tj, 1);
1959                     cmp(reg_tj, jcp.jtiles);
1960                     jl(loop_jtiles);
1961                 }
1962                 xor_(reg_tj, reg_tj);
1963                 add(reg_src, jcp.oc * jcp.ow * jcp.oh * typesize);
1964                 jmp(loop_mb);
1965             }
1966
1967             L(tiles_done);
1968             mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1969             add(reg_dst, simd_w * typesize);
1970             mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1971             add(reg_src, jcp.oh * jcp.ow * simd_w * typesize);
1972
1973             if (with_bias) add(reg_bias, simd_w * typesize);
1974             add(reg_oc_ur, 1);
1975             cmp(reg_oc_ur, jcp.oc_reg_block);
1976             jl(loop_oc_ur);
1977         }
1978     };
1979
1980     auto compute_transform = [=]() {
1981         mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1982         mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1983         if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]);
1984
1985         mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1986         mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
1987         imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize);
1988         add(reg_dst, reg_temp);
1989
1990         xor_(reg_zero, reg_zero);
1991         xor_(reg_oc_ur, reg_oc_ur);
1992         Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, next_tile_block, next_tile;
1993
1994         L(loop_oc_ur);
1995         {
1996             xor_(reg_ti, reg_ti);
1997             xor_(reg_tj, reg_tj);
1998
1999             L(loop_jtiles);
2000             {
2001                 L(loop_itiles);
2002                 {
2003                     load_src();
2004
2005                     trans_W_3x3_4x4();
2006
2007                     add(reg_tile_count, 1);
2008                     cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
2009                     jge(next_tile_block);
2010                     add(reg_dst, jcp.oc_reg_block * simd_w * typesize);
2011                     jmp(next_tile);
2012
2013                     L(next_tile_block);
2014                     sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1)
2015                             * jcp.oc_reg_block * simd_w * typesize);
2016                     int tblk_off = alpha * alpha * (jcp.oc/jcp.nb_oc)
2017                         * (jcp.ntiles/jcp.tile_block) * typesize;
2018                     add(reg_dst, tblk_off);
2019                     xor_(reg_tile_count, reg_tile_count);
2020
2021                     L(next_tile);
2022                     add(reg_ti, 1);
2023                     cmp(reg_ti, jcp.itiles);
2024                     jl(loop_itiles);
2025                 }
2026                 xor_(reg_ti, reg_ti);
2027                 add(reg_tj, 1);
2028                 cmp(reg_tj, jcp.jtiles);
2029                 jl(loop_jtiles);
2030             }
2031
2032             mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
2033             mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
2034             imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize);
2035             add(reg_dst, reg_temp);
2036             add(reg_dst, simd_w * typesize);
2037             mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
2038             add(reg_src, jcp.oh * jcp.ow * simd_w * typesize);
2039
2040             if (with_bias) add(reg_bias, simd_w * typesize);
2041             add(reg_oc_ur, 1);
2042             cmp(reg_oc_ur, jcp.oc_reg_block);
2043             jl(loop_oc_ur);
2044         }
2045     };
2046
2047     preamble();
2048     if (jcp.sched_policy == WSCHED_WEI_SDGtWo) {
2049         compute_transform_SDGtWo();
2050     } else {
2051         compute_transform();
2052     }
2053     postamble();
2054 }
2055
2056 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
2057 diff_weights_transform_generate(bool first_tile) {
2058     int G_size = 4;
2059
2060     auto zmm_G = [](int i) {
2061         return Xbyak::Zmm(i);
2062     };
2063
2064     auto init_G = [=]() {
2065         mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
2066         for (int i = 0; i  < G_size; i++)
2067             vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]);
2068     };
2069
2070     auto zmm_src = [=](int i) {
2071         return Xbyak::Zmm(G_size + i);
2072     };
2073
2074     auto load_src = [=](int i) {
2075         for (int j = 0; j < alpha; j++) {
2076             size_t alpha_offset = jcp.oc_block * jcp.oc_reg_block
2077                 * jcp.ic_block * simd_w * simd_w * typesize;
2078             size_t src_off = (j * alpha + i) * alpha_offset;
2079             vmovups(zmm_src(j), EVEX_compress_addr(reg_src, src_off));
2080         }
2081     };
2082
2083     auto zmm_t = [=](int i) {
2084         return Xbyak::Zmm(G_size + 6 + i);
2085     };
2086
2087     auto zmm_T = [=](int j, int i) {
2088         return Xbyak::Zmm(G_size + 6 + 3 + j * 6 + i);
2089     };
2090
2091     auto zmm_dst = [=](int i) {
2092         return Xbyak::Zmm(G_size + i);
2093     };
2094
2095     auto zmm_temp = Xbyak::Zmm(31);
2096
2097     auto store_dst = [=](int j) {
2098         for (int i = 0; i < jcp.kw; i++) {
2099             size_t dst_off = (j * jcp.kw + i) * simd_w * simd_w * typesize;
2100
2101             if (!first_tile) {
2102                 vmovups(zmm_temp, EVEX_compress_addr(reg_dst, dst_off));
2103                 vaddps(zmm_dst(i), zmm_dst(i), zmm_temp);
2104             }
2105             vmovntps(EVEX_compress_addr(reg_dst, dst_off), zmm_dst(i));
2106         }
2107     };
2108
2109     auto compute_transform = [=] () {
2110         mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
2111         mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
2112
2113         xor_(reg_ic_simd, reg_ic_simd);
2114         Label loop_ic_simd;
2115         L(loop_ic_simd);
2116         {
2117             for (int i = 0; i < alpha; i++) {
2118                 load_src(i);
2119
2120                 vaddps(zmm_t(0), zmm_src(1), zmm_src(2));
2121                 vaddps(zmm_t(1), zmm_src(3), zmm_src(4));
2122                 vmovups(zmm_t(2), zmm_src(5));
2123                 vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0));
2124
2125                 vaddps(zmm_T(0, i), zmm_src(0), zmm_t(0));
2126                 vaddps(zmm_T(0, i), zmm_T(0, i), zmm_t(1));
2127                 vsubps(zmm_T(1, i), zmm_src(1), zmm_src(2));
2128                 vmulps(zmm_T(1, i), zmm_T(1, i), zmm_G(1));
2129                 vsubps(zmm_temp, zmm_src(3), zmm_src(4));
2130                 vfmadd231ps(zmm_T(1, i), zmm_temp, zmm_G(2));
2131                 vmovups(zmm_T(2, i), zmm_t(2));
2132                 vfmadd231ps(zmm_T(2, i), zmm_t(0), zmm_G(3));
2133             }
2134
2135             for (int j = 0; j < jcp.kh; j++) {
2136                 vaddps(zmm_t(0), zmm_T(j, 1), zmm_T(j, 2));
2137                 vaddps(zmm_t(1), zmm_T(j, 3), zmm_T(j, 4));
2138                 vmovups(zmm_t(2), zmm_T(j, 5));
2139                 vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0));
2140
2141                 vaddps(zmm_dst(0), zmm_T(j, 0), zmm_t(0));
2142                 vaddps(zmm_dst(0), zmm_dst(0), zmm_t(1));
2143                 vsubps(zmm_dst(1), zmm_T(j, 1), zmm_T(j, 2));
2144                 vmulps(zmm_dst(1), zmm_dst(1), zmm_G(1));
2145                 vsubps(zmm_temp, zmm_T(j, 3), zmm_T(j, 4));
2146                 vfmadd231ps(zmm_dst(1), zmm_temp, zmm_G(2));
2147                 vmovups(zmm_dst(2), zmm_t(2));
2148                 vfmadd231ps(zmm_dst(2), zmm_t(0), zmm_G(3));
2149
2150                 store_dst(j);
2151             }
2152
2153             add(reg_src, jcp.oc_reg_block * simd_w * typesize);
2154             add(reg_dst, simd_w * typesize);
2155             add(reg_ic_simd, 1);
2156             cmp(reg_ic_simd, simd_w);
2157             jl(loop_ic_simd);
2158         }
2159     };
2160     preamble();
2161     push(reg_EVEX_max_8b_offt);
2162     mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
2163     init_G();
2164     compute_transform();
2165     pop(reg_EVEX_max_8b_offt);
2166     postamble();
2167 }
2168
2169 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::gemm_loop_generate(
2170         bool is_first_tile)
2171 {
2172     auto zmm_srcA = [=]() {
2173         return Xbyak::Zmm(0);
2174     };
2175
2176     auto zmm_srcB = [=] (size_t N_ur){
2177         return Xbyak::Zmm(N_ur + 1);
2178     };
2179
2180     auto broadcastB = [=](size_t K_ur) {
2181         for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) {
2182             size_t srcB_off = (K_ur * jcp.dimN_reg_block + N_bcast)
2183                 * sizeof(float);
2184             vbroadcastss(zmm_srcB(N_bcast), EVEX_compress_addr(reg_srcB, srcB_off));
2185         }
2186     };
2187
2188     auto load_srcA = [=] (size_t K_ur, int M_ur) {
2189         size_t srcA_off = (K_ur * jcp.dimM_reg_block * jcp.dimM_simd_block
2190                         + M_ur * jcp.dimM_simd_block) * sizeof(float);
2191         vmovups(zmm_srcA(), EVEX_compress_addr(reg_srcA, srcA_off));
2192     };
2193
2194     auto zmm_dstC = [=](size_t M_reg_ur, int N_bcast){
2195         size_t idx = 1 // zmm_srcA
2196             + jcp.dimN_bcast_ur // zmm_srcB
2197             + M_reg_ur * jcp.dimN_bcast_ur + N_bcast;
2198         assert(idx < 32);
2199         return Xbyak::Zmm(idx);
2200     };
2201     auto prepare_accumm = [=](){
2202         for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) {
2203             for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) {
2204                 Zmm zmm = zmm_dstC(M_reg_ur, N_bcast);
2205                 vpxord(zmm, zmm, zmm);
2206             }
2207         }
2208     };
2209
2210     auto store_dstC = [=](){
2211         /******** Write C back to memory *******/
2212         for (int M_reg = 0; M_reg < jcp.dimM_reg_block; M_reg++) {
2213             for (int N_ur = 0; N_ur < jcp.dimN_bcast_ur; ++N_ur) {
2214                 Zmm zmm = zmm_dstC(M_reg, N_ur);
2215                 size_t C_off = (N_ur * jcp.dimM_reg_block * jcp.dimM_simd_block
2216                              + M_reg * jcp.dimM_simd_block) * sizeof(float);
2217                 if (!is_first_tile) {
2218                     vmovups(Xbyak::Zmm(0), EVEX_compress_addr(reg_dstC, C_off));
2219                     vaddps(zmm, zmm, Xbyak::Zmm(0));
2220                 }
2221                 vmovups(EVEX_compress_addr(reg_dstC, C_off), zmm);
2222             }
2223         }
2224     };
2225
2226     auto inner_loops = [=]() {
2227         Label dimM_block_loop, dimK_block_loop, dimN_block_loop, dimN_bcast_ur;
2228
2229         mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
2230         L(dimM_block_loop);
2231         { /************* OC_block (M) loop ***********/
2232             mov(reg_dimN_block_loop_cnt, jcp.dimN_block);
2233             L(dimN_block_loop);
2234             { /*************** IC_block (N) loop *********/
2235
2236                 mov(reg_nb_dimN_bcast_ur, jcp.dimN_reg_block/jcp.dimN_bcast_ur);
2237                 L(dimN_bcast_ur);
2238                 {
2239                     prepare_accumm();
2240
2241                     mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
2242                     L(dimK_block_loop);
2243                     {
2244                      /************* nb_tile_ur(K) loop ********/
2245                         for (int K_ur = 0; K_ur < jcp.dimK_reg_block; K_ur++) {
2246
2247                             broadcastB(K_ur);
2248
2249                             for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) {
2250                                 load_srcA(K_ur, M_reg_ur);
2251                                 for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; ++N_bcast) {
2252                                     vfmadd231ps(zmm_dstC(M_reg_ur, N_bcast), zmm_srcA(),
2253                                             zmm_srcB(N_bcast));
2254                                 }
2255                             }
2256                         }
2257                         add(reg_srcA, jcp.dimK_reg_block
2258                                       * jcp.dimM_reg_block * jcp.dimM_simd_block
2259                                       * sizeof(float));
2260                         add(reg_srcB, jcp.dimK_reg_block
2261                                       * jcp.dimN_reg_block
2262                                       * sizeof(float));
2263                         sub(reg_dimK_block_loop_cnt, 1);
2264                         jnz(dimK_block_loop);
2265                     }
2266
2267                     store_dstC();
2268
2269                     sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
2270                                   * jcp.dimM_reg_block * jcp.dimM_simd_block
2271                                   * sizeof(float));
2272                     sub(reg_srcB, jcp.dimK_block * jcp.dimK_reg_block
2273                                   * jcp.dimN_reg_block
2274                                   * sizeof(float));
2275                     add(reg_srcB, jcp.dimN_bcast_ur * sizeof(float));
2276                     add(reg_dstC, jcp.dimN_bcast_ur
2277                             * jcp.dimM_reg_block * jcp.dimM_simd_block
2278                             * sizeof(float));
2279                     sub(reg_nb_dimN_bcast_ur, 1);
2280                     jnz(dimN_bcast_ur);
2281                 }
2282
2283                 sub(reg_srcB, jcp.dimN_reg_block * sizeof(float));
2284                 add(reg_srcB, jcp.dimK_block
2285                         * jcp.dimK_reg_block
2286                         * jcp.dimN_reg_block * sizeof(float));
2287                 sub(reg_dimN_block_loop_cnt, 1);
2288                 jnz(dimN_block_loop);
2289             }
2290
2291             sub(reg_srcB, jcp.dimN_block
2292                           * jcp.dimK_block * jcp.dimK_reg_block
2293                           * jcp.dimN_reg_block
2294                           * sizeof(float));
2295             add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
2296                           * jcp.dimM_reg_block * jcp.dimM_simd_block
2297                           * sizeof(float));
2298             sub(reg_dimM_block_loop_cnt, 1);
2299             jnz(dimM_block_loop);
2300         }
2301     };
2302
2303     /* Preamble */
2304     preamble();
2305
2306     inner_loops();
2307
2308     /* Postamble */
2309     postamble();
2310     ret();
2311 }
2312
2313 namespace {
2314
2315 void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp) {
2316 /*M params*/
2317     jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block
2318         / jcp.dimM_simd_block;
2319     jcp.oc_reg_block = jcp.dimM_reg_block;
2320     jcp.oc_block = jcp.dimM_block;
2321     jcp.nb_oc = jcp.dimM_nb_block;
2322     /*N params*/
2323     jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
2324     jcp.ic_block = jcp.dimN_block;
2325     jcp.nb_ic = jcp.dimN_nb_block;
2326
2327     /*K params*/
2328     jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
2329     jcp.tile_block_ur = jcp.dimK_reg_block;
2330     jcp.nb_tile_block_ur = jcp.dimK_block;
2331     jcp.tile_block = jcp.dimK_nb_block;
2332 }
2333
2334 status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) {
2335
2336     size_t K_blk_ur, N_blk, M_blk;
2337     /* IS this strategy feasible? */
2338     auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) {
2339         size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float);
2340         size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float);
2341         size_t nthreads = mkldnn_get_max_threads();
2342         return (((V_sz + M_sz) / nthreads) >= 2 * L2_cache_size)
2343             && (jcp.dimK / nthreads >= 1.0);
2344     };
2345
2346     auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur,
2347             int max_block=1) {
2348         size_t L1_block_M  = jcp.dimM_reg_block * jcp.dimM_simd_block * dimK_block_ur * sizeof(float);
2349         size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float);
2350         size_t M_L2_block = alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float);
2351         size_t nthreads = mkldnn_get_max_threads();
2352         bool load_balance=true;
2353         if (!(jcp.dimK % nthreads)) {
2354             load_balance = ((jcp.dimK / dimK_block_ur) % nthreads == 0);
2355         }
2356         return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size)
2357             && (L1_block_M + L1_block_N <= 0.5 * L1_cache_size)
2358             && load_balance
2359             && (M_L2_block < L2_cache_size);
2360     };
2361
2362     auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur,
2363             int useless_arg=0) {
2364         return (dimK_ur >= 2) && (dimK_ur <= 8);
2365     };
2366
2367     auto blocking_ok =  [&](){
2368         size_t M_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block
2369                           * K_blk_ur * sizeof(float);
2370         size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block
2371                           * K_blk_ur * sizeof(float);
2372         size_t U_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block
2373                           * N_blk * jcp.dimN_reg_block * sizeof(float);
2374         size_t L2_block = M_L2_block + V_L2_block + U_L2_block;
2375         /*Replace 2.375 with L2+L3 cache size*/
2376         return (L2_block > 0.1 * L2_cache_size) && (L2_block <= 1.2 * L2_cache_size);
2377     };
2378
2379     if (test_MV_large_enough(jcp)) {
2380         if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) {
2381             jcp.dimM_reg_block = 2;
2382         } else {
2383             jcp.dimM_reg_block = 1;
2384         }
2385         jcp.dimM_simd_block = jcp.oc_simd_block;
2386         jcp.dimN_reg_block = jcp.ic_simd_block;
2387         jcp.dimN_bcast_ur = 8;
2388         /*dimK_block and dimK_ur*/
2389         size_t min_dimK_block_ur = get_divisor_satisfying_cond(jcp, jcp.dimK, 1, test_min_dimK_L1);
2390
2391         jcp.dimM_block = jcp.dimM/jcp.dimM_reg_block/jcp.dimM_simd_block;
2392         jcp.dimN_block = jcp.dimN/jcp.dimN_reg_block;
2393         for (K_blk_ur = min_dimK_block_ur; K_blk_ur >= 1; --K_blk_ur) {
2394             if (test_min_dimK_L1(jcp, K_blk_ur) && !(jcp.dimK % K_blk_ur)) {
2395                 for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) {
2396                     if (!(jcp.dimN_block % N_blk)) {
2397                         for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) {
2398                             if (!(jcp.dimM_block % M_blk) && blocking_ok()) {
2399                                 jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur);
2400                                 if (!test_dimK_ur(jcp, jcp.dimK_reg_block)) return status::unimplemented;
2401                                 jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block;
2402                                 jcp.dimN_block = N_blk;
2403                                 jcp.dimM_block = M_blk;
2404                                 jcp.sched_policy = WSCHED_WEI_SDGtWo;
2405                                 set_jcp_WEI_params(jcp);
2406                                 jcp.nthr = nstl::min(mkldnn_get_max_threads(),
2407                                         jcp.tile_block);
2408                                 return status::success;
2409                             }
2410                         }
2411                     }
2412                 }
2413             }
2414         }
2415     }
2416     return status::unimplemented;
2417 }
2418
2419 status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) {
2420     if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) {
2421         jcp.dimM_reg_block = 2;
2422     } else {
2423         jcp.dimM_reg_block = 1;
2424     }
2425     jcp.dimN_bcast_ur = 8;
2426     jcp.dimN_reg_block = jcp.ic_simd_block;
2427     jcp.dimM_simd_block = jcp.oc_simd_block;
2428     jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block;
2429     jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block;
2430     float C1 = 0.0, C2 = 0.0;
2431     float C1_max = 0.5, C2_max = 1.4;
2432     int N_blk, M_blk, K_blk_ur;
2433
2434     auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur,
2435             int useless_arg=0) {
2436         return (dimK_ur >= 2) && (dimK_ur <= 8);
2437     };
2438
2439     auto blocking_ok = [&]() -> bool {
2440         size_t L1_block_M  = jcp.dimM_reg_block * jcp.dimM_simd_block * K_blk_ur * sizeof(float);
2441         size_t L1_block_N = jcp.dimN_reg_block * K_blk_ur * sizeof(float);
2442         bool L1_cond = ((L1_block_N + L1_block_M) >= C1 * L1_cache_size)
2443                      && ((L1_block_N + L1_block_M) <= C1_max * L1_cache_size);
2444
2445         size_t nb_N_blk = jcp.dimN/N_blk/jcp.dimN_reg_block;
2446         size_t nb_M_blk = jcp.dimM/M_blk/jcp.dimM_reg_block/jcp.dimM_simd_block;
2447         size_t nb_K_blk = jcp.dimK / K_blk_ur;
2448         size_t nthreads = mkldnn_get_max_threads();
2449         bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) >= nthreads;
2450         if (!(nb_K_blk % nthreads)) {
2451             load_balance = load_balance && (nb_K_blk % nthreads == 0);
2452         }
2453
2454         size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block * K_blk_ur * sizeof(float);
2455
2456         size_t L2_block = V_L2_block;
2457         /*Replace 2.375 with L2+L3 cache size*/
2458         bool L2_cond = (L2_block >= C2 * L2_cache_size) && (L2_block <= C2_max * L2_cache_size);
2459         return L1_cond && load_balance && L2_cond;
2460     };
2461
2462     for (K_blk_ur = jcp.dimK; K_blk_ur >= 1; --K_blk_ur) {
2463         if (jcp.dimK % K_blk_ur == 0) {
2464             for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) {
2465                 if (jcp.dimN_block % N_blk == 0) {
2466                     for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) {
2467                         if (jcp.dimM_block % M_blk == 0) {
2468                             if (blocking_ok()) {
2469                                 jcp.dimN_block = N_blk;
2470                                 jcp.dimM_block = M_blk;
2471                                 jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur);
2472                                 jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block;
2473                                 jcp.sched_policy = WSCHED_WEI_S_D_Giot_W;
2474                                 set_jcp_WEI_params(jcp);
2475                                 return status::success;
2476                             }
2477                         }
2478                     }
2479                 }
2480             }
2481         }
2482     }
2483     jcp.dimK_reg_block = 1;
2484     jcp.dimK_block = 1;
2485     jcp.sched_policy = WSCHED_WEI_S_D_Giot_W;
2486     set_jcp_WEI_params(jcp);
2487     return status::success;
2488 }
2489 } // namespace
2490 status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::init_conf(
2491         jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
2492         const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d,
2493         const memory_desc_wrapper &diff_weights_d) {
2494     if (!mayiuse(avx512_core))
2495         return status::unimplemented;
2496     else
2497         jcp.ver = ver_avx512_core;
2498
2499     jcp.nthr = mkldnn_get_max_threads();
2500
2501     jcp.prop_kind = cd.prop_kind;
2502     const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
2503     jcp.mb = src_d.dims()[0];
2504     jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
2505     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
2506     jcp.oc_without_padding = jcp.oc;
2507     jcp.ic = src_d.dims()[1] / jcp.ngroups;
2508     jcp.ih = src_d.dims()[2];
2509     jcp.iw = src_d.dims()[3];
2510     jcp.oh = diff_dst_d.dims()[2];
2511     jcp.ow = diff_dst_d.dims()[3];
2512     jcp.kh = diff_weights_d.dims()[with_groups + 2];
2513     jcp.kw = diff_weights_d.dims()[with_groups + 3];
2514     jcp.t_pad = cd.padding[0][0];
2515     jcp.l_pad = cd.padding[0][1];
2516     jcp.stride_h = cd.strides[0];
2517     jcp.stride_w = cd.strides[1];
2518     jcp.r_pad = nstl::max(
2519             0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
2520     jcp.b_pad = nstl::max(
2521             0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
2522     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
2523     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
2524     jcp.ohp = jcp.oh;
2525     jcp.owp = jcp.ow;
2526     jcp.with_bias = (cd.diff_bias_desc.format != memory_format::undef);
2527     jcp.dilate_h = cd.dilates[0];
2528     jcp.dilate_w = cd.dilates[1];
2529
2530     bool ok_to_pad_channels = jcp.ngroups == 1;
2531     if (ok_to_pad_channels) {
2532         jcp.oc = rnd_up(jcp.oc, simd_w);
2533         jcp.ic = rnd_up(jcp.ic, simd_w);
2534     }
2535
2536     // Winograd specific initialization
2537     jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
2538     jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
2539     jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
2540
2541     // Winograd kernel works only for 3x3 convolution with stride 1
2542     if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
2543                is_winograd_faster_than_direct(jcp)))
2544         return status::unimplemented;
2545
2546     if (jcp.ngroups != 1)
2547         return status::unimplemented;
2548     if ((jcp.kh != 3) || (jcp.kw != 3))
2549         return status::unimplemented;
2550     if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
2551         return status::unimplemented;
2552     if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
2553         return status::unimplemented;
2554     if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
2555         return status::unimplemented;
2556     if (src_d.format() != nChw16c)
2557         return status::unimplemented;
2558     if (diff_weights_d.format() != (with_groups ? gOIhw16i16o : OIhw16i16o))
2559         return status::unimplemented;
2560     if (diff_dst_d.format() != nChw16c)
2561         return status::unimplemented;
2562
2563     bool layout_consistency = true
2564         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
2565         && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
2566         && jcp.ic <= diff_weights_d.blocking_desc().padding_dims[with_groups + 1]
2567         && jcp.oc <= diff_weights_d.blocking_desc().padding_dims[with_groups + 0];
2568     if (!layout_consistency) return status::unimplemented;
2569
2570     /******************Kernel blocking Parameters ***********/
2571     jcp.ic_simd_block = simd_w;
2572     jcp.oc_simd_block = simd_w;
2573
2574     jcp.dimK = jcp.ntiles;
2575     jcp.dimN = jcp.ic;
2576     jcp.dimM = jcp.oc;
2577     jcp.dimM_simd_block = jcp.oc_simd_block;
2578     jcp.dimN_reg_block = jcp.ic_simd_block;
2579     jcp.sched_policy = WSCHED_INVALID;
2580     status_t res = set_wsched_WEI_SDGtWo(jcp);
2581     if (res == status::unimplemented) {
2582         res = set_wsched_WEI_S_D_Giot_W(jcp);
2583         assert(res == status::success);
2584     }
2585     return res;
2586 }
2587 }
2588 }
2589 }
2590
2591 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s