1 /*******************************************************************************
2 * Copyright 2017-2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
20 #include "c_types_map.hpp"
21 #include "memory_tracking.hpp"
22 #include "mkldnn_thread.hpp"
24 #include "type_helpers.hpp"
27 #include "cpu_memory.hpp"
28 #include "cpu_barrier.hpp"
30 #include "jit_uni_1x1_conv_utils.hpp"
31 #include "jit_avx512_common_1x1_conv_kernel.hpp"
33 #define GET_OFF(field) offsetof(jit_1x1_conv_call_s, field)
39 using namespace mkldnn::impl::prop_kind;
40 using namespace mkldnn::impl::memory_format;
41 using namespace mkldnn::impl::utils;
43 using namespace Xbyak;
45 void jit_avx512_common_1x1_conv_kernel::bcast_loop(int load_loop_blk)
47 mov(aux1_reg_bcast_data, reg_bcast_data);
48 mov(aux_reg_bcast_data, reg_bcast_data);
50 mov(aux_reg_output_data, reg_output_data);
51 mov(bcast_loop_iter, EVEX_compress_addr(rsp, bcast_loop_work_offt));
53 if (jcp.ver == ver_4fma)
56 Label bcast_loop_wraparound;
58 Label bcast_loop_ur_full;
60 cmp(bcast_loop_iter, jcp.ur);
61 jle(bcast_loop_wraparound, T_NEAR);
64 assert(jcp.bcast_block % jcp.ur == 0);
65 int num_substeps = jcp.bcast_block / jcp.ur;
66 assert(num_substeps > 0 && num_substeps < 10);
67 for (int i = 0; i < num_substeps; i++) {
68 reduce_loop(load_loop_blk, jcp.ur, i, false);
69 if (i < num_substeps - 1) {
70 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
71 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
74 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
75 - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
76 add(aux_reg_output_data, jcp.bcast_loop_output_step
77 - (num_substeps - 1) * jcp.bcast_loop_output_substep);
80 sub(bcast_loop_iter, jcp.bcast_block);
81 cmp(bcast_loop_iter, jcp.bcast_block);
82 jg(bcast_loop, T_NEAR);
85 L(bcast_loop_wraparound);
87 je(bcast_loop_ur_full, T_NEAR);
88 reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
89 jmp(bcast_loop_out, T_NEAR);
91 L(bcast_loop_ur_full);
92 reduce_loop(load_loop_blk, jcp.ur, 0, true);
98 Label bcast_loop_tail;
100 cmp(bcast_loop_iter, jcp.ur);
101 jl(bcast_loop_tail, T_NEAR);
104 assert(jcp.bcast_block % jcp.ur == 0);
105 int num_substeps = jcp.bcast_block / jcp.ur;
106 assert(num_substeps > 0 && num_substeps < 10);
107 for (int i = 0; i < num_substeps; i++) {
108 reduce_loop(load_loop_blk, jcp.ur, i, false);
109 if (i < num_substeps - 1) {
110 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_substep);
111 add(aux_reg_output_data, jcp.bcast_loop_output_substep);
114 add(aux1_reg_bcast_data, jcp.bcast_loop_bcast_step
115 - (num_substeps - 1) * jcp.bcast_loop_bcast_substep);
116 add(aux_reg_output_data, jcp.bcast_loop_output_step
117 - (num_substeps - 1) * jcp.bcast_loop_output_substep);
120 sub(bcast_loop_iter, jcp.bcast_block);
121 cmp(bcast_loop_iter, jcp.bcast_block);
122 jge(bcast_loop, T_NEAR);
127 Label bcast_loop_tail_out;
128 cmp(bcast_loop_iter, 0);
129 jz(bcast_loop_tail_out, T_NEAR);
130 reduce_loop(load_loop_blk, jcp.ur_tail, 0, true);
131 L(bcast_loop_tail_out);
136 void jit_avx512_common_1x1_conv_kernel::reduce_loop(int load_loop_blk,
137 int ur, int substep, bool wraparound)
139 auto vreg_load = [=](int i_load, int i_fma) {
140 return Zmm(utils::rnd_up(ur * load_loop_blk, jcp.fma_step)
141 + jcp.fma_step * i_load + i_fma);
144 auto vreg_accum = [=](int i_load, int i_ur) {
145 return Zmm(i_ur + i_load * ur);
148 auto bias_ptr = [=](int i_load) {
149 return EVEX_compress_addr(reg_bias_data,
150 jcp.typesize_out * jcp.oc_block * i_load);
153 auto bcast_ptr = [=](int i_reduce, int i_ur, bool bcast) {
154 assert(i_ur < jcp.ur);
155 assert(i_reduce <= jcp.reduce_loop_unroll);
157 if (one_of(jcp.prop_kind, forward_training, forward_inference,
159 assert(jcp.reduce_loop_unroll == jcp.reduce_block);
160 offt = (i_reduce == jcp.reduce_loop_unroll)
161 ? (jcp.bcast_dim + i_ur) * jcp.reduce_loop_unroll
162 : i_ur * jcp.reduce_loop_unroll + i_reduce;
164 if (jcp.transpose_src) {
165 const int reduce_group = i_reduce / 4;
166 const int reduce_shift = i_reduce % 4;
167 offt = 4 * (reduce_group * jcp.ic_block + i_ur) + reduce_shift;
170 offt = i_reduce * jcp.ic_block + i_ur;
172 return EVEX_compress_addr(aux_reg_bcast_data, jcp.typesize_in * offt,
176 auto load_ptr = [=](int i_reduce, int i_load) {
178 int u0 = i_reduce % jcp.reduce_loop_unroll;
179 int u1 = i_reduce / jcp.reduce_loop_unroll;
180 if (jcp.prop_kind == backward_data && jcp.ver == ver_4vnni)
181 offt = (i_load * jcp.reduce_block + u0) * jcp.load_block;
183 offt = (i_load * jcp.reduce_dim + u0) * jcp.load_block;
184 return EVEX_compress_addr(aux_reg_load_data,
185 u1 * jcp.reduce_loop_load_step
186 + jcp.typesize_in * offt);
189 auto output_ptr = [=](int i_load, int i_ur) {
190 if (one_of(jcp.prop_kind, forward_training, forward_inference,
192 return EVEX_compress_addr(aux_reg_output_data,
193 (i_load * jcp.bcast_dim + i_ur) * jcp.load_block
196 return ptr[aux_reg_output_data +
198 ? reg_output_stride * i_load
199 : 0) // TODO: Xbyak should allow 0 scale
200 + jcp.typesize_out * jcp.load_block * i_ur];
208 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
209 for (int i_ur = 0; i_ur < ur; ++i_ur) {
210 mic_prefetcht1(output_ptr(i_load, i_ur));
216 && one_of(jcp.prop_kind, forward_training, forward_inference)) {
217 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
218 jz(init_zero, T_NEAR);
220 for (int i_load = 0; i_load < load_loop_blk; i_load++)
221 for (int i_ur = 0; i_ur < ur; ++i_ur)
222 vmovups(vreg_accum(i_load, i_ur), bias_ptr(i_load));
223 jmp(init_done, T_NEAR);
227 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
228 for (int i_ur = 0; i_ur < ur; ++i_ur) {
229 auto r = vreg_accum(i_load, i_ur);
235 auto vadd = [=](const Xmm& x1, const Xmm& x2, const Operand& op) {
236 if (jcp.ver == ver_4vnni)
246 test(reg_reduce_pos_flag, FLAG_REDUCE_FIRST);
247 jnz(store_noadd, T_NEAR);
250 for (int i_ur = 0; i_ur < ur; ++i_ur)
251 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
252 auto r = vreg_accum(i_load, i_ur);
253 vadd(r, r, output_ptr(i_load, i_ur));
258 Label store_nopostproc;
259 test(reg_reduce_pos_flag, FLAG_REDUCE_LAST);
260 jz(store_nopostproc, T_NEAR);
262 int eltwise_inj_idx = 0;
263 int depthwise_inj_idx = 0;
264 const auto &p = attr_.post_ops_;
266 for (int i = 0; i < p.len_; i++) {
267 auto& post_op = p.entry_[i];
268 if (post_op.is_eltwise()) {
269 if (jcp.ver == ver_4vnni) {
270 zmm_t zmm_zero = vreg_bcast;
271 vpxord(zmm_zero, zmm_zero, zmm_zero);
273 for (int i_ur = 0; i_ur < ur; ++i_ur) {
274 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
275 Zmm zmm = vreg_accum(i_load, i_ur);
276 vpcmpd(k1, zmm, zmm_zero, _cmp_lt_os);
277 vpmulld(zmm | k1, zmm, zmm_zero);
281 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(0, ur * load_loop_blk);
284 } else if (post_op.is_depthwise()) {
285 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
286 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
288 add(reg_d_weights, reg_oc_off);
289 add(reg_d_bias, reg_oc_off);
291 for (int j = 0; j < load_loop_blk; ++j) {
292 int start_idx = vreg_accum(j, 0).getIdx();
293 int end_idx = start_idx + ur;
295 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
296 start_idx, end_idx, reg_d_weights, reg_d_bias);
298 add(reg_d_weights, jcp.oc_block * sizeof(float));
299 add(reg_d_bias, jcp.oc_block * sizeof(float));
308 auto store_output = [=](bool output_is_aligned) {
309 for (int i_ur = 0; i_ur < ur; ++i_ur)
310 for (int i_load = 0; i_load < load_loop_blk; ++i_load)
311 if (output_is_aligned && jcp.use_vmovntps)
312 vmovntps(output_ptr(i_load, i_ur),
313 vreg_accum(i_load, i_ur));
315 vmovups(output_ptr(i_load, i_ur),
316 vreg_accum(i_load, i_ur));
319 Label unaligned_store, end_store;
320 test(aux_reg_output_data, cpu_isa_traits<avx512_common>::vlen - 1);
321 jnz(unaligned_store, T_NEAR);
323 jmp(end_store, T_NEAR);
324 L(unaligned_store); {
330 auto prefetch_callback = [=](int ur, int i_reduce, int i_ur, int i_load,
331 bool last_block, bool wraparound, int reduce_step)
333 bool pf_ker_l1 = true;
334 bool pf_ker_l2 = wraparound;
335 int n_ops = (jcp.reduce_loop_unroll / reduce_step) * ur * load_loop_blk;
336 int i_op = (i_reduce / reduce_step) * ur * load_loop_blk +
337 i_ur * load_loop_blk + i_load;
339 int n_pf_ker_l1 = pf_ker_l1 ? jcp.reduce_block : 0;
340 int n_pf_ker_l2 = pf_ker_l2 && wraparound ? jcp.reduce_block : 0;
341 int n_pf_out_l1 = jcp.use_vmovntps ? 0 : ur;
343 int pf_inp_ops = n_ops / 2; // # of operations during which to pf input
345 if (jcp.prop_kind == backward_weights)
346 pf_inp_trigger = nstl::max(1, pf_inp_ops / jcp.reduce_block);
348 pf_inp_trigger = nstl::max(1, pf_inp_ops / ur);
351 load_loop_blk * (n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1);
352 int n_other_pf_ops = n_ops - pf_inp_ops;
354 = n_other_pf ? nstl::max(1, n_other_pf_ops / n_other_pf) : 0;
356 if (i_op < pf_inp_ops && i_op % pf_inp_trigger == 0) {
357 // input prefetches have the highest priority b/c the
358 // first iteration of the kernel block touches all the
360 int i_pf = i_op / pf_inp_trigger;
361 auto pf_reg = wraparound && last_block
363 : (last_block ? aux1_reg_bcast_data
364 : aux_reg_bcast_data);
366 if (jcp.prop_kind == backward_weights) {
367 offt += wraparound && last_block
369 : (last_block ? jcp.is : jcp.reduce_block);
370 offt *= jcp.bcast_block;
372 offt += wraparound && last_block
374 : (last_block ? jcp.ur : jcp.bcast_dim);
375 offt *= jcp.reduce_block;
377 mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]);
378 } else if (i_op >= pf_inp_ops && n_other_pf) {
379 // remaining prefetches are spread among the rest of the
380 // operations; prefetches for output take priority
381 // TODO: spread L2 prefetches among L1 prefetches
383 if (i_op % other_pf_trigger == 0) {
384 int i_pf = i_op / (load_loop_blk * other_pf_trigger);
385 if (i_pf < n_pf_ker_l2) {
386 int offt = (i_pf + (i_load + 1) * jcp.reduce_dim)
388 if (jcp.prop_kind == backward_data && jcp.ver == ver_4vnni)
389 offt = (i_pf + (i_load + 1) * jcp.reduce_block)
392 mic_prefetcht1(ptr[aux_reg_load_data
393 + offt * jcp.typesize_in]);
394 } else if (i_pf < n_pf_ker_l2 + n_pf_ker_l1) {
396 auto pf_reg = last_block ? reg_load_data
398 int offt = (i_pf + i_load * jcp.reduce_dim
400 ? (wraparound ? jcp.reduce_dim : 0)
403 mic_prefetcht0(ptr[pf_reg + offt * jcp.typesize_in]);
404 } else if (i_pf < n_pf_ker_l1 + n_pf_ker_l2 + n_pf_out_l1) {
405 i_pf -= n_pf_ker_l1 + n_pf_ker_l2;
406 int offt = i_pf * jcp.load_block;
407 mic_prefetcht0(ptr[aux_reg_output_data
408 + offt * jcp.typesize_out]);
414 auto fma_block = [=](bool last_block) {
415 assert(jcp.reduce_loop_unroll % jcp.fma_step == 0);
417 int reduce_step = jcp.fma_step;
418 if (jcp.ver == ver_4vnni)
421 for (int i_reduce = 0; i_reduce < jcp.reduce_loop_unroll;
422 i_reduce += reduce_step) {
423 int load_scale = (jcp.ver == ver_4vnni) ? 2 : 1;
424 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
425 // if transposed input data used and if spatial size is
426 // not divided by transpose step (4) then for last reduce step
427 // we should load only needed load_registers data
428 // and clear remaining
429 if (jcp.transpose_src && jcp.is % jcp.fma_step && last_block
430 && i_reduce == jcp.reduce_loop_unroll - reduce_step) {
433 test(reg_reduce_pos_flag, FLAG_SP_LAST);
434 jz(load_all, T_NEAR);
436 const int n_loads = jcp.is % jcp.fma_step;
437 for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
439 vmovups(vreg_load(i_load, i_fma),
440 load_ptr(i_reduce + load_scale * i_fma,
443 vpxord(vreg_load(i_load, i_fma),
444 vreg_load(i_load, i_fma),
445 vreg_load(i_load, i_fma));
450 for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
451 vmovups(vreg_load(i_load, i_fma),
452 load_ptr(i_reduce + load_scale * i_fma, i_load));
456 for (int i_fma = 0; i_fma < jcp.fma_step; i_fma++) {
457 vmovups(vreg_load(i_load, i_fma),
459 + load_scale * i_fma,
465 for (int i_ur = 0; i_ur < ur; ++i_ur) {
466 if (jcp.ver == ver_avx512_core && jcp.expl_bcast
467 && load_loop_blk > 1)
468 vbroadcastss(vreg_bcast, bcast_ptr(i_reduce, i_ur, false));
469 for (int i_load = 0; i_load < load_loop_blk; ++i_load) {
470 if (jcp.ver == ver_4fma)
471 v4fmaddps(vreg_accum(i_load, i_ur),
472 vreg_load(i_load, 0),
473 bcast_ptr(i_reduce, i_ur, false));
474 else if (jcp.ver == ver_4vnni)
475 vp4dpwssd(vreg_accum(i_load, i_ur),
476 vreg_load(i_load, 0),
477 bcast_ptr(i_reduce, i_ur, false));
478 else if (jcp.ver == ver_avx512_core && jcp.expl_bcast
479 && load_loop_blk > 1)
480 vfmadd231ps(vreg_accum(i_load, i_ur),
481 vreg_load(i_load, 0), vreg_bcast);
483 vfmadd231ps(vreg_accum(i_load, i_ur),
484 vreg_load(i_load, 0),
485 bcast_ptr(i_reduce, i_ur, true));
486 prefetch_callback(ur, i_reduce, i_ur, i_load,
487 last_block, wraparound, reduce_step);
493 Label reduce_loop_tail;
495 mov(aux_reg_load_data, reg_load_data);
497 mov(aux_reg_bcast_data, aux1_reg_bcast_data);
500 mov(reduce_loop_iter, reg_reduce_loop_work);
501 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
502 jle(reduce_loop_tail, T_NEAR);
506 add(aux_reg_bcast_data, jcp.reduce_loop_bcast_step);
507 add(aux_reg_load_data, jcp.reduce_loop_load_step);
508 sub(reduce_loop_iter, jcp.reduce_loop_unroll);
509 jg(reduce_loop, T_NEAR);
518 void jit_avx512_common_1x1_conv_kernel::generate()
520 const auto &p = attr_.post_ops_;
521 for (int i = 0; i < p.len_; i++) {
522 auto &post_op = p.entry_[i];
523 if (post_op.is_eltwise()) {
524 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<avx512_common>(
527 post_op.eltwise.alpha,
530 } else if (post_op.is_depthwise()) {
531 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<avx512_common>(
533 post_op.depthwise.alg
540 mov(reg_bcast_data, ptr[param1 + GET_OFF(bcast_data)]);
541 mov(reg_load_data, ptr[param1 + GET_OFF(load_data)]);
542 mov(reg_output_data, ptr[param1 + GET_OFF(output_data)]);
544 sub(rsp, stack_space_needed);
547 mov(reg_bias_data, ptr[param1 + GET_OFF(bias_data)]);
549 mov(reg_load_loop_work, ptr[param1 + GET_OFF(load_dim)]);
550 mov(reg_bcast_loop_work, ptr[param1 + GET_OFF(bcast_dim)]);
551 mov(EVEX_compress_addr(rsp, bcast_loop_work_offt), reg_bcast_loop_work);
552 mov(reg_reduce_loop_work, ptr[param1 + GET_OFF(reduce_dim)]);
553 mov(reg_reduce_pos_flag, ptr[param1 + GET_OFF(first_last_flag)]);
554 if (one_of(jcp.prop_kind, forward_training, forward_inference))
555 mov(reg_relu_ns, reinterpret_cast<size_t>(&jcp.eltwise.alpha));
556 if (jcp.prop_kind == backward_weights)
557 mov(reg_output_stride, ptr[param1 + GET_OFF(output_stride)]);
558 mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
560 auto load_loop_body = [=](int load_loop_blk) {
561 bcast_loop(load_loop_blk);
562 add(reg_load_data, load_loop_blk * jcp.load_loop_load_step);
563 switch (jcp.prop_kind) {
564 case forward_training:
565 case forward_inference:
567 load_loop_blk * jcp.load_block * jcp.typesize_out);
569 load_loop_blk * jcp.bcast_dim * jcp.load_block *
574 load_loop_blk * jcp.bcast_dim * jcp.load_block *
577 case backward_weights:
578 for (int i_load = 0; i_load < load_loop_blk; i_load++)
579 add(reg_output_data, reg_output_stride);
582 assert(!"invalid prop_kind");
584 sub(reg_load_loop_work, load_loop_blk * jcp.load_loop_iter_step);
585 add(reg_oc_off, load_loop_blk * jcp.oc_block * jcp.typesize_out);
588 const int simd_w = 16;
590 Label load_loop_blk[7];
592 static const int ur_cases_fma_embd_bcast[] = { 2, 4, 5, 8, 14, 32 };
593 static const int ur_cases_fma_expl_bcast[] = { 2, 5, 6, 9, 14, 32 };
594 static const int ur_cases_4fma[] = { 2, 4, 6, 12, 32 };
596 const int size_ur_cases_fma
597 = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ?
598 sizeof(ur_cases_fma_expl_bcast) :
599 sizeof(ur_cases_fma_embd_bcast);
600 const int size_ur_cases_4fma = sizeof(ur_cases_4fma);
602 const int *ur_cases_fma = (jcp.ver == ver_avx512_core && jcp.expl_bcast) ?
603 ur_cases_fma_expl_bcast :
604 ur_cases_fma_embd_bcast;
605 const int *ur_cases = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni)
606 ? ur_cases_4fma : ur_cases_fma;
607 const int num_ur_cases = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni ?
612 for (int ur_idx = num_ur_cases - 1; ur_idx > 0; ur_idx--) {
613 int label_idx = num_ur_cases - ur_idx - 1;
614 if (jcp.ur <= ur_cases[ur_idx]) {
615 cmp(reg_load_loop_work, simd_w * (label_idx + 1));
616 jle(load_loop_blk[label_idx], T_NEAR);
620 for (int ur_idx = 0; ur_idx < num_ur_cases; ur_idx++) {
621 if (jcp.ur <= ur_cases[ur_idx]) {
622 int label_idx = num_ur_cases - ur_idx - 1;
623 L(load_loop_blk[label_idx]);
625 if (label_idx == 0) {
626 cmp(reg_load_loop_work, 0);
627 je(load_loop_blk[num_ur_cases], T_NEAR);
629 load_loop_body(label_idx + 1);
630 if (label_idx - 1 > 0) {
631 cmp(reg_load_loop_work, 2 * label_idx * simd_w);
632 je(load_loop_blk[label_idx - 1], T_NEAR);
634 cmp(reg_load_loop_work, (label_idx + 1) * simd_w);
635 jge(load_loop_blk[label_idx]);
637 for (int idx = label_idx - 1; idx > 0; --idx) {
638 cmp(reg_load_loop_work, simd_w * (idx + 1));
639 je(load_loop_blk[idx], T_NEAR);
641 if (ur_idx < num_ur_cases - 2) {
642 cmp(reg_load_loop_work, simd_w);
643 jle(load_loop_blk[0], T_NEAR);
647 L(load_loop_blk[num_ur_cases]);
649 add(rsp, stack_space_needed);
653 for (auto& inj : eltwise_injectors)
654 inj->prepare_table();
657 bool jit_avx512_common_1x1_conv_kernel::post_ops_ok(
658 jit_1x1_conv_conf_t &jcp, const primitive_attr_t &attr) {
659 const auto &p = attr.post_ops_;
661 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
662 auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
663 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(); };
664 auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
668 case 1: return is_simple(0) || is_sum(0);
669 case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_simple(1));
670 case 3: return is_sum(0) && is_simple(1) && is_simple(2);
671 default: return false;
677 status_t jit_avx512_common_1x1_conv_kernel::init_conf(jit_1x1_conv_conf_t &jcp,
678 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
679 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
680 const primitive_attr_t &attr, int nthreads, bool reduce_src) {
681 if (!mayiuse(avx512_common)) return status::unimplemented;
683 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
684 const int simd_w = cpu_isa_traits<avx512_common>::vlen / sizeof(float);
685 const int ndims = src_d.ndims();
687 jcp.prop_kind = cd.prop_kind;
689 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
690 jcp.mb = src_d.dims()[0];
692 jcp.oc_without_padding = dst_d.dims()[1] / jcp.ngroups;
693 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
694 jcp.ic = src_d.dims()[1] / jcp.ngroups;
696 bool ok_to_pad_channels = true
698 && src_d.data_type() == data_type::f32;
699 if (ok_to_pad_channels) {
700 jcp.oc = rnd_up(jcp.oc, simd_w);
701 jcp.ic = rnd_up(jcp.ic, simd_w);
704 jcp.ih = (ndims == 3) ? 1 : src_d.dims()[2];
705 jcp.iw = src_d.dims()[ndims - 1];
706 jcp.oh = (ndims == 3) ? 1 : dst_d.dims()[2];
707 jcp.ow = dst_d.dims()[ndims - 1];
709 jcp.kh = (ndims == 3) ? 1 : weights_d.dims()[with_groups + 2];
710 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
712 jcp.t_pad = (ndims == 3) ? 0 : cd.padding[0][0];
713 jcp.l_pad = cd.padding[0][ndims - 3];
715 jcp.stride_h = (ndims == 3) ? 1 : cd.strides[0];
716 jcp.stride_w = cd.strides[ndims - 3];
718 jcp.src_fmt = src_d.format();
719 jcp.with_bias = pick_by_prop_kind(jcp.prop_kind, cd.bias_desc.format,
720 memory_format::undef, cd.diff_bias_desc.format)
721 != memory_format::undef;
723 jcp.os = jcp.oh * jcp.ow;
724 jcp.is = jcp.ih * jcp.iw;
725 jcp.tr_is = rnd_up(jcp.is, 4);
727 if (!post_ops_ok(jcp, attr))
728 return status::unimplemented;
730 const auto &p = attr.post_ops_;
731 jcp.with_sum = p.find(primitive_kind::sum) != -1;
732 const int eltwise_ind = p.find(primitive_kind::eltwise);
733 jcp.with_eltwise = eltwise_ind != -1;
734 if (jcp.with_eltwise) {
735 jcp.eltwise = p.entry_[eltwise_ind].eltwise;
736 if (dst_d.data_type() == data_type::s32) return status::unimplemented;
741 && everyone_is(pick(ndims - 3, nCw16c, nChw16c), src_d.format(),
743 && one_of(cd.bias_desc.format, memory_format::undef, any, x);
744 if (!args_ok) return status::unimplemented;
747 && jcp.oc % simd_w == 0 && jcp.ic % simd_w == 0
748 && jcp.t_pad == 0 && jcp.l_pad == 0
749 && jcp.stride_w == 1 && jcp.stride_h == 1 // TODO: support some strides
750 && jcp.kh == 1 && jcp.kw == 1;
751 if (!args_ok) return status::unimplemented;
753 jcp.ic_block = jcp.oc_block = simd_w;
754 jcp.transpose_src = false;
756 if (mayiuse(avx512_mic_4ops)
757 && ((one_of(jcp.prop_kind, forward_training, forward_inference)
758 && src_d.data_type() == data_type::s16
759 && weights_d.data_type() == data_type::s16
760 && dst_d.data_type() == data_type::s32)
761 || (jcp.prop_kind == backward_data
762 && src_d.data_type() == data_type::s32
763 && weights_d.data_type() == data_type::s16
764 && dst_d.data_type() == data_type::s16)))
766 const int is_bwd_d = jcp.prop_kind == backward_data;
767 memory_format_t weights_format = with_groups
768 ? pick(2 * ndims - 6 + is_bwd_d, gOIw8i16o2i, gOIw8o16i2o,
769 gOIhw8i16o2i, gOIhw8o16i2o)
770 : pick(2 * ndims - 6 + is_bwd_d, OIw8i16o2i, OIw8o16i2o,
771 OIhw8i16o2i, OIhw8o16i2o);
773 if (weights_d.format() != weights_format)
774 return status::unimplemented;
778 jcp.typesize_in = sizeof(prec_traits<data_type::s16>::type);
779 jcp.typesize_out = sizeof(prec_traits<data_type::s32>::type);
781 else if (everyone_is(data_type::f32, src_d.data_type(),
782 weights_d.data_type(), dst_d.data_type()))
784 const int is_bwd_d = jcp.prop_kind == backward_data;
785 memory_format_t weights_format = with_groups
786 ? pick(2 * ndims - 6 + is_bwd_d, gOIw16i16o, gIOw16o16i,
787 gOIhw16i16o, gIOhw16o16i)
788 : pick(2 * ndims - 6 + is_bwd_d, OIw16i16o, IOw16o16i,
789 OIhw16i16o, IOhw16o16i);
791 if (weights_d.format() != weights_format)
792 return status::unimplemented;
793 if (jcp.prop_kind != backward_weights && mayiuse(avx512_mic_4ops) &&
794 ((jcp.prop_kind == backward_data) ? jcp.oc_block : jcp.ic_block) % 4
798 } else if (jcp.prop_kind == backward_weights && mayiuse(avx512_mic_4ops)
800 /* Heuristic condition for relation of src size to oc. Otherwise
801 the src transposition overhead exceed the benefit from 4fma
803 && ((jcp.is * jcp.ic) / jcp.oc <= 2048)
804 && mkldnn_thr_syncable()
807 jcp.transpose_src = true;
811 jcp.ver = (mayiuse(avx512_core)) ? ver_avx512_core : ver_fma;
814 jcp.typesize_in = sizeof(prec_traits<data_type::f32>::type);
815 jcp.typesize_out = sizeof(prec_traits<data_type::f32>::type);
817 return status::unimplemented;
820 /* once all the formats are set, check the padding consistency */
822 && jcp.ic <= src_d.blocking_desc().padding_dims[1]
823 && jcp.oc <= dst_d.blocking_desc().padding_dims[1]
824 && jcp.ic <= weights_d.blocking_desc().padding_dims[with_groups + 1]
825 && jcp.oc <= weights_d.blocking_desc().padding_dims[with_groups + 0];
826 if (!args_ok) return status::unimplemented;
828 const int SMALL_SPATIAL = 10;
829 const int BIG_SPATIAL = 28;
830 const int BIG_REDUCE_DIM = 1024;
831 const int BIG_LOAD_DIM = 256;
833 int load_blocking{ 0 };
834 int load_blocking_max{ 0 };
835 int bcast_blocking{ 0 };
836 int bcast_blocking_max{ 0 };
837 int reduce_blocking{ 0 };
838 int reduce_blocking_max{ 0 };
840 jcp.load_grp_count = 1;
842 const int L1_capacity = get_cache_size(1, true) / sizeof(float);
843 const int L2_size = get_cache_size(2, true) / sizeof(float);
844 const int L2_capacity = (L2_size * 3) / 4;
846 if (one_of(jcp.prop_kind, forward_training, forward_inference,
848 if (one_of(jcp.prop_kind, forward_training, forward_inference)) {
849 jcp.reduce_dim = jcp.ic;
850 jcp.reduce_block = jcp.ic_block;
852 jcp.load_dim = jcp.oc;
853 jcp.load_block = jcp.oc_block;
855 jcp.bcast_dim = jcp.is;
857 jcp.reduce_dim = jcp.oc;
858 jcp.reduce_block = jcp.oc_block;
860 jcp.load_dim = jcp.ic;
861 jcp.load_block = jcp.ic_block;
863 jcp.bcast_dim = jcp.os;
865 jcp.reduce_loop_unroll = jcp.reduce_block;
866 jcp.reduce_loop_bcast_step
867 = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in;
869 if (jcp.prop_kind == backward_data && jcp.ver == ver_4vnni) {
870 jcp.reduce_loop_load_step
871 = jcp.reduce_loop_unroll * jcp.ic * jcp.typesize_in;
872 jcp.load_loop_load_step
873 = jcp.oc_block * jcp.ic_block * jcp.typesize_in;
875 jcp.reduce_loop_load_step
876 = jcp.reduce_loop_unroll * jcp.load_block * jcp.typesize_in;
877 jcp.load_loop_load_step
878 = jcp.reduce_dim * jcp.load_block * jcp.typesize_in;
881 // adjusting registry blocking
882 int max_regs, min_regs, size_treshold, ur_step;
884 = (one_of(jcp.prop_kind, forward_training, forward_inference)) ?
887 if (jcp.ver == ver_avx512_core && (8 * jcp.mb) / nthreads >= 1) {
892 jcp.expl_bcast = true;
894 if (jcp.load_dim > 128 && jcp.load_dim < BIG_LOAD_DIM
895 && spatial > SMALL_SPATIAL && spatial < BIG_SPATIAL) {
900 bool is4ops = (jcp.ver == ver_4fma || jcp.ver == ver_4vnni);
902 max_regs = is4ops ? 28 : 30;
904 size_treshold = is4ops ? 28 : 14;
905 ur_step = is4ops ? 4 : 1;
906 jcp.expl_bcast = false;
907 jcp.use_vmovntps = true;
910 for (int ur_w = max_regs; ur_w >= min_regs; ur_w -= ur_step) {
911 if ((spatial >= size_treshold && spatial % ur_w == 0)
912 || (spatial < size_treshold && jcp.os % ur_w == 0)) {
918 jcp.ur = nstl::min(max_regs, jcp.os);
919 int os_tail = jcp.os % max_regs;
920 for (int i = max_regs; i >= min_regs; i -= ur_step) {
921 int i_tail = jcp.os % i;
922 if (i_tail > os_tail || i_tail == 0) {
931 jcp.reduce_loop_unroll = jcp.reduce_block;
932 jcp.reduce_loop_bcast_step
933 = jcp.reduce_loop_unroll * jcp.bcast_dim * jcp.typesize_in;
935 jcp.bcast_block = jcp.ur;
937 jcp.bcast_loop_output_step = jcp.ur * jcp.load_block * jcp.typesize_out;
938 jcp.bcast_loop_output_substep = -1; // unused
939 jcp.bcast_loop_bcast_step = jcp.ur * jcp.reduce_block * jcp.typesize_in;
940 jcp.bcast_loop_bcast_substep = -1; // unused
942 jcp.load_loop_iter_step = jcp.load_block;
944 if (jcp.prop_kind == backward_data)
945 jcp.loop_order = loop_lbr;
947 jcp.loop_order = reduce_src ? loop_blr : loop_lbr;
949 int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
950 int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
951 int nb_load = div_up(jcp.load_dim, jcp.load_block);
953 if (jcp.ver == ver_avx512_core && jcp.expl_bcast) {
954 if (jcp.load_dim <= BIG_LOAD_DIM && spatial > SMALL_SPATIAL
955 && spatial < BIG_SPATIAL)
956 reduce_blocking = nstl::min(jcp.reduce_dim, 80);
957 else if (spatial > SMALL_SPATIAL)
958 reduce_blocking = nstl::min(jcp.reduce_dim, 512);
960 reduce_blocking = nstl::min(jcp.reduce_dim, 256);
962 if ((jcp.mb > 28 && spatial >= 28)
963 || (jcp.mb > 112 && spatial >= 17))
964 jcp.use_vmovntps = true;
966 jcp.use_vmovntps = false;
969 reduce_blocking = nb_reduce;
970 if (spatial <= SMALL_SPATIAL && jcp.reduce_dim >= BIG_REDUCE_DIM)
971 reduce_blocking = 16;
972 else if (spatial > SMALL_SPATIAL
973 && jcp.reduce_dim >= BIG_REDUCE_DIM)
975 reduce_blocking = best_divider(nb_reduce, 1, reduce_blocking, true);
976 reduce_blocking *= jcp.reduce_block;
979 // Check input data cache aliasing.
980 // For other ISA constants may be updated.
981 // 64 * 1024 is chosen due to 1MB L2 16-way cache.
982 // 7 is empirical value. It is about half of 16.
983 // So we leave about half of the set for other data - weights, dst
984 int way_size = (64 * 1024) / jcp.typesize_in;
986 if (jcp.bcast_dim * reduce_blocking > way_size * max_hits) {
987 int nrb = reduce_blocking / simd_w;
988 int sp = jcp.bcast_dim;
989 int wl = way_size / simd_w;
990 for (int start_off = 0; start_off < jcp.ur; start_off++) {
991 for (int off = start_off, hits = 0; off < sp * nrb; off += wl) {
992 if (off % sp >= jcp.ur || ++hits < max_hits)
994 int max_r_blocking = simd_w * nstl::max(1, (off + wl) / sp);
996 = nstl::min(reduce_blocking, max_r_blocking);
1002 if (reduce_blocking < jcp.reduce_dim) {
1003 jcp.use_vmovntps = false;
1004 if (jcp.prop_kind == backward_data)
1005 jcp.loop_order = reduce_src ? loop_lbr : loop_rlb;
1007 jcp.loop_order = reduce_src ? loop_rbl : loop_rlb;
1009 load_blocking = jcp.load_dim;
1011 int load_size = jcp.load_dim * jcp.reduce_dim;
1012 int bcast_size = jcp.mb * jcp.ngroups * jcp.bcast_dim * jcp.reduce_dim;
1014 if (jcp.ver == ver_avx512_core && nthreads <= 28 && jcp.mb < nthreads
1015 && nb_load * nb_bcast > nthreads) {
1016 // Some heuristic here
1017 float calc_koef = 0.01, best_cost = FLT_MAX;
1018 int n_lgc = nthreads;
1019 float ratio = (float)load_size / (float)bcast_size;
1020 int best_lgc = ratio > 1 ? n_lgc : 1;
1021 auto calc_job_cost = [&](int lb, int tg, float mem_k) {
1022 int bb_size = jcp.mb * div_up(nb_bcast, tg);
1023 float calc_size = (float)(bb_size * jcp.ur)
1024 * (lb * jcp.load_block) * jcp.reduce_dim;
1025 float mem_size = (float)(bb_size * jcp.ur + lb * jcp.load_block)
1027 return calc_koef * calc_size + mem_k * mem_size;
1029 for (int lgc, ilgc = 0; ilgc < n_lgc; ilgc++) {
1030 lgc = ratio > 1 ? n_lgc - ilgc : ilgc + 1;
1031 int min_lb = nb_load / lgc;
1032 int max_lb = div_up(nb_load, lgc);
1033 int min_tg = nthreads / lgc;
1034 int max_tg = div_up(nthreads, lgc);
1035 // Some heuristic here
1036 float mem_koef = (max_tg == 1) ? 1.f : 1.3f;
1037 float job_cost = 0.;
1038 if (nthreads % lgc < nb_load % lgc) {
1039 job_cost = calc_job_cost(max_lb, min_tg, mem_koef);
1041 auto job_cost1 = calc_job_cost(max_lb, max_tg, mem_koef);
1042 auto job_cost2 = calc_job_cost(min_lb, min_tg, mem_koef);
1043 job_cost = nstl::max(job_cost1, job_cost2);
1046 if (job_cost < best_cost) {
1048 best_cost = job_cost;
1051 jcp.load_grp_count = best_lgc;
1052 load_blocking = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
1054 jcp.load_grp_count = div_up(nthreads, jcp.mb * jcp.ngroups * nb_bcast);
1055 jcp.load_grp_count = best_divider(
1056 nthreads, jcp.load_grp_count, 2 * jcp.load_grp_count, false);
1059 if (jcp.ver == ver_avx512_core && jcp.expl_bcast && jcp.bcast_dim <= 64
1060 && load_size >= L2_size) {
1061 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 4);
1062 } else if (jcp.bcast_dim <= 49 && jcp.mb <= nthreads
1063 && jcp.load_dim > 512 && jcp.load_dim / jcp.reduce_dim >= 4) {
1064 jcp.load_grp_count = nstl::max(jcp.load_grp_count, 2);
1065 load_blocking = jcp.load_block;
1068 if (jcp.ver == ver_4fma && jcp.bcast_dim * jcp.mb < jcp.load_dim
1069 && jcp.oh * jcp.ow > 64
1070 && IMPLICATION(reduce_src, jcp.load_dim < 1024)) {
1071 /* Looking for best loading dimension blocking
1072 * to get the best thread and data read/write efficiency
1073 * by finding the optimal 'load_chunk' value
1075 * for 72 threads and convolution with mb=1, ih=iw=7, oc = 512
1076 * the 'best' load_chunk value should be 1
1077 * TODO: remove heuristic constants in above condition
1078 * TODO: check this blocking for other ISA
1080 float best_eff = -1.f;
1083 for (int load_chunk = 1; load_chunk <= nb_load; load_chunk++) {
1084 int lgc = div_up(nb_load, load_chunk);
1087 int thr_per_grp = div_up(nthreads, lgc);
1088 int bcast_per_thr = div_up(jcp.mb * nb_bcast, thr_per_grp)
1090 int load_per_thr = load_chunk * simd_w;
1091 float data_norm = (bcast_per_thr + load_per_thr) / 2.f;
1092 float data_eff = (bcast_per_thr * load_per_thr)
1093 / (data_norm * data_norm);
1094 float thr_eff_over_grp = (float)nstl::max(1, nthreads / lgc)
1095 / div_up(nthreads, lgc);
1096 float thr_eff_in_grp = ((float)jcp.mb * nb_bcast)
1097 / rnd_up(jcp.mb * nb_bcast, thr_per_grp);
1098 float thr_eff = thr_eff_over_grp * thr_eff_in_grp;
1099 float load_eff = (float)nb_load / rnd_up(nb_load, lgc);
1100 float overall_eff = data_eff + thr_eff + load_eff;
1101 if (overall_eff > best_eff) {
1102 best_eff = overall_eff;
1106 jcp.load_grp_count = best_lgc;
1108 = div_up(nb_load, jcp.load_grp_count) * jcp.load_block;
1110 bcast_blocking = div_up(jcp.mb * jcp.ngroups * nb_bcast,
1111 div_up(nthreads, jcp.load_grp_count))
1113 bcast_blocking = nstl::min(jcp.bcast_dim, bcast_blocking);
1114 bcast_blocking = rnd_up(bcast_blocking, jcp.bcast_block);
1117 = (L2_capacity - /* kernel_size - */
1118 2 * jcp.load_block * reduce_blocking
1119 - jcp.ur * reduce_blocking - 3 * 1024);
1120 if (jcp.reduce_dim * jcp.bcast_dim > L2_capacity)
1121 space_for_bcast /= 2;
1124 = nstl::max(jcp.bcast_block, space_for_bcast / reduce_blocking);
1125 bcast_blocking = nstl::min(
1126 bcast_blocking, rnd_dn(bcast_in_cache, jcp.bcast_block));
1128 load_blocking_max = load_blocking;
1129 bcast_blocking_max = bcast_blocking * 3 / 2;
1130 reduce_blocking_max = reduce_blocking;
1132 } else if (jcp.prop_kind == backward_weights) {
1134 jcp.use_vmovntps = false;
1135 if (jcp.is > SMALL_SPATIAL * SMALL_SPATIAL && jcp.ver == ver_4fma)
1136 jcp.use_vmovntps = true;
1138 if (jcp.transpose_src)
1139 jcp.reduce_dim = jcp.tr_is;
1141 jcp.reduce_dim = jcp.is;
1143 if (jcp.ver == ver_4fma) {
1144 // reduce_block should be divided by fma_step
1145 jcp.reduce_block = best_divider(jcp.reduce_dim, 4, 16, true, 4);
1147 jcp.reduce_block = best_divider(jcp.reduce_dim, 7, 16, true);
1148 if (jcp.reduce_dim % jcp.reduce_block != 0)
1149 jcp.reduce_block = best_divider(jcp.iw, 4, jcp.iw, false);
1150 if (jcp.reduce_block > 256) {
1151 jcp.reduce_block = 1;
1156 jcp.load_dim = jcp.oc;
1157 jcp.load_block = jcp.oc_block;
1159 jcp.bcast_dim = jcp.ic;
1160 jcp.bcast_block = jcp.ic_block;
1162 if (jcp.ver == ver_avx512_core && jcp.reduce_block <= 19) {
1163 // if reduce_block is big then generated JIT code may be big
1164 // for small values of ur because reduce_loop_unroll = reduce_block
1165 jcp.ur = jcp.bcast_block / 2;
1166 jcp.expl_bcast = true;
1168 jcp.ur = jcp.bcast_block;
1169 jcp.expl_bcast = false;
1172 jcp.reduce_loop_unroll = jcp.reduce_block;
1173 jcp.reduce_loop_bcast_step
1174 = jcp.reduce_loop_unroll * jcp.ic_block * jcp.typesize_in;
1175 jcp.reduce_loop_load_step
1176 = jcp.reduce_loop_unroll * jcp.oc_block * jcp.typesize_in;
1178 jcp.bcast_loop_output_step =
1179 jcp.oc_block * jcp.ic_block * jcp.typesize_out;
1180 jcp.bcast_loop_output_substep =
1181 jcp.oc_block * jcp.ur * jcp.typesize_out;
1182 jcp.bcast_loop_bcast_step =
1183 jcp.ic_block * jcp.reduce_dim * jcp.typesize_in;
1184 jcp.bcast_loop_bcast_substep = jcp.ur * jcp.typesize_in;
1186 jcp.load_loop_load_step = jcp.oc_block * jcp.os * jcp.typesize_in;
1187 jcp.load_loop_iter_step = jcp.oc_block;
1190 balance(jcp, nthreads);
1192 load_blocking = div_up(jcp.load_dim, jcp.load_block);
1193 load_blocking = best_divider(load_blocking, 16, load_blocking, false);
1194 load_blocking *= jcp.load_block;
1196 load_blocking_max = load_blocking;
1197 assert(jcp.load_dim % load_blocking == 0);
1199 int max_bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
1200 int min_bcast_blocking = 5;
1202 bcast_blocking = div_up(jcp.bcast_dim, jcp.bcast_block);
1203 bcast_blocking = best_divider(
1204 bcast_blocking, min_bcast_blocking, max_bcast_blocking, false);
1205 bcast_blocking *= jcp.bcast_block;
1206 bcast_blocking_max = bcast_blocking;
1207 assert(jcp.bcast_dim % bcast_blocking == 0);
1209 // for reduction balance
1210 if (jcp.ver == ver_avx512_core) {
1211 int max_reduce_blocking
1212 = nstl::min(L1_capacity / jcp.ur, jcp.reduce_dim);
1213 int min_reduce_blocking = nstl::min(
1214 L1_capacity / jcp.ur, nstl::max(jcp.iw, jcp.ih));
1215 reduce_blocking = best_divider(jcp.reduce_dim, min_reduce_blocking,
1216 max_reduce_blocking, true);
1218 = nstl::max(rnd_dn(reduce_blocking, jcp.reduce_block),
1221 int max_reduce_blocking = L2_capacity
1222 / ((bcast_blocking + load_blocking) * jcp.reduce_block);
1223 max_reduce_blocking = nstl::min(max_reduce_blocking,
1224 (L1_capacity / (jcp.bcast_block)) / jcp.reduce_block);
1226 int num_jobs = div_up(jcp.load_dim, load_blocking)
1227 * div_up(jcp.bcast_dim, bcast_blocking);
1228 int threads_per_job = nstl::max(1, nthreads / num_jobs);
1229 reduce_blocking = div_up(jcp.mb * jcp.reduce_dim, jcp.reduce_block);
1230 reduce_blocking = div_up(reduce_blocking, threads_per_job);
1232 reduce_blocking = best_divider(reduce_blocking,
1233 max_reduce_blocking - 2, max_reduce_blocking, true);
1234 reduce_blocking *= jcp.reduce_block;
1237 reduce_blocking_max = rnd_dn(reduce_blocking * 3 / 2, jcp.reduce_block);
1239 return status::unimplemented;
1241 assert(load_blocking);
1242 assert(load_blocking_max);
1243 assert(bcast_blocking);
1244 assert(bcast_blocking_max);
1245 assert(reduce_blocking);
1246 assert(reduce_blocking_max);
1247 assert(load_blocking % jcp.load_block == 0);
1248 assert(reduce_blocking % jcp.reduce_block == 0);
1249 assert(load_blocking_max % jcp.load_block == 0);
1250 assert(reduce_blocking_max % jcp.reduce_block == 0);
1251 if (jcp.ver == ver_4fma || jcp.ver == ver_4vnni) {
1252 if (jcp.ver == ver_4fma)
1253 assert(jcp.reduce_loop_unroll % jcp.fma_step == 0);
1254 if (jcp.ver == ver_4vnni)
1255 assert(jcp.reduce_loop_unroll % (2 * jcp.fma_step) == 0);
1256 assert(jcp.reduce_dim % jcp.reduce_loop_unroll == 0);
1259 assert(jcp.bcast_block % jcp.ur == 0);
1260 assert(jcp.reduce_dim % jcp.reduce_block == 0);
1262 jcp.ur_tail = jcp.bcast_dim % jcp.ur;
1264 jcp.nb_bcast_blocking = bcast_blocking / jcp.bcast_block;
1265 jcp.nb_bcast_blocking_max = bcast_blocking_max / jcp.bcast_block;
1266 jcp.nb_load_blocking = load_blocking / jcp.load_block;
1267 jcp.nb_load_blocking_max = load_blocking_max / jcp.load_block;
1268 jcp.nb_reduce_blocking = reduce_blocking / jcp.reduce_block;
1269 jcp.nb_reduce_blocking_max = reduce_blocking_max / jcp.reduce_block;
1271 jcp.nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1272 jcp.nb_load = div_up(jcp.load_dim, jcp.load_block);
1273 jcp.nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1275 return status::success;
1278 void jit_avx512_common_1x1_conv_kernel::init_scratchpad(
1279 memory_tracking::registrar_t &scratchpad,
1280 const jit_1x1_conv_conf_t &jcp) {
1281 using namespace mkldnn::impl::memory_tracking::names;
1283 if (jcp.prop_kind != backward_data && jcp.with_bias
1284 && jcp.oc != jcp.oc_without_padding)
1285 scratchpad.book(key_conv_padded_bias, jcp.typesize_out * jcp.oc);
1287 if (jcp.prop_kind == backward_weights) {
1288 const size_t wei_size = (size_t)jcp.ngroups * jcp.oc * jcp.ic;
1289 scratchpad.book(key_conv_wei_reduction,
1290 jcp.typesize_out * wei_size * (jcp.nthr_mb - 1));
1293 if (jcp.transpose_src) {
1294 const size_t tr_src_size =
1295 (size_t)jcp.nthr_mb * jcp.ngroups * jcp.ic * jcp.tr_is;
1296 scratchpad.book(key_conv_tr_src, jcp.typesize_out * tr_src_size);
1297 scratchpad.book(key_conv_tr_src_bctx,
1298 sizeof(simple_barrier::ctx_t) * jcp.nthr);
1302 void jit_avx512_common_1x1_conv_kernel::balance(jit_1x1_conv_conf_t &jcp,
1305 // initialize jcp reduction threading properties
1306 jcp.nthr = jcp.nthr_mb = jcp.nthr_g = jcp.nthr_oc_b = jcp.nthr_ic_b = 1;
1307 if (nthreads < jcp.ngroups) {
1308 /* simplification... fortunately it doesn't hurt much */
1311 const int nb_bcast = div_up(jcp.bcast_dim, jcp.bcast_block);
1312 const int nb_load = div_up(jcp.load_dim, jcp.load_block);
1313 const int nb_reduce = div_up(jcp.reduce_dim, jcp.reduce_block);
1315 jcp.nthr_g = jcp.ngroups;
1316 const int nthr = nthreads / jcp.nthr_g;
1318 auto calc_mem_cost = [=](int nthr_mb, int nthr_oc_b, int nthr_ic_b) {
1319 /* calculate per thread memory cost (read/write). high level
1320 * optimizer tries to minimize memory consumption. few notes: (n1)
1321 * unclear why, but that essentially helps first convolution...
1322 * (n2) assuming the reduction over minibatch is always there:
1323 * - instead of 8 it should be 5 here (write ~= 2 read):
1324 * kernel: temporal workspace 1 write
1325 * reduction: 1 read from workspace and 1 write to the diff_wei
1326 * - but experiments showed 8 works better than 5 or 6... */
1327 int bcast_koeff = 1;
1329 int output_koeff = 12;
1330 if (jcp.transpose_src) {
1336 + (size_t)bcast_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
1337 * div_up(jcp.ngroups, jcp.nthr_g)
1338 * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block * jcp.reduce_block
1339 / jcp.stride_h / jcp.stride_w /* (n1) */
1340 + (size_t)load_koeff * div_up(jcp.mb * nb_reduce, nthr_mb)
1341 * div_up(jcp.ngroups, jcp.nthr_g)
1342 * div_up(nb_load, nthr_oc_b) * jcp.oc_block * jcp.reduce_block
1343 + (size_t)output_koeff /* (n2) */
1344 * div_up(jcp.ngroups, jcp.nthr_g) * div_up(nb_load, nthr_oc_b)
1345 * div_up(nb_bcast, nthr_ic_b) * jcp.ic_block
1349 int nthr_mb = 1, nthr_oc_b = 1, nthr_ic_b = 1;
1350 auto best_mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
1352 /* step 1: find the best thread distribution with lowest memory cost */
1353 const int nthr_mb_max = nstl::min(nthr, jcp.mb * nb_reduce);
1354 for (nthr_mb = 1; nthr_mb <= nthr_mb_max; ++nthr_mb) {
1355 const int nthr_par = nthr / nthr_mb;
1356 const int nthr_oc_b_max = nstl::min(nthr_par, nb_load);
1357 for (nthr_oc_b = 1; nthr_oc_b <= nthr_oc_b_max; ++nthr_oc_b) {
1358 nthr_ic_b = nstl::min(nthr_par / nthr_oc_b, nb_bcast);
1359 auto mem_cost = calc_mem_cost(nthr_mb, nthr_oc_b, nthr_ic_b);
1360 if (mem_cost <= best_mem_cost) {
1361 best_mem_cost = mem_cost;
1362 jcp.nthr_mb = nthr_mb;
1363 jcp.nthr_oc_b = nthr_oc_b;
1364 jcp.nthr_ic_b = nthr_ic_b;
1368 if (!mkldnn_thr_syncable()) { assert(nthr_mb == 1); break; }
1370 if (jcp.nthr_mb > nthreads / 2 && jcp.nthr_mb < nthreads)
1371 jcp.nthr_mb = nstl::min(jcp.mb, nthreads);
1373 jcp.nthr = jcp.nthr_mb * jcp.nthr_g * jcp.nthr_oc_b * jcp.nthr_ic_b;
1374 assert(jcp.nthr <= nthreads);