1 /*******************************************************************************
2 * Copyright 2018-2019 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include <common/memory_tracking.hpp>
18 #include <common/primitive_attr.hpp>
19 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
23 #include "cpu_memory.hpp"
25 #include "jit_uni_x8s8s32x_conv_kernel.hpp"
27 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
33 using namespace mkldnn::impl::prop_kind;
34 using namespace mkldnn::impl::memory_format;
35 using namespace mkldnn::impl::memory_tracking::names;
36 using namespace mkldnn::impl::utils;
38 using namespace Xbyak;
40 template <cpu_isa_t isa>
41 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in,
42 const Xbyak::Operand &op, bool scalar_load) {
43 Xmm xmm_in = Xmm(vmm_in.getIdx());
51 uni_vmovups(vmm_in, op);
56 movsx(reg_tmp_32, op);
57 movq(xmm_in, reg_tmp_64);
59 uni_vpmovsxbd(vmm_in, op);
64 movzx(reg_tmp_32, op);
65 movq(xmm_in, reg_tmp_64);
67 uni_vpmovzxbd(vmm_in, op);
70 default: assert(!"unsupported data type");
73 if (type_in != data_type::f32)
74 uni_vcvtdq2ps(vmm_in, vmm_in);
77 template <cpu_isa_t isa>
78 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store, bool need_pack) {
79 Ymm ymm_dst = Ymm(vmm_dst.getIdx());
80 Xmm xmm_dst = Xmm(vmm_dst.getIdx());
86 movq(reg_tmp_64, xmm_dst);
89 uni_vmovups(op, vmm_dst);
94 uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
96 if (isa != sse42 && !scalar_store && need_pack)
97 vpermq(ymm_dst, ymm_dst, 0x08);
100 uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
103 movq(reg_tmp_64, xmm_dst);
114 uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
116 if (isa != sse42 && !scalar_store && need_pack)
117 vpermq(ymm_dst, ymm_dst, 0x08);
120 uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
123 movq(reg_tmp_64, xmm_dst);
134 assert(!"unknown dst_dt");
138 template <cpu_isa_t isa>
139 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::apply_filter(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step,
140 int tail_size, bool h_padded) {
144 int nb_ic = jcp.nb_ic;
145 int stride_w = jcp.stride_w;
146 int dilate_w = jcp.dilate_w + 1;
147 int ic_blk = jcp.ic_block;
148 int oc_blk = jcp.oc_block;
150 int repeats = isa == sse42 && oc_step > (oc_blk / 2) ? 2 : 1;
152 for (int ki = 0; ki < kw; ki++) {
153 int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w));
154 int jj_end = ur_w - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w));
156 int _start = (jcp.signed_input) ? 0 : jj_start;
157 int _end = (jcp.signed_input) ? ur_w : jj_end;
159 for (int r = 0; r < repeats; r++) {
160 for (int jj = _start; jj < _end; jj++) {
161 int inp_off = (ki * dilate_w + jj * stride_w - pad_l) * jcp.ic * jcp.ngroups;
163 if (h_padded || jj < jj_start || jj >= jj_end) {
164 uni_vpxor(get_src_reg(jj), get_src_reg(jj), get_src_reg(jj));
165 uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
167 uni_vpbroadcastd(get_src_reg(jj), ptr[aux1_reg_input + jcp.typesize_in * inp_off]);
169 if (jcp.signed_input) {
170 uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
174 if (h_padded || jj < jj_start || jj >= jj_end) {
175 uni_vpxor(get_src_reg(jj), get_src_reg(jj), get_src_reg(jj));
177 uni_vpbroadcastd(get_src_reg(jj), ptr[aux1_reg_input + jcp.typesize_in * inp_off]);
180 if (jcp.signed_input)
181 uni_vpsubb(get_src_reg(jj), get_src_reg(jj), vmm_shift);
185 for (int ii = 0; ii < oc_blocks; ii++) {
186 int ker_off = ii * nb_ic * kd * kh * kw * ic_blk * oc_blk + ki * ic_blk * oc_blk + r * ic_blk * (oc_blk / 2);
187 uni_vmovups(get_ker_reg(0), ptr[aux1_reg_kernel + jcp.typesize_in * ker_off]);
189 for (int jj = _start; jj < _end; jj++) {
190 Vmm vmm_src = get_src_reg(jj);
192 uni_vmovups(get_tmp_reg(0), vmm_src);
193 uni_vpmaddubsw(get_tmp_reg(0), get_tmp_reg(0), get_ker_reg(0));
195 uni_vpmaddubsw(get_tmp_reg(0), vmm_src, get_ker_reg(0));
197 uni_vpmaddwd(get_tmp_reg(0), get_tmp_reg(0), vmm_one);
198 uni_vpaddd(get_acc_reg(r*jcp.ur_w*jcp.nb_oc_blocking + ur_w * ii + jj),
199 get_acc_reg(r*jcp.ur_w*jcp.nb_oc_blocking + ur_w * ii + jj), get_tmp_reg(0));
206 template <cpu_isa_t isa>
207 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::oh_step_unroll_kw(int ur_w,
208 int pad_l, int pad_r, int oc_blocks, int oc_step, bool h_padded) {
210 int ic_blk = jcp.ic_block;
211 int oc_blk = jcp.oc_block;
217 mov(aux1_reg_input, aux_reg_input);
218 mov(aux1_reg_kernel, aux_reg_kernel);
220 mov(reg_ic_iter, jcp.ic);
223 cmp(reg_ic_iter, ic_blk);
226 apply_filter(ur_w, pad_l, pad_r, oc_blocks, oc_step, 0, h_padded);
228 add(aux1_reg_input, ic_blk * jcp.typesize_in);
229 add(aux1_reg_kernel, kw * ic_blk * oc_blk * jcp.typesize_in);
230 sub(reg_ic_iter, ic_blk);
231 jmp(ic_main_loop, T_NEAR);
235 int ic_tail_size = jcp.ic % jcp.ic_block;
237 if (ic_tail_size > 0)
238 apply_filter(ur_w, pad_l, pad_r, oc_blocks, oc_step, ic_tail_size, h_padded);
243 template <cpu_isa_t isa>
244 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::kh_loop(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step) {
247 int dilate_h = jcp.dilate_h + 1;
248 const int inp_mult = jcp.ic * dilate_h * jcp.ngroups;
250 Label t_overflow_label, no_t_overflow_label,
251 b_overflow_label, no_b_overflow_label;
253 auto h_overflow_func = [&] () {
254 Label h_overflow_label, no_h_overflow_label;
255 cmp(reg_overflow, 0);
256 je(no_h_overflow_label, T_NEAR);
257 L(h_overflow_label); {
258 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, true);
260 add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * rnd_up(jcp.ic, jcp.ic_block));
262 cmp(reg_overflow, 0);
263 jg(h_overflow_label, T_NEAR);
265 L(no_h_overflow_label);
268 if (jcp.signed_input) {
269 mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]);
274 mov(reg_kj, ptr[this->param1 + GET_OFF(kh_padding)]);
275 if ((jcp.signed_input) || (!jcp.signed_input &&
276 (jcp.kh - 1) * (jcp.dilate_h + 1) < nstl::max(jcp.t_pad, jcp.b_pad))) {
278 je(skip_kh_loop, T_NEAR);
284 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, false);
286 add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * rnd_up(jcp.ic, jcp.ic_block));
287 add(aux_reg_input, jcp.typesize_in * iw * inp_mult);
291 jg(kh_label, T_NEAR);
296 if (jcp.signed_input) {
297 mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]);
302 template <cpu_isa_t isa>
303 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::kd_loop(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step) {
308 int dilate_d = jcp.dilate_d + 1;
310 auto d_overflow_func = [&] () {
311 Label d_overflow_label, no_d_overflow_label, aux_kh_label;
312 cmp(reg_overflow, 0);
313 je(no_d_overflow_label, T_NEAR);
321 oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, true);
323 add(aux_reg_kernel, jcp.typesize_in * kw * jcp.oc_block * rnd_up(jcp.ic, jcp.ic_block));
326 jg(aux_kh_label, T_NEAR);
331 add(aux_reg_ker_d, jcp.typesize_in * kh * kw * jcp.oc_block * rnd_up(jcp.ic, jcp.ic_block));
332 mov(aux_reg_kernel, aux_reg_ker_d);
334 cmp(reg_overflow, 0);
335 jg(d_overflow_label);
337 L(no_d_overflow_label);
342 push(reg_compensation_base);
344 mov(aux_reg_inp_d, reg_input);
345 mov(aux_reg_ker_d, reg_kernel);
347 if (jcp.signed_input) {
348 mov(reg_overflow, ptr[param1 + GET_OFF(front_overflow)]);
353 mov(reg_kd, ptr[param1 + GET_OFF(kd_padding)]);
355 je(skip_kd_loop, T_NEAR);
360 kh_loop(ur_w, pad_l, pad_r, oc_blocks, oc_step);
362 add(aux_reg_inp_d, jcp.typesize_in * dilate_d * ih * iw * jcp.ic * jcp.ngroups);
363 add(aux_reg_ker_d, jcp.typesize_in * kh * kw * jcp.oc_block * rnd_up(jcp.ic, jcp.ic_block));
364 mov(aux_reg_input, aux_reg_inp_d);
365 mov(aux_reg_kernel, aux_reg_ker_d);
369 jg(kd_label, T_NEAR);
374 if (jcp.signed_input) {
375 mov(reg_overflow, ptr[param1 + GET_OFF(back_overflow)]);
379 pop(reg_compensation_base);
384 template <cpu_isa_t isa>
385 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::width_blk_step(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step)
387 int repeats = isa == sse42 && oc_step > (jcp.oc_block / 2) ? 2 : 1;
389 for (int r = 0; r < repeats; r++)
390 for (int ii = 0; ii < oc_blocks; ii++)
391 for (int jj = 0; jj < ur_w; jj++)
392 uni_vpxor(get_acc_reg(r*jcp.ur_w*jcp.nb_oc_blocking + ur_w * ii + jj),
393 get_acc_reg(r*jcp.ur_w*jcp.nb_oc_blocking + ur_w * ii + jj),
394 get_acc_reg(r*jcp.ur_w*jcp.nb_oc_blocking + ur_w * ii + jj));
396 mov(imm_addr64, l_table);
397 uni_vmovups(vmm_one, ptr[imm_addr64 + 0 * vlen]);
398 uni_vmovups(vmm_shift, ptr[imm_addr64 + 1 * vlen]);
400 mov(aux_reg_input, reg_input);
401 mov(aux_reg_kernel, reg_kernel);
403 if (jcp.ndims == 5) {
404 kd_loop(ur_w, pad_l, pad_r, oc_blocks, oc_step);
406 kh_loop(ur_w, pad_l, pad_r, oc_blocks, oc_step);
410 pop(reg_scales_base);
412 mov(imm_addr64, l_table);
413 uni_vmovups(vmm_bias_alpha, ptr[imm_addr64 + 2 * vlen]);
415 const auto &p = attr_.post_ops_;
416 const int sum_idx = p.find(primitive_kind::sum);
417 const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
419 for (int r = 0; r < repeats; r++) {
420 auto get_dst_off = [=](int ii, int jj) {
421 if (jcp.with_dw_conv)
422 return (ii * jcp_dw.kh * jcp.ow + jj) * jcp.oc_block + r * (jcp.oc_block / 2);
424 return ii * jcp.oc_block + jj * jcp.oc * jcp.ngroups + r * (jcp.oc_block / 2);
427 int tail_size = isa == avx2 ? oc_step : nstl::min(jcp.oc_block / 2, oc_step - r * jcp.oc_block / 2);
428 bool is_scalar_store = isa == avx2 ? tail_size < jcp.oc_block : tail_size < jcp.oc_block / 2;
430 for (int ii = 0; ii < oc_blocks; ii++) {
432 int b_off = ii * jcp.oc_block + r * (jcp.oc_block / 2);
433 cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], false);
435 if (jcp.signed_input)
436 uni_vmulps(vmm_bias, vmm_bias, vmm_bias_alpha);
439 for (int jj = 0; jj < ur_w; jj++) {
440 Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
441 uni_vcvtdq2ps(vmm_dst, vmm_dst);
443 if (jcp.signed_input) {
444 int c_off = ii * jcp.oc_block + r * (jcp.oc_block / 2);
445 cvt2ps(data_type::s32, vmm_comp, ptr[reg_compensation_base + c_off * sizeof(int32_t)], false);
448 if (jcp.signed_input)
449 uni_vaddps(vmm_dst, vmm_dst, vmm_comp);
451 uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
453 int s_off = jcp.is_oc_scale * (ii * jcp.oc_block + r * (jcp.oc_block / 2));
454 cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], false);
455 uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
459 int eltwise_inj_idx = 0;
460 int depthwise_inj_idx = 0;
461 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
462 for (int i = 0; i < end_idx; i++) {
463 int start_idx = 1 + r * jcp.ur_w * jcp.nb_oc_blocking;
465 auto& post_op = p.entry_[i];
466 if (post_op.is_eltwise()) {
467 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, start_idx + oc_blocks * ur_w);
469 } else if (post_op.is_depthwise()) {
470 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
471 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
473 add(reg_d_weights, reg_oc_off);
474 add(reg_d_bias, reg_oc_off);
477 add(reg_d_weights, (jcp.oc_block / 2) * sizeof(float));
478 add(reg_d_bias, (jcp.oc_block / 2) * sizeof(float));
481 for (int ii = 0; ii < oc_blocks; ii++) {
482 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(start_idx + ur_w * ii,
483 start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
485 add(reg_d_weights, jcp.oc_block * sizeof(float));
486 add(reg_d_bias, jcp.oc_block * sizeof(float));
490 } else if (post_op.is_sum(false)) {
491 for (int ii = 0; ii < oc_blocks; ii++) {
492 for (int jj = 0; jj < ur_w; jj++) {
493 Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
494 int o_off = get_dst_off(ii, jj);
496 if (is_scalar_store) {
497 for (int oc = 0; oc < tail_size; oc++) {
498 uni_vpxor(vmm_prev_dst, vmm_prev_dst, vmm_prev_dst);
499 cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + (o_off + oc) * jcp.typesize_out], true);
501 if (oc < jcp.oc_block / 2) {
502 uni_vpslldq(vmm_prev_dst, vmm_prev_dst, oc * sizeof(float));
504 Ymm ymm_prev_dst = Ymm(vmm_prev_dst.getIdx());
505 vperm2i128(ymm_prev_dst, ymm_prev_dst, ymm_prev_dst, 0x01);
506 vpslldq(vmm_prev_dst, vmm_prev_dst, (oc - jcp.oc_block / 2) * sizeof(float));
509 if (p_sum_scale == 1.f) {
510 uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
512 uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 3 * vlen]);
516 cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], false);
518 if (p_sum_scale == 1.f) {
519 uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
521 uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 3 * vlen]);
529 for (int ii = 0; ii < oc_blocks; ii++) {
530 for (int jj = 0; jj < ur_w; jj++) {
531 Vmm vmm_dst = get_acc_reg(r * jcp.ur_w * jcp.nb_oc_blocking + ur_w * ii + jj);
532 int o_off = get_dst_off(ii, jj);
534 if (jcp.dst_dt != data_type::f32) {
535 if (attr_.round_mode_ == round_mode::nearest)
536 uni_vcvtps2dq(vmm_dst, vmm_dst);
537 else if (attr_.round_mode_ == round_mode::down) {
538 uni_vroundps(vmm_dst, vmm_dst, 1);
539 uni_vcvtps2dq(vmm_dst, vmm_dst);
541 assert(!"unimplemented");
544 if (is_scalar_store) {
545 for (int oc = 0; oc < tail_size; oc++) {
546 store_dst(ptr[reg_output + (o_off + oc) * jcp.typesize_out], vmm_dst, true, oc == 0);
549 Ymm ymm_dst = Ymm(vmm_dst.getIdx());
550 vperm2i128(ymm_tmp, ymm_dst, ymm_dst, 0x01);
551 vpalignr(ymm_dst, ymm_tmp, ymm_dst, jcp.typesize_out);
553 psrldq(vmm_dst, jcp.typesize_out);
557 store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, false);
563 push(reg_scales_base);
567 template <cpu_isa_t isa>
568 inline void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::solve_common(int oc_blocks, int oc_step)
571 int ur_w_tail = jcp.ur_w_tail;
572 int n_oi = jcp.ow / ur_w;
575 int dilate_w = jcp.dilate_w + 1;
576 int str_w = jcp.stride_w;
577 const int inp_mult = jcp.ic * jcp.ngroups;
578 const int out_mult = jcp.with_dw_conv ? jcp.oc_block : jcp.oc * jcp.ngroups;
580 int l_pad = jcp.l_pad;
581 int r_pad = nstl::max(0, (jcp.ow - 1) * str_w + (kw - 1) * dilate_w
583 int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w
585 if (r_pad1 > 0) n_oi--;
587 mov(reg_input, reg_input_base);
588 mov(reg_output, reg_output_base);
589 mov(reg_kernel, reg_kernel_base);
591 push(reg_input_base);
592 push(reg_output_base);
593 push(reg_kernel_base);
594 push(reg_scales_base);
599 if (n_oi < 0 && r_pad1 > 0)
600 width_blk_step(ur_w, l_pad, r_pad1, oc_blocks, oc_step); // "lrpad"
602 width_blk_step(ur_w, l_pad, 0, oc_blocks, oc_step); // "lpad"
603 add(reg_input, jcp.typesize_in * (ur_w * str_w - l_pad) * inp_mult);
604 add(reg_output, jcp.typesize_out * ur_w * out_mult);
608 xor_(reg_oi_iter, reg_oi_iter);
613 width_blk_step(ur_w, 0, 0, oc_blocks, oc_step); // "middle"
614 add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
615 add(reg_output, jcp.typesize_out * ur_w * out_mult);
618 cmp(reg_oi_iter, n_oi);
619 jl(ow_loop_label, T_NEAR);
622 if (r_pad1 > 0 && n_oi >=0) {
623 width_blk_step(ur_w, 0, r_pad1, oc_blocks, oc_step); // "rpad"
624 add(reg_input, jcp.typesize_in * ur_w * str_w * inp_mult);
625 add(reg_output, jcp.typesize_out * ur_w * out_mult);
629 width_blk_step(ur_w_tail, 0, r_pad, oc_blocks, oc_step); // "tail"
632 pop(reg_scales_base);
633 pop(reg_kernel_base);
634 pop(reg_output_base);
638 template <cpu_isa_t isa>
639 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::generate()
641 const auto &p = attr_.post_ops_;
642 int end_idx = jcp.with_dw_conv ? p.find(primitive_kind::convolution) : p.len_;
643 for (int i = 0; i < end_idx; i++) {
644 auto &post_op = p.entry_[i];
645 if (post_op.is_eltwise()) {
646 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
649 post_op.eltwise.alpha,
652 } else if (post_op.is_depthwise()) {
653 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
655 post_op.depthwise.alg
662 mov(reg_kernel_base, ptr[this->param1 + GET_OFF(filt)]);
663 mov(reg_input_base, ptr[this->param1 + GET_OFF(src)]);
664 mov(reg_output_base, ptr[this->param1 + GET_OFF(dst)]);
665 mov(reg_oc_work, ptr[this->param1 + GET_OFF(oc_work)]);
667 mov(reg_bias_base, ptr[this->param1 + GET_OFF(bias)]);
668 mov(reg_scales_base, ptr[this->param1 + GET_OFF(scales)]);
669 if (jcp.signed_input)
670 mov(reg_compensation_base, ptr[param1 + GET_OFF(compensation)]);
671 mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]);
673 Label main_loop_label;
677 cmp(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
678 jne(main_loop_label, T_NEAR);
680 solve_common(jcp.nb_oc_blocking, jcp.oc_block);
682 sub(reg_oc_work, jcp.nb_oc_blocking * jcp.oc_block);
684 jmp(exit_label, T_NEAR);
686 L(main_loop_label); {
687 cmp(reg_oc_work, jcp.oc_block);
688 jl(tail_label, T_NEAR);
690 solve_common(1, jcp.oc_block);
692 sub(reg_oc_work, jcp.oc_block);
693 add(reg_kernel_base, jcp.oc_block * jcp.nb_ic * jcp.kd * jcp.kh * jcp.kw * jcp.ic_block * jcp.typesize_in);
694 if (jcp.with_dw_conv)
695 add(reg_output_base, jcp.oc_block * jcp_dw.kh * jcp.ow * jcp.typesize_out);
697 add(reg_output_base, jcp.oc_block * jcp.typesize_out);
698 add(reg_bias_base, jcp.oc_block * jcp.typesize_bia);
699 add(reg_scales_base, jcp.is_oc_scale * jcp.oc_block * sizeof(float));
700 add(reg_compensation_base, jcp.oc_block * sizeof(int32_t));
701 add(reg_oc_off, jcp.oc_block * sizeof(float));
703 jmp(main_loop_label, T_NEAR);
708 if (jcp.oc % jcp.oc_block != 0)
709 solve_common(1, jcp.oc % jcp.oc_block);
717 for (auto& inj : eltwise_injectors)
718 inj->prepare_table();
721 template <cpu_isa_t isa>
722 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::prepare_table() {
723 const auto &p = attr_.post_ops_;
724 const int sum_idx = p.find(primitive_kind::sum);
725 const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
727 const uint16_t cvals_one[] = {
731 const int8_t cvals_shift[] = {
735 const int32_t cvals_scale[] = {
736 float2int(jcp.wei_adj_scale)
739 const int32_t cvals_sum_scale[] = {
740 float2int(p_sum_scale)
745 for (size_t i = 0; i < sizeof(cvals_one) / sizeof(cvals_one[0]); ++i) {
746 for (size_t d = 0; d < vlen / sizeof(uint16_t); ++d) {
751 for (size_t i = 0; i < sizeof(cvals_shift) / sizeof(cvals_shift[0]); ++i) {
752 for (size_t d = 0; d < vlen / sizeof(int8_t); ++d) {
757 for (size_t i = 0; i < sizeof(cvals_scale) / sizeof(cvals_scale[0]); ++i) {
758 for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
763 for (size_t i = 0; i < sizeof(cvals_sum_scale) / sizeof(cvals_sum_scale[0]); ++i) {
764 for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
765 dd(cvals_sum_scale[i]);
770 template <cpu_isa_t isa>
771 bool jit_uni_x8s8s32x_conv_fwd_kernel<isa>::post_ops_ok(
772 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
773 const auto &p = attr.post_ops_;
775 int dw_conv_idx = p.find(primitive_kind::convolution);
776 bool with_dw_conv = dw_conv_idx != -1;
778 auto all_post_ops_supported = [&]() {
781 int end_idx = with_dw_conv ? dw_conv_idx : p.len_;
782 for (int i = 0; i < end_idx; i++) {
783 ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise);
787 auto contain = [&](mkldnn::impl::primitive_kind_t kind) { return p.find(kind, 0, dw_conv_idx) != -1; };
788 auto count = [&](mkldnn::impl::primitive_kind_t kind) { return p.count(kind, 0, dw_conv_idx); };
790 return all_post_ops_supported() &&
791 count(primitive_kind::sum) <= 1 &&
792 IMPLICATION(with_dw_conv, !contain(primitive_kind::sum));
795 template <cpu_isa_t isa>
796 status_t jit_uni_x8s8s32x_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
797 const convolution_desc_t &cd, cpu_memory_t::pd_t &src_pd,
798 cpu_memory_t::pd_t &weights_pd, cpu_memory_t::pd_t &dst_pd,
799 cpu_memory_t::pd_t &bias_pd,
800 const primitive_attr_t &attr)
802 if (!mayiuse(isa)) return status::unimplemented;
804 const memory_desc_wrapper src_d(&src_pd);
805 const memory_desc_wrapper weights_d(&weights_pd);
806 const memory_desc_wrapper dst_d(&dst_pd);
807 const memory_desc_wrapper bias_d(&bias_pd);
809 jcp.prop_kind = cd.prop_kind;
811 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
812 int ndims = src_d.ndims();
815 jcp.ngroups = with_groups ? weights_d.dims()[0] : 1;
816 jcp.mb = src_d.dims()[0];
818 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
819 jcp.ic = src_d.dims()[1] / jcp.ngroups;
821 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
822 jcp.ih = src_d.dims()[ndims - 2];
823 jcp.iw = src_d.dims()[ndims - 1];
824 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
825 jcp.oh = dst_d.dims()[ndims - 2];
826 jcp.ow = dst_d.dims()[ndims - 1];
828 jcp.kd = (ndims == 5) ? weights_d.dims()[with_groups + 2] : 1;
829 jcp.kh = weights_d.dims()[with_groups + ndims - 2];
830 jcp.kw = weights_d.dims()[with_groups + ndims - 1];
832 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
833 jcp.t_pad = cd.padding[0][ndims - 4];
834 jcp.l_pad = cd.padding[0][ndims - 3];
836 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
837 jcp.stride_h = cd.strides[ndims - 4];
838 jcp.stride_w = cd.strides[ndims - 3];
840 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
841 jcp.dilate_h = cd.dilates[ndims - 4];
842 jcp.dilate_w = cd.dilates[ndims - 3];
844 jcp.src_fmt = src_d.format();
845 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
847 jcp.signed_input = src_d.data_type() == data_type::s8;
849 const int simd_w = 8;
852 jcp.nb_ic = div_up(jcp.ic, jcp.ic_block);
854 jcp.oc_block = simd_w;
855 jcp.oc_padded = rnd_up(jcp.oc, jcp.oc_block);
856 jcp.nb_oc = div_up(jcp.oc, jcp.oc_block);
858 if (jcp.ngroups != 1) {
859 if (jcp.ic % jcp.ic_block != 0 || jcp.oc % jcp.oc_block != 0)
860 return status::unimplemented;
863 jcp.src_dt = cd.src_desc.data_type;
864 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
865 jcp.dst_dt = cd.dst_desc.data_type;
867 if (!post_ops_ok(jcp, attr))
868 return status::unimplemented;
870 const auto &p = attr.post_ops_;
872 int dw_conv_ind = p.find(primitive_kind::convolution);
873 jcp.with_dw_conv = dw_conv_ind != -1;
874 if (jcp.with_dw_conv) {
875 if (ndims == 5) return status::unimplemented;
877 jcp.dw_conv_oh = jcp.oh;
878 jcp.dw_conv_ow = jcp.ow;
879 jcp.oh = p.entry_[dw_conv_ind].dw_conv.in_h;
880 jcp.ow = p.entry_[dw_conv_ind].dw_conv.in_w;
882 jcp.dw_conv_dst_dt = jcp.dst_dt;
883 jcp.dst_dt = p.entry_[dw_conv_ind].dw_conv.in_dt;
886 auto desired_act_fmt = (ndims == 5) ? ndhwc : nhwc;
887 auto desired_wei_fmt = (ndims == 5) ? with_groups ? jcp.signed_input ? gOdhIw8o4i_s8s8 : gOdhIw8o4i
888 : jcp.signed_input ? OdhIw8o4i_s8s8 : OdhIw8o4i
889 : with_groups ? jcp.signed_input ? gOhIw8o4i_s8s8 : gOhIw8o4i
890 : jcp.signed_input ? OhIw8o4i_s8s8 : OhIw8o4i;
892 if (src_d.format() == any)
893 CHECK(src_pd.set_format(desired_act_fmt));
894 if (src_d.format() != desired_act_fmt)
895 return status::unimplemented;
897 if (dst_d.format() == any)
898 CHECK(dst_pd.set_format(desired_act_fmt));
899 if (dst_d.format() != desired_act_fmt)
900 return status::unimplemented;
902 if (weights_d.format() == any)
903 CHECK(weights_pd.set_format(desired_wei_fmt));
904 if (weights_d.format() != desired_wei_fmt)
905 return status::unimplemented;
908 if (bias_d.format() == any)
909 CHECK(bias_pd.set_format(x));
910 if (bias_d.format() != x)
911 return status::unimplemented;
914 jcp.typesize_in = types::data_type_size(src_d.data_type());
915 jcp.typesize_out = types::data_type_size(dst_d.data_type());
916 jcp.typesize_acc = sizeof(int32_t);
917 jcp.typesize_bia = jcp.with_bias
918 ? types::data_type_size(bias_d.data_type())
921 const auto &oscales = attr.output_scales_;
922 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
924 assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
926 jcp.ur_h = 1; /* no code-unrolling by h so far */
927 jcp.ur_w = isa == avx2 ? 4 : 2;
928 jcp.nb_oc_blocking = nstl::min(2, jcp.nb_oc);
929 jcp.max_regs_ur = 12;
931 // WA to prevent fallback on gemm implementation
932 if (isa == sse42 && jcp.ic == 3) {
934 jcp.nb_oc_blocking = 1;
937 if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow;
938 jcp.ur_w_tail = jcp.ow % jcp.ur_w;
941 && jcp.l_pad <= jcp.ur_w
942 && IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0)
943 || (jcp.stride_w == 1 && jcp.stride_h == 1));
944 if (!args_ok) return status::unimplemented;
946 int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w
947 + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1));
948 if (r_pad_no_tail > jcp.ur_w)
949 return status::unimplemented;
951 if (jcp.l_pad > jcp.ur_w)
952 return status::unimplemented;
954 jcp.wei_adj_scale = (jcp.signed_input) ? (1.0f / 2.0f) : 1.0f;
956 return status::success;
959 template <cpu_isa_t isa>
960 void jit_uni_x8s8s32x_conv_fwd_kernel<isa>::init_scratchpad(
961 memory_tracking::registrar_t &scratchpad, const jit_conv_conf_t &jcp, const jit_conv_conf_t &jcp_dw,
962 const primitive_attr_t &attr) {
963 if (jcp.oc != jcp.oc_padded)
964 scratchpad.book(key_conv_padded_bias, (size_t)jcp.typesize_bia * jcp.oc_padded);
966 if (jcp.signed_input) {
967 size_t count = nstl::max(attr.output_scales_.count_, 8);
968 scratchpad.book(key_conv_adjusted_scales, sizeof(float) * count);
970 if (jcp.oc != jcp.oc_padded)
971 scratchpad.book(key_conv_padded_compensation, sizeof(int32_t) * jcp.oc_padded);
974 if (jcp.with_dw_conv) {
975 const int nthreads = mkldnn_get_max_threads();
976 size_t dw_conv_buffer_size_ = (size_t)jcp_dw.kh * jcp_dw.iw * jcp_dw.ch_block * jcp.nb_oc_blocking;
977 scratchpad.book(key_dw_conv_buffer, jcp_dw.typesize_in * dw_conv_buffer_size_ * nthreads);
979 if (jcp.oc != jcp.oc_padded)
980 scratchpad.book(key_dw_conv_padded_bias, (size_t)jcp_dw.typesize_bia * jcp.oc_padded);
984 template struct jit_uni_x8s8s32x_conv_fwd_kernel<avx2>;
985 template struct jit_uni_x8s8s32x_conv_fwd_kernel<sse42>;