namespace impl {
namespace cpu {
-template <cpu_isa_t isa, bool with_relu, impl::data_type_t src_type, impl::data_type_t dst_type>
+template <cpu_isa_t isa, impl::data_type_t src_type, impl::data_type_t dst_type>
struct _jit_uni_x8s8s32x_dw_convolution_fwd_t: public cpu_primitive_t {
- struct pd_t: public _cpu_convolution_fwd_pd_t<with_relu> {
- pd_t(engine_t *engine, const typename pd_t::base_desc_t *adesc,
+ struct pd_t: public cpu_convolution_fwd_pd_t {
+ pd_t(engine_t *engine, const convolution_desc_t *adesc,
const primitive_attr_t *attr,
const typename pd_t::base_class *hint_fwd_pd)
- : _cpu_convolution_fwd_pd_t<with_relu>(engine, adesc, attr,
+ : cpu_convolution_fwd_pd_t(engine, adesc, attr,
hint_fwd_pd)
- , jcp_({}) {}
+ , jcp_() {}
DECLARE_COMMON_PD_T(
JIT_IMPL_NAME_HELPER("jit_dw:", isa, ""),
- _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, with_relu, src_type, dst_type>);
+ _jit_uni_x8s8s32x_dw_convolution_fwd_t<isa, src_type, dst_type>);
virtual status_t init() override {
using namespace prop_kind;
assert(this->engine()->kind() == engine_kind::cpu);
bool ok = true
&& this->set_default_params() == status::success
- && utils::one_of(this->cdesc_().prop_kind, forward_training,
+ && utils::one_of(this->desc()->prop_kind, forward_training,
forward_inference)
- && this->cdesc_().alg_kind == alg_kind::convolution_direct
- && this->cdesc_().dst_desc.data_type == dst_type
+ && this->desc()->alg_kind == alg_kind::convolution_direct
+ && this->desc()->dst_desc.data_type == dst_type
&& IMPLICATION(this->with_bias(), utils::one_of(
- this->cdesc_().bias_desc.data_type, data_type::f32,
+ this->desc()->bias_desc.data_type, data_type::f32,
data_type::s32, data_type::s8, data_type::u8))
- && this->cdesc_().accum_data_type == data_type::s32;
+ && this->desc()->accum_data_type == data_type::s32;
if (!ok) return status::unimplemented;
return jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>::init_conf(jcp_,
- this->cdesc_(),
- this->src_pd_.desc(), *this->weights_pd_.desc(),
+ *this->desc(),
+ *this->src_pd_.desc(), *this->weights_pd_.desc(),
*this->dst_pd_.desc(), *this->bias_pd_.desc(),
- *this->attr(), with_relu, this->negative_slope());
+ *this->attr());
}
jit_conv_conf_t jcp_;
}
};
- _jit_uni_x8s8s32x_dw_convolution_fwd_t(const pd_t *pd, const input_vector &inputs,
- const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd)
- { kernel_ = new jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>(conf_.jcp_, *conf_.attr()); }
+ _jit_uni_x8s8s32x_dw_convolution_fwd_t(const pd_t *apd,
+ const input_vector &inputs, const output_vector &outputs)
+ : cpu_primitive_t(apd, inputs, outputs)
+ {
+ kernel_ = new jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa>(pd()->jcp_, *pd()->attr());
+ }
+
~_jit_uni_x8s8s32x_dw_convolution_fwd_t() { delete kernel_; };
typedef typename prec_traits<data_type::u8>::type src_data_t;
typedef typename prec_traits<data_type::s8>::type wei_data_t;
typedef typename prec_traits<dst_type>::type dst_data_t;
- virtual void execute(event_t *e) {
+ virtual void execute(event_t *e) const {
execute_forward();
e->set_state(event_t::ready);
}
private:
- void execute_forward();
- pd_t conf_;
+ void execute_forward() const ;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
jit_uni_x8s8s32x_dw_conv_fwd_kernel<isa> *kernel_;
};
template <impl::data_type_t src_type, impl::data_type_t dst_type>
-using jit_avx2_x8s8s32x_dw_convolution_fwd_t = _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, false, src_type, dst_type>;
-template <impl::data_type_t src_type, impl::data_type_t dst_type>
-using jit_sse42_x8s8s32x_dw_convolution_fwd_t = _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, false, src_type, dst_type>;
-template <impl::data_type_t src_type, impl::data_type_t dst_type>
-using jit_avx2_x8s8s32x_dw_convolution_relu_t = _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, true, src_type, dst_type>;
+using jit_avx2_x8s8s32x_dw_convolution_fwd_t = _jit_uni_x8s8s32x_dw_convolution_fwd_t<avx2, src_type, dst_type>;
template <impl::data_type_t src_type, impl::data_type_t dst_type>
-using jit_sse42_x8s8s32x_dw_convolution_relu_t = _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, true, src_type, dst_type>;
+using jit_sse42_x8s8s32x_dw_convolution_fwd_t = _jit_uni_x8s8s32x_dw_convolution_fwd_t<sse42, src_type, dst_type>;
}
}