1 /*******************************************************************************
2 * Copyright 2018 Intel Corporation
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
17 #include "c_types_map.hpp"
19 #include "type_helpers.hpp"
21 #include "cpu_memory.hpp"
23 #include "jit_uni_x8s8s32x_dw_conv_kernel.hpp"
25 #define GET_OFF(field) offsetof(jit_conv_call_s, field)
31 using namespace mkldnn::impl::prop_kind;
32 using namespace mkldnn::impl::memory_format;
33 using namespace mkldnn::impl::utils;
35 using namespace Xbyak;
37 template <cpu_isa_t isa>
38 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::load_src(int ur_ch_blocks, int ch_step, int ur_w) {
39 int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
40 for (int i = 0; i < repeats; i++) {
41 for (int ch = 0; ch < ur_ch_blocks; ch++) {
42 for (int ow = 0; ow < ur_w; ow++) {
43 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
45 uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
51 template <cpu_isa_t isa>
52 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::apply_filter(int ur_ch_blocks, int ch_step, int ur_w) {
53 int ch_blk = jcp.ch_block;
54 int dilate_h = jcp.dilate_h + 1;
55 int dilate_w = jcp.dilate_w + 1;
56 int stride_w = jcp.stride_w;
58 Label iter_exit_label;
61 je(iter_exit_label, T_NEAR);
63 je(iter_exit_label, T_NEAR);
69 mov(aux1_reg_input, aux_reg_input);
70 mov(aux1_reg_kernel, aux_reg_kernel);
74 int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
75 for (int i = 0; i < repeats; i++) {
76 for (int ch = 0; ch < ur_ch_blocks; ch++) {
77 int ker_off = ch*jcp.kh*jcp.kw*ch_blk + i*(ch_blk / 2);
78 Vmm vmm_ker = get_ker_reg(0);
79 Xmm xmm_ker = Xmm(vmm_ker.getIdx());
82 movsx(reg_tmp_32, ptr[aux1_reg_kernel + ker_off*jcp.typesize_in]);
83 movq(xmm_ker, reg_tmp_64);
85 uni_vpmovsxbd(vmm_ker, ptr[aux1_reg_kernel + ker_off*jcp.typesize_in]);
88 for (int ow = 0; ow < ur_w; ow++) {
89 int inp_off = ch*ch_blk + ow*stride_w*jcp.oc + i*(ch_blk / 2);
90 Vmm vmm_src = get_src_reg(0);
91 Xmm xmm_src = Xmm(vmm_src.getIdx());
94 movzx(reg_tmp_32, ptr[aux1_reg_input + inp_off * jcp.typesize_in]);
95 movq(xmm_src, reg_tmp_64);
97 uni_vpmovzxbd(vmm_src, ptr[aux1_reg_input + inp_off * jcp.typesize_in]);
100 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
101 uni_vpmulld(vmm_src, vmm_src, vmm_ker);
102 uni_vpaddd(vmm_acc, vmm_acc, vmm_src);
106 add(aux1_reg_kernel, ch_blk*jcp.typesize_in);
107 add(aux1_reg_input, jcp.oc*dilate_w*jcp.typesize_in);
111 jg(kw_label, T_NEAR);
113 add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
114 add(aux_reg_input, jcp.iw*jcp.oc*dilate_h*jcp.typesize_in);
118 jg(kh_label, T_NEAR);
124 template <cpu_isa_t isa>
125 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::apply_filter_unrolled(int ur_ch_blocks, int ch_step, int ur_w) {
126 int ch_blk = jcp.ch_block;
127 int dilate_h = jcp.dilate_h + 1;
128 int dilate_w = jcp.dilate_w + 1;
129 int stride_w = jcp.stride_w;
131 Label iter_exit_label;
134 je(iter_exit_label, T_NEAR);
136 mov(iter_kh, reg_kh);
139 int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
140 for (int i = 0; i < repeats; i++) {
141 for (int ch = 0; ch < ur_ch_blocks; ch++) {
142 for (int kw = 0; kw < jcp.kw; kw++) {
143 int ker_off = ch*jcp.kh*jcp.kw*ch_blk + kw*ch_blk + i*(ch_blk / 2);
144 Vmm vmm_ker = get_ker_reg(0);
145 Xmm xmm_ker = Xmm(vmm_ker.getIdx());
148 movsx(reg_tmp_32, ptr[aux_reg_kernel + ker_off*jcp.typesize_in]);
149 movq(xmm_ker, reg_tmp_64);
151 uni_vpmovsxbd(vmm_ker, ptr[aux_reg_kernel + ker_off*jcp.typesize_in]);
154 for (int ow = 0; ow < ur_w; ow++) {
155 int inp_off = ch*ch_blk + ow*stride_w*jcp.oc + kw*jcp.oc*dilate_w + i*(ch_blk / 2);
156 Vmm vmm_src = get_src_reg(0);
157 Xmm xmm_src = Xmm(vmm_src.getIdx());
160 movzx(reg_tmp_32, ptr[aux_reg_input + inp_off * jcp.typesize_in]);
161 movq(xmm_src, reg_tmp_64);
163 uni_vpmovzxbd(vmm_src, ptr[aux_reg_input + inp_off * jcp.typesize_in]);
166 Vmm vmm_acc = get_acc_reg(i*ur_ch_blocks*ur_w + ch*ur_w + ow);
167 uni_vpmulld(vmm_src, vmm_src, vmm_ker);
168 uni_vpaddd(vmm_acc, vmm_acc, vmm_src);
174 add(aux_reg_kernel, jcp.kw*ch_blk*jcp.typesize_in);
175 add(aux_reg_input, jcp.iw*jcp.oc*dilate_h*jcp.typesize_in);
179 jg(kh_label, T_NEAR);
185 template <cpu_isa_t isa>
186 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) {
187 Ymm ymm_dst = Ymm(vmm_dst.getIdx());
188 Xmm xmm_dst = Xmm(vmm_dst.getIdx());
190 switch (jcp.dst_dt) {
194 movq(reg_tmp_64, xmm_dst);
197 uni_vmovups(op, vmm_dst);
201 uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst);
203 if (isa != sse42 && !scalar_store)
204 vpermq(ymm_dst, ymm_dst, 0x08);
206 uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst);
209 movq(reg_tmp_64, xmm_dst);
219 uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst);
221 if (isa != sse42 && !scalar_store)
222 vpermq(ymm_dst, ymm_dst, 0x08);
224 uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst);
227 movq(reg_tmp_64, xmm_dst);
238 assert(!"unknown dst_dt");
242 template <cpu_isa_t isa>
243 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::cvt2ps(data_type_t type_in, Vmm vmm_in,
244 const Xbyak::Operand &op, bool scalar_load) {
245 Xmm xmm_in = Xmm(vmm_in.getIdx());
253 uni_vmovups(vmm_in, op);
258 movsx(reg_tmp_32, op);
259 movq(xmm_in, reg_tmp_64);
261 uni_vpmovsxbd(vmm_in, op);
266 movzx(reg_tmp_32, op);
267 movq(xmm_in, reg_tmp_64);
269 uni_vpmovzxbd(vmm_in, op);
272 default: assert(!"unsupported data type");
275 if (type_in != data_type::f32)
276 uni_vcvtdq2ps(vmm_in, vmm_in);
279 template <cpu_isa_t isa>
280 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::store_dst(int ur_ch_blocks, int ch_step, int ur_w) {
281 int repeats = isa == sse42 && ch_step > (jcp.ch_block / 2) ? 2 : 1;
284 pop(reg_scales_base);
286 mov(imm_addr64, l_table);
288 const auto &p = attr_.post_ops_;
289 const int sum_idx = p.find(primitive_kind::sum);
290 const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
292 bool is_scalar_store = ch_step < jcp.ch_block;
294 for (int r = 0; r < repeats; r++) {
295 for (int ii = 0; ii < ur_ch_blocks; ii++) {
297 int b_off = ii * jcp.ch_block + r * (jcp.ch_block / 2);
298 cvt2ps(jcp.bia_dt, vmm_bias, ptr[reg_bias_base + b_off * jcp.typesize_bia], is_scalar_store);
301 for (int jj = 0; jj < ur_w; jj++) {
302 Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks * ur_w + ur_w * ii + jj);
303 uni_vcvtdq2ps(vmm_dst, vmm_dst);
306 uni_vaddps(vmm_dst, vmm_dst, vmm_bias);
308 int s_off = jcp.is_oc_scale * (ii * jcp.ch_block + r * (jcp.ch_block / 2));
309 cvt2ps(mkldnn_f32, vmm_scale, ptr[reg_scales_base + s_off * sizeof(float)], is_scalar_store);
310 uni_vmulps(vmm_dst, vmm_dst, vmm_scale);
314 int eltwise_inj_idx = 0;
315 int depthwise_inj_idx = 0;
316 for (int i = 0; i < p.len_; i++) {
317 int start_idx = 4 + r * ur_ch_blocks*ur_w;
319 auto& post_op = p.entry_[i];
320 if (post_op.is_eltwise()) {
321 eltwise_injectors[eltwise_inj_idx]->compute_vector_range(start_idx, start_idx + ur_ch_blocks * ur_w);
323 } else if (post_op.is_depthwise()) {
324 mov(reg_d_weights, reinterpret_cast<size_t>(post_op.depthwise.weights_data));
325 mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
327 add(reg_d_weights, reg_oc_off);
328 add(reg_d_bias, reg_oc_off);
331 add(reg_d_weights, (jcp.ch_block / 2) * sizeof(float));
332 add(reg_d_bias, (jcp.ch_block / 2) * sizeof(float));
335 for (int ii = 0; ii < ur_ch_blocks; ii++) {
336 depthwise_injectors[depthwise_inj_idx]->compute_vector_range(
337 start_idx + ur_w * ii, start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_bias);
339 add(reg_d_weights, jcp.ch_block * sizeof(float));
340 add(reg_d_bias, jcp.ch_block * sizeof(float));
344 } else if (post_op.is_sum(false)) {
345 for (int ii = 0; ii < ur_ch_blocks; ii++) {
346 for (int jj = 0; jj < ur_w; jj++) {
347 Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks*ur_w + ur_w * ii + jj);
348 int o_off = ii * jcp.ch_block + jj * jcp.oc + r * (jcp.ch_block / 2);
350 cvt2ps(jcp.dst_dt, vmm_prev_dst, ptr[reg_output + o_off * jcp.typesize_out], is_scalar_store);
352 if (p_sum_scale == 1.f) {
353 uni_vaddps(vmm_dst, vmm_dst, vmm_prev_dst);
355 uni_vfmadd231ps(vmm_dst, vmm_prev_dst, ptr[imm_addr64 + 0 * vlen]);
362 for (int ii = 0; ii < ur_ch_blocks; ii++) {
363 for (int jj = 0; jj < ur_w; jj++) {
364 Vmm vmm_dst = get_acc_reg(r * ur_ch_blocks * ur_w + ur_w * ii + jj);
365 int o_off = ii * jcp.ch_block + jj * jcp.oc + r * (jcp.ch_block / 2);
367 if (jcp.dst_dt != data_type::f32) {
368 if (attr_.round_mode_ == round_mode::nearest)
369 uni_vcvtps2dq(vmm_dst, vmm_dst);
370 else if (attr_.round_mode_ == round_mode::down) {
371 uni_vroundps(vmm_dst, vmm_dst, 1);
372 uni_vcvtps2dq(vmm_dst, vmm_dst);
374 assert(!"unimplemented");
377 store_dst(ptr[reg_output + o_off * jcp.typesize_out], vmm_dst, is_scalar_store);
382 push(reg_scales_base);
386 template <cpu_isa_t isa>
387 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::loop_body(int ur_ch_blocks, int ch_step) {
388 Label unrolled_w_label;
392 mov(reg_ur_w, ptr[this->param1 + GET_OFF(ur_w)]);
393 mov(reg_input, reg_input_base);
394 mov(reg_output, reg_output_base);
395 mov(reg_kernel, reg_kernel_base);
397 push(reg_input_base);
398 push(reg_output_base);
399 push(reg_kernel_base);
401 push(reg_scales_base);
404 L(unrolled_w_label); {
408 jl(tail_w_label, T_NEAR);
410 mov(aux_reg_input, reg_input);
411 mov(aux_reg_kernel, reg_kernel);
413 load_src(ur_ch_blocks, ch_step, ur_w);
414 apply_filter_unrolled(ur_ch_blocks, ch_step, ur_w);
415 store_dst(ur_ch_blocks, ch_step, ur_w);
417 add(reg_input, jcp.typesize_in * ur_w * jcp.ic * jcp.stride_w);
418 add(reg_output, jcp.typesize_out * ur_w * jcp.oc);
421 jmp(unrolled_w_label);
428 jl(exit_label, T_NEAR);
430 mov(aux_reg_input, reg_input);
431 mov(aux_reg_kernel, reg_kernel);
433 load_src(ur_ch_blocks, ch_step, ur_w);
434 apply_filter(ur_ch_blocks, ch_step, ur_w);
435 store_dst(ur_ch_blocks, ch_step, ur_w);
437 add(reg_input, jcp.typesize_in * ur_w * jcp.ic * jcp.stride_w);
438 add(reg_output, jcp.typesize_out * ur_w * jcp.oc);
447 pop(reg_scales_base);
449 pop(reg_kernel_base);
450 pop(reg_output_base);
454 template <cpu_isa_t isa>
455 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::generate() {
456 const auto &p = attr_.post_ops_;
457 for (int i = 0; i < p.len_; i++) {
458 auto &post_op = p.entry_[i];
459 if (post_op.is_eltwise()) {
460 eltwise_injectors.push_back(new jit_uni_eltwise_injector_f32<isa>(
463 post_op.eltwise.alpha,
466 } else if (post_op.is_depthwise()) {
467 depthwise_injectors.push_back(new jit_uni_depthwise_injector_f32<isa>(
469 post_op.depthwise.alg
476 mov(reg_input_base, ptr[this->param1 + GET_OFF(src)]);
477 mov(reg_output_base, ptr[this->param1 + GET_OFF(dst)]);
478 mov(reg_kernel_base, ptr[this->param1 + GET_OFF(filt)]);
480 mov(reg_bias_base, ptr[this->param1 + GET_OFF(bias)]);
481 mov(reg_scales_base, ptr[this->param1 + GET_OFF(scales)]);
482 mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]);
483 mov(reg_kw, ptr[this->param1 + GET_OFF(kw_padding)]);
484 mov(reg_ch_work, ptr[this->param1 + GET_OFF(ch_work)]);
485 mov(reg_oc_off, ptr[this->param1 + GET_OFF(oc_off)]);
487 Label main_loop_label;
488 Label tail_loop_label;
491 cmp(reg_ch_work, jcp.nb_ch_blocking * jcp.ch_block);
492 jne(main_loop_label, T_NEAR);
494 loop_body(jcp.nb_ch_blocking, jcp.nb_ch_blocking * jcp.ch_block);
496 sub(reg_ch_work, jcp.nb_ch_blocking * jcp.ch_block);
498 jmp(exit_label, T_NEAR);
500 L(main_loop_label); {
501 cmp(reg_ch_work, jcp.ch_block);
502 jl(tail_loop_label, T_NEAR);
504 loop_body(1, jcp.ch_block);
506 sub(reg_ch_work, jcp.ch_block);
507 add(reg_input_base, jcp.ch_block * jcp.typesize_in);
508 add(reg_output_base, jcp.ch_block * jcp.typesize_out);
509 add(reg_kernel_base, jcp.ch_block * jcp.kh * jcp.kw * jcp.typesize_in);
510 add(reg_bias_base, jcp.ch_block * jcp.typesize_bia);
511 add(reg_scales_base, jcp.is_oc_scale * jcp.ch_block * sizeof(float));
512 add(reg_oc_off, jcp.ch_block * sizeof(float));
514 jmp(main_loop_label, T_NEAR);
517 L(tail_loop_label); {
519 jl(exit_label, T_NEAR);
524 add(reg_input_base, 1 * jcp.typesize_in);
525 add(reg_output_base, 1 * jcp.typesize_out);
526 add(reg_kernel_base, 1 * jcp.typesize_in);
527 add(reg_bias_base, 1 * jcp.typesize_bia);
528 add(reg_scales_base, jcp.is_oc_scale * 1 * sizeof(float));
529 add(reg_oc_off, 1 * sizeof(float));
531 jmp(tail_loop_label, T_NEAR);
540 for (auto& inj : eltwise_injectors)
541 inj->prepare_table();
544 template <cpu_isa_t isa>
545 void jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::prepare_table() {
546 const auto &p = attr_.post_ops_;
547 const int sum_idx = p.find(primitive_kind::sum);
548 const float p_sum_scale = (sum_idx != -1) ? p.entry_[sum_idx].sum.scale : 1.f;
550 const int32_t cvals_sum_scale[] = {
551 float2int(p_sum_scale)
556 for (size_t i = 0; i < sizeof(cvals_sum_scale) / sizeof(cvals_sum_scale[0]); ++i) {
557 for (size_t d = 0; d < vlen / sizeof(int32_t); ++d) {
558 dd(cvals_sum_scale[i]);
563 template <cpu_isa_t isa>
564 bool jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::post_ops_ok(
565 jit_conv_conf_t &jcp, const primitive_attr_t &attr) {
566 const auto &p = attr.post_ops_;
568 auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(); };
569 auto is_depthwise = [&](int idx) { return p.entry_[idx].is_depthwise(); };
570 auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); };
571 auto is_simple = [&](int idx) { return is_eltwise(idx) || is_depthwise(idx); };
575 case 1: return is_simple(0) || is_sum(0);
576 case 2: return (is_sum(0) && is_simple(1)) || (is_simple(0) && is_sum(1)) ||
577 (is_simple(0) && is_simple(1));
578 case 3: return (is_simple(0) && is_sum(1) && is_simple(2));
579 default: return false;
585 template <cpu_isa_t isa>
586 status_t jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::init_conf(jit_conv_conf_t &jcp,
587 const convolution_desc_t &cd, const memory_desc_wrapper &src_d,
588 const memory_desc_wrapper &weights_d, const memory_desc_wrapper &dst_d,
589 const memory_desc_wrapper &bias_pd, const primitive_attr_t &attr)
591 if (!mayiuse(isa)) return status::unimplemented;
593 if (!(src_d.data_type() == data_type::u8 &&
594 weights_d.data_type() == data_type::s8 &&
595 one_of(dst_d.data_type(), data_type::f32, data_type::s32, data_type::s8, data_type::u8)))
596 return status::unimplemented;
598 jcp.prop_kind = cd.prop_kind;
600 const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
601 if (!with_groups) return status::unimplemented;
603 jcp.ngroups = weights_d.dims()[0];
604 jcp.mb = src_d.dims()[0];
606 jcp.oc = dst_d.dims()[1];
607 jcp.ic = src_d.dims()[1];
609 jcp.ih = src_d.dims()[2];
610 jcp.iw = src_d.dims()[3];
611 jcp.oh = dst_d.dims()[2];
612 jcp.ow = dst_d.dims()[3];
614 jcp.kh = weights_d.dims()[3];
615 jcp.kw = weights_d.dims()[4];
617 jcp.t_pad = cd.padding[0][0];
618 jcp.l_pad = cd.padding[0][1];
619 jcp.b_pad = cd.padding[1][0];
620 jcp.r_pad = cd.padding[1][1];
622 jcp.stride_h = cd.strides[0];
623 jcp.stride_w = cd.strides[1];
625 jcp.dilate_h = cd.dilates[0];
626 jcp.dilate_w = cd.dilates[1];
628 jcp.src_fmt = src_d.format();
629 jcp.with_bias = cd.bias_desc.format != memory_format::undef;
631 jcp.signed_input = (src_d.data_type() == data_type::s8) ? true : false;
633 if (jcp.signed_input)
634 return status::unimplemented;
636 const int simd_w = isa == avx512_common ? 16 : 8;
637 jcp.ch_block = simd_w;
638 jcp.nb_ch = div_up(jcp.oc, jcp.ch_block);
640 if (!post_ops_ok(jcp, attr))
641 return status::unimplemented;
643 const auto &p = attr.post_ops_;
644 jcp.with_sum = p.find(primitive_kind::sum) != -1;
645 const int eltwise_ind = p.find(primitive_kind::eltwise);
646 jcp.with_eltwise = eltwise_ind != -1;
647 if (jcp.with_eltwise)
648 jcp.eltwise = p.entry_[eltwise_ind].eltwise;
650 auto desired_act_fmt = nhwc;
651 auto desired_wei_fmt = isa == avx512_common ? Goihw16g : Goihw8g;
654 && jcp.oc == jcp.ngroups
655 && jcp.ic == jcp.ngroups
656 && src_d.format() == desired_act_fmt
657 && weights_d.format() == desired_wei_fmt
658 && one_of(cd.bias_desc.format, memory_format::undef, any, x)
659 && dst_d.format() == desired_act_fmt;
660 if (!args_ok) return status::unimplemented;
662 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
663 jcp.dst_dt = cd.dst_desc.data_type;
665 jcp.typesize_in = types::data_type_size(src_d.data_type());
666 jcp.typesize_out = types::data_type_size(dst_d.data_type());
667 jcp.typesize_acc = sizeof(int32_t);
668 jcp.typesize_bia = jcp.with_bias
669 ? types::data_type_size(bias_pd.data_type())
672 const auto &oscales = attr.output_scales_;
673 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
675 assert(IMPLICATION(!jcp.is_oc_scale, oscales.mask_ == 0));
677 jcp.ur_w = isa == avx512_common ? 6 : isa == avx2 ? 4 : 3;
679 jcp.nb_ch_blocking = isa == avx512_common ? 4 : isa == avx2 ? 3 : 2;
680 if (jcp.nb_ch < jcp.nb_ch_blocking)
681 jcp.nb_ch_blocking = jcp.nb_ch;
683 return status::success;
686 template struct jit_uni_x8s8s32x_dw_conv_fwd_kernel<avx2>;
687 template struct jit_uni_x8s8s32x_dw_conv_fwd_kernel<sse42>;