Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / cpu / rnn / cpu_rnn_pd.hpp
 #define CPU_RNN_PD_HPP
 
 #include "c_types_map.hpp"
-#include "cpu_engine.hpp"
-#include "cpu_memory.hpp"
-#include "cpu_primitive.hpp"
+#include "../cpu_engine.hpp"
+#include "../cpu_memory.hpp"
+#include "../cpu_primitive.hpp"
 #include "nstl.hpp"
 #include "rnn_pd.hpp"
 #include "type_helpers.hpp"
 #include "utils.hpp"
+#include "rnn_utils.hpp"
 
 namespace mkldnn {
 namespace impl {
@@ -87,10 +88,6 @@ protected:
         using namespace memory_format;
         if (src_layer_pd_.desc()->format == any)
             CHECK(src_layer_pd_.set_format(tnc));
-        if (weights_layer_pd_.desc()->format == any)
-            CHECK(weights_layer_pd_.set_format(ldigo));
-        if (weights_iter_pd_.desc()->format == any)
-            CHECK(weights_iter_pd_.set_format(ldigo));
         if (dst_layer_pd_.desc()->format == any)
             CHECK(dst_layer_pd_.set_format(tnc));
 
@@ -104,14 +101,51 @@ protected:
 
         return status::success;
     }
+
+    status_t check_layout_consistency() {
+        using namespace memory_format;
+        using namespace utils;
+        using namespace data_type;
+        bool ok = true;
+        ok = ok && src_layer_pd_.desc()->format == tnc
+                && dst_layer_pd_.desc()->format == tnc;
+        ok = ok && IMPLICATION(!src_iter_pd_.is_zero(),
+                           src_iter_pd_.desc()->format == ldsnc)
+                && IMPLICATION(!dst_iter_pd_.is_zero(),
+                           dst_iter_pd_.desc()->format == ldsnc);
+
+        ok = ok && one_of(weights_layer_pd_.desc()->format, ldigo, rnn_packed)
+                && one_of(weights_iter_pd_.desc()->format, ldigo, rnn_packed);
+        ok = ok && IMPLICATION(weights_iter_pd_.desc()->format == rnn_packed,
+                           weights_iter_pd_.desc()
+                                           ->layout_desc.rnn_packed_desc.format
+                                   == mkldnn_ldigo_p);
+        ok = ok && IMPLICATION(weights_layer_pd_.desc()->format == rnn_packed,
+                           weights_layer_pd_.desc()
+                                           ->layout_desc.rnn_packed_desc.format
+                                   == mkldnn_ldigo_p);
+
+        ok = ok && IMPLICATION(!bias_pd_.is_zero(),
+                           bias_pd_.desc()->format == ldgo);
+
+        /* Int8 is supported only for packed weights */
+        data_type_t weights_iter_dt = weights_iter_pd_.desc()->data_type;
+        data_type_t weights_layer_dt = weights_layer_pd_.desc()->data_type;
+        ok = ok && IMPLICATION(weights_iter_dt == s8,
+                           weights_iter_pd_.desc()->format == rnn_packed);
+        ok = ok && IMPLICATION(weights_layer_dt == s8,
+                           weights_layer_pd_.desc()->format == rnn_packed);
+
+        return ok ? status::success : status::unimplemented;
+    }
 };
 
 struct cpu_rnn_bwd_pd_t : public rnn_bwd_pd_t {
     using cpu_memory_pd_t = cpu_memory_t::pd_t;
 
     cpu_rnn_bwd_pd_t(engine_t *engine, const rnn_desc_t *adesc,
-            const primitive_attr_t *attr, const rnn_bwd_pd_t *hint_bwd_pd)
-        : rnn_bwd_pd_t(engine, adesc, attr, hint_bwd_pd)
+            const primitive_attr_t *attr, const rnn_fwd_pd_t *hint_fwd_pd)
+        : rnn_bwd_pd_t(engine, adesc, attr, hint_fwd_pd)
         , src_layer_pd_(engine, &desc_.src_layer_desc)
         , src_iter_pd_(engine, &desc_.src_iter_desc)
         , weights_layer_pd_(engine, &desc_.weights_layer_desc)
