1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
19 #include "type_helpers.hpp"
21 #include "cpu_memory.hpp"
23 #include "jit_uni_x8s8s32x_dw_conv_kernel.hpp"
25 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
31 using namespace mkldnn::impl::prop_kind;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::utils;
35 using namespace Xbyak;
37 template <cpu_isa_t isa>
38 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::load_src(int ur_ch_blocks, int ch_step, int ur_w) {
39 int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
40 for (int i = 0; i < repeats; i++) {
41 for (int ch = 0; ch < ur_ch_blocks; ch++) {
42 for (int ow = 0; ow < ur_w; ow++) {
43 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
45 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
51 template <cpu_isa_t isa>
52 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::apply_filter(int ur_ch_blocks, int ch_step, int ur_w) {
53 int ch_blk = jcp.ch_block;
54 int dilate_d = jcp.dilate_d + 1;
55 int dilate_h = jcp.dilate_h + 1;
56 int dilate_w = jcp.dilate_w + 1;
57 int stride_w = jcp.stride_w;
59 Label iter_exit_label;
60 Label kd_label, iter_d_exit_label;
67 mov(reg_kd, ptr[this->param1 + GET_OFF(kd_padding)]);
69 je(iter_d_exit_label, T_NEAR);
71 mov(aux_reg_inp_d, aux_reg_input);
72 mov(aux_reg_ker_d, aux_reg_kernel);
76 mov(aux_reg_input, aux_reg_inp_d);
77 mov(aux_reg_kernel, aux_reg_ker_d);
80 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
82 je(iter_exit_label, T_NEAR);
84 je(iter_exit_label, T_NEAR);
90 mov(aux1_reg_input, aux_reg_input);
91 mov(aux1_reg_kernel, aux_reg_kernel);
95 int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
96 for (int i = 0; i < repeats; i++) {
97 for (int ch = 0; ch < ur_ch_blocks; ch++) {
98 int ker_off = ch*jcp.kd*jcp.kh*jcp.kw*ch_blk + i*(ch_blk / 2);
99 Vmm vmm_ker = get_ker_reg(0);
100 Xmm xmm_ker = Xmm(vmm_ker.getIdx());
103 movsx(reg_tmp_32, ptr[aux1_reg_kernel + ker_off*jcp.typesize_in]);
104 movq(xmm_ker, reg_tmp_64);
106 uni_vpmovsxbd(vmm_ker, ptr[aux1_reg_kernel + ker_off*jcp.typesize_in]);
109 for (int ow = 0; ow < ur_w; ow++) {
110 int inp_off = ch*ch_blk + ow*stride_w*jcp.oc + i*(ch_blk / 2);
111 Vmm vmm_src = get_src_reg(0);
112 Xmm xmm_src = Xmm(vmm_src.getIdx());
115 movzx(reg_tmp_32, ptr[aux1_reg_input + inp_off * jcp.typesize_in]);
116 movq(xmm_src, reg_tmp_64);
118 uni_vpmovzxbd(vmm_src, ptr[aux1_reg_input + inp_off * jcp.typesize_in]);
121 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
122 uni_vpmulld(vmm_src, vmm_src, vmm_ker);
123 uni_vpaddd(vmm_acc, vmm_acc, vmm_src);
127 add(aux1_reg_kernel, ch_blk*jcp.typesize_in);
128 add(aux1_reg_input, jcp.oc*dilate_w*jcp.typesize_in);
132 jg(kw_label, T_NEAR);
134 add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
135 add(aux_reg_input, jcp.iw*jcp.oc*dilate_h*jcp.typesize_in);
139 jg(kh_label, T_NEAR);
144 if (jcp.ndims == 5) {
145 add(aux_reg_inp_d, dilate_d * jcp.ih * jcp.iw * jcp.ic * jcp.typesize_in);
146 add(aux_reg_ker_d, jcp.kh * jcp.kw * ch_blk * jcp.typesize_in);
147 mov(aux_reg_input, aux_reg_inp_d);
148 mov(aux_reg_kernel, aux_reg_ker_d);
152 jg(kd_label, T_NEAR);
154 L(iter_d_exit_label);
162 template <cpu_isa_t isa>
163 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::apply_filter_unrolled(int ur_ch_blocks, int ch_step, int ur_w) {
164 int ch_blk = jcp.ch_block;
165 int dilate_d = jcp.dilate_d + 1;
166 int dilate_h = jcp.dilate_h + 1;
167 int dilate_w = jcp.dilate_w + 1;
168 int stride_w = jcp.stride_w;
170 Label iter_exit_label;
171 Label kd_label, iter_d_exit_label;
173 if (jcp.ndims == 5) {
178 mov(reg_kd, ptr[this->param1 + GET_OFF(kd_padding)]);
180 je(iter_d_exit_label, T_NEAR);
182 mov(aux_reg_inp_d, aux_reg_input);
183 mov(aux_reg_ker_d, aux_reg_kernel);
187 mov(aux_reg_input, aux_reg_inp_d);
188 mov(aux_reg_kernel, aux_reg_ker_d);
191 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
193 je(iter_exit_label, T_NEAR);
195 mov(iter_kh, reg_kh);
198 int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
199 for (int i = 0; i < repeats; i++) {
200 for (int ch = 0; ch < ur_ch_blocks; ch++) {
201 for (int kw = 0; kw < jcp.kw; kw++) {
202 int ker_off = ch*jcp.kd*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*(ch_blk / 2);
203 Vmm vmm_ker = get_ker_reg(0);
204 Xmm xmm_ker = Xmm(vmm_ker.getIdx());
207 movsx(reg_tmp_32, ptr[aux_reg_kernel + ker_off*jcp.typesize_in]);
208 movq(xmm_ker, reg_tmp_64);
210 uni_vpmovsxbd(vmm_ker, ptr[aux_reg_kernel + ker_off*jcp.typesize_in]);
213 for (int ow = 0; ow < ur_w; ow++) {
214 int inp_off = ch*ch_blk + ow*stride_w*jcp.oc + kw*jcp.oc*dilate_w + i*(ch_blk / 2);
215 Vmm vmm_src = get_src_reg(0);
216 Xmm xmm_src = Xmm(vmm_src.getIdx());
219 movzx(reg_tmp_32, ptr[aux_reg_input + inp_off * jcp.typesize_in]);
220 movq(xmm_src, reg_tmp_64);
222 uni_vpmovzxbd(vmm_src, ptr[aux_reg_input + inp_off * jcp.typesize_in]);
225 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
226 uni_vpmulld(vmm_src, vmm_src, vmm_ker);
227 uni_vpaddd(vmm_acc, vmm_acc, vmm_src);
233 add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
234 add(aux_reg_input, jcp.iw*jcp.oc*dilate_h*jcp.typesize_in);
238 jg(kh_label, T_NEAR);
243 if (jcp.ndims == 5) {
244 add(aux_reg_inp_d, dilate_d * jcp.ih * jcp.iw * jcp.ic * jcp.typesize_in);
245 add(aux_reg_ker_d, jcp.kh * jcp.kw * ch_blk * jcp.typesize_in);
246 mov(aux_reg_input, aux_reg_inp_d);
247 mov(aux_reg_kernel, aux_reg_ker_d);
251 jg(kd_label, T_NEAR);
253 L(iter_d_exit_label);
261 template <cpu_isa_t isa>
262 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
263 Ymm ymm_dst = Ymm(vmm_dst.getIdx());
264 Xmm xmm_dst = Xmm(vmm_dst.getIdx());
266 switch (jcp.dst_dt) {
270 movq(reg_tmp_64, xmm_dst);
273 uni_vmovups(op, vmm_dst);
277 uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
279 if (isa != sse42 && !scalar_store)
280 vpermq(ymm_dst, ymm_dst, 0x08);
282 uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
285 movq(reg_tmp_64, xmm_dst);
295 uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
297 if (isa != sse42 && !scalar_store)
298 vpermq(ymm_dst, ymm_dst, 0x08);
300 uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
303 movq(reg_tmp_64, xmm_dst);
314 assert(!"unknown dst_dt");
318 template <cpu_isa_t isa>
319 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in,
320 const Xbyak::Operand &op, bool scalar_load) {
321 Xmm xmm_in = Xmm(vmm_in.getIdx());
329 uni_vmovups(vmm_in, op);
334 movsx(reg_tmp_32, op);
335 movq(xmm_in, reg_tmp_64);
337 uni_vpmovsxbd(vmm_in, op);
342 movzx(reg_tmp_32, op);
343 movq(xmm_in, reg_tmp_64);
345 uni_vpmovzxbd(vmm_in, op);
348 default: assert(!"unsupported data type");
351 if (type_in != data_type::f32)
352 uni_vcvtdq2ps(vmm_in, vmm_in);
355 template <cpu_isa_t isa>
356 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(int ur_ch_blocks, int ch_step, int ur_w) {
357 int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
360 pop(reg_scales_base);
362 mov(imm_addr64, l_table);
364 const auto &p = attr_.post_ops_;
365 const int sum_idx = p.find(primitive_kind::sum);
366 const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
368 bool is_scalar_store = ch_step < jcp.ch_block;
370 for (int r = 0; r < repeats; r++) {
371 for (int ii = 0; ii < ur_ch_blocks; ii++) {
373 int b_off = ii * jcp.ch_block + r * (jcp.ch_block / 2);
374 cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], is_scalar_store);
377 for (int jj = 0; jj < ur_w; jj++) {
378 Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks * ur_w + ur_w * ii + jj);
379 uni_vcvtdq2ps(vmm_dst, vmm_dst);
382 uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
384 int s_off = jcp.is_oc_scale * (ii * jcp.ch_block + r * (jcp.ch_block / 2));
385 cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], is_scalar_store);
386 uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
390 int eltwise_inj_idx = 0;
391 int depthwise_inj_idx = 0;
392 for (int i = 0; i < p.len_; i++) {
393 int start_idx = 4 + r * ur_ch_blocks*ur_w;
395 auto& post_op = p.entry_[i];
396 if (post_op.is_eltwise()) {
397 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, start_idx + ur_ch_blocks * ur_w);
399 } else if (post_op.is_depthwise()) {
400 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
401 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
403 add(reg_d_weights, reg_oc_off);
404 add(reg_d_bias, reg_oc_off);
407 add(reg_d_weights, (jcp.ch_block / 2) * sizeof(float));
408 add(reg_d_bias, (jcp.ch_block / 2) * sizeof(float));
411 for (int ii = 0; ii < ur_ch_blocks; ii++) {
412 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
413 start_idx + ur_w * ii, start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
415 add(reg_d_weights, jcp.ch_block * sizeof(float));
416 add(reg_d_bias, jcp.ch_block * sizeof(float));
420 } else if (post_op.is_sum(false)) {
421 for (int ii = 0; ii < ur_ch_blocks; ii++) {
422 for (int jj = 0; jj < ur_w; jj++) {
423 Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks*ur_w + ur_w * ii + jj);
424 int o_off = ii * jcp.ch_block + jj * jcp.oc + r * (jcp.ch_block / 2);
426 cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], is_scalar_store);
428 if (p_sum_scale == 1.f) {
429 uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
431 uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 0 * vlen]);
438 for (int ii = 0; ii < ur_ch_blocks; ii++) {
439 for (int jj = 0; jj < ur_w; jj++) {
440 Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks * ur_w + ur_w * ii + jj);
441 int o_off = ii * jcp.ch_block + jj * jcp.oc + r * (jcp.ch_block / 2);
443 if (jcp.dst_dt != data_type::f32) {
444 if (attr_.round_mode_ == round_mode::nearest)
445 uni_vcvtps2dq(vmm_dst, vmm_dst);
446 else if (attr_.round_mode_ == round_mode::down) {
447 uni_vroundps(vmm_dst, vmm_dst, 1);
448 uni_vcvtps2dq(vmm_dst, vmm_dst);
450 assert(!"unimplemented");
453 store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, is_scalar_store);
458 push(reg_scales_base);
462 template <cpu_isa_t isa>
463 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::loop_body(int ur_ch_blocks, int ch_step) {
464 Label unrolled_w_label;
468 mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]);
469 mov(reg_input, reg_input_base);
470 mov(reg_output, reg_output_base);
471 mov(reg_kernel, reg_kernel_base);
473 push(reg_input_base);
474 push(reg_output_base);
475 push(reg_kernel_base);
477 push(reg_scales_base);
480 L(unrolled_w_label); {
484 jl(tail_w_label, T_NEAR);
486 mov(aux_reg_input, reg_input);
487 mov(aux_reg_kernel, reg_kernel);
489 load_src(ur_ch_blocks, ch_step, ur_w);
490 apply_filter_unrolled(ur_ch_blocks, ch_step, ur_w);
491 store_dst(ur_ch_blocks, ch_step, ur_w);
493 add(reg_input, jcp.typesize_in * ur_w * jcp.ic * jcp.stride_w);
494 add(reg_output, jcp.typesize_out * ur_w * jcp.oc);
497 jmp(unrolled_w_label);
504 jl(exit_label, T_NEAR);
506 mov(aux_reg_input, reg_input);
507 mov(aux_reg_kernel, reg_kernel);
509 load_src(ur_ch_blocks, ch_step, ur_w);
510 apply_filter(ur_ch_blocks, ch_step, ur_w);
511 store_dst(ur_ch_blocks, ch_step, ur_w);
513 add(reg_input, jcp.typesize_in * ur_w * jcp.ic * jcp.stride_w);
514 add(reg_output, jcp.typesize_out * ur_w * jcp.oc);
523 pop(reg_scales_base);
525 pop(reg_kernel_base);
526 pop(reg_output_base);
530 template <cpu_isa_t isa>
531 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::generate() {
532 const auto &p = attr_.post_ops_;
533 for (int i = 0; i < p.len_; i++) {
534 auto &post_op = p.entry_[i];
535 if (post_op.is_eltwise()) {
536 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
539 post_op.eltwise.alpha,
542 } else if (post_op.is_depthwise()) {
543 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
545 post_op.depthwise.alg
552 mov(reg_input_base, ptr[this->param1 + GET_OFF(src)]);
553 mov(reg_output_base, ptr[this->param1 + GET_OFF(dst)]);
554 mov(reg_kernel_base, ptr[this->param1 + GET_OFF(filt)]);
556 mov(reg_bias_base, ptr[this->param1 + GET_OFF(bias)]);
557 mov(reg_scales_base, ptr[this->param1 + GET_OFF(scales)]);
558 mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
559 mov(reg_ch_work, ptr[this->param1 + GET_OFF(ch_work)]);
560 mov(reg_oc_off, ptr[this->param1 + GET_OFF(oc_off)]);
562 Label main_loop_label;
563 Label tail_loop_label;
566 cmp(reg_ch_work, jcp.nb_ch_blocking * jcp.ch_block);
567 jne(main_loop_label, T_NEAR);
569 loop_body(jcp.nb_ch_blocking, jcp.nb_ch_blocking * jcp.ch_block);
571 sub(reg_ch_work, jcp.nb_ch_blocking * jcp.ch_block);
573 jmp(exit_label, T_NEAR);
575 L(main_loop_label); {
576 cmp(reg_ch_work, jcp.ch_block);
577 jl(tail_loop_label, T_NEAR);
579 loop_body(1, jcp.ch_block);
581 sub(reg_ch_work, jcp.ch_block);
582 add(reg_input_base, jcp.ch_block * jcp.typesize_in);
583 add(reg_output_base, jcp.ch_block * jcp.typesize_out);
584 add(reg_kernel_base, jcp.ch_block * jcp.kd * jcp.kh * jcp.kw * jcp.typesize_in);
585 add(reg_bias_base, jcp.ch_block * jcp.typesize_bia);
586 add(reg_scales_base, jcp.is_oc_scale * jcp.ch_block * sizeof(float));
587 add(reg_oc_off, jcp.ch_block * sizeof(float));
589 jmp(main_loop_label, T_NEAR);
592 L(tail_loop_label); {
594 jl(exit_label, T_NEAR);
599 add(reg_input_base, 1 * jcp.typesize_in);
600 add(reg_output_base, 1 * jcp.typesize_out);
601 add(reg_kernel_base, 1 * jcp.typesize_in);
602 add(reg_bias_base, 1 * jcp.typesize_bia);
603 add(reg_scales_base, jcp.is_oc_scale * 1 * sizeof(float));
604 add(reg_oc_off, 1 * sizeof(float));
606 jmp(tail_loop_label, T_NEAR);
615 for (auto& inj : eltwise_injectors)
616 inj->prepare_table();
619 template <cpu_isa_t isa>
620 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::prepare_table() {
621 const auto &p = attr_.post_ops_;
622 const int sum_idx = p.find(primitive_kind::sum);
623 const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
625 const int32_t cvals_sum_scale[] = {
626 float2int(p_sum_scale)
631 for (size_t i = 0; i < sizeof(cvals_sum_scale) / sizeof(cvals_sum_scale[0]); ++i) {
632 for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
633 dd(cvals_sum_scale[i]);
638 template <cpu_isa_t isa>
639 bool jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::post_ops_ok(
640 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
641 const auto &p = attr.post_ops_;
643 auto all_post_ops_supported = [&]() {
646 for (int i = 0; i < p.len_; i++) {
647 ok = ok && utils::one_of(p.entry_[i].kind, primitive_kind::sum, primitive_kind::eltwise, primitive_kind::depthwise);
651 auto count = [&](mkldnn::impl::primitive_kind_t kind) { return p.count(kind); };
653 return all_post_ops_supported() &&
654 count(primitive_kind::sum) <= 1;
657 template <cpu_isa_t isa>
658 status_t jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
659 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
660 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
661 const memory_desc_wrapper &bias_pd, const primitive_attr_t &attr)
663 if (!mayiuse(isa)) return status::unimplemented;
665 if (!(src_d.data_type() == data_type::u8 &&
666 weights_d.data_type() == data_type::s8 &&
667 one_of(dst_d.data_type(), data_type::f32, data_type::s32, data_type::s8, data_type::u8)))
668 return status::unimplemented;
670 jcp.prop_kind = cd.prop_kind;
672 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
673 if (!with_groups) return status::unimplemented;
675 int ndims = src_d.ndims();
678 jcp.ngroups = weights_d.dims()[0];
679 jcp.mb = src_d.dims()[0];
681 jcp.oc = dst_d.dims()[1];
682 jcp.ic = src_d.dims()[1];
684 jcp.id = (ndims == 5) ? src_d.dims()[2] : 1;
685 jcp.ih = src_d.dims()[ndims - 2];
686 jcp.iw = src_d.dims()[ndims - 1];
687 jcp.od = (ndims == 5) ? dst_d.dims()[2] : 1;
688 jcp.oh = dst_d.dims()[ndims - 2];
689 jcp.ow = dst_d.dims()[ndims - 1];
691 jcp.kd = (ndims == 5) ? weights_d.dims()[3] : 1;
692 jcp.kh = weights_d.dims()[ndims - 1];
693 jcp.kw = weights_d.dims()[ndims];
695 jcp.f_pad = (ndims == 5) ? cd.padding[0][0] : 0;
696 jcp.t_pad = cd.padding[0][ndims - 4];
697 jcp.l_pad = cd.padding[0][ndims - 3];
698 jcp.back_pad = (ndims == 5) ? cd.padding[1][0] : 0;
699 jcp.b_pad = cd.padding[1][ndims - 4];
700 jcp.r_pad = cd.padding[1][ndims - 3];
702 jcp.stride_d = (ndims == 5) ? cd.strides[0] : 1;
703 jcp.stride_h = cd.strides[ndims - 4];
704 jcp.stride_w = cd.strides[ndims - 3];
706 jcp.dilate_d = (ndims == 5) ? cd.dilates[0] : 0;
707 jcp.dilate_h = cd.dilates[ndims - 4];
708 jcp.dilate_w = cd.dilates[ndims - 3];
710 jcp.src_fmt = src_d.format();
711 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
713 jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
715 if (jcp.signed_input)
716 return status::unimplemented;
718 const int simd_w = isa == avx512_common ? 16 : 8;
719 jcp.ch_block = simd_w;
720 jcp.nb_ch = div_up(jcp.oc, jcp.ch_block);
722 if (!post_ops_ok(jcp, attr))
723 return status::unimplemented;
725 const auto &p = attr.post_ops_;
726 jcp.with_sum = p.find(primitive_kind::sum) != -1;
727 const int eltwise_ind = p.find(primitive_kind::eltwise);
728 jcp.with_eltwise = eltwise_ind != -1;
729 if (jcp.with_eltwise)
730 jcp.eltwise = p.entry_[eltwise_ind].eltwise;
732 auto desired_act_fmt = (ndims == 5) ? ndhwc : nhwc;
733 auto desired_wei_fmt = (ndims == 5) ? isa == avx512_common ? Goidhw16g : Goidhw8g
734 : isa == avx512_common ? Goihw16g : Goihw8g;
737 && jcp.oc == jcp.ngroups
738 && jcp.ic == jcp.ngroups
739 && src_d.format() == desired_act_fmt
740 && weights_d.format() == desired_wei_fmt
741 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
742 && dst_d.format() == desired_act_fmt;
743 if (!args_ok) return status::unimplemented;
745 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
746 jcp.dst_dt = cd.dst_desc.data_type;
748 jcp.typesize_in = types::data_type_size(src_d.data_type());
749 jcp.typesize_out = types::data_type_size(dst_d.data_type());
750 jcp.typesize_acc = sizeof(int32_t);
751 jcp.typesize_bia = jcp.with_bias
752 ? types::data_type_size(bias_pd.data_type())
755 const auto &oscales = attr.output_scales_;
756 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
758 assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
760 jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
762 jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
763 if (jcp.nb_ch < jcp.nb_ch_blocking)
764 jcp.nb_ch_blocking = jcp.nb_ch;
766 return status::success;
769 template struct jit_uni_x8s8s32x_dw_conv_fwd_kernel<avx2>;
770 template struct jit_uni_x8s8s32x_dw_conv_fwd_kernel<sse42>;