1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
19 #include "type_helpers.hpp"
20 #include "mkldnn_thread.hpp"
22 #include "cpu_memory.hpp"
24 #include "jit_uni_1x1_conv_utils.hpp"
25 #include "jit_avx512_core_u8s8s32x_1x1_conv_kernel.hpp"
27 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
33 using namespace mkldnn::impl::memory_format;
34 using namespace mkldnn::impl::utils;
36 using namespace Xbyak;
38 bool jit_avx512_core_u8s8s32x_1x1_conv_kernel::maybe_relu(int position)
40 using namespace primitive_kind;
41 const auto &p = attr_.post_ops_;
47 || p.contain(eltwise, 0)
48 || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
49 } else if (position == 1) {
51 const int sum_idx = p.contain(sum, 0)
52 ? 0 : (p.contain(sum, 1) ? 1 : -1);
57 || p.contain(eltwise, sum_idx + 1)
58 || jcp.dst_dt == data_type::u8;
64 void jit_avx512_core_u8s8s32x_1x1_conv_kernel::bcast_loop(int load_loop_blk)
66 mov(aux1_reg_bcast_data, reg_bcast_data);
67 mov(aux_reg_bcast_data, reg_bcast_data);
69 mov(aux_reg_output_data, reg_output_data);
70 mov(aux_reg_acc_s32, reg_acc_s32);
72 mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_offt));
75 Label bcast_loop_tail;
77 cmp(bcast_loop_iter, jcp.ur);
78 jl(bcast_loop_tail, T_NEAR);
81 assert(jcp.bcast_block % jcp.ur == 0);
82 int num_substeps = jcp.bcast_block / jcp.ur;
83 assert(num_substeps > 0 && num_substeps < 10);
84 for (int i = 0; i < num_substeps; i++) {
85 reduce_loop(load_loop_blk, jcp.ur, i, false);
86 if (i < num_substeps - 1) {
87 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
88 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
90 (jcp.bcast_loop_output_substep / jcp.typesize_out)
92 add(aux_reg_acc_s32, ws_offset);
95 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
96 - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
97 int output_offset = jcp.bcast_loop_output_step
98 - (num_substeps - 1) * jcp.bcast_loop_output_substep;
99 add(aux_reg_output_data, output_offset);
100 int ws_offset = (output_offset / jcp.typesize_out)
102 add(aux_reg_acc_s32, ws_offset);
105 sub(bcast_loop_iter, jcp.bcast_block);
106 cmp(bcast_loop_iter, jcp.bcast_block);
107 jge(bcast_loop, T_NEAR);
112 Label bcast_loop_tail_out;
113 cmp(bcast_loop_iter, 0);
114 jz(bcast_loop_tail_out, T_NEAR);
115 reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
116 L(bcast_loop_tail_out);
120 void jit_avx512_core_u8s8s32x_1x1_conv_kernel::reduce_loop(int load_loop_blk,
121 int ur, int substep, bool wraparound)
123 auto vreg_load = [=](int i_load) {
124 return Zmm(ur * load_loop_blk + i_load);
127 auto vreg_accum = [=](int i_load, int i_ur) {
128 return Zmm(i_ur * load_loop_blk + i_load);
131 auto xreg_accum = [=](int i_load, int i_ur) {
132 return Xmm(i_ur * load_loop_blk + i_load);
135 auto bias_ptr = [=](int i_load) {
136 return EVEX_compress_addr(reg_bias_data,
137 jcp.typesize_bia * jcp.oc_block * i_load);
139 auto scale_ptr = [=](int i_load) {
140 return EVEX_compress_addr(reg_ptr_scales,
141 jcp.is_oc_scale * (sizeof(float) * jcp.oc_block * i_load));
144 auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) {
145 assert(i_ur < jcp.ur);
146 assert(i_reduce <= jcp.reduce_loop_unroll);
147 assert(jcp.reduce_loop_unroll == jcp.reduce_block);
149 int offt = (jcp.reduce_dim * i_ur + i_reduce);
151 return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt,
155 auto load_ptr = [=](int i_reduce, int i_load) {
156 int u0 = i_reduce % jcp.reduce_loop_unroll;
157 int u1 = i_reduce / jcp.reduce_loop_unroll;
159 int offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block;
161 return EVEX_compress_addr(aux_reg_load_data,
162 u1 * jcp.reduce_loop_load_step
163 + jcp.typesize_in * offt);
166 auto output_ptr = [=](int i_load, int i_ur) {
167 return EVEX_compress_addr(aux_reg_output_data,
168 jcp.typesize_out * (jcp.load_dim * i_ur + i_load * jcp.load_block));
171 auto acc_s32_ptr = [=](int i_load, int i_ur) {
172 return EVEX_compress_addr(aux_reg_acc_s32,
173 jcp.typesize_acc * (jcp.load_dim * i_ur + i_load * jcp.load_block));
177 Label l_first_load, l_ret;
179 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
180 jnz(l_first_load, T_NEAR); // FISRT load: if not zero jump to <l_first_load>
182 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
183 for (int i_ur = 0; i_ur < ur; ++i_ur) {
184 auto r = vreg_accum(i_load, i_ur);
185 vmovups(r, acc_s32_ptr(i_load, i_ur));
190 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
191 for (int i_ur = 0; i_ur < ur; ++i_ur) {
192 auto r = vreg_accum(i_load, i_ur);
199 Label l_update_acc, l_ret;
201 test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
202 jz(l_update_acc, T_NEAR); // LAST channel: if zero jump to <l_update_acc>
204 const auto &p = attr_.post_ops_;
205 const int sum_idx = p.find(primitive_kind::sum);
206 const float *p_sum_scale = (sum_idx != -1)
207 ? &p.entry_[sum_idx].sum.scale
211 mov(EVEX_compress_addr(rsp, aux_reg_acc_s32_offt), aux_reg_acc_s32);
212 mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_offt));
214 mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
215 mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off));
216 if (p_sum_scale && *p_sum_scale != 1.f) {
217 mov(EVEX_compress_addr(rsp, reg_load_data_off), reg_load_data);
218 mov(reg_ptr_sum_scale, (size_t)p_sum_scale);
220 vpxord(zmm_zero, zmm_zero, zmm_zero);
221 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
222 auto zmm_bias = zmm_tmp;
224 switch (jcp.bia_dt) {
226 case data_type::s32: vmovups(zmm_bias,
227 bias_ptr(i_load)); break;
228 case data_type::s8: vpmovsxbd(zmm_bias,
229 bias_ptr(i_load)); break;
230 case data_type::u8: vpmovzxbd(zmm_bias,
231 bias_ptr(i_load)); break;
232 default: assert(!"unsupported bias data type");
234 if (jcp.bia_dt != data_type::f32)
235 vcvtdq2ps(zmm_bias, zmm_bias);
237 for (int i_ur = 0; i_ur < ur; ++i_ur) {
238 auto r = vreg_accum(i_load, i_ur);
239 auto x = xreg_accum(i_load, i_ur);
242 vaddps(r, r, zmm_bias);
243 vmulps(r, r, scale_ptr(i_load));
245 vmaxps(r, zmm_zero, r);
246 if (p_sum_scale) { // post_op: sum
247 auto zmm_prev_dst = zmm_bcast;
248 switch (jcp.dst_dt) {
250 case data_type::s32: vmovups(zmm_prev_dst,
251 output_ptr(i_load, i_ur)); break;
252 case data_type::s8: vpmovsxbd(zmm_prev_dst,
253 output_ptr(i_load, i_ur)); break;
254 case data_type::u8: vpmovzxbd(zmm_prev_dst,
255 output_ptr(i_load, i_ur)); break;
256 default: assert(!"unsupported dst data type");
258 if (jcp.dst_dt != data_type::f32)
259 vcvtdq2ps(zmm_prev_dst, zmm_prev_dst);
260 if (*p_sum_scale == 1.f)
261 vaddps(r, zmm_prev_dst);
263 vfmadd231ps(r, zmm_prev_dst, zword_b[reg_ptr_sum_scale]);
266 vmaxps(r, zmm_zero, r);
267 if (jcp.dst_dt != data_type::f32) {
268 if (attr_.round_mode_ == round_mode::nearest) {
269 vcvtps2dq(r | T_rn_sae, r);
270 } else if (attr_.round_mode_ == round_mode::down) {
271 vcvtps2dq(r | T_rd_sae, r);
273 assert(!"unimplemented");
275 switch (jcp.dst_dt) {
277 case data_type::s32: vmovups(output_ptr(i_load, i_ur), r); break;
278 case data_type::s8: vpmovsdb(x, r);
279 vmovups(output_ptr(i_load, i_ur), x); break;
280 case data_type::u8: vpmovusdb(x, r);
281 vmovups(output_ptr(i_load, i_ur), x); break;
282 default: assert(!"unknown dst_dt");
287 mov(aux_reg_acc_s32, EVEX_compress_addr(rsp, aux_reg_acc_s32_offt));
288 mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
289 if (p_sum_scale && *p_sum_scale != 1.f)
290 mov(reg_load_data, EVEX_compress_addr(rsp, reg_load_data_off));
295 mov(aux_reg_bcast_data, EVEX_compress_addr(rsp, aux_reg_acc_s32_offt));
296 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
297 for (int i_ur = 0; i_ur < ur; ++i_ur) {
298 auto r = vreg_accum(i_load, i_ur);
299 vmovups(acc_s32_ptr(i_load, i_ur), r);
305 auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
306 if (jcp.ver == ver_vnni) {
307 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
309 vpmaddubsw(zmm_tmp, vreg_src, vreg_wei);
310 vpmaddwd(zmm_tmp, zmm_tmp, zmm_one);
311 vpaddd(vreg_acc, vreg_acc, zmm_tmp);
315 auto fma_block = [=](bool last_block) {
317 for (int i_reduce = 0; i_reduce < jcp.reduce_loop_unroll;
318 i_reduce += reduce_step) {
319 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
320 vmovups(vreg_load(i_load), load_ptr(i_reduce, i_load));
321 for (int i_ur = 0; i_ur < ur; ++i_ur) {
322 vpbroadcastd(zmm_bcast, bcast_ptr(i_reduce, i_ur, false));
323 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
324 compute(vreg_accum(i_load, i_ur),
325 vreg_load(i_load), zmm_bcast);
332 Label reduce_loop_tail;
334 mov(aux_reg_load_data, reg_load_data);
336 mov(aux_reg_bcast_data, aux1_reg_bcast_data);
339 mov(reduce_loop_iter, reg_reduce_loop_work);
340 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
341 jle(reduce_loop_tail, T_NEAR);
345 add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
346 add(aux_reg_load_data, jcp.reduce_loop_load_step);
347 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
348 jg(reduce_loop, T_NEAR);
357 void jit_avx512_core_u8s8s32x_1x1_conv_kernel::generate()
361 xor_(reg_scratch, reg_scratch);
362 Reg16 _t = reg_scratch.cvt16();
364 vpbroadcastw(zmm_one, _t);
366 sub(rsp, stack_space_needed);
368 mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
369 mov(EVEX_compress_addr(rsp, reg_bias_data_offt), reg_bias_data);
371 mov(reg_ptr_scales, ptr[param1 + GET_OFF(scales)]);
372 mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
373 mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
374 mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
375 mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
377 mov(reg_acc_s32, ptr[param1 + GET_OFF(acc_s32)]);
378 mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
379 mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
380 mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work);
381 mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
382 mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(reduce_pos_flag)]);
385 auto load_loop_body = [=](int load_loop_blk) {
386 bcast_loop(load_loop_blk);
387 add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
389 mov(reg_bias_data, EVEX_compress_addr(rsp, reg_bias_data_offt));
391 load_loop_blk * jcp.load_block * jcp.typesize_bia);
392 mov(EVEX_compress_addr(rsp, reg_bias_data_offt), reg_bias_data);
394 mov(EVEX_compress_addr(rsp, reg_bcast_data_off), reg_bcast_data);
395 mov(reg_ptr_scales, EVEX_compress_addr(rsp, reg_ptr_sum_scale_off));
397 jcp.is_oc_scale * load_loop_blk * jcp.load_block * sizeof(float));
398 mov(EVEX_compress_addr(rsp, reg_ptr_sum_scale_off), reg_ptr_scales);
399 mov(reg_bcast_data, EVEX_compress_addr(rsp, reg_bcast_data_off));
401 load_loop_blk * jcp.load_block * jcp.typesize_out);
403 load_loop_blk * jcp.load_block * jcp.typesize_acc);
404 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
407 const int simd_w = 16;
409 Label load_loop_blk[7];
411 static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 };
412 const int size_ur_cases_fma = sizeof(ur_cases_fma_expl_bcast);
413 const int *ur_cases_fma = ur_cases_fma_expl_bcast;
414 const int *ur_cases = ur_cases_fma;
415 const int num_ur_cases = (size_ur_cases_fma) / sizeof(*ur_cases);
417 for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
418 int label_idx = num_ur_cases - ur_idx - 1;
419 if (jcp.ur <= ur_cases[ur_idx]) {
420 cmp(reg_load_loop_work, simd_w * (label_idx + 1));
421 jle(load_loop_blk[label_idx], T_NEAR);
425 for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
426 if (jcp.ur <= ur_cases[ur_idx]) {
427 int label_idx = num_ur_cases - ur_idx - 1;
428 L(load_loop_blk[label_idx]);
430 if (label_idx == 0) {
431 cmp(reg_load_loop_work, 0);
432 je(load_loop_blk[num_ur_cases], T_NEAR);
434 load_loop_body(label_idx + 1);
435 if (label_idx - 1 > 0) {
436 cmp(reg_load_loop_work, 2 * label_idx * simd_w);
437 je(load_loop_blk[label_idx - 1], T_NEAR);
439 cmp(reg_load_loop_work, (label_idx + 1) * simd_w);
440 jge(load_loop_blk[label_idx]);
442 for (int idx = label_idx - 1; idx > 0; --idx) {
443 cmp(reg_load_loop_work, simd_w * (idx + 1));
444 je(load_loop_blk[idx], T_NEAR);
446 if (ur_idx < num_ur_cases - 2) {
447 cmp(reg_load_loop_work, simd_w);
448 jle(load_loop_blk[0], T_NEAR);
452 L(load_loop_blk[num_ur_cases]);
454 add(rsp, stack_space_needed);
459 bool jit_avx512_core_u8s8s32x_1x1_conv_kernel::post_ops_ok(
460 jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
461 using namespace primitive_kind;
462 const auto &p = attr.post_ops_;
464 auto is_relu = [&](int idx) {
465 return p.entry_[idx].kind == eltwise
466 && p.entry_[idx].eltwise.scale == 1.
467 && p.entry_[idx].eltwise.alg == alg_kind::eltwise_relu
468 && p.entry_[idx].eltwise.alpha == 0.;
474 && implication(jcp.with_eltwise, p.contain(sum, 0))
475 && implication(!jcp.with_eltwise, is_relu(0) || p.contain(sum, 0));
477 && implication(jcp.with_eltwise, p.contain(sum, 0) && is_relu(1))
478 && implication(!jcp.with_eltwise, false
479 || (p.contain(sum, 0) && is_relu(1))
480 || (p.contain(sum, 1) && is_relu(0)));
482 && jcp.with_eltwise == false
483 && (is_relu(0) && p.contain(sum, 1) && is_relu(2));
484 default: return false;
490 status_t jit_avx512_core_u8s8s32x_1x1_conv_kernel::init_conf(
491 jit_1x1_conv_conf_t &jcp, const convolution_desc_t &cd,
492 const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d,
493 const memory_desc_wrapper &dst_d, const memory_desc_wrapper &bias_d,
494 const primitive_attr_t &attr, bool with_relu, float relu_negative_slope,
495 int nthreads, bool reduce_src)
497 if (!mayiuse(avx512_core)) return status::unimplemented;
499 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
500 if (src_d.data_type() != data_type::u8
501 || weights_d.data_type() != data_type::s8
502 || !one_of(dst_d.data_type(),
503 data_type::f32, data_type::s32, data_type::s8, data_type::u8))
504 return status::unimplemented;
505 if (!one_of(weights_d.format(), gOIhw4i16o4i, OIhw4i16o4i))
506 return status::unimplemented;
508 jcp.ver = ver_avx512_core;
509 if (mayiuse(avx512_core_vnni))
512 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
513 jcp.mb = src_d.dims()[0];
514 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
515 jcp.ic = src_d.dims()[1] / jcp.ngroups;
516 jcp.ih = src_d.dims()[2];
517 jcp.iw = src_d.dims()[3];
518 jcp.oh = dst_d.dims()[2];
519 jcp.ow = dst_d.dims()[3];
520 jcp.kh = weights_d.dims()[with_groups + 2];
521 jcp.kw = weights_d.dims()[with_groups + 3];
522 jcp.t_pad = cd.padding[0][0];
523 jcp.l_pad = cd.padding[0][1];
524 jcp.stride_h = cd.strides[0];
525 jcp.stride_w = cd.strides[1];
526 jcp.src_fmt = src_d.format();
527 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
528 jcp.with_eltwise = with_relu;
529 jcp.eltwise_alpha = relu_negative_slope;
530 if (!implication(with_relu, relu_negative_slope == 0.))
531 return status::unimplemented;
533 jcp.os = jcp.oh * jcp.ow;
534 jcp.is = jcp.ih * jcp.iw;
535 jcp.tr_is = rnd_up(jcp.is, 4);
537 if (!post_ops_ok(jcp, attr))
538 return status::unimplemented;
542 && src_d.format() == nhwc
543 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
544 && dst_d.format() == nhwc;
545 if (!args_ok) return status::unimplemented;
547 const int simd_w = 16;
550 && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
551 && jcp.t_pad == 0 && jcp.l_pad == 0
552 && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
553 && jcp.kh == 1 && jcp.kw == 1;
554 if (!args_ok) return status::unimplemented;
556 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
557 jcp.dst_dt = cd.dst_desc.data_type;
559 jcp.ic_block = jcp.oc_block = simd_w;
561 jcp.typesize_in = types::data_type_size(src_d.data_type());
562 jcp.typesize_out = types::data_type_size(dst_d.data_type());
563 jcp.typesize_acc = sizeof(int32_t);
564 jcp.typesize_bia = jcp.with_bias
565 ? types::data_type_size(bias_d.data_type())
568 const int SMALL_SPATIAL = 7 * 7;
569 const int BIG_REDUCE_DIM = 1024;
571 int load_blocking = 0;
572 int load_blocking_max = 0;
573 int bcast_blocking = 0;
574 int bcast_blocking_max = 0;
575 int reduce_blocking = 0;
576 int reduce_blocking_max = 0;
577 jcp.load_grp_count = 1;
578 jcp.use_vmovntps = false;
580 const int L2_size = get_cache_size(2, true) / sizeof(jcp.typesize_in);
581 const int L2_capacity = (L2_size * 3) / 4;
583 int size_treshold = 28;
584 int max_regs = (jcp.ver == ver_vnni) ? 9 : 8;
586 jcp.expl_bcast = true;
588 const int spatial = jcp.oh;
590 for (int ur_w = max_regs; ur_w >= min_regs; ur_w--) {
591 if ((spatial >= size_treshold && spatial % ur_w == 0)
592 || (spatial < size_treshold && jcp.os % ur_w == 0)) {
598 jcp.ur = nstl::min(max_regs, jcp.os);
599 int os_tail = jcp.os % max_regs;
600 for (int i = max_regs; i >= min_regs; i--) {
601 int i_tail = jcp.os % i;
602 if (i_tail > os_tail || i_tail == 0) {
611 jcp.reduce_dim = jcp.ic;
612 jcp.reduce_block = jcp.ic_block;
614 jcp.load_dim = jcp.oc;
615 jcp.load_block = jcp.oc_block;
617 jcp.bcast_dim = jcp.is;
619 jcp.bcast_block = jcp.ur;
621 jcp.reduce_loop_unroll = jcp.reduce_block;
622 jcp.reduce_loop_bcast_step
623 = jcp.reduce_loop_unroll * jcp.typesize_in;
625 jcp.reduce_loop_load_step
626 = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
628 jcp.bcast_loop_output_step = jcp.ur * jcp.load_dim * jcp.typesize_out;
629 jcp.bcast_loop_output_substep = -1; // unused
630 jcp.bcast_loop_bcast_step = jcp.ur * jcp.reduce_dim * jcp.typesize_in;
631 jcp.bcast_loop_bcast_substep = -1; // unused
633 jcp.load_loop_load_step
634 = jcp.reduce_dim * jcp.load_block * jcp.typesize_in;
636 jcp.load_loop_iter_step = jcp.load_block;
638 jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
640 int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
641 int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
643 reduce_blocking = nb_reduce;
644 if (jcp.bcast_dim <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
645 reduce_blocking = 64;
646 else if (jcp.bcast_dim > SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
647 reduce_blocking = 16;
648 reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true);
649 reduce_blocking *= jcp.reduce_block;
651 bool cmp_reduce = reduce_blocking <= jcp.reduce_dim;
653 jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
654 load_blocking = jcp.load_dim;
656 jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast);
657 jcp.load_grp_count = best_divider(
658 nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false);
660 if (jcp.bcast_dim <= 64 && jcp.load_dim * jcp.reduce_dim >= L2_size) {
661 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
662 } else if (jcp.bcast_dim <= 49 && jcp.mb <= nthreads
663 && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) {
664 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2); //
665 load_blocking = jcp.load_block;
668 bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
669 div_up(nthreads, jcp.load_grp_count)) * jcp.bcast_block;
670 bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking);
671 bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
674 = (L2_capacity - /* kernel_size - */
675 2 * jcp.load_block * reduce_blocking
676 - jcp.ur * reduce_blocking - 3 * 1024);
677 if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity)
678 space_for_bcast /= 2;
681 = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
682 bcast_blocking = nstl::min(
683 bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
685 load_blocking_max = load_blocking;
686 bcast_blocking_max = bcast_blocking * 3 / 2;
687 reduce_blocking_max = reduce_blocking;
689 assert(load_blocking);
690 assert(load_blocking_max);
691 assert(bcast_blocking);
692 assert(bcast_blocking_max);
693 assert(reduce_blocking);
694 assert(reduce_blocking_max);
695 assert(load_blocking % jcp.load_block == 0);
696 assert(reduce_blocking % jcp.reduce_block == 0);
697 assert(load_blocking_max % jcp.load_block == 0);
698 assert(reduce_blocking_max % jcp.reduce_block == 0);
700 assert(jcp.reduce_loop_unroll % 4 == 0);
701 assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
703 assert(jcp.bcast_block % jcp.ur == 0);
704 assert(jcp.reduce_dim % jcp.reduce_block == 0);
706 jcp.ur_tail = jcp.bcast_dim % jcp.ur;
708 jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
709 jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
710 jcp.nb_load_blocking = load_blocking / jcp.load_block;
711 jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
712 jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
713 jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block;
715 jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
716 jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
717 jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
719 const auto &oscales = attr.output_scales_;
720 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
721 assert(utils::implication(!jcp.is_oc_scale, oscales.mask_ == 0));
723 return status::success;