1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
18 #include "mkldnn_thread.hpp"
20 #include "type_helpers.hpp"
22 #include "cpu_memory.hpp"
26 #include "jit_avx512_common_conv_winograd_kernel_f32.hpp"
28 #ifndef KERNEL_SIZE_THRESHOLD
29 #define KERNEL_SIZE_THRESHOLD 16
32 #define MIN_REQUIRED_DIMN_REG_BLOCK 14
40 using namespace mkldnn::impl::utils;
42 unsigned int L1_cache_size = get_cache_size(1, true);
43 unsigned int L2_cache_size = get_cache_size(2, true);
44 unsigned int LLC_data_size = get_cache_size(3, false);
46 // the test funtion takes jcp, the candidate and the current best.
47 // it returns true if the new candidate is better
48 int get_divisor_satisfying_cond(jit_conv_winograd_conf_t &jcp, int number,
49 int default_best, bool (*test)(jit_conv_winograd_conf_t &, int, int))
51 int best_divisor = default_best;
53 = [&best_divisor, test](jit_conv_winograd_conf_t &jcp, int num) {
54 if (test(jcp, num, best_divisor)) {
59 for (int divisor = 1; divisor <= ::sqrt(number); divisor++) {
60 if (number % divisor == 0) {
61 test_num(jcp, divisor);
62 test_num(jcp, number / divisor);
70 bool is_winograd_faster_than_direct(const jit_conv_winograd_conf_t &jcp) {
71 if (jcp.ver == ver_4fma)
78 /* assumes 512 bits registers */
79 /* TODO: add support for strides */
80 /* TODO: handle the prefetch distance automatically */
81 typedef enum cache_t_ { L1, L2, L3 } cache_t;
83 template <typename data_t>
85 prefetcher_t(jit_generator *generator, Xbyak::Reg64 reg_base_addr,
86 cache_t cache_type, size_t block_size, /* in number of elements*/
87 int nb_instructions_in_block, int fma_ipc)
89 , reg_base_addr_(reg_base_addr)
90 , cache_type_(cache_type)
91 , cache_block_size_(block_size)
93 nb_cache_lines_to_prefetch_ = cache_block_size_ / (64 / sizeof(data_t));
95 = div_up(nb_instructions_in_block, nb_cache_lines_to_prefetch_);
97 = div_up(nb_cache_lines_to_prefetch_, nb_instructions_in_block);
99 /* assumption: when fetch in Li, data is already in L(i+1) */
101 switch (cache_type_) {
102 case L1: cache_latency = 14; break;
105 default: cache_latency = 250; break;
108 prefetch_distance_ = div_up(cache_latency, nb_cache_lines_to_prefetch_);
111 void prefetch(int instruction_number)
113 if (instruction_number % prefetch_spread_ == 0) {
114 for (int i = 0; (i < prefetch_blk_)
115 && (prefetches_issued_ < nb_cache_lines_to_prefetch_);
116 i++, prefetches_issued_++) {
117 prefetch_inst_(cg_->EVEX_compress_addr(
118 reg_base_addr_, (cache_block_size_ * prefetch_distance_)
120 + (prefetches_issued_ * 64)));
126 void prefetch_inst_(const Xbyak::Address &addr)
128 switch (cache_type_) {
129 case L1: cg_->prefetcht0(addr); break;
130 case L2: cg_->prefetcht1(addr); break;
131 case L3: cg_->prefetcht2(addr); break;
133 break; // TODO: raise an exception or put an assert
138 Xbyak::Reg64 reg_base_addr_;
140 int cache_block_size_ = 0;
141 int nb_cache_lines_to_prefetch_ = 0;
142 int prefetches_issued_ = 0;
143 int prefetch_spread_ = 0;
144 int prefetch_blk_ = 0;
145 int prefetch_distance_ = 0;
148 // utilities to support kernel parameter selection
149 bool check_cond1(int dimN_reg_block, int dimK_block, int dimK_reg_block,
150 int dimM_block, int dimM_simd_block, float C)
152 float lhs = (dimM_block * dimN_reg_block * dimM_simd_block
153 + dimM_block * dimK_block * dimK_reg_block
155 + dimK_block * dimN_reg_block * dimK_reg_block)
156 * (float)sizeof(float);
157 float rhs = C * L1_cache_size;
161 bool check_cond1_bis(int dimN_reg_block, int dimK_block, int dimK_reg_block,
162 int dimM_block, int dimM_simd_block, float C)
164 float lhs = (dimM_block * dimK_block * dimK_reg_block * dimM_simd_block
165 + dimK_block * dimN_reg_block * dimK_reg_block)
166 * (float)sizeof(float);
167 float rhs = C * L1_cache_size;
171 bool check_cond2(int nb_dimN_reg_block, int dimN_reg_block, int dimK_nb_block,
172 int dimK_block, int dimK_reg_block, int dimM_block, int dimM_simd_block,
175 float lhs = (nb_dimN_reg_block * dimM_block * dimN_reg_block * dimM_simd_block
176 + dimK_nb_block * dimM_block * dimK_block * dimK_reg_block
178 + nb_dimN_reg_block * dimK_nb_block * dimK_block
179 * dimN_reg_block * dimK_reg_block)
180 * (float)sizeof(float);
181 float rhs = C * L2_cache_size;
186 using namespace mkldnn::impl::memory_format;
187 using namespace mkldnn::impl::utils;
188 using namespace Xbyak;
190 void _jit_avx512_common_conv_winograd_data_kernel_f32::gemm_loop_generate(
193 // const int dimK_simd_block = jcp.dimK_reg_block;
195 // for (int dimM_block =0; dimM_block < jcp.dimM_block; dimM_block++)
196 // for (int dimK_block = 0; dimK_block < jcp.dimK_block; dimK_block++)
197 // for (int dimK_reg_block= 0; dimK_reg_block < jcp.dimK_reg_block;
199 // for (int tile =0; tile < jcp.dimN_reg_block; tile++)
200 // C[dimM_block][tile] +=
201 // A[dimM_block][dimK_block][dimK_reg_block] *
202 // broadcast(B[dimK_block][tile][dimK_reg_block]);
203 // 1) We do register blocking on A[dimM_block][dimK_block][dimK_reg_block],
204 // so we load it before the loop on tile
205 // 2) the loop on tile must be fully unrolled. Don't know about the one on
206 // dimK_reg_block. I think it should be
208 auto inner_loops = [=]() {
209 Label dimM_block_loop, dimK_block_loop;
210 const int inc_dimK_reg_block = jcp.ver == ver_4fma ? 4 : 1;
211 const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2;
213 prefetcher_t<float> L1_pf(this, reg_srcB, L1,
214 jcp.dimN_reg_block * jcp.dimK_reg_block,
215 jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block,
217 prefetcher_t<float> L2_pf(this, reg_srcB, L2,
218 jcp.dimN_reg_block * jcp.dimK_reg_block,
219 jcp.dimK_reg_block * jcp.dimN_reg_block / inc_dimK_reg_block,
222 if (jcp.dimM_block > 1) {
223 mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
227 // First, we zero the accumulators if first nb_ic iteration,
228 // otherwise we load them
229 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
230 Zmm zmm(jcp.zmm_start + tile);
232 vpxord(zmm, zmm, zmm);
234 vmovups(zmm, zword[reg_dstC + 64 * tile]);
237 if (jcp.dimK_block > 1) {
238 mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
242 auto load_A = [=](int reg_idx, int offset) {
243 for (int i = 0; i < inc_dimK_reg_block; i++)
244 vmovups(Zmm(reg_idx + i),
245 zword[reg_srcA + 64 * (offset + i)]);
248 // Used when doing double buffering
250 if (jcp.double_buffering) {
253 for (int dimK_reg_block = 0;
254 dimK_reg_block < jcp.dimK_reg_block;
255 dimK_reg_block += inc_dimK_reg_block) {
257 /* Loading the next vector from A */
259 if (jcp.double_buffering) {
260 next = (dimK_reg_block + inc_dimK_reg_block)
261 % (2 * inc_dimK_reg_block);
262 load_A(next, dimK_reg_block + inc_dimK_reg_block);
265 load_A(next, dimK_reg_block);
267 /* Performing the fmas */
268 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
269 Zmm zmm(jcp.zmm_start + tile);
270 if (jcp.ver != ver_avx512_core)
272 dimK_reg_block * jcp.dimN_reg_block + tile);
273 if (jcp.ver == ver_4fma)
274 v4fmaddps(zmm, Zmm(current),
275 EVEX_compress_addr(reg_srcB,
276 64 * tile + dimK_reg_block * 4));
278 vfmadd231ps(zmm, Zmm(current),
279 EVEX_compress_addr(reg_srcB,
280 64 * tile + dimK_reg_block * 4,
282 if (jcp.ver != ver_avx512_core)
284 dimK_reg_block * jcp.dimN_reg_block + tile);
288 add(reg_srcA, jcp.dimK_reg_block * 64);
289 add(reg_srcB, jcp.dimN_reg_block * 64);
290 if (jcp.dimK_block > 1) {
291 sub(reg_dimK_block_loop_cnt, 1);
292 jnz(dimK_block_loop);
297 auto store_output = [=](bool output_is_aligned) {
298 for (int tile = 0; tile < jcp.dimN_reg_block; tile++) {
299 Zmm zmm(jcp.zmm_start + tile);
300 if (output_is_aligned
301 && jcp.dimK_nb_block == 1
302 && (jcp.dimN * jcp.dimM * alpha * alpha
303 * sizeof(float) > 2 * LLC_data_size))
304 vmovntps(zword[reg_dstC + 64 * tile], zmm);
306 vmovups(zword[reg_dstC + 64 * tile], zmm);
310 Label unaligned_store, end_store;
311 test(reg_dstC, cpu_isa_traits<avx512_common>::vlen - 1);
312 jnz(unaligned_store, T_NEAR);
314 jmp(end_store, T_NEAR);
315 L(unaligned_store); {
320 if (jcp.dimM_block > 1) {
321 sub(reg_srcB, jcp.dimK_block * jcp.dimN_reg_block * 64);
322 add(reg_dstC, jcp.dimN_reg_block * 64);
323 sub(reg_dimM_block_loop_cnt, 1);
324 jnz(dimM_block_loop);
340 status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_common(
341 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
342 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
343 const memory_desc_wrapper &dst_d)
346 if (mayiuse(avx512_core))
347 return status::unimplemented;
348 else if (!mayiuse(avx512_common))
349 return status::unimplemented;
350 else if (mayiuse(avx512_mic_4ops))
355 jcp.nthr = mkldnn_get_max_threads();
357 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
359 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
360 jcp.mb = src_d.dims()[0];
361 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
362 jcp.oc_without_padding = jcp.oc;
363 jcp.ic = src_d.dims()[1] / jcp.ngroups;
364 jcp.ih = src_d.dims()[2];
365 jcp.iw = src_d.dims()[3];
366 jcp.oh = dst_d.dims()[2];
367 jcp.ow = dst_d.dims()[3];
368 jcp.kh = weights_d.dims()[with_groups + 2];
369 jcp.kw = weights_d.dims()[with_groups + 3];
370 jcp.t_pad = cd.padding[0][0];
371 jcp.l_pad = cd.padding[0][1];
372 jcp.stride_h = cd.strides[0];
373 jcp.stride_w = cd.strides[1];
374 jcp.dilate_h = cd.dilates[0];
375 jcp.dilate_w = cd.dilates[1];
376 jcp.r_pad = nstl::max(
377 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
378 jcp.b_pad = nstl::max(
379 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
380 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
381 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
385 bool ok_to_pad_channels = jcp.ngroups == 1;
386 if (ok_to_pad_channels) {
387 jcp.oc = rnd_up(jcp.oc, simd_w);
388 jcp.ic = rnd_up(jcp.ic, simd_w);
391 if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
392 is_winograd_faster_than_direct(jcp)))
393 return status::unimplemented;
395 // Checking conditions not supported by these kernels
396 if (jcp.ngroups != 1)
397 return status::unimplemented;
398 if ((jcp.kh != 3) || (jcp.kw != 3))
399 return status::unimplemented;
400 if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
401 return status::unimplemented;
402 if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
403 return status::unimplemented;
404 if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
405 return status::unimplemented;
407 if (src_d.format() != nChw16c)
408 return status::unimplemented;
409 if (weights_d.format() != (with_groups ? gOIhw16i16o : OIhw16i16o))
410 return status::unimplemented;
411 if (dst_d.format() != nChw16c)
412 return status::unimplemented;
414 bool layout_consistency = true
415 && jcp.ic <= src_d.blocking_desc().padding_dims[1]
416 && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
417 && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
418 && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
419 if (!layout_consistency) return status::unimplemented;
421 return status::success;
425 status_t set_wsched_DATA_W_S_G_D_avx512_common(jit_conv_winograd_conf_t &jcp) {
427 auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
428 int dimN_reg_block, int current_best) {
429 return (dimN_reg_block >= MIN_REQUIRED_DIMN_REG_BLOCK)
430 && (dimN_reg_block < jcp.nb_reg)
431 && (dimN_reg_block < current_best);
433 jcp.dimN_reg_block = get_divisor_satisfying_cond(
434 jcp, jcp.dimN, jcp.dimN, test_cond_dimN_reg_block);
436 if (jcp.dimN_reg_block >= jcp.nb_reg) {
437 auto test_cond_dimN_reg_block = [](jit_conv_winograd_conf_t &jcp,
438 int dimN_reg_block, int current_best) {
439 return (dimN_reg_block < jcp.nb_reg)
440 && (dimN_reg_block > current_best);
443 jcp.dimN_reg_block = get_divisor_satisfying_cond(
444 jcp, jcp.dimN, 1, test_cond_dimN_reg_block);
447 //********************* Choosing dimK_block **********************//
448 auto test_cond1_dimK_block = [](
449 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
450 return check_cond1(jcp.dimN_reg_block, dimK_block, jcp.dimK_reg_block,
451 1, jcp.dimM_simd_block, .75f)
452 && (dimK_block > current_best);
455 auto test_cond1_bis_dimK_block = [](
456 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
457 return check_cond1_bis(jcp.dimN_reg_block, dimK_block,
458 jcp.dimK_reg_block, 1, jcp.dimM_simd_block, .9f)
459 && (dimK_block > current_best);
462 jcp.dimK_block = get_divisor_satisfying_cond(
463 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_bis_dimK_block);
464 // If we are not able to use streams, we fall back to condition [1]
465 if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
466 jcp.dimK_block = get_divisor_satisfying_cond(
467 jcp, jcp.dimK / jcp.dimK_reg_block, 1, test_cond1_dimK_block);
468 jcp.dimK_nb_block = (jcp.dimK / jcp.dimK_reg_block) / jcp.dimK_block;
470 //********************* Choosing dimM_block **********************//
471 jcp.dimM_simd_block = 16;
472 /*XXX: Why C=0.5 here but C=0.75 for dimK_block?*/
473 auto test_cond1_dimM_block = [](
474 jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
475 return check_cond1(jcp.dimN_reg_block, jcp.dimK_block,
476 jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .5f)
477 && (dimM_block > current_best);
480 auto test_cond1_bis_dimM_block = [](
481 jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
482 return check_cond1_bis(jcp.dimN_reg_block, jcp.dimK_block,
483 jcp.dimK_reg_block, dimM_block, jcp.dimM_simd_block, .3f)
484 && (dimM_block > current_best);
487 if (jcp.dimK_block < jcp.dimK / jcp.dimK_reg_block)
488 jcp.dimM_block = get_divisor_satisfying_cond(
489 jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block);
491 jcp.dimM_block = get_divisor_satisfying_cond(jcp,
492 jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_bis_dimM_block);
493 jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block;
495 //******************* Choosing dimN_block *******************//
496 auto test_cond2_dimN_block = [](
497 jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
498 return check_cond2(dimN_block, jcp.dimN_reg_block, jcp.dimK_nb_block,
499 jcp.dimK_block, jcp.dimK_reg_block, jcp.dimM_block,
500 jcp.dimM_simd_block, .5f)
501 && (dimN_block > current_best);
504 jcp.dimN_block = get_divisor_satisfying_cond(
505 jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
506 jcp.dimN_nb_block = jcp.dimN / (jcp.dimN_reg_block * jcp.dimN_block);
507 jcp.sched_policy = WSCHED_DATA_W_S_G_D;
508 return status::success;
511 status_t _jit_avx512_common_conv_winograd_data_kernel_f32::init_conf_kernel(
512 jit_conv_winograd_conf_t &jcp, int dimM, int dimN, int dimK)
514 jcp.dimK_reg_block = 16;
515 jcp.dimM_simd_block = 16;
517 // TODO: replace double buffering with nuple buffering to maximize register
519 // the choice of the number of buffers will then come after choosing
521 jcp.double_buffering = true;
522 if (jcp.double_buffering)
523 jcp.zmm_start = 2 * ((jcp.ver == ver_4fma) ? 4 : 2);
526 jcp.nb_reg = 32 - jcp.zmm_start;
532 jcp.sched_policy = WSCHED_INVALID;
533 set_wsched_DATA_W_S_G_D_avx512_common(jcp);
535 assert(jcp.sched_policy == WSCHED_DATA_W_S_G_D);
536 return status::success;
539 bool jit_avx512_common_conv_winograd_fwd_kernel_f32::post_ops_ok(
540 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
541 const auto &p = attr.post_ops_;
543 auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
544 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
547 case 0: return true; // no post_ops
548 case 1: return is_relu(0) || is_sum(0); // relu or sum
549 case 2: return (is_sum(0) && is_relu(1)) ||
550 (is_relu(0) && is_sum(1)); // sum->relu or relu->sum
551 case 3: return is_relu(0) && is_sum(1) && is_relu(2); // relu->sum->relu
552 default: return false;
558 status_t jit_avx512_common_conv_winograd_fwd_kernel_f32::init_conf(
559 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
560 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
561 const memory_desc_wrapper &dst_d, const primitive_attr_t &attr) {
562 status_t st = init_conf_common(jcp, cd, src_d, weights_d, dst_d);
564 if (st != status::success)
567 // Winograd specific initialization
568 jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
569 jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
570 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
572 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
574 if (!post_ops_ok(jcp, attr))
575 return status::unimplemented;
577 const auto &p = attr.post_ops_;
578 const int eltwise_ind = p.find(primitive_kind::eltwise, 0, 1);
579 jcp.with_eltwise = eltwise_ind != -1;
580 if (jcp.with_eltwise) jcp.eltwise = p.entry_[eltwise_ind].eltwise;
581 jcp.with_sum = p.find(primitive_kind::sum, 0) != -1;
583 status_t res = init_conf_kernel(jcp, jcp.oc, jcp.ntiles, jcp.ic);
584 jcp.ic_simd_block = jcp.dimK_reg_block;
585 jcp.ic_block = jcp.dimK_block;
586 jcp.nb_ic = jcp.dimK_nb_block;
587 jcp.oc_simd_block = jcp.dimM_simd_block;
588 jcp.oc_block = jcp.dimM_block;
589 jcp.nb_oc = jcp.dimM_nb_block;
590 jcp.tile_block_ur = jcp.dimN_reg_block;
591 jcp.nb_tile_block_ur = jcp.dimN_block;
592 jcp.tile_block = jcp.dimN_nb_block;
593 jcp.tile_4fma_padding = 0; // only relevant for backward weights
598 status_t jit_avx512_common_conv_winograd_bwd_data_kernel_f32::init_conf(
599 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
600 const memory_desc_wrapper &diff_src_d,
601 const memory_desc_wrapper &weights_d,
602 const memory_desc_wrapper &diff_dst_d)
604 status_t st = init_conf_common(jcp, cd, diff_src_d, weights_d, diff_dst_d);
606 if (st != status::success)
609 jcp.itiles = (jcp.iw + tile_size - 1) / tile_size;
610 jcp.jtiles = (jcp.ih + tile_size - 1) / tile_size;
611 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
613 status_t res = init_conf_kernel(jcp, jcp.ic, jcp.ntiles, jcp.oc);
614 jcp.oc_simd_block = jcp.dimK_reg_block;
615 jcp.oc_block = jcp.dimK_block;
616 jcp.nb_oc = jcp.dimK_nb_block;
617 jcp.ic_simd_block = jcp.dimM_simd_block;
618 jcp.ic_block = jcp.dimM_block;
619 jcp.nb_ic = jcp.dimM_nb_block;
620 jcp.tile_block_ur = jcp.dimN_reg_block;
621 jcp.nb_tile_block_ur = jcp.dimN_block;
622 jcp.tile_block = jcp.dimN_nb_block;
623 jcp.tile_4fma_padding = 0; // only relevant for backward weights
628 void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::transpose_ker_generate()
630 auto load_B = [=](int reg_idx, int offset) {
631 for (int i = 0; i < 4; i++) {
632 vmovups(Zmm(reg_idx + i), zword[reg_origB + (offset + i) * jcp.dimN_reg_block * sizeof(float)]);
638 for (int j = 0; j < alpha; j++) {
639 for (int i = 0; i < alpha; i++) {
640 int origB_offset = (j * alpha + i) * jcp.dimK_4fma;
641 size_t transB_offset = (size_t)(j * alpha + i) * jcp.dimK_nb_block *
642 jcp.dimN_block * jcp.dimK_block * jcp.dimK_reg_block *
643 jcp.dimK_4fma * jcp.dimN_reg_block * sizeof(float);
644 mov(reg_transB_idx, transB_offset);
645 for (int tb = 0; tb < jcp.dimK_4fma; tb+=4) {
646 /*double buffering to hide load latencies*/
647 int next = (curr + 4) % 8;
648 if (i == 0 && tb == 0) {
649 load_B(0, origB_offset);
651 if (tb + 4 < (jcp.dimK_4fma -1)) {
652 load_B(next, origB_offset + 4);
653 } else if (i < alpha - 1) {
654 load_B(next, origB_offset + jcp.dimK_4fma);
657 vunpcklps(Zmm(8), Zmm(curr), Zmm(curr + 1));
658 vunpcklps(Zmm(9), Zmm(curr + 2), Zmm(curr + 3));
659 vunpckhps(Zmm(curr), Zmm(curr), Zmm(curr + 1));
660 vunpckhps(Zmm(curr + 1), Zmm(curr + 2), Zmm(curr + 3));
662 vunpcklpd(Zmm(curr + 2), Zmm(8), Zmm(9));
663 vunpckhpd(Zmm(curr + 3), Zmm(8), Zmm(9));
665 vunpcklpd(Zmm(8), Zmm(curr), Zmm(curr + 1));
666 vunpckhpd(Zmm(9), Zmm(curr), Zmm(curr + 1));
668 vmovntps(zword[reg_transB + reg_transB_idx
669 + sizeof(float) * tb * jcp.dimN_reg_block],
671 vmovntps(zword[reg_transB + reg_transB_idx
672 + sizeof(float) * (tb + 1) * jcp.dimN_reg_block],
674 vmovntps(zword[reg_transB + reg_transB_idx
675 + sizeof(float) * (tb + 2) * jcp.dimN_reg_block],
677 vmovntps(zword[reg_transB + reg_transB_idx
678 + sizeof(float) * (tb + 3) * jcp.dimN_reg_block],
688 void jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::gemm_loop_generate(
691 // for (int ofm2 = 0; ofm2 < jcp.oc_block; ofm2++)
692 // for (int ifm2 = 0; ifm2 < jcp.ic_block; ifm2++)
693 // for (int nb_tile_block_ur = 0; nb_tile_block_ur <
694 // jcp.nb_tile_block_ur; nb_tile_block_ur++)
695 // for (int tile_block_ur = 0; tile_block_ur <
696 // jcp.tile_block_ur; tile_block_ur++)
697 // for (int ifm3 = 0; ifm3 < jcp.ic_reg_block; ++ifm3)
698 // U[ofm2][ifm2][ofm3][ifm3][0:oc_simd_block] +=
699 // M[ofm2][ofm3][nb_tile_block_ur][tile_block_ur][0:oc_simd_block]
701 // broadcast(V[ifm2][nb_tile_block_ur][ifm3][tile_block_ur])
702 auto inner_loops = [=]() {
703 int inc_fma = jcp.ver == ver_4fma ? 4 : 1;
704 const int fma_ipc = jcp.ver == ver_4fma ? 1 : 2;
705 prefetcher_t<float> L1_pf(this, reg_srcB, L1,
706 jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma,
707 jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma
710 prefetcher_t<float> L2_pf(this, reg_srcB, L2,
711 jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma,
712 jcp.dimK_reg_block * jcp.dimN_reg_block * jcp.dimK_4fma
716 auto load_A = [=](int reg_idx, int offset) {
717 for (int i = 0; i < inc_fma; i++) {
718 vmovups(Zmm(reg_idx + i),
720 sizeof(float) * jcp.dimM_simd_block * (offset + i)]);
724 Label dimM_block_loop, dimK_block_loop, dimN_block_loop;
725 if (jcp.dimM_block > 1) {
726 mov(reg_dimM_block_loop_cnt, jcp.dimM_block);
729 { /************* OC_block (M) loop ***********/
730 if (jcp.dimN_block > 1) {
731 mov(reg_dimN_block_loop_cnt, jcp.dimN_block);
734 { /*************** IC_block (N) loop *********/
735 for (int dimN_reg_block = 0;
736 dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) {
737 Zmm zmm(jcp.zmm_start + dimN_reg_block);
739 vpxord(zmm, zmm, zmm);
741 vmovups(zmm, zword[reg_dstC +
742 dimN_reg_block * jcp.dimM_simd_block *
746 if (jcp.dimK_block > 1) {
747 mov(reg_dimK_block_loop_cnt, jcp.dimK_block);
750 { /************* nb_tile_ur(K) loop ********/
752 if (jcp.double_buffering) {
755 for (int dimK_reg_block = 0;
756 dimK_reg_block < jcp.dimK_reg_block;
758 int srcB_offset = dimK_reg_block * jcp.dimK_4fma
759 * jcp.dimN_reg_block;
760 for (int dimK_4fma = 0; dimK_4fma < jcp.dimK_4fma;
761 dimK_4fma += inc_fma) {
763 if (jcp.double_buffering) {
764 next = (dimK_reg_block * jcp.dimK_4fma
765 + dimK_4fma + inc_fma)
767 load_A(next, dimK_reg_block * jcp.dimK_4fma
768 + dimK_4fma + inc_fma);
771 load_A(next, dimK_reg_block * jcp.dimK_4fma
774 for (int dimN_reg_block = 0;
775 dimN_reg_block < jcp.dimN_reg_block;
777 L1_pf.prefetch(srcB_offset / inc_fma
778 + dimK_4fma / inc_fma
781 L2_pf.prefetch(srcB_offset / inc_fma
782 + dimK_4fma / inc_fma
785 if (jcp.ver == ver_4fma) {
786 int srcB_trans_offset = (dimK_4fma / 4) * 64
789 Zmm(jcp.zmm_start + dimN_reg_block),
791 EVEX_compress_addr(reg_srcB,
795 (dimN_reg_block % 4) * 16 +
796 (dimN_reg_block / 4) * 4)));
799 Zmm(jcp.zmm_start + dimN_reg_block),
801 EVEX_compress_addr(reg_srcB,
802 sizeof(float) * (srcB_offset + dimN_reg_block),
810 add(reg_srcA, jcp.dimK_reg_block * jcp.dimK_4fma
811 * jcp.dimM_simd_block * sizeof(float));
812 add(reg_srcB, jcp.dimK_reg_block * jcp.dimN_reg_block
813 * jcp.dimK_4fma * sizeof(float));
814 if (jcp.dimK_block > 1) {
815 sub(reg_dimK_block_loop_cnt, 1);
816 jnz(dimK_block_loop);
819 /******** Write C back to memory *******/
820 for (int dimN_reg_block = 0;
821 dimN_reg_block < jcp.dimN_reg_block; ++dimN_reg_block) {
822 Zmm zmm(jcp.zmm_start + dimN_reg_block);
823 vmovups(zword[reg_dstC +
824 dimN_reg_block * jcp.dimM_simd_block * sizeof(float)],
828 sub(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block *
829 jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float));
830 add(reg_dstC, jcp.dimN_reg_block * jcp.dimM_simd_block
832 if (jcp.dimN_block > 1) {
833 sub(reg_dimN_block_loop_cnt, 1);
834 jnz(dimN_block_loop);
838 if (jcp.dimM_block > 1) {
839 sub(reg_srcB, jcp.dimN_block * jcp.dimK_block
840 * jcp.dimK_reg_block * jcp.dimN_reg_block
841 * jcp.dimK_4fma * sizeof(float));
842 add(reg_srcA, jcp.dimK_block * jcp.dimK_reg_block
843 * jcp.dimK_4fma * jcp.dimM_simd_block * sizeof(float));
844 sub(reg_dimM_block_loop_cnt, 1);
845 jnz(dimM_block_loop);
851 // register used to handle long fma encoding
853 mov(reg_srcA, reg_srcA_const);
862 bool check_cond1_wu(int dimM_block, int dimM_simdw, int dimK_block,
863 int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C)
865 float lhs = 1.0f * dimM_block * dimN_reg_block * dimM_simdw;
866 lhs += dimM_block * dimK_block * dimK_reg_block * dimK_4fma * dimM_simdw;
867 lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma;
868 lhs *= sizeof(float);
869 float rhs = C * L1_cache_size;
873 bool check_cond1bis_wu(int dimM_block, int dimM_simdw, int dimK_block,
874 int dimK_reg_block, int dimK_4fma, int dimN_reg_block, float C)
876 float lhs = 1.0f * dimM_block * dimK_block * dimK_reg_block * dimK_4fma
878 lhs += dimK_block * dimN_reg_block * dimK_reg_block * dimK_4fma;
879 lhs *= sizeof(float);
880 float rhs = C * L1_cache_size;
884 bool check_cond2bis_wu(int dimM_block, int dimM_simdw, int dimK_block,
885 int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block,
888 float lhs = 1.0f * dimM_block * dimM_simdw * dimK_block * dimK_reg_block
890 lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block
892 lhs *= sizeof(float);
893 float rhs = C * L2_cache_size;
897 bool check_cond2_wu(int dimM_block, int dimM_simdw, int dimK_block,
898 int dimK_reg_block, int dimK_4fma, int dimN_block, int dimN_reg_block,
901 float lhs = 1.0f * dimM_block * dimM_simdw * dimN_block * dimN_reg_block;
902 lhs += dimM_block * dimM_simdw * dimK_block * dimK_reg_block * dimK_4fma;
903 lhs += dimK_block * dimK_reg_block * dimK_4fma * dimN_block
905 lhs *= sizeof(float);
906 float rhs = C * L2_cache_size;
911 status_t set_wsched_WEI_S_D_G_W_avx512_common(jit_conv_winograd_conf_t &jcp)
913 /*************** Choose dimN_reg_block (ic_simd_block)
914 * *******************************/
916 /*Hardcoded to 16 because N = ic for bwd weights and
917 innermost dimension for ic is assumed 16 in src transforms. This
918 choice covers load latencies while maintaining simplicity of kernel
919 for POR topologies. FIXME in future??: Will not work for future topologies
921 jcp.dimN_reg_block = jcp.ic_simd_block;
923 /****************************** Choose dimK_block
924 * **************************/
925 // No freedom for choosing dimM_simd_block because ic_simd_block
926 // is determined by input data format
927 jcp.dimM_simd_block = jcp.oc_simd_block;
929 auto test_cond1bis_dimK_block = [](
930 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
931 return check_cond1bis_wu(1, jcp.dimM_simd_block, dimK_block, 1,
932 jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f)
933 && (dimK_block > current_best);
936 auto test_cond1_dimK_block = [](
937 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
938 return check_cond1_wu(1, jcp.dimM_simd_block, dimK_block, 1,
939 jcp.dimK_4fma, jcp.dimN_reg_block, 0.4f)
940 && (dimK_block > current_best);
943 auto test_cond2bis_dimK_block = [](
944 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
945 return check_cond2bis_wu(1, jcp.dimM_simd_block, dimK_block, 1,
946 jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.5f)
947 && (dimK_block > current_best);
950 auto test_cond2_dimK_block = [](
951 jit_conv_winograd_conf_t &jcp, int dimK_block, int current_best) {
952 return check_cond2_wu(1, jcp.dimM_simd_block, dimK_block, 1,
953 jcp.dimK_4fma, 1, jcp.dimN_reg_block, 0.1f)
954 && (dimK_block > current_best);
957 jcp.dimK_block = get_divisor_satisfying_cond(
958 jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2bis_dimK_block);
959 if (jcp.dimK_block < jcp.dimK / jcp.dimK_4fma)
960 jcp.dimK_block = get_divisor_satisfying_cond(
961 jcp, jcp.dimK / jcp.dimK_4fma, 1, test_cond2_dimK_block);
963 jcp.dimK_reg_block = get_divisor_satisfying_cond(
964 jcp, jcp.dimK_block, 1, test_cond1bis_dimK_block);
965 if (jcp.dimK_reg_block < jcp.dimK_block) {
966 jcp.dimK_reg_block = get_divisor_satisfying_cond(
967 jcp, jcp.dimK_block, 1, test_cond1_dimK_block);
969 jcp.dimK_block /= jcp.dimK_reg_block;
971 = jcp.dimK / jcp.dimK_4fma / jcp.dimK_reg_block / jcp.dimK_block;
972 jcp.tile_block_ur = jcp.dimK_reg_block;
973 jcp.nb_tile_block_ur = jcp.dimK_block;
974 jcp.tile_block = jcp.dimK_nb_block;
976 /***************************** Chose dimN block
977 * ****************************/
978 auto test_cond2_dimN_block = [](
979 jit_conv_winograd_conf_t &jcp, int dimN_block, int current_best) {
980 return check_cond2_wu(1, jcp.dimM_simd_block, jcp.dimK_block,
981 jcp.dimK_reg_block, jcp.dimK_4fma, dimN_block,
982 jcp.dimN_reg_block, 0.5f)
983 && (dimN_block > current_best);
986 jcp.dimN_block = get_divisor_satisfying_cond(
987 jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond2_dimN_block);
988 jcp.ic_block = jcp.dimN_block;
989 jcp.dimN_nb_block = jcp.dimN / jcp.dimN_reg_block / jcp.dimN_block;
990 jcp.nb_ic = jcp.dimN_nb_block;
992 /********************************* Choose dimM block
993 * ************************/
996 auto test_cond1_dimM_block = [](
997 jit_conv_winograd_conf_t &jcp, int dimM_block, int current_best) {
998 return check_cond1_wu(dimM_block, jcp.dimM_simd_block, 1,
999 jcp.dimK_reg_block, jcp.dimK_4fma, jcp.dimN_reg_block,
1001 && (dimM_block > current_best)
1002 && (jcp.dimM / jcp.dimM_simd_block / dimM_block) >= 2;
1005 jcp.dimM_block = get_divisor_satisfying_cond(
1006 jcp, jcp.dimM / jcp.dimM_simd_block, 1, test_cond1_dimM_block);
1007 jcp.dimM_nb_block = (jcp.dimM / jcp.dimM_simd_block) / jcp.dimM_block;
1009 jcp.sched_policy = WSCHED_WEI_S_D_G_W;
1010 return status::success;
1013 status_t jit_avx512_common_conv_winograd_bwd_weights_kernel_f32::init_conf(
1014 jit_conv_winograd_conf_t &jcp, const convolution_desc_t &cd,
1015 const memory_desc_wrapper &src_d, const memory_desc_wrapper &diff_dst_d,
1016 const memory_desc_wrapper &diff_weights_d)
1018 jcp.nthr = mkldnn_get_max_threads();
1020 const bool with_groups = diff_weights_d.ndims() == src_d.ndims() + 1;
1022 jcp.ngroups = with_groups ? diff_weights_d.dims()[0] : 1;
1023 jcp.mb = src_d.dims()[0];
1024 jcp.oc = diff_dst_d.dims()[1] / jcp.ngroups;
1025 jcp.oc_without_padding = jcp.oc;
1026 jcp.ic = src_d.dims()[1] / jcp.ngroups;
1027 jcp.ih = src_d.dims()[2];
1028 jcp.iw = src_d.dims()[3];
1029 jcp.oh = diff_dst_d.dims()[2];
1030 jcp.ow = diff_dst_d.dims()[3];
1031 jcp.kh = diff_weights_d.dims()[with_groups + 2];
1032 jcp.kw = diff_weights_d.dims()[with_groups + 3];
1033 jcp.t_pad = cd.padding[0][0];
1034 jcp.l_pad = cd.padding[0][1];
1035 jcp.stride_h = cd.strides[0];
1036 jcp.stride_w = cd.strides[1];
1037 jcp.r_pad = nstl::max(
1038 0, (jcp.ow - 1) * jcp.stride_w + jcp.kw - jcp.iw - jcp.l_pad);
1039 jcp.b_pad = nstl::max(
1040 0, (jcp.oh - 1) * jcp.stride_h + jcp.kh - jcp.ih - jcp.t_pad);
1041 jcp.ihp = jcp.ih + jcp.t_pad + jcp.b_pad;
1042 jcp.iwp = jcp.iw + jcp.l_pad + jcp.r_pad;
1045 jcp.with_bias = (cd.diff_bias_desc.format != memory_format::undef);
1046 jcp.dilate_h = cd.dilates[0];
1047 jcp.dilate_w = cd.dilates[1];
1049 bool ok_to_pad_channels = jcp.ngroups == 1;
1050 if (ok_to_pad_channels) {
1051 jcp.oc = rnd_up(jcp.oc, simd_w);
1052 jcp.ic = rnd_up(jcp.ic, simd_w);
1055 if (mayiuse(avx512_core))
1056 return status::unimplemented;
1057 if (!mayiuse(avx512_common))
1058 return status::unimplemented;
1059 else if (mayiuse(avx512_mic_4ops))
1064 if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
1065 is_winograd_faster_than_direct(jcp)))
1066 return status::unimplemented;
1067 // Winograd specific initialization
1068 jcp.itiles = (jcp.ow + tile_size - 1) / tile_size;
1069 jcp.jtiles = (jcp.oh + tile_size - 1) / tile_size;
1070 jcp.ntiles = jcp.mb * jcp.itiles * jcp.jtiles;
1072 // Winograd kernel works only for 3x3 convolution with stride 1
1073 if (jcp.ngroups != 1)
1074 return status::unimplemented;
1075 if ((jcp.kh != 3) || (jcp.kw != 3))
1076 return status::unimplemented;
1077 if ((jcp.dilate_h != 0) || (jcp.dilate_w != 0))
1078 return status::unimplemented;
1079 if ((jcp.stride_h != 1) || (jcp.stride_w != 1))
1080 return status::unimplemented;
1081 if ((jcp.ic % simd_w) != 0 || (jcp.oc % simd_w) != 0)
1082 return status::unimplemented;
1083 if (src_d.format() != nChw16c)
1084 return status::unimplemented;
1085 if (diff_weights_d.format() != (with_groups ? gOIhw16i16o : OIhw16i16o))
1086 return status::unimplemented;
1087 if (diff_dst_d.format() != nChw16c)
1088 return status::unimplemented;
1090 bool layout_consistency = true
1091 && jcp.ic <= src_d.blocking_desc().padding_dims[1]
1092 && jcp.oc <= diff_dst_d.blocking_desc().padding_dims[1]
1093 && jcp.ic <= diff_weights_d.blocking_desc().padding_dims[with_groups + 1]
1094 && jcp.oc <= diff_weights_d.blocking_desc().padding_dims[with_groups + 0];
1095 if (!layout_consistency) return status::unimplemented;
1097 /*************************** New Kernel Parameters
1098 * *****************************/
1099 jcp.ic_simd_block = simd_w;
1100 jcp.oc_simd_block = simd_w;
1102 jcp.tile_4fma_padding = 0;
1104 #define MAX_4FMA_UR 8
1105 if (jcp.ver == ver_4fma) {
1106 auto test_cond_4fma = [](jit_conv_winograd_conf_t &jcp, int dimK_4fma,
1108 return (dimK_4fma % 4 == 0) && (dimK_4fma <= MAX_4FMA_UR)
1109 && (dimK_4fma > current_best);
1111 jcp.dimK_4fma = get_divisor_satisfying_cond(
1112 jcp, jcp.itiles * jcp.jtiles, 4, test_cond_4fma);
1113 if (jcp.dimK_4fma == 1)
1115 if ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma != 0)
1116 jcp.tile_4fma_padding = jcp.dimK_4fma
1117 - ((jcp.itiles * jcp.jtiles) % jcp.dimK_4fma);
1120 jcp.tile_4fma = jcp.dimK_4fma;
1121 /*NOTE: When (itiles * jtiles) % dimK_4fma != 0, transpose in diff_src
1123 * will not work correctly, this is solved by applying padding.*/
1124 jcp.dimK = jcp.mb * (jcp.itiles * jcp.jtiles + jcp.tile_4fma_padding);
1128 jcp.double_buffering = true;
1129 if (jcp.double_buffering)
1130 jcp.zmm_start = jcp.ver == ver_4fma ? 8 : 2;
1132 jcp.zmm_start = jcp.ver == ver_4fma ? 4 : 1;
1133 jcp.nb_reg = 32 - jcp.zmm_start;
1135 jcp.sched_policy = WSCHED_INVALID;
1136 status_t res = set_wsched_WEI_S_D_G_W_avx512_common(jcp);
1137 assert(jcp.sched_policy == WSCHED_WEI_S_D_G_W);
1139 jcp.tile_block_ur = jcp.dimK_reg_block;
1140 jcp.nb_tile_block_ur = jcp.dimK_block;
1141 jcp.tile_block = jcp.dimK_nb_block;
1143 jcp.ic_block = jcp.dimN_block;
1144 jcp.nb_ic = jcp.dimN_nb_block;
1146 jcp.oc_block = jcp.dimM_block;
1147 jcp.nb_oc = jcp.dimM_nb_block;
1156 // vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s