1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
20 #include "type_helpers.hpp"
22 #include "cpu_memory.hpp"
26 #include "jit_avx512_core_fp32_wino_conv_4x3_kernel.hpp"
28 #define GET_OFF(field) offsetof(jit_wino_transform_call_s, field)
36 using namespace mkldnn::impl::utils;
38 unsigned int L1_cache_size = get_cache_size(1, true);
39 unsigned int L2_cache_size = get_cache_size(2, true);
40 unsigned int LLC_data_size = get_cache_size(3, false);
42 // the test funtion takes jcp, the candidate and the current best.
43 // it returns true if the new candidate is better
44 int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number,
45 int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int))
47 int best_divisor = default_best;
49 = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) {
50 if (test(jcp, num, best_divisor)) {
55 for (int divisor = 1; divisor <= ::sqrt(number); divisor++) {
56 if (number % divisor == 0) {
57 test_num(jcp, divisor);
58 test_num(jcp, number / divisor);
66 bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) {
67 /* Determines if current winograd implementation is faster than direct.
68 Following conditions are empirical and based on performance data */
69 unsigned int ncores_per_socket =
70 cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel);
71 unsigned int nthreads = mkldnn_get_max_threads();
73 if (jcp.prop_kind == prop_kind::forward_inference) {
75 } else if (nthreads > ncores_per_socket) {
76 double src_dst_transforms_per_core = alpha * alpha
78 * jcp.mb * ((jcp.oh + tile_size - 1) / tile_size)
79 * ((jcp.ow + tile_size - 1) / tile_size)
80 * sizeof(float) / 1024. / 1024. / nthreads;
81 double wei_transform = alpha * alpha
82 * jcp.ic * jcp.oc * sizeof(float) /1024. / 1024.;
84 if (jcp.prop_kind == prop_kind::backward_weights) {
85 if (src_dst_transforms_per_core < 0.3
86 || (src_dst_transforms_per_core <= 28 && wei_transform < 4))
91 if (src_dst_transforms_per_core < 2.0 || wei_transform < 0.02)
100 /* assumes 512 bits registers */
101 /* TODO: add support for strides */
102 /* TODO: handle the prefetch distance automatically */
103 typedef enum cache_t_ { L1, L2, L3 } cache_t;
105 template <typename data_t>
106 struct prefetcher_t {
107 prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr,
108 cache_t cache_type, size_t block_size, /* in number of elements*/
109 int nb_instructions_in_block, int fma_ipc)
111 , reg_base_addr_(reg_base_addr)
112 , cache_type_(cache_type)
113 , cache_block_size_(block_size)
115 nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t));
117 = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_);
119 = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block);
121 /* assumption: when fetch in Li, data is already in L(i+1) */
123 switch (cache_type_) {
124 case L1: cache_latency = 14; break;
125 case L2: cache_latency = 250; break;
126 case L3: cache_latency = 250; break;
129 prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_);
132 void prefetch(int instruction_number)
134 if (instruction_number % prefetch_spread_ == 0) {
135 for (int i = 0; (i < prefetch_blk_)
136 && (prefetches_issued_ < nb_cache_lines_to_prefetch_);
137 i++, prefetches_issued_++) {
138 prefetch_inst_(cg_->EVEX_compress_addr(
139 reg_base_addr_, (cache_block_size_ * prefetch_distance_)
141 + (prefetches_issued_ * 64)));
147 void prefetch_inst_(const Xbyak::Address &addr)
149 switch (cache_type_) {
150 case L1: cg_->prefetcht0(addr); break;
151 case L2: cg_->prefetcht1(addr); break;
152 case L3: cg_->prefetcht2(addr); break;
154 break; // TODO: raise an exception or put an assert
159 Xbyak::Reg64 reg_base_addr_;
161 int cache_block_size_ = 0;
162 int nb_cache_lines_to_prefetch_ = 0;
163 int prefetches_issued_ = 0;
164 int prefetch_spread_ = 0;
165 int prefetch_blk_ = 0;
166 int prefetch_distance_ = 0;
169 // utilities to support kernel parameter selection
170 bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp,
171 int dimN_block, float C2_min, float C2_max) {
172 float block_size = alpha * alpha * (2*(jcp.oc + jcp.ic)
173 * dimN_block * jcp.dimN_reg_block
174 + div_up(jcp.ic * jcp.oc,mkldnn_get_max_threads())) * (float)sizeof(float);
175 float L2_lb = C2_min * L2_cache_size;
176 float L2_ub = C2_max * L2_cache_size;
177 return (block_size > L2_lb && block_size < L2_ub);
180 bool check_L1_block_gemm(jit_conv_winograd_conf_t &jcp, int dimK_block,
181 int dimM_block, float C1_min, float C1_max) {
182 float gemm_block_size = (dimM_block * jcp.dimM_simd_block * dimK_block
183 * jcp.dimK_reg_block * jcp.dimM_reg_block
184 + dimK_block * jcp.dimK_reg_block * jcp.dimN_reg_block
185 + dimM_block * jcp.dimM_simd_block * jcp.dimN_reg_block)
186 * (float)sizeof(float);
187 float L1_lb = C1_min * L1_cache_size;
188 float L1_ub = C1_max * L1_cache_size;
189 return (gemm_block_size > L1_lb && gemm_block_size < L1_ub);
191 bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block,
192 int dimM_block, int dimM_reg_block, int dimM_simd_block, float C)
194 float lhs = (dimM_block * dimN_reg_block * dimM_simd_block * dimM_reg_block
195 + dimM_block * dimK_block * dimK_reg_block
196 * dimM_simd_block * dimM_reg_block
197 + dimK_block * dimN_reg_block * dimK_reg_block)
198 * (float)sizeof(float);
199 float rhs = C * L1_cache_size;
202 bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block,
203 int dimM_block, int dimM_reg_block, int dimM_simd_block, float C)
205 float lhs = (dimM_block * dimM_reg_block * dimK_block * dimK_reg_block
206 * dimM_simd_block + dimK_block * dimN_reg_block * dimK_reg_block)
207 * (float)sizeof(float);
208 float rhs = C * L1_cache_size;
211 bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block,
212 int dimK_block, int dimK_reg_block, int dimM_block, int dimM_reg_block,
213 int dimM_simd_block, float C)
215 float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block
216 * dimM_simd_block * dimM_reg_block
217 + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block
218 * dimM_simd_block * dimM_reg_block
219 + nb_dimN_reg_block * dimK_nb_block * dimK_block
220 * dimN_reg_block * dimK_reg_block)
221 * (float)sizeof(float);
222 float rhs = C * L2_cache_size;
226 bool check_kernel_cond(int dimM_block, int dimM_reg_block, int dimM_simd_block,
227 int dimN_block, int dimN_reg_block, int dimK, float C1, float C2)
229 float A_size = dimM_block * dimM_reg_block * dimM_simd_block * dimK
230 * (float)sizeof(float);
231 float B_size = dimN_block * dimN_reg_block * dimK
232 * (float)sizeof(float);
233 return (A_size > C1 * L2_cache_size && B_size > C2 * L2_cache_size);
237 using namespace mkldnn::impl::memory_format;
238 using namespace mkldnn::impl::utils;
239 using namespace Xbyak;
241 void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::gemm_loop_generate()
243 // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++)
244 // for (int dimM_reg_block =0; dimM_reg_block < jcp.dimM_reg_block;
245 // dimM_reg_block++) // unrolled
246 // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++)
247 // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block;
248 // dimK_reg_block++) // unrolled
249 // for (int tile =0; tile < jcp.dimN_reg_block; tile++)
250 // C[dimM_block][dimM_reg_block][tile] +=
251 // A[dimM_block][dimM_reg_block][dimK_block][dimK_reg_block]
252 // * broadcast(B[dimK_block][tile][dimK_reg_block]);
254 // jcp.kernel_kind defines embedded or explicit broadcast
255 // dimM_reg_block=1 for embedded bcast kernel
257 auto zmm_srcA = [=]() {
258 return Xbyak::Zmm(0);
260 auto zmm_srcB = [=](int tile) {
262 assert(idx < 1 + jcp.dimN_reg_block);
263 return Xbyak::Zmm(idx);
265 auto zmm_dstC = [=](int dimM_reg_block, int tile) {
267 if (jcp.kernel_kind == embd_bcast)
270 idx = 1 + jcp.dimN_reg_block
271 + dimM_reg_block * jcp.dimN_reg_block + tile;
273 return Xbyak::Zmm(idx);
276 auto prepare_output = [=]() {
277 for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
279 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
280 Zmm zmm = zmm_dstC(dimM_reg_block, tile);
281 vpxord(zmm, zmm, zmm);
285 auto store_output = [=](bool output_is_aligned) {
287 cmp(reg_is_beta_zero, 0);
290 for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
292 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
293 Zmm zmm = zmm_dstC(dimM_reg_block,tile);
295 = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64;
296 vaddps(zmm, zmm, EVEX_compress_addr(reg_dstC, output_offset));
301 for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
303 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
304 Zmm zmm = zmm_dstC(dimM_reg_block,tile);
306 = jcp.dimN_reg_block * dimM_reg_block * 64 + tile * 64;
308 // In W_SGD, output will be reused.
309 if (output_is_aligned
310 && jcp.dimK_nb_block == 1
311 && jcp.sched_policy == WSCHED_DATA_W_S_G_D
312 && (jcp.dimN * jcp.dimM * alpha * alpha
313 * sizeof(float) > 2 * LLC_data_size))
314 vmovntps(EVEX_compress_addr(reg_dstC, output_offset), zmm);
315 else vmovups(EVEX_compress_addr(reg_dstC, output_offset), zmm);
320 auto inner_loops = [=]() {
321 Label dimM_block_loop, dimK_block_loop;
323 if (jcp.dimM_block > 1) {
324 mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
330 if (jcp.dimK_block > 1) {
331 mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
335 for (int dimK_reg_block = 0;
336 dimK_reg_block < jcp.dimK_reg_block;
339 if (jcp.kernel_kind == expl_bcast) {
340 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
341 vbroadcastss(zmm_srcB(tile),
342 ptr[reg_srcB + 64 * tile + dimK_reg_block * 4]);
346 /* Performing the fmas */
348 for (int dimM_reg_block = 0; dimM_reg_block < jcp.dimM_reg_block;
353 + jcp.dimK_reg_block * jcp.dimK_block * 64
355 + dimK_reg_block * 64]
358 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
359 if (jcp.kernel_kind == expl_bcast)
360 vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(),
363 vfmadd231ps(zmm_dstC(dimM_reg_block, tile), zmm_srcA(),
364 EVEX_compress_addr(reg_srcB,
365 64 * tile + dimK_reg_block * 4, true));
369 add(reg_srcA, jcp.dimK_reg_block * 64);
370 add(reg_srcB, jcp.dimN_reg_block * 64);
371 if (jcp.dimK_block > 1) {
372 sub(reg_dimK_block_loop_cnt, 1);
373 jnz(dimK_block_loop);
376 Label unaligned_store, end_store;
377 test(reg_dstC, cpu_isa_traits<avx512_core>::vlen - 1);
378 jnz(unaligned_store, T_NEAR);
380 jmp(end_store, T_NEAR);
381 L(unaligned_store); {
386 if (jcp.dimM_block > 1) {
387 sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64);
388 add(reg_dstC, jcp.dimM_reg_block * jcp.dimN_reg_block * 64);
389 if (jcp.kernel_kind == expl_bcast) {
391 (jcp.dimM_reg_block-1) * jcp.dimK_reg_block * 64
394 sub(reg_dimM_block_loop_cnt, 1);
395 jnz(dimM_block_loop);
410 void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
411 ::weights_transform_data_ker_generate()
413 bool is_fwd = one_of(jcp.prop_kind,
414 mkldnn_forward_training, mkldnn_forward_inference);
418 auto zmm_temp = Xbyak::Zmm(31);
419 auto zmm_zero = Xbyak::Zmm(30);
421 auto zmm_M = [=](int i) {
422 return Xbyak::Zmm(i);
424 auto zmm_MT = [=](int i) {
425 return Xbyak::Zmm(i + simd_w);
428 auto zmm_G = [=](int i) {
429 return Xbyak::Zmm(i);
431 auto zmm_F = [=](int i) {
432 return Xbyak::Zmm(alpha + i);
434 auto zmm_T = [=](int i) {
435 return Xbyak::Zmm(alpha + 3 + i);
437 auto zmm_t = [=](int i) {
438 return Xbyak::Zmm(2 * alpha + 3 + i);
441 auto zmm_load = [=](int i) {
442 return Xbyak::Zmm(i);
445 auto init_G = [=]() {
446 mov(wreg_temp, ptr[param1 + GET_OFF(G)]);
447 for (int i = 0; i < alpha; i++) {
448 vbroadcastss(zmm_G(i), ptr[wreg_temp + i * typesize]);
450 vpxord(zmm_zero, zmm_zero, zmm_zero);
453 auto trans16x16 = [=]() {
454 for (int i = 0; i < simd_w; i+=2 ) {
455 vmovups(zmm_M(i), ptr[wreg_M + i * simd_w * 4]);
456 vmovups(zmm_M(i+1), ptr[wreg_M + (i + 1) * simd_w * 4]);
457 vunpcklps(zmm_MT(i), zmm_M(i), zmm_M(i+1));
458 vunpckhps(zmm_MT(i+1), zmm_M(i), zmm_M(i+1));
460 for (int i = 0; i < simd_w; i+=4 ) {
461 vunpcklpd(zmm_M(i), zmm_MT(i), zmm_MT(i+2));
462 vunpckhpd(zmm_M(i+1), zmm_MT(i), zmm_MT(i+2));
463 vunpcklpd(zmm_M(i+2), zmm_MT(i+1), zmm_MT(i+3));
464 vunpckhpd(zmm_M(i+3), zmm_MT(i+1), zmm_MT(i+3));
466 for (int i = 0; i < simd_w; i += 8) {
467 vshuff32x4(zmm_MT(i), zmm_M(i), zmm_M(i + 4), 0x88);
468 vshuff32x4(zmm_MT(i+1), zmm_M(i+1), zmm_M(i + 5), 0x88);
469 vshuff32x4(zmm_MT(i+2), zmm_M(i+2), zmm_M(i + 6), 0x88);
470 vshuff32x4(zmm_MT(i+3), zmm_M(i+3), zmm_M(i + 7), 0x88);
471 vshuff32x4(zmm_MT(i+4), zmm_M(i), zmm_M(i + 4), 0xdd);
472 vshuff32x4(zmm_MT(i+5), zmm_M(i+1), zmm_M(i + 5), 0xdd);
473 vshuff32x4(zmm_MT(i+6), zmm_M(i+2), zmm_M(i + 6), 0xdd);
474 vshuff32x4(zmm_MT(i+7), zmm_M(i+3), zmm_M(i + 7), 0xdd);
479 vshuff32x4(zmm_M(0), zmm_MT(i), zmm_MT(i + 8), mask);
480 vmovups(ptr[wreg_MT + 0 * 16 * 4], zmm_M(0));
481 vshuff32x4(zmm_M(1), zmm_MT(i + 1), zmm_MT(i + 9), mask);
482 vmovups(ptr[wreg_MT + 1 * 16 * 4], zmm_M(1));
483 vshuff32x4(zmm_M(2), zmm_MT(i + 2), zmm_MT(i + 10), mask);
484 vmovups(ptr[wreg_MT + 2 * 16 * 4], zmm_M(2));
485 vshuff32x4(zmm_M(3), zmm_MT(i + 3), zmm_MT(i + 11), mask);
486 vmovups(ptr[wreg_MT + 3 * 16 * 4], zmm_M(3));
487 vshuff32x4(zmm_M(4), zmm_MT(i + 4), zmm_MT(i + 12), mask);
488 vmovups(ptr[wreg_MT + 4 * 16 * 4], zmm_M(4));
489 vshuff32x4(zmm_M(5), zmm_MT(i + 5), zmm_MT(i + 13), mask);
490 vmovups(ptr[wreg_MT + 5 * 16 * 4], zmm_M(5));
491 vshuff32x4(zmm_M(6), zmm_MT(i + 6), zmm_MT(i + 14), mask);
492 vmovups(ptr[wreg_MT + 6 * 16 * 4], zmm_M(6));
493 vshuff32x4(zmm_M(7), zmm_MT(i + 7), zmm_MT(i + 15), mask);
494 vmovups(ptr[wreg_MT + 7 * 16 * 4], zmm_M(7));
496 vshuff32x4(zmm_M(8), zmm_MT(i), zmm_MT(i + 8), mask);
497 vmovups(ptr[wreg_MT + 8 * 16 * 4], zmm_M(8));
498 vshuff32x4(zmm_M(9), zmm_MT(i + 1), zmm_MT(i + 9), mask);
499 vmovups(ptr[wreg_MT + 9 * 16 * 4], zmm_M(9));
500 vshuff32x4(zmm_M(10), zmm_MT(i + 2), zmm_MT(i + 10), mask);
501 vmovups(ptr[wreg_MT + 10 * 16 * 4], zmm_M(10));
502 vshuff32x4(zmm_M(11), zmm_MT(i + 3), zmm_MT(i + 11), mask);
503 vmovups(ptr[wreg_MT + 11 * 16 * 4], zmm_M(11));
504 vshuff32x4(zmm_M(12), zmm_MT(i + 4), zmm_MT(i + 12), mask);
505 vmovups(ptr[wreg_MT + 12 * 16 * 4], zmm_M(12));
506 vshuff32x4(zmm_M(13), zmm_MT(i + 5), zmm_MT(i + 13), mask);
507 vmovups(ptr[wreg_MT + 13 * 16 * 4], zmm_M(13));
508 vshuff32x4(zmm_M(14), zmm_MT(i + 6), zmm_MT(i + 14), mask);
509 vmovups(ptr[wreg_MT + 14 * 16 * 4], zmm_M(14));
510 vshuff32x4(zmm_M(15), zmm_MT(i + 7), zmm_MT(i + 15), mask);
511 vmovups(ptr[wreg_MT + 15 * 16 * 4], zmm_M(15));
515 auto load_src = [=]() {
516 mov(wreg_src, ptr[param1 + GET_OFF(src)]);
517 mov(wreg_F, ptr[param1 + GET_OFF(M)]);
518 for (int j = 0; j < kh; j++) {
519 for (int i = 0; i < kw; i++) {
521 for (int v1 = 0; v1 < simd_w; v1++) {
522 int offset_src = (j * kw * simd_w * simd_w
523 + i * simd_w * simd_w + v1 * simd_w) * typesize;
524 int offset_F = (j * kw * simd_w * simd_w
525 + i * simd_w * simd_w + v1 * simd_w) * typesize;
526 vmovups(zmm_temp, ptr[wreg_src + offset_src]);
527 vmovups(ptr[wreg_F + offset_F], zmm_temp);
530 int offset_src = ((2 - j) * kw * simd_w * simd_w
531 + (2 - i) * simd_w * simd_w) * typesize;
532 int offset_F = (j * kw * simd_w * simd_w
533 + i * simd_w * simd_w) * typesize;
534 lea(wreg_M, ptr[wreg_src + offset_src]);
535 lea(wreg_MT, ptr[wreg_F + offset_F]);
542 auto store_dst = [=]() {
543 mov(wreg_dst, ptr[param1 + GET_OFF(dst)]);
544 mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]);
548 mov(wreg_dst_aux, wreg_dst);
549 mov(wreg_Fw_aux, wreg_Fw);
551 int dim5 = jcp.dimK_nb_block * (jcp.dimM_block * jcp.dimM_reg_block)
552 * jcp.dimK_block * simd_w * simd_w;
556 for (int i = 0; i < alpha; i++) {
558 vmovups(zmm_load(0), ptr[wreg_Fw_aux
559 + (i * simd_w * simd_w) * typesize]);
560 mov(wreg_dst_idx, i * dim5 * typesize);
561 vmovntps(ptr[wreg_dst_aux + wreg_dst_idx], zmm_load(0));
563 for (int i = 0; i < alpha; i++) {
564 for (int v1 = 1; v1 < simd_w; v1++) {
565 int offset_Fw = (i * simd_w * simd_w + v1 * simd_w)
567 vmovups(zmm_load(v1), ptr[wreg_Fw_aux + offset_Fw]);
569 mov(wreg_dst_idx, i * dim5 * typesize);
570 for (int v1 = 1; v1 < simd_w; v1++) {
571 int offset_dst = v1 * simd_w * typesize;
572 vmovntps(ptr[wreg_dst_aux + wreg_dst_idx + offset_dst],
576 add(wreg_Fw_aux, alpha * simd_w * simd_w * typesize);
577 add(wreg_dst_aux, alpha * dim5 * typesize);
579 cmp(wreg_cnt_j, alpha);
584 auto trans_W_4x4_3x3 = [=]() {
585 auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
587 vfmadd231ps(dst, b, c);
589 auto fms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
590 vmulps(zmm_temp, b, c);
591 vsubps(dst, a, zmm_temp);
593 auto fnms4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
594 vsubps(dst, zmm_zero, a);
595 vfnmadd231ps(dst, b, c);
598 mov(wreg_Fw, ptr[param1 + GET_OFF(Mw)]);
599 mov(wreg_F, ptr[param1 + GET_OFF(M)]);
600 mov(wreg_T, ptr[param1 + GET_OFF(T)]);
605 mov(wreg_F_aux, wreg_F);
606 mov(wreg_Fw_aux, wreg_Fw);
607 mov(wreg_temp, wreg_cnt_j);
608 shl(wreg_temp, 4 + 2);
609 lea(wreg_F_aux, ptr[wreg_F + wreg_temp]);
610 lea(wreg_Fw_aux, ptr[wreg_Fw + wreg_temp]);
612 for (int i = 0; i < 3; i++) {
613 for (int idx = 0; idx < 3; idx ++) {
614 vmovups(zmm_F(idx), ptr[wreg_F_aux + (idx * 3 * simd_w
615 * simd_w + i * simd_w * simd_w) * typesize]);
617 vmulps(zmm_t(0), zmm_G(0), zmm_F(2));
618 fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_F(0));
619 fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_F(0));
621 vmulps(zmm_T(0), zmm_G(3), zmm_F(0));
622 fms4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_F(1));
623 fma4(zmm_T(2), zmm_t(1), zmm_G(4), zmm_F(1));
624 fma4(zmm_T(3), zmm_t(2), zmm_G(5), zmm_F(1));
625 fms4(zmm_T(4), zmm_t(2), zmm_G(5), zmm_F(1));
626 vmovaps(zmm_T(5), zmm_F(2));
628 for (int idx = 0; idx < 6; idx ++) {
629 vmovups(ptr[wreg_T + (idx * 3 * simd_w + i * simd_w)
630 * typesize], zmm_T(idx));
633 for (int i = 0; i < 6; i++) {
635 for (int idx = 0; idx < 3; idx ++) {
636 vmovups(zmm_T(idx), ptr[wreg_T
637 + (i * 3 * simd_w + idx * simd_w) * typesize]);
639 vmulps(zmm_t(0), zmm_G(0), zmm_T(2));
640 fnms4(zmm_t(1), zmm_t(0), zmm_G(1), zmm_T(0));
641 fma4(zmm_t(2), zmm_t(0), zmm_G(2), zmm_T(0));
643 vmulps(zmm_F(0), zmm_G(3), zmm_T(0));
644 fms4(zmm_F(1), zmm_t(1), zmm_G(4), zmm_T(1));
645 fma4(zmm_F(2), zmm_t(1), zmm_G(4), zmm_T(1));
646 fma4(zmm_F(3), zmm_t(2), zmm_G(5), zmm_T(1));
647 fms4(zmm_F(4), zmm_t(2), zmm_G(5), zmm_T(1));
648 vmovaps(zmm_F(5), zmm_T(2));
650 for (int l = 0; l < 6; l++) {
651 vmovups(ptr[wreg_Fw_aux + (i * 6 * simd_w * simd_w
652 + l * simd_w * simd_w) * typesize], zmm_F(l));
660 auto inner_loops = [=]() {
672 void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
673 ::output_transform_data_ker_generate()
675 bool is_fwd = one_of(jcp.prop_kind,
676 mkldnn_forward_training, mkldnn_forward_inference);
677 int outw = is_fwd ? jcp.ow : jcp.iw;
678 int outh = is_fwd ? jcp.oh : jcp.ih;
679 bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D;
680 bool with_bias = jcp.with_bias;
681 bool with_relu = jcp.with_eltwise;
682 bool with_relu_postsum = jcp.with_relu_postsum;
683 bool with_sum = jcp.with_sum;
685 auto zmm_zero = Xbyak::Zmm(0);
686 auto zmm_temp = Xbyak::Zmm(31);
687 auto zmm_G = [=](int i) {
688 return Xbyak::Zmm(1 + i);
690 auto zmm_O = [=](int i) {
691 return Xbyak::Zmm(1 + alpha + i);
693 auto zmm_T = [=](int i) {
694 return Xbyak::Zmm(1 + 2 * alpha + i);
696 auto zmm_t = [=](int i) {
697 return Xbyak::Zmm(1 + 3 * alpha + i);
700 auto init_G = [=]() {
701 mov(oreg_temp, ptr[param1 + GET_OFF(G)]);
702 for (int i = 0; i < 6; i++) {
703 vbroadcastss(zmm_G(i), ptr[oreg_temp + i * typesize]);
707 auto load_src = [=]() {
708 mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]);
709 mov(oreg_src, ptr[param1 + GET_OFF(src)]);
711 mov(oreg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]);
712 imul(oreg_nb_tile_block_ur, oreg_nb_tile_block_ur,
713 (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block
714 * jcp.dimM_simd_block * typesize);
715 add(oreg_src, oreg_nb_tile_block_ur);
717 mov(oreg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]);
718 imul(oreg_tile_block_ur, oreg_tile_block_ur,
719 jcp.dimM_simd_block * typesize);
720 add(oreg_src, oreg_tile_block_ur);
723 mov(oreg_tile_block, ptr[param1 + GET_OFF(tile_block)]);
724 imul(oreg_tile_block, oreg_tile_block,
725 jcp.dimM_nb_block * alpha * alpha * jcp.dimN_block
726 * (jcp.dimM_block * jcp.dimM_reg_block) * jcp.dimN_reg_block
727 * jcp.dimM_simd_block * typesize);
728 add(oreg_src, oreg_tile_block);
731 int last4dim = jcp.dimN_block * (jcp.dimM_block * jcp.dimM_reg_block)
732 * jcp.dimN_reg_block * jcp.dimM_simd_block * typesize;
733 for (int j = 0; j < alpha; j++) {
734 for (int i = 0; i < alpha; i++) {
735 int j_base_offset = j * alpha * last4dim;
736 int i_base_offset = i * last4dim;
737 vmovups(zmm_temp, ptr[oreg_src + j_base_offset + i_base_offset]);
738 vmovups(ptr[oreg_Ow + (j * alpha * simd_w + i * simd_w)
739 * typesize], zmm_temp);
744 auto store_dst = [=]() {
745 vpxord(zmm_zero, zmm_zero, zmm_zero);
746 mov(oreg_dst, ptr[param1 + GET_OFF(dst)]);
747 mov(oreg_O, ptr[param1 + GET_OFF(M)]);
748 mov(oreg_ydim, ptr[param1 + GET_OFF(tj)]);
749 shl(oreg_ydim, 2); // tj * tile_size (==4)
750 mov(oreg_xdim, ptr[param1 + GET_OFF(ti)]);
751 shl(oreg_xdim, 2); // ti * tilesize (==4)
754 mov(oreg_bias, ptr[param1 + GET_OFF(bias)]);
756 auto store_one = [=](int j, int i, bool is_aligned) {
757 auto zmm_O = Xbyak::Zmm(31);
758 auto zmm_relu_ns = Xbyak::Zmm(30);
759 auto xmm_relu_ns = Xbyak::Xmm(30);
760 int offset = (j * tile_size * simd_w + i * simd_w) * typesize;
762 vmovups(zmm_O, ptr[oreg_O + offset]);
765 vaddps(zmm_O, zmm_O, ptr[oreg_bias]);
768 if (jcp.eltwise.alpha == 0) {
769 vmaxps(zmm_O, zmm_O, zmm_zero);
771 Opmask kmask = Opmask(7);
772 mov(imm_addr64, float2int(jcp.eltwise.alpha));
773 vmovq(xmm_relu_ns, imm_addr64);
774 vbroadcastss(zmm_relu_ns, xmm_relu_ns);
775 vcmpps(kmask, zmm_O, zmm_zero, _cmp_lt_os);
776 vmulps(zmm_O | kmask, zmm_O, zmm_relu_ns);
781 vaddps(zmm_O, zmm_O, ptr[oreg_out_j + oreg_temp]);
782 if (with_relu_postsum) // orig: with_relu_postsum
783 vmaxps(zmm_O, zmm_O, zmm_zero);
786 vmovntps(ptr[oreg_out_j + oreg_temp], zmm_O);
788 vmovups(ptr[oreg_out_j + oreg_temp], zmm_O);
791 auto i_loop = [=](int j, bool is_aligned) {
792 for (int i = 0; i < tile_size; i++) {
794 mov(oreg_temp, oreg_xdim);
796 cmp(oreg_temp, outw);
798 shl(oreg_temp, 4 + 2); // * 16 * 4
800 store_one(j, i, is_aligned);
807 for (int j = 0; j < tile_size; j++) {
808 Label next, unaligned;
809 mov(oreg_temp, oreg_ydim);
811 cmp(oreg_temp, outh);
814 mov(oreg_out_j, oreg_dst);
815 imul(oreg_temp, oreg_temp, outw * simd_w * typesize);
816 add(oreg_out_j, oreg_temp);
819 jnz(unaligned, T_NEAR);
831 auto trans_O_4x4_3x3 = [=]() {
832 auto fma2 = [=](Zmm dst, Zmm v1, Zmm u1, Zmm v2, Zmm u2){
834 vfmadd231ps(dst, v2, u2);
836 mov(oreg_Ow, ptr[param1 + GET_OFF(Mw)]);
837 mov(oreg_T, ptr[param1 + GET_OFF(T)]);
838 mov(oreg_O, ptr[param1 + GET_OFF(M)]);
840 for (int i = 0; i < alpha; i++) {
841 for (int j = 0; j < alpha; j++) {
842 vmovups(zmm_O(j), ptr[oreg_Ow + (j * alpha * simd_w
843 + i * simd_w) * typesize]);
846 vaddps(zmm_t(0), zmm_O(1), zmm_O(2));
847 vaddps(zmm_t(1), zmm_O(3), zmm_O(4));
848 vsubps(zmm_t(2), zmm_O(1), zmm_O(2));
849 vsubps(zmm_t(3), zmm_O(3), zmm_O(4));
851 vaddps(zmm_T(0), zmm_t(0), zmm_t(1));
852 vaddps(zmm_T(0), zmm_T(0), zmm_O(0));
853 fma2(zmm_T(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1));
854 fma2(zmm_T(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3));
855 fma2(zmm_T(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5));
856 vaddps(zmm_T(3), zmm_T(3), zmm_O(5));
858 for (int j = 0; j < tile_size; j++) {
859 vmovups(ptr[oreg_T + (j * alpha * simd_w
860 + i * simd_w) * typesize], zmm_T(j));
863 for (int j = 0; j < tile_size; j++) {
864 for (int i = 0; i < alpha; i++) {
865 vmovups(zmm_T(i), ptr[oreg_T + (j * alpha * simd_w
866 + i * simd_w) * typesize]);
868 vaddps(zmm_t(0), zmm_T(1), zmm_T(2));
869 vaddps(zmm_t(1), zmm_T(3), zmm_T(4));
870 vsubps(zmm_t(2), zmm_T(1), zmm_T(2));
871 vsubps(zmm_t(3), zmm_T(3), zmm_T(4));
873 vaddps(zmm_O(0), zmm_t(0), zmm_t(1));
874 vaddps(zmm_O(0), zmm_O(0), zmm_T(0));
875 fma2(zmm_O(1), zmm_t(2), zmm_G(0), zmm_t(3), zmm_G(1));
876 fma2(zmm_O(2), zmm_t(0), zmm_G(2), zmm_t(1), zmm_G(3));
877 fma2(zmm_O(3), zmm_t(2), zmm_G(4), zmm_t(3), zmm_G(5));
878 vaddps(zmm_O(3), zmm_O(3), zmm_T(5));
880 for (int i = 0; i < tile_size; i++) {
881 vmovups(ptr[oreg_O + (j * tile_size * simd_w
882 + i * simd_w) * typesize], zmm_O(i));
887 auto inner_loops = [=]() {
899 void _jit_avx512_core_fp32_wino_conv_4x3_data_kernel
900 ::input_transform_data_ker_generate()
902 bool is_fwd = one_of(jcp.prop_kind,
903 mkldnn_forward_training, mkldnn_forward_inference);
904 int inpw = is_fwd ? jcp.iw : jcp.ow;
905 int inph = is_fwd ? jcp.ih : jcp.oh;
906 int l_pad = is_fwd ? jcp.l_pad : jcp.iw + jcp.r_pad - jcp.ow;
907 int t_pad = is_fwd ? jcp.t_pad : jcp.ih + jcp.t_pad - jcp.oh;
908 int wp_max = inpw + l_pad;
909 int hp_max = inph + t_pad;
910 bool not_tiled = jcp.sched_policy == WSCHED_DATA_W_S_G_D;
913 auto zmm_zero = Xbyak::Zmm(0);
914 auto zmm_temp = Xbyak::Zmm(31);
915 auto zmm_G = [=](int i) {
916 return Xbyak::Zmm(1 + i);
918 auto zmm_I = [=](int i) {
919 return Xbyak::Zmm(1 + G_size + i);
921 auto zmm_T = [=](int i) {
922 return Xbyak::Zmm(1 + G_size + alpha + i);
924 auto zmm_t = [=](int i) {
925 return Xbyak::Zmm(1 + G_size + 2 * alpha + i);
928 auto init_G = [=]() {
929 mov(ireg_temp, ptr[param1 + GET_OFF(G)]);
930 for (int i = 0; i < G_size; i++) {
931 vbroadcastss(zmm_G(i), ptr[ireg_temp + i * typesize]);
935 auto load_src = [=]() {
936 mov(ireg_src, ptr[param1 + GET_OFF(src)]); // base addr of inp
937 mov(ireg_I, ptr[param1 + GET_OFF(M)]);
939 xor_(ireg_zero, ireg_zero);
940 vpxord(zmm_zero, zmm_zero, zmm_zero);
942 mov(ireg_ydim, ptr[param1 + GET_OFF(tj)]);
943 shl(ireg_ydim, 2); // tj * tile_size (==4)
944 mov(ireg_xdim, ptr[param1 + GET_OFF(ti)]);
945 shl(ireg_xdim, 2); // ti * tilesize (==4)
947 for (int j = 0; j < alpha; j++) {
948 mov(ireg_temp, ireg_ydim);
951 mov(ireg_mask_j, 0xffff);
952 cmp(ireg_temp, t_pad);
953 cmovl(ireg_mask_j, ireg_zero);
954 cmp(ireg_temp, hp_max);
955 cmovge(ireg_mask_j, ireg_zero);
957 sub(ireg_temp, t_pad);
958 imul(ireg_temp, ireg_temp, inpw * simd_w * typesize);
959 mov(ireg_inp_j, ireg_src);
960 add(ireg_inp_j, ireg_temp);
962 for (int i = 0; i < alpha; i++) {
964 mov(ireg_temp, ireg_xdim);
967 mov(ireg_mask, 0xffff);
968 cmp(ireg_temp, l_pad);
969 cmovl(ireg_mask, ireg_zero);
970 cmp(ireg_temp, wp_max);
971 cmovge(ireg_mask, ireg_zero);
972 and_(ireg_mask, ireg_mask_j);
974 sub(ireg_temp, l_pad);
975 shl(ireg_temp, 4 + 2);
977 vpxord(zmm_temp, zmm_temp, zmm_temp);
978 Opmask kmask = Opmask(7);
979 kmovw(kmask, ireg_mask_32);
980 vmovups(zmm_temp | kmask, ptr[ireg_inp_j + ireg_temp]);
981 vmovups(ptr[ireg_I + (j * alpha * simd_w + i * simd_w)
982 * typesize], zmm_temp);
987 auto store_Iw = [=]() {
989 mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]);
990 mov(ireg_output, ptr[param1 + GET_OFF(dst)]);
993 = jcp.dimN * jcp.dimK * alpha * alpha * sizeof(float)
998 mov(ireg_tile_block, ptr[param1 + GET_OFF(tile_block)]);
999 imul(ireg_tile_block, ireg_tile_block,
1000 alpha * alpha * jcp.dimN_block * jcp.dimK_nb_block
1001 * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
1005 mov(ireg_nb_tile_block_ur, ptr[param1 + GET_OFF(nb_tile_block_ur)]);
1006 imul(ireg_nb_tile_block_ur, ireg_nb_tile_block_ur,
1007 jcp.dimK_nb_block * jcp.dimK_block * jcp.dimN_reg_block
1008 * jcp.dimK_reg_block * typesize);
1010 mov(ireg_tile_block_ur, ptr[param1 + GET_OFF(tile_block_ur)]);
1011 imul(ireg_tile_block_ur, ireg_tile_block_ur,
1012 jcp.dimK_reg_block * typesize);
1014 add(ireg_output, ireg_nb_tile_block_ur);
1015 add(ireg_output, ireg_tile_block_ur);
1017 add(ireg_output, ireg_tile_block);
1019 for (int j = 0; j < alpha; j++) {
1020 for (int i = 0; i < alpha; i++) {
1021 vmovups(zmm_temp,ptr[ireg_Iw + (j * alpha * simd_w
1022 + i * simd_w) * typesize]);
1025 j * alpha * jcp.dimN_block * jcp.dimK_nb_block
1026 * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
1029 i * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block
1030 * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize;
1032 if (not_tiled && streamout)
1033 vmovntps(ptr[ireg_output + j_base_offset + i_base_offset],
1036 vmovups(ptr[ireg_output + j_base_offset + i_base_offset],
1042 auto fma4 = [=](Zmm dst, Zmm a, Zmm b, Zmm c) {
1043 vmulps(zmm_temp, a, b);
1044 vaddps(dst, zmm_temp, c);
1047 auto trans_I_4x4_3x3 = [=]() {
1048 mov(ireg_Iw, ptr[param1 + GET_OFF(Mw)]);
1049 mov(ireg_T, ptr[param1 + GET_OFF(T)]);
1050 mov(ireg_I, ptr[param1 + GET_OFF(M)]);
1052 mov(ireg_output, ptr[param1 + GET_OFF(dst)]); // for prefetch
1053 for (int i = 0; i < alpha; i++) {
1054 for (int idx = 0; idx < alpha; idx++) {
1055 vmovups(zmm_I(idx), ptr[ireg_I + (idx * alpha * simd_w
1056 + i * simd_w) * typesize]);
1058 i * alpha * jcp.dimN_block * jcp.dimK_nb_block
1059 * jcp.dimK_block * jcp.dimN_reg_block * jcp.dimK_reg_block
1061 int idx_base_offset =
1062 idx * jcp.dimN_block * jcp.dimK_nb_block * jcp.dimK_block
1063 * jcp.dimN_reg_block * jcp.dimK_reg_block * typesize;
1064 prefetcht0(ptr[ireg_output + j_base_offset + idx_base_offset]);
1067 fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4));
1068 fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3));
1069 fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4));
1070 fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3));
1071 fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4));
1072 fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5));
1074 fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4));
1075 fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0));
1076 fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0));
1077 fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2));
1078 fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2));
1079 fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5));
1081 for (int idx = 0; idx < alpha; idx++) {
1082 vmovups(ptr[ireg_T + (idx * alpha * simd_w + i * simd_w)
1083 * typesize],zmm_T(idx));
1086 for (int i = 0; i < alpha; i++) {
1087 for (int idx = 0; idx < alpha; idx++) {
1088 vmovups(zmm_T(idx), ptr[ireg_T + (i * alpha * simd_w + idx
1089 * simd_w) * typesize]);
1092 fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4));
1093 fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3));
1094 fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4));
1095 fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3));
1096 fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4));
1097 fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5));
1099 fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4));
1100 fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0));
1101 fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0));
1102 fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2));
1103 fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2));
1104 fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5));
1106 for (int idx = 0; idx < alpha; idx++) {
1107 vmovups(ptr[ireg_Iw + (i * alpha * simd_w + idx * simd_w)
1108 * typesize],zmm_I(idx));
1113 auto inner_loops = [=]() {
1125 status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_common(
1126 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1127 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
1128 const memory_desc_wrapper &dst_d)
1130 if (!mayiuse(avx512_core)) {
1131 return status::unimplemented;
1134 jcp.nthr = mkldnn_get_max_threads();
1136 jcp.ver = ver_avx512_core;
1137 jcp.prop_kind = cd.prop_kind;
1139 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
1141 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
1142 jcp.mb = src_d.dims()[0];
1143 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
1144 jcp.oc_without_padding = jcp.oc;
1145 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1146 jcp.ih = src_d.dims()[2];
1147 jcp.iw = src_d.dims()[3];
1148 jcp.oh = dst_d.dims()[2];
1149 jcp.ow = dst_d.dims()[3];
1150 jcp.kh = weights_d.dims()[with_groups + 2];
1151 jcp.kw = weights_d.dims()[with_groups + 3];
1152 jcp.t_pad = cd.padding[0][0];
1153 jcp.l_pad = cd.padding[0][1];
1154 jcp.stride_h = cd.strides[0];
1155 jcp.stride_w = cd.strides[1];
1156 jcp.dilate_h = cd.dilates[0];
1157 jcp.dilate_w = cd.dilates[1];
1158 jcp.r_pad = nstl::max(
1159 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
1160 jcp.b_pad = nstl::max(
1161 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
1162 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1163 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1167 bool ok_to_pad_channels = jcp.ngroups == 1;
1168 if (ok_to_pad_channels) {
1169 jcp.oc = rnd_up(jcp.oc, simd_w);
1170 jcp.ic = rnd_up(jcp.ic, simd_w);
1173 // Checking conditions not supported by these kernels
1174 if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
1175 is_winograd_faster_than_direct(jcp)))
1176 return status::unimplemented;
1178 if (jcp.ngroups != 1)
1179 return status::unimplemented;
1180 if ((jcp.kh != 3) || (jcp.kw != 3))
1181 return status::unimplemented;
1182 if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
1183 return status::unimplemented;
1184 if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
1185 return status::unimplemented;
1186 if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
1187 return status::unimplemented;
1189 if (src_d.format() != nChw16c)
1190 return status::unimplemented;
1191 if (!one_of(weights_d.format(), any,
1192 with_groups ? gOIhw16i16o : OIhw16i16o, wino_fmt))
1193 return status::unimplemented;
1194 if (dst_d.format() != nChw16c)
1195 return status::unimplemented;
1197 bool layout_consistency = true
1198 && jcp.ic <= src_d.blocking_desc().padding_dims[1]
1199 && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
1200 && (weights_d.format() == any || weights_d.format() == wino_fmt
1201 || (jcp.ic <= weights_d.blocking_desc()
1202 .padding_dims[with_groups + 1]
1203 && jcp.oc <= weights_d.blocking_desc()
1204 .padding_dims[with_groups + 0]));
1205 if (!layout_consistency)
1206 return status::unimplemented;
1208 return status::success;
1211 void set_kernel_dims_reg_block(jit_conv_winograd_conf_t &jcp) {
1213 /* ----------- dimM reg block ---------------------*/
1214 auto test_cond_dimM_reg_block = [](jit_conv_winograd_conf_t &jcp,
1215 int dimM_reg_block, int current_best) {
1216 int max_dimM_reg_block = jcp.kernel_kind == embd_bcast ? 1 : 4;
1217 return (dimM_reg_block >= 1)
1218 && (dimM_reg_block <= max_dimM_reg_block )
1219 && (dimM_reg_block > current_best);
1221 jcp.dimM_reg_block = get_divisor_satisfying_cond(jcp,
1222 jcp.dimM/jcp.dimM_simd_block, 1, test_cond_dimM_reg_block);
1224 /* ----------- dimN reg block ---------------------*/
1226 auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
1227 int dimN_reg_block, int current_best) {
1228 return jcp.kernel_kind == embd_bcast
1229 ? dimN_reg_block < jcp.nb_reg && dimN_reg_block > current_best
1230 : dimN_reg_block >= 1
1231 && (dimN_reg_block * jcp.dimM_reg_block + dimN_reg_block)
1233 && dimN_reg_block > current_best;
1235 jcp.dimN_reg_block = get_divisor_satisfying_cond(jcp,
1236 jcp.dimN, 1, test_cond_dimN_reg_block);
1239 status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) {
1240 if (jcp.ver != ver_avx512_core)
1241 return status::unimplemented;
1243 jcp.kernel_kind = embd_bcast;
1245 set_kernel_dims_reg_block(jcp);
1247 /*-------------- L2 blocking for dimN block ---------*/
1249 auto test_cond_dimN_block = [](jit_conv_winograd_conf_t &jcp,
1250 int dimN_block, int current_best) {
1251 return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0)
1252 && (dimN_block > current_best)
1253 && ((jcp.dimN / dimN_block / jcp.dimN_reg_block)
1254 >= 1.5 * mkldnn_get_max_threads());
1257 jcp.dimN_block = get_divisor_satisfying_cond(
1258 jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block);
1259 jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
1261 if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2)
1262 && (jcp.dimN_nb_block >= 1.5 * mkldnn_get_max_threads())) {
1264 /* ------------------- L1 blocking for GEMM --------------*/
1265 /* -------------------- Choose dimK block ----------------*/
1267 auto test_cond_dimK_block = [](jit_conv_winograd_conf_t &jcp,
1268 int dimK_block, int current_best) {
1269 return check_L1_block_gemm(jcp, dimK_block, 1, 0.1, 0.5)
1270 && (dimK_block > current_best);
1273 jcp.dimK_block = get_divisor_satisfying_cond(
1274 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond_dimK_block);
1276 if (check_L1_block_gemm(jcp, jcp.dimK_block, 1, 0.1, 1.0)) {
1277 jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
1279 /* -------------- Choose dimM block -------------------*/
1280 auto test_cond_dimM_block = [](jit_conv_winograd_conf_t &jcp,
1281 int dimM_block, int current_best) {
1282 return check_L1_block_gemm(jcp, jcp.dimK_block, dimM_block,
1283 0.2, 0.5) && (dimM_block > current_best);
1286 jcp.dimM_block = get_divisor_satisfying_cond(jcp,
1287 jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_reg_block), 1,
1288 test_cond_dimM_block);
1289 jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block
1290 / jcp.dimM_simd_block;
1292 jcp.sched_policy = WSCHED_DATA_W_SGD;
1293 return status::success;
1297 return status::unimplemented;
1300 void set_kernel_blocking_DATA_W_S_G_D(jit_conv_winograd_conf_t &jcp) {
1302 set_kernel_dims_reg_block(jcp);
1304 //********************* Choosing dimK_block **********************//
1305 auto test_cond1_dimK_block = [](
1306 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
1307 return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block,
1308 1, jcp.dimM_reg_block, jcp.dimM_simd_block, .75f)
1309 && (dimK_block > current_best);
1312 auto test_cond1_bis_dimK_block = [](
1313 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
1314 return check_cond1_bis(jcp.dimN_reg_block, dimK_block,
1315 jcp.dimK_reg_block, 1, jcp.dimM_reg_block,
1316 jcp.dimM_simd_block, .9f)
1317 && (dimK_block > current_best);
1320 jcp.dimK_block = get_divisor_satisfying_cond(
1321 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block);
1322 // If we are not able to use streams, we fall back to condition [1]
1323 if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
1324 jcp.dimK_block = get_divisor_satisfying_cond(
1325 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block);
1326 jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block;
1328 //********************* Choosing dimM_block **********************//
1329 auto test_cond1_dimM_block = [](
1330 jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
1331 return check_cond1(jcp.dimN_reg_block, jcp.dimK_block,
1332 jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block,
1333 jcp.dimM_simd_block, .5f)
1334 && (dimM_block > current_best);
1337 auto test_cond1_bis_dimM_block = [](
1338 jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
1339 return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block,
1340 jcp.dimK_reg_block, dimM_block, jcp.dimM_reg_block,
1341 jcp.dimM_simd_block, .3f)
1342 && (dimM_block > current_best);
1345 if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
1346 jcp.dimM_block = get_divisor_satisfying_cond(
1347 jcp, jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1,
1348 test_cond1_dimM_block);
1350 jcp.dimM_block = get_divisor_satisfying_cond(jcp,
1351 jcp.dimM / (jcp.dimM_simd_block*jcp.dimM_reg_block), 1,
1352 test_cond1_bis_dimM_block);
1353 jcp.dimM_nb_block = jcp.dimM / (jcp.dimM_simd_block * jcp.dimM_block
1354 * jcp.dimM_reg_block);
1356 //******************* Choosing dimN_block *******************//
1357 auto test_cond2_dimN_block = [](
1358 jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
1359 return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block,
1360 jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block,
1361 jcp.dimM_reg_block, jcp.dimM_simd_block, .9f)
1362 && (dimN_block > current_best);
1365 jcp.dimN_block = get_divisor_satisfying_cond(
1366 jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
1367 jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block);
1370 status_t set_wsched_DATA_W_S_G_D_avx512_core(jit_conv_winograd_conf_t &jcp) {
1372 jcp.kernel_kind = expl_bcast;
1373 set_kernel_blocking_DATA_W_S_G_D(jcp);
1374 if (!(check_kernel_cond(jcp.dimM_block, jcp.dimM_reg_block,
1375 jcp.dimM_simd_block, jcp.dimN_block, jcp.dimN_reg_block, jcp.dimK,
1377 jcp.kernel_kind = embd_bcast;
1378 set_kernel_blocking_DATA_W_S_G_D(jcp);
1380 jcp.sched_policy = WSCHED_DATA_W_S_G_D;
1381 return status::success;
1384 status_t _jit_avx512_core_fp32_wino_conv_4x3_data_kernel::init_conf_kernel(
1385 jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK)
1391 jcp.sched_policy = WSCHED_INVALID;
1393 jcp.dimK_reg_block = 16;
1394 jcp.dimM_simd_block = 16;
1396 if (jcp.kernel_kind == embd_bcast) {
1397 jcp.dimM_reg_block = 1;
1400 if (!(set_wsched_DATA_W_SGD_avx512_core(jcp) == status::success))
1401 set_wsched_DATA_W_S_G_D_avx512_core(jcp);
1403 assert(jcp.sched_policy != WSCHED_INVALID);
1404 return status::success;
1407 bool jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::post_ops_ok(
1408 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
1409 const auto &p = attr.post_ops_;
1411 auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
1412 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
1415 case 0: return true; // no post_ops
1416 case 1: return is_relu(0) || is_sum(0); // relu or sum
1417 case 2: return (is_sum(0) && is_relu(1))
1418 || (is_relu(0) && is_sum(1)); // sum->relu or relu->sum
1419 case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu
1420 default: return false;
1426 status_t jit_avx512_core_fp32_wino_conv_4x3_fwd_kernel::init_conf(
1427 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1428 const cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
1429 const cpu_memory_t::pd_t &dst_pd, const primitive_attr_t &attr) {
1431 status_t st = init_conf_common(jcp, cd,
1432 *src_pd.desc(), *weights_pd.desc(), *dst_pd.desc());
1434 if (st != status::success)
1437 // Winograd specific initialization
1438 jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
1439 jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
1440 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
1442 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
1444 if (!post_ops_ok(jcp, attr))
1445 return status::unimplemented;
1447 const auto &p = attr.post_ops_;
1448 const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1);
1449 jcp.with_eltwise = eltwise_ind != -1;
1450 if (jcp.with_eltwise)
1451 jcp.eltwise = p.entry_[eltwise_ind].eltwise;
1453 jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
1454 jcp.with_relu_postsum = p.find(primitive_kind::eltwise, 1) != -1;
1456 status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic);
1458 jcp.ic_simd_block = jcp.dimK_reg_block;
1459 jcp.ic_block = jcp.dimK_block;
1460 jcp.nb_ic = jcp.dimK_nb_block;
1461 jcp.oc_simd_block = jcp.dimM_simd_block;
1462 jcp.oc_block = jcp.dimM_block;
1463 jcp.oc_reg_block = jcp.dimM_reg_block;
1464 jcp.ic_reg_block = 1;
1465 jcp.nb_oc = jcp.dimM_nb_block;
1466 jcp.tile_block_ur = jcp.dimN_reg_block;
1467 jcp.nb_tile_block_ur = jcp.dimN_block;
1468 jcp.tile_block = jcp.dimN_nb_block;
1470 /* re-create weights primitive descriptor
1471 and set weights wino_blocking */
1472 if (cd.prop_kind == mkldnn_forward_inference) {
1473 memory_desc_t expect_wei_md = *weights_pd.desc();
1475 expect_wei_md.format = mkldnn_wino_fmt;
1476 expect_wei_md.data_type = data_type::f32;
1477 mkldnn_wino_desc_t &wd = expect_wei_md.layout_desc.wino_desc;
1478 wd.wino_format = mkldnn_wino_wei_OBaaIBOIio;
1484 wd.ic_block = jcp.dimK_reg_block;
1485 wd.oc_block = jcp.dimM_simd_block;
1486 wd.ic2_block = jcp.dimK_block;
1487 wd.oc2_block = jcp.dimM_block * jcp.dimM_reg_block;
1488 size_t max_size = sizeof(float) * wd.alpha * wd.alpha * jcp.ic * jcp.oc;
1492 cpu_memory_t::pd_t new_weights_pd(
1493 weights_pd.engine(), &expect_wei_md);
1494 if (weights_pd.desc()->format == memory_format::any)
1495 weights_pd = new_weights_pd;
1496 if (!weights_pd.is_equal(&new_weights_pd))
1497 return status::unimplemented;
1503 status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_data_kernel::init_conf(
1504 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1505 const memory_desc_wrapper &diff_src_d,
1506 const memory_desc_wrapper &weights_d,
1507 const memory_desc_wrapper &diff_dst_d)
1509 status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d);
1511 if (st != status::success)
1514 jcp.itiles = (jcp.iw + tile_size - 1) / tile_size;
1515 jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size;
1516 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
1518 status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc);
1520 jcp.oc_simd_block = jcp.dimK_reg_block;
1521 jcp.oc_block = jcp.dimK_block;
1522 jcp.nb_oc = jcp.dimK_nb_block;
1523 jcp.ic_simd_block = jcp.dimM_simd_block;
1524 jcp.ic_block = jcp.dimM_block;
1525 jcp.ic_reg_block = jcp.dimM_reg_block;
1526 jcp.oc_reg_block = 1;
1527 jcp.nb_ic = jcp.dimM_nb_block;
1528 jcp.tile_block_ur = jcp.dimN_reg_block;
1529 jcp.nb_tile_block_ur = jcp.dimN_block;
1530 jcp.tile_block = jcp.dimN_nb_block;
1535 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
1536 src_transform_generate() {
1537 constexpr int G_size = 9;
1538 const size_t ifwp = jcp.iw + jcp.l_pad;
1539 const size_t ifhp = jcp.ih + jcp.t_pad;
1541 auto zmm_G = [=](int i) {
1542 return Xbyak::Zmm(i);
1544 auto zmm_I = [=](int i) {
1545 return Xbyak::Zmm(G_size + i);
1547 auto zmm_T = [=](int i) {
1548 return Xbyak::Zmm(G_size + alpha + i);
1550 auto zmm_t = [=](int i) {
1551 return Xbyak::Zmm(G_size + 2 * alpha + i);
1554 auto init_G = [=]() {
1555 mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1556 for (int i = 0; i < G_size; i++) {
1557 vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]);
1561 auto load_src = [=]() {
1562 mov(reg_I, ptr[reg_transp + GET_OFF(M)]);
1563 xor_(reg_zero, reg_zero);
1565 mov(reg_ydim, reg_tj);
1566 shl(reg_ydim, 2); //tj * tile_size(=4)
1568 for (int j = 0; j < alpha; j++) {
1569 /* check if tile index is within physical spatial boundaries*/
1570 mov(reg_maskj, 0xffff);
1571 cmp(reg_ydim, jcp.t_pad);
1572 cmovl(reg_maskj, reg_zero);
1573 cmp(reg_ydim, ifhp);
1574 cmovge(reg_maskj, reg_zero);
1576 /*address offset for tile in src*/
1577 mov(reg_src_offset, reg_ydim);
1578 sub(reg_src_offset, jcp.t_pad); // tj*tile_size - t_pad
1579 imul(reg_src_offset, reg_src_offset, jcp.iw);
1581 mov(reg_xdim, reg_ti);
1582 shl(reg_xdim, 2); // xdim = ti * tile_size
1584 add(reg_src_offset, reg_xdim);
1585 sub(reg_src_offset, jcp.l_pad);
1586 imul(reg_src_offset, reg_src_offset, simd_w * typesize);
1587 for (int i = 0; i < alpha; i++) {
1588 /* check if tile index is within physical spatial boundaries*/
1589 mov(reg_maski, 0xffff);
1590 cmp(reg_xdim, jcp.l_pad);
1591 cmovl(reg_maski, reg_zero);
1592 cmp(reg_xdim, ifwp);
1593 cmovge(reg_maski, reg_zero);
1594 and_(reg_maski, reg_maskj);
1596 Opmask kmask_src = Xbyak::Opmask(7);
1597 auto zmm_src = Xbyak::Zmm(31);
1598 kmovw(kmask_src, reg_maski_32);
1599 vpxord(zmm_src, zmm_src, zmm_src);
1600 vmovups(zmm_src | kmask_src, ptr[reg_src + reg_src_offset]);
1601 vmovups(ptr[reg_I], zmm_src);
1603 add(reg_xdim, 1); //xdim = ti * tile_size + i
1604 add(reg_src_offset, simd_w * typesize);
1605 add(reg_I, simd_w * typesize);
1611 auto fma4 = [=](Xbyak::Zmm dst, Xbyak::Zmm a, Xbyak::Zmm b, Xbyak::Zmm c) {
1613 vfmadd231ps(dst, a, b);
1616 auto trans_I_3x3_4x4 = [=]() {
1618 mov(reg_I, ptr[reg_transp + GET_OFF(M)]);
1619 mov(reg_T, ptr[reg_transp + GET_OFF(T)]);
1620 for (int i = 0; i < alpha; i++) {
1621 for (int j = 0; j < alpha; j++) {
1622 size_t I_off = (j * alpha + i) * simd_w * typesize;
1623 vmovups(zmm_I(j), ptr[reg_I + I_off]);
1626 fma4(zmm_t(0), zmm_I(2), zmm_G(0), zmm_I(4));
1627 fma4(zmm_t(1), zmm_I(1), zmm_G(0), zmm_I(3));
1628 fma4(zmm_t(2), zmm_I(2), zmm_G(1), zmm_I(4));
1629 fma4(zmm_t(3), zmm_I(1), zmm_G(1), zmm_I(3));
1630 fma4(zmm_t(4), zmm_I(0), zmm_G(2), zmm_I(4));
1631 fma4(zmm_t(5), zmm_I(1), zmm_G(2), zmm_I(5));
1633 fma4(zmm_T(0), zmm_I(2), zmm_G(3), zmm_t(4));
1634 fma4(zmm_T(1), zmm_t(1), zmm_G(4), zmm_t(0));
1635 fma4(zmm_T(2), zmm_t(1), zmm_G(5), zmm_t(0));
1636 fma4(zmm_T(3), zmm_t(3), zmm_G(6), zmm_t(2));
1637 fma4(zmm_T(4), zmm_t(3), zmm_G(7), zmm_t(2));
1638 fma4(zmm_T(5), zmm_I(3), zmm_G(8), zmm_t(5));
1640 for (int j = 0; j < alpha; j++) {
1641 vmovups(ptr[reg_T + (j * alpha + i) * simd_w * typesize],
1647 for (int j = 0; j < alpha; j++) {
1648 for (int i = 0; i < alpha; i++) {
1649 vmovups(zmm_T(i), ptr[reg_T + (j * alpha + i) * simd_w * typesize]);
1652 fma4(zmm_t(0), zmm_T(2), zmm_G(0), zmm_T(4));
1653 fma4(zmm_t(1), zmm_T(1), zmm_G(0), zmm_T(3));
1654 fma4(zmm_t(2), zmm_T(2), zmm_G(1), zmm_T(4));
1655 fma4(zmm_t(3), zmm_T(1), zmm_G(1), zmm_T(3));
1656 fma4(zmm_t(4), zmm_T(0), zmm_G(2), zmm_T(4));
1657 fma4(zmm_t(5), zmm_T(1), zmm_G(2), zmm_T(5));
1659 fma4(zmm_I(0), zmm_T(2), zmm_G(3), zmm_t(4));
1660 fma4(zmm_I(1), zmm_t(1), zmm_G(4), zmm_t(0));
1661 fma4(zmm_I(2), zmm_t(1), zmm_G(5), zmm_t(0));
1662 fma4(zmm_I(3), zmm_t(3), zmm_G(6), zmm_t(2));
1663 fma4(zmm_I(4), zmm_t(3), zmm_G(7), zmm_t(2));
1664 fma4(zmm_I(5), zmm_T(3), zmm_G(8), zmm_t(5));
1666 for (int i = 0; i < alpha; i++) {
1667 size_t dst_off = (j * alpha * jcp.ic_block
1668 * jcp.nb_tile_block_ur * jcp.tile_block_ur
1669 + i * jcp.ic_block * jcp.nb_tile_block_ur * jcp.tile_block_ur)
1670 * simd_w * typesize;
1671 vmovups(ptr[reg_dst + dst_off], zmm_I(i));
1676 auto compute_transform_SDGtWo = [=]() {
1677 mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]);
1678 mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]);
1679 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1680 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1681 xor_(reg_tile_count, reg_tile_count);
1682 Label loop_mb, loop_jtiles, loop_itiles, done;
1693 add(reg_tile_count, 1);
1694 cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1697 add(reg_dst, simd_w * typesize);
1699 cmp(reg_ti, jcp.itiles);
1702 xor_(reg_ti, reg_ti);
1704 cmp(reg_tj, jcp.jtiles);
1707 xor_(reg_tj, reg_tj);
1708 add(reg_src, jcp.ic * jcp.iw * jcp.ih * typesize);
1714 auto compute_transform = [=]() {
1715 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1716 xor_(reg_ti, reg_ti);
1717 xor_(reg_tj, reg_tj);
1719 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1720 mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
1721 imul(reg_temp, reg_tile_count, simd_w * typesize);
1722 add(reg_dst, reg_temp);
1724 Label loop_jtiles, loop_itiles, next_tile_block, next_tile;
1734 add(reg_tile_count, 1);
1735 cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1736 jge(next_tile_block);
1737 add(reg_dst, simd_w * typesize);
1741 sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1)
1742 * simd_w * typesize);
1743 size_t tblk_off = alpha * alpha * jcp.ic_block
1744 * jcp.nb_tile_block_ur * jcp.tile_block_ur
1745 * simd_w * typesize;
1746 add(reg_dst, tblk_off);
1747 xor_(reg_tile_count, reg_tile_count);
1751 cmp(reg_ti, jcp.itiles);
1754 xor_(reg_ti, reg_ti);
1756 cmp(reg_tj, jcp.jtiles);
1763 if (jcp.sched_policy == WSCHED_WEI_SDGtWo)
1764 compute_transform_SDGtWo();
1766 compute_transform();
1770 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
1771 diff_dst_transform_generate(bool with_bias) {
1773 constexpr int G_size = 8;
1774 auto zmm_G = [](int i) {
1775 return Xbyak::Zmm(31);
1778 auto zmm_src = [=](int j, int i) {
1779 return Xbyak::Zmm(G_size + j * 4 + i);
1782 auto zmm_bias = Xbyak::Zmm(31);
1784 auto load_src = [=]() {
1785 if (with_bias) vmovups(zmm_bias, ptr[reg_bias]);
1786 mov(reg_ydim, reg_tj);
1787 shl(reg_ydim, 2); //tj * tile_size(=4)
1788 for (int j = 0; j < tile_size; j++) {
1789 /* check if tile index is within physical spatial boundaries*/
1790 mov(reg_maskj, 0xffff);
1791 cmp(reg_ydim, jcp.oh);
1792 cmovge(reg_maskj, reg_zero);
1794 /*address offset for tile in src*/
1795 mov(reg_src_offset, reg_ydim);
1796 imul(reg_src_offset, reg_src_offset, jcp.ow);
1798 mov(reg_xdim, reg_ti);
1799 shl(reg_xdim, 2); // xdim = ti * tile_size
1801 add(reg_src_offset, reg_xdim);
1802 imul(reg_src_offset, reg_src_offset, simd_w * typesize);
1803 for (int i = 0; i < tile_size; i++) {
1804 /* check if tile index is within physical spatial boundaries*/
1805 mov(reg_maski, 0xffff);
1806 cmp(reg_xdim, jcp.ow);
1807 cmovge(reg_maski, reg_zero);
1808 and_(reg_maski, reg_maskj);
1810 Opmask kmask_src = Xbyak::Opmask(7);
1811 kmovw(kmask_src, reg_maski_32);
1812 vpxord(zmm_src(j, i), zmm_src(j, i), zmm_src(j, i));
1813 vmovups(zmm_src(j, i) | kmask_src, ptr[reg_src + reg_src_offset]);
1814 if (with_bias) vaddps(zmm_bias | kmask_src, zmm_bias,
1815 ptr[reg_src + reg_src_offset]);
1817 add(reg_xdim, 1); //xdim = ti * tile_size + i
1818 add(reg_src_offset, simd_w * typesize);
1822 if(with_bias) vmovups(ptr[reg_bias], zmm_bias);
1825 auto zmm_t = [=](int i) {
1826 return Xbyak::Zmm(G_size + 16 + i);
1829 auto zmm_T = [=](int j, int i) {
1830 return Xbyak::Zmm(j * 4 + i);
1833 auto movps = [=](Xbyak::Reg64 reg_dst, size_t dst_off, Xbyak::Zmm a) {
1834 if (jcp.sched_policy == WSCHED_WEI_SDGtWo)
1835 vmovups(ptr[reg_dst + dst_off], a);
1837 vmovntps(ptr[reg_dst + dst_off], a);
1840 auto trans_W_3x3_4x4 = [=]() {
1841 mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1842 for (int i = 0; i < tile_size; i++) {
1843 vbroadcastss(zmm_G(0), ptr[reg_G]);
1844 vmulps(zmm_t(0), zmm_src(2, i), zmm_G(0));
1846 vbroadcastss(zmm_G(1), ptr[reg_G + typesize]);
1847 vmovups(zmm_t(1), zmm_t(0));
1848 vfmsub231ps(zmm_t(1), zmm_src(0, i), zmm_G(1));
1850 vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]);
1851 vmovups(zmm_t(2), zmm_t(0));
1852 vfmadd231ps(zmm_t(2), zmm_src(0, i), zmm_G(2));
1854 vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]);
1855 vmulps(zmm_t(3), zmm_src(1, i), zmm_G(3));
1857 vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]);
1858 vfmadd231ps(zmm_t(3), zmm_src(3, i), zmm_G(4));
1860 vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]);
1861 vmulps(zmm_t(4), zmm_src(1, i), zmm_G(5));
1863 vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]);
1864 vfmadd231ps(zmm_t(4), zmm_src(3, i), zmm_G(6));
1866 vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]);
1867 vmulps(zmm_T(0, i), zmm_src(0, i), zmm_G(7));
1868 vsubps(zmm_T(1, i), zmm_t(1), zmm_t(3));
1869 vaddps(zmm_T(2, i), zmm_t(1), zmm_t(3));
1870 vaddps(zmm_T(3, i), zmm_t(2), zmm_t(4));
1871 vsubps(zmm_T(4, i), zmm_t(2), zmm_t(4));
1872 vmovups(zmm_T(5, i), zmm_src(3, i));
1875 for (int j = 0; j < alpha; j++) {
1876 vbroadcastss(zmm_G(0), ptr[reg_G]);
1877 vmulps(zmm_t(0), zmm_T(j, 2), zmm_G(0));
1879 vbroadcastss(zmm_G(1), ptr[reg_G + typesize]);
1880 vmovups(zmm_t(1), zmm_t(0));
1881 vfmsub231ps(zmm_t(1), zmm_T(j, 0), zmm_G(1));
1883 vbroadcastss(zmm_G(2), ptr[reg_G + 2 * typesize]);
1884 vmovups(zmm_t(2), zmm_t(0));
1885 vfmadd231ps(zmm_t(2), zmm_T(j, 0), zmm_G(2));
1887 vbroadcastss(zmm_G(3), ptr[reg_G + 3 * typesize]);
1888 vmulps(zmm_t(3), zmm_T(j, 1), zmm_G(3));
1890 vbroadcastss(zmm_G(4), ptr[reg_G + 4 * typesize]);
1891 vfmadd231ps(zmm_t(3), zmm_T(j, 3), zmm_G(4));
1893 vbroadcastss(zmm_G(5), ptr[reg_G + 5 * typesize]);
1894 vmulps(zmm_t(4), zmm_T(j, 1), zmm_G(5));
1896 vbroadcastss(zmm_G(6), ptr[reg_G + 6 * typesize]);
1897 vfmadd231ps(zmm_t(4), zmm_T(j, 3), zmm_G(6));
1899 vbroadcastss(zmm_G(7), ptr[reg_G + 7 * typesize]);
1900 vmulps(zmm_t(0), zmm_T(j, 0), zmm_G(7));
1901 vsubps(zmm_t(5), zmm_t(1), zmm_t(3));
1902 vaddps(zmm_t(1), zmm_t(1), zmm_t(3));
1903 vaddps(zmm_t(6), zmm_t(2), zmm_t(4));
1904 vsubps(zmm_t(2), zmm_t(2), zmm_t(4));
1905 vmovups(zmm_t(3), zmm_T(j, 3));
1907 int alpha_offset = (jcp.oc / jcp.nb_oc)
1908 * (jcp.ntiles / jcp.tile_block) * typesize;
1909 int dst_off = j * alpha * alpha_offset;
1910 movps(reg_dst, dst_off, zmm_t(0));
1911 dst_off += alpha_offset;
1912 movps(reg_dst, dst_off, zmm_t(5));
1913 dst_off += alpha_offset;
1914 movps(reg_dst, dst_off, zmm_t(1));
1915 dst_off += alpha_offset;
1916 movps(reg_dst, dst_off, zmm_t(6));
1917 dst_off += alpha_offset;
1918 movps(reg_dst, dst_off, zmm_t(2));
1919 dst_off += alpha_offset;
1920 movps(reg_dst, dst_off, zmm_t(3));
1924 auto compute_transform_SDGtWo = [=]() {
1925 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1926 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1927 if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]);
1929 xor_(reg_zero, reg_zero);
1930 xor_(reg_oc_ur, reg_oc_ur);
1931 Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, tiles_done;
1935 mov(reg_ti, ptr[reg_transp + GET_OFF(ti)]);
1936 mov(reg_tj, ptr[reg_transp + GET_OFF(tj)]);
1937 xor_(reg_tile_count, reg_tile_count);
1948 add(reg_tile_count, 1);
1949 cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
1952 add(reg_dst, jcp.oc_reg_block * simd_w * typesize);
1954 cmp(reg_ti, jcp.itiles);
1957 xor_(reg_ti, reg_ti);
1959 cmp(reg_tj, jcp.jtiles);
1962 xor_(reg_tj, reg_tj);
1963 add(reg_src, jcp.oc * jcp.ow * jcp.oh * typesize);
1968 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1969 add(reg_dst, simd_w * typesize);
1970 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1971 add(reg_src, jcp.oh * jcp.ow * simd_w * typesize);
1973 if (with_bias) add(reg_bias, simd_w * typesize);
1975 cmp(reg_oc_ur, jcp.oc_reg_block);
1980 auto compute_transform = [=]() {
1981 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
1982 mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
1983 if (with_bias) mov(reg_bias, ptr[reg_transp + GET_OFF(bias)]);
1985 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
1986 mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
1987 imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize);
1988 add(reg_dst, reg_temp);
1990 xor_(reg_zero, reg_zero);
1991 xor_(reg_oc_ur, reg_oc_ur);
1992 Label loop_mb, loop_jtiles, loop_itiles, loop_oc_ur, next_tile_block, next_tile;
1996 xor_(reg_ti, reg_ti);
1997 xor_(reg_tj, reg_tj);
2007 add(reg_tile_count, 1);
2008 cmp(reg_tile_count, jcp.nb_tile_block_ur * jcp.tile_block_ur);
2009 jge(next_tile_block);
2010 add(reg_dst, jcp.oc_reg_block * simd_w * typesize);
2014 sub(reg_dst, (jcp.nb_tile_block_ur * jcp.tile_block_ur - 1)
2015 * jcp.oc_reg_block * simd_w * typesize);
2016 int tblk_off = alpha * alpha * (jcp.oc/jcp.nb_oc)
2017 * (jcp.ntiles/jcp.tile_block) * typesize;
2018 add(reg_dst, tblk_off);
2019 xor_(reg_tile_count, reg_tile_count);
2023 cmp(reg_ti, jcp.itiles);
2026 xor_(reg_ti, reg_ti);
2028 cmp(reg_tj, jcp.jtiles);
2032 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
2033 mov(reg_tile_count, ptr[reg_transp + GET_OFF(tile_count)]);
2034 imul(reg_temp, reg_tile_count, jcp.oc_reg_block * simd_w * typesize);
2035 add(reg_dst, reg_temp);
2036 add(reg_dst, simd_w * typesize);
2037 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
2038 add(reg_src, jcp.oh * jcp.ow * simd_w * typesize);
2040 if (with_bias) add(reg_bias, simd_w * typesize);
2042 cmp(reg_oc_ur, jcp.oc_reg_block);
2048 if (jcp.sched_policy == WSCHED_WEI_SDGtWo) {
2049 compute_transform_SDGtWo();
2051 compute_transform();
2056 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::
2057 diff_weights_transform_generate(bool first_tile) {
2060 auto zmm_G = [](int i) {
2061 return Xbyak::Zmm(i);
2064 auto init_G = [=]() {
2065 mov(reg_G, ptr[reg_transp + GET_OFF(G)]);
2066 for (int i = 0; i < G_size; i++)
2067 vbroadcastss(zmm_G(i), ptr[reg_G + i * typesize]);
2070 auto zmm_src = [=](int i) {
2071 return Xbyak::Zmm(G_size + i);
2074 auto load_src = [=](int i) {
2075 for (int j = 0; j < alpha; j++) {
2076 size_t alpha_offset = jcp.oc_block * jcp.oc_reg_block
2077 * jcp.ic_block * simd_w * simd_w * typesize;
2078 size_t src_off = (j * alpha + i) * alpha_offset;
2079 vmovups(zmm_src(j), EVEX_compress_addr(reg_src, src_off));
2083 auto zmm_t = [=](int i) {
2084 return Xbyak::Zmm(G_size + 6 + i);
2087 auto zmm_T = [=](int j, int i) {
2088 return Xbyak::Zmm(G_size + 6 + 3 + j * 6 + i);
2091 auto zmm_dst = [=](int i) {
2092 return Xbyak::Zmm(G_size + i);
2095 auto zmm_temp = Xbyak::Zmm(31);
2097 auto store_dst = [=](int j) {
2098 for (int i = 0; i < jcp.kw; i++) {
2099 size_t dst_off = (j * jcp.kw + i) * simd_w * simd_w * typesize;
2102 vmovups(zmm_temp, EVEX_compress_addr(reg_dst, dst_off));
2103 vaddps(zmm_dst(i), zmm_dst(i), zmm_temp);
2105 vmovntps(EVEX_compress_addr(reg_dst, dst_off), zmm_dst(i));
2109 auto compute_transform = [=] () {
2110 mov(reg_src, ptr[reg_transp + GET_OFF(src)]);
2111 mov(reg_dst, ptr[reg_transp + GET_OFF(dst)]);
2113 xor_(reg_ic_simd, reg_ic_simd);
2117 for (int i = 0; i < alpha; i++) {
2120 vaddps(zmm_t(0), zmm_src(1), zmm_src(2));
2121 vaddps(zmm_t(1), zmm_src(3), zmm_src(4));
2122 vmovups(zmm_t(2), zmm_src(5));
2123 vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0));
2125 vaddps(zmm_T(0, i), zmm_src(0), zmm_t(0));
2126 vaddps(zmm_T(0, i), zmm_T(0, i), zmm_t(1));
2127 vsubps(zmm_T(1, i), zmm_src(1), zmm_src(2));
2128 vmulps(zmm_T(1, i), zmm_T(1, i), zmm_G(1));
2129 vsubps(zmm_temp, zmm_src(3), zmm_src(4));
2130 vfmadd231ps(zmm_T(1, i), zmm_temp, zmm_G(2));
2131 vmovups(zmm_T(2, i), zmm_t(2));
2132 vfmadd231ps(zmm_T(2, i), zmm_t(0), zmm_G(3));
2135 for (int j = 0; j < jcp.kh; j++) {
2136 vaddps(zmm_t(0), zmm_T(j, 1), zmm_T(j, 2));
2137 vaddps(zmm_t(1), zmm_T(j, 3), zmm_T(j, 4));
2138 vmovups(zmm_t(2), zmm_T(j, 5));
2139 vfmadd231ps(zmm_t(2), zmm_t(1), zmm_G(0));
2141 vaddps(zmm_dst(0), zmm_T(j, 0), zmm_t(0));
2142 vaddps(zmm_dst(0), zmm_dst(0), zmm_t(1));
2143 vsubps(zmm_dst(1), zmm_T(j, 1), zmm_T(j, 2));
2144 vmulps(zmm_dst(1), zmm_dst(1), zmm_G(1));
2145 vsubps(zmm_temp, zmm_T(j, 3), zmm_T(j, 4));
2146 vfmadd231ps(zmm_dst(1), zmm_temp, zmm_G(2));
2147 vmovups(zmm_dst(2), zmm_t(2));
2148 vfmadd231ps(zmm_dst(2), zmm_t(0), zmm_G(3));
2153 add(reg_src, jcp.oc_reg_block * simd_w * typesize);
2154 add(reg_dst, simd_w * typesize);
2155 add(reg_ic_simd, 1);
2156 cmp(reg_ic_simd, simd_w);
2161 push(reg_EVEX_max_8b_offt);
2162 mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
2164 compute_transform();
2165 pop(reg_EVEX_max_8b_offt);
2169 void jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::gemm_loop_generate(
2172 auto zmm_srcA = [=]() {
2173 return Xbyak::Zmm(0);
2176 auto zmm_srcB = [=] (size_t N_ur){
2177 return Xbyak::Zmm(N_ur + 1);
2180 auto broadcastB = [=](size_t K_ur) {
2181 for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) {
2182 size_t srcB_off = (K_ur * jcp.dimN_reg_block + N_bcast)
2184 vbroadcastss(zmm_srcB(N_bcast), EVEX_compress_addr(reg_srcB, srcB_off));
2188 auto load_srcA = [=] (size_t K_ur, int M_ur) {
2189 size_t srcA_off = (K_ur * jcp.dimM_reg_block * jcp.dimM_simd_block
2190 + M_ur * jcp.dimM_simd_block) * sizeof(float);
2191 vmovups(zmm_srcA(), EVEX_compress_addr(reg_srcA, srcA_off));
2194 auto zmm_dstC = [=](size_t M_reg_ur, int N_bcast){
2195 size_t idx = 1 // zmm_srcA
2196 + jcp.dimN_bcast_ur // zmm_srcB
2197 + M_reg_ur * jcp.dimN_bcast_ur + N_bcast;
2199 return Xbyak::Zmm(idx);
2201 auto prepare_accumm = [=](){
2202 for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) {
2203 for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; N_bcast++) {
2204 Zmm zmm = zmm_dstC(M_reg_ur, N_bcast);
2205 vpxord(zmm, zmm, zmm);
2210 auto store_dstC = [=](){
2211 /******** Write C back to memory *******/
2212 for (int M_reg = 0; M_reg < jcp.dimM_reg_block; M_reg++) {
2213 for (int N_ur = 0; N_ur < jcp.dimN_bcast_ur; ++N_ur) {
2214 Zmm zmm = zmm_dstC(M_reg, N_ur);
2215 size_t C_off = (N_ur * jcp.dimM_reg_block * jcp.dimM_simd_block
2216 + M_reg * jcp.dimM_simd_block) * sizeof(float);
2217 if (!is_first_tile) {
2218 vmovups(Xbyak::Zmm(0), EVEX_compress_addr(reg_dstC, C_off));
2219 vaddps(zmm, zmm, Xbyak::Zmm(0));
2221 vmovups(EVEX_compress_addr(reg_dstC, C_off), zmm);
2226 auto inner_loops = [=]() {
2227 Label dimM_block_loop, dimK_block_loop, dimN_block_loop, dimN_bcast_ur;
2229 mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
2231 { /************* OC_block (M) loop ***********/
2232 mov(reg_dimN_block_loop_cnt, jcp.dimN_block);
2234 { /*************** IC_block (N) loop *********/
2236 mov(reg_nb_dimN_bcast_ur, jcp.dimN_reg_block/jcp.dimN_bcast_ur);
2241 mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
2244 /************* nb_tile_ur(K) loop ********/
2245 for (int K_ur = 0; K_ur < jcp.dimK_reg_block; K_ur++) {
2249 for (int M_reg_ur = 0; M_reg_ur < jcp.dimM_reg_block; M_reg_ur++) {
2250 load_srcA(K_ur, M_reg_ur);
2251 for (int N_bcast = 0; N_bcast < jcp.dimN_bcast_ur; ++N_bcast) {
2252 vfmadd231ps(zmm_dstC(M_reg_ur, N_bcast), zmm_srcA(),
2257 add(reg_srcA, jcp.dimK_reg_block
2258 * jcp.dimM_reg_block * jcp.dimM_simd_block
2260 add(reg_srcB, jcp.dimK_reg_block
2261 * jcp.dimN_reg_block
2263 sub(reg_dimK_block_loop_cnt, 1);
2264 jnz(dimK_block_loop);
2269 sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
2270 * jcp.dimM_reg_block * jcp.dimM_simd_block
2272 sub(reg_srcB, jcp.dimK_block * jcp.dimK_reg_block
2273 * jcp.dimN_reg_block
2275 add(reg_srcB, jcp.dimN_bcast_ur * sizeof(float));
2276 add(reg_dstC, jcp.dimN_bcast_ur
2277 * jcp.dimM_reg_block * jcp.dimM_simd_block
2279 sub(reg_nb_dimN_bcast_ur, 1);
2283 sub(reg_srcB, jcp.dimN_reg_block * sizeof(float));
2284 add(reg_srcB, jcp.dimK_block
2285 * jcp.dimK_reg_block
2286 * jcp.dimN_reg_block * sizeof(float));
2287 sub(reg_dimN_block_loop_cnt, 1);
2288 jnz(dimN_block_loop);
2291 sub(reg_srcB, jcp.dimN_block
2292 * jcp.dimK_block * jcp.dimK_reg_block
2293 * jcp.dimN_reg_block
2295 add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
2296 * jcp.dimM_reg_block * jcp.dimM_simd_block
2298 sub(reg_dimM_block_loop_cnt, 1);
2299 jnz(dimM_block_loop);
2315 void set_jcp_WEI_params(jit_conv_winograd_conf_t &jcp) {
2317 jcp.dimM_nb_block = jcp.dimM / jcp.dimM_block / jcp.dimM_reg_block
2318 / jcp.dimM_simd_block;
2319 jcp.oc_reg_block = jcp.dimM_reg_block;
2320 jcp.oc_block = jcp.dimM_block;
2321 jcp.nb_oc = jcp.dimM_nb_block;
2323 jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;
2324 jcp.ic_block = jcp.dimN_block;
2325 jcp.nb_ic = jcp.dimN_nb_block;
2328 jcp.dimK_nb_block = jcp.dimK / jcp.dimK_block / jcp.dimK_reg_block;
2329 jcp.tile_block_ur = jcp.dimK_reg_block;
2330 jcp.nb_tile_block_ur = jcp.dimK_block;
2331 jcp.tile_block = jcp.dimK_nb_block;
2334 status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) {
2336 size_t K_blk_ur, N_blk, M_blk;
2337 /* IS this strategy feasible? */
2338 auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) {
2339 size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float);
2340 size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float);
2341 size_t nthreads = mkldnn_get_max_threads();
2342 return (((V_sz + M_sz) / nthreads) >= 2 * L2_cache_size)
2343 && (jcp.dimK / nthreads >= 1.0);
2346 auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur,
2348 size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * dimK_block_ur * sizeof(float);
2349 size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float);
2350 size_t M_L2_block = alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float);
2351 size_t nthreads = mkldnn_get_max_threads();
2352 bool load_balance=true;
2353 if (!(jcp.dimK % nthreads)) {
2354 load_balance = ((jcp.dimK / dimK_block_ur) % nthreads == 0);
2356 return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size)
2357 && (L1_block_M + L1_block_N <= 0.5 * L1_cache_size)
2359 && (M_L2_block < L2_cache_size);
2362 auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur,
2363 int useless_arg=0) {
2364 return (dimK_ur >= 2) && (dimK_ur <= 8);
2367 auto blocking_ok = [&](){
2368 size_t M_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block
2369 * K_blk_ur * sizeof(float);
2370 size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block
2371 * K_blk_ur * sizeof(float);
2372 size_t U_L2_block = alpha * alpha * M_blk * jcp.dimM_reg_block * jcp.dimM_simd_block
2373 * N_blk * jcp.dimN_reg_block * sizeof(float);
2374 size_t L2_block = M_L2_block + V_L2_block + U_L2_block;
2375 /*Replace 2.375 with L2+L3 cache size*/
2376 return (L2_block > 0.1 * L2_cache_size) && (L2_block <= 1.2 * L2_cache_size);
2379 if (test_MV_large_enough(jcp)) {
2380 if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) {
2381 jcp.dimM_reg_block = 2;
2383 jcp.dimM_reg_block = 1;
2385 jcp.dimM_simd_block = jcp.oc_simd_block;
2386 jcp.dimN_reg_block = jcp.ic_simd_block;
2387 jcp.dimN_bcast_ur = 8;
2388 /*dimK_block and dimK_ur*/
2389 size_t min_dimK_block_ur = get_divisor_satisfying_cond(jcp, jcp.dimK, 1, test_min_dimK_L1);
2391 jcp.dimM_block = jcp.dimM/jcp.dimM_reg_block/jcp.dimM_simd_block;
2392 jcp.dimN_block = jcp.dimN/jcp.dimN_reg_block;
2393 for (K_blk_ur = min_dimK_block_ur; K_blk_ur >= 1; --K_blk_ur) {
2394 if (test_min_dimK_L1(jcp, K_blk_ur) && !(jcp.dimK % K_blk_ur)) {
2395 for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) {
2396 if (!(jcp.dimN_block % N_blk)) {
2397 for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) {
2398 if (!(jcp.dimM_block % M_blk) && blocking_ok()) {
2399 jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur);
2400 if (!test_dimK_ur(jcp, jcp.dimK_reg_block)) return status::unimplemented;
2401 jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block;
2402 jcp.dimN_block = N_blk;
2403 jcp.dimM_block = M_blk;
2404 jcp.sched_policy = WSCHED_WEI_SDGtWo;
2405 set_jcp_WEI_params(jcp);
2406 jcp.nthr = nstl::min(mkldnn_get_max_threads(),
2408 return status::success;
2416 return status::unimplemented;
2419 status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) {
2420 if ((jcp.dimM/jcp.dimM_simd_block) % 2 == 0) {
2421 jcp.dimM_reg_block = 2;
2423 jcp.dimM_reg_block = 1;
2425 jcp.dimN_bcast_ur = 8;
2426 jcp.dimN_reg_block = jcp.ic_simd_block;
2427 jcp.dimM_simd_block = jcp.oc_simd_block;
2428 jcp.dimN_block = jcp.dimN / jcp.dimN_reg_block;
2429 jcp.dimM_block = jcp.dimM / jcp.dimM_reg_block / jcp.dimM_simd_block;
2430 float C1 = 0.0, C2 = 0.0;
2431 float C1_max = 0.5, C2_max = 1.4;
2432 int N_blk, M_blk, K_blk_ur;
2434 auto test_dimK_ur = [](jit_conv_winograd_conf_t &jcp, int dimK_ur,
2435 int useless_arg=0) {
2436 return (dimK_ur >= 2) && (dimK_ur <= 8);
2439 auto blocking_ok = [&]() -> bool {
2440 size_t L1_block_M = jcp.dimM_reg_block * jcp.dimM_simd_block * K_blk_ur * sizeof(float);
2441 size_t L1_block_N = jcp.dimN_reg_block * K_blk_ur * sizeof(float);
2442 bool L1_cond = ((L1_block_N + L1_block_M) >= C1 * L1_cache_size)
2443 && ((L1_block_N + L1_block_M) <= C1_max * L1_cache_size);
2445 size_t nb_N_blk = jcp.dimN/N_blk/jcp.dimN_reg_block;
2446 size_t nb_M_blk = jcp.dimM/M_blk/jcp.dimM_reg_block/jcp.dimM_simd_block;
2447 size_t nb_K_blk = jcp.dimK / K_blk_ur;
2448 size_t nthreads = mkldnn_get_max_threads();
2449 bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) >= nthreads;
2450 if (!(nb_K_blk % nthreads)) {
2451 load_balance = load_balance && (nb_K_blk % nthreads == 0);
2454 size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block * K_blk_ur * sizeof(float);
2456 size_t L2_block = V_L2_block;
2457 /*Replace 2.375 with L2+L3 cache size*/
2458 bool L2_cond = (L2_block >= C2 * L2_cache_size) && (L2_block <= C2_max * L2_cache_size);
2459 return L1_cond && load_balance && L2_cond;
2462 for (K_blk_ur = jcp.dimK; K_blk_ur >= 1; --K_blk_ur) {
2463 if (jcp.dimK % K_blk_ur == 0) {
2464 for (N_blk = jcp.dimN_block; N_blk >= 1; --N_blk) {
2465 if (jcp.dimN_block % N_blk == 0) {
2466 for (M_blk = jcp.dimM_block; M_blk >= 1; --M_blk) {
2467 if (jcp.dimM_block % M_blk == 0) {
2468 if (blocking_ok()) {
2469 jcp.dimN_block = N_blk;
2470 jcp.dimM_block = M_blk;
2471 jcp.dimK_reg_block = get_divisor_satisfying_cond(jcp, K_blk_ur, 1, test_dimK_ur);
2472 jcp.dimK_block = K_blk_ur / jcp.dimK_reg_block;
2473 jcp.sched_policy = WSCHED_WEI_S_D_Giot_W;
2474 set_jcp_WEI_params(jcp);
2475 return status::success;
2483 jcp.dimK_reg_block = 1;
2485 jcp.sched_policy = WSCHED_WEI_S_D_Giot_W;
2486 set_jcp_WEI_params(jcp);
2487 return status::success;
2490 status_t jit_avx512_core_fp32_wino_conv_4x3_bwd_weights_kernel::init_conf(
2491 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
2492 const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d,
2493 const memory_desc_wrapper &diff_weights_d) {
2494 if (!mayiuse(avx512_core))
2495 return status::unimplemented;
2497 jcp.ver = ver_avx512_core;
2499 jcp.nthr = mkldnn_get_max_threads();
2501 jcp.prop_kind = cd.prop_kind;
2502 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
2503 jcp.mb = src_d.dims()[0];
2504 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
2505 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
2506 jcp.oc_without_padding = jcp.oc;
2507 jcp.ic = src_d.dims()[1] / jcp.ngroups;
2508 jcp.ih = src_d.dims()[2];
2509 jcp.iw = src_d.dims()[3];
2510 jcp.oh = diff_dst_d.dims()[2];
2511 jcp.ow = diff_dst_d.dims()[3];
2512 jcp.kh = diff_weights_d.dims()[with_groups + 2];
2513 jcp.kw = diff_weights_d.dims()[with_groups + 3];
2514 jcp.t_pad = cd.padding[0][0];
2515 jcp.l_pad = cd.padding[0][1];
2516 jcp.stride_h = cd.strides[0];
2517 jcp.stride_w = cd.strides[1];
2518 jcp.r_pad = nstl::max(
2519 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
2520 jcp.b_pad = nstl::max(
2521 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
2522 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
2523 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
2526 jcp.with_bias = (cd.diff_bias_desc.format != memory_format::undef);
2527 jcp.dilate_h = cd.dilates[0];
2528 jcp.dilate_w = cd.dilates[1];
2530 bool ok_to_pad_channels = jcp.ngroups == 1;
2531 if (ok_to_pad_channels) {
2532 jcp.oc = rnd_up(jcp.oc, simd_w);
2533 jcp.ic = rnd_up(jcp.ic, simd_w);
2536 // Winograd specific initialization
2537 jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
2538 jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
2539 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
2541 // Winograd kernel works only for 3x3 convolution with stride 1
2542 if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
2543 is_winograd_faster_than_direct(jcp)))
2544 return status::unimplemented;
2546 if (jcp.ngroups != 1)
2547 return status::unimplemented;
2548 if ((jcp.kh != 3) || (jcp.kw != 3))
2549 return status::unimplemented;
2550 if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
2551 return status::unimplemented;
2552 if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
2553 return status::unimplemented;
2554 if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
2555 return status::unimplemented;
2556 if (src_d.format() != nChw16c)
2557 return status::unimplemented;
2558 if (diff_weights_d.format() != (with_groups ? gOIhw16i16o : OIhw16i16o))
2559 return status::unimplemented;
2560 if (diff_dst_d.format() != nChw16c)
2561 return status::unimplemented;
2563 bool layout_consistency = true
2564 && jcp.ic <= src_d.blocking_desc().padding_dims[1]
2565 && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
2566 && jcp.ic <= diff_weights_d.blocking_desc().padding_dims[with_groups + 1]
2567 && jcp.oc <= diff_weights_d.blocking_desc().padding_dims[with_groups + 0];
2568 if (!layout_consistency) return status::unimplemented;
2570 /******************Kernel blocking Parameters ***********/
2571 jcp.ic_simd_block = simd_w;
2572 jcp.oc_simd_block = simd_w;
2574 jcp.dimK = jcp.ntiles;
2577 jcp.dimM_simd_block = jcp.oc_simd_block;
2578 jcp.dimN_reg_block = jcp.ic_simd_block;
2579 jcp.sched_policy = WSCHED_INVALID;
2580 status_t res = set_wsched_WEI_SDGtWo(jcp);
2581 if (res == status::unimplemented) {
2582 res = set_wsched_WEI_S_D_Giot_W(jcp);
2583 assert(res == status::success);
2591 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s