Publishing R3
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_avx512_core_conv_winograd_kernel_f32.cpp
1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
19 #include "nstl.hpp"
20 #include "type_helpers.hpp"
21 #include "utils.hpp"
22 #include "cpu_memory.hpp"
23
24 #include <math.h>
25
26 #include "jit_avx512_core_conv_winograd_kernel_f32.hpp"
27
28 #define GET_OFF(field) offsetof(jit_wino_transform_call_s, field)
29
30 namespace mkldnn {
31 namespace impl {
32 namespace cpu {
33
34 namespace {
35
36 using namespace mkldnn::impl::utils;
37
38 unsigned int L1_cache_size = get_cache_size(1, true);
39 unsigned int L2_cache_size = get_cache_size(2, true);
40 unsigned int LLC_data_size = get_cache_size(3, false);
41
42 // the test funtion takes jcp, the candidate and the current best.
43 // it  returns true if the new candidate is better
44 int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number,
45         int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int))
46 {
47     int best_divisor = default_best;
48     auto test_num
49             = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) {
50                   if (test(jcp, num, best_divisor)) {
51                       best_divisor = num;
52                   }
53               };
54
55     for (int divisor = 1; divisor <= ::sqrt(number); divisor++) {
56         if (number % divisor == 0) {
57             test_num(jcp, divisor);
58             test_num(jcp, number / divisor);
59         }
60     }
61
62     return best_divisor;
63 }
64
65 /* assumes 512 bits registers */
66 /* TODO: add support for strides */
67 /* TODO: handle the prefetch distance automatically */
68 typedef enum cache_t_ { L1, L2, L3 } cache_t;
69
70 template <typename data_t>
71 struct prefetcher_t {
72     prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr,
73             cache_t cache_type, size_t block_size, /* in number of elements*/
74             int nb_instructions_in_block, int fma_ipc)
75         : cg_(generator)
76         , reg_base_addr_(reg_base_addr)
77         , cache_type_(cache_type)
78         , cache_block_size_(block_size)
79     {
80         nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t));
81         prefetch_spread_
82                 = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_);
83         prefetch_blk_
84                 = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block);
85
86         /* assumption: when fetch in Li, data is already in L(i+1) */
87         int cache_latency;
88         switch (cache_type_) {
89         case L1: cache_latency = 14; break;
90         case L2: cache_latency = 250; break;
91         case L3: cache_latency = 250; break;
92         }
93
94         prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_);
95     }
96
97     void prefetch(int instruction_number)
98     {
99         if (instruction_number % prefetch_spread_ == 0) {
100             for (int i = 0; (i < prefetch_blk_)
101                     && (prefetches_issued_ < nb_cache_lines_to_prefetch_);
102                     i++, prefetches_issued_++) {
103                 prefetch_inst_(cg_->EVEX_compress_addr(
104                         reg_base_addr_, (cache_block_size_ * prefetch_distance_)
105                                         * sizeof(data_t)
106                                 + (prefetches_issued_ * 64)));
107             }
108         }
109     }
110
111 private:
112     void prefetch_inst_(const Xbyak::Address &addr)
113     {
114         switch (cache_type_) {
115         case L1: cg_->prefetcht0(addr); break;
116         case L2: cg_->prefetcht1(addr); break;
117         case L3: cg_->prefetcht2(addr); break;
118         default:
119             break; // TODO: raise an exception or put an assert
120         }
121     }
122
123     jit_generator *cg_;
124     Xbyak::Reg64 reg_base_addr_;
125     cache_t cache_type_;
126     int cache_block_size_ = 0;
127     int nb_cache_lines_to_prefetch_ = 0;
128     int prefetches_issued_ = 0;
129     int prefetch_spread_ = 0;
130     int prefetch_blk_ = 0;
131     int prefetch_distance_ = 0;
132 };
133
134 // utilities to support kernel parameter selection
135 bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp,
136         int dimN_block, float C2_min, float C2_max) {
137     float block_size = alpha * alpha * (2*(jcp.oc + jcp.ic)
138         * dimN_block * jcp.dimN_reg_block
139         + div_up(jcp.ic * jcp.oc,omp_get_max_threads())) * (float)sizeof(float);
140     float L2_lb = C2_min * L2_cache_size;
141     float L2_ub = C2_max * L2_cache_size;
142     return (block_size > L2_lb && block_size < L2_ub);
143 }
144
145 bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block,
146         int dimM_block, float C1_min, float C1_max) {
147     float gemm_block_size = (dimM_block * jcp.dimM_simd_block * dimK_block
148                              * jcp.dimK_reg_block * jcp.dimM_reg_block
149                      + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block
150                      + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block)
151                      * (float)sizeof(float);
152     float L1_lb = C1_min * L1_cache_size;
153     float L1_ub = C1_max * L1_cache_size;
154     return (gemm_block_size > L1_lb && gemm_block_size < L1_ub);
155 }
156 bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block,
157         int dimM_block, int dimM_reg_block, int dimM_simd_block, float C)
158 {
159     float lhs = (dimM_block * dimN_reg_block * dimM_simd_block * dimM_reg_block
160                         + dimM_block * dimK_block * dimK_reg_block
161                                 * dimM_simd_block * dimM_reg_block
162                         + dimK_block * dimN_reg_block * dimK_reg_block)
163             * (float)sizeof(float);
164     float rhs = C * L1_cache_size;
165     return (lhs < rhs);
166 }
167 bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block,
168         int dimM_block, int dimM_reg_block, int dimM_simd_block, float C)
169 {
170     float lhs = (dimM_block * dimM_reg_block * dimK_block * dimK_reg_block
171             * dimM_simd_block + dimK_block * dimN_reg_block * dimK_reg_block)
172             * (float)sizeof(float);
173     float rhs = C * L1_cache_size;
174     return (lhs < rhs);
175 }
176 bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block,
177         int dimK_block, int dimK_reg_block, int dimM_block, int dimM_reg_block,
178         int dimM_simd_block, float C)
179 {
180     float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block
181                               * dimM_simd_block * dimM_reg_block
182                       + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block
183                               * dimM_simd_block * dimM_reg_block
184                       + nb_dimN_reg_block * dimK_nb_block * dimK_block
185                               * dimN_reg_block * dimK_reg_block)
186             * (float)sizeof(float);
187     float rhs = C * L2_cache_size;
188     return (lhs < rhs);
189 }
190
191 bool check_kernel_cond(int dimM_block, int dimM_reg_block, int dimM_simd_block,
192         int dimN_block, int dimN_reg_block, int dimK, float C1, float C2)
193 {
194     float A_size = dimM_block * dimM_reg_block * dimM_simd_block * dimK
195         * (float)sizeof(float);
196     float B_size = dimN_block * dimN_reg_block * dimK
197         * (float)sizeof(float);
198     return (A_size > C1 * L2_cache_size && B_size > C2 * L2_cache_size);
199 }
200 }
201
202 using namespace mkldnn::impl::memory_format;
203 using namespace mkldnn::impl::utils;
204 using namespace Xbyak;
205
206 void _jit_avx512_core_conv_winograd_data_kernel_f32::gemm_loop_generate()
207 {
208     // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++)
209     // for (int dimM_reg_block =0; dimM_reg_block < jcp.dimM_reg_block;
210     //      dimM_reg_block++) // unrolled
211     //     for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++)
212     //         for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block;
213     //              dimK_reg_block++) // unrolled
214     //             for (int tile =0; tile < jcp.dimN_reg_block; tile++)
215     //                 C[dimM_block][dimM_reg_block][tile] +=
216     //                 A[dimM_block][dimM_reg_block][dimK_block][dimK_reg_block]
217     //                 * broadcast(B[dimK_block][tile][dimK_reg_block]);
218     // Notes:
219     // jcp.kernel_kind defines embedded or explicit broadcast
220     // dimM_reg_block=1 for embedded bcast kernel
221
222     auto zmm_srcA = [=]() {
223         return Xbyak::Zmm(0);
224     };
225     auto zmm_srcB = [=](int tile) {
226         int idx = 1 + tile;
227         assert(idx < 1 + jcp.dimN_reg_block);
228         return Xbyak::Zmm(idx);
229     };
230     auto zmm_dstC = [=](int dimM_reg_block, int tile) {
231         int idx{0};
232         if (jcp.kernel_kind == embd_bcast)
233             idx = 1 + tile;
234         else
235             idx = 1 + jcp.dimN_reg_block
236                   + dimM_reg_block * jcp.dimN_reg_block + tile;
237         assert(idx < 32);
238         return Xbyak::Zmm(idx);
239     };
240
241     auto prepare_output = [=]() {
242         for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
243               dimM_reg_block++) {
244             for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
245                 Zmm zmm = zmm_dstC(dimM_reg_block, tile);
246                 vpxord(zmm, zmm, zmm);
247             }
248         }
249     };
250     auto store_output = [=](bool output_is_aligned) {
251         Label save;
252         cmp(reg_is_beta_zero, 0);
253         je(save, T_NEAR);
254
255         for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
256               dimM_reg_block++) {
257             for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
258                 Zmm zmm = zmm_dstC(dimM_reg_block,tile);
259                 int output_offset
260                     = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64;
261                 vaddps(zmm, zmm, EVEX_compress_addr(reg_dstC, output_offset));
262             }
263         }
264
265         L(save);
266         for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
267               dimM_reg_block++) {
268             for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
269                 Zmm zmm = zmm_dstC(dimM_reg_block,tile);
270                 int output_offset
271                     = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64;
272
273                 // In W_SGD, output will be reused.
274                 if (output_is_aligned
275                     && jcp.dimK_nb_block == 1
276                     && jcp.sched_policy == WSCHED_DATA_W_S_G_D
277                     && (jcp.dimN * jcp.dimM * alpha * alpha
278                         * sizeof(float) > 2 * LLC_data_size))
279                     vmovntps(EVEX_compress_addr(reg_dstC, output_offset), zmm);
280                 else vmovups(EVEX_compress_addr(reg_dstC, output_offset), zmm);
281             }
282         }
283     };
284
285     auto inner_loops = [=]() {
286         Label dimM_block_loop, dimK_block_loop;
287
288         if (jcp.dimM_block > 1) {
289             mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
290             L(dimM_block_loop);
291         }
292
293         prepare_output();
294
295         if (jcp.dimK_block > 1) {
296             mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
297             L(dimK_block_loop);
298         }
299
300         for (int dimK_reg_block = 0;
301                 dimK_reg_block < jcp.dimK_reg_block;
302                 dimK_reg_block ++) {
303
304             if (jcp.kernel_kind == expl_bcast) {
305                 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
306                     vbroadcastss(zmm_srcB(tile),
307                         ptr[reg_srcB + 64 * tile + dimK_reg_block * 4]);
308                 }
309             }
310
311             /* Performing the fmas */
312
313             for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
314                 dimM_reg_block++) {
315
316                 vmovups(zmm_srcA(),
317                     zword[reg_srcA
318                             + jcp.dimK_reg_block * jcp.dimK_block * 64
319                               * dimM_reg_block
320                             + dimK_reg_block * 64]
321                     );
322
323                 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
324                     if (jcp.kernel_kind == expl_bcast)
325                         vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(),
326                             zmm_srcB(tile));
327                     else
328                         vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(),
329                             EVEX_compress_addr(reg_srcB,
330                                 64 * tile + dimK_reg_block * 4, true));
331                 }
332             }
333         }
334         add(reg_srcA, jcp.dimK_reg_block * 64);
335         add(reg_srcB, jcp.dimN_reg_block * 64);
336         if (jcp.dimK_block > 1) {
337             sub(reg_dimK_block_loop_cnt, 1);
338             jnz(dimK_block_loop);
339         }
340
341         Label unaligned_store, end_store;
342         test(reg_dstC, cpu_isa_traits<avx512_core>::vlen - 1);
343         jnz(unaligned_store, T_NEAR);
344         store_output(true);
345         jmp(end_store, T_NEAR);
346         L(unaligned_store); {
347             store_output(false);
348         }
349         L(end_store);
350
351         if (jcp.dimM_block > 1) {
352             sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64);
353             add(reg_dstC, jcp.dimM_reg_block * jcp.dimN_reg_block * 64);
354             if (jcp.kernel_kind == expl_bcast) {
355                 add(reg_srcA,
356                      (jcp.dimM_reg_block-1) * jcp.dimK_reg_block * 64
357                       * jcp.dimK_block);
358             }
359             sub(reg_dimM_block_loop_cnt, 1);
360             jnz(dimM_block_loop);
361         }
362     };
363
364     /* Preamble */
365     // register used to handle long fma encoding
366     push(reg_EVEX_max_8b_offt);
367     mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
368
369     /* kernel */
370     inner_loops();
371
372     /* Postamble */
373     pop(reg_EVEX_max_8b_offt);
374     uni_vzeroupper();
375     ret();
376 }
377
378 void _jit_avx512_core_conv_winograd_data_kernel_f32
379     ::weights_transform_data_ker_generate()
380 {
381     bool is_fwd = one_of(jcp.prop_kind,
382         mkldnn_forward_training, mkldnn_forward_inference);
383     int kh = jcp.kh;
384     int kw = jcp.kw;
385
386     auto zmm_temp = Xbyak::Zmm(31);
387     auto zmm_zero = Xbyak::Zmm(30);
388
389     auto zmm_M = [=](int i) {
390         return Xbyak::Zmm(i);
391     };
392     auto zmm_MT = [=](int i) {
393         return Xbyak::Zmm(i + simd_w);
394     };
395
396     auto zmm_G = [=](int i) {
397         return Xbyak::Zmm(i);
398     };
399     auto zmm_F = [=](int i) {
400         return Xbyak::Zmm(alpha + i);
401     };
402     auto zmm_T = [=](int i) {
403         return Xbyak::Zmm(alpha + 3 + i);
404     };
405     auto zmm_t = [=](int i) {
406         return Xbyak::Zmm(2 * alpha + 3 + i);
407     };
408
409     auto zmm_load = [=](int i) {
410         return Xbyak::Zmm(i);
411     };
412
413     auto init_G = [=]() {
414         mov(wreg_temp, ptr[param1 + GET_OFF(G)]);
415         for (int i = 0; i < alpha; i++) {
416             vbroadcastss(zmm_G(i), ptr[wreg_temp + i * typesize]);
417         }
418         vpxord(zmm_zero, zmm_zero, zmm_zero);
419     };
420
421     auto trans16x16 = [=]() {
422         for (int i = 0; i < simd_w; i+=2 ) {
423             vmovups(zmm_M(i), ptr[wreg_M + i * simd_w * 4]);
424             vmovups(zmm_M(i+1), ptr[wreg_M + (i + 1) * simd_w * 4]);
425             vunpcklps(zmm_MT(i), zmm_M(i), zmm_M(i+1));
426             vunpckhps(zmm_MT(i+1), zmm_M(i), zmm_M(i+1));
427         }
428         for (int i = 0; i < simd_w; i+=4 ) {
429             vunpcklpd(zmm_M(i), zmm_MT(i), zmm_MT(i+2));
430             vunpckhpd(zmm_M(i+1), zmm_MT(i), zmm_MT(i+2));
431             vunpcklpd(zmm_M(i+2), zmm_MT(i+1), zmm_MT(i+3));
432             vunpckhpd(zmm_M(i+3), zmm_MT(i+1), zmm_MT(i+3));
433         }
434         for (int i = 0; i < simd_w; i += 8) {
435             vshuff32x4(zmm_MT(i), zmm_M(i), zmm_M(i + 4), 0x88);
436             vshuff32x4(zmm_MT(i+1), zmm_M(i+1), zmm_M(i + 5), 0x88);
437             vshuff32x4(zmm_MT(i+2), zmm_M(i+2), zmm_M(i + 6), 0x88);
438             vshuff32x4(zmm_MT(i+3), zmm_M(i+3), zmm_M(i + 7), 0x88);
439             vshuff32x4(zmm_MT(i+4), zmm_M(i), zmm_M(i + 4), 0xdd);
440             vshuff32x4(zmm_MT(i+5), zmm_M(i+1), zmm_M(i + 5), 0xdd);
441             vshuff32x4(zmm_MT(i+6), zmm_M(i+2), zmm_M(i + 6), 0xdd);
442             vshuff32x4(zmm_MT(i+7), zmm_M(i+3), zmm_M(i + 7), 0xdd);
443         }
444         {
445             int i = 0;
446             int mask = 0x88;
447             vshuff32x4(zmm_M(0), zmm_MT(i), zmm_MT(i + 8), mask);
448             vmovups(ptr[wreg_MT + 0 * 16 * 4], zmm_M(0));
449             vshuff32x4(zmm_M(1), zmm_MT(i + 1), zmm_MT(i + 9), mask);
450             vmovups(ptr[wreg_MT + 1 * 16 * 4], zmm_M(1));
451             vshuff32x4(zmm_M(2), zmm_MT(i + 2), zmm_MT(i + 10), mask);
452             vmovups(ptr[wreg_MT + 2 * 16 * 4], zmm_M(2));
453             vshuff32x4(zmm_M(3), zmm_MT(i + 3), zmm_MT(i + 11), mask);
454             vmovups(ptr[wreg_MT + 3 * 16 * 4], zmm_M(3));
455             vshuff32x4(zmm_M(4), zmm_MT(i + 4), zmm_MT(i + 12), mask);
456             vmovups(ptr[wreg_MT + 4 * 16 * 4], zmm_M(4));
457             vshuff32x4(zmm_M(5), zmm_MT(i + 5), zmm_MT(i + 13), mask);
458             vmovups(ptr[wreg_MT + 5 * 16 * 4], zmm_M(5));
459             vshuff32x4(zmm_M(6), zmm_MT(i + 6), zmm_MT(i + 14), mask);
460             vmovups(ptr[wreg_MT + 6 * 16 * 4], zmm_M(6));
461             vshuff32x4(zmm_M(7), zmm_MT(i + 7), zmm_MT(i + 15), mask);
462             vmovups(ptr[wreg_MT + 7 * 16 * 4], zmm_M(7));
463             mask = 0xdd;
464             vshuff32x4(zmm_M(8), zmm_MT(i), zmm_MT(i + 8), mask);
465             vmovups(ptr[wreg_MT + 8 * 16 * 4], zmm_M(8));
466             vshuff32x4(zmm_M(9), zmm_MT(i + 1), zmm_MT(i + 9), mask);
467             vmovups(ptr[wreg_MT + 9 * 16 * 4], zmm_M(9));
468             vshuff32x4(zmm_M(10), zmm_MT(i + 2), zmm_MT(i + 10), mask);
469             vmovups(ptr[wreg_MT + 10 * 16 * 4], zmm_M(10));
470             vshuff32x4(zmm_M(11), zmm_MT(i + 3), zmm_MT(i + 11), mask);
471             vmovups(ptr[wreg_MT + 11 * 16 * 4], zmm_M(11));
472             vshuff32x4(zmm_M(12), zmm_MT(i + 4), zmm_MT(i + 12), mask);
473             vmovups(ptr[wreg_MT + 12 * 16 * 4], zmm_M(12));
474             vshuff32x4(zmm_M(13), zmm_MT(i + 5), zmm_MT(i + 13), mask);
475             vmovups(ptr[wreg_MT + 13 * 16 * 4], zmm_M(13));
476             vshuff32x4(zmm_M(14), zmm_MT(i + 6), zmm_MT(i + 14), mask);
477             vmovups(ptr[wreg_MT + 14 * 16 * 4], zmm_M(14));
478             vshuff32x4(zmm_M(15), zmm_MT(i + 7), zmm_MT(i + 15), mask);
479             vmovups(ptr[wreg_MT + 15 * 16 * 4], zmm_M(15));
480         }
481     };
482
483     auto load_src = [=]() {
484         mov(wreg_src, ptr[param1 + GET_OFF(src)]);
485         mov(wreg_F, ptr[param1 + GET_OFF(M)]);
486         for (int j = 0; j < kh; j++) {
487             for (int i = 0; i < kw; i++) {
488                 if (is_fwd) {
489                     for (int v1 = 0; v1 < simd_w; v1++) {
490                         int offset_src = (j * kw * simd_w * simd_w
491                             + i * simd_w * simd_w + v1 * simd_w) * typesize;
492                         int offset_F = (j * kw * simd_w * simd_w
493                             + i * simd_w * simd_w  + v1 * simd_w) * typesize;
494                         vmovups(zmm_temp, ptr[wreg_src + offset_src]);
495                         vmovups(ptr[wreg_F + offset_F], zmm_temp);
496                     }
497                 } else {
498                     int offset_src = ((2 - j) * kw * simd_w * simd_w
499                         + (2 - i) * simd_w * simd_w) * typesize;
500                     int offset_F = (j * kw * simd_w * simd_w
501                         + i * simd_w * simd_w) * typesize;
502                     lea(wreg_M, ptr[wreg_src + offset_src]);
503                     lea(wreg_MT, ptr[wreg_F + offset_F]);
504                     trans16x16();
505                 }
506             }
507         }
508     };
509
510     auto store_dst = [=]() {
511         mov(wreg_dst, ptr[param1 + GET_OFF(dst)]);
512         mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]);
513
514         Label Loop_j;
515         mov(wreg_cnt_j, 0);
516         mov(wreg_dst_aux, wreg_dst);
517         mov(wreg_Fw_aux, wreg_Fw);
518
519         int dim5 = jcp.dimK_nb_block * (jcp.dimM_block * jcp.dimM_reg_block)
520             * jcp.dimK_block * simd_w * simd_w;
521
522         L(Loop_j);
523         {
524             for (int i = 0; i < alpha; i++) {
525                 // touch pages
526                 vmovups(zmm_load(0), ptr[wreg_Fw_aux
527                     + (i * simd_w * simd_w) * typesize]);
528                 mov(wreg_dst_idx, i * dim5 * typesize);
529                 vmovntps(ptr[wreg_dst_aux + wreg_dst_idx], zmm_load(0));
530             }
531             for (int i = 0; i < alpha; i++) {
532                 for (int v1 = 1; v1 < simd_w; v1++) {
533                     int offset_Fw = (i * simd_w * simd_w  + v1 * simd_w)
534                         * typesize;
535                     vmovups(zmm_load(v1), ptr[wreg_Fw_aux + offset_Fw]);
536                 }
537                 mov(wreg_dst_idx, i * dim5 * typesize);
538                 for (int v1 = 1; v1 < simd_w; v1++) {
539                     int offset_dst = v1 * simd_w * typesize;
540                     vmovntps(ptr[wreg_dst_aux + wreg_dst_idx + offset_dst],
541                         zmm_load(v1));
542                 }
543             }
544             add(wreg_Fw_aux, alpha * simd_w * simd_w * typesize);
545             add(wreg_dst_aux, alpha * dim5 * typesize);
546             add(wreg_cnt_j, 1);
547             cmp(wreg_cnt_j, alpha);
548             jl(Loop_j, T_NEAR);
549         }
550     };
551
552     auto trans_W_4x4_3x3 = [=]() {
553         auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
554             vmovups(dst, a);
555             vfmadd231ps(dst, b, c);
556         };
557         auto fms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
558             vmulps(zmm_temp, b, c);
559             vsubps(dst, a, zmm_temp);
560         };
561         auto fnms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
562             vsubps(dst, zmm_zero, a);
563             vfnmadd231ps(dst, b, c);
564         };
565
566         mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]);
567         mov(wreg_F, ptr[param1 + GET_OFF(M)]);
568         mov(wreg_T, ptr[param1 + GET_OFF(T)]);
569
570         Label Loop_j;
571         mov(wreg_cnt_j, 0);
572         L(Loop_j);
573             mov(wreg_F_aux, wreg_F);
574             mov(wreg_Fw_aux, wreg_Fw);
575             mov(wreg_temp, wreg_cnt_j);
576             shl(wreg_temp, 4 + 2);
577             lea(wreg_F_aux, ptr[wreg_F + wreg_temp]);
578             lea(wreg_Fw_aux, ptr[wreg_Fw + wreg_temp]);
579
580             for (int i = 0; i < 3; i++) {
581                 for (int idx = 0; idx < 3; idx ++) {
582                     vmovups(zmm_F(idx), ptr[wreg_F_aux + (idx * 3 * simd_w
583                         * simd_w + i * simd_w * simd_w) * typesize]);
584                 }
585                 vmulps(zmm_t(0), zmm_G(0), zmm_F(2));
586                 fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_F(0));
587                 fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_F(0));
588
589                 vmulps(zmm_T(0), zmm_G(3), zmm_F(0));
590                 fms4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_F(1));
591                 fma4(zmm_T(2), zmm_t(1), zmm_G(4), zmm_F(1));
592                 fma4(zmm_T(3), zmm_t(2), zmm_G(5), zmm_F(1));
593                 fms4(zmm_T(4), zmm_t(2), zmm_G(5), zmm_F(1));
594                 vmovaps(zmm_T(5), zmm_F(2));
595
596                 for (int idx = 0; idx < 6; idx ++) {
597                     vmovups(ptr[wreg_T + (idx * 3 * simd_w + i * simd_w)
598                         * typesize], zmm_T(idx));
599                 }
600             }
601             for (int i = 0; i < 6; i++) {
602
603                 for (int idx = 0; idx < 3; idx ++) {
604                     vmovups(zmm_T(idx), ptr[wreg_T
605                         + (i * 3 * simd_w + idx * simd_w) * typesize]);
606                 }
607                 vmulps(zmm_t(0), zmm_G(0), zmm_T(2));
608                 fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_T(0));
609                 fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_T(0));
610
611                 vmulps(zmm_F(0), zmm_G(3), zmm_T(0));
612                 fms4(zmm_F(1), zmm_t(1), zmm_G(4), zmm_T(1));
613                 fma4(zmm_F(2), zmm_t(1), zmm_G(4), zmm_T(1));
614                 fma4(zmm_F(3), zmm_t(2), zmm_G(5), zmm_T(1));
615                 fms4(zmm_F(4), zmm_t(2), zmm_G(5), zmm_T(1));
616                 vmovaps(zmm_F(5), zmm_T(2));
617
618                 for (int l = 0; l < 6; l++) {
619                     vmovups(ptr[wreg_Fw_aux + (i * 6 * simd_w * simd_w
620                         + l * simd_w * simd_w) * typesize], zmm_F(l));
621                 }
622             }
623         add(wreg_cnt_j, 1);
624         cmp(wreg_cnt_j, 16);
625         jl(Loop_j, T_NEAR);
626     };
627
628     auto inner_loops = [=]() {
629         load_src();
630         init_G();
631         trans_W_4x4_3x3();
632         store_dst();
633     };
634
635     preamble();
636     inner_loops();
637     postamble();
638 }
639
640 void _jit_avx512_core_conv_winograd_data_kernel_f32
641     ::output_transform_data_ker_generate()
642 {
643     bool is_fwd = one_of(jcp.prop_kind,
644         mkldnn_forward_training, mkldnn_forward_inference);
645     int outw = is_fwd ? jcp.ow : jcp.iw;
646     int outh = is_fwd ? jcp.oh : jcp.ih;
647     bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D;
648     bool with_bias = jcp.with_bias;
649     bool with_relu = jcp.with_eltwise;
650     bool with_relu_postsum = jcp.with_relu_postsum;
651     bool with_sum = jcp.with_sum;
652
653     auto zmm_zero = Xbyak::Zmm(0);
654     auto zmm_temp = Xbyak::Zmm(31);
655     auto zmm_G = [=](int i) {
656         return Xbyak::Zmm(1 + i);
657     };
658     auto zmm_O = [=](int i) {
659         return Xbyak::Zmm(1 + alpha + i);
660     };
661     auto zmm_T = [=](int i) {
662         return Xbyak::Zmm(1 + 2 * alpha + i);
663     };
664     auto zmm_t = [=](int i) {
665         return Xbyak::Zmm(1 + 3 * alpha + i);
666     };
667
668     auto init_G = [=]() {
669         mov(oreg_temp, ptr[param1 + GET_OFF(G)]);
670         for (int i = 0; i < 6; i++) {
671             vbroadcastss(zmm_G(i), ptr[oreg_temp + i * typesize]);
672         }
673     };
674
675     auto load_src = [=]() {
676         mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]);
677         mov(oreg_src, ptr[param1 + GET_OFF(src)]);
678
679         mov(oreg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]);
680         imul(oreg_nb_tile_block_ur, oreg_nb_tile_block_ur,
681             (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block
682             * jcp.dimM_simd_block * typesize);
683         add(oreg_src, oreg_nb_tile_block_ur);
684
685         mov(oreg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]);
686         imul(oreg_tile_block_ur, oreg_tile_block_ur,
687             jcp.dimM_simd_block * typesize);
688         add(oreg_src, oreg_tile_block_ur);
689
690         if (not_tiled) {
691             mov(oreg_tile_block, ptr[param1 + GET_OFF(tile_block)]);
692             imul(oreg_tile_block, oreg_tile_block,
693                 jcp.dimM_nb_block * alpha * alpha * jcp.dimN_block
694                 * (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block
695                 * jcp.dimM_simd_block * typesize);
696             add(oreg_src, oreg_tile_block);
697         }
698
699         int last4dim = jcp.dimN_block * (jcp.dimM_block * jcp.dimM_reg_block)
700             * jcp.dimN_reg_block * jcp.dimM_simd_block * typesize;
701         for (int j = 0; j < alpha; j++) {
702             for (int i = 0; i < alpha; i++) {
703                 int j_base_offset = j * alpha * last4dim;
704                 int i_base_offset = i * last4dim;
705                 vmovups(zmm_temp, ptr[oreg_src + j_base_offset + i_base_offset]);
706                 vmovups(ptr[oreg_Ow + (j * alpha * simd_w + i * simd_w)
707                     * typesize], zmm_temp);
708             }
709         }
710     };
711
712     auto store_dst = [=]() {
713         vpxord(zmm_zero, zmm_zero, zmm_zero);
714         mov(oreg_dst, ptr[param1 + GET_OFF(dst)]);
715         mov(oreg_O, ptr[param1 + GET_OFF(M)]);
716         mov(oreg_ydim, ptr[param1 + GET_OFF(tj)]);
717         shl(oreg_ydim, 2); // tj * tile_size (==4)
718         mov(oreg_xdim, ptr[param1 + GET_OFF(ti)]);
719         shl(oreg_xdim, 2); // ti * tilesize (==4)
720
721         if (with_bias)
722             mov(oreg_bias, ptr[param1 + GET_OFF(bias)]);
723
724         auto store_one = [=](int j, int i, bool is_aligned) {
725             auto zmm_O = Xbyak::Zmm(31);
726             auto zmm_relu_ns = Xbyak::Zmm(30);
727             auto xmm_relu_ns = Xbyak::Xmm(30);
728             int offset = (j * tile_size * simd_w + i * simd_w) * typesize;
729
730             vmovups(zmm_O, ptr[oreg_O + offset]);
731             if (is_fwd) {
732                 if (with_bias) {
733                     vaddps(zmm_O, zmm_O, ptr[oreg_bias]);
734                 }
735                 if (with_relu) {
736                     Opmask kmask = Opmask(7);
737                     if (jcp.eltwise_alpha == 0) {
738                         zmm_relu_ns = zmm_zero;
739                     } else {
740                         mov(imm_addr64, float2int(jcp.eltwise_alpha));
741                         vmovq(xmm_relu_ns, imm_addr64);
742                         vbroadcastss(zmm_relu_ns, xmm_relu_ns);
743                     }
744                     vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os);
745                     vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns);
746                 }
747             }
748             if (with_sum) {
749                 vaddps(zmm_O, zmm_O, ptr[oreg_out_j + oreg_temp]);
750                 if (with_relu_postsum) // orig: with_relu_postsum
751                     vmaxps(zmm_O, zmm_O, zmm_zero);
752             }
753             if (is_aligned)
754                 vmovntps(ptr[oreg_out_j + oreg_temp], zmm_O);
755             else
756                 vmovups(ptr[oreg_out_j + oreg_temp], zmm_O);
757         };
758
759         auto i_loop = [=](int j, bool is_aligned) {
760             for (int i = 0; i < tile_size; i++) {
761                 Label next;
762                 mov(oreg_temp, oreg_xdim);
763                 add(oreg_temp, i);
764                 cmp(oreg_temp, outw);
765                 jge(next, T_NEAR);
766                 shl(oreg_temp, 4 + 2); // * 16 * 4
767
768                 store_one(j, i, is_aligned);
769
770                 L(next);
771             }
772         };
773
774
775         for (int j = 0; j < tile_size; j++) {
776             Label next, unaligned;
777             mov(oreg_temp, oreg_ydim);
778             add(oreg_temp, j);
779             cmp(oreg_temp, outh);
780             jge(next, T_NEAR);
781
782             mov(oreg_out_j, oreg_dst);
783             imul(oreg_temp, oreg_temp, outw * simd_w * typesize);
784             add(oreg_out_j, oreg_temp);
785
786             test(oreg_dst, 63);
787             jnz(unaligned, T_NEAR);
788
789             i_loop(j, true);
790             jmp(next, T_NEAR);
791
792             L(unaligned);
793             i_loop(j, false);
794
795             L(next);
796         }
797     };
798
799     auto trans_O_4x4_3x3 = [=]() {
800         auto fma2 = [=](Zmm dst, Zmm v1, Zmm u1, Zmm v2, Zmm u2){
801             vmulps(dst, v1, u1);
802             vfmadd231ps(dst, v2, u2);
803         };
804         mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]);
805         mov(oreg_T, ptr[param1 + GET_OFF(T)]);
806         mov(oreg_O, ptr[param1 + GET_OFF(M)]);
807
808         for (int i = 0; i < alpha; i++) {
809             for (int j = 0; j < alpha; j++) {
810                 vmovups(zmm_O(j), ptr[oreg_Ow + (j * alpha * simd_w
811                     + i * simd_w) * typesize]);
812             }
813
814             vaddps(zmm_t(0), zmm_O(1), zmm_O(2));
815             vaddps(zmm_t(1), zmm_O(3), zmm_O(4));
816             vsubps(zmm_t(2), zmm_O(1), zmm_O(2));
817             vsubps(zmm_t(3), zmm_O(3), zmm_O(4));
818
819             vaddps(zmm_T(0), zmm_t(0), zmm_t(1));
820             vaddps(zmm_T(0), zmm_T(0), zmm_O(0));
821             fma2(zmm_T(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1));
822             fma2(zmm_T(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3));
823             fma2(zmm_T(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5));
824             vaddps(zmm_T(3), zmm_T(3), zmm_O(5));
825
826             for (int j = 0; j < tile_size; j++) {
827                 vmovups(ptr[oreg_T + (j * alpha * simd_w
828                     + i * simd_w) * typesize], zmm_T(j));
829             }
830         }
831         for (int j = 0; j < tile_size; j++) {
832             for (int i = 0; i < alpha; i++) {
833                 vmovups(zmm_T(i), ptr[oreg_T + (j * alpha * simd_w
834                     + i * simd_w) * typesize]);
835             }
836             vaddps(zmm_t(0), zmm_T(1), zmm_T(2));
837             vaddps(zmm_t(1), zmm_T(3), zmm_T(4));
838             vsubps(zmm_t(2), zmm_T(1), zmm_T(2));
839             vsubps(zmm_t(3), zmm_T(3), zmm_T(4));
840
841             vaddps(zmm_O(0), zmm_t(0), zmm_t(1));
842             vaddps(zmm_O(0), zmm_O(0), zmm_T(0));
843             fma2(zmm_O(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1));
844             fma2(zmm_O(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3));
845             fma2(zmm_O(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5));
846             vaddps(zmm_O(3), zmm_O(3), zmm_T(5));
847
848             for (int i = 0; i < tile_size; i++) {
849                 vmovups(ptr[oreg_O + (j * tile_size * simd_w
850                     + i * simd_w) * typesize], zmm_O(i));
851             }
852         }
853     };
854
855     auto inner_loops = [=]() {
856         init_G();
857         load_src();
858         trans_O_4x4_3x3();
859         store_dst();
860     };
861
862     preamble();
863     inner_loops();
864     postamble();
865 }
866
867 void _jit_avx512_core_conv_winograd_data_kernel_f32
868     ::input_transform_data_ker_generate()
869 {
870     bool is_fwd = one_of(jcp.prop_kind,
871         mkldnn_forward_training, mkldnn_forward_inference);
872     int inpw = is_fwd ? jcp.iw : jcp.ow;
873     int inph = is_fwd ? jcp.ih : jcp.oh;
874     int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow;
875     int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh;
876     int wp_max = inpw + l_pad;
877     int hp_max = inph + t_pad;
878     bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D;
879     int G_size = 9;
880
881     auto zmm_zero = Xbyak::Zmm(0);
882     auto zmm_temp = Xbyak::Zmm(31);
883     auto zmm_G = [=](int i) {
884         return Xbyak::Zmm(1 + i);
885     };
886     auto zmm_I = [=](int i) {
887         return Xbyak::Zmm(1 + G_size + i);
888     };
889     auto zmm_T = [=](int i) {
890         return Xbyak::Zmm(1 + G_size + alpha + i);
891     };
892     auto zmm_t = [=](int i) {
893         return Xbyak::Zmm(1 + G_size + 2 * alpha + i);
894     };
895
896     auto init_G = [=]() {
897         mov(ireg_temp, ptr[param1 + GET_OFF(G)]);
898         for (int i = 0; i < G_size; i++) {
899             vbroadcastss(zmm_G(i), ptr[ireg_temp + i * typesize]);
900         }
901     };
902
903     auto load_src = [=]() {
904         mov(ireg_src, ptr[param1 + GET_OFF(src)]); // base addr of inp
905         mov(ireg_I, ptr[param1 + GET_OFF(M)]);
906
907         xor_(ireg_zero,  ireg_zero);
908         vpxord(zmm_zero, zmm_zero, zmm_zero);
909
910         mov(ireg_ydim, ptr[param1 + GET_OFF(tj)]);
911         shl(ireg_ydim, 2); // tj * tile_size (==4)
912         mov(ireg_xdim, ptr[param1 + GET_OFF(ti)]);
913         shl(ireg_xdim, 2); // ti * tilesize (==4)
914
915         for (int j = 0; j < alpha; j++) {
916             mov(ireg_temp, ireg_ydim);
917             add(ireg_temp, j);
918
919             mov(ireg_mask_j, 0xffff);
920             cmp(ireg_temp, t_pad);
921             cmovl(ireg_mask_j, ireg_zero);
922             cmp(ireg_temp, hp_max);
923             cmovge(ireg_mask_j, ireg_zero);
924
925             sub(ireg_temp, t_pad);
926             imul(ireg_temp, ireg_temp, inpw * simd_w * typesize);
927             mov(ireg_inp_j, ireg_src);
928             add(ireg_inp_j, ireg_temp);
929
930             for (int i = 0; i < alpha; i++) {
931
932                 mov(ireg_temp, ireg_xdim);
933                 add(ireg_temp, i);
934
935                 mov(ireg_mask, 0xffff);
936                 cmp(ireg_temp, l_pad);
937                 cmovl(ireg_mask, ireg_zero);
938                 cmp(ireg_temp, wp_max);
939                 cmovge(ireg_mask, ireg_zero);
940                 and_(ireg_mask, ireg_mask_j);
941
942                 sub(ireg_temp, l_pad);
943                 shl(ireg_temp, 4 + 2);
944
945                 vpxord(zmm_temp, zmm_temp, zmm_temp);
946                 Opmask kmask = Opmask(7);
947                 kmovw(kmask, ireg_mask_32);
948                 vmovups(zmm_temp | kmask, ptr[ireg_inp_j + ireg_temp]);
949                 vmovups(ptr[ireg_I + (j * alpha * simd_w + i * simd_w)
950                     * typesize], zmm_temp);
951             }
952         }
953     };
954
955     auto store_Iw = [=]() {
956
957         mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]);
958         mov(ireg_output, ptr[param1 + GET_OFF(dst)]);
959
960        bool streamout
961           = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float)
962             > 2 * LLC_data_size
963             ? true : false;
964
965         if (not_tiled) {
966             mov(ireg_tile_block, ptr[param1 + GET_OFF(tile_block)]);
967             imul(ireg_tile_block, ireg_tile_block,
968                 alpha * alpha * jcp.dimN_block * jcp.dimK_nb_block
969                 * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
970                 * typesize);
971         }
972
973         mov(ireg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]);
974         imul(ireg_nb_tile_block_ur, ireg_nb_tile_block_ur,
975             jcp.dimK_nb_block * jcp.dimK_block * jcp.dimN_reg_block
976             * jcp.dimK_reg_block * typesize);
977
978         mov(ireg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]);
979         imul(ireg_tile_block_ur, ireg_tile_block_ur,
980             jcp.dimK_reg_block * typesize);
981
982         add(ireg_output, ireg_nb_tile_block_ur);
983         add(ireg_output, ireg_tile_block_ur);
984         if (not_tiled)
985             add(ireg_output, ireg_tile_block);
986
987         for (int j = 0; j < alpha; j++) {
988             for (int i = 0; i < alpha; i++) {
989                 vmovups(zmm_temp,ptr[ireg_Iw + (j * alpha * simd_w
990                     + i * simd_w) * typesize]);
991
992                 int j_base_offset =
993                     j * alpha * jcp.dimN_block * jcp.dimK_nb_block
994                     * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
995                     * typesize;
996                 int i_base_offset =
997                     i * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block
998                     * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize;
999
1000                 if (not_tiled && streamout)
1001                     vmovntps(ptr[ireg_output + j_base_offset + i_base_offset],
1002                         zmm_temp);
1003                 else
1004                     vmovups(ptr[ireg_output + j_base_offset + i_base_offset],
1005                         zmm_temp);
1006             }
1007         }
1008     };
1009
1010     auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
1011         vmulps(zmm_temp, a, b);
1012         vaddps(dst, zmm_temp, c);
1013     };
1014
1015     auto trans_I_4x4_3x3 = [=]() {
1016         mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]);
1017         mov(ireg_T, ptr[param1 + GET_OFF(T)]);
1018         mov(ireg_I, ptr[param1 + GET_OFF(M)]);
1019
1020         mov(ireg_output, ptr[param1 + GET_OFF(dst)]); // for prefetch
1021         for (int i = 0; i < alpha; i++) {
1022             for (int idx = 0; idx < alpha; idx++) {
1023                 vmovups(zmm_I(idx), ptr[ireg_I + (idx * alpha * simd_w
1024                     + i * simd_w) * typesize]);
1025                 int j_base_offset =
1026                     i * alpha * jcp.dimN_block * jcp.dimK_nb_block
1027                     * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
1028                     * typesize;
1029                 int idx_base_offset =
1030                     idx * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block
1031                     * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize;
1032                 prefetcht0(ptr[ireg_output + j_base_offset + idx_base_offset]);
1033             }
1034
1035             fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4));
1036             fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3));
1037             fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4));
1038             fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3));
1039             fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4));
1040             fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5));
1041
1042             fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4));
1043             fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0));
1044             fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0));
1045             fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2));
1046             fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2));
1047             fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5));
1048
1049             for (int idx = 0; idx < alpha; idx++) {
1050                 vmovups(ptr[ireg_T + (idx * alpha * simd_w + i * simd_w)
1051                     * typesize],zmm_T(idx));
1052             }
1053         }
1054         for (int i = 0; i < alpha; i++) {
1055             for (int idx = 0; idx < alpha; idx++) {
1056                 vmovups(zmm_T(idx), ptr[ireg_T + (i * alpha * simd_w + idx
1057                     * simd_w) * typesize]);
1058             }
1059
1060             fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4));
1061             fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3));
1062             fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4));
1063             fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3));
1064             fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4));
1065             fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5));
1066
1067             fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4));
1068             fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0));
1069             fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0));
1070             fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2));
1071             fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2));
1072             fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5));
1073
1074             for (int idx = 0; idx < alpha; idx++) {
1075                 vmovups(ptr[ireg_Iw + (i * alpha * simd_w + idx * simd_w)
1076                     * typesize],zmm_I(idx));
1077             }
1078         }
1079     };
1080
1081     auto inner_loops = [=]() {
1082         init_G();
1083         load_src();
1084         trans_I_4x4_3x3();
1085         store_Iw();
1086     };
1087
1088     preamble();
1089     inner_loops();
1090     postamble();
1091 }
1092
1093 status_t _jit_avx512_core_conv_winograd_data_kernel_f32::init_conf_common(
1094         jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1095         const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
1096         const memory_desc_wrapper &dst_d)
1097 {
1098     if (!mayiuse(avx512_core)) {
1099         return status::unimplemented;
1100     }
1101     jcp.ver = ver_avx512_core;
1102     jcp.prop_kind = cd.prop_kind;
1103
1104     const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1105
1106     jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1107     jcp.mb = src_d.dims()[0];
1108     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
1109     jcp.oc_without_padding = jcp.oc;
1110     jcp.ic = src_d.dims()[1] / jcp.ngroups;
1111     jcp.ih = src_d.dims()[2];
1112     jcp.iw = src_d.dims()[3];
1113     jcp.oh = dst_d.dims()[2];
1114     jcp.ow = dst_d.dims()[3];
1115     jcp.kh = weights_d.dims()[with_groups + 2];
1116     jcp.kw = weights_d.dims()[with_groups + 3];
1117     jcp.t_pad = cd.padding[0][0];
1118     jcp.l_pad = cd.padding[0][1];
1119     jcp.stride_h = cd.strides[0];
1120     jcp.stride_w = cd.strides[1];
1121     jcp.dilate_h = cd.dilates[0];
1122     jcp.dilate_w = cd.dilates[1];
1123     jcp.r_pad = nstl::max(
1124             0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
1125     jcp.b_pad = nstl::max(
1126             0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
1127     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1128     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1129     jcp.ohp = jcp.oh;
1130     jcp.owp = jcp.ow;
1131
1132     bool ok_to_pad_channels = jcp.ngroups == 1;
1133     if (ok_to_pad_channels) {
1134         jcp.oc = rnd_up(jcp.oc, simd_w);
1135         jcp.ic = rnd_up(jcp.ic, simd_w);
1136     }
1137
1138     // Checking conditions not supported by these kernels
1139     if (jcp.ngroups != 1)
1140         return status::unimplemented;
1141     if ((jcp.kh != 3) || (jcp.kw != 3))
1142         return status::unimplemented;
1143     if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
1144         return status::unimplemented;
1145     if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
1146         return status::unimplemented;
1147     if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
1148         return status::unimplemented;
1149
1150     if (src_d.format() != nChw16c)
1151         return status::unimplemented;
1152     if (weights_d.format() != (with_groups ? gOIhw16i16o : OIhw16i16o))
1153         return status::unimplemented;
1154     if (dst_d.format() != nChw16c)
1155         return status::unimplemented;
1156
1157     bool layout_consistency = true
1158         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
1159         && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
1160         && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
1161         && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
1162     if (!layout_consistency) return status::unimplemented;
1163
1164     return status::success;
1165 }
1166
1167 void set_kernel_dims_reg_block(jit_conv_winograd_conf_t &jcp) {
1168
1169     /* ----------- dimM reg block ---------------------*/
1170     auto test_cond_dimM_reg_block = [](jit_conv_winograd_conf_t &jcp,
1171             int dimM_reg_block, int current_best) {
1172         int max_dimM_reg_block = jcp.kernel_kind == embd_bcast ? 1 : 4;
1173         return (dimM_reg_block >= 1)
1174                 && (dimM_reg_block <= max_dimM_reg_block )
1175                 && (dimM_reg_block > current_best);
1176     };
1177     jcp.dimM_reg_block = get_divisor_satisfying_cond(jcp,
1178         jcp.dimM/jcp.dimM_simd_block, 1, test_cond_dimM_reg_block);
1179
1180     /* ----------- dimN reg block ---------------------*/
1181
1182     auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
1183             int dimN_reg_block, int current_best) {
1184         return jcp.kernel_kind == embd_bcast
1185             ? dimN_reg_block < jcp.nb_reg && dimN_reg_block > current_best
1186             : dimN_reg_block >= 1
1187               && (dimN_reg_block * jcp.dimM_reg_block + dimN_reg_block)
1188                  < jcp.nb_reg
1189               && dimN_reg_block > current_best;
1190     };
1191     jcp.dimN_reg_block = get_divisor_satisfying_cond(jcp,
1192         jcp.dimN, 1, test_cond_dimN_reg_block);
1193 }
1194
1195 status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) {
1196     if (jcp.ver != ver_avx512_core)
1197         return status::unimplemented;
1198
1199     jcp.kernel_kind = embd_bcast;
1200
1201     set_kernel_dims_reg_block(jcp);
1202
1203     /*-------------- L2 blocking for dimN block ---------*/
1204
1205     auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp,
1206         int dimN_block, int current_best) {
1207         return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0)
1208             && (dimN_block > current_best)
1209             && ((jcp.dimN / dimN_block / jcp.dimN_reg_block)
1210             >= 1.5 * omp_get_max_threads());
1211     };
1212
1213     jcp.dimN_block = get_divisor_satisfying_cond(
1214             jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block);
1215     jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
1216
1217     if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2)
1218         && (jcp.dimN_nb_block >= 1.5 * omp_get_max_threads())) {
1219
1220         /* ------------------- L1 blocking for GEMM --------------*/
1221         /* -------------------- Choose dimK block ----------------*/
1222
1223         auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp,
1224                 int dimK_block, int current_best) {
1225             return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.5)
1226                 && (dimK_block > current_best);
1227         };
1228
1229         jcp.dimK_block = get_divisor_satisfying_cond(
1230                 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block);
1231
1232         if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 1.0)) {
1233             jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
1234
1235             /* -------------- Choose dimM block -------------------*/
1236             auto test_cond_dimM_block = [](jit_conv_winograd_conf_t &jcp,
1237                     int dimM_block, int current_best) {
1238                 return check_L1_block_gemm(jcp, jcp.dimK_block, dimM_block,
1239                     0.2, 0.5) && (dimM_block > current_best);
1240             };
1241
1242             jcp.dimM_block = get_divisor_satisfying_cond(jcp,
1243                 jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1,
1244                 test_cond_dimM_block);
1245             jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block
1246                 / jcp.dimM_simd_block;
1247
1248             jcp.sched_policy = WSCHED_DATA_W_SGD;
1249             return status::success;
1250         }
1251
1252     }
1253     return status::unimplemented;
1254 }
1255
1256 void set_kernel_blocking_DATA_W_S_G_D(jit_conv_winograd_conf_t &jcp) {
1257
1258     set_kernel_dims_reg_block(jcp);
1259
1260     //********************* Choosing dimK_block **********************//
1261     auto test_cond1_dimK_block = [](
1262             jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
1263         return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block,
1264                        1, jcp.dimM_reg_block, jcp.dimM_simd_block, .75f)
1265                 && (dimK_block > current_best);
1266     };
1267
1268     auto test_cond1_bis_dimK_block = [](
1269             jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
1270         return check_cond1_bis(jcp.dimN_reg_block, dimK_block,
1271                    jcp.dimK_reg_block, 1, jcp.dimM_reg_block,
1272                    jcp.dimM_simd_block, .9f)
1273                 && (dimK_block > current_best);
1274     };
1275
1276     jcp.dimK_block = get_divisor_satisfying_cond(
1277             jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block);
1278     // If we are not able to use streams, we fall back to condition [1]
1279     if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
1280         jcp.dimK_block = get_divisor_satisfying_cond(
1281                 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block);
1282     jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block;
1283
1284     //********************* Choosing dimM_block **********************//
1285     auto test_cond1_dimM_block = [](
1286             jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
1287         return check_cond1(jcp.dimN_reg_block, jcp.dimK_block,
1288                    jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block,
1289                    jcp.dimM_simd_block, .5f)
1290                 && (dimM_block > current_best);
1291     };
1292
1293     auto test_cond1_bis_dimM_block = [](
1294             jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
1295         return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block,
1296                    jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block,
1297                    jcp.dimM_simd_block, .3f)
1298                 && (dimM_block > current_best);
1299     };
1300
1301     if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
1302         jcp.dimM_block = get_divisor_satisfying_cond(
1303                 jcp, jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1,
1304                 test_cond1_dimM_block);
1305     else
1306         jcp.dimM_block = get_divisor_satisfying_cond(jcp,
1307                 jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1,
1308                 test_cond1_bis_dimM_block);
1309     jcp.dimM_nb_block = jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_block
1310                         * jcp.dimM_reg_block);
1311
1312     //******************* Choosing dimN_block *******************//
1313     auto test_cond2_dimN_block = [](
1314             jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
1315         return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block,
1316                        jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block,
1317                        jcp.dimM_reg_block, jcp.dimM_simd_block, .9f)
1318                 && (dimN_block > current_best);
1319     };
1320
1321     jcp.dimN_block = get_divisor_satisfying_cond(
1322             jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
1323     jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block);
1324 }
1325
1326 status_t set_wsched_DATA_W_S_G_D_avx512_core(jit_conv_winograd_conf_t &jcp) {
1327
1328     jcp.kernel_kind = expl_bcast;
1329     set_kernel_blocking_DATA_W_S_G_D(jcp);
1330     if (!(check_kernel_cond(jcp.dimM_block, jcp.dimM_reg_block,
1331         jcp.dimM_simd_block, jcp.dimN_block, jcp.dimN_reg_block, jcp.dimK,
1332         .1f, .35f))) {
1333         jcp.kernel_kind = embd_bcast;
1334         set_kernel_blocking_DATA_W_S_G_D(jcp);
1335     }
1336     jcp.sched_policy = WSCHED_DATA_W_S_G_D;
1337     return status::success;
1338 }
1339
1340 status_t _jit_avx512_core_conv_winograd_data_kernel_f32::init_conf_kernel(
1341         jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK)
1342 {
1343     jcp.nb_reg = 32;
1344     jcp.dimN = dimN;
1345     jcp.dimK = dimK;
1346     jcp.dimM = dimM;
1347     jcp.sched_policy = WSCHED_INVALID;
1348
1349     jcp.dimK_reg_block = 16;
1350     jcp.dimM_simd_block = 16;
1351
1352     if (jcp.kernel_kind == embd_bcast) {
1353         jcp.dimM_reg_block = 1;
1354     }
1355
1356     if (!(set_wsched_DATA_W_SGD_avx512_core(jcp) == status::success))
1357         set_wsched_DATA_W_S_G_D_avx512_core(jcp);
1358
1359     assert(jcp.sched_policy != WSCHED_INVALID);
1360     return status::success;
1361 }
1362
1363 bool jit_avx512_core_conv_winograd_fwd_kernel_f32::post_ops_ok(
1364         jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
1365     const auto &p = attr.post_ops_;
1366
1367     auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
1368     auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
1369
1370     switch (p.len_) {
1371     case 0:
1372         return true; // no post_ops
1373     case 1:
1374         return true // relu or sum
1375                 && implication(jcp.with_eltwise, is_sum(0))
1376                 && implication(!jcp.with_eltwise, is_eltwise(0) || is_sum(0));
1377     case 2:
1378         return true // sum->relu or relu->sum
1379                 && implication(jcp.with_eltwise, is_sum(0) && is_eltwise(1))
1380                 && implication(!jcp.with_eltwise, false
1381                                    || (is_sum(0) && is_eltwise(1))
1382                                    || (is_eltwise(0) && is_sum(1)));
1383     case 3:
1384         return true // relu->sum->relu
1385                 && jcp.with_eltwise == false
1386                 && (is_eltwise(0) && is_sum(1) && is_eltwise(2));
1387     default:
1388         return false;
1389     }
1390
1391     return false;
1392 }
1393
1394 status_t jit_avx512_core_conv_winograd_fwd_kernel_f32::init_conf(
1395         jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1396         const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
1397         const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
1398         bool with_relu, float relu_negative_slope) {
1399     status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d);
1400
1401     if (st != status::success)
1402         return st;
1403
1404     // Winograd specific initialization
1405     jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
1406     jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
1407     jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
1408
1409     jcp.with_bias = cd.bias_desc.format != memory_format::undef;
1410     jcp.with_eltwise = with_relu;
1411     jcp.eltwise_alpha = relu_negative_slope;
1412
1413     if (!post_ops_ok(jcp, attr))
1414         return status::unimplemented;
1415
1416     const auto &p = attr.post_ops_;
1417     if (!jcp.with_eltwise) {
1418         /* PostOps ReLU before SUM is handled the same as ReLU primitive */
1419         jcp.with_eltwise = p.find(primitive_kind::eltwise, 0, 1) != -1;
1420         jcp.eltwise_alpha = 0.f;
1421     }
1422     jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
1423     jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1;
1424
1425     status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic);
1426
1427     jcp.ic_simd_block = jcp.dimK_reg_block;
1428     jcp.ic_block = jcp.dimK_block;
1429     jcp.nb_ic = jcp.dimK_nb_block;
1430     jcp.oc_simd_block = jcp.dimM_simd_block;
1431     jcp.oc_block = jcp.dimM_block;
1432     jcp.oc_reg_block = jcp.dimM_reg_block;
1433     jcp.ic_reg_block = 1;
1434     jcp.nb_oc = jcp.dimM_nb_block;
1435     jcp.tile_block_ur = jcp.dimN_reg_block;
1436     jcp.nb_tile_block_ur = jcp.dimN_block;
1437     jcp.tile_block = jcp.dimN_nb_block;
1438
1439     return res;
1440 }
1441
1442 status_t jit_avx512_core_conv_winograd_bwd_data_kernel_f32::init_conf(
1443         jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1444         const memory_desc_wrapper &diff_src_d,
1445         const memory_desc_wrapper &weights_d,
1446         const memory_desc_wrapper &diff_dst_d)
1447 {
1448     status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d);
1449
1450     if (st != status::success)
1451         return st;
1452
1453     jcp.itiles = (jcp.iw + tile_size - 1) / tile_size;
1454     jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size;
1455     jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
1456
1457     status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc);
1458
1459     jcp.oc_simd_block = jcp.dimK_reg_block;
1460     jcp.oc_block = jcp.dimK_block;
1461     jcp.nb_oc = jcp.dimK_nb_block;
1462     jcp.ic_simd_block = jcp.dimM_simd_block;
1463     jcp.ic_block = jcp.dimM_block;
1464     jcp.ic_reg_block = jcp.dimM_reg_block;
1465     jcp.oc_reg_block = 1;
1466     jcp.nb_ic = jcp.dimM_nb_block;
1467     jcp.tile_block_ur = jcp.dimN_reg_block;
1468     jcp.nb_tile_block_ur = jcp.dimN_block;
1469     jcp.tile_block = jcp.dimN_nb_block;
1470
1471     return res;
1472 }
1473
1474 void jit_avx512_core_conv_winograd_bwd_weights_kernel_f32::
1475 src_transform_generate() {
1476     constexpr int G_size = 9;
1477     const size_t ifwp = jcp.iw + jcp.l_pad;
1478     const size_t ifhp = jcp.ih + jcp.t_pad;
1479
1480     auto zmm_G = [=](int i) {
1481         return Xbyak::Zmm(i);
1482     };
1483     auto zmm_I = [=](int i) {
1484         return Xbyak::Zmm(G_size + i);
1485     };
1486     auto zmm_T = [=](int i) {
1487         return Xbyak::Zmm(G_size + alpha + i);
1488     };
1489     auto zmm_t = [=](int i) {
1490         return Xbyak::Zmm(G_size + 2 * alpha + i);
1491     };
1492
1493     auto init_G = [=]() {
1494         mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1495         for (int i = 0; i < G_size; i++) {
1496             vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]);
1497         }
1498     };
1499
1500     auto load_src = [=]() {
1501         mov(reg_I, ptr[reg_transp + GET_OFF(M)]);
1502         xor_(reg_zero, reg_zero);
1503
1504         mov(reg_ydim, reg_tj);
1505         shl(reg_ydim, 2); //tj * tile_size(=4)
1506
1507         for (int j = 0; j < alpha; j++) {
1508             /* check if tile index is within physical spatial boundaries*/
1509             mov(reg_maskj, 0xffff);
1510             cmp(reg_ydim, jcp.t_pad);
1511             cmovl(reg_maskj, reg_zero);
1512             cmp(reg_ydim, ifhp);
1513             cmovge(reg_maskj, reg_zero);
1514
1515             /*address offset for tile in src*/
1516             mov(reg_src_offset, reg_ydim);
1517             sub(reg_src_offset, jcp.t_pad); // tj*tile_size - t_pad
1518             imul(reg_src_offset, reg_src_offset, jcp.iw);
1519
1520             mov(reg_xdim, reg_ti);
1521             shl(reg_xdim, 2); // xdim = ti * tile_size
1522
1523             add(reg_src_offset, reg_xdim);
1524             sub(reg_src_offset, jcp.l_pad);
1525             imul(reg_src_offset, reg_src_offset, simd_w * typesize);
1526             for (int i = 0; i < alpha; i++) {
1527                 /* check if tile index is within physical spatial boundaries*/
1528                 mov(reg_maski, 0xffff);
1529                 cmp(reg_xdim, jcp.l_pad);
1530                 cmovl(reg_maski, reg_zero);
1531                 cmp(reg_xdim, ifwp);
1532                 cmovge(reg_maski, reg_zero);
1533                 and_(reg_maski, reg_maskj);
1534
1535                 Opmask kmask_src = Xbyak::Opmask(7);
1536                 auto zmm_src = Xbyak::Zmm(31);
1537                 kmovw(kmask_src, reg_maski_32);
1538                 vpxord(zmm_src, zmm_src, zmm_src);
1539                 vmovups(zmm_src | kmask_src, ptr[reg_src + reg_src_offset]);
1540                 vmovups(ptr[reg_I], zmm_src);
1541
1542                 add(reg_xdim, 1); //xdim = ti * tile_size + i
1543                 add(reg_src_offset, simd_w * typesize);
1544                 add(reg_I, simd_w * typesize);
1545             }
1546             add(reg_ydim, 1);
1547         }
1548     };
1549
1550     auto fma4 = [=](Xbyak::Zmm dst, Xbyak::Zmm a, Xbyak::Zmm b, Xbyak::Zmm c) {
1551         vmovups(dst, c);
1552         vfmadd231ps(dst, a, b);
1553     };
1554
1555     auto trans_I_3x3_4x4 = [=]() {
1556         //Use 24 registers
1557         mov(reg_I, ptr[reg_transp + GET_OFF(M)]);
1558         mov(reg_T, ptr[reg_transp + GET_OFF(T)]);
1559         for (int i = 0; i < alpha; i++) {
1560             for (int j = 0; j < alpha; j++) {
1561                 size_t I_off = (j * alpha + i) * simd_w * typesize;
1562                 vmovups(zmm_I(j), ptr[reg_I + I_off]);
1563             }
1564
1565             fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4));
1566             fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3));
1567             fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4));
1568             fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3));
1569             fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4));
1570             fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5));
1571
1572             fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4));
1573             fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0));
1574             fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0));
1575             fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2));
1576             fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2));
1577             fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5));
1578
1579             for (int j = 0; j < alpha; j++) {
1580                 vmovups(ptr[reg_T + (j * alpha + i) * simd_w * typesize],
1581                         zmm_T(j));
1582             }
1583
1584         }
1585
1586         for (int j = 0; j < alpha; j++) {
1587             for (int i = 0; i < alpha; i++) {
1588                 vmovups(zmm_T(i), ptr[reg_T + (j * alpha + i) * simd_w * typesize]);
1589             }
1590
1591             fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4));
1592             fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3));
1593             fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4));
1594             fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3));
1595             fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4));
1596             fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5));
1597
1598             fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4));
1599             fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0));
1600             fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0));
1601             fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2));
1602             fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2));
1603             fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5));
1604
1605             for (int i = 0; i < alpha; i++) {
1606                 size_t dst_off = (j * alpha * jcp.ic_block
1607                     * jcp.nb_tile_block_ur * jcp.tile_block_ur
1608                     + i * jcp.ic_block * jcp.nb_tile_block_ur * jcp.tile_block_ur)
1609                     * simd_w * typesize;
1610                 vmovups(ptr[reg_dst + dst_off], zmm_I(i));
1611             }
1612         }
1613     };
1614
1615     auto compute_transform_SDGtWo = [=]() {
1616         mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]);
1617         mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]);
1618         mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1619         mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1620         xor_(reg_tile_count, reg_tile_count);
1621         Label loop_mb, loop_jtiles, loop_itiles, done;
1622         L(loop_mb);
1623         {
1624             L(loop_jtiles);
1625             {
1626                 L(loop_itiles);
1627                 {
1628                     load_src();
1629
1630                     trans_I_3x3_4x4();
1631
1632                     add(reg_tile_count, 1);
1633                     cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1634                     jge(done);
1635
1636                     add(reg_dst, simd_w * typesize);
1637                     add(reg_ti, 1);
1638                     cmp(reg_ti, jcp.itiles);
1639                     jl(loop_itiles);
1640                 }
1641                 xor_(reg_ti, reg_ti);
1642                 add(reg_tj, 1);
1643                 cmp(reg_tj, jcp.jtiles);
1644                 jl(loop_jtiles);
1645             }
1646             xor_(reg_tj, reg_tj);
1647             add(reg_src, jcp.ic * jcp.iw * jcp.ih * typesize);
1648             jmp(loop_mb);
1649         }
1650         L(done);
1651     };
1652
1653     auto compute_transform = [=]() {
1654         mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1655         xor_(reg_ti, reg_ti);
1656         xor_(reg_tj, reg_tj);
1657
1658         mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1659         mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
1660         imul(reg_temp, reg_tile_count, simd_w * typesize);
1661         add(reg_dst, reg_temp);
1662
1663         Label loop_jtiles, loop_itiles, next_tile_block, next_tile;
1664         L(loop_jtiles);
1665
1666         {
1667             L(loop_itiles);
1668             {
1669                 load_src();
1670
1671                 trans_I_3x3_4x4();
1672
1673                 add(reg_tile_count, 1);
1674                 cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1675                 jge(next_tile_block);
1676                 add(reg_dst, simd_w * typesize);
1677                 jmp(next_tile);
1678
1679                 L(next_tile_block);
1680                 sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1)
1681                         * simd_w * typesize);
1682                 size_t tblk_off = alpha * alpha * jcp.ic_block
1683                     * jcp.nb_tile_block_ur * jcp.tile_block_ur
1684                     * simd_w * typesize;
1685                 add(reg_dst, tblk_off);
1686                 xor_(reg_tile_count, reg_tile_count);
1687
1688                 L(next_tile);
1689                 add(reg_ti, 1);
1690                 cmp(reg_ti, jcp.itiles);
1691                 jl(loop_itiles);
1692             }
1693             xor_(reg_ti, reg_ti);
1694             add(reg_tj, 1);
1695             cmp(reg_tj, jcp.jtiles);
1696             jl(loop_jtiles);
1697         }
1698     };
1699
1700     preamble();
1701     init_G();
1702     if (jcp.sched_policy == WSCHED_WEI_SDGtWo)
1703         compute_transform_SDGtWo();
1704     else
1705         compute_transform();
1706     postamble();
1707 }
1708
1709 void jit_avx512_core_conv_winograd_bwd_weights_kernel_f32::
1710 diff_dst_transform_generate(bool with_bias) {
1711
1712     constexpr int G_size = 8;
1713     auto zmm_G = [](int i) {
1714         return Xbyak::Zmm(31);
1715     };
1716
1717     auto zmm_src = [=](int j, int i) {
1718         return Xbyak::Zmm(G_size + j * 4 + i);
1719     };
1720
1721     auto zmm_bias = Xbyak::Zmm(31);
1722
1723     auto load_src = [=]() {
1724         if (with_bias) vmovups(zmm_bias, ptr[reg_bias]);
1725         mov(reg_ydim, reg_tj);
1726         shl(reg_ydim, 2); //tj * tile_size(=4)
1727         for (int j = 0; j < tile_size; j++) {
1728             /* check if tile index is within physical spatial boundaries*/
1729             mov(reg_maskj, 0xffff);
1730             cmp(reg_ydim, jcp.oh);
1731             cmovge(reg_maskj, reg_zero);
1732
1733             /*address offset for tile in src*/
1734             mov(reg_src_offset, reg_ydim);
1735             imul(reg_src_offset, reg_src_offset, jcp.ow);
1736
1737             mov(reg_xdim, reg_ti);
1738             shl(reg_xdim, 2); // xdim = ti * tile_size
1739
1740             add(reg_src_offset, reg_xdim);
1741             imul(reg_src_offset, reg_src_offset, simd_w * typesize);
1742             for (int i = 0; i < tile_size; i++) {
1743                 /* check if tile index is within physical spatial boundaries*/
1744                 mov(reg_maski, 0xffff);
1745                 cmp(reg_xdim, jcp.ow);
1746                 cmovge(reg_maski, reg_zero);
1747                 and_(reg_maski, reg_maskj);
1748
1749                 Opmask kmask_src = Xbyak::Opmask(7);
1750                 kmovw(kmask_src, reg_maski_32);
1751                 vpxord(zmm_src(j, i), zmm_src(j, i), zmm_src(j, i));
1752                 vmovups(zmm_src(j, i) | kmask_src, ptr[reg_src + reg_src_offset]);
1753                 if (with_bias) vaddps(zmm_bias | kmask_src, zmm_bias,
1754                         ptr[reg_src + reg_src_offset]);
1755
1756                 add(reg_xdim, 1); //xdim = ti * tile_size + i
1757                 add(reg_src_offset, simd_w * typesize);
1758             }
1759             add(reg_ydim, 1);
1760         }
1761         if(with_bias) vmovups(ptr[reg_bias], zmm_bias);
1762     };
1763
1764     auto zmm_t = [=](int i) {
1765         return Xbyak::Zmm(G_size + 16 + i);
1766     };
1767
1768     auto zmm_T = [=](int j, int i) {
1769         return Xbyak::Zmm(j * 4 + i);
1770     };
1771
1772     auto movps = [=](Xbyak::Reg64 reg_dst, size_t dst_off, Xbyak::Zmm a) {
1773         if (jcp.sched_policy == WSCHED_WEI_SDGtWo)
1774             vmovups(ptr[reg_dst + dst_off], a);
1775         else
1776             vmovntps(ptr[reg_dst + dst_off], a);
1777     };
1778
1779     auto trans_W_3x3_4x4 = [=]() {
1780         mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1781         for (int i = 0; i < tile_size; i++) {
1782             vbroadcastss(zmm_G(0), ptr[reg_G]);
1783             vmulps(zmm_t(0), zmm_src(2, i), zmm_G(0));
1784
1785             vbroadcastss(zmm_G(1), ptr[reg_G + typesize]);
1786             vmovups(zmm_t(1), zmm_t(0));
1787             vfmsub231ps(zmm_t(1), zmm_src(0, i), zmm_G(1));
1788
1789             vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]);
1790             vmovups(zmm_t(2), zmm_t(0));
1791             vfmadd231ps(zmm_t(2), zmm_src(0, i), zmm_G(2));
1792
1793             vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]);
1794             vmulps(zmm_t(3), zmm_src(1, i), zmm_G(3));
1795
1796             vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]);
1797             vfmadd231ps(zmm_t(3), zmm_src(3, i), zmm_G(4));
1798
1799             vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]);
1800             vmulps(zmm_t(4), zmm_src(1, i), zmm_G(5));
1801
1802             vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]);
1803             vfmadd231ps(zmm_t(4), zmm_src(3, i), zmm_G(6));
1804
1805             vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]);
1806             vmulps(zmm_T(0, i), zmm_src(0, i), zmm_G(7));
1807             vsubps(zmm_T(1, i), zmm_t(1), zmm_t(3));
1808             vaddps(zmm_T(2, i), zmm_t(1), zmm_t(3));
1809             vaddps(zmm_T(3, i), zmm_t(2), zmm_t(4));
1810             vsubps(zmm_T(4, i), zmm_t(2), zmm_t(4));
1811             vmovups(zmm_T(5, i), zmm_src(3, i));
1812         }
1813
1814         for (int j = 0; j < alpha; j++) {
1815             vbroadcastss(zmm_G(0), ptr[reg_G]);
1816             vmulps(zmm_t(0), zmm_T(j, 2), zmm_G(0));
1817
1818             vbroadcastss(zmm_G(1), ptr[reg_G + typesize]);
1819             vmovups(zmm_t(1), zmm_t(0));
1820             vfmsub231ps(zmm_t(1), zmm_T(j, 0), zmm_G(1));
1821
1822             vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]);
1823             vmovups(zmm_t(2), zmm_t(0));
1824             vfmadd231ps(zmm_t(2), zmm_T(j, 0), zmm_G(2));
1825
1826             vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]);
1827             vmulps(zmm_t(3), zmm_T(j, 1), zmm_G(3));
1828
1829             vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]);
1830             vfmadd231ps(zmm_t(3), zmm_T(j, 3), zmm_G(4));
1831
1832             vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]);
1833             vmulps(zmm_t(4), zmm_T(j, 1), zmm_G(5));
1834
1835             vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]);
1836             vfmadd231ps(zmm_t(4), zmm_T(j, 3), zmm_G(6));
1837
1838             vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]);
1839             vmulps(zmm_t(0), zmm_T(j, 0), zmm_G(7));
1840             vsubps(zmm_t(5), zmm_t(1), zmm_t(3));
1841             vaddps(zmm_t(1), zmm_t(1), zmm_t(3));
1842             vaddps(zmm_t(6), zmm_t(2), zmm_t(4));
1843             vsubps(zmm_t(2), zmm_t(2), zmm_t(4));
1844             vmovups(zmm_t(3), zmm_T(j, 3));
1845
1846             size_t alpha_offset = jcp.oc/jcp.nb_oc * jcp.ntiles/jcp.tile_block
1847                 * typesize;
1848             size_t dst_off = (j * alpha * alpha_offset);
1849             movps(reg_dst, dst_off, zmm_t(0));
1850             dst_off += alpha_offset;
1851             movps(reg_dst, dst_off, zmm_t(5));
1852             dst_off += alpha_offset;
1853             movps(reg_dst, dst_off, zmm_t(1));
1854             dst_off += alpha_offset;
1855             movps(reg_dst, dst_off, zmm_t(6));
1856             dst_off += alpha_offset;
1857             movps(reg_dst, dst_off, zmm_t(2));
1858             dst_off += alpha_offset;
1859             movps(reg_dst, dst_off, zmm_t(3));
1860         }
1861
1862     };
1863     auto compute_transform_SDGtWo = [=]() {
1864         mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1865         mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1866         if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]);
1867
1868         xor_(reg_zero, reg_zero);
1869         xor_(reg_oc_ur, reg_oc_ur);
1870         Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, tiles_done;
1871
1872         L(loop_oc_ur);
1873         {
1874             mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]);
1875             mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]);
1876             xor_(reg_tile_count, reg_tile_count);
1877             L(loop_mb);
1878             {
1879                 L(loop_jtiles);
1880                 {
1881                     L(loop_itiles);
1882                     {
1883                         load_src();
1884
1885                         trans_W_3x3_4x4();
1886
1887                         add(reg_tile_count, 1);
1888                         cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1889                         jge(tiles_done);
1890
1891                         add(reg_dst, jcp.oc_reg_block * simd_w * typesize);
1892                         add(reg_ti, 1);
1893                         cmp(reg_ti, jcp.itiles);
1894                         jl(loop_itiles);
1895                     }
1896                     xor_(reg_ti, reg_ti);
1897                     add(reg_tj, 1);
1898                     cmp(reg_tj, jcp.jtiles);
1899                     jl(loop_jtiles);
1900                 }
1901                 xor_(reg_tj, reg_tj);
1902                 add(reg_src, jcp.oc * jcp.ow * jcp.oh * typesize);
1903                 jmp(loop_mb);
1904             }
1905
1906             L(tiles_done);
1907             mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1908             add(reg_dst, simd_w * typesize);
1909             mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1910             add(reg_src, jcp.oh * jcp.ow * simd_w * typesize);
1911
1912             if (with_bias) add(reg_bias, simd_w * typesize);
1913             add(reg_oc_ur, 1);
1914             cmp(reg_oc_ur, jcp.oc_reg_block);
1915             jl(loop_oc_ur);
1916         }
1917     };
1918
1919     auto compute_transform = [=]() {
1920         mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1921         mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1922         if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]);
1923
1924         mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1925         mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
1926         imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize);
1927         add(reg_dst, reg_temp);
1928
1929         xor_(reg_zero, reg_zero);
1930         xor_(reg_oc_ur, reg_oc_ur);
1931         Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, next_tile_block, next_tile;
1932
1933         L(loop_oc_ur);
1934         {
1935             xor_(reg_ti, reg_ti);
1936             xor_(reg_tj, reg_tj);
1937
1938             L(loop_jtiles);
1939             {
1940                 L(loop_itiles);
1941                 {
1942                     load_src();
1943
1944                     trans_W_3x3_4x4();
1945
1946                     add(reg_tile_count, 1);
1947                     cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1948                     jge(next_tile_block);
1949                     add(reg_dst, jcp.oc_reg_block * simd_w * typesize);
1950                     jmp(next_tile);
1951
1952                     L(next_tile_block);
1953                     sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1)
1954                             * jcp.oc_reg_block * simd_w * typesize);
1955                     size_t tblk_off = alpha * alpha * jcp.oc/jcp.nb_oc
1956                         * jcp.ntiles/jcp.tile_block * typesize;
1957                     add(reg_dst, tblk_off);
1958                     xor_(reg_tile_count, reg_tile_count);
1959
1960                     L(next_tile);
1961                     add(reg_ti, 1);
1962                     cmp(reg_ti, jcp.itiles);
1963                     jl(loop_itiles);
1964                 }
1965                 xor_(reg_ti, reg_ti);
1966                 add(reg_tj, 1);
1967                 cmp(reg_tj, jcp.jtiles);
1968                 jl(loop_jtiles);
1969             }
1970
1971             mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1972             mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
1973             imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize);
1974             add(reg_dst, reg_temp);
1975             add(reg_dst, simd_w * typesize);
1976             mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1977             add(reg_src, jcp.oh * jcp.ow * simd_w * typesize);
1978
1979             if (with_bias) add(reg_bias, simd_w * typesize);
1980             add(reg_oc_ur, 1);
1981             cmp(reg_oc_ur, jcp.oc_reg_block);
1982             jl(loop_oc_ur);
1983         }
1984     };
1985
1986     preamble();
1987     if (jcp.sched_policy == WSCHED_WEI_SDGtWo) {
1988         compute_transform_SDGtWo();
1989     } else {
1990         compute_transform();
1991     }
1992     postamble();
1993 }
1994
1995 void jit_avx512_core_conv_winograd_bwd_weights_kernel_f32::
1996 diff_weights_transform_generate(bool first_tile) {
1997     int G_size = 4;
1998
1999     auto zmm_G = [](int i) {
2000         return Xbyak::Zmm(i);
2001     };
2002
2003     auto init_G = [=]() {
2004         mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
2005         for (int i = 0; i  < G_size; i++)
2006             vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]);
2007     };
2008
2009     auto zmm_src = [=](int i) {
2010         return Xbyak::Zmm(G_size + i);
2011     };
2012
2013     auto load_src = [=](int i) {
2014         for (int j = 0; j < alpha; j++) {
2015             size_t alpha_offset = jcp.oc_block * jcp.oc_reg_block
2016                 * jcp.ic_block * simd_w * simd_w * typesize;
2017             size_t src_off = (j * alpha + i) * alpha_offset;
2018             vmovups(zmm_src(j), EVEX_compress_addr(reg_src, src_off));
2019         }
2020     };
2021
2022     auto zmm_t = [=](int i) {
2023         return Xbyak::Zmm(G_size + 6 + i);
2024     };
2025
2026     auto zmm_T = [=](int j, int i) {
2027         return Xbyak::Zmm(G_size + 6 + 3 + j * 6 + i);
2028     };
2029
2030     auto zmm_dst = [=](int i) {
2031         return Xbyak::Zmm(G_size + i);
2032     };
2033
2034     auto zmm_temp = Xbyak::Zmm(31);
2035
2036     auto store_dst = [=](int j) {
2037         for (int i = 0; i < jcp.kw; i++) {
2038             size_t dst_off = (j * jcp.kw + i) * simd_w * simd_w * typesize;
2039
2040             if (!first_tile) {
2041                 vmovups(zmm_temp, EVEX_compress_addr(reg_dst, dst_off));
2042                 vaddps(zmm_dst(i), zmm_dst(i), zmm_temp);
2043             }
2044             vmovntps(EVEX_compress_addr(reg_dst, dst_off), zmm_dst(i));
2045         }
2046     };
2047
2048     auto compute_transform = [=] () {
2049         mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
2050         mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
2051
2052         xor_(reg_ic_simd, reg_ic_simd);
2053         Label loop_ic_simd;
2054         L(loop_ic_simd);
2055         {
2056             for (int i = 0; i < alpha; i++) {
2057                 load_src(i);
2058
2059                 vaddps(zmm_t(0), zmm_src(1), zmm_src(2));
2060                 vaddps(zmm_t(1), zmm_src(3), zmm_src(4));
2061                 vmovups(zmm_t(2), zmm_src(5));
2062                 vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0));
2063
2064                 vaddps(zmm_T(0, i), zmm_src(0), zmm_t(0));
2065                 vaddps(zmm_T(0, i), zmm_T(0, i), zmm_t(1));
2066                 vsubps(zmm_T(1, i), zmm_src(1), zmm_src(2));
2067                 vmulps(zmm_T(1, i), zmm_T(1, i), zmm_G(1));
2068                 vsubps(zmm_temp, zmm_src(3), zmm_src(4));
2069                 vfmadd231ps(zmm_T(1, i), zmm_temp, zmm_G(2));
2070                 vmovups(zmm_T(2, i), zmm_t(2));
2071                 vfmadd231ps(zmm_T(2, i), zmm_t(0), zmm_G(3));
2072             }
2073
2074             for (int j = 0; j < jcp.kh; j++) {
2075                 vaddps(zmm_t(0), zmm_T(j, 1), zmm_T(j, 2));
2076                 vaddps(zmm_t(1), zmm_T(j, 3), zmm_T(j, 4));
2077                 vmovups(zmm_t(2), zmm_T(j, 5));
2078                 vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0));
2079
2080                 vaddps(zmm_dst(0), zmm_T(j, 0), zmm_t(0));
2081                 vaddps(zmm_dst(0), zmm_dst(0), zmm_t(1));
2082                 vsubps(zmm_dst(1), zmm_T(j, 1), zmm_T(j, 2));
2083                 vmulps(zmm_dst(1), zmm_dst(1), zmm_G(1));
2084                 vsubps(zmm_temp, zmm_T(j, 3), zmm_T(j, 4));
2085                 vfmadd231ps(zmm_dst(1), zmm_temp, zmm_G(2));
2086                 vmovups(zmm_dst(2), zmm_t(2));
2087                 vfmadd231ps(zmm_dst(2), zmm_t(0), zmm_G(3));
2088
2089                 store_dst(j);
2090             }
2091
2092             add(reg_src, jcp.oc_reg_block * simd_w * typesize);
2093             add(reg_dst, simd_w * typesize);
2094             add(reg_ic_simd, 1);
2095             cmp(reg_ic_simd, simd_w);
2096             jl(loop_ic_simd);
2097         }
2098     };
2099     preamble();
2100     push(reg_EVEX_max_8b_offt);
2101     mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
2102     init_G();
2103     compute_transform();
2104     pop(reg_EVEX_max_8b_offt);
2105     postamble();
2106 }
2107
2108 void jit_avx512_core_conv_winograd_bwd_weights_kernel_f32::gemm_loop_generate(
2109         bool is_first_tile)
2110 {
2111     auto zmm_srcA = [=]() {
2112         return Xbyak::Zmm(0);
2113     };
2114
2115     auto zmm_srcB = [=] (size_t N_ur){
2116         return Xbyak::Zmm(N_ur + 1);
2117     };
2118
2119     auto broadcastB = [=](size_t K_ur) {
2120         for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) {
2121             size_t srcB_off = (K_ur * jcp.dimN_reg_block + N_bcast)
2122                 * sizeof(float);
2123             vbroadcastss(zmm_srcB(N_bcast), EVEX_compress_addr(reg_srcB, srcB_off));
2124         }
2125     };
2126
2127     auto load_srcA = [=] (size_t K_ur, int M_ur) {
2128         size_t srcA_off = (K_ur * jcp.dimM_reg_block * jcp.dimM_simd_block
2129                         + M_ur * jcp.dimM_simd_block) * sizeof(float);
2130         vmovups(zmm_srcA(), EVEX_compress_addr(reg_srcA, srcA_off));
2131     };
2132
2133     auto zmm_dstC = [=](size_t M_reg_ur, int N_bcast){
2134         size_t idx = 1 // zmm_srcA
2135             + jcp.dimN_bcast_ur // zmm_srcB
2136             + M_reg_ur * jcp.dimN_bcast_ur + N_bcast;
2137         assert(idx < 32);
2138         return Xbyak::Zmm(idx);
2139     };
2140     auto prepare_accumm = [=](){
2141         for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) {
2142             for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) {
2143                 Zmm zmm = zmm_dstC(M_reg_ur, N_bcast);
2144                 vpxord(zmm, zmm, zmm);
2145             }
2146         }
2147     };
2148
2149     auto store_dstC = [=](){
2150         /******** Write C back to memory *******/
2151         for (int M_reg = 0; M_reg < jcp.dimM_reg_block; M_reg++) {
2152             for (int N_ur = 0; N_ur < jcp.dimN_bcast_ur; ++N_ur) {
2153                 Zmm zmm = zmm_dstC(M_reg, N_ur);
2154                 size_t C_off = (N_ur * jcp.dimM_reg_block * jcp.dimM_simd_block
2155                              + M_reg * jcp.dimM_simd_block) * sizeof(float);
2156                 if (!is_first_tile) {
2157                     vmovups(Xbyak::Zmm(0), EVEX_compress_addr(reg_dstC, C_off));
2158                     vaddps(zmm, zmm, Xbyak::Zmm(0));
2159                 }
2160                 vmovups(EVEX_compress_addr(reg_dstC, C_off), zmm);
2161             }
2162         }
2163     };
2164
2165     auto inner_loops = [=]() {
2166         Label dimM_block_loop, dimK_block_loop, dimN_block_loop, dimN_bcast_ur;
2167
2168         mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
2169         L(dimM_block_loop);
2170         { /************* OC_block (M) loop ***********/
2171             mov(reg_dimN_block_loop_cnt, jcp.dimN_block);
2172             L(dimN_block_loop);
2173             { /*************** IC_block (N) loop *********/
2174
2175                 mov(reg_nb_dimN_bcast_ur, jcp.dimN_reg_block/jcp.dimN_bcast_ur);
2176                 L(dimN_bcast_ur);
2177                 {
2178                     prepare_accumm();
2179
2180                     mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
2181                     L(dimK_block_loop);
2182                     {
2183                      /************* nb_tile_ur(K) loop ********/
2184                         for (int K_ur = 0; K_ur < jcp.dimK_reg_block; K_ur++) {
2185
2186                             broadcastB(K_ur);
2187
2188                             for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) {
2189                                 load_srcA(K_ur, M_reg_ur);
2190                                 for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; ++N_bcast) {
2191                                     vfmadd231ps(zmm_dstC(M_reg_ur, N_bcast), zmm_srcA(),
2192                                             zmm_srcB(N_bcast));
2193                                 }
2194                             }
2195                         }
2196                         add(reg_srcA, jcp.dimK_reg_block
2197                                       * jcp.dimM_reg_block * jcp.dimM_simd_block
2198                                       * sizeof(float));
2199                         add(reg_srcB, jcp.dimK_reg_block
2200                                       * jcp.dimN_reg_block
2201                                       * sizeof(float));
2202                         sub(reg_dimK_block_loop_cnt, 1);
2203                         jnz(dimK_block_loop);
2204                     }
2205
2206                     store_dstC();
2207
2208                     sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
2209                                   * jcp.dimM_reg_block * jcp.dimM_simd_block
2210                                   * sizeof(float));
2211                     sub(reg_srcB, jcp.dimK_block * jcp.dimK_reg_block
2212                                   * jcp.dimN_reg_block
2213                                   * sizeof(float));
2214                     add(reg_srcB, jcp.dimN_bcast_ur * sizeof(float));
2215                     add(reg_dstC, jcp.dimN_bcast_ur
2216                             * jcp.dimM_reg_block * jcp.dimM_simd_block
2217                             * sizeof(float));
2218                     sub(reg_nb_dimN_bcast_ur, 1);
2219                     jnz(dimN_bcast_ur);
2220                 }
2221
2222                 sub(reg_srcB, jcp.dimN_reg_block * sizeof(float));
2223                 add(reg_srcB, jcp.dimK_block
2224                         * jcp.dimK_reg_block
2225                         * jcp.dimN_reg_block * sizeof(float));
2226                 sub(reg_dimN_block_loop_cnt, 1);
2227                 jnz(dimN_block_loop);
2228             }
2229
2230             sub(reg_srcB, jcp.dimN_block
2231                           * jcp.dimK_block * jcp.dimK_reg_block
2232                           * jcp.dimN_reg_block
2233                           * sizeof(float));
2234             add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
2235                           * jcp.dimM_reg_block * jcp.dimM_simd_block
2236                           * sizeof(float));
2237             sub(reg_dimM_block_loop_cnt, 1);
2238             jnz(dimM_block_loop);
2239         }
2240     };
2241
2242     /* Preamble */
2243     // register used to handle long fma encoding
2244     push(reg_EVEX_max_8b_offt);
2245     push(reg_dimK_block_loop_cnt);
2246     mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
2247
2248     inner_loops();
2249
2250     /* Postamble */
2251     pop(reg_dimK_block_loop_cnt);
2252     pop(reg_EVEX_max_8b_offt);
2253     uni_vzeroupper();
2254     ret();
2255 }
2256
2257 namespace {
2258
2259 void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp) {
2260 /*M params*/
2261     jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block
2262         / jcp.dimM_simd_block;
2263     jcp.oc_reg_block = jcp.dimM_reg_block;
2264     jcp.oc_block = jcp.dimM_block;
2265     jcp.nb_oc = jcp.dimM_nb_block;
2266     /*N params*/
2267     jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
2268     jcp.ic_block = jcp.dimN_block;
2269     jcp.nb_ic = jcp.dimN_nb_block;
2270
2271     /*K params*/
2272     jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
2273     jcp.tile_block_ur = jcp.dimK_reg_block;
2274     jcp.nb_tile_block_ur = jcp.dimK_block;
2275     jcp.tile_block = jcp.dimK_nb_block;
2276 }
2277
2278 status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) {
2279
2280     size_t K_blk_ur, N_blk, M_blk;
2281     /* IS this strategy feasible? */
2282     auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) {
2283         size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float);
2284         size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float);
2285         size_t nthreads = omp_get_max_threads();
2286         return (((V_sz + M_sz) / nthreads) >= 2 * L2_cache_size)
2287             && (jcp.dimK / nthreads >= 1.0);
2288     };
2289
2290     auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur,
2291             int max_block=1) {
2292         size_t L1_block_M  = jcp.dimM_reg_block * jcp.dimM_simd_block * dimK_block_ur * sizeof(float);
2293         size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float);
2294         size_t M_L2_block = alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float);
2295         size_t nthreads = omp_get_max_threads();
2296         bool load_balance=true;
2297         if (!(jcp.dimK % nthreads)) {
2298             load_balance = ((jcp.dimK / dimK_block_ur) % nthreads == 0);
2299         }
2300         return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size)
2301             && (L1_block_M + L1_block_N <= 0.5 * L1_cache_size)
2302             && load_balance
2303             && (M_L2_block < L2_cache_size);
2304     };
2305
2306     auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur,
2307             int useless_arg=0) {
2308         return (dimK_ur >= 2) && (dimK_ur <= 8);
2309     };
2310
2311     auto blocking_ok =  [&](){
2312         size_t M_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block
2313                           * K_blk_ur * sizeof(float);
2314         size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block
2315                           * K_blk_ur * sizeof(float);
2316         size_t U_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block
2317                           * N_blk * jcp.dimN_reg_block * sizeof(float);
2318         size_t L2_block = M_L2_block + V_L2_block + U_L2_block;
2319         /*Replace 2.375 with L2+L3 cache size*/
2320         return (L2_block > 0.1 * L2_cache_size) && (L2_block <= 1.2 * L2_cache_size);
2321     };
2322
2323     if (test_MV_large_enough(jcp)) {
2324         if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) {
2325             jcp.dimM_reg_block = 2;
2326         } else {
2327             jcp.dimM_reg_block = 1;
2328         }
2329         jcp.dimM_simd_block = jcp.oc_simd_block;
2330         jcp.dimN_reg_block = jcp.ic_simd_block;
2331         jcp.dimN_bcast_ur = 8;
2332         /*dimK_block and dimK_ur*/
2333         size_t min_dimK_block_ur = get_divisor_satisfying_cond(jcp, jcp.dimK, 1, test_min_dimK_L1);
2334
2335         jcp.dimM_block = jcp.dimM/jcp.dimM_reg_block/jcp.dimM_simd_block;
2336         jcp.dimN_block = jcp.dimN/jcp.dimN_reg_block;
2337         for (K_blk_ur = min_dimK_block_ur; K_blk_ur >= 1; --K_blk_ur) {
2338             if (test_min_dimK_L1(jcp, K_blk_ur) && !(jcp.dimK % K_blk_ur)) {
2339                 for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) {
2340                     if (!(jcp.dimN_block % N_blk)) {
2341                         for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) {
2342                             if (!(jcp.dimM_block % M_blk) && blocking_ok()) {
2343                                 jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur);
2344                                 if (!test_dimK_ur(jcp, jcp.dimK_reg_block)) return status::unimplemented;
2345                                 jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block;
2346                                 jcp.dimN_block = N_blk;
2347                                 jcp.dimM_block = M_blk;
2348                                 jcp.sched_policy = WSCHED_WEI_SDGtWo;
2349                                 set_jcp_WEI_params(jcp);
2350                                 return status::success;
2351                             }
2352                         }
2353                     }
2354                 }
2355             }
2356         }
2357     }
2358     return status::unimplemented;
2359 }
2360
2361 status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) {
2362     if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) {
2363         jcp.dimM_reg_block = 2;
2364     } else {
2365         jcp.dimM_reg_block = 1;
2366     }
2367     jcp.dimN_bcast_ur = 8;
2368     jcp.dimN_reg_block = jcp.ic_simd_block;
2369     jcp.dimM_simd_block = jcp.oc_simd_block;
2370     jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block;
2371     jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block;
2372     float C1 = 0.0, C2 = 0.0;
2373     float C1_max = 0.5, C2_max = 1.4;
2374     int N_blk, M_blk, K_blk_ur;
2375
2376     auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur,
2377             int useless_arg=0) {
2378         return (dimK_ur >= 2) && (dimK_ur <= 8);
2379     };
2380
2381     auto blocking_ok = [&]() -> bool {
2382         size_t L1_block_M  = jcp.dimM_reg_block * jcp.dimM_simd_block * K_blk_ur * sizeof(float);
2383         size_t L1_block_N = jcp.dimN_reg_block * K_blk_ur * sizeof(float);
2384         bool L1_cond = ((L1_block_N + L1_block_M) >= C1 * L1_cache_size)
2385                      && ((L1_block_N + L1_block_M) <= C1_max * L1_cache_size);
2386
2387         size_t nb_N_blk = jcp.dimN/N_blk/jcp.dimN_reg_block;
2388         size_t nb_M_blk = jcp.dimM/M_blk/jcp.dimM_reg_block/jcp.dimM_simd_block;
2389         size_t nb_K_blk = jcp.dimK / K_blk_ur;
2390         size_t nthreads = omp_get_max_threads();
2391         bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) >= nthreads;
2392         if (!(nb_K_blk % nthreads)) {
2393             load_balance = load_balance && (nb_K_blk % nthreads == 0);
2394         }
2395
2396         size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block * K_blk_ur * sizeof(float);
2397
2398         size_t L2_block = V_L2_block;
2399         /*Replace 2.375 with L2+L3 cache size*/
2400         bool L2_cond = (L2_block >= C2 * L2_cache_size) && (L2_block <= C2_max * L2_cache_size);
2401         return L1_cond && load_balance && L2_cond;
2402     };
2403
2404     for (K_blk_ur = jcp.dimK; K_blk_ur >= 1; --K_blk_ur) {
2405         if (jcp.dimK % K_blk_ur == 0) {
2406             for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) {
2407                 if (jcp.dimN_block % N_blk == 0) {
2408                     for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) {
2409                         if (jcp.dimM_block % M_blk == 0) {
2410                             if (blocking_ok()) {
2411                                 jcp.dimN_block = N_blk;
2412                                 jcp.dimM_block = M_blk;
2413                                 jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur);
2414                                 jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block;
2415                                 jcp.sched_policy = WSCHED_WEI_S_D_Giot_W;
2416                                 set_jcp_WEI_params(jcp);
2417                                 return status::success;
2418                             }
2419                         }
2420                     }
2421                 }
2422             }
2423         }
2424     }
2425     return status::success;
2426 }
2427 } // namespace
2428 status_t jit_avx512_core_conv_winograd_bwd_weights_kernel_f32::init_conf(
2429         jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
2430         const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d,
2431         const memory_desc_wrapper &diff_weights_d) {
2432     if (!mayiuse(avx512_core))
2433         return status::unimplemented;
2434     else
2435         jcp.ver = ver_avx512_core;
2436
2437     const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
2438     jcp.mb = src_d.dims()[0];
2439     jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
2440     jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
2441     jcp.oc_without_padding = jcp.oc;
2442     jcp.ic = src_d.dims()[1] / jcp.ngroups;
2443     jcp.ih = src_d.dims()[2];
2444     jcp.iw = src_d.dims()[3];
2445     jcp.oh = diff_dst_d.dims()[2];
2446     jcp.ow = diff_dst_d.dims()[3];
2447     jcp.kh = diff_weights_d.dims()[with_groups + 2];
2448     jcp.kw = diff_weights_d.dims()[with_groups + 3];
2449     jcp.t_pad = cd.padding[0][0];
2450     jcp.l_pad = cd.padding[0][1];
2451     jcp.stride_h = cd.strides[0];
2452     jcp.stride_w = cd.strides[1];
2453     jcp.r_pad = nstl::max(
2454             0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
2455     jcp.b_pad = nstl::max(
2456             0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
2457     jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
2458     jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
2459     jcp.ohp = jcp.oh;
2460     jcp.owp = jcp.ow;
2461     jcp.with_bias = (cd.diff_bias_desc.format != memory_format::undef);
2462     jcp.dilate_h = cd.dilates[0];
2463     jcp.dilate_w = cd.dilates[1];
2464
2465     bool ok_to_pad_channels = jcp.ngroups == 1;
2466     if (ok_to_pad_channels) {
2467         jcp.oc = rnd_up(jcp.oc, simd_w);
2468         jcp.ic = rnd_up(jcp.ic, simd_w);
2469     }
2470
2471     // Winograd specific initialization
2472     jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
2473     jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
2474     jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
2475
2476     // Winograd kernel works only for 3x3 convolution with stride 1
2477     if (jcp.ngroups != 1)
2478         return status::unimplemented;
2479     if ((jcp.kh != 3) || (jcp.kw != 3))
2480         return status::unimplemented;
2481     if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
2482         return status::unimplemented;
2483     if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
2484         return status::unimplemented;
2485     if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
2486         return status::unimplemented;
2487     if (src_d.format() != nChw16c)
2488         return status::unimplemented;
2489     if (diff_weights_d.format() != (with_groups ? gOIhw16i16o : OIhw16i16o))
2490         return status::unimplemented;
2491     if (diff_dst_d.format() != nChw16c)
2492         return status::unimplemented;
2493
2494     bool layout_consistency = true
2495         && jcp.ic <= src_d.blocking_desc().padding_dims[1]
2496         && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
2497         && jcp.ic <= diff_weights_d.blocking_desc().padding_dims[with_groups + 1]
2498         && jcp.oc <= diff_weights_d.blocking_desc().padding_dims[with_groups + 0];
2499     if (!layout_consistency) return status::unimplemented;
2500
2501     /******************Kernel blocking Parameters ***********/
2502     jcp.ic_simd_block = simd_w;
2503     jcp.oc_simd_block = simd_w;
2504
2505     jcp.dimK = jcp.ntiles;
2506     jcp.dimN = jcp.ic;
2507     jcp.dimM = jcp.oc;
2508     jcp.dimM_simd_block = jcp.oc_simd_block;
2509     jcp.dimN_reg_block = jcp.ic_simd_block;
2510     jcp.sched_policy = WSCHED_INVALID;
2511     status_t res = set_wsched_WEI_SDGtWo(jcp);
2512     if (res == status::unimplemented) {
2513         res = set_wsched_WEI_S_D_Giot_W(jcp);
2514         assert(res == status::success);
2515     }
2516     return res;
2517 }
2518 }
2519 }
2520 }
2521
2522 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s