Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / jit_uni_x8s8s32x_dw_convolution.hpp
index 17d70c1..a6c3cf6 100644 (file)
@@ -28,40 +28,40 @@ namespace mkldnn {
 namespace impl {
 namespace cpu {
 
-template <cpu_isa_t isa, bool with_relu, impl::data_type_t src_type, impl::data_type_t dst_type>
+template <cpu_isa_t isa, impl::data_type_t src_type, impl::data_type_t dst_type>
 struct _jit_uni_x8s8s32x_dw_convolution_fwd_t: public cpu_primitive_t {
-    struct pd_t: public _cpu_convolution_fwd_pd_t<with_relu> {
-        pd_t(engine_t *engine, const typename pd_t::base_desc_t *adesc,
+    struct pd_t: public cpu_convolution_fwd_pd_t {
+        pd_t(engine_t *engine, const convolution_desc_t *adesc,
                 const primitive_attr_t *attr,
                 const typename pd_t::base_class *hint_fwd_pd)
-            : _cpu_convolution_fwd_pd_t<with_relu>(engine, adesc, attr,
+            : cpu_convolution_fwd_pd_t(engine, adesc, attr,
                 hint_fwd_pd)
-            , jcp_({}) {}
+            , jcp_() {}
 
         DECLARE_COMMON_PD_T(
                 JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""),
-                _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, with_relu, src_type, dst_type>);
+                _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, src_type, dst_type>);
 
         virtual status_t init() override {
             using namespace prop_kind;
             assert(this->engine()->kind() == engine_kind::cpu);
             bool ok = true
                 && this->set_default_params() == status::success
-                && utils::one_of(this->cdesc_().prop_kind, forward_training,
+                && utils::one_of(this->desc()->prop_kind, forward_training,
                         forward_inference)
-                && this->cdesc_().alg_kind == alg_kind::convolution_direct
-                && this->cdesc_().dst_desc.data_type == dst_type
+                && this->desc()->alg_kind == alg_kind::convolution_direct
+                && this->desc()->dst_desc.data_type == dst_type
                 && IMPLICATION(this->with_bias(), utils::one_of(
-                    this->cdesc_().bias_desc.data_type, data_type::f32,
+                    this->desc()->bias_desc.data_type, data_type::f32,
                     data_type::s32, data_type::s8, data_type::u8))
-                && this->cdesc_().accum_data_type == data_type::s32;
+                && this->desc()->accum_data_type == data_type::s32;
             if (!ok) return status::unimplemented;
 
             return jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::init_conf(jcp_,
-                        this->cdesc_(),
-                        this->src_pd_.desc(), *this->weights_pd_.desc(),
+                        *this->desc(),
+                        *this->src_pd_.desc(), *this->weights_pd_.desc(),
                         *this->dst_pd_.desc(), *this->bias_pd_.desc(),
-                        *this->attr(), with_relu, this->negative_slope());
+                        *this->attr());
         }
 
         jit_conv_conf_t jcp_;
@@ -84,35 +84,34 @@ struct _jit_uni_x8s8s32x_dw_convolution_fwd_t: public cpu_primitive_t {
         }
     };
 
-    _jit_uni_x8s8s32x_dw_convolution_fwd_t(const pd_t *pd, const input_vector &inputs,
-                                    const output_vector &outputs)
-            : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
-    { kernel_ = new jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>(conf_.jcp_, *conf_.attr()); }
+    _jit_uni_x8s8s32x_dw_convolution_fwd_t(const pd_t *apd,
+            const input_vector &inputs, const output_vector &outputs)
+        : cpu_primitive_t(apd, inputs, outputs)
+    {
+        kernel_ = new jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>(pd()->jcp_, *pd()->attr());
+    }
+
     ~_jit_uni_x8s8s32x_dw_convolution_fwd_t() { delete kernel_; };
 
     typedef typename prec_traits<data_type::u8>::type src_data_t;
     typedef typename prec_traits<data_type::s8>::type wei_data_t;
     typedef typename prec_traits<dst_type>::type dst_data_t;
 
-    virtual void execute(event_t *e) {
+    virtual void execute(event_t *e) const {
         execute_forward();
         e->set_state(event_t::ready);
     }
 
 private:
-    void execute_forward();
-    pd_t conf_;
+    void execute_forward() const ;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
     jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa> *kernel_;
 };
 
 template <impl::data_type_t src_type, impl::data_type_t dst_type>
-using jit_avx2_x8s8s32x_dw_convolution_fwd_t = _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, false, src_type, dst_type>;
-template <impl::data_type_t src_type, impl::data_type_t dst_type>
-using jit_sse42_x8s8s32x_dw_convolution_fwd_t = _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, false, src_type, dst_type>;
-template <impl::data_type_t src_type, impl::data_type_t dst_type>
-using jit_avx2_x8s8s32x_dw_convolution_relu_t = _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, true, src_type, dst_type>;
+using jit_avx2_x8s8s32x_dw_convolution_fwd_t = _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, src_type, dst_type>;
 template <impl::data_type_t src_type, impl::data_type_t dst_type>
-using jit_sse42_x8s8s32x_dw_convolution_relu_t = _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, true, src_type, dst_type>;
+using jit_sse42_x8s8s32x_dw_convolution_fwd_t = _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, src_type, dst_type>;
 
 }
 }