Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_x8s8s32x_dw_convolution.cpp
index bc31a38..b102c53 100644 (file)
@@ -26,17 +26,17 @@ using namespace mkldnn::impl::status;
 using namespace mkldnn::impl::memory_format;
 using namespace mkldnn::impl::utils;
 
-template <cpu_isa_t isa, bool with_relu, data_type_t src_type, data_type_t dst_type>
-void _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, with_relu, src_type, dst_type>::execute_forward() {
+template <cpu_isa_t isa, data_type_t src_type, data_type_t dst_type>
+void _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, src_type, dst_type>::execute_forward() const {
     auto src = reinterpret_cast<const src_data_t*>(this->input_memory(0));
     auto weights = reinterpret_cast<const wei_data_t*>(this->input_memory(1));
     auto bias = reinterpret_cast<const char*>(this->input_memory(2));
     auto dst = reinterpret_cast<dst_data_t*>(this->memory());
 
-    const memory_desc_wrapper src_d(conf_.src_pd());
-    const memory_desc_wrapper dst_d(conf_.dst_pd());
-    const memory_desc_wrapper weights_d(conf_.weights_pd(0));
-    const memory_desc_wrapper bias_d(conf_.weights_pd(1));
+    const memory_desc_wrapper src_d(pd()->src_pd());
+    const memory_desc_wrapper dst_d(pd()->dst_pd());
+    const memory_desc_wrapper weights_d(pd()->weights_pd(0));
+    const memory_desc_wrapper bias_d(pd()->weights_pd(1));
 
     const auto &jcp = kernel_->jcp;
 
@@ -45,10 +45,10 @@ void _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, with_relu, src_type, dst_type>:
     int str_h = jcp.stride_h;
     int str_w = jcp.stride_w;
 
-    const size_t bia_dt_size = conf_.with_bias()
-        ? types::data_type_size(conf_.cdesc()->bias_desc.data_type) : 0;
+    const size_t bia_dt_size = pd()->with_bias()
+        ? types::data_type_size(pd()->desc()->bias_desc.data_type) : 0;
 
-    const auto &oscales = conf_.attr()->output_scales_;
+    const auto &oscales = pd()->attr()->output_scales_;
 
     int MB = jcp.mb;
     int chb_work = utils::div_up(jcp.nb_ch, jcp.nb_ch_blocking);
@@ -56,7 +56,7 @@ void _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, with_relu, src_type, dst_type>:
 
     auto kernel_params = [&](int ur_w_step, int ow, int oh, int ih, int kh,
             int kh_padding, int ch, int ch_num, int n) {
-        jit_conv_call_s par_conv = {};
+        auto par_conv = jit_conv_call_s();
 
         const int i_l_overflow = nstl::max(0, (jcp.l_pad - ow * str_w));
         const int i_r_overflow = nstl::max(jcp.iw, (ow * str_w
@@ -86,6 +86,7 @@ void _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, with_relu, src_type, dst_type>:
         par_conv.ch_work = nstl::min((ch + ch_num) * jcp.ch_block, jcp.oc) - ch*jcp.ch_block;
 
         par_conv.scales = &oscales.scales_[jcp.is_oc_scale * ch * jcp.ch_block];
+        par_conv.oc_off = ch * jcp.ch_block * sizeof(float);
 
         return par_conv;
     };
@@ -149,23 +150,25 @@ void _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, with_relu, src_type, dst_type>:
     parallel(0, ker);
 }
 
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, true, data_type::u8, data_type::u8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, true, data_type::u8, data_type::s8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, true, data_type::u8, data_type::s32>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, true, data_type::u8, data_type::f32>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, false, data_type::u8, data_type::u8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, false, data_type::u8, data_type::s8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, false, data_type::u8, data_type::s32>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, false, data_type::u8, data_type::f32>::execute_forward();
-
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, true, data_type::u8, data_type::u8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, true, data_type::u8, data_type::s8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, true, data_type::u8, data_type::s32>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, true, data_type::u8, data_type::f32>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, false, data_type::u8, data_type::u8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, false, data_type::u8, data_type::s8>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, false, data_type::u8, data_type::s32>::execute_forward();
-template void _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, false, data_type::u8, data_type::f32>::execute_forward();
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::u8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::s8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::s32>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::u8, data_type::f32>;
+
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::u8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::s8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::s32>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, data_type::s8, data_type::f32>;
+
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::u8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::s8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::s32>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::u8, data_type::f32>;
+
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::u8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::s8>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::s32>;
+template struct _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, data_type::s8, data_type::f32>;
 
 }
 }