@@ -203,14 +237,22 @@ protected:
             CHECK(src_layer_pd_.set_format(tnc));
         if (diff_src_layer_pd_.desc()->format == any)
             CHECK(diff_src_layer_pd_.set_format(tnc));
-        if (weights_layer_pd_.desc()->format == any)
-            CHECK(weights_layer_pd_.set_format(ldgoi));
-        if (diff_weights_layer_pd_.desc()->format == any)
-            CHECK(diff_weights_layer_pd_.set_format(ldigo));
-        if (weights_iter_pd_.desc()->format == any)
-            CHECK(weights_iter_pd_.set_format(ldgoi));
-        if (diff_weights_iter_pd_.desc()->format == any)
-            CHECK(diff_weights_iter_pd_.set_format(ldigo));
+        if (diff_weights_layer_pd_.desc()->format == any) {
+            memory_desc_t md = *(diff_weights_layer_pd_.desc());
+            md.format = ldigo;
+            CHECK(memory_desc_wrapper::compute_blocking(md));
+            CHECK(rnn_utils::set_good_strides(md));
+            cpu_memory_t::pd_t new_pd(engine_, &md);
+            diff_weights_layer_pd_ = new_pd;
+        }
+        if (diff_weights_iter_pd_.desc()->format == any) {
+            memory_desc_t md = *(diff_weights_iter_pd_.desc());
+            md.format = ldigo;
+            CHECK(memory_desc_wrapper::compute_blocking(md));
+            CHECK(rnn_utils::set_good_strides(md));
+            cpu_memory_t::pd_t new_pd(engine_, &md);
+            diff_weights_iter_pd_ = new_pd;
+        }
         if (dst_layer_pd_.desc()->format == any)
             CHECK(dst_layer_pd_.set_format(tnc));
         if (diff_dst_layer_pd_.desc()->format == any)
@@ -234,6 +276,45 @@ protected:
 
         return status::success;
     }
+
+    status_t check_layout_consistency() {
+        using namespace memory_format;
+        using namespace utils;
+        bool ok = true;
+        ok = ok && src_layer_pd_.desc()->format == tnc
+                && dst_layer_pd_.desc()->format == tnc;
+        ok = ok && IMPLICATION(!src_iter_pd_.is_zero(),
+                           src_iter_pd_.desc()->format == ldsnc)
+                && IMPLICATION(!dst_iter_pd_.is_zero(),
+                           dst_iter_pd_.desc()->format == ldsnc);
+
+        ok = ok && one_of(weights_layer_pd_.desc()->format, ldgoi, rnn_packed)
+                && one_of(weights_iter_pd_.desc()->format, ldgoi, rnn_packed);
+        ok = ok && IMPLICATION(weights_iter_pd_.desc()->format == rnn_packed,
+                           weights_iter_pd_.desc()
+                                           ->layout_desc.rnn_packed_desc.format
+                                   == mkldnn_ldgoi_p);
+        ok = ok && IMPLICATION(weights_layer_pd_.desc()->format == rnn_packed,
+                           weights_layer_pd_.desc()
+                                           ->layout_desc.rnn_packed_desc.format
+                                   == mkldnn_ldgoi_p);
+
+        ok = ok && IMPLICATION(!bias_pd_.is_zero(),
+                           bias_pd_.desc()->format == ldgo);
+
+        ok = ok && diff_src_layer_pd_.desc()->format == tnc
+                && diff_dst_layer_pd_.desc()->format == tnc;
+        ok = ok && IMPLICATION(!diff_states_pd_.is_zero(),
+                           diff_states_pd_.desc()->format == ldsnc)
+                && IMPLICATION(!diff_dst_iter_pd_.is_zero(),
+                           diff_dst_iter_pd_.desc()->format == ldsnc);
+        ok = ok && diff_weights_layer_pd_.desc()->format == ldigo
+                && diff_weights_iter_pd_.desc()->format == ldigo;
+        ok = ok && IMPLICATION(!diff_bias_pd_.is_zero(),
+                           diff_bias_pd_.desc()->format == ldgo);
+
+        return ok ? status::success : status::unimplemented;
+    }
 };
 }
 }