using namespace mkldnn::impl::memory_format;
using namespace mkldnn::impl::utils;
-template <cpu_isa_t isa, bool with_relu, data_type_t src_type, data_type_t dst_type>
-void _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, with_relu, src_type, dst_type>::execute_forward() {
+template <cpu_isa_t isa, data_type_t src_type, data_type_t dst_type>
+void _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, src_type, dst_type>::execute_forward() const {
auto src = reinterpret_cast<const src_data_t*>(this->input_memory(0));
auto weights = reinterpret_cast<const wei_data_t*>(this->input_memory(1));
auto bias = reinterpret_cast<const char*>(this->input_memory(2));
auto dst = reinterpret_cast<dst_data_t*>(this->memory());
- const memory_desc_wrapper src_d(conf_.src_pd());
- const memory_desc_wrapper dst_d(conf_.dst_pd());
- const memory_desc_wrapper weights_d(conf_.weights_pd(0));
- const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+ const memory_desc_wrapper src_d(pd()->src_pd());
+ const memory_desc_wrapper dst_d(pd()->dst_pd());
+ const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+ const memory_desc_wrapper bias_d(pd()->weights_pd(1));
const auto &jcp = kernel_->jcp;
int str_h = jcp.stride_h;
int str_w = jcp.stride_w;
- const size_t bia_dt_size = conf_.with_bias()
- ? types::data_type_size(conf_.cdesc()->bias_desc.data_type) : 0;
+ const size_t bia_dt_size = pd()->with_bias()
+ ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
- const auto &oscales = conf_.attr()->output_scales_;
+ const auto &oscales = pd()->attr()->output_scales_;
int MB = jcp.mb;
int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
auto kernel_params = [&](int ur_w_step, int ow, int oh, int ih, int kh,
int kh_padding, int ch, int ch_num, int n) {
- jit_conv_call_s par_conv = {};
+ auto par_conv = jit_conv_call_s();
const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w));
const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w
par_conv.ch_work = nstl::min((ch + ch_num) * jcp.ch_block, jcp.oc) - ch*jcp.ch_block;
par_conv.scales = &oscales.scales_[jcp.is_oc_scale * ch * jcp.ch_block];
+ par_conv.oc_off = ch * jcp.ch_block * sizeof(float);
return par_conv;
};
parallel(0, ker);
}
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, true, data_type::u8, data_type::u8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, true, data_type::u8, data_type::s8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, true, data_type::u8, data_type::s32>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, true, data_type::u8, data_type::f32>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, false, data_type::u8, data_type::u8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, false, data_type::u8, data_type::s8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, false, data_type::u8, data_type::s32>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, false, data_type::u8, data_type::f32>::execute_forward();
-
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, true, data_type::u8, data_type::u8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, true, data_type::u8, data_type::s8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, true, data_type::u8, data_type::s32>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, true, data_type::u8, data_type::f32>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, false, data_type::u8, data_type::u8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, false, data_type::u8, data_type::s8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, false, data_type::u8, data_type::s32>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, false, data_type::u8, data_type::f32>::execute_forward();
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::u8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::s8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::s32>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::f32>;
+
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::u8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::s8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::s32>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::f32>;
+
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::u8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::s8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::s32>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::f32>;
+
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::u8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::s8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::s32>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::f32>;
}
}