Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / primitive_attr.hpp
index 3f56d99..949449f 100644 (file)
 namespace mkldnn {
 namespace impl {
 
+struct rnn_data_qparams_t : public c_compatible {
+    rnn_data_qparams_t() : scale_(1.), shift_(0.) {}
+    bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); }
+
+    status_t set(float scale, float shift) {
+        scale_ = scale;
+        shift_ = shift;
+        return status::success;
+    }
+
+    float scale_;
+    float shift_;
+};
+
 struct scales_t: public c_compatible {
     scales_t(): count_(1), mask_(0), scales_(scales_buf_)
     { set(1.); }
@@ -54,7 +68,6 @@ struct scales_t: public c_compatible {
 
     status_t set(int count, int mask, const float *scales);
     status_t set(float single_scale) { return this->set(1, 0, &single_scale); }
-    status_t scale(float factor);
 
     int count_;
     int mask_;
@@ -79,13 +92,15 @@ private:
 
 struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
     struct entry_t {
+        struct eltwise_t {
+            mkldnn::impl::alg_kind_t alg;
+            float scale, alpha, beta;
+        };
+
         mkldnn::impl::primitive_kind_t kind;
         union {
             struct { float scale; } sum;
-            struct {
-                mkldnn::impl::alg_kind_t alg;
-                float scale, alpha, beta;
-            } eltwise;
+            eltwise_t eltwise;
             struct {
                 mkldnn::impl::alg_kind_t alg;
                 const float* weights_data;
@@ -101,34 +116,45 @@ struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
                 const float* weights_data;
                 const float* biases_data;
             } dw_conv;
+            struct {
+                mkldnn::impl::alg_kind_t alg;
+                const float* weights_data;
+            } binarization;
         };
 
+        bool is_eltwise(bool require_scale_one = true) const {
+            using namespace mkldnn::impl;
+            return kind == primitive_kind::eltwise
+                && IMPLICATION(require_scale_one, eltwise.scale == 1.f);
+        }
+
         bool is_relu(bool require_scale_one = true,
                 bool require_nslope_zero = true) const {
             using namespace mkldnn::impl;
-            return kind == primitive_kind::eltwise
-                && IMPLICATION(require_scale_one, eltwise.scale == 1.f)
+            return is_eltwise(require_scale_one)
                 && eltwise.alg == alg_kind::eltwise_relu
                 && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f);
         }
+
         bool is_sum(bool require_scale_one = true) const {
             using namespace mkldnn::impl;
             return kind == primitive_kind::sum
                 && IMPLICATION(require_scale_one, sum.scale == 1.f);
         }
-        bool is_eltwise(bool require_scale_one = true) const {
-            using namespace mkldnn::impl;
-            return kind == primitive_kind::eltwise
-                   && IMPLICATION(require_scale_one, eltwise.scale == 1.f);
-        }
+
         bool is_depthwise() const {
             using namespace mkldnn::impl;
             return kind == primitive_kind::depthwise;
         }
+
         bool is_dw_conv() const {
             using namespace mkldnn::impl;
             return kind == primitive_kind::convolution;
         }
+        bool is_binarization() const {
+            using namespace mkldnn::impl;
+            return kind == primitive_kind::binarization;
+        }
     };
 
     mkldnn_post_ops(): len_(0) {}
@@ -141,6 +167,7 @@ struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
     mkldnn::impl::status_t append_dw_conv(int in_h, int in_w, int ker_h, int ker_w, int str_h, int str_w,
                                           const float* weights_data,
                                           const float* biases_data);
+    mkldnn::impl::status_t append_binarization(mkldnn::impl::alg_kind_t alg, const float* weights_data);
 
     int find(mkldnn::impl::primitive_kind_t kind, int start = 0,
             int stop = -1) const {
@@ -173,7 +200,9 @@ struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible {
        return true
             && round_mode_ == mkldnn::impl::round_mode::nearest
             && output_scales_.has_default_values()
-            && post_ops_.has_default_values() ;
+            && post_ops_.has_default_values()
+            && rnn_data_qparams_.has_default_values()
+            && rnn_weights_qparams_.has_default_values();
     }
 
     mkldnn::impl::status_t set_round_mode(
@@ -184,6 +213,8 @@ struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible {
     mkldnn::impl::round_mode_t round_mode_;
     mkldnn::impl::scales_t output_scales_;
     mkldnn::impl::post_ops_t post_ops_;
+    mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_;
+    mkldnn::impl::scales_t rnn_weights_qparams_;
 };
 
 #endif