Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / memory_desc_wrapper.hpp
index 91e18cf..7c2f8ef 100644 (file)
@@ -46,12 +46,16 @@ struct memory_desc_wrapper: public c_compatible {
     memory_format_t format() const { return _md->format; }
     bool is_blocking_desc() const {
         return (format() != memory_format::wino_fmt
+                && format() != memory_format::rnn_packed
                 && format() != memory_format::any
                 && format() != memory_format::undef);
     }
     bool is_wino_desc() const {
         return (format() == memory_format::wino_fmt);
     }
+    bool is_rnn_packed_desc() const {
+        return (format() == memory_format::rnn_packed);
+    }
     const blocking_desc_t &blocking_desc() const {
         assert(is_blocking_desc());
         return _md->layout_desc.blocking;
@@ -60,6 +64,10 @@ struct memory_desc_wrapper: public c_compatible {
         assert(is_wino_desc());
         return _md->layout_desc.wino_desc;
     }
+    const rnn_packed_data_t &rnn_packed_desc() const {
+        assert(is_rnn_packed_desc());
+        return _md->layout_desc.rnn_packed_desc;
+    }
 
     /* some useful function */
 
@@ -67,7 +75,7 @@ struct memory_desc_wrapper: public c_compatible {
      * is true, and the number of data elements otherwise */
     size_t nelems(bool with_padding = false) const {
         if (is_zero()) return 0;
-        return (utils::array_product<int, size_t>(with_padding
+        return (utils::array_product<ptrdiff_t, size_t>(with_padding
                 ? blocking_desc().padding_dims : dims(), ndims()));
     }
 
@@ -85,7 +93,11 @@ struct memory_desc_wrapper: public c_compatible {
     size_t additional_buffer_data_size() const {
         using namespace mkldnn::impl::memory_format;
         return (utils::one_of(format(), hwio_s8s8, hwigo_s8s8,
-                    gOIhw4i16o4i_s8s8, OIhw4i16o4i_s8s8, OhIw8o4i_s8s8, gOhIw8o4i_s8s8))
+                    gOIhw4o4i_s8s8,
+                    gOIhw4i16o4i_s8s8, OIhw4i16o4i_s8s8,
+                    gOIhw2i8o4i_s8s8,
+                    gOhIw8o4i_s8s8, OhIw8o4i_s8s8,
+                    Goihw16g_s8s8))
             ? sizeof(int32_t) : 0;
     }
 
@@ -93,7 +105,11 @@ struct memory_desc_wrapper: public c_compatible {
     bool is_additional_buffer() const {
         using namespace mkldnn::impl::memory_format;
         return (utils::one_of(format(), hwio_s8s8, hwigo_s8s8,
-                    gOIhw4i16o4i_s8s8, OIhw4i16o4i_s8s8, OhIw8o4i_s8s8, gOhIw8o4i_s8s8))
+                    gOIhw4o4i_s8s8,
+                    gOIhw4i16o4i_s8s8, OIhw4i16o4i_s8s8,
+                    gOIhw2i8o4i_s8s8,
+                    gOhIw8o4i_s8s8, OhIw8o4i_s8s8,
+                    Goihw16g_s8s8))
             ? true : false;
     }
 
@@ -103,10 +119,13 @@ struct memory_desc_wrapper: public c_compatible {
         const auto &padding_dims = blocking_desc().padding_dims;
         switch(format()) {
             case hwigo_s8s8:
+            case gOIhw4o4i_s8s8:
+            case gOIhw2i8o4i_s8s8:
             case gOIhw4i16o4i_s8s8:
             case gOhIw8o4i_s8s8:
                 return size_t(padding_dims[0]) * size_t(padding_dims[1])
                     * additional_buffer_data_size();
+            case Goihw16g_s8s8:
             case hwio_s8s8:
             case OIhw4i16o4i_s8s8:
             case OhIw8o4i_s8s8:
@@ -126,11 +145,14 @@ struct memory_desc_wrapper: public c_compatible {
         assert((false
                     || types::format_normalize(format()) == blocked
                     || types::is_format_double_blocked(format())
-                    || format() == wino_fmt)
+                    || format() == wino_fmt
+                    || format() == rnn_packed)
                 && "unknown format");
 
         if (format() == wino_fmt) {
             return wino_desc().size;
+        } else if (format() == rnn_packed) {
+            return rnn_packed_desc().size;
         } else {
             if (blocking_desc().offset_padding != 0) return 0;
 
@@ -147,7 +169,8 @@ struct memory_desc_wrapper: public c_compatible {
                     max_size = nstl::max(max_size,
                             size_t(block * strides[1][d]));
             }
-            return max_size * data_type_size() + additional_buffer_size();;
+
+            return max_size * data_type_size() + additional_buffer_size();
         }
     }
 
@@ -231,6 +254,13 @@ struct memory_desc_wrapper: public c_compatible {
             const int ic_4  = pos[with_groups + 1] % 4;
             phys_offset += 4 * oc_16 + ic_4 - (oc_16 + 16 * ic_4);
         }
+        if (utils::one_of(format(), gOIhw2i8o4i,  gOIhw2i8o4i_s8s8)) {
+            // TODO: Fix temporary workaround for formats with double blocking
+            const bool with_groups = true;
+            const int oc_8 = pos[with_groups + 0] % 8;
+            const int ic_4 = pos[with_groups + 1] % 4;
+            phys_offset += 4 * oc_8 + ic_4 - (oc_8 + 8 * ic_4);
+        }
         if (format() == gOIw8i16o2i || format() == OIw8i16o2i) {
             // TODO: Fix temporary workaround for formats with double blocking
             const bool with_groups = format() == gOIw8i16o2i;
@@ -362,13 +392,18 @@ inline bool memory_desc_wrapper::operator==(const memory_desc_wrapper &rhs)
             && utils::array_cmp(dims(), rhs.dims(), ndims())
             && data_type() == rhs.data_type()
             && ((is_blocking_desc() && rhs.is_blocking_desc())
-                       || (is_wino_desc() && rhs.is_wino_desc()))
+                       || (is_wino_desc() && rhs.is_wino_desc())
+                       || (is_rnn_packed_desc() && rhs.is_rnn_packed_desc()))
             && (is_blocking_desc() ? blocking_desc_is_equal(blocking_desc(),
                                              rhs.blocking_desc(), ndims()) :
                                      true)
             && (is_wino_desc() ? wino_desc_is_equal(
                                          wino_desc(), rhs.wino_desc()) :
-                                 true);
+                                 true)
+            && (is_rnn_packed_desc() ?
+                               rnn_packed_desc_is_equal(rnn_packed_desc(),
+                                       rhs.rnn_packed_desc()) :
+                               true);
 }
 
 inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
@@ -377,7 +412,8 @@ inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
     using namespace utils;
     if (utils::one_of(format(), memory_format::undef, memory_format::any))
         return false;
-    if (is_wino_desc() || rhs.is_wino_desc())
+    if (is_wino_desc() || rhs.is_wino_desc() || is_rnn_packed_desc()
+            || rhs.is_rnn_packed_desc())
         return false;
 
     const int ds = dim_start;