1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
19 #include "type_helpers.hpp"
20 #include "mkldnn_thread.hpp"
22 #include "cpu_memory.hpp"
24 #include "jit_uni_1x1_conv_utils.hpp"
25 #include "jit_avx512_core_x8s8s32x_1x1_conv_kernel.hpp"
27 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::utils;
36 using namespace Xbyak;
38 bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::maybe_relu(int position)
40 using namespace primitive_kind;
41 const auto &p = attr_.post_ops_;
47 || p.contain(eltwise, 0)
48 || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
49 } else if (position == 1) {
51 const int sum_idx = p.contain(sum, 0)
52 ? 0 : (p.contain(sum, 1) ? 1 : -1);
57 || p.contain(eltwise, sum_idx + 1)
58 || jcp.dst_dt == data_type::u8;
64 void jit_avx512_core_x8s8s32x_1x1_conv_kernel::bcast_loop(int load_loop_blk)
66 mov(aux1_reg_bcast_data, reg_bcast_data);
67 mov(aux_reg_bcast_data, reg_bcast_data);
69 mov(aux_reg_output_data, reg_output_data);
70 mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_off));
73 Label bcast_loop_tail;
75 cmp(bcast_loop_iter, jcp.ur);
76 jl(bcast_loop_tail, T_NEAR);
79 assert(jcp.bcast_block % jcp.ur == 0);
80 int num_substeps = jcp.bcast_block / jcp.ur;
81 assert(num_substeps > 0 && num_substeps < 10);
82 for (int i = 0; i < num_substeps; i++) {
83 reduce_loop(load_loop_blk, jcp.ur, i, false);
84 if (i < num_substeps - 1) {
85 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
86 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
89 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
90 - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
91 int output_offset = jcp.bcast_loop_output_step
92 - (num_substeps - 1) * jcp.bcast_loop_output_substep;
94 add(aux_reg_output_data, output_offset);
97 sub(bcast_loop_iter, jcp.bcast_block);
98 cmp(bcast_loop_iter, jcp.bcast_block);
99 jge(bcast_loop, T_NEAR);
104 Label bcast_loop_tail_out;
105 cmp(bcast_loop_iter, 0);
106 jz(bcast_loop_tail_out, T_NEAR);
107 reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
108 L(bcast_loop_tail_out);
112 void jit_avx512_core_x8s8s32x_1x1_conv_kernel::cvt2ps(data_type_t type_in,
113 zmm_t zmm_in, const Xbyak::Operand &op, bool mask_flag) {
114 zmm_t zmm = mask_flag ? zmm_in | ktail_mask | T_z : zmm_in;
117 case data_type::s32: vmovups(zmm, op); break;
118 case data_type::s8: vpmovsxbd(zmm, op); break;
119 case data_type::u8: vpmovzxbd(zmm, op); break;
120 default: assert(!"unsupported data type");
122 if (type_in != data_type::f32)
123 vcvtdq2ps(zmm_in, zmm_in);
126 void jit_avx512_core_x8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
127 int ur, int substep, bool wraparound)
129 auto vreg_load = [=](int i_load) {
130 return Zmm(ur * load_loop_blk + i_load);
133 auto vreg_accum = [=](int i_load, int i_ur) {
134 return Zmm(i_ur * load_loop_blk + i_load);
137 auto zmm_bias_alpha = [=]() {
138 return Zmm(ur * load_loop_blk);
141 auto xmm_bias_alpha = [=]() {
142 return Xmm(ur * load_loop_blk);
144 auto bias_ptr = [=](int i_load) {
145 return EVEX_compress_addr(reg_bias_data,
146 jcp.typesize_bia * jcp.oc_block * i_load);
149 auto comp_ptr = [=](int i_load) {
150 return EVEX_compress_addr(reg_comp_data,
151 sizeof(int32_t) * jcp.oc_block * i_load);
154 auto scale_ptr = [=](int i_load) {
155 return EVEX_compress_addr(reg_ptr_scales,
156 jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load));
159 auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) {
160 assert(i_ur < jcp.ur);
161 assert(i_reduce <= jcp.reduce_loop_unroll);
162 assert(jcp.reduce_loop_unroll == jcp.reduce_block);
164 int offt = (jcp.ic_without_padding * i_ur + i_reduce);
166 return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt,
170 auto load_ptr = [=](int i_reduce, int i_load) {
171 int u0 = i_reduce % jcp.reduce_loop_unroll;
172 int u1 = i_reduce / jcp.reduce_loop_unroll;
174 int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block;
176 return EVEX_compress_addr(aux_reg_load_data,
177 u1 * jcp.reduce_loop_load_step
178 + jcp.typesize_in * offt);
181 auto output_ptr = [=](int i_load, int i_ur) {
182 return EVEX_compress_addr(aux_reg_output_data,
183 jcp.typesize_out * (jcp.oc_without_padding * i_ur
184 + i_load * jcp.load_block));
188 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
189 for (int i_ur = 0; i_ur < ur; ++i_ur) {
190 auto r = vreg_accum(i_load, i_ur);
193 if (jcp.signed_input) {
194 xor_(reg_scratch, reg_scratch);
195 Reg8 _t8 = reg_scratch.cvt8();
196 mov(_t8, (int8_t)-128);
197 vpbroadcastb(zmm_shift, _t8);
201 auto store = [=](const bool mask_flag_in) {
202 const auto &p = attr_.post_ops_;
203 const int sum_idx = p.find(primitive_kind::sum);
204 const float *p_sum_scale = (sum_idx != -1)
205 ? &p.entry_[sum_idx].sum.scale
207 mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
208 mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off));
209 if (p_sum_scale && *p_sum_scale != 1.f) {
210 mov(EVEX_compress_addr(rsp, reg_load_data_off), reg_load_data);
211 mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
213 if (jcp.signed_input && jcp.ver != ver_vnni) {
214 mov(reg_scratch, float2int(jcp.wei_adj_scale));
215 vmovq(xmm_bias_alpha(), reg_scratch);
216 vbroadcastss(zmm_bias_alpha(), xmm_bias_alpha());
218 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
219 const bool mask_flag = mask_flag_in && i_load == load_loop_blk - 1;
220 auto zmm_bias = zmm_tmp;
221 auto zmm_comp = zmm_bcast;
223 if (jcp.signed_input)
225 EVEX_compress_addr(rsp,reg_bias_data_off));
226 cvt2ps(jcp.bia_dt, zmm_bias, bias_ptr(i_load), mask_flag);
227 if (jcp.signed_input && jcp.ver != ver_vnni)
228 vmulps(zmm_bias, zmm_bias, zmm_bias_alpha());
230 if (jcp.signed_input) {
231 mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off));
232 cvt2ps(data_type::s32, zmm_comp, comp_ptr(i_load), mask_flag);
235 for (int i_ur = 0; i_ur < ur; ++i_ur) {
236 auto r = vreg_accum(i_load, i_ur);
238 if (jcp.signed_input)
239 vaddps(r, r, zmm_comp);
241 vaddps(r, r, zmm_bias);
243 zmm_t mask_zmm = mask_flag ? r | ktail_mask | T_z : r;
244 vmulps(mask_zmm, r, scale_ptr(i_load));
246 vpxord(zmm_zero, zmm_zero, zmm_zero);
247 vmaxps(r, zmm_zero, r);
249 if (p_sum_scale) { // post_op: sum
250 vpxord(zmm_zero, zmm_zero, zmm_zero);
251 auto zmm_prev_dst = zmm_zero;
253 cvt2ps(jcp.dst_dt, zmm_prev_dst, output_ptr(i_load, i_ur),
256 if (*p_sum_scale == 1.f)
257 vaddps(r, zmm_prev_dst);
259 vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
262 vpxord(zmm_zero, zmm_zero, zmm_zero);
263 vmaxps(r, zmm_zero, r);
265 if (jcp.dst_dt != data_type::f32) {
266 if (attr_.round_mode_ == round_mode::nearest) {
267 vcvtps2dq(r | T_rn_sae, r);
268 } else if (attr_.round_mode_ == round_mode::down) {
269 vcvtps2dq(r | T_rd_sae, r);
271 assert(!"unimplemented");
274 for (int i_ur = 0; i_ur < ur; ++i_ur) {
275 auto r = vreg_accum(i_load, i_ur);
276 zmm_t r_zmm = mask_flag ? r | ktail_mask : r;
277 switch (jcp.dst_dt) {
280 vmovups(output_ptr(i_load, i_ur), r_zmm); break;
282 vpmovsdb(output_ptr(i_load, i_ur), r_zmm); break;
284 vpmovusdb(output_ptr(i_load, i_ur), r_zmm); break;
285 default: assert(!"unknown dst_dt");
289 mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
290 if (p_sum_scale && *p_sum_scale != 1.f)
291 mov(reg_load_data, EVEX_compress_addr(rsp, reg_load_data_off));
294 auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
295 if (jcp.ver == ver_vnni) {
296 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
298 vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
299 vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
300 vpaddd(vreg_acc, vreg_acc, zmm_tmp);
304 auto fma_block = [=](bool last_block) {
306 int tail_size = jcp.ic_without_padding % reduce_step;
307 int loop_unroll = last_block && jcp.ic != jcp.ic_without_padding
308 ? rnd_up(jcp.ic_without_padding % jcp.ic_block, reduce_step)
309 : jcp.reduce_loop_unroll;
310 for (int i_reduce = 0; i_reduce < loop_unroll;
311 i_reduce += reduce_step) {
312 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
313 vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load));
314 for (int i_ur = 0; i_ur < ur; ++i_ur) {
315 if (last_block && tail_size != 0
316 && i_reduce == loop_unroll - reduce_step) {
317 Xmm xmm_bcast = Xmm(zmm_bcast.getIdx());
318 for (int r = 0; r < tail_size; ++r)
319 vpinsrb(xmm_bcast, xmm_bcast, ptr[aux_reg_bcast_data
320 + jcp.ic_without_padding * i_ur + i_reduce + r], r);
321 vpbroadcastd(zmm_bcast, xmm_bcast);
323 vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false));
325 if (jcp.signed_input)
326 vpsubb(zmm_bcast, zmm_bcast, zmm_shift);
327 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
328 compute(vreg_accum(i_load, i_ur),
329 vreg_load(i_load), zmm_bcast);
336 Label reduce_loop_tail;
338 mov(aux_reg_load_data, reg_load_data);
340 mov(aux_reg_bcast_data, aux1_reg_bcast_data);
343 mov(reduce_loop_iter, reg_reduce_loop_work);
344 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
345 jle(reduce_loop_tail, T_NEAR);
349 add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
350 add(aux_reg_load_data, jcp.reduce_loop_load_step);
351 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
352 jg(reduce_loop, T_NEAR);
356 if (jcp.ic != jcp.ic_without_padding) {
362 if (jcp.oc_without_padding != jcp.oc) {
363 Label end_store, common_store;
364 mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
366 /*Check if it is the last load_loop_blk*/
367 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
368 cmp(reg_load_loop_work, 0);
369 jg(common_store, T_NEAR);
371 /*Check if it is the last ocb*/
372 test(reg_reduce_pos_flag, FLAG_OC_LAST);
373 jz(common_store, T_NEAR);
376 jmp(end_store, T_NEAR);
383 add(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
389 void jit_avx512_core_x8s8s32x_1x1_conv_kernel::generate()
393 xor_(reg_scratch, reg_scratch);
394 Reg16 _t = reg_scratch.cvt16();
396 vpbroadcastw(zmm_one, _t);
398 sub(rsp, stack_space_needed);
400 if (jcp.oc_without_padding != jcp.oc) {
401 int tail_size = jcp.oc_without_padding % jcp.oc_block;
402 int mask = (1 << tail_size) - 1;
403 Reg32 regw_tmp = reg_last_load.cvt32();
405 kmovw(ktail_mask, regw_tmp);
409 mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
410 if (jcp.signed_input) {
411 mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data);
412 mov(reg_comp_data, ptr[param1 + GET_OFF(compensation)]);
413 mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data);
415 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
416 mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
417 mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
418 mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
419 mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
421 mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
422 mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
423 mov(EVEX_compress_addr(rsp, bcast_loop_work_off), reg_bcast_loop_work);
424 mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
425 mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
428 auto load_loop_body = [=](int load_loop_blk) {
429 bcast_loop(load_loop_blk);
430 add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
432 if (jcp.signed_input)
433 mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_off));
435 load_loop_blk * jcp.load_block * jcp.typesize_bia);
436 if (jcp.signed_input)
437 mov(EVEX_compress_addr(rsp, reg_bias_data_off), reg_bias_data);
439 if (jcp.signed_input) {
440 mov(reg_comp_data, EVEX_compress_addr(rsp, reg_comp_data_off));
442 load_loop_blk * jcp.load_block * sizeof(int32_t));
443 mov(EVEX_compress_addr(rsp, reg_comp_data_off), reg_comp_data);
445 mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
446 mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off));
448 jcp.is_oc_scale * load_loop_blk * jcp.load_block * sizeof(float));
449 mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
450 mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
452 load_loop_blk * jcp.load_block * jcp.typesize_out);
453 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
456 const int simd_w = 16;
458 Label load_loop_blk[7];
460 static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 };
461 const int size_ur_cases_fma = sizeof(ur_cases_fma_expl_bcast);
462 const int *ur_cases_fma = ur_cases_fma_expl_bcast;
463 const int *ur_cases = ur_cases_fma;
464 const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases);
466 for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
467 int label_idx = num_ur_cases - ur_idx - 1;
468 if (jcp.ur <= ur_cases[ur_idx]) {
469 cmp(reg_load_loop_work, simd_w * (label_idx + 1));
470 jle(load_loop_blk[label_idx], T_NEAR);
474 for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
475 if (jcp.ur <= ur_cases[ur_idx]) {
476 int label_idx = num_ur_cases - ur_idx - 1;
477 L(load_loop_blk[label_idx]);
479 if (label_idx == 0) {
480 cmp(reg_load_loop_work, 0);
481 je(load_loop_blk[num_ur_cases], T_NEAR);
483 load_loop_body(label_idx + 1);
484 if (label_idx - 1 > 0) {
485 cmp(reg_load_loop_work, 2 * label_idx * simd_w);
486 je(load_loop_blk[label_idx - 1], T_NEAR);
488 cmp(reg_load_loop_work, (label_idx + 1) * simd_w);
489 jge(load_loop_blk[label_idx]);
491 for (int idx = label_idx - 1; idx > 0; --idx) {
492 cmp(reg_load_loop_work, simd_w * (idx + 1));
493 je(load_loop_blk[idx], T_NEAR);
495 if (ur_idx < num_ur_cases - 2) {
496 cmp(reg_load_loop_work, simd_w);
497 jle(load_loop_blk[0], T_NEAR);
501 L(load_loop_blk[num_ur_cases]);
503 add(rsp, stack_space_needed);
508 bool jit_avx512_core_x8s8s32x_1x1_conv_kernel::post_ops_ok(
509 jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
510 using namespace primitive_kind;
511 const auto &p = attr.post_ops_;
513 auto is_relu = [&](int idx) {
514 return p.entry_[idx].kind == eltwise
515 && p.entry_[idx].eltwise.scale == 1.
516 && p.entry_[idx].eltwise.alg == alg_kind::eltwise_relu
517 && p.entry_[idx].eltwise.alpha == 0.;
523 && IMPLICATION(jcp.with_eltwise, p.contain(sum, 0))
524 && IMPLICATION(!jcp.with_eltwise, is_relu(0) || p.contain(sum, 0));
526 && IMPLICATION(jcp.with_eltwise, p.contain(sum, 0) && is_relu(1))
527 && IMPLICATION(!jcp.with_eltwise, false
528 || (p.contain(sum, 0) && is_relu(1))
529 || (p.contain(sum, 1) && is_relu(0)));
531 && jcp.with_eltwise == false
532 && (is_relu(0) && p.contain(sum, 1) && is_relu(2));
533 default: return false;
539 status_t jit_avx512_core_x8s8s32x_1x1_conv_kernel::init_conf(
540 jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd,
541 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
542 const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d,
543 const primitive_attr_t &attr, bool with_relu, float relu_negative_slope,
544 int nthreads, bool reduce_src)
546 if (!mayiuse(avx512_core)) return status::unimplemented;
548 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
549 if (!one_of(src_d.data_type(), data_type::u8, data_type::s8)
550 || weights_d.data_type() != data_type::s8
551 || !one_of(dst_d.data_type(),
552 data_type::f32, data_type::s32, data_type::s8, data_type::u8))
553 return status::unimplemented;
554 if (!one_of(weights_d.format(), gOIhw4i16o4i, OIhw4i16o4i,
555 gOIhw4i16o4i_s8s8, OIhw4i16o4i_s8s8)) {
556 return status::unimplemented;
558 jcp.ver = ver_avx512_core;
559 if (mayiuse(avx512_core_vnni))
562 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
563 jcp.mb = src_d.dims()[0];
564 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
565 jcp.oc_without_padding = jcp.oc;
566 jcp.ic = src_d.dims()[1] / jcp.ngroups;
567 jcp.ic_without_padding = jcp.ic;
568 jcp.ih = src_d.dims()[2];
569 jcp.iw = src_d.dims()[3];
570 jcp.oh = dst_d.dims()[2];
571 jcp.ow = dst_d.dims()[3];
572 jcp.kh = weights_d.dims()[with_groups + 2];
573 jcp.kw = weights_d.dims()[with_groups + 3];
574 jcp.t_pad = cd.padding[0][0];
575 jcp.l_pad = cd.padding[0][1];
576 jcp.stride_h = cd.strides[0];
577 jcp.stride_w = cd.strides[1];
578 jcp.src_fmt = src_d.format();
579 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
580 jcp.with_eltwise = with_relu;
581 jcp.eltwise_alpha = relu_negative_slope;
582 if (!IMPLICATION(with_relu, relu_negative_slope == 0.))
583 return status::unimplemented;
585 jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
587 jcp.os = jcp.oh * jcp.ow;
588 jcp.is = jcp.ih * jcp.iw;
589 jcp.tr_is = rnd_up(jcp.is, 4);
591 if (!post_ops_ok(jcp, attr))
592 return status::unimplemented;
596 && src_d.format() == nhwc
597 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
598 && dst_d.format() == nhwc;
599 if (!args_ok) return status::unimplemented;
601 const int simd_w = 16;
603 jcp.oc = rnd_up(jcp.oc, simd_w);
604 jcp.ic = rnd_up(jcp.ic, simd_w);
607 && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
608 && jcp.t_pad == 0 && jcp.l_pad == 0
609 && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
610 && jcp.kh == 1 && jcp.kw == 1;
611 if (!args_ok) return status::unimplemented;
613 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
614 jcp.dst_dt = cd.dst_desc.data_type;
616 jcp.ic_block = jcp.oc_block = simd_w;
618 jcp.typesize_in = types::data_type_size(src_d.data_type());
619 jcp.typesize_out = types::data_type_size(dst_d.data_type());
620 jcp.typesize_bia = jcp.with_bias
621 ? types::data_type_size(bias_d.data_type())
624 const int SMALL_SPATIAL = 7 * 7;
625 const int BIG_REDUCE_DIM = 1024;
627 int load_blocking = 0;
628 int load_blocking_max = 0;
629 int bcast_blocking = 0;
630 int bcast_blocking_max = 0;
631 int reduce_blocking = 0;
632 int reduce_blocking_max = 0;
633 jcp.load_grp_count = 1;
634 jcp.use_vmovntps = false;
636 const int L2_size = get_cache_size(2, true) / sizeof(jcp.typesize_in);
637 const int L2_capacity = (L2_size * 3) / 4;
639 int size_treshold = 28;
642 if (jcp.ver == ver_vnni)
643 max_regs = ((jcp.oh > size_treshold && jcp.ow > size_treshold)
644 && (jcp.oc < 128 || jcp.ic < 128)) ? min_regs : 9;
647 jcp.expl_bcast = true;
649 const int spatial = jcp.oh;
651 for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) {
652 if ((spatial >= size_treshold && spatial % ur_w == 0)
653 || (spatial < size_treshold && jcp.os % ur_w == 0)) {
659 jcp.ur = nstl::min(max_regs, jcp.os);
660 int os_tail = jcp.os % max_regs;
661 for (int i = max_regs; i >= min_regs; i--) {
662 int i_tail = jcp.os % i;
663 if (i_tail > os_tail || i_tail == 0) {
672 jcp.reduce_dim = jcp.ic;
673 jcp.reduce_block = jcp.ic_block;
675 jcp.load_dim = jcp.oc;
676 jcp.load_block = jcp.oc_block;
678 jcp.bcast_dim = jcp.is;
680 jcp.bcast_block = jcp.ur;
682 jcp.reduce_loop_unroll = jcp.reduce_block;
683 jcp.reduce_loop_bcast_step
684 = jcp.reduce_loop_unroll * jcp.typesize_in;
686 jcp.reduce_loop_load_step
687 = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
689 jcp.bcast_loop_output_step = jcp.ur * jcp.oc_without_padding * jcp.typesize_out;
690 jcp.bcast_loop_output_substep = -1; // unused
691 jcp.bcast_loop_bcast_step = jcp.ur * jcp.ic_without_padding * jcp.typesize_in;
692 jcp.bcast_loop_bcast_substep = -1; // unused
694 jcp.load_loop_load_step
695 = jcp.reduce_dim * jcp.load_block * jcp.typesize_in;
697 jcp.load_loop_iter_step = jcp.load_block;
699 jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
701 int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
702 int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
704 reduce_blocking = nb_reduce;
705 if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
706 reduce_blocking = 64;
707 else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
708 reduce_blocking = 16;
709 reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true);
710 reduce_blocking *= jcp.reduce_block;
712 bool cmp_reduce = reduce_blocking <= jcp.reduce_dim;
714 jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
715 load_blocking = jcp.load_dim;
717 jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast);
718 jcp.load_grp_count = best_divider(
719 nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false);
721 if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.load_dim * jcp.reduce_dim >= L2_size) {
722 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
723 } else if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.mb <= nthreads
724 && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) {
725 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); //
726 load_blocking = jcp.load_block;
729 bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
730 div_up(nthreads, jcp.load_grp_count)) * jcp.bcast_block;
731 bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking);
732 bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
735 = (L2_capacity - /* kernel_size - */
736 2 * jcp.load_block * reduce_blocking
737 - jcp.ur * reduce_blocking - 3 * 1024);
738 if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity)
739 space_for_bcast /= 2;
742 = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
743 bcast_blocking = nstl::min(
744 bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
746 load_blocking_max = load_blocking;
747 bcast_blocking_max = bcast_blocking * 3 / 2;
748 reduce_blocking_max = reduce_blocking;
750 assert(load_blocking);
751 assert(load_blocking_max);
752 assert(bcast_blocking);
753 assert(bcast_blocking_max);
754 assert(reduce_blocking);
755 assert(reduce_blocking_max);
756 assert(load_blocking % jcp.load_block == 0);
757 assert(reduce_blocking % jcp.reduce_block == 0);
758 assert(load_blocking_max % jcp.load_block == 0);
759 assert(reduce_blocking_max % jcp.reduce_block == 0);
761 assert(jcp.reduce_loop_unroll % 4 == 0);
762 assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
764 assert(jcp.bcast_block % jcp.ur == 0);
765 assert(jcp.reduce_dim % jcp.reduce_block == 0);
767 jcp.ur_tail = jcp.bcast_dim % jcp.ur;
769 jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
770 jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
771 jcp.nb_load_blocking = load_blocking / jcp.load_block;
772 jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
773 jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
774 jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block;
776 jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
777 jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
778 jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
780 const auto &oscales = attr.output_scales_;
781 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
782 assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
784 jcp.wei_adj_scale = (jcp.signed_input) ? (1.f / 2.f) : 1.f;
786 return status::success;