#include "utils.hpp"
#include "primitive_iterator.hpp"
-#define DECLARE_DECONVOLUTION_PD_t(impl_name, ...) \
- virtual pd_t *clone() const override { return new pd_t(*this); } \
- virtual status_t create_primitive(primitive_t **primitive, \
- const primitive_at_t *inputs, \
- const primitive_t **outputs) const override { \
- double ms = get_msec(); \
- using namespace prop_kind;\
- primitive_t::input_vector ins(inputs, inputs + this->n_inputs()); \
- primitive_t::output_vector outs(outputs, outputs + this->n_outputs()); \
- auto ret = safe_ptr_assign<primitive_t>(*primitive, \
- new (__VA_ARGS__)(this, ins, outs)); \
- primitive_t *conv_primitive; \
- if (this->desc()->prop_kind == backward_weights) {\
- primitive_at_t conv_inputs[2];\
- conv_inputs[0] = inputs[1];\
- conv_inputs[1] = inputs[0];\
- conv_pd_->create_primitive((&conv_primitive), conv_inputs, outputs);\
- } \
- else conv_pd_->create_primitive((&conv_primitive), inputs, outputs);\
- ((__VA_ARGS__ *)(*primitive))->conv_p_ = conv_primitive;\
- ms = get_msec() - ms; \
- if (mkldnn_verbose()->level >= 2) { \
- printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
- fflush(0); \
- } \
- return ret; \
- } \
-virtual const char *name() const override { return impl_name; }
-
-#define DECLARE_DECONVOLUTION_PD_T(impl_name, ...) \
- DECLARE_DECONVOLUTION_PD_t(impl_name, __VA_ARGS__)
-
-
namespace mkldnn {
namespace impl {
namespace cpu {
~pd_t() { delete conv_pd_; }
- DECLARE_DECONVOLUTION_PD_T("ref:any", ref_deconvolution_fwd_t);
+ DECLARE_DECONVOLUTION_PD_T(ref_deconvolution_fwd_t);
status_t init_convolution(){
using namespace memory_format;
convolution_desc_t cd;
status_t status;
- status = conv_descr_create(this->cdesc(), &cd);
+ status = conv_descr_create(this->desc(), &cd);
if (status != status::success) return status;
mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
bool conv_supports_bias_;
};
- ref_deconvolution_fwd_t(const pd_t *pd, const input_vector &inputs,
+ ref_deconvolution_fwd_t(const pd_t *apd, const input_vector &inputs,
const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), conv_p_(nullptr) {}
+ : cpu_primitive_t(apd, inputs, outputs), conv_p_(nullptr) {}
~ref_deconvolution_fwd_t() { delete this->conv_p_; }
- virtual void execute(event_t *e) {
- switch (conf_.desc()->prop_kind) {
+ virtual void execute(event_t *e) const {
+ switch (pd()->desc()->prop_kind) {
case prop_kind::forward_training:
case prop_kind::forward_inference:
(conv_p_)->execute(e);
- if (conf_.with_bias() && !conf_.conv_supports_bias_) {
- switch (conf_.dst_pd()->desc()->format) {
+ if (pd()->with_bias() && !pd()->conv_supports_bias_) {
+ switch (pd()->dst_pd()->desc()->format) {
case memory_format::nchw :
case memory_format::ncdhw :
compute_fwd_bias_ncdhw();
}
private:
- void compute_fwd_bias();
- void compute_fwd_bias_ncdhw();
- template <int blksize> void compute_fwd_bias_nCdhwXc();
- pd_t conf_;
+ void compute_fwd_bias() const;
+ void compute_fwd_bias_ncdhw() const;
+ template <int blksize> void compute_fwd_bias_nCdhwXc() const;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
primitive_t *conv_p_;
};
~pd_t() { delete conv_pd_; }
- DECLARE_DECONVOLUTION_PD_T("ref:any", ref_deconvolution_bwd_data_t);
+ DECLARE_DECONVOLUTION_PD_T(ref_deconvolution_bwd_data_t);
status_t init_convolution(){
using namespace memory_format;
convolution_desc_t cd;
status_t status;
- status = conv_descr_create(this->cdesc(), &cd);
+ status = conv_descr_create(this->desc(), &cd);
if (status != status::success) return status;
mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
}
primitive_desc_t *conv_pd_;
};
- ref_deconvolution_bwd_data_t(const pd_t *pd, const input_vector &inputs,
+ ref_deconvolution_bwd_data_t(const pd_t *apd, const input_vector &inputs,
const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), conv_p_(nullptr) {}
+ : cpu_primitive_t(apd, inputs, outputs), conv_p_(nullptr) {}
~ref_deconvolution_bwd_data_t() { delete this->conv_p_; }
- virtual void execute(event_t *e) {
- switch (conf_.desc()->prop_kind) {
+ virtual void execute(event_t *e) const {
+ switch (pd()->desc()->prop_kind) {
case prop_kind::backward_data:
(conv_p_)->execute(e);
break;
}
private:
- pd_t conf_;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
primitive_t *conv_p_;
};
~pd_t() { delete conv_pd_; }
- DECLARE_DECONVOLUTION_PD_T("ref:any", ref_deconvolution_bwd_weights_t);
+ DECLARE_DECONVOLUTION_PD_T(ref_deconvolution_bwd_weights_t);
status_t init_convolution(){
using namespace memory_format;
convolution_desc_t cd;
status_t status;
- status = conv_descr_create(this->cdesc(), &cd);
+ status = conv_descr_create(this->desc(), &cd);
if (status != status::success) return status;
mkldnn_primitive_desc_iterator it(this->engine_, (op_desc_t *)&cd,
primitive_desc_t *conv_pd_;
};
- ref_deconvolution_bwd_weights_t(const pd_t *pd, const input_vector &inputs,
+ ref_deconvolution_bwd_weights_t(const pd_t *apd, const input_vector &inputs,
const output_vector &outputs)
- : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd), conv_p_(nullptr) {}
+ : cpu_primitive_t(apd, inputs, outputs), conv_p_(nullptr) {}
~ref_deconvolution_bwd_weights_t() { delete this->conv_p_; }
typedef typename prec_traits<data_type::f32>::type data_t;
- virtual void execute(event_t *e) {
- switch (conf_.desc()->prop_kind) {
+ virtual void execute(event_t *e) const {
+ switch (pd()->desc()->prop_kind) {
case prop_kind::backward_weights:
(conv_p_)->execute(e);
- if (conf_.with_bias()) {
- switch (conf_.diff_dst_pd()->desc()->format) {
+ if (pd()->with_bias()) {
+ switch (pd()->diff_dst_pd()->desc()->format) {
case memory_format::nchw :
case memory_format::ncdhw :
compute_bwd_bias_ncdhw();
}
private:
- pd_t conf_;
+ const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
primitive_t *conv_p_;
- void compute_bwd_bias();
- void compute_bwd_bias_ncdhw();
- template <int blksize> void compute_bwd_bias_nCdhwXc();
+ void compute_bwd_bias() const;
+ void compute_bwd_bias_ncdhw() const;
+ template <int blksize> void compute_bwd_bias_nCdhwXc() const;
};
}