Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / type_helpers.hpp
index a7cf1a1..06a0e2f 100644 (file)
@@ -64,6 +64,7 @@ inline size_t data_type_size(data_type_t data_type) {
     case s16: return sizeof(prec_traits<s16>::type);
     case s8: return sizeof(prec_traits<s8>::type);
     case u8: return sizeof(prec_traits<u8>::type);
+    case bin: return sizeof(prec_traits<u8>::type);
     case data_type::undef:
     default: assert(!"unknown data_type");
     }
@@ -94,26 +95,32 @@ inline memory_format_t format_normalize(const memory_format_t fmt) {
             nc,
             ncw,
             nwc,
+            nCw4c,
             nCw8c,
             nCw16c,
             nchw,
             nhwc,
             chwn,
+            nChw4c,
             nChw8c,
             nChw16c,
             ncdhw,
             ndhwc,
+            nCdhw4c,
             nCdhw8c,
             nCdhw16c,
             oi,
             io,
             oiw,
             wio,
+            Owi4o,
+            OIw4i4o,
             Owi8o,
             OIw8i8o,
             OIw8o8i,
             OIw16i16o,
             OIw16o16i,
+            Oiw4o,
             Oiw16o,
             Owi16o,
             OIw8i16o2i,
@@ -122,20 +129,25 @@ inline memory_format_t format_normalize(const memory_format_t fmt) {
             oihw,
             ihwo,
             hwio,
+            iohw,
             hwio_s8s8,
             dhwio,
             oidhw,
+            OIdhw4i4o,
+            Odhwi4o,
             OIdhw8i8o,
             OIdhw8o8i,
             Odhwi8o,
             OIdhw16i16o,
             OIdhw16o16i,
+            Oidhw4o,
             Oidhw16o,
             Odhwi16o,
             oIhw8i,
             oIhw16i,
             oIdhw8i,
             oIdhw16i,
+            OIhw4i4o,
             OIhw8i8o,
             OIhw16i16o,
             OIhw4i16o4i,
@@ -145,18 +157,25 @@ inline memory_format_t format_normalize(const memory_format_t fmt) {
             OIhw8o16i2o,
             OIhw8o8i,
             OhIw8o4i,
+            OhIw8o32i,
+            OhIw16o32i,
             OhIw8o4i_s8s8,
             OIhw16o16i,
             IOhw16o16i,
+            Oihw4o,
             Oihw16o,
             Ohwi8o,
+            Ohwi4o,
             Ohwi16o,
             goiw,
+            gOwi4o,
+            gOIw4i4o,
             gOwi8o,
             gOIw8i8o,
             gOIw8o8i,
             gOIw16i16o,
             gOIw16o16i,
+            gOiw4o,
             gOiw16o,
             gOwi16o,
             gOIw8i16o2i,
@@ -164,31 +183,43 @@ inline memory_format_t format_normalize(const memory_format_t fmt) {
             gIOw16o16i,
             goihw,
             hwigo,
+            giohw,
             hwigo_s8s8,
+            gOIhw4i4o,
             gOIhw8i8o,
             gOIhw16i16o,
             gOIhw4i16o4i,
             gOIhw4i16o4i_s8s8,
+            gOIhw2i8o4i,
+            gOIhw2i8o4i_s8s8,
             gOIhw8i16o2i,
             gOIdhw8i16o2i,
             gOIhw8o16i2o,
+            gOIhw4o4i,
+            gOIhw4o4i_s8s8,
             gOIhw8o8i,
             gOhIw8o4i,
             gOhIw8o4i_s8s8,
             gOIhw16o16i,
             gIOhw16o16i,
+            gOihw4o,
             gOihw16o,
             gOhwi8o,
+            gOhwi4o,
             gOhwi16o,
             Goihw8g,
             Goihw16g,
+            Goihw16g_s8s8,
             goidhw,
+            gOIdhw4i4o,
+            gOdhwi4o,
             gOIdhw8i8o,
             gOIdhw8o8i,
             gOdhwi8o,
             gOIdhw16i16o,
             gOIdhw16o16i,
             gOidhw16o,
+            gOidhw4o,
             gOdhwi16o,
             ntc,
             tnc,
@@ -202,9 +233,9 @@ inline memory_format_t format_normalize(const memory_format_t fmt) {
 inline bool is_format_double_blocked(memory_format_t fmt) {
     using namespace memory_format;
     return utils::one_of(OIw8o16i2o, OIw8i16o2i, OIhw8i16o2i, OIdhw8i16o2i,
-            OIhw8o16i2o, OIhw4i16o4i, OIhw4i16o4i_s8s8, gOIw8o16i2o, gOIw8i16o2i,
-            gOIhw8i16o2i, gOIdhw8i16o2i, gOIhw8o16i2o, gOIhw4i16o4i,
-            gOIhw4i16o4i_s8s8);
+            OIhw8o16i2o, OIhw4i16o4i, OIhw4i16o4i_s8s8,
+            gOIw8o16i2o, gOIw8i16o2i, gOIhw8i16o2i, gOIdhw8i16o2i, gOIhw8o16i2o,
+            gOIhw4i16o4i, gOIhw4i16o4i_s8s8, gOIhw2i8o4i, gOIhw2i8o4i_s8s8);
 }
 
 inline bool blocking_desc_is_equal(const blocking_desc_t &lhs,
@@ -232,6 +263,22 @@ inline bool wino_desc_is_equal(const wino_data_t &lhs,
         && lhs.r == rhs.r;
 }
 
+inline bool rnn_packed_desc_is_equal(
+        const rnn_packed_data_t &lhs, const rnn_packed_data_t &rhs) {
+    bool ok = lhs.format == rhs.format && lhs.n_parts == rhs.n_parts
+            && lhs.offset_compensation == rhs.offset_compensation
+            && lhs.size == rhs.size
+            && lhs.n == rhs.n;
+    if (!ok)
+        return false;
+
+    for (int i = 0; i < rhs.n_parts; i++)
+        ok = ok && lhs.parts[i] == rhs.parts[i];
+    for (int i = 0; i < rhs.n_parts; i++)
+        ok = ok && lhs.part_pack_size[i] == rhs.part_pack_size[i];
+    return ok;
+}
+
 inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) {
     assert(lhs.primitive_kind == mkldnn::impl::primitive_kind::memory);
     assert(rhs.primitive_kind == mkldnn::impl::primitive_kind::memory);
@@ -247,6 +294,9 @@ inline bool operator==(const memory_desc_t &lhs, const memory_desc_t &rhs) {
     else if (lhs.format == memory_format::wino_fmt)
         return wino_desc_is_equal(lhs.layout_desc.wino_desc,
             rhs.layout_desc.wino_desc);
+    else if (lhs.format == memory_format::rnn_packed)
+        return rnn_packed_desc_is_equal(lhs.layout_desc.rnn_packed_desc,
+                rhs.layout_desc.rnn_packed_desc);
     return true;
 }
 
@@ -276,6 +326,7 @@ inline data_type_t default_accum_data_type(data_type_t src_dt,
     if (one_of(f32, src_dt, dst_dt)) return f32;
     if (one_of(s32, src_dt, dst_dt)) return s32;
     if (one_of(s16, src_dt, dst_dt)) return s32;
+    if (one_of(bin, src_dt, dst_dt)) return s32;
 
     if (one_of(s8, src_dt, dst_dt) || one_of(u8, src_dt, dst_dt)) return s32;
 
@@ -298,10 +349,13 @@ inline data_type_t default_accum_data_type(data_type_t src_dt,
         if ((src_dt == u8 || src_dt == s8)
             && wei_dt == s8 && one_of(dst_dt, f32, s32, s8, u8))
             return s32;
+        if (src_dt == bin && wei_dt == bin && (dst_dt == f32 || dst_dt == bin))
+            return s32;
     } else if (prop_kind == backward_data) {
         if (src_dt == s32 && wei_dt == s16 && dst_dt == s16)
             return s32;
-        if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 && dst_dt == u8)
+        if (one_of(src_dt, f32, s32, s8, u8) && wei_dt == s8 &&
+                one_of(dst_dt, s8, u8))
             return s32;
     } else if (prop_kind == backward_weights) {
         if (src_dt == s16 && wei_dt == s32 && dst_dt == s16)