Publishing 2019 R1 content
[platform/upstream/dldt.git] / inference-engine / thirdparty / mkl-dnn / src / common / format_traits.hpp
index 0a13a33..7afe129 100644 (file)
@@ -35,12 +35,17 @@ enum class data_kind_t {
 
 enum class block_format_t {
     _,
+    _4c, _4i, _4o,
     _8c, _8g, _8i, _8o,
-    _8i8o, _8o8i, _8o4i, _8o4i_s8s8,
-    _16c, _16g, _16i, _16o,
+    _4i4o, _4o4i, _4o4i_s8s8,
+    _8i8o, _8o8i,
+    _8o4i, _8o4i_s8s8,
+    _8o32i, _16o32i,
+    _16c, _16g, _16g_s8s8, _16i, _16o,
     _16i16o, _16o16i,
     _8i16o2i, _8o16i2o,
     _4i16o4i, _4i16o4i_s8s8,
+    _2i8o4i, _2i8o4i_s8s8
 };
 
 template <block_format_t f> struct block_format_traits {
@@ -48,15 +53,20 @@ template <block_format_t f> struct block_format_traits {
     static constexpr int levels = f == bf::_
         ? 0
         : utils::one_of(f, bf::_8i16o2i, bf::_8o16i2o,
-                           bf::_4i16o4i, bf::_4i16o4i_s8s8) ? 2 : 1;
+                           bf::_4i16o4i, bf::_4i16o4i_s8s8,
+                           bf::_2i8o4i, bf::_2i8o4i_s8s8) ? 2 : 1;
     static constexpr int blk_ndims = f == bf::_
         ? 0
-        : utils::one_of(f, bf::_8c, bf::_8g, bf::_8i, bf::_8o, bf::_16c,
-                bf::_16g, bf::_16i, bf::_16o) ? 1 : 2;
+        : utils::one_of(f, bf::_4c, bf::_4i, bf::_4o, bf::_8c, bf::_8g, bf::_8i, bf::_8o, bf::_16c,
+                bf::_16g, bf::_16g_s8s8, bf::_16i, bf::_16o) ? 1 : 2;
     static constexpr int blk_size = f == bf::_
         ? 1
-        : utils::one_of(f, bf::_8c, bf::_8g, bf::_8i, bf::_8o, bf::_8i8o,
-                bf::_8o8i, bf::_8o4i, bf::_8o4i_s8s8) ? 8 : 16;
+        : (utils::one_of(f, bf::_4c, bf::_4i, bf::_4o, bf::_4i4o, bf::_4o4i, bf::_4o4i_s8s8) ? 4
+                : (utils::one_of(f, bf::_8c, bf::_8g, bf::_8i, bf::_8o,
+                        bf::_8i8o, bf::_8o8i,
+                        bf::_8o4i, bf::_8o4i_s8s8,
+                        bf::_2i8o4i, bf::_2i8o4i_s8s8,
+                        bf::_8o32i) ? 8 : 16));
 };
 
 template <memory_format_t> struct format_traits {
@@ -64,7 +74,7 @@ template <memory_format_t> struct format_traits {
     // block_format_t blk_fmt;  -- the format of blocks (e.g. 8c or 4i16o4i)
     // int ndims;               -- # of dimensions
     // int ndims_sp;            -- # of spatial dimensions
-    // int blk_size;            -- block size (1, 8, or 16)
+    // int blk_size;            -- block size (1, 4, 8, or 16)
 };
 
 #define DECL_TRAITS(_fmt, _data_kind, _blk_fmt, _ndims, _ndims_sp) \
@@ -87,6 +97,7 @@ DECL_TRAITS(nc, data, _, 2, 0);
 /* data: 3D */
 DECL_TRAITS(ncw, data, _, 3, 1);
 DECL_TRAITS(nwc, data, _, 3, 1);
+DECL_TRAITS(nCw4c, data, _4c, 3, 1);
 DECL_TRAITS(nCw8c, data, _8c, 3, 1);
 DECL_TRAITS(nCw16c, data, _16c, 3, 1);
 
@@ -94,12 +105,14 @@ DECL_TRAITS(nCw16c, data, _16c, 3, 1);
 DECL_TRAITS(nchw, data, _, 4, 2);
 DECL_TRAITS(nhwc, data, _, 4, 2);
 DECL_TRAITS(chwn, data, _, 4, 2);
+DECL_TRAITS(nChw4c, data, _4c, 4, 2);
 DECL_TRAITS(nChw8c, data, _8c, 4, 2);
 DECL_TRAITS(nChw16c, data, _16c, 4, 2);
 
 /* data: 5D */
 DECL_TRAITS(ncdhw, data, _, 5, 3);
 DECL_TRAITS(ndhwc, data, _, 5, 3);
+DECL_TRAITS(nCdhw4c, data, _4c, 5, 3);
 DECL_TRAITS(nCdhw8c, data, _8c, 5, 3);
 DECL_TRAITS(nCdhw16c, data, _16c, 5, 3);
 
@@ -110,11 +123,14 @@ DECL_TRAITS(io, wei, _, 2, 0);
 /* wei: 3D */
 DECL_TRAITS(oiw, wei, _, 3, 1);
 DECL_TRAITS(wio, wei, _, 3, 1);
+DECL_TRAITS(Owi4o, wei, _4o, 3, 1);
+DECL_TRAITS(OIw4i4o, wei, _4i4o, 3, 1);
 DECL_TRAITS(Owi8o, wei, _8o, 3, 1);
 DECL_TRAITS(OIw8i8o, wei, _8i8o, 3, 1);
 DECL_TRAITS(OIw8o8i, wei, _8o8i, 3, 1);
 DECL_TRAITS(OIw16i16o, wei, _16i16o, 3, 1);
 DECL_TRAITS(OIw16o16i, wei, _16o16i, 3, 1);
+DECL_TRAITS(Oiw4o, wei, _4o, 3, 1);
 DECL_TRAITS(Oiw16o, wei, _16o, 3, 1);
 DECL_TRAITS(Owi16o, wei, _16o, 3, 1);
 DECL_TRAITS(OIw8i16o2i, wei, _8i16o2i, 3, 1);
@@ -125,10 +141,14 @@ DECL_TRAITS(OIw8o16i2o, wei, _8o16i2o, 3, 1);
 DECL_TRAITS(oihw, wei, _, 4, 2);
 DECL_TRAITS(ihwo, wei, _, 4, 2);
 DECL_TRAITS(hwio, wei, _, 4, 2);
+DECL_TRAITS(iohw, wei, _, 4, 2);
 DECL_TRAITS(hwio_s8s8, wei, _, 4, 2);
 DECL_TRAITS(oIhw8i, wei, _8i, 4, 2);
 DECL_TRAITS(oIhw16i, wei, _16i, 4, 2);
+DECL_TRAITS(OIhw4i4o, wei, _4i4o, 4, 2);
 DECL_TRAITS(OIhw8i8o, wei, _8i8o, 4, 2);
+DECL_TRAITS(OhIw8o32i, wei, _8o32i, 4, 2);
+DECL_TRAITS(OhIw16o32i, wei, _16o32i, 4, 2);
 DECL_TRAITS(OhIw8o4i, wei, _8o4i, 4, 2);
 DECL_TRAITS(OhIw8o4i_s8s8, wei, _8o4i_s8s8, 4, 2);
 DECL_TRAITS(OIhw16i16o, wei, _16i16o, 4, 2);
@@ -139,18 +159,23 @@ DECL_TRAITS(OIhw8o16i2o, wei, _8o16i2o, 4, 2);
 DECL_TRAITS(OIhw8o8i, wei, _8o8i, 4, 2);
 DECL_TRAITS(OIhw16o16i, wei, _16o16i, 4, 2);
 DECL_TRAITS(IOhw16o16i, wei, _16o16i, 4, 2);
+DECL_TRAITS(Oihw4o, wei, _4o, 4, 2);
 DECL_TRAITS(Oihw16o, wei, _16o, 4, 2);
 DECL_TRAITS(Ohwi8o, wei, _8o, 4, 2);
+DECL_TRAITS(Ohwi4o, wei, _4o, 4, 2);
 DECL_TRAITS(Ohwi16o, wei, _16o, 4, 2);
 
 /* wei: 5D */
 DECL_TRAITS(dhwio, wei, _, 5, 3);
 DECL_TRAITS(oidhw, wei, _, 5, 3);
+DECL_TRAITS(OIdhw4i4o, wei, _4i4o, 5, 3);
+DECL_TRAITS(Odhwi4o, wei, _4o, 5, 3);
 DECL_TRAITS(OIdhw8i8o, wei, _8i8o, 5, 3);
 DECL_TRAITS(OIdhw8o8i, wei, _8o8i, 5, 3);
 DECL_TRAITS(Odhwi8o, wei, _8o, 5, 3);
 DECL_TRAITS(OIdhw16i16o, wei, _16i16o, 5, 3);
 DECL_TRAITS(OIdhw16o16i, wei, _16o16i, 5, 3);
+DECL_TRAITS(Oidhw4o, wei, _4o, 5, 3);
 DECL_TRAITS(Oidhw16o, wei, _16o, 5, 3);
 DECL_TRAITS(Odhwi16o, wei, _16o, 5, 3);
 DECL_TRAITS(oIdhw8i, wei, _8i, 5, 3);
@@ -159,11 +184,14 @@ DECL_TRAITS(OIdhw8i16o2i, wei, _8i16o2i, 5, 3);
 
 /* gwei: 4D */
 DECL_TRAITS(goiw, gwei, _, 4, 1);
+DECL_TRAITS(gOwi4o, gwei, _4o, 4, 1);
+DECL_TRAITS(gOIw4i4o, gwei, _4i4o, 4, 1);
 DECL_TRAITS(gOwi8o, gwei, _8o, 4, 1);
 DECL_TRAITS(gOIw8i8o, gwei, _8i8o, 4, 1);
 DECL_TRAITS(gOIw8o8i, gwei, _8o8i, 4, 1);
 DECL_TRAITS(gOIw16i16o, gwei, _16i16o, 4, 1);
 DECL_TRAITS(gOIw16o16i, gwei, _16o16i, 4, 1);
+DECL_TRAITS(gOiw4o, gwei, _4o, 4, 1);
 DECL_TRAITS(gOiw16o, gwei, _16o, 4, 1);
 DECL_TRAITS(gOwi16o, gwei, _16o, 4, 1);
 DECL_TRAITS(gOIw8i16o2i, gwei, _8i16o2i, 4, 1);
@@ -173,32 +201,43 @@ DECL_TRAITS(gOIw8o16i2o, gwei, _8o16i2o, 4, 1);
 /* gwei: 5D */
 DECL_TRAITS(goihw, gwei, _, 5, 2);
 DECL_TRAITS(hwigo, gwei, _, 5, 2);
+DECL_TRAITS(giohw, gwei, _, 5, 2);
 DECL_TRAITS(hwigo_s8s8, gwei, _, 5, 2);
+DECL_TRAITS(gOIhw4i4o, gwei, _4i4o, 5, 2);
 DECL_TRAITS(gOIhw8i8o, gwei, _8i8o, 5, 2);
 DECL_TRAITS(gOhIw8o4i, gwei, _8o4i, 5, 2);
 DECL_TRAITS(gOhIw8o4i_s8s8, gwei, _8o4i_s8s8, 5, 2);
 DECL_TRAITS(gOIhw16i16o, gwei, _16i16o, 5, 2);
 DECL_TRAITS(gOIhw4i16o4i, gwei, _4i16o4i, 5, 2);
 DECL_TRAITS(gOIhw4i16o4i_s8s8, gwei, _4i16o4i_s8s8, 5, 2);
+DECL_TRAITS(gOIhw2i8o4i, gwei, _2i8o4i, 5, 2);
+DECL_TRAITS(gOIhw2i8o4i_s8s8, gwei, _2i8o4i_s8s8, 5, 2);
 DECL_TRAITS(gOIhw8i16o2i, gwei, _8i16o2i, 5, 2);
 DECL_TRAITS(gOIdhw8i16o2i, gwei, _8i16o2i, 5, 2);
 DECL_TRAITS(gOIhw8o16i2o, gwei, _8o16i2o, 5, 2);
 DECL_TRAITS(gOIhw8o8i, gwei, _8o8i, 5, 2);
+DECL_TRAITS(gOIhw4o4i, gwei, _4o4i, 5, 2);
+DECL_TRAITS(gOIhw4o4i_s8s8, gwei, _4o4i_s8s8, 5, 2);
 DECL_TRAITS(gOIhw16o16i, gwei, _16o16i, 5, 2);
 DECL_TRAITS(gIOhw16o16i, gwei, _16o16i, 5, 2);
+DECL_TRAITS(gOihw4o, gwei, _4o, 5, 2);
 DECL_TRAITS(gOihw16o, gwei, _16o, 5, 2);
 DECL_TRAITS(gOhwi8o, gwei, _8o, 5, 2);
+DECL_TRAITS(gOhwi4o, gwei, _4o, 5, 2);
 DECL_TRAITS(gOhwi16o, gwei, _16o, 5, 2);
 DECL_TRAITS(Goihw8g, gwei, _8g, 5, 2);
 DECL_TRAITS(Goihw16g, gwei, _16g, 5, 2);
+DECL_TRAITS(Goihw16g_s8s8, gwei, _16g_s8s8, 5, 2);
 
 /* gwei: 6D */
 DECL_TRAITS(goidhw, gwei, _, 6, 3);
+DECL_TRAITS(gOIdhw4i4o, gwei, _4i4o, 6, 3);
 DECL_TRAITS(gOIdhw8i8o, gwei, _8i8o, 6, 3);
 DECL_TRAITS(gOIdhw8o8i, gwei, _8o8i, 6, 3);
 DECL_TRAITS(gOdhwi8o, gwei, _8o, 6, 3);
 DECL_TRAITS(gOIdhw16i16o, gwei, _16i16o, 6, 3);
 DECL_TRAITS(gOIdhw16o16i, gwei, _16o16i, 6, 3);
+DECL_TRAITS(gOidhw4o, gwei, _4o, 6, 3);
 DECL_TRAITS(gOidhw16o, gwei, _16o, 6, 3);
 DECL_TRAITS(gOdhwi16o, gwei, _16o, 6, 3);
 
@@ -216,21 +255,28 @@ DECL_TRAITS(ldgo, rnn, _, 4, 0);
 template <block_format_t f>
 constexpr int OI_blk_off(int oc, int ic) {
     using bf = block_format_t;
-    static_assert(utils::one_of(f, bf::_8i8o, bf::_8o8i, bf::_8o4i, bf::_8o4i_s8s8,
-                bf::_16i16o, bf::_16o16i, bf::_8i16o2i, bf::_8o16i2o,
-                bf::_4i16o4i, bf::_4i16o4i_s8s8),
+    static_assert(utils::one_of(f, bf::_4i4o, bf::_4o4i, bf::_4o4i_s8s8,
+                bf::_8i8o, bf::_8o8i, bf::_16i16o,
+                bf::_16o16i, bf::_8i16o2i, bf::_8o16i2o,
+                bf::_4i16o4i, bf::_4i16o4i_s8s8,
+                bf::_2i8o4i, bf::_2i8o4i_s8s8,
+                bf::_8o4i, bf::_8o4i_s8s8,
+                bf::_8o32i, bf::_16o32i),
             "unexpected blocked format");
 #   define blksize block_format_traits<f>::blk_size
     return f == bf::_8i16o2i
         ? (ic / 2) * blksize * 2 + 2 * oc + ic % 2
-        : (f == bf::_4i16o4i || f == bf::_4i16o4i_s8s8)
+        : (f == bf::_4i16o4i || f == bf::_4i16o4i_s8s8
+        || f == bf::_2i8o4i || f == bf::_2i8o4i_s8s8)
         ? (ic / 4) * blksize * 4 + oc * 4 + ic % 4
         : f == bf::_8o16i2o
         ? (oc / 2) * blksize * 2 + 2 * ic + oc % 2
-        : utils::one_of(f, bf::_8i8o, bf::_16i16o)
+        : utils::one_of(f, bf::_4i4o, bf::_8i8o, bf::_16i16o)
         ? ic * blksize + oc
         : (f == bf::_8o4i || f == bf::_8o4i_s8s8)
         ? (ic / 4) * blksize * 4 + 4 * oc + ic % 4
+        : (f == bf::_8o32i || f == bf::_16o32i)
+        ? 32 * oc + 32
         : oc * blksize + ic;
 #   undef blksize // if only we program in C++14...
 }