Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / cpu_deconvolution_pd.hpp
index cd9cdfe..d236c23 100644 (file)
 #include "type_helpers.hpp"
 #include "utils.hpp"
 
+#define DECLARE_DECONVOLUTION_PD_t(...)                                        \
+    virtual pd_t *clone() const override { return new pd_t(*this); }           \
+    virtual status_t create_primitive(primitive_t **primitive,                 \
+            const primitive_at_t *inputs, const primitive_t **outputs)         \
+            const override {                                                   \
+        double ms = get_msec();                                                \
+        using namespace prop_kind;                                             \
+        primitive_t::input_vector ins(inputs, inputs + this->n_inputs());      \
+        primitive_t::output_vector outs(outputs, outputs + this->n_outputs()); \
+        auto ret = safe_ptr_assign<primitive_t>(                               \
+                *primitive, new (__VA_ARGS__)(this, ins, outs));               \
+        primitive_t *conv_primitive;                                           \
+        if (this->desc()->prop_kind == backward_weights) {                     \
+            primitive_at_t conv_inputs[2];                                     \
+            conv_inputs[0] = inputs[1];                                        \
+            conv_inputs[1] = inputs[0];                                        \
+            conv_pd_->create_primitive(                                        \
+                    (&conv_primitive), conv_inputs, outputs);                  \
+        } else                                                                 \
+            conv_pd_->create_primitive((&conv_primitive), inputs, outputs);    \
+        ((__VA_ARGS__ *)(*primitive))->conv_p_ = conv_primitive;               \
+        ms = get_msec() - ms;                                                  \
+        if (mkldnn_verbose()->level >= 2) {                                    \
+            printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms);         \
+            fflush(0);                                                         \
+        }                                                                      \
+        return ret;                                                            \
+    }                                                                          \
+    virtual const char *name() const override { return conv_pd_->name(); }
+
+#define DECLARE_DECONVOLUTION_PD_T(...) DECLARE_DECONVOLUTION_PD_t(__VA_ARGS__)
+
 namespace mkldnn {
 namespace impl {
 namespace cpu {