1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
20 #include "type_helpers.hpp"
22 #include "cpu_memory.hpp"
26 #include "jit_avx512_core_conv_winograd_kernel_f32.hpp"
28 #define GET_OFF(field) offsetof(jit_wino_transform_call_s, field)
36 using namespace mkldnn::impl::utils;
38 unsigned int L1_cache_size = get_cache_size(1, true);
39 unsigned int L2_cache_size = get_cache_size(2, true);
40 unsigned int LLC_data_size = get_cache_size(3, false);
42 // the test funtion takes jcp, the candidate and the current best.
43 // it returns true if the new candidate is better
44 int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number,
45 int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int))
47 int best_divisor = default_best;
49 = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) {
50 if (test(jcp, num, best_divisor)) {
55 for (int divisor = 1; divisor <= ::sqrt(number); divisor++) {
56 if (number % divisor == 0) {
57 test_num(jcp, divisor);
58 test_num(jcp, number / divisor);
65 /* assumes 512 bits registers */
66 /* TODO: add support for strides */
67 /* TODO: handle the prefetch distance automatically */
68 typedef enum cache_t_ { L1, L2, L3 } cache_t;
70 template <typename data_t>
72 prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr,
73 cache_t cache_type, size_t block_size, /* in number of elements*/
74 int nb_instructions_in_block, int fma_ipc)
76 , reg_base_addr_(reg_base_addr)
77 , cache_type_(cache_type)
78 , cache_block_size_(block_size)
80 nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t));
82 = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_);
84 = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block);
86 /* assumption: when fetch in Li, data is already in L(i+1) */
88 switch (cache_type_) {
89 case L1: cache_latency = 14; break;
90 case L2: cache_latency = 250; break;
91 case L3: cache_latency = 250; break;
94 prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_);
97 void prefetch(int instruction_number)
99 if (instruction_number % prefetch_spread_ == 0) {
100 for (int i = 0; (i < prefetch_blk_)
101 && (prefetches_issued_ < nb_cache_lines_to_prefetch_);
102 i++, prefetches_issued_++) {
103 prefetch_inst_(cg_->EVEX_compress_addr(
104 reg_base_addr_, (cache_block_size_ * prefetch_distance_)
106 + (prefetches_issued_ * 64)));
112 void prefetch_inst_(const Xbyak::Address &addr)
114 switch (cache_type_) {
115 case L1: cg_->prefetcht0(addr); break;
116 case L2: cg_->prefetcht1(addr); break;
117 case L3: cg_->prefetcht2(addr); break;
119 break; // TODO: raise an exception or put an assert
124 Xbyak::Reg64 reg_base_addr_;
126 int cache_block_size_ = 0;
127 int nb_cache_lines_to_prefetch_ = 0;
128 int prefetches_issued_ = 0;
129 int prefetch_spread_ = 0;
130 int prefetch_blk_ = 0;
131 int prefetch_distance_ = 0;
134 // utilities to support kernel parameter selection
135 bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp,
136 int dimN_block, float C2_min, float C2_max) {
137 float block_size = alpha * alpha * (2*(jcp.oc + jcp.ic)
138 * dimN_block * jcp.dimN_reg_block
139 + div_up(jcp.ic * jcp.oc,omp_get_max_threads())) * (float)sizeof(float);
140 float L2_lb = C2_min * L2_cache_size;
141 float L2_ub = C2_max * L2_cache_size;
142 return (block_size > L2_lb && block_size < L2_ub);
145 bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block,
146 int dimM_block, float C1_min, float C1_max) {
147 float gemm_block_size = (dimM_block * jcp.dimM_simd_block * dimK_block
148 * jcp.dimK_reg_block * jcp.dimM_reg_block
149 + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block
150 + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block)
151 * (float)sizeof(float);
152 float L1_lb = C1_min * L1_cache_size;
153 float L1_ub = C1_max * L1_cache_size;
154 return (gemm_block_size > L1_lb && gemm_block_size < L1_ub);
156 bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block,
157 int dimM_block, int dimM_reg_block, int dimM_simd_block, float C)
159 float lhs = (dimM_block * dimN_reg_block * dimM_simd_block * dimM_reg_block
160 + dimM_block * dimK_block * dimK_reg_block
161 * dimM_simd_block * dimM_reg_block
162 + dimK_block * dimN_reg_block * dimK_reg_block)
163 * (float)sizeof(float);
164 float rhs = C * L1_cache_size;
167 bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block,
168 int dimM_block, int dimM_reg_block, int dimM_simd_block, float C)
170 float lhs = (dimM_block * dimM_reg_block * dimK_block * dimK_reg_block
171 * dimM_simd_block + dimK_block * dimN_reg_block * dimK_reg_block)
172 * (float)sizeof(float);
173 float rhs = C * L1_cache_size;
176 bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block,
177 int dimK_block, int dimK_reg_block, int dimM_block, int dimM_reg_block,
178 int dimM_simd_block, float C)
180 float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block
181 * dimM_simd_block * dimM_reg_block
182 + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block
183 * dimM_simd_block * dimM_reg_block
184 + nb_dimN_reg_block * dimK_nb_block * dimK_block
185 * dimN_reg_block * dimK_reg_block)
186 * (float)sizeof(float);
187 float rhs = C * L2_cache_size;
191 bool check_kernel_cond(int dimM_block, int dimM_reg_block, int dimM_simd_block,
192 int dimN_block, int dimN_reg_block, int dimK, float C1, float C2)
194 float A_size = dimM_block * dimM_reg_block * dimM_simd_block * dimK
195 * (float)sizeof(float);
196 float B_size = dimN_block * dimN_reg_block * dimK
197 * (float)sizeof(float);
198 return (A_size > C1 * L2_cache_size && B_size > C2 * L2_cache_size);
202 using namespace mkldnn::impl::memory_format;
203 using namespace mkldnn::impl::utils;
204 using namespace Xbyak;
206 void _jit_avx512_core_conv_winograd_data_kernel_f32::gemm_loop_generate()
208 // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++)
209 // for (int dimM_reg_block =0; dimM_reg_block < jcp.dimM_reg_block;
210 // dimM_reg_block++) // unrolled
211 // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++)
212 // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block;
213 // dimK_reg_block++) // unrolled
214 // for (int tile =0; tile < jcp.dimN_reg_block; tile++)
215 // C[dimM_block][dimM_reg_block][tile] +=
216 // A[dimM_block][dimM_reg_block][dimK_block][dimK_reg_block]
217 // * broadcast(B[dimK_block][tile][dimK_reg_block]);
219 // jcp.kernel_kind defines embedded or explicit broadcast
220 // dimM_reg_block=1 for embedded bcast kernel
222 auto zmm_srcA = [=]() {
223 return Xbyak::Zmm(0);
225 auto zmm_srcB = [=](int tile) {
227 assert(idx < 1 + jcp.dimN_reg_block);
228 return Xbyak::Zmm(idx);
230 auto zmm_dstC = [=](int dimM_reg_block, int tile) {
232 if (jcp.kernel_kind == embd_bcast)
235 idx = 1 + jcp.dimN_reg_block
236 + dimM_reg_block * jcp.dimN_reg_block + tile;
238 return Xbyak::Zmm(idx);
241 auto prepare_output = [=]() {
242 for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
244 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
245 Zmm zmm = zmm_dstC(dimM_reg_block, tile);
246 vpxord(zmm, zmm, zmm);
250 auto store_output = [=](bool output_is_aligned) {
252 cmp(reg_is_beta_zero, 0);
255 for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
257 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
258 Zmm zmm = zmm_dstC(dimM_reg_block,tile);
260 = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64;
261 vaddps(zmm, zmm, EVEX_compress_addr(reg_dstC, output_offset));
266 for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
268 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
269 Zmm zmm = zmm_dstC(dimM_reg_block,tile);
271 = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64;
273 // In W_SGD, output will be reused.
274 if (output_is_aligned
275 && jcp.dimK_nb_block == 1
276 && jcp.sched_policy == WSCHED_DATA_W_S_G_D
277 && (jcp.dimN * jcp.dimM * alpha * alpha
278 * sizeof(float) > 2 * LLC_data_size))
279 vmovntps(EVEX_compress_addr(reg_dstC, output_offset), zmm);
280 else vmovups(EVEX_compress_addr(reg_dstC, output_offset), zmm);
285 auto inner_loops = [=]() {
286 Label dimM_block_loop, dimK_block_loop;
288 if (jcp.dimM_block > 1) {
289 mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
295 if (jcp.dimK_block > 1) {
296 mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
300 for (int dimK_reg_block = 0;
301 dimK_reg_block < jcp.dimK_reg_block;
304 if (jcp.kernel_kind == expl_bcast) {
305 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
306 vbroadcastss(zmm_srcB(tile),
307 ptr[reg_srcB + 64 * tile + dimK_reg_block * 4]);
311 /* Performing the fmas */
313 for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
318 + jcp.dimK_reg_block * jcp.dimK_block * 64
320 + dimK_reg_block * 64]
323 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
324 if (jcp.kernel_kind == expl_bcast)
325 vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(),
328 vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(),
329 EVEX_compress_addr(reg_srcB,
330 64 * tile + dimK_reg_block * 4, true));
334 add(reg_srcA, jcp.dimK_reg_block * 64);
335 add(reg_srcB, jcp.dimN_reg_block * 64);
336 if (jcp.dimK_block > 1) {
337 sub(reg_dimK_block_loop_cnt, 1);
338 jnz(dimK_block_loop);
341 Label unaligned_store, end_store;
342 test(reg_dstC, cpu_isa_traits<avx512_core>::vlen - 1);
343 jnz(unaligned_store, T_NEAR);
345 jmp(end_store, T_NEAR);
346 L(unaligned_store); {
351 if (jcp.dimM_block > 1) {
352 sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64);
353 add(reg_dstC, jcp.dimM_reg_block * jcp.dimN_reg_block * 64);
354 if (jcp.kernel_kind == expl_bcast) {
356 (jcp.dimM_reg_block-1) * jcp.dimK_reg_block * 64
359 sub(reg_dimM_block_loop_cnt, 1);
360 jnz(dimM_block_loop);
365 // register used to handle long fma encoding
366 push(reg_EVEX_max_8b_offt);
367 mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
373 pop(reg_EVEX_max_8b_offt);
378 void _jit_avx512_core_conv_winograd_data_kernel_f32
379 ::weights_transform_data_ker_generate()
381 bool is_fwd = one_of(jcp.prop_kind,
382 mkldnn_forward_training, mkldnn_forward_inference);
386 auto zmm_temp = Xbyak::Zmm(31);
387 auto zmm_zero = Xbyak::Zmm(30);
389 auto zmm_M = [=](int i) {
390 return Xbyak::Zmm(i);
392 auto zmm_MT = [=](int i) {
393 return Xbyak::Zmm(i + simd_w);
396 auto zmm_G = [=](int i) {
397 return Xbyak::Zmm(i);
399 auto zmm_F = [=](int i) {
400 return Xbyak::Zmm(alpha + i);
402 auto zmm_T = [=](int i) {
403 return Xbyak::Zmm(alpha + 3 + i);
405 auto zmm_t = [=](int i) {
406 return Xbyak::Zmm(2 * alpha + 3 + i);
409 auto zmm_load = [=](int i) {
410 return Xbyak::Zmm(i);
413 auto init_G = [=]() {
414 mov(wreg_temp, ptr[param1 + GET_OFF(G)]);
415 for (int i = 0; i < alpha; i++) {
416 vbroadcastss(zmm_G(i), ptr[wreg_temp + i * typesize]);
418 vpxord(zmm_zero, zmm_zero, zmm_zero);
421 auto trans16x16 = [=]() {
422 for (int i = 0; i < simd_w; i+=2 ) {
423 vmovups(zmm_M(i), ptr[wreg_M + i * simd_w * 4]);
424 vmovups(zmm_M(i+1), ptr[wreg_M + (i + 1) * simd_w * 4]);
425 vunpcklps(zmm_MT(i), zmm_M(i), zmm_M(i+1));
426 vunpckhps(zmm_MT(i+1), zmm_M(i), zmm_M(i+1));
428 for (int i = 0; i < simd_w; i+=4 ) {
429 vunpcklpd(zmm_M(i), zmm_MT(i), zmm_MT(i+2));
430 vunpckhpd(zmm_M(i+1), zmm_MT(i), zmm_MT(i+2));
431 vunpcklpd(zmm_M(i+2), zmm_MT(i+1), zmm_MT(i+3));
432 vunpckhpd(zmm_M(i+3), zmm_MT(i+1), zmm_MT(i+3));
434 for (int i = 0; i < simd_w; i += 8) {
435 vshuff32x4(zmm_MT(i), zmm_M(i), zmm_M(i + 4), 0x88);
436 vshuff32x4(zmm_MT(i+1), zmm_M(i+1), zmm_M(i + 5), 0x88);
437 vshuff32x4(zmm_MT(i+2), zmm_M(i+2), zmm_M(i + 6), 0x88);
438 vshuff32x4(zmm_MT(i+3), zmm_M(i+3), zmm_M(i + 7), 0x88);
439 vshuff32x4(zmm_MT(i+4), zmm_M(i), zmm_M(i + 4), 0xdd);
440 vshuff32x4(zmm_MT(i+5), zmm_M(i+1), zmm_M(i + 5), 0xdd);
441 vshuff32x4(zmm_MT(i+6), zmm_M(i+2), zmm_M(i + 6), 0xdd);
442 vshuff32x4(zmm_MT(i+7), zmm_M(i+3), zmm_M(i + 7), 0xdd);
447 vshuff32x4(zmm_M(0), zmm_MT(i), zmm_MT(i + 8), mask);
448 vmovups(ptr[wreg_MT + 0 * 16 * 4], zmm_M(0));
449 vshuff32x4(zmm_M(1), zmm_MT(i + 1), zmm_MT(i + 9), mask);
450 vmovups(ptr[wreg_MT + 1 * 16 * 4], zmm_M(1));
451 vshuff32x4(zmm_M(2), zmm_MT(i + 2), zmm_MT(i + 10), mask);
452 vmovups(ptr[wreg_MT + 2 * 16 * 4], zmm_M(2));
453 vshuff32x4(zmm_M(3), zmm_MT(i + 3), zmm_MT(i + 11), mask);
454 vmovups(ptr[wreg_MT + 3 * 16 * 4], zmm_M(3));
455 vshuff32x4(zmm_M(4), zmm_MT(i + 4), zmm_MT(i + 12), mask);
456 vmovups(ptr[wreg_MT + 4 * 16 * 4], zmm_M(4));
457 vshuff32x4(zmm_M(5), zmm_MT(i + 5), zmm_MT(i + 13), mask);
458 vmovups(ptr[wreg_MT + 5 * 16 * 4], zmm_M(5));
459 vshuff32x4(zmm_M(6), zmm_MT(i + 6), zmm_MT(i + 14), mask);
460 vmovups(ptr[wreg_MT + 6 * 16 * 4], zmm_M(6));
461 vshuff32x4(zmm_M(7), zmm_MT(i + 7), zmm_MT(i + 15), mask);
462 vmovups(ptr[wreg_MT + 7 * 16 * 4], zmm_M(7));
464 vshuff32x4(zmm_M(8), zmm_MT(i), zmm_MT(i + 8), mask);
465 vmovups(ptr[wreg_MT + 8 * 16 * 4], zmm_M(8));
466 vshuff32x4(zmm_M(9), zmm_MT(i + 1), zmm_MT(i + 9), mask);
467 vmovups(ptr[wreg_MT + 9 * 16 * 4], zmm_M(9));
468 vshuff32x4(zmm_M(10), zmm_MT(i + 2), zmm_MT(i + 10), mask);
469 vmovups(ptr[wreg_MT + 10 * 16 * 4], zmm_M(10));
470 vshuff32x4(zmm_M(11), zmm_MT(i + 3), zmm_MT(i + 11), mask);
471 vmovups(ptr[wreg_MT + 11 * 16 * 4], zmm_M(11));
472 vshuff32x4(zmm_M(12), zmm_MT(i + 4), zmm_MT(i + 12), mask);
473 vmovups(ptr[wreg_MT + 12 * 16 * 4], zmm_M(12));
474 vshuff32x4(zmm_M(13), zmm_MT(i + 5), zmm_MT(i + 13), mask);
475 vmovups(ptr[wreg_MT + 13 * 16 * 4], zmm_M(13));
476 vshuff32x4(zmm_M(14), zmm_MT(i + 6), zmm_MT(i + 14), mask);
477 vmovups(ptr[wreg_MT + 14 * 16 * 4], zmm_M(14));
478 vshuff32x4(zmm_M(15), zmm_MT(i + 7), zmm_MT(i + 15), mask);
479 vmovups(ptr[wreg_MT + 15 * 16 * 4], zmm_M(15));
483 auto load_src = [=]() {
484 mov(wreg_src, ptr[param1 + GET_OFF(src)]);
485 mov(wreg_F, ptr[param1 + GET_OFF(M)]);
486 for (int j = 0; j < kh; j++) {
487 for (int i = 0; i < kw; i++) {
489 for (int v1 = 0; v1 < simd_w; v1++) {
490 int offset_src = (j * kw * simd_w * simd_w
491 + i * simd_w * simd_w + v1 * simd_w) * typesize;
492 int offset_F = (j * kw * simd_w * simd_w
493 + i * simd_w * simd_w + v1 * simd_w) * typesize;
494 vmovups(zmm_temp, ptr[wreg_src + offset_src]);
495 vmovups(ptr[wreg_F + offset_F], zmm_temp);
498 int offset_src = ((2 - j) * kw * simd_w * simd_w
499 + (2 - i) * simd_w * simd_w) * typesize;
500 int offset_F = (j * kw * simd_w * simd_w
501 + i * simd_w * simd_w) * typesize;
502 lea(wreg_M, ptr[wreg_src + offset_src]);
503 lea(wreg_MT, ptr[wreg_F + offset_F]);
510 auto store_dst = [=]() {
511 mov(wreg_dst, ptr[param1 + GET_OFF(dst)]);
512 mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]);
516 mov(wreg_dst_aux, wreg_dst);
517 mov(wreg_Fw_aux, wreg_Fw);
519 int dim5 = jcp.dimK_nb_block * (jcp.dimM_block * jcp.dimM_reg_block)
520 * jcp.dimK_block * simd_w * simd_w;
524 for (int i = 0; i < alpha; i++) {
526 vmovups(zmm_load(0), ptr[wreg_Fw_aux
527 + (i * simd_w * simd_w) * typesize]);
528 mov(wreg_dst_idx, i * dim5 * typesize);
529 vmovntps(ptr[wreg_dst_aux + wreg_dst_idx], zmm_load(0));
531 for (int i = 0; i < alpha; i++) {
532 for (int v1 = 1; v1 < simd_w; v1++) {
533 int offset_Fw = (i * simd_w * simd_w + v1 * simd_w)
535 vmovups(zmm_load(v1), ptr[wreg_Fw_aux + offset_Fw]);
537 mov(wreg_dst_idx, i * dim5 * typesize);
538 for (int v1 = 1; v1 < simd_w; v1++) {
539 int offset_dst = v1 * simd_w * typesize;
540 vmovntps(ptr[wreg_dst_aux + wreg_dst_idx + offset_dst],
544 add(wreg_Fw_aux, alpha * simd_w * simd_w * typesize);
545 add(wreg_dst_aux, alpha * dim5 * typesize);
547 cmp(wreg_cnt_j, alpha);
552 auto trans_W_4x4_3x3 = [=]() {
553 auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
555 vfmadd231ps(dst, b, c);
557 auto fms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
558 vmulps(zmm_temp, b, c);
559 vsubps(dst, a, zmm_temp);
561 auto fnms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
562 vsubps(dst, zmm_zero, a);
563 vfnmadd231ps(dst, b, c);
566 mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]);
567 mov(wreg_F, ptr[param1 + GET_OFF(M)]);
568 mov(wreg_T, ptr[param1 + GET_OFF(T)]);
573 mov(wreg_F_aux, wreg_F);
574 mov(wreg_Fw_aux, wreg_Fw);
575 mov(wreg_temp, wreg_cnt_j);
576 shl(wreg_temp, 4 + 2);
577 lea(wreg_F_aux, ptr[wreg_F + wreg_temp]);
578 lea(wreg_Fw_aux, ptr[wreg_Fw + wreg_temp]);
580 for (int i = 0; i < 3; i++) {
581 for (int idx = 0; idx < 3; idx ++) {
582 vmovups(zmm_F(idx), ptr[wreg_F_aux + (idx * 3 * simd_w
583 * simd_w + i * simd_w * simd_w) * typesize]);
585 vmulps(zmm_t(0), zmm_G(0), zmm_F(2));
586 fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_F(0));
587 fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_F(0));
589 vmulps(zmm_T(0), zmm_G(3), zmm_F(0));
590 fms4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_F(1));
591 fma4(zmm_T(2), zmm_t(1), zmm_G(4), zmm_F(1));
592 fma4(zmm_T(3), zmm_t(2), zmm_G(5), zmm_F(1));
593 fms4(zmm_T(4), zmm_t(2), zmm_G(5), zmm_F(1));
594 vmovaps(zmm_T(5), zmm_F(2));
596 for (int idx = 0; idx < 6; idx ++) {
597 vmovups(ptr[wreg_T + (idx * 3 * simd_w + i * simd_w)
598 * typesize], zmm_T(idx));
601 for (int i = 0; i < 6; i++) {
603 for (int idx = 0; idx < 3; idx ++) {
604 vmovups(zmm_T(idx), ptr[wreg_T
605 + (i * 3 * simd_w + idx * simd_w) * typesize]);
607 vmulps(zmm_t(0), zmm_G(0), zmm_T(2));
608 fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_T(0));
609 fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_T(0));
611 vmulps(zmm_F(0), zmm_G(3), zmm_T(0));
612 fms4(zmm_F(1), zmm_t(1), zmm_G(4), zmm_T(1));
613 fma4(zmm_F(2), zmm_t(1), zmm_G(4), zmm_T(1));
614 fma4(zmm_F(3), zmm_t(2), zmm_G(5), zmm_T(1));
615 fms4(zmm_F(4), zmm_t(2), zmm_G(5), zmm_T(1));
616 vmovaps(zmm_F(5), zmm_T(2));
618 for (int l = 0; l < 6; l++) {
619 vmovups(ptr[wreg_Fw_aux + (i * 6 * simd_w * simd_w
620 + l * simd_w * simd_w) * typesize], zmm_F(l));
628 auto inner_loops = [=]() {
640 void _jit_avx512_core_conv_winograd_data_kernel_f32
641 ::output_transform_data_ker_generate()
643 bool is_fwd = one_of(jcp.prop_kind,
644 mkldnn_forward_training, mkldnn_forward_inference);
645 int outw = is_fwd ? jcp.ow : jcp.iw;
646 int outh = is_fwd ? jcp.oh : jcp.ih;
647 bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D;
648 bool with_bias = jcp.with_bias;
649 bool with_relu = jcp.with_eltwise;
650 bool with_relu_postsum = jcp.with_relu_postsum;
651 bool with_sum = jcp.with_sum;
653 auto zmm_zero = Xbyak::Zmm(0);
654 auto zmm_temp = Xbyak::Zmm(31);
655 auto zmm_G = [=](int i) {
656 return Xbyak::Zmm(1 + i);
658 auto zmm_O = [=](int i) {
659 return Xbyak::Zmm(1 + alpha + i);
661 auto zmm_T = [=](int i) {
662 return Xbyak::Zmm(1 + 2 * alpha + i);
664 auto zmm_t = [=](int i) {
665 return Xbyak::Zmm(1 + 3 * alpha + i);
668 auto init_G = [=]() {
669 mov(oreg_temp, ptr[param1 + GET_OFF(G)]);
670 for (int i = 0; i < 6; i++) {
671 vbroadcastss(zmm_G(i), ptr[oreg_temp + i * typesize]);
675 auto load_src = [=]() {
676 mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]);
677 mov(oreg_src, ptr[param1 + GET_OFF(src)]);
679 mov(oreg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]);
680 imul(oreg_nb_tile_block_ur, oreg_nb_tile_block_ur,
681 (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block
682 * jcp.dimM_simd_block * typesize);
683 add(oreg_src, oreg_nb_tile_block_ur);
685 mov(oreg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]);
686 imul(oreg_tile_block_ur, oreg_tile_block_ur,
687 jcp.dimM_simd_block * typesize);
688 add(oreg_src, oreg_tile_block_ur);
691 mov(oreg_tile_block, ptr[param1 + GET_OFF(tile_block)]);
692 imul(oreg_tile_block, oreg_tile_block,
693 jcp.dimM_nb_block * alpha * alpha * jcp.dimN_block
694 * (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block
695 * jcp.dimM_simd_block * typesize);
696 add(oreg_src, oreg_tile_block);
699 int last4dim = jcp.dimN_block * (jcp.dimM_block * jcp.dimM_reg_block)
700 * jcp.dimN_reg_block * jcp.dimM_simd_block * typesize;
701 for (int j = 0; j < alpha; j++) {
702 for (int i = 0; i < alpha; i++) {
703 int j_base_offset = j * alpha * last4dim;
704 int i_base_offset = i * last4dim;
705 vmovups(zmm_temp, ptr[oreg_src + j_base_offset + i_base_offset]);
706 vmovups(ptr[oreg_Ow + (j * alpha * simd_w + i * simd_w)
707 * typesize], zmm_temp);
712 auto store_dst = [=]() {
713 vpxord(zmm_zero, zmm_zero, zmm_zero);
714 mov(oreg_dst, ptr[param1 + GET_OFF(dst)]);
715 mov(oreg_O, ptr[param1 + GET_OFF(M)]);
716 mov(oreg_ydim, ptr[param1 + GET_OFF(tj)]);
717 shl(oreg_ydim, 2); // tj * tile_size (==4)
718 mov(oreg_xdim, ptr[param1 + GET_OFF(ti)]);
719 shl(oreg_xdim, 2); // ti * tilesize (==4)
722 mov(oreg_bias, ptr[param1 + GET_OFF(bias)]);
724 auto store_one = [=](int j, int i, bool is_aligned) {
725 auto zmm_O = Xbyak::Zmm(31);
726 auto zmm_relu_ns = Xbyak::Zmm(30);
727 auto xmm_relu_ns = Xbyak::Xmm(30);
728 int offset = (j * tile_size * simd_w + i * simd_w) * typesize;
730 vmovups(zmm_O, ptr[oreg_O + offset]);
733 vaddps(zmm_O, zmm_O, ptr[oreg_bias]);
736 Opmask kmask = Opmask(7);
737 if (jcp.eltwise_alpha == 0) {
738 zmm_relu_ns = zmm_zero;
740 mov(imm_addr64, float2int(jcp.eltwise_alpha));
741 vmovq(xmm_relu_ns, imm_addr64);
742 vbroadcastss(zmm_relu_ns, xmm_relu_ns);
744 vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os);
745 vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns);
749 vaddps(zmm_O, zmm_O, ptr[oreg_out_j + oreg_temp]);
750 if (with_relu_postsum) // orig: with_relu_postsum
751 vmaxps(zmm_O, zmm_O, zmm_zero);
754 vmovntps(ptr[oreg_out_j + oreg_temp], zmm_O);
756 vmovups(ptr[oreg_out_j + oreg_temp], zmm_O);
759 auto i_loop = [=](int j, bool is_aligned) {
760 for (int i = 0; i < tile_size; i++) {
762 mov(oreg_temp, oreg_xdim);
764 cmp(oreg_temp, outw);
766 shl(oreg_temp, 4 + 2); // * 16 * 4
768 store_one(j, i, is_aligned);
775 for (int j = 0; j < tile_size; j++) {
776 Label next, unaligned;
777 mov(oreg_temp, oreg_ydim);
779 cmp(oreg_temp, outh);
782 mov(oreg_out_j, oreg_dst);
783 imul(oreg_temp, oreg_temp, outw * simd_w * typesize);
784 add(oreg_out_j, oreg_temp);
787 jnz(unaligned, T_NEAR);
799 auto trans_O_4x4_3x3 = [=]() {
800 auto fma2 = [=](Zmm dst, Zmm v1, Zmm u1, Zmm v2, Zmm u2){
802 vfmadd231ps(dst, v2, u2);
804 mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]);
805 mov(oreg_T, ptr[param1 + GET_OFF(T)]);
806 mov(oreg_O, ptr[param1 + GET_OFF(M)]);
808 for (int i = 0; i < alpha; i++) {
809 for (int j = 0; j < alpha; j++) {
810 vmovups(zmm_O(j), ptr[oreg_Ow + (j * alpha * simd_w
811 + i * simd_w) * typesize]);
814 vaddps(zmm_t(0), zmm_O(1), zmm_O(2));
815 vaddps(zmm_t(1), zmm_O(3), zmm_O(4));
816 vsubps(zmm_t(2), zmm_O(1), zmm_O(2));
817 vsubps(zmm_t(3), zmm_O(3), zmm_O(4));
819 vaddps(zmm_T(0), zmm_t(0), zmm_t(1));
820 vaddps(zmm_T(0), zmm_T(0), zmm_O(0));
821 fma2(zmm_T(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1));
822 fma2(zmm_T(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3));
823 fma2(zmm_T(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5));
824 vaddps(zmm_T(3), zmm_T(3), zmm_O(5));
826 for (int j = 0; j < tile_size; j++) {
827 vmovups(ptr[oreg_T + (j * alpha * simd_w
828 + i * simd_w) * typesize], zmm_T(j));
831 for (int j = 0; j < tile_size; j++) {
832 for (int i = 0; i < alpha; i++) {
833 vmovups(zmm_T(i), ptr[oreg_T + (j * alpha * simd_w
834 + i * simd_w) * typesize]);
836 vaddps(zmm_t(0), zmm_T(1), zmm_T(2));
837 vaddps(zmm_t(1), zmm_T(3), zmm_T(4));
838 vsubps(zmm_t(2), zmm_T(1), zmm_T(2));
839 vsubps(zmm_t(3), zmm_T(3), zmm_T(4));
841 vaddps(zmm_O(0), zmm_t(0), zmm_t(1));
842 vaddps(zmm_O(0), zmm_O(0), zmm_T(0));
843 fma2(zmm_O(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1));
844 fma2(zmm_O(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3));
845 fma2(zmm_O(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5));
846 vaddps(zmm_O(3), zmm_O(3), zmm_T(5));
848 for (int i = 0; i < tile_size; i++) {
849 vmovups(ptr[oreg_O + (j * tile_size * simd_w
850 + i * simd_w) * typesize], zmm_O(i));
855 auto inner_loops = [=]() {
867 void _jit_avx512_core_conv_winograd_data_kernel_f32
868 ::input_transform_data_ker_generate()
870 bool is_fwd = one_of(jcp.prop_kind,
871 mkldnn_forward_training, mkldnn_forward_inference);
872 int inpw = is_fwd ? jcp.iw : jcp.ow;
873 int inph = is_fwd ? jcp.ih : jcp.oh;
874 int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow;
875 int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh;
876 int wp_max = inpw + l_pad;
877 int hp_max = inph + t_pad;
878 bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D;
881 auto zmm_zero = Xbyak::Zmm(0);
882 auto zmm_temp = Xbyak::Zmm(31);
883 auto zmm_G = [=](int i) {
884 return Xbyak::Zmm(1 + i);
886 auto zmm_I = [=](int i) {
887 return Xbyak::Zmm(1 + G_size + i);
889 auto zmm_T = [=](int i) {
890 return Xbyak::Zmm(1 + G_size + alpha + i);
892 auto zmm_t = [=](int i) {
893 return Xbyak::Zmm(1 + G_size + 2 * alpha + i);
896 auto init_G = [=]() {
897 mov(ireg_temp, ptr[param1 + GET_OFF(G)]);
898 for (int i = 0; i < G_size; i++) {
899 vbroadcastss(zmm_G(i), ptr[ireg_temp + i * typesize]);
903 auto load_src = [=]() {
904 mov(ireg_src, ptr[param1 + GET_OFF(src)]); // base addr of inp
905 mov(ireg_I, ptr[param1 + GET_OFF(M)]);
907 xor_(ireg_zero, ireg_zero);
908 vpxord(zmm_zero, zmm_zero, zmm_zero);
910 mov(ireg_ydim, ptr[param1 + GET_OFF(tj)]);
911 shl(ireg_ydim, 2); // tj * tile_size (==4)
912 mov(ireg_xdim, ptr[param1 + GET_OFF(ti)]);
913 shl(ireg_xdim, 2); // ti * tilesize (==4)
915 for (int j = 0; j < alpha; j++) {
916 mov(ireg_temp, ireg_ydim);
919 mov(ireg_mask_j, 0xffff);
920 cmp(ireg_temp, t_pad);
921 cmovl(ireg_mask_j, ireg_zero);
922 cmp(ireg_temp, hp_max);
923 cmovge(ireg_mask_j, ireg_zero);
925 sub(ireg_temp, t_pad);
926 imul(ireg_temp, ireg_temp, inpw * simd_w * typesize);
927 mov(ireg_inp_j, ireg_src);
928 add(ireg_inp_j, ireg_temp);
930 for (int i = 0; i < alpha; i++) {
932 mov(ireg_temp, ireg_xdim);
935 mov(ireg_mask, 0xffff);
936 cmp(ireg_temp, l_pad);
937 cmovl(ireg_mask, ireg_zero);
938 cmp(ireg_temp, wp_max);
939 cmovge(ireg_mask, ireg_zero);
940 and_(ireg_mask, ireg_mask_j);
942 sub(ireg_temp, l_pad);
943 shl(ireg_temp, 4 + 2);
945 vpxord(zmm_temp, zmm_temp, zmm_temp);
946 Opmask kmask = Opmask(7);
947 kmovw(kmask, ireg_mask_32);
948 vmovups(zmm_temp | kmask, ptr[ireg_inp_j + ireg_temp]);
949 vmovups(ptr[ireg_I + (j * alpha * simd_w + i * simd_w)
950 * typesize], zmm_temp);
955 auto store_Iw = [=]() {
957 mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]);
958 mov(ireg_output, ptr[param1 + GET_OFF(dst)]);
961 = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float)
966 mov(ireg_tile_block, ptr[param1 + GET_OFF(tile_block)]);
967 imul(ireg_tile_block, ireg_tile_block,
968 alpha * alpha * jcp.dimN_block * jcp.dimK_nb_block
969 * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
973 mov(ireg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]);
974 imul(ireg_nb_tile_block_ur, ireg_nb_tile_block_ur,
975 jcp.dimK_nb_block * jcp.dimK_block * jcp.dimN_reg_block
976 * jcp.dimK_reg_block * typesize);
978 mov(ireg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]);
979 imul(ireg_tile_block_ur, ireg_tile_block_ur,
980 jcp.dimK_reg_block * typesize);
982 add(ireg_output, ireg_nb_tile_block_ur);
983 add(ireg_output, ireg_tile_block_ur);
985 add(ireg_output, ireg_tile_block);
987 for (int j = 0; j < alpha; j++) {
988 for (int i = 0; i < alpha; i++) {
989 vmovups(zmm_temp,ptr[ireg_Iw + (j * alpha * simd_w
990 + i * simd_w) * typesize]);
993 j * alpha * jcp.dimN_block * jcp.dimK_nb_block
994 * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
997 i * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block
998 * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize;
1000 if (not_tiled && streamout)
1001 vmovntps(ptr[ireg_output + j_base_offset + i_base_offset],
1004 vmovups(ptr[ireg_output + j_base_offset + i_base_offset],
1010 auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
1011 vmulps(zmm_temp, a, b);
1012 vaddps(dst, zmm_temp, c);
1015 auto trans_I_4x4_3x3 = [=]() {
1016 mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]);
1017 mov(ireg_T, ptr[param1 + GET_OFF(T)]);
1018 mov(ireg_I, ptr[param1 + GET_OFF(M)]);
1020 mov(ireg_output, ptr[param1 + GET_OFF(dst)]); // for prefetch
1021 for (int i = 0; i < alpha; i++) {
1022 for (int idx = 0; idx < alpha; idx++) {
1023 vmovups(zmm_I(idx), ptr[ireg_I + (idx * alpha * simd_w
1024 + i * simd_w) * typesize]);
1026 i * alpha * jcp.dimN_block * jcp.dimK_nb_block
1027 * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
1029 int idx_base_offset =
1030 idx * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block
1031 * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize;
1032 prefetcht0(ptr[ireg_output + j_base_offset + idx_base_offset]);
1035 fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4));
1036 fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3));
1037 fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4));
1038 fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3));
1039 fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4));
1040 fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5));
1042 fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4));
1043 fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0));
1044 fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0));
1045 fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2));
1046 fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2));
1047 fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5));
1049 for (int idx = 0; idx < alpha; idx++) {
1050 vmovups(ptr[ireg_T + (idx * alpha * simd_w + i * simd_w)
1051 * typesize],zmm_T(idx));
1054 for (int i = 0; i < alpha; i++) {
1055 for (int idx = 0; idx < alpha; idx++) {
1056 vmovups(zmm_T(idx), ptr[ireg_T + (i * alpha * simd_w + idx
1057 * simd_w) * typesize]);
1060 fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4));
1061 fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3));
1062 fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4));
1063 fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3));
1064 fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4));
1065 fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5));
1067 fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4));
1068 fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0));
1069 fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0));
1070 fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2));
1071 fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2));
1072 fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5));
1074 for (int idx = 0; idx < alpha; idx++) {
1075 vmovups(ptr[ireg_Iw + (i * alpha * simd_w + idx * simd_w)
1076 * typesize],zmm_I(idx));
1081 auto inner_loops = [=]() {
1093 status_t _jit_avx512_core_conv_winograd_data_kernel_f32::init_conf_common(
1094 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1095 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
1096 const memory_desc_wrapper &dst_d)
1098 if (!mayiuse(avx512_core)) {
1099 return status::unimplemented;
1101 jcp.ver = ver_avx512_core;
1102 jcp.prop_kind = cd.prop_kind;
1104 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1106 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1107 jcp.mb = src_d.dims()[0];
1108 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
1109 jcp.oc_without_padding = jcp.oc;
1110 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1111 jcp.ih = src_d.dims()[2];
1112 jcp.iw = src_d.dims()[3];
1113 jcp.oh = dst_d.dims()[2];
1114 jcp.ow = dst_d.dims()[3];
1115 jcp.kh = weights_d.dims()[with_groups + 2];
1116 jcp.kw = weights_d.dims()[with_groups + 3];
1117 jcp.t_pad = cd.padding[0][0];
1118 jcp.l_pad = cd.padding[0][1];
1119 jcp.stride_h = cd.strides[0];
1120 jcp.stride_w = cd.strides[1];
1121 jcp.dilate_h = cd.dilates[0];
1122 jcp.dilate_w = cd.dilates[1];
1123 jcp.r_pad = nstl::max(
1124 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
1125 jcp.b_pad = nstl::max(
1126 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
1127 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1128 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1132 bool ok_to_pad_channels = jcp.ngroups == 1;
1133 if (ok_to_pad_channels) {
1134 jcp.oc = rnd_up(jcp.oc, simd_w);
1135 jcp.ic = rnd_up(jcp.ic, simd_w);
1138 // Checking conditions not supported by these kernels
1139 if (jcp.ngroups != 1)
1140 return status::unimplemented;
1141 if ((jcp.kh != 3) || (jcp.kw != 3))
1142 return status::unimplemented;
1143 if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
1144 return status::unimplemented;
1145 if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
1146 return status::unimplemented;
1147 if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
1148 return status::unimplemented;
1150 if (src_d.format() != nChw16c)
1151 return status::unimplemented;
1152 if (weights_d.format() != (with_groups ? gOIhw16i16o : OIhw16i16o))
1153 return status::unimplemented;
1154 if (dst_d.format() != nChw16c)
1155 return status::unimplemented;
1157 bool layout_consistency = true
1158 && jcp.ic <= src_d.blocking_desc().padding_dims[1]
1159 && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
1160 && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
1161 && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
1162 if (!layout_consistency) return status::unimplemented;
1164 return status::success;
1167 void set_kernel_dims_reg_block(jit_conv_winograd_conf_t &jcp) {
1169 /* ----------- dimM reg block ---------------------*/
1170 auto test_cond_dimM_reg_block = [](jit_conv_winograd_conf_t &jcp,
1171 int dimM_reg_block, int current_best) {
1172 int max_dimM_reg_block = jcp.kernel_kind == embd_bcast ? 1 : 4;
1173 return (dimM_reg_block >= 1)
1174 && (dimM_reg_block <= max_dimM_reg_block )
1175 && (dimM_reg_block > current_best);
1177 jcp.dimM_reg_block = get_divisor_satisfying_cond(jcp,
1178 jcp.dimM/jcp.dimM_simd_block, 1, test_cond_dimM_reg_block);
1180 /* ----------- dimN reg block ---------------------*/
1182 auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
1183 int dimN_reg_block, int current_best) {
1184 return jcp.kernel_kind == embd_bcast
1185 ? dimN_reg_block < jcp.nb_reg && dimN_reg_block > current_best
1186 : dimN_reg_block >= 1
1187 && (dimN_reg_block * jcp.dimM_reg_block + dimN_reg_block)
1189 && dimN_reg_block > current_best;
1191 jcp.dimN_reg_block = get_divisor_satisfying_cond(jcp,
1192 jcp.dimN, 1, test_cond_dimN_reg_block);
1195 status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) {
1196 if (jcp.ver != ver_avx512_core)
1197 return status::unimplemented;
1199 jcp.kernel_kind = embd_bcast;
1201 set_kernel_dims_reg_block(jcp);
1203 /*-------------- L2 blocking for dimN block ---------*/
1205 auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp,
1206 int dimN_block, int current_best) {
1207 return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0)
1208 && (dimN_block > current_best)
1209 && ((jcp.dimN / dimN_block / jcp.dimN_reg_block)
1210 >= 1.5 * omp_get_max_threads());
1213 jcp.dimN_block = get_divisor_satisfying_cond(
1214 jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block);
1215 jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
1217 if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2)
1218 && (jcp.dimN_nb_block >= 1.5 * omp_get_max_threads())) {
1220 /* ------------------- L1 blocking for GEMM --------------*/
1221 /* -------------------- Choose dimK block ----------------*/
1223 auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp,
1224 int dimK_block, int current_best) {
1225 return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.5)
1226 && (dimK_block > current_best);
1229 jcp.dimK_block = get_divisor_satisfying_cond(
1230 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block);
1232 if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 1.0)) {
1233 jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
1235 /* -------------- Choose dimM block -------------------*/
1236 auto test_cond_dimM_block = [](jit_conv_winograd_conf_t &jcp,
1237 int dimM_block, int current_best) {
1238 return check_L1_block_gemm(jcp, jcp.dimK_block, dimM_block,
1239 0.2, 0.5) && (dimM_block > current_best);
1242 jcp.dimM_block = get_divisor_satisfying_cond(jcp,
1243 jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1,
1244 test_cond_dimM_block);
1245 jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block
1246 / jcp.dimM_simd_block;
1248 jcp.sched_policy = WSCHED_DATA_W_SGD;
1249 return status::success;
1253 return status::unimplemented;
1256 void set_kernel_blocking_DATA_W_S_G_D(jit_conv_winograd_conf_t &jcp) {
1258 set_kernel_dims_reg_block(jcp);
1260 //********************* Choosing dimK_block **********************//
1261 auto test_cond1_dimK_block = [](
1262 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
1263 return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block,
1264 1, jcp.dimM_reg_block, jcp.dimM_simd_block, .75f)
1265 && (dimK_block > current_best);
1268 auto test_cond1_bis_dimK_block = [](
1269 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
1270 return check_cond1_bis(jcp.dimN_reg_block, dimK_block,
1271 jcp.dimK_reg_block, 1, jcp.dimM_reg_block,
1272 jcp.dimM_simd_block, .9f)
1273 && (dimK_block > current_best);
1276 jcp.dimK_block = get_divisor_satisfying_cond(
1277 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block);
1278 // If we are not able to use streams, we fall back to condition [1]
1279 if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
1280 jcp.dimK_block = get_divisor_satisfying_cond(
1281 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block);
1282 jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block;
1284 //********************* Choosing dimM_block **********************//
1285 auto test_cond1_dimM_block = [](
1286 jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
1287 return check_cond1(jcp.dimN_reg_block, jcp.dimK_block,
1288 jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block,
1289 jcp.dimM_simd_block, .5f)
1290 && (dimM_block > current_best);
1293 auto test_cond1_bis_dimM_block = [](
1294 jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
1295 return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block,
1296 jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block,
1297 jcp.dimM_simd_block, .3f)
1298 && (dimM_block > current_best);
1301 if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
1302 jcp.dimM_block = get_divisor_satisfying_cond(
1303 jcp, jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1,
1304 test_cond1_dimM_block);
1306 jcp.dimM_block = get_divisor_satisfying_cond(jcp,
1307 jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1,
1308 test_cond1_bis_dimM_block);
1309 jcp.dimM_nb_block = jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_block
1310 * jcp.dimM_reg_block);
1312 //******************* Choosing dimN_block *******************//
1313 auto test_cond2_dimN_block = [](
1314 jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
1315 return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block,
1316 jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block,
1317 jcp.dimM_reg_block, jcp.dimM_simd_block, .9f)
1318 && (dimN_block > current_best);
1321 jcp.dimN_block = get_divisor_satisfying_cond(
1322 jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
1323 jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block);
1326 status_t set_wsched_DATA_W_S_G_D_avx512_core(jit_conv_winograd_conf_t &jcp) {
1328 jcp.kernel_kind = expl_bcast;
1329 set_kernel_blocking_DATA_W_S_G_D(jcp);
1330 if (!(check_kernel_cond(jcp.dimM_block, jcp.dimM_reg_block,
1331 jcp.dimM_simd_block, jcp.dimN_block, jcp.dimN_reg_block, jcp.dimK,
1333 jcp.kernel_kind = embd_bcast;
1334 set_kernel_blocking_DATA_W_S_G_D(jcp);
1336 jcp.sched_policy = WSCHED_DATA_W_S_G_D;
1337 return status::success;
1340 status_t _jit_avx512_core_conv_winograd_data_kernel_f32::init_conf_kernel(
1341 jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK)
1347 jcp.sched_policy = WSCHED_INVALID;
1349 jcp.dimK_reg_block = 16;
1350 jcp.dimM_simd_block = 16;
1352 if (jcp.kernel_kind == embd_bcast) {
1353 jcp.dimM_reg_block = 1;
1356 if (!(set_wsched_DATA_W_SGD_avx512_core(jcp) == status::success))
1357 set_wsched_DATA_W_S_G_D_avx512_core(jcp);
1359 assert(jcp.sched_policy != WSCHED_INVALID);
1360 return status::success;
1363 bool jit_avx512_core_conv_winograd_fwd_kernel_f32::post_ops_ok(
1364 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
1365 const auto &p = attr.post_ops_;
1367 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
1368 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
1372 return true; // no post_ops
1374 return true // relu or sum
1375 && implication(jcp.with_eltwise, is_sum(0))
1376 && implication(!jcp.with_eltwise, is_eltwise(0) || is_sum(0));
1378 return true // sum->relu or relu->sum
1379 && implication(jcp.with_eltwise, is_sum(0) && is_eltwise(1))
1380 && implication(!jcp.with_eltwise, false
1381 || (is_sum(0) && is_eltwise(1))
1382 || (is_eltwise(0) && is_sum(1)));
1384 return true // relu->sum->relu
1385 && jcp.with_eltwise == false
1386 && (is_eltwise(0) && is_sum(1) && is_eltwise(2));
1394 status_t jit_avx512_core_conv_winograd_fwd_kernel_f32::init_conf(
1395 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1396 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
1397 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr,
1398 bool with_relu, float relu_negative_slope) {
1399 status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d);
1401 if (st != status::success)
1404 // Winograd specific initialization
1405 jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
1406 jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
1407 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
1409 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
1410 jcp.with_eltwise = with_relu;
1411 jcp.eltwise_alpha = relu_negative_slope;
1413 if (!post_ops_ok(jcp, attr))
1414 return status::unimplemented;
1416 const auto &p = attr.post_ops_;
1417 if (!jcp.with_eltwise) {
1418 /* PostOps ReLU before SUM is handled the same as ReLU primitive */
1419 jcp.with_eltwise = p.find(primitive_kind::eltwise, 0, 1) != -1;
1420 jcp.eltwise_alpha = 0.f;
1422 jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
1423 jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1;
1425 status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic);
1427 jcp.ic_simd_block = jcp.dimK_reg_block;
1428 jcp.ic_block = jcp.dimK_block;
1429 jcp.nb_ic = jcp.dimK_nb_block;
1430 jcp.oc_simd_block = jcp.dimM_simd_block;
1431 jcp.oc_block = jcp.dimM_block;
1432 jcp.oc_reg_block = jcp.dimM_reg_block;
1433 jcp.ic_reg_block = 1;
1434 jcp.nb_oc = jcp.dimM_nb_block;
1435 jcp.tile_block_ur = jcp.dimN_reg_block;
1436 jcp.nb_tile_block_ur = jcp.dimN_block;
1437 jcp.tile_block = jcp.dimN_nb_block;
1442 status_t jit_avx512_core_conv_winograd_bwd_data_kernel_f32::init_conf(
1443 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1444 const memory_desc_wrapper &diff_src_d,
1445 const memory_desc_wrapper &weights_d,
1446 const memory_desc_wrapper &diff_dst_d)
1448 status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d);
1450 if (st != status::success)
1453 jcp.itiles = (jcp.iw + tile_size - 1) / tile_size;
1454 jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size;
1455 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
1457 status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc);
1459 jcp.oc_simd_block = jcp.dimK_reg_block;
1460 jcp.oc_block = jcp.dimK_block;
1461 jcp.nb_oc = jcp.dimK_nb_block;
1462 jcp.ic_simd_block = jcp.dimM_simd_block;
1463 jcp.ic_block = jcp.dimM_block;
1464 jcp.ic_reg_block = jcp.dimM_reg_block;
1465 jcp.oc_reg_block = 1;
1466 jcp.nb_ic = jcp.dimM_nb_block;
1467 jcp.tile_block_ur = jcp.dimN_reg_block;
1468 jcp.nb_tile_block_ur = jcp.dimN_block;
1469 jcp.tile_block = jcp.dimN_nb_block;
1474 void jit_avx512_core_conv_winograd_bwd_weights_kernel_f32::
1475 src_transform_generate() {
1476 constexpr int G_size = 9;
1477 const size_t ifwp = jcp.iw + jcp.l_pad;
1478 const size_t ifhp = jcp.ih + jcp.t_pad;
1480 auto zmm_G = [=](int i) {
1481 return Xbyak::Zmm(i);
1483 auto zmm_I = [=](int i) {
1484 return Xbyak::Zmm(G_size + i);
1486 auto zmm_T = [=](int i) {
1487 return Xbyak::Zmm(G_size + alpha + i);
1489 auto zmm_t = [=](int i) {
1490 return Xbyak::Zmm(G_size + 2 * alpha + i);
1493 auto init_G = [=]() {
1494 mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1495 for (int i = 0; i < G_size; i++) {
1496 vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]);
1500 auto load_src = [=]() {
1501 mov(reg_I, ptr[reg_transp + GET_OFF(M)]);
1502 xor_(reg_zero, reg_zero);
1504 mov(reg_ydim, reg_tj);
1505 shl(reg_ydim, 2); //tj * tile_size(=4)
1507 for (int j = 0; j < alpha; j++) {
1508 /* check if tile index is within physical spatial boundaries*/
1509 mov(reg_maskj, 0xffff);
1510 cmp(reg_ydim, jcp.t_pad);
1511 cmovl(reg_maskj, reg_zero);
1512 cmp(reg_ydim, ifhp);
1513 cmovge(reg_maskj, reg_zero);
1515 /*address offset for tile in src*/
1516 mov(reg_src_offset, reg_ydim);
1517 sub(reg_src_offset, jcp.t_pad); // tj*tile_size - t_pad
1518 imul(reg_src_offset, reg_src_offset, jcp.iw);
1520 mov(reg_xdim, reg_ti);
1521 shl(reg_xdim, 2); // xdim = ti * tile_size
1523 add(reg_src_offset, reg_xdim);
1524 sub(reg_src_offset, jcp.l_pad);
1525 imul(reg_src_offset, reg_src_offset, simd_w * typesize);
1526 for (int i = 0; i < alpha; i++) {
1527 /* check if tile index is within physical spatial boundaries*/
1528 mov(reg_maski, 0xffff);
1529 cmp(reg_xdim, jcp.l_pad);
1530 cmovl(reg_maski, reg_zero);
1531 cmp(reg_xdim, ifwp);
1532 cmovge(reg_maski, reg_zero);
1533 and_(reg_maski, reg_maskj);
1535 Opmask kmask_src = Xbyak::Opmask(7);
1536 auto zmm_src = Xbyak::Zmm(31);
1537 kmovw(kmask_src, reg_maski_32);
1538 vpxord(zmm_src, zmm_src, zmm_src);
1539 vmovups(zmm_src | kmask_src, ptr[reg_src + reg_src_offset]);
1540 vmovups(ptr[reg_I], zmm_src);
1542 add(reg_xdim, 1); //xdim = ti * tile_size + i
1543 add(reg_src_offset, simd_w * typesize);
1544 add(reg_I, simd_w * typesize);
1550 auto fma4 = [=](Xbyak::Zmm dst, Xbyak::Zmm a, Xbyak::Zmm b, Xbyak::Zmm c) {
1552 vfmadd231ps(dst, a, b);
1555 auto trans_I_3x3_4x4 = [=]() {
1557 mov(reg_I, ptr[reg_transp + GET_OFF(M)]);
1558 mov(reg_T, ptr[reg_transp + GET_OFF(T)]);
1559 for (int i = 0; i < alpha; i++) {
1560 for (int j = 0; j < alpha; j++) {
1561 size_t I_off = (j * alpha + i) * simd_w * typesize;
1562 vmovups(zmm_I(j), ptr[reg_I + I_off]);
1565 fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4));
1566 fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3));
1567 fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4));
1568 fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3));
1569 fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4));
1570 fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5));
1572 fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4));
1573 fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0));
1574 fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0));
1575 fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2));
1576 fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2));
1577 fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5));
1579 for (int j = 0; j < alpha; j++) {
1580 vmovups(ptr[reg_T + (j * alpha + i) * simd_w * typesize],
1586 for (int j = 0; j < alpha; j++) {
1587 for (int i = 0; i < alpha; i++) {
1588 vmovups(zmm_T(i), ptr[reg_T + (j * alpha + i) * simd_w * typesize]);
1591 fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4));
1592 fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3));
1593 fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4));
1594 fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3));
1595 fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4));
1596 fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5));
1598 fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4));
1599 fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0));
1600 fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0));
1601 fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2));
1602 fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2));
1603 fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5));
1605 for (int i = 0; i < alpha; i++) {
1606 size_t dst_off = (j * alpha * jcp.ic_block
1607 * jcp.nb_tile_block_ur * jcp.tile_block_ur
1608 + i * jcp.ic_block * jcp.nb_tile_block_ur * jcp.tile_block_ur)
1609 * simd_w * typesize;
1610 vmovups(ptr[reg_dst + dst_off], zmm_I(i));
1615 auto compute_transform_SDGtWo = [=]() {
1616 mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]);
1617 mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]);
1618 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1619 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1620 xor_(reg_tile_count, reg_tile_count);
1621 Label loop_mb, loop_jtiles, loop_itiles, done;
1632 add(reg_tile_count, 1);
1633 cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1636 add(reg_dst, simd_w * typesize);
1638 cmp(reg_ti, jcp.itiles);
1641 xor_(reg_ti, reg_ti);
1643 cmp(reg_tj, jcp.jtiles);
1646 xor_(reg_tj, reg_tj);
1647 add(reg_src, jcp.ic * jcp.iw * jcp.ih * typesize);
1653 auto compute_transform = [=]() {
1654 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1655 xor_(reg_ti, reg_ti);
1656 xor_(reg_tj, reg_tj);
1658 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1659 mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
1660 imul(reg_temp, reg_tile_count, simd_w * typesize);
1661 add(reg_dst, reg_temp);
1663 Label loop_jtiles, loop_itiles, next_tile_block, next_tile;
1673 add(reg_tile_count, 1);
1674 cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1675 jge(next_tile_block);
1676 add(reg_dst, simd_w * typesize);
1680 sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1)
1681 * simd_w * typesize);
1682 size_t tblk_off = alpha * alpha * jcp.ic_block
1683 * jcp.nb_tile_block_ur * jcp.tile_block_ur
1684 * simd_w * typesize;
1685 add(reg_dst, tblk_off);
1686 xor_(reg_tile_count, reg_tile_count);
1690 cmp(reg_ti, jcp.itiles);
1693 xor_(reg_ti, reg_ti);
1695 cmp(reg_tj, jcp.jtiles);
1702 if (jcp.sched_policy == WSCHED_WEI_SDGtWo)
1703 compute_transform_SDGtWo();
1705 compute_transform();
1709 void jit_avx512_core_conv_winograd_bwd_weights_kernel_f32::
1710 diff_dst_transform_generate(bool with_bias) {
1712 constexpr int G_size = 8;
1713 auto zmm_G = [](int i) {
1714 return Xbyak::Zmm(31);
1717 auto zmm_src = [=](int j, int i) {
1718 return Xbyak::Zmm(G_size + j * 4 + i);
1721 auto zmm_bias = Xbyak::Zmm(31);
1723 auto load_src = [=]() {
1724 if (with_bias) vmovups(zmm_bias, ptr[reg_bias]);
1725 mov(reg_ydim, reg_tj);
1726 shl(reg_ydim, 2); //tj * tile_size(=4)
1727 for (int j = 0; j < tile_size; j++) {
1728 /* check if tile index is within physical spatial boundaries*/
1729 mov(reg_maskj, 0xffff);
1730 cmp(reg_ydim, jcp.oh);
1731 cmovge(reg_maskj, reg_zero);
1733 /*address offset for tile in src*/
1734 mov(reg_src_offset, reg_ydim);
1735 imul(reg_src_offset, reg_src_offset, jcp.ow);
1737 mov(reg_xdim, reg_ti);
1738 shl(reg_xdim, 2); // xdim = ti * tile_size
1740 add(reg_src_offset, reg_xdim);
1741 imul(reg_src_offset, reg_src_offset, simd_w * typesize);
1742 for (int i = 0; i < tile_size; i++) {
1743 /* check if tile index is within physical spatial boundaries*/
1744 mov(reg_maski, 0xffff);
1745 cmp(reg_xdim, jcp.ow);
1746 cmovge(reg_maski, reg_zero);
1747 and_(reg_maski, reg_maskj);
1749 Opmask kmask_src = Xbyak::Opmask(7);
1750 kmovw(kmask_src, reg_maski_32);
1751 vpxord(zmm_src(j, i), zmm_src(j, i), zmm_src(j, i));
1752 vmovups(zmm_src(j, i) | kmask_src, ptr[reg_src + reg_src_offset]);
1753 if (with_bias) vaddps(zmm_bias | kmask_src, zmm_bias,
1754 ptr[reg_src + reg_src_offset]);
1756 add(reg_xdim, 1); //xdim = ti * tile_size + i
1757 add(reg_src_offset, simd_w * typesize);
1761 if(with_bias) vmovups(ptr[reg_bias], zmm_bias);
1764 auto zmm_t = [=](int i) {
1765 return Xbyak::Zmm(G_size + 16 + i);
1768 auto zmm_T = [=](int j, int i) {
1769 return Xbyak::Zmm(j * 4 + i);
1772 auto movps = [=](Xbyak::Reg64 reg_dst, size_t dst_off, Xbyak::Zmm a) {
1773 if (jcp.sched_policy == WSCHED_WEI_SDGtWo)
1774 vmovups(ptr[reg_dst + dst_off], a);
1776 vmovntps(ptr[reg_dst + dst_off], a);
1779 auto trans_W_3x3_4x4 = [=]() {
1780 mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1781 for (int i = 0; i < tile_size; i++) {
1782 vbroadcastss(zmm_G(0), ptr[reg_G]);
1783 vmulps(zmm_t(0), zmm_src(2, i), zmm_G(0));
1785 vbroadcastss(zmm_G(1), ptr[reg_G + typesize]);
1786 vmovups(zmm_t(1), zmm_t(0));
1787 vfmsub231ps(zmm_t(1), zmm_src(0, i), zmm_G(1));
1789 vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]);
1790 vmovups(zmm_t(2), zmm_t(0));
1791 vfmadd231ps(zmm_t(2), zmm_src(0, i), zmm_G(2));
1793 vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]);
1794 vmulps(zmm_t(3), zmm_src(1, i), zmm_G(3));
1796 vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]);
1797 vfmadd231ps(zmm_t(3), zmm_src(3, i), zmm_G(4));
1799 vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]);
1800 vmulps(zmm_t(4), zmm_src(1, i), zmm_G(5));
1802 vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]);
1803 vfmadd231ps(zmm_t(4), zmm_src(3, i), zmm_G(6));
1805 vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]);
1806 vmulps(zmm_T(0, i), zmm_src(0, i), zmm_G(7));
1807 vsubps(zmm_T(1, i), zmm_t(1), zmm_t(3));
1808 vaddps(zmm_T(2, i), zmm_t(1), zmm_t(3));
1809 vaddps(zmm_T(3, i), zmm_t(2), zmm_t(4));
1810 vsubps(zmm_T(4, i), zmm_t(2), zmm_t(4));
1811 vmovups(zmm_T(5, i), zmm_src(3, i));
1814 for (int j = 0; j < alpha; j++) {
1815 vbroadcastss(zmm_G(0), ptr[reg_G]);
1816 vmulps(zmm_t(0), zmm_T(j, 2), zmm_G(0));
1818 vbroadcastss(zmm_G(1), ptr[reg_G + typesize]);
1819 vmovups(zmm_t(1), zmm_t(0));
1820 vfmsub231ps(zmm_t(1), zmm_T(j, 0), zmm_G(1));
1822 vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]);
1823 vmovups(zmm_t(2), zmm_t(0));
1824 vfmadd231ps(zmm_t(2), zmm_T(j, 0), zmm_G(2));
1826 vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]);
1827 vmulps(zmm_t(3), zmm_T(j, 1), zmm_G(3));
1829 vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]);
1830 vfmadd231ps(zmm_t(3), zmm_T(j, 3), zmm_G(4));
1832 vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]);
1833 vmulps(zmm_t(4), zmm_T(j, 1), zmm_G(5));
1835 vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]);
1836 vfmadd231ps(zmm_t(4), zmm_T(j, 3), zmm_G(6));
1838 vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]);
1839 vmulps(zmm_t(0), zmm_T(j, 0), zmm_G(7));
1840 vsubps(zmm_t(5), zmm_t(1), zmm_t(3));
1841 vaddps(zmm_t(1), zmm_t(1), zmm_t(3));
1842 vaddps(zmm_t(6), zmm_t(2), zmm_t(4));
1843 vsubps(zmm_t(2), zmm_t(2), zmm_t(4));
1844 vmovups(zmm_t(3), zmm_T(j, 3));
1846 size_t alpha_offset = jcp.oc/jcp.nb_oc * jcp.ntiles/jcp.tile_block
1848 size_t dst_off = (j * alpha * alpha_offset);
1849 movps(reg_dst, dst_off, zmm_t(0));
1850 dst_off += alpha_offset;
1851 movps(reg_dst, dst_off, zmm_t(5));
1852 dst_off += alpha_offset;
1853 movps(reg_dst, dst_off, zmm_t(1));
1854 dst_off += alpha_offset;
1855 movps(reg_dst, dst_off, zmm_t(6));
1856 dst_off += alpha_offset;
1857 movps(reg_dst, dst_off, zmm_t(2));
1858 dst_off += alpha_offset;
1859 movps(reg_dst, dst_off, zmm_t(3));
1863 auto compute_transform_SDGtWo = [=]() {
1864 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1865 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1866 if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]);
1868 xor_(reg_zero, reg_zero);
1869 xor_(reg_oc_ur, reg_oc_ur);
1870 Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, tiles_done;
1874 mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]);
1875 mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]);
1876 xor_(reg_tile_count, reg_tile_count);
1887 add(reg_tile_count, 1);
1888 cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1891 add(reg_dst, jcp.oc_reg_block * simd_w * typesize);
1893 cmp(reg_ti, jcp.itiles);
1896 xor_(reg_ti, reg_ti);
1898 cmp(reg_tj, jcp.jtiles);
1901 xor_(reg_tj, reg_tj);
1902 add(reg_src, jcp.oc * jcp.ow * jcp.oh * typesize);
1907 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1908 add(reg_dst, simd_w * typesize);
1909 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1910 add(reg_src, jcp.oh * jcp.ow * simd_w * typesize);
1912 if (with_bias) add(reg_bias, simd_w * typesize);
1914 cmp(reg_oc_ur, jcp.oc_reg_block);
1919 auto compute_transform = [=]() {
1920 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1921 mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1922 if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]);
1924 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1925 mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
1926 imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize);
1927 add(reg_dst, reg_temp);
1929 xor_(reg_zero, reg_zero);
1930 xor_(reg_oc_ur, reg_oc_ur);
1931 Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, next_tile_block, next_tile;
1935 xor_(reg_ti, reg_ti);
1936 xor_(reg_tj, reg_tj);
1946 add(reg_tile_count, 1);
1947 cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1948 jge(next_tile_block);
1949 add(reg_dst, jcp.oc_reg_block * simd_w * typesize);
1953 sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1)
1954 * jcp.oc_reg_block * simd_w * typesize);
1955 size_t tblk_off = alpha * alpha * jcp.oc/jcp.nb_oc
1956 * jcp.ntiles/jcp.tile_block * typesize;
1957 add(reg_dst, tblk_off);
1958 xor_(reg_tile_count, reg_tile_count);
1962 cmp(reg_ti, jcp.itiles);
1965 xor_(reg_ti, reg_ti);
1967 cmp(reg_tj, jcp.jtiles);
1971 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1972 mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
1973 imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize);
1974 add(reg_dst, reg_temp);
1975 add(reg_dst, simd_w * typesize);
1976 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1977 add(reg_src, jcp.oh * jcp.ow * simd_w * typesize);
1979 if (with_bias) add(reg_bias, simd_w * typesize);
1981 cmp(reg_oc_ur, jcp.oc_reg_block);
1987 if (jcp.sched_policy == WSCHED_WEI_SDGtWo) {
1988 compute_transform_SDGtWo();
1990 compute_transform();
1995 void jit_avx512_core_conv_winograd_bwd_weights_kernel_f32::
1996 diff_weights_transform_generate(bool first_tile) {
1999 auto zmm_G = [](int i) {
2000 return Xbyak::Zmm(i);
2003 auto init_G = [=]() {
2004 mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
2005 for (int i = 0; i < G_size; i++)
2006 vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]);
2009 auto zmm_src = [=](int i) {
2010 return Xbyak::Zmm(G_size + i);
2013 auto load_src = [=](int i) {
2014 for (int j = 0; j < alpha; j++) {
2015 size_t alpha_offset = jcp.oc_block * jcp.oc_reg_block
2016 * jcp.ic_block * simd_w * simd_w * typesize;
2017 size_t src_off = (j * alpha + i) * alpha_offset;
2018 vmovups(zmm_src(j), EVEX_compress_addr(reg_src, src_off));
2022 auto zmm_t = [=](int i) {
2023 return Xbyak::Zmm(G_size + 6 + i);
2026 auto zmm_T = [=](int j, int i) {
2027 return Xbyak::Zmm(G_size + 6 + 3 + j * 6 + i);
2030 auto zmm_dst = [=](int i) {
2031 return Xbyak::Zmm(G_size + i);
2034 auto zmm_temp = Xbyak::Zmm(31);
2036 auto store_dst = [=](int j) {
2037 for (int i = 0; i < jcp.kw; i++) {
2038 size_t dst_off = (j * jcp.kw + i) * simd_w * simd_w * typesize;
2041 vmovups(zmm_temp, EVEX_compress_addr(reg_dst, dst_off));
2042 vaddps(zmm_dst(i), zmm_dst(i), zmm_temp);
2044 vmovntps(EVEX_compress_addr(reg_dst, dst_off), zmm_dst(i));
2048 auto compute_transform = [=] () {
2049 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
2050 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
2052 xor_(reg_ic_simd, reg_ic_simd);
2056 for (int i = 0; i < alpha; i++) {
2059 vaddps(zmm_t(0), zmm_src(1), zmm_src(2));
2060 vaddps(zmm_t(1), zmm_src(3), zmm_src(4));
2061 vmovups(zmm_t(2), zmm_src(5));
2062 vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0));
2064 vaddps(zmm_T(0, i), zmm_src(0), zmm_t(0));
2065 vaddps(zmm_T(0, i), zmm_T(0, i), zmm_t(1));
2066 vsubps(zmm_T(1, i), zmm_src(1), zmm_src(2));
2067 vmulps(zmm_T(1, i), zmm_T(1, i), zmm_G(1));
2068 vsubps(zmm_temp, zmm_src(3), zmm_src(4));
2069 vfmadd231ps(zmm_T(1, i), zmm_temp, zmm_G(2));
2070 vmovups(zmm_T(2, i), zmm_t(2));
2071 vfmadd231ps(zmm_T(2, i), zmm_t(0), zmm_G(3));
2074 for (int j = 0; j < jcp.kh; j++) {
2075 vaddps(zmm_t(0), zmm_T(j, 1), zmm_T(j, 2));
2076 vaddps(zmm_t(1), zmm_T(j, 3), zmm_T(j, 4));
2077 vmovups(zmm_t(2), zmm_T(j, 5));
2078 vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0));
2080 vaddps(zmm_dst(0), zmm_T(j, 0), zmm_t(0));
2081 vaddps(zmm_dst(0), zmm_dst(0), zmm_t(1));
2082 vsubps(zmm_dst(1), zmm_T(j, 1), zmm_T(j, 2));
2083 vmulps(zmm_dst(1), zmm_dst(1), zmm_G(1));
2084 vsubps(zmm_temp, zmm_T(j, 3), zmm_T(j, 4));
2085 vfmadd231ps(zmm_dst(1), zmm_temp, zmm_G(2));
2086 vmovups(zmm_dst(2), zmm_t(2));
2087 vfmadd231ps(zmm_dst(2), zmm_t(0), zmm_G(3));
2092 add(reg_src, jcp.oc_reg_block * simd_w * typesize);
2093 add(reg_dst, simd_w * typesize);
2094 add(reg_ic_simd, 1);
2095 cmp(reg_ic_simd, simd_w);
2100 push(reg_EVEX_max_8b_offt);
2101 mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
2103 compute_transform();
2104 pop(reg_EVEX_max_8b_offt);
2108 void jit_avx512_core_conv_winograd_bwd_weights_kernel_f32::gemm_loop_generate(
2111 auto zmm_srcA = [=]() {
2112 return Xbyak::Zmm(0);
2115 auto zmm_srcB = [=] (size_t N_ur){
2116 return Xbyak::Zmm(N_ur + 1);
2119 auto broadcastB = [=](size_t K_ur) {
2120 for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) {
2121 size_t srcB_off = (K_ur * jcp.dimN_reg_block + N_bcast)
2123 vbroadcastss(zmm_srcB(N_bcast), EVEX_compress_addr(reg_srcB, srcB_off));
2127 auto load_srcA = [=] (size_t K_ur, int M_ur) {
2128 size_t srcA_off = (K_ur * jcp.dimM_reg_block * jcp.dimM_simd_block
2129 + M_ur * jcp.dimM_simd_block) * sizeof(float);
2130 vmovups(zmm_srcA(), EVEX_compress_addr(reg_srcA, srcA_off));
2133 auto zmm_dstC = [=](size_t M_reg_ur, int N_bcast){
2134 size_t idx = 1 // zmm_srcA
2135 + jcp.dimN_bcast_ur // zmm_srcB
2136 + M_reg_ur * jcp.dimN_bcast_ur + N_bcast;
2138 return Xbyak::Zmm(idx);
2140 auto prepare_accumm = [=](){
2141 for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) {
2142 for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) {
2143 Zmm zmm = zmm_dstC(M_reg_ur, N_bcast);
2144 vpxord(zmm, zmm, zmm);
2149 auto store_dstC = [=](){
2150 /******** Write C back to memory *******/
2151 for (int M_reg = 0; M_reg < jcp.dimM_reg_block; M_reg++) {
2152 for (int N_ur = 0; N_ur < jcp.dimN_bcast_ur; ++N_ur) {
2153 Zmm zmm = zmm_dstC(M_reg, N_ur);
2154 size_t C_off = (N_ur * jcp.dimM_reg_block * jcp.dimM_simd_block
2155 + M_reg * jcp.dimM_simd_block) * sizeof(float);
2156 if (!is_first_tile) {
2157 vmovups(Xbyak::Zmm(0), EVEX_compress_addr(reg_dstC, C_off));
2158 vaddps(zmm, zmm, Xbyak::Zmm(0));
2160 vmovups(EVEX_compress_addr(reg_dstC, C_off), zmm);
2165 auto inner_loops = [=]() {
2166 Label dimM_block_loop, dimK_block_loop, dimN_block_loop, dimN_bcast_ur;
2168 mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
2170 { /************* OC_block (M) loop ***********/
2171 mov(reg_dimN_block_loop_cnt, jcp.dimN_block);
2173 { /*************** IC_block (N) loop *********/
2175 mov(reg_nb_dimN_bcast_ur, jcp.dimN_reg_block/jcp.dimN_bcast_ur);
2180 mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
2183 /************* nb_tile_ur(K) loop ********/
2184 for (int K_ur = 0; K_ur < jcp.dimK_reg_block; K_ur++) {
2188 for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) {
2189 load_srcA(K_ur, M_reg_ur);
2190 for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; ++N_bcast) {
2191 vfmadd231ps(zmm_dstC(M_reg_ur, N_bcast), zmm_srcA(),
2196 add(reg_srcA, jcp.dimK_reg_block
2197 * jcp.dimM_reg_block * jcp.dimM_simd_block
2199 add(reg_srcB, jcp.dimK_reg_block
2200 * jcp.dimN_reg_block
2202 sub(reg_dimK_block_loop_cnt, 1);
2203 jnz(dimK_block_loop);
2208 sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
2209 * jcp.dimM_reg_block * jcp.dimM_simd_block
2211 sub(reg_srcB, jcp.dimK_block * jcp.dimK_reg_block
2212 * jcp.dimN_reg_block
2214 add(reg_srcB, jcp.dimN_bcast_ur * sizeof(float));
2215 add(reg_dstC, jcp.dimN_bcast_ur
2216 * jcp.dimM_reg_block * jcp.dimM_simd_block
2218 sub(reg_nb_dimN_bcast_ur, 1);
2222 sub(reg_srcB, jcp.dimN_reg_block * sizeof(float));
2223 add(reg_srcB, jcp.dimK_block
2224 * jcp.dimK_reg_block
2225 * jcp.dimN_reg_block * sizeof(float));
2226 sub(reg_dimN_block_loop_cnt, 1);
2227 jnz(dimN_block_loop);
2230 sub(reg_srcB, jcp.dimN_block
2231 * jcp.dimK_block * jcp.dimK_reg_block
2232 * jcp.dimN_reg_block
2234 add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
2235 * jcp.dimM_reg_block * jcp.dimM_simd_block
2237 sub(reg_dimM_block_loop_cnt, 1);
2238 jnz(dimM_block_loop);
2243 // register used to handle long fma encoding
2244 push(reg_EVEX_max_8b_offt);
2245 push(reg_dimK_block_loop_cnt);
2246 mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
2251 pop(reg_dimK_block_loop_cnt);
2252 pop(reg_EVEX_max_8b_offt);
2259 void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp) {
2261 jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block
2262 / jcp.dimM_simd_block;
2263 jcp.oc_reg_block = jcp.dimM_reg_block;
2264 jcp.oc_block = jcp.dimM_block;
2265 jcp.nb_oc = jcp.dimM_nb_block;
2267 jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
2268 jcp.ic_block = jcp.dimN_block;
2269 jcp.nb_ic = jcp.dimN_nb_block;
2272 jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
2273 jcp.tile_block_ur = jcp.dimK_reg_block;
2274 jcp.nb_tile_block_ur = jcp.dimK_block;
2275 jcp.tile_block = jcp.dimK_nb_block;
2278 status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) {
2280 size_t K_blk_ur, N_blk, M_blk;
2281 /* IS this strategy feasible? */
2282 auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) {
2283 size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float);
2284 size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float);
2285 size_t nthreads = omp_get_max_threads();
2286 return (((V_sz + M_sz) / nthreads) >= 2 * L2_cache_size)
2287 && (jcp.dimK / nthreads >= 1.0);
2290 auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur,
2292 size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * dimK_block_ur * sizeof(float);
2293 size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float);
2294 size_t M_L2_block = alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float);
2295 size_t nthreads = omp_get_max_threads();
2296 bool load_balance=true;
2297 if (!(jcp.dimK % nthreads)) {
2298 load_balance = ((jcp.dimK / dimK_block_ur) % nthreads == 0);
2300 return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size)
2301 && (L1_block_M + L1_block_N <= 0.5 * L1_cache_size)
2303 && (M_L2_block < L2_cache_size);
2306 auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur,
2307 int useless_arg=0) {
2308 return (dimK_ur >= 2) && (dimK_ur <= 8);
2311 auto blocking_ok = [&](){
2312 size_t M_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block
2313 * K_blk_ur * sizeof(float);
2314 size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block
2315 * K_blk_ur * sizeof(float);
2316 size_t U_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block
2317 * N_blk * jcp.dimN_reg_block * sizeof(float);
2318 size_t L2_block = M_L2_block + V_L2_block + U_L2_block;
2319 /*Replace 2.375 with L2+L3 cache size*/
2320 return (L2_block > 0.1 * L2_cache_size) && (L2_block <= 1.2 * L2_cache_size);
2323 if (test_MV_large_enough(jcp)) {
2324 if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) {
2325 jcp.dimM_reg_block = 2;
2327 jcp.dimM_reg_block = 1;
2329 jcp.dimM_simd_block = jcp.oc_simd_block;
2330 jcp.dimN_reg_block = jcp.ic_simd_block;
2331 jcp.dimN_bcast_ur = 8;
2332 /*dimK_block and dimK_ur*/
2333 size_t min_dimK_block_ur = get_divisor_satisfying_cond(jcp, jcp.dimK, 1, test_min_dimK_L1);
2335 jcp.dimM_block = jcp.dimM/jcp.dimM_reg_block/jcp.dimM_simd_block;
2336 jcp.dimN_block = jcp.dimN/jcp.dimN_reg_block;
2337 for (K_blk_ur = min_dimK_block_ur; K_blk_ur >= 1; --K_blk_ur) {
2338 if (test_min_dimK_L1(jcp, K_blk_ur) && !(jcp.dimK % K_blk_ur)) {
2339 for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) {
2340 if (!(jcp.dimN_block % N_blk)) {
2341 for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) {
2342 if (!(jcp.dimM_block % M_blk) && blocking_ok()) {
2343 jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur);
2344 if (!test_dimK_ur(jcp, jcp.dimK_reg_block)) return status::unimplemented;
2345 jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block;
2346 jcp.dimN_block = N_blk;
2347 jcp.dimM_block = M_blk;
2348 jcp.sched_policy = WSCHED_WEI_SDGtWo;
2349 set_jcp_WEI_params(jcp);
2350 return status::success;
2358 return status::unimplemented;
2361 status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) {
2362 if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) {
2363 jcp.dimM_reg_block = 2;
2365 jcp.dimM_reg_block = 1;
2367 jcp.dimN_bcast_ur = 8;
2368 jcp.dimN_reg_block = jcp.ic_simd_block;
2369 jcp.dimM_simd_block = jcp.oc_simd_block;
2370 jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block;
2371 jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block;
2372 float C1 = 0.0, C2 = 0.0;
2373 float C1_max = 0.5, C2_max = 1.4;
2374 int N_blk, M_blk, K_blk_ur;
2376 auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur,
2377 int useless_arg=0) {
2378 return (dimK_ur >= 2) && (dimK_ur <= 8);
2381 auto blocking_ok = [&]() -> bool {
2382 size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * K_blk_ur * sizeof(float);
2383 size_t L1_block_N = jcp.dimN_reg_block * K_blk_ur * sizeof(float);
2384 bool L1_cond = ((L1_block_N + L1_block_M) >= C1 * L1_cache_size)
2385 && ((L1_block_N + L1_block_M) <= C1_max * L1_cache_size);
2387 size_t nb_N_blk = jcp.dimN/N_blk/jcp.dimN_reg_block;
2388 size_t nb_M_blk = jcp.dimM/M_blk/jcp.dimM_reg_block/jcp.dimM_simd_block;
2389 size_t nb_K_blk = jcp.dimK / K_blk_ur;
2390 size_t nthreads = omp_get_max_threads();
2391 bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) >= nthreads;
2392 if (!(nb_K_blk % nthreads)) {
2393 load_balance = load_balance && (nb_K_blk % nthreads == 0);
2396 size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block * K_blk_ur * sizeof(float);
2398 size_t L2_block = V_L2_block;
2399 /*Replace 2.375 with L2+L3 cache size*/
2400 bool L2_cond = (L2_block >= C2 * L2_cache_size) && (L2_block <= C2_max * L2_cache_size);
2401 return L1_cond && load_balance && L2_cond;
2404 for (K_blk_ur = jcp.dimK; K_blk_ur >= 1; --K_blk_ur) {
2405 if (jcp.dimK % K_blk_ur == 0) {
2406 for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) {
2407 if (jcp.dimN_block % N_blk == 0) {
2408 for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) {
2409 if (jcp.dimM_block % M_blk == 0) {
2410 if (blocking_ok()) {
2411 jcp.dimN_block = N_blk;
2412 jcp.dimM_block = M_blk;
2413 jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur);
2414 jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block;
2415 jcp.sched_policy = WSCHED_WEI_S_D_Giot_W;
2416 set_jcp_WEI_params(jcp);
2417 return status::success;
2425 return status::success;
2428 status_t jit_avx512_core_conv_winograd_bwd_weights_kernel_f32::init_conf(
2429 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
2430 const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d,
2431 const memory_desc_wrapper &diff_weights_d) {
2432 if (!mayiuse(avx512_core))
2433 return status::unimplemented;
2435 jcp.ver = ver_avx512_core;
2437 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
2438 jcp.mb = src_d.dims()[0];
2439 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
2440 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
2441 jcp.oc_without_padding = jcp.oc;
2442 jcp.ic = src_d.dims()[1] / jcp.ngroups;
2443 jcp.ih = src_d.dims()[2];
2444 jcp.iw = src_d.dims()[3];
2445 jcp.oh = diff_dst_d.dims()[2];
2446 jcp.ow = diff_dst_d.dims()[3];
2447 jcp.kh = diff_weights_d.dims()[with_groups + 2];
2448 jcp.kw = diff_weights_d.dims()[with_groups + 3];
2449 jcp.t_pad = cd.padding[0][0];
2450 jcp.l_pad = cd.padding[0][1];
2451 jcp.stride_h = cd.strides[0];
2452 jcp.stride_w = cd.strides[1];
2453 jcp.r_pad = nstl::max(
2454 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
2455 jcp.b_pad = nstl::max(
2456 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
2457 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
2458 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
2461 jcp.with_bias = (cd.diff_bias_desc.format != memory_format::undef);
2462 jcp.dilate_h = cd.dilates[0];
2463 jcp.dilate_w = cd.dilates[1];
2465 bool ok_to_pad_channels = jcp.ngroups == 1;
2466 if (ok_to_pad_channels) {
2467 jcp.oc = rnd_up(jcp.oc, simd_w);
2468 jcp.ic = rnd_up(jcp.ic, simd_w);
2471 // Winograd specific initialization
2472 jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
2473 jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
2474 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
2476 // Winograd kernel works only for 3x3 convolution with stride 1
2477 if (jcp.ngroups != 1)
2478 return status::unimplemented;
2479 if ((jcp.kh != 3) || (jcp.kw != 3))
2480 return status::unimplemented;
2481 if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
2482 return status::unimplemented;
2483 if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
2484 return status::unimplemented;
2485 if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
2486 return status::unimplemented;
2487 if (src_d.format() != nChw16c)
2488 return status::unimplemented;
2489 if (diff_weights_d.format() != (with_groups ? gOIhw16i16o : OIhw16i16o))
2490 return status::unimplemented;
2491 if (diff_dst_d.format() != nChw16c)
2492 return status::unimplemented;
2494 bool layout_consistency = true
2495 && jcp.ic <= src_d.blocking_desc().padding_dims[1]
2496 && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
2497 && jcp.ic <= diff_weights_d.blocking_desc().padding_dims[with_groups + 1]
2498 && jcp.oc <= diff_weights_d.blocking_desc().padding_dims[with_groups + 0];
2499 if (!layout_consistency) return status::unimplemented;
2501 /******************Kernel blocking Parameters ***********/
2502 jcp.ic_simd_block = simd_w;
2503 jcp.oc_simd_block = simd_w;
2505 jcp.dimK = jcp.ntiles;
2508 jcp.dimM_simd_block = jcp.oc_simd_block;
2509 jcp.dimN_reg_block = jcp.ic_simd_block;
2510 jcp.sched_policy = WSCHED_INVALID;
2511 status_t res = set_wsched_WEI_SDGtWo(jcp);
2512 if (res == status::unimplemented) {
2513 res = set_wsched_WEI_S_D_Giot_W(jcp);
2514 assert(res == status::success);
2522 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s