2 /*******************************************************************************
3 * Copyright 2018 Intel Corporation
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
9 * http://www.apache.org/licenses/LICENSE-2.0
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *******************************************************************************/
20 #include "c_types_map.hpp"
21 #include "cpu_convolution_pd.hpp"
22 #include "cpu_engine.hpp"
23 #include "mkldnn_thread.hpp"
24 #include "type_helpers.hpp"
27 #include "jit_avx512_core_fp32_wino_conv_2x3.hpp"
28 #include "jit_generator.hpp"
34 using namespace mkldnn::impl::memory_format;
35 using namespace mkldnn::impl::memory_tracking::names;
36 using namespace mkldnn::impl::utils;
37 using namespace Xbyak;
39 /// SRC TRANSFORMS /////////////////////////////////////////////////////////////
40 struct jit_avx512_core_fp32_wino_conv_2x3_src_trans_t: public jit_generator {
41 DECLARE_CPU_JIT_AUX_FUNCTIONS(
42 jit_avx512_core_fp32_wino_conv_2x3_src_trans_t)
44 jit_conv_conf_2x3_wino_t jcp;
46 struct call_params_t {
49 const void *v_y_masks;
50 const void *v_x_masks;
52 void (*ker_)(const call_params_t *);
54 jit_avx512_core_fp32_wino_conv_2x3_src_trans_t(
55 jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
59 reinterpret_cast<decltype(ker_)>(const_cast<uint8_t*>(getCode()));
65 assert(i < jcp.alpha * jcp.alpha);
70 assert(i < jcp.alpha * jcp.alpha);
75 assert(i < jcp.alpha * jcp.alpha);
79 Opmask y_mask = Opmask(1);
80 Opmask r_mask = Opmask(2);
81 Opmask x_mask(int id) {
83 return Opmask(3 + id);
86 Reg64 reg_ptr_v_y_masks = r12;
87 Reg64 reg_ptr_v_x_masks = r11;
89 Reg64 reg_aux_ptr_src = r10;
90 Reg64 reg_aux_ptr_dst = r9;
92 Reg64 reg_ic_block = r8;
96 void jit_avx512_core_fp32_wino_conv_2x3_src_trans_t::generate() {
99 const int load_block = 16;
100 int out_offset = 0, inp_offset = 0;
103 #define READ_PARAM(reg, field) \
104 mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
105 READ_PARAM(reg_aux_ptr_src, src);
106 READ_PARAM(reg_aux_ptr_dst, wino_src);
107 READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
108 READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
111 for (int i = 0; i < jcp.alpha; i++) {
112 kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
114 mov(reg_ic_block, jcp.ic / load_block);
117 for (int y = 0; y < jcp.alpha; y++) {
118 kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(int16_t) * y]);
119 for (int x = 0; x < jcp.alpha; x++) {
120 Zmm zmm = vreg_inp(y * jcp.alpha + x);
122 vxorps(zmm, zmm, zmm);
123 kandw(r_mask, y_mask, x_mask(x));
124 inp_offset = sizeof(float)
125 * ((-jcp.t_pad + y) * jcp.iw * load_block
126 + (-jcp.l_pad + x) * load_block);
127 vmovups(zmm | r_mask,
128 EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
131 for (int y = 0; y < jcp.alpha; y++) {
132 vsubps(vreg_tmp(y * jcp.alpha + 0), vreg_inp(y * jcp.alpha + 0),
133 vreg_inp(y * jcp.alpha + 2));
134 vaddps(vreg_tmp(y * jcp.alpha + 1), vreg_inp(y * jcp.alpha + 1),
135 vreg_inp(y * jcp.alpha + 2));
136 vsubps(vreg_tmp(y * jcp.alpha + 2), vreg_inp(y * jcp.alpha + 2),
137 vreg_inp(y * jcp.alpha + 1));
138 vsubps(vreg_tmp(y * jcp.alpha + 3), vreg_inp(y * jcp.alpha + 1),
139 vreg_inp(y * jcp.alpha + 3));
141 for (int x = 0; x < jcp.alpha; x++) {
142 vsubps(vreg_out(x + 0 * jcp.alpha), vreg_tmp(x + jcp.alpha * 0),
143 vreg_tmp(x + jcp.alpha * 2));
144 vaddps(vreg_out(x + 1 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1),
145 vreg_tmp(x + jcp.alpha * 2));
146 vsubps(vreg_out(x + 2 * jcp.alpha), vreg_tmp(x + jcp.alpha * 2),
147 vreg_tmp(x + jcp.alpha * 1));
148 vsubps(vreg_out(x + 3 * jcp.alpha), vreg_tmp(x + jcp.alpha * 1),
149 vreg_tmp(x + jcp.alpha * 3));
152 for (int i = 0; i < 16; i++) {
153 out_offset = sizeof(float) * (jcp.inp_stride * i);
154 vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset),
158 add(reg_aux_ptr_src, sizeof(float) * jcp.ih * jcp.iw * load_block);
159 add(reg_aux_ptr_dst, sizeof(float) * load_block);
162 cmp(reg_ic_block, 0);
163 jg(ic_block_label, T_NEAR);
167 /// DST TRANSFORMS /////////////////////////////////////////////////////////////
168 struct jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t: public jit_generator {
169 DECLARE_CPU_JIT_AUX_FUNCTIONS(
170 jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t)
172 jit_conv_conf_2x3_wino_t jcp;
173 const primitive_attr_t &attr_;
175 struct call_params_t {
176 const void *wino_dst;
178 const void *v_y_masks;
179 const void *v_x_masks;
184 void (*ker_)(const call_params_t *);
186 jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t(
187 jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
188 : jcp(ajcp), attr_(attr) {
190 ker_ = reinterpret_cast<decltype(ker_)>(
191 const_cast<uint8_t *>(getCode()));
195 bool maybe_relu(int position);
197 Zmm vreg_inp(int i) { // 16
198 assert(i < jcp.alpha * jcp.alpha);
202 Zmm vreg_stg(int id) { // 8
203 const int id_reg_stg = jcp.alpha * jcp.alpha + id;
204 assert(id_reg_stg < jcp.alpha * jcp.alpha + 8);
205 return Zmm(31 - id_reg_stg);
208 Zmm vreg_out(int id) { // 4
209 const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
210 assert(id_reg_out < jcp.alpha * jcp.alpha + 12);
211 return Zmm(31 - id_reg_out);
214 Zmm vreg_tmp(int id) { // 2
215 const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id;
216 assert(id_reg_tmp < jcp.alpha * jcp.alpha + 14);
217 return Zmm(31 - id_reg_tmp);
220 Zmm vreg_zero = Zmm(0);
221 Zmm vreg_prev_dst = Zmm(0);
222 Zmm vreg_bias = Zmm(2);
224 Opmask y_mask = Opmask(1);
225 Opmask r_mask = Opmask(2);
226 Opmask x_mask(int id) {
228 return Opmask(3 + id);
231 Reg64 reg_ptr_v_y_masks = r12;
232 Reg64 reg_ptr_v_x_masks = r11;
234 Reg64 reg_aux_ptr_src = r10;
235 Reg64 reg_aux_ptr_dst = r9;
237 Reg64 reg_oc_block = r8;
239 Reg64 reg_ptr_bias = rbx;
240 Reg64 reg_ptr_scales = abi_not_param1;
241 Reg64 reg_ptr_sum_scale = rdx;
244 bool jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::maybe_relu(int position) {
245 using namespace primitive_kind;
246 const auto &p = attr_.post_ops_;
249 /* relu before sum */
251 || p.contain(eltwise, 0);
252 } else if (position == 1) {
254 const int sum_idx = p.contain(sum, 0)
255 ? 0 : (p.contain(sum, 1) ? 1 : -1);
260 || p.contain(eltwise, sum_idx + 1);
266 void jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t::generate() {
267 Label oc_block_label;
269 const int load_block = 16;
271 auto loop_body = [=]() {
272 const auto &p = attr_.post_ops_;
273 const int sum_idx = p.find(primitive_kind::sum);
274 const float *p_sum_scale = (sum_idx != -1)
275 ? &p.entry_[sum_idx].sum.scale
277 if (p_sum_scale && *p_sum_scale != 1.f)
278 mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
280 for (int i = 0; i < 16; i++) {
281 int internal_offset = sizeof(float) * jcp.out_stride * i;
283 EVEX_compress_addr(reg_aux_ptr_src, internal_offset));
285 for (int y = 0; y < jcp.alpha; y++) {
286 vaddps(vreg_tmp(0), vreg_inp(y * 4 + 0), vreg_inp(y * 4 + 1));
287 vaddps(vreg_stg(y * 2), vreg_tmp(0), vreg_inp(y * 4 + 2));
289 vsubps(vreg_tmp(1), vreg_inp(y * 4 + 1), vreg_inp(y * 4 + 2));
290 vsubps(vreg_stg(y * 2+1), vreg_tmp(1), vreg_inp(y * 4 + 3));
292 for (int x = 0; x < jcp.m; x++) {
293 vaddps(vreg_tmp(0), vreg_stg(x), vreg_stg(x+2 * 1));
294 vaddps(vreg_out(x), vreg_tmp(0), vreg_stg(x+2 * 2));
296 vsubps(vreg_tmp(1), vreg_stg(x+2 * 1), vreg_stg(x+2 * 2));
297 vsubps(vreg_out(x+2), vreg_tmp(1), vreg_stg(x+2 * 3));
302 auto bias_addr = ptr [ reg_ptr_bias ];
303 vmovups(vreg_bias, bias_addr);
305 for (int y = 0; y < jcp.m; y++) {
306 kmovw(y_mask, ptr[ reg_ptr_v_y_masks + sizeof(int16_t) * y ]);
307 for (int x = 0; x < jcp.m; x++) {
308 kandw(r_mask, y_mask, x_mask(x));
310 int i = y * jcp.m + x;
311 int offset = sizeof(float) *
312 (y * jcp.ow * jcp.oc_block + x * jcp.oc_block);
313 Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset);
315 Zmm zmm = vreg_out(i);
317 vaddps(zmm, zmm, vreg_bias);
318 vmulps(zmm, zmm, ptr [reg_ptr_scales]);
321 vxorps(vreg_zero, vreg_zero, vreg_zero);
322 vmaxps(zmm, vreg_zero, zmm);
324 if (p_sum_scale) { // post_op: sum
325 vxorps(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst);
326 vmovups(vreg_prev_dst | r_mask, addr);
327 if (*p_sum_scale == 1.f)
328 vaddps(zmm, vreg_prev_dst);
330 vfmadd231ps(zmm, vreg_prev_dst,
331 zword_b[reg_ptr_sum_scale]);
334 vxorps(vreg_zero, vreg_zero, vreg_zero);
335 vmaxps(zmm, vreg_zero, zmm);
338 vmovups(addr, zmm | r_mask);
345 #define READ_PARAM(reg, field) \
346 mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
347 READ_PARAM(reg_aux_ptr_src, wino_dst);
348 READ_PARAM(reg_aux_ptr_dst, dst);
349 READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
350 READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
351 READ_PARAM(reg_ptr_bias, bias);
352 READ_PARAM(reg_ptr_scales, scales);
355 for (int i = 0; i < jcp.alpha * jcp.alpha; i++)
356 vxorps(vreg_inp(i), vreg_inp(i), vreg_inp(i));
358 for (int i = 0; i < jcp.alpha; i++)
359 kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(int16_t) * i]);
362 oc_blocks = jcp.oc / load_block;
363 mov(reg_oc_block, oc_blocks);
367 add(reg_aux_ptr_src, sizeof(float) * load_block);
368 add(reg_aux_ptr_dst, sizeof(float) * jcp.oh * jcp.ow * load_block);
370 add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
371 add(reg_ptr_bias, jcp.typesize_bia * load_block);
374 cmp(reg_oc_block, 0);
375 jg(oc_block_label, T_NEAR);
377 sub(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
378 sub(reg_ptr_bias, oc_blocks * jcp.typesize_bia * load_block);
384 /// GEMM kernel ////////////////////////////////////////////////////////////////
385 struct jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t: public jit_generator {
386 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t)
387 jit_conv_conf_2x3_wino_t jcp;
389 struct call_params_t {
395 void (*ker_)(const call_params_t *);
398 static bool post_ops_ok(jit_conv_conf_2x3_wino_t &jcp,
399 const primitive_attr_t &attr);
401 jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t(
402 jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
405 ker_ = reinterpret_cast<decltype(ker_)>(
406 const_cast<uint8_t *>(getCode()));
409 static status_t init_conf(
410 jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
411 cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &weights_pd,
412 cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
413 const primitive_attr_t &attr,
414 memory_desc_t& expect_wei_md);
416 Zmm vreg_out(int n, int m) {
417 const int id_reg_out = n * jcp.m_block + m;
418 assert(id_reg_out < jcp.n2_block * jcp.m_block);
419 return Zmm(31 - id_reg_out);
421 Zmm vreg_wei(int i) {
422 assert (31 - jcp.n2_block * jcp.m_block - i > 1);
423 return Zmm(31 - jcp.n2_block * jcp.m_block - i);
426 Zmm vreg_src = Zmm(0);
427 Zmm vreg_one = Zmm(1);
428 Zmm vreg_tmp = Zmm(2);
430 Reg64 reg_ptr_src = r15;
432 Reg64 reg_aux_dst = r12;
433 Reg64 reg_aux_dst2 = r11;
434 Reg64 reg_aux_wei = r10;
435 Reg64 reg_aux_wei2 = r9;
436 Reg64 reg_aux_src = r8;
437 Reg64 reg_aux_src2 = rax;
445 bool jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::post_ops_ok(
446 jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) {
447 using namespace primitive_kind;
448 const auto &p = attr.post_ops_;
450 auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
454 case 1: return is_relu(0) || p.contain(sum, 0);
455 case 2: return (p.contain(sum, 0) && is_relu(1)) ||
456 (p.contain(sum, 1) && is_relu(0));
457 case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
458 default: return false;
464 void jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::generate() {
465 Label nnb_loop_label, K_loop_label, mb_loop_label;
468 #define READ_PARAM(reg, field) \
469 mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
470 READ_PARAM(reg_ptr_src, src);
471 READ_PARAM(reg_aux_dst, dst);
472 READ_PARAM(reg_aux_wei, wei);
476 mov(reg_nnb, jcp.n_chunks);
479 mov(reg_aux_dst2, reg_aux_dst);
480 mov(reg_aux_src, reg_ptr_src);
481 mov(reg_mb, jcp.M / jcp.m_block);
485 for (nb2 = 0; nb2 < jcp.n2_block; nb2++) {
486 for (int m = 0; m < jcp.m_block; m++) {
487 vxorps(vreg_out(nb2, m), vreg_out(nb2, m), vreg_out(nb2, m));
490 mov(reg_aux_src2, reg_aux_src);
491 mov(reg_aux_wei2, reg_aux_wei);
493 mov(reg_K, jcp.k_chunks);
496 for (int _i = 0; _i < jcp.k2_block; _i++) {
497 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
499 int wei_offset = sizeof(float)
500 * ((nb2 * jcp.nb_ic * jcp.ic_block
502 + _i * jcp.oc_block);
503 vmovups(vreg_wei(nb2),
504 EVEX_compress_addr(reg_aux_wei2, wei_offset));
506 vmovups(vreg_wei(nb2),
507 EVEX_compress_addr(reg_aux_wei2,
508 sizeof(float) * wei_offset));
509 wei_offset += jcp.oc_block;
512 for (int m = 0; m < jcp.m_block; m++) {
513 int inp_offset = sizeof(float) * (m * jcp.K + _i);
514 if (jcp.n2_block > 1) {
515 vbroadcastss(vreg_src,
516 EVEX_compress_addr(reg_aux_src2, inp_offset));
517 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++)
518 vfmadd231ps(vreg_out(nb2, m), vreg_wei(nb2),
521 vfmadd231ps(vreg_out(0, m), vreg_wei(0),
522 EVEX_compress_addr(reg_aux_src2, inp_offset, true));
526 add(reg_aux_src2, sizeof(float) * jcp.ic_block);
528 add(reg_aux_wei2, sizeof(float) * jcp.oc_block * jcp.ic_block);
531 sizeof(float) * jcp.k2_block * jcp.n2_block
536 jg(K_loop_label, T_NEAR);
538 for (int m = 0; m < jcp.m_block; m++) {
540 for (nb2 = 0; nb2 < jcp.n2_block; nb2++) {
541 int offset = sizeof(float) *
542 (m * jcp.N + nb2 * jcp.oc_block);
543 vmovups(EVEX_compress_addr(reg_aux_dst2,offset),
547 add(reg_aux_src, sizeof(float) * jcp.m_block * jcp.K);
548 add(reg_aux_dst2, sizeof(float) * jcp.m_block * jcp.N);
552 jg(mb_loop_label, T_NEAR);
555 add(reg_aux_dst, sizeof(float) * jcp.n2_block * jcp.oc_block);
557 sizeof(float) * jcp.k_chunks * jcp.ic_block * jcp.n2_block
562 jg(nnb_loop_label, T_NEAR);
568 bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
573 status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::init_conf(
574 jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
575 cpu_memory_t::pd_t &src_pd, cpu_memory_t::pd_t &wei_pd,
576 cpu_memory_t::pd_t &dst_pd, cpu_memory_t::pd_t &bias_pd,
577 const primitive_attr_t &attr, memory_desc_t &expect_wei_md) {
578 const memory_desc_wrapper src_d(&src_pd);
579 const memory_desc_wrapper wei_d(&wei_pd);
580 const memory_desc_wrapper dst_d(&dst_pd);
581 const memory_desc_wrapper bias_d(&bias_pd);
583 const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
585 jcp.nthr = mkldnn_get_max_threads();
587 jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
588 jcp.mb = src_d.dims()[0];
589 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
590 jcp.oc_without_padding = jcp.oc;
591 jcp.ic = src_d.dims()[1] / jcp.ngroups;
592 jcp.ih = src_d.dims()[2];
593 jcp.iw = src_d.dims()[3];
594 jcp.oh = dst_d.dims()[2];
595 jcp.ow = dst_d.dims()[3];
596 jcp.kh = wei_d.dims()[with_groups + 2];
597 jcp.kw = wei_d.dims()[with_groups + 3];
598 jcp.t_pad = cd.padding[0][0];
599 jcp.b_pad = cd.padding[1][0];
600 jcp.l_pad = cd.padding[0][1];
601 jcp.r_pad = cd.padding[1][1];
602 jcp.stride_h = cd.strides[0];
603 jcp.stride_w = cd.strides[1];
604 jcp.dilate_h = cd.dilates[0];
605 jcp.dilate_w = cd.dilates[1];
609 jcp.alpha = jcp.m + jcp.r - 1;
611 jcp.src_fmt = src_d.format();
612 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
614 if (!post_ops_ok(jcp, attr))
615 return status::unimplemented;
617 bool ok_to_pad_channels = jcp.ngroups == 1;
618 if (ok_to_pad_channels) {
619 jcp.oc = rnd_up(jcp.oc, simdw);
620 jcp.ic = rnd_up(jcp.ic, simdw);
623 if (src_d.format() != nChw16c
624 || dst_d.format() != nChw16c
625 || !IMPLICATION(jcp.with_bias,
626 bias_d.format() == x))
627 return status::unimplemented;
629 jcp.ver = ver_avx512_core;
630 if (!(mayiuse(avx512_core)))
631 return status::unimplemented;
633 if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
634 is_winograd_faster_than_direct(jcp)))
635 return status::unimplemented;
637 if (src_d.data_type() != data_type::f32)
638 return status::unimplemented;
639 if (wei_d.data_type() != data_type::f32)
640 return status::unimplemented;
641 if (dst_d.data_type() != data_type::f32)
642 return status::unimplemented;
644 if (mayiuse(avx512_core_vnni))
647 jcp.ic_block = simdw;
648 jcp.oc_block = simdw;
650 bool ok = true && jcp.kh == 3 && jcp.kw == 3 && jcp.ngroups == 1
651 && jcp.oc % jcp.oc_block == 0 && jcp.ic % jcp.ic_block == 0
652 && jcp.stride_h == 1 && jcp.stride_w == 1 && jcp.dilate_h == 0
653 && jcp.dilate_w == 0 && jcp.t_pad == jcp.b_pad
654 && jcp.l_pad == jcp.r_pad && jcp.t_pad < 2 && jcp.t_pad >= 0
655 && jcp.l_pad < 2 && jcp.l_pad >= 0;
657 return status::unimplemented;
659 const int L2_cap = get_cache_size(2, true) / sizeof(float);
660 const int L3_capacity = get_cache_size(3, false) / sizeof(float);
668 auto wei_sz = (float)aa * ic * oc;
669 auto inp_sz = (float)mb * ih * iw * ic;
670 auto sp_sz = (float)mb * ih * iw;
672 /* Heuristics here. Numbers '28','196' is an observation from data. */
673 if (wei_sz / inp_sz > 5)
676 jcp.small_mb = false;
678 if (mb > nstl::min(jcp.nthr, 28)
680 && (wei_sz >= 0.9f * L2_cap
681 || inp_sz > L2_cap * jcp.nthr + L3_capacity))
682 || (jcp.small_mb && sp_sz > 196))
683 return unimplemented;
685 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
686 jcp.dst_dt = cd.dst_desc.data_type;
689 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
691 jcp.nb_oc = jcp.oc / jcp.oc_block;
692 jcp.nb_ic = jcp.ic / jcp.ic_block;
694 const int skx_free_regs = 30;
696 auto find_m_n2_blocks = [=](int xb, int yb, int &M, int &m_block,
697 int &n2_block, float ®_eff) {
698 M = (xb * yb) / jcp.alpha;
699 int max_m_block = m_block = nstl::min(M, skx_free_regs);
700 int max_n2_block = n2_block = nstl::min(jcp.nb_oc, skx_free_regs);
702 for (int im = max_m_block; im > 0; im--) {
703 for (int in2 = max_n2_block; in2 > 0; in2--) {
704 int used_regs = in2 * im + in2;
705 float cur_reg_eff = ((float)in2 * im) / (im + in2) / 2.5f;
706 if (M % im || jcp.nb_oc % in2 || used_regs > skx_free_regs
707 || cur_reg_eff <= reg_eff)
709 reg_eff = cur_reg_eff;
718 int nb_oc = jcp.nb_oc;
721 const int L3_cap_per_core = get_cache_size(3, true) / sizeof(float);
723 /* Selecting xb and yb blocking */
724 int min_yb = jcp.alpha;
725 int min_xb = jcp.alpha;
726 int max_yb = nstl::max(min_yb, rnd_up(ih, 2));
727 int max_xb = nstl::max(min_xb, rnd_up(iw, 2));
728 float best_eff = 0.f;
729 for (int ix = max_xb; ix >= min_xb; ix -= 2) {
730 if (rnd_up(ow, ix) < iw - 2)
732 for (int iy = max_yb; iy >= min_yb; iy -= 2) {
733 if (rnd_up(oh, iy) < ih - 2)
735 int ex_y = rnd_up(oh, iy);
736 int ex_x = rnd_up(ow, ix);
737 float work_eff = (float)(ih * iw) / (ex_y * ex_x);
739 int M, m_block, n2_b;
740 float reg_eff, thr_eff, par_eff, mem_eff, req_mem;
742 find_m_n2_blocks(ix, iy, M, m_block, n2_b, reg_eff);
744 /* outer parallelization */
745 int nblocks = mb * div_up(ih, iy) * div_up(iw, ix);
746 thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
749 req_mem = (((float)ix + 2) * (iy + 2) + aa * M) * Z + aa * Y;
750 if (req_mem > L2_cap / 2) {
751 if (req_mem > ((L2_cap + L3_cap_per_core) * 4) / 7)
752 mem_eff /= (n2_b + 1) / 2.f;
754 mem_eff /= (n2_b + 1) / 3.f;
757 float outer_eff = thr_eff + work_eff + reg_eff + mem_eff;
759 /* inner parallelization */
760 int bsz = iy * ix / a;
761 int gemmw = aa * (nb_oc / n2_b);
762 int bsz_r = rnd_up(bsz, jcp.nthr);
763 int gemmw_r = rnd_up(gemmw, jcp.nthr);
764 thr_eff = ((float)Z * bsz / bsz_r + Y * gemmw / gemmw_r) / (Z + Y);
766 req_mem = (float)ix * iy * (ic + simdw * n2_b) + simdw * n2_b * ic;
767 mem_eff = nstl::min(1.f, L2_cap / req_mem);
768 int M_per_thr = nstl::max(2, div_up(aa, jcp.nthr));
770 nstl::min(oc, div_up(aa * (nb_oc / n2_b), jcp.nthr));
771 req_mem = (float)aa * oc_per_thr * ic + M_per_thr * M * Z;
772 if (req_mem > L2_cap)
774 par_eff = 1 / (2.f * nblocks);
776 float inner_eff = thr_eff + work_eff + mem_eff + par_eff;
778 float eff = jcp.small_mb ? inner_eff : outer_eff;
779 if (eff > best_eff) {
784 jcp.m_block = m_block;
790 assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0);
792 jcp.inp_stride = jcp.M * jcp.ic;
793 jcp.out_stride = jcp.M * jcp.oc;
794 jcp.wei_stride = jcp.ic * jcp.oc;
795 jcp.bia_stride = jcp.oc;
800 jcp.n_block = jcp.oc_block;
801 jcp.k_block = jcp.ic_block;
803 assert(jcp.M % jcp.m_block == 0);
804 assert(jcp.nb_oc % jcp.n2_block == 0);
806 jcp.n_chunks = jcp.nb_oc / jcp.n2_block;
807 jcp.k2_block = jcp.ic_block;
808 jcp.k_chunks = jcp.K / jcp.k2_block;
810 const auto &oscales = attr.output_scales_;
811 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
812 assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
814 /* re-create weights primitive descriptor
815 and set weights wino_blocking */
816 expect_wei_md.format = mkldnn_wino_fmt;
817 expect_wei_md.data_type = data_type::f32;
818 mkldnn_wino_desc_t &wd = expect_wei_md.layout_desc.wino_desc;
820 = jcp.small_mb ? mkldnn_wino_wei_aaOio : mkldnn_wino_wei_aaOBiOo;
822 wd.alpha = jcp.alpha;
825 wd.ic_block = jcp.ic_block;
826 wd.oc_block = jcp.oc_block;
827 wd.oc2_block = jcp.n2_block;
830 size_t max_size = sizeof(float) * jcp.alpha * jcp.alpha * jcp.ic * jcp.oc;
833 return status::success;
835 ////////////////////////////////////////////////////////////////////////////////
837 status_t jit_avx512_core_fp32_wino_conv_2x3_fwd_t
838 ::pd_t::jit_conf(memory_desc_t& expect_wei_md) {
839 return jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t::init_conf(
840 jcp_, *this->desc(), this->src_pd_, this->weights_pd_,
841 this->dst_pd_,this->bias_pd_, *this->attr(), expect_wei_md);
844 jit_avx512_core_fp32_wino_conv_2x3_fwd_t::
845 jit_avx512_core_fp32_wino_conv_2x3_fwd_t(const pd_t *apd,
846 const input_vector &inputs, const output_vector &outputs)
847 : cpu_primitive_t(apd, inputs, outputs)
849 kernel_ = new jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t(
850 pd()->jcp_, *pd()->attr());
851 src_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_src_trans_t(
852 pd()->jcp_, *pd()->attr());
853 dst_trans_ = new jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t(
854 pd()->jcp_, *pd()->attr());
857 jit_avx512_core_fp32_wino_conv_2x3_fwd_t
858 ::~jit_avx512_core_fp32_wino_conv_2x3_fwd_t() {
864 void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward() const {
865 const auto &jcp = kernel_->jcp;
868 execute_forward_small_mb();
870 execute_forward_mbN();
873 void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_mbN() const {
874 auto src = reinterpret_cast<const float *>(input_memory(0));
875 auto wei = reinterpret_cast<const float *>(input_memory(1));
876 auto bia = reinterpret_cast<const float *>(input_memory(2));
877 auto dst = reinterpret_cast<float *>(memory(0));
879 auto scratchpad = this->scratchpad();
881 const auto &jcp = kernel_->jcp;
882 const auto &oscales = pd()->attr()->output_scales_;
884 const size_t wino_size_offset =
885 (size_t)(pd()->jcp_.yb / 2) * (pd()->jcp_.xb / 2) + (pd()->jcp_.xb);
886 const size_t size_wino_src = wino_size_offset * pd()->jcp_.ic * 16;
887 const size_t size_wino_dst = wino_size_offset * pd()->jcp_.oc * 16;
889 if (pd()->wants_padded_bias()) {
890 auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
891 utils::array_copy(padded_bias, bia, jcp.oc_without_padding);
892 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
893 jcp.oc - jcp.oc_without_padding);
897 auto ptr_V = scratchpad.get<float>(key_wino_V);
898 auto ptr_M = scratchpad.get<float>(key_wino_M);
900 parallel_nd(jcp.mb, div_up(jcp.oh,jcp.yb), div_up(jcp.ow, jcp.xb),
901 [&](int mb, int tile_y_b, int tile_x_b) {
902 int tile_y = tile_y_b * jcp.yb;
903 int tile_x = tile_x_b * jcp.xb;
905 int ithr = mkldnn_get_thread_num();
906 auto wino_src = ptr_V + size_wino_src * ithr;
907 auto wino_dst = ptr_M + size_wino_dst * ithr;
910 jit_avx512_core_fp32_wino_conv_2x3_src_trans_t
913 jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t
915 auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::
918 /* transformation of input tensor to winograd domain */
919 for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
920 for (int x_in_block = 0; x_in_block < jcp.xb;
923 unsigned short v_y_masks[4], v_x_masks[4];
925 int y = y_in_block + tile_y;
926 int x = x_in_block + tile_x;
927 int m = (y_in_block / 2) * (jcp.xb / 2)
930 int v_ys = nstl::max(0, jcp.t_pad - y);
931 int v_ye = nstl::min(jcp.alpha,
932 nstl::max(0, jcp.ih + jcp.t_pad - y));
934 int v_xs = nstl::max(0, jcp.l_pad - x);
935 int v_xe = nstl::min(jcp.alpha,
936 nstl::max(0, jcp.iw + jcp.l_pad - x));
939 for (int i = 0; i < jcp.alpha; i++) {
940 v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
941 v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
944 + mb * jcp.nb_ic * jcp.ih * jcp.iw
946 + y * jcp.iw * jcp.ic_block + x * jcp.ic_block;
947 auto local_w = wino_src + m * jcp.ic;
949 src_trans_p.src = local_s;
950 src_trans_p.wino_src = local_w;
951 src_trans_p.v_y_masks = v_y_masks;
952 src_trans_p.v_x_masks = v_x_masks;
954 src_trans_->ker_(&src_trans_p);
958 for (int tile_ij = 0; tile_ij < 16; tile_ij++) {
959 int offset = (tile_ij + ithr) % 16;
960 gemm_p.src = wino_src + jcp.inp_stride * offset;
961 gemm_p.dst = wino_dst + jcp.out_stride * offset;
962 gemm_p.wei = wei + jcp.wei_stride * offset;
964 kernel_->ker_(&gemm_p);
967 /* transformation from winograd domain to output tensor */
968 for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
969 for (int x_in_block = 0; x_in_block < jcp.xb;
971 unsigned short v_y_masks[2], v_x_masks[2];
973 int y = y_in_block + tile_y;
974 int x = x_in_block + tile_x;
975 int m = (y_in_block / 2) * (jcp.xb / 2)
979 for (int i = 0; i < jcp.m; i++) {
980 v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
981 v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
984 + mb * jcp.nb_oc * jcp.oh * jcp.ow
986 + y * jcp.ow * jcp.oc_block + x * jcp.oc_block;
987 auto local_w = wino_dst + m * jcp.oc;
989 auto scales = oscales.scales_;
990 dst_trans_p.dst = local_d;
991 dst_trans_p.wino_dst = local_w;
992 dst_trans_p.v_y_masks = v_y_masks;
993 dst_trans_p.v_x_masks = v_x_masks;
995 dst_trans_p.scales = scales;
996 dst_trans_p.bias = bia;
998 dst_trans_->ker_(&dst_trans_p);
1004 void jit_avx512_core_fp32_wino_conv_2x3_fwd_t::execute_forward_small_mb() const
1006 auto src = reinterpret_cast<const float *>(input_memory(0));
1007 auto wei = reinterpret_cast<const float *>(input_memory(1));
1008 auto bia = reinterpret_cast<const float *>(input_memory(2));
1009 auto dst = reinterpret_cast<float *>(memory(0));
1011 auto scratchpad = this->scratchpad();
1013 const auto &jcp = kernel_->jcp;
1014 const auto &oscales = pd()->attr()->output_scales_;
1016 if (pd()->wants_padded_bias()) {
1017 auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
1018 utils::array_copy(padded_bias, bia, jcp.oc_without_padding);
1019 utils::array_set(padded_bias + jcp.oc_without_padding, 0.f,
1020 jcp.oc - jcp.oc_without_padding);
1024 auto ptr_V = scratchpad.get<float>(key_wino_V);
1025 auto ptr_M = scratchpad.get<float>(key_wino_M);
1027 for (int mb = 0; mb < jcp.mb; mb++) {
1028 for (int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb) {
1029 for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
1030 /* transformation of input tensor to winograd domain */
1031 parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
1032 [&](int y_in_block_b, int x_in_block_b) {
1033 int y_in_block = y_in_block_b * 2;
1034 int x_in_block = x_in_block_b * 2;
1036 auto src_trans_p = jit_avx512_core_fp32_wino_conv_2x3_src_trans_t ::
1039 unsigned short v_y_masks[4], v_x_masks[4];
1041 int y = y_in_block + tile_y;
1042 int x = x_in_block + tile_x;
1043 int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
1045 int v_ys = nstl::max(0, jcp.t_pad - y);
1046 int v_ye = nstl::min(
1047 jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y));
1049 int v_xs = nstl::max(0, jcp.l_pad - x);
1050 int v_xe = nstl::min(
1051 jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x));
1054 for (int i = 0; i < jcp.alpha; i++) {
1055 v_y_masks[i] = (i < v_ys || i >= v_ye) ? 0 : 0xffff;
1056 v_x_masks[i] = (i < v_xs || i >= v_xe) ? 0 : 0xffff;
1059 + mb * jcp.nb_ic * jcp.ih * jcp.iw * jcp.ic_block
1060 + y * jcp.iw * jcp.ic_block + x * jcp.ic_block;
1061 auto local_w = ptr_V + m * jcp.ic;
1063 src_trans_p.src = local_s;
1064 src_trans_p.wino_src = local_w;
1065 src_trans_p.v_y_masks = v_y_masks;
1066 src_trans_p.v_x_masks = v_x_masks;
1068 src_trans_->ker_(&src_trans_p);
1072 parallel_nd(16, jcp.n_chunks, [&](int tile_ij, int nnb) {
1073 auto gemm_p = jit_avx512_core_fp32_wino_conv_2x3_fwd_ker_t ::
1076 gemm_p.src = ptr_V + jcp.inp_stride * tile_ij;
1077 gemm_p.dst = ptr_M + jcp.out_stride * tile_ij
1078 + nnb * jcp.n2_block * jcp.n_block;
1079 gemm_p.wei = wei + jcp.wei_stride * tile_ij
1080 + nnb * jcp.n2_block * jcp.n_block * jcp.K;
1082 kernel_->ker_(&gemm_p);
1085 /* transformation from winograd domain to output tensor */
1087 parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2),
1088 [&](int y_in_block_b, int x_in_block_b) {
1089 int y_in_block = y_in_block_b * 2;
1090 int x_in_block = x_in_block_b * 2;
1092 auto dst_trans_p = jit_avx512_core_fp32_wino_conv_2x3_dst_trans_t ::
1095 unsigned short v_y_masks[2], v_x_masks[2];
1097 int y = y_in_block + tile_y;
1098 int x = x_in_block + tile_x;
1099 int m = (y_in_block / 2) * (jcp.xb / 2) + (x_in_block / 2);
1102 for (int i = 0; i < jcp.m; i++) {
1103 v_x_masks[i] = (x + i < jcp.ow) ? 0xffff : 0;
1104 v_y_masks[i] = (y + i < jcp.oh) ? 0xffff : 0;
1107 + mb * jcp.nb_oc * jcp.oh * jcp.ow * jcp.oc_block
1108 + y * jcp.ow * jcp.oc_block + x * jcp.oc_block;
1109 auto local_w = ptr_M + m * jcp.oc;
1111 auto scales = oscales.scales_;
1112 dst_trans_p.dst = local_d;
1113 dst_trans_p.wino_dst = local_w;
1114 dst_trans_p.v_y_masks = v_y_masks;
1115 dst_trans_p.v_x_masks = v_x_masks;
1117 dst_trans_p.scales = scales;
1118 dst_trans_p.bias = bia;
1120 dst_trans_->ker_(&dst_trans_p);
1127 } // namespace mkldnn