Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / ref_eltwise.hpp
index bd90dc1..718844b 100644 (file)
@@ -31,13 +31,16 @@ namespace cpu {
 
 struct ref_eltwise_scalar_fwd_t {
 public:
-    ref_eltwise_scalar_fwd_t(const alg_kind_t alg, float alpha, float beta);
+    ref_eltwise_scalar_fwd_t(alg_kind_t alg, float alpha, float beta);
+
+    // note that eltwise.scale is ignored
+    ref_eltwise_scalar_fwd_t(const post_ops_t::entry_t::eltwise_t &eltwise);
+
     float compute_scalar(float s);
 
-private:
-    alg_kind_t alg;
-    float alpha;
-    float beta;
+    const alg_kind_t alg_;
+    const float alpha_;
+    const float beta_;
 };
 
 template <impl::data_type_t data_type>
@@ -87,15 +90,15 @@ struct ref_eltwise_fwd_t: public cpu_primitive_t {
         bool use_dense_, use_nCspBc_padded_;
     };
 
-    ref_eltwise_fwd_t(const pd_t *pd, const input_vector &inputs,
+    ref_eltwise_fwd_t(const pd_t *apd, const input_vector &inputs,
             const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {}
+        : cpu_primitive_t(apd, inputs, outputs) {}
     typedef typename prec_traits<data_type>::type data_t;
 
-    virtual void execute(event_t *e) {
-        if (conf_.use_dense_)
+    virtual void execute(event_t *e) const {
+        if (pd()->use_dense_)
             execute_forward_dense();
-        else if (conf_.use_nCspBc_padded_)
+        else if (pd()->use_nCspBc_padded_)
             execute_forward_nCspBc_padded();
         else
             execute_forward_generic();
@@ -103,10 +106,10 @@ struct ref_eltwise_fwd_t: public cpu_primitive_t {
     }
 
 private:
-    void execute_forward_nCspBc_padded();
-    void execute_forward_dense();
-    void execute_forward_generic();
-    pd_t conf_;
+    void execute_forward_nCspBc_padded() const;
+    void execute_forward_dense() const;
+    void execute_forward_generic() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
 };
 
 template <impl::data_type_t data_type>
@@ -142,27 +145,30 @@ struct ref_eltwise_bwd_t: public cpu_primitive_t {
             if (use_generic && !one_of(diff_dst_d.ndims(), 4, 5))
                 return status::unimplemented;
 
+            if (desc()->alg_kind == mkldnn_eltwise_not)
+                return status::unimplemented;
+
             return status::success;
         }
 
         bool use_dense_;
     };
 
-    ref_eltwise_bwd_t(const pd_t *pd, const input_vector &inputs,
+    ref_eltwise_bwd_t(const pd_t *apd, const input_vector &inputs,
             const output_vector &outputs)
-        : cpu_primitive_t(&conf_, inputs, outputs), conf_(*pd) {}
+        : cpu_primitive_t(apd, inputs, outputs) {}
     typedef typename prec_traits<data_type>::type data_t;
 
-    virtual void execute(event_t *e) {
-        if (conf_.use_dense_) execute_backward_dense();
+    virtual void execute(event_t *e) const {
+        if (pd()->use_dense_) execute_backward_dense();
         else execute_backward_generic();
         e->set_state(event_t::ready);
     }
 
 private:
-    void execute_backward_dense();
-    void execute_backward_generic();
-    pd_t conf_;
+    void execute_backward_dense() const;
+    void execute_backward_generic() const;
+    const pd_t *pd() const { return (const pd_t *)primitive_t::pd(); }
 };
 
 }