Add a planar data layout tracking for TensorFlow importer
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 29 Jun 2018 06:50:14 +0000 (09:50 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 29 Jun 2018 06:50:14 +0000 (09:50 +0300)
modules/dnn/src/tensorflow/tf_importer.cpp

index 6a4a0ab..1faa7fb 100644 (file)
@@ -51,7 +51,8 @@ enum DataLayout
 {
     DATA_LAYOUT_NHWC,
     DATA_LAYOUT_NCHW,
-    DATA_LAYOUT_UNKNOWN
+    DATA_LAYOUT_UNKNOWN,
+    DATA_LAYOUT_PLANAR  // 2-dimensional outputs (matmul, flatten, reshape to 2d)
 };
 
 typedef std::vector<std::pair<String, int> > StrIntVector;
@@ -948,7 +949,7 @@ void TFImporter::populateNet(Net dstNet)
             // one input only
             int input_blob_index = kernel_blob_index == 0 ? 1 : 0;
             connect(layer_id, dstNet, parsePin(layer.input(input_blob_index)), id, 0);
-            data_layouts[name] = DATA_LAYOUT_UNKNOWN;
+            data_layouts[name] = DATA_LAYOUT_PLANAR;
         }
         else if (type == "Reshape")
         {
@@ -981,7 +982,7 @@ void TFImporter::populateNet(Net dstNet)
 
             // one input only
             connect(layer_id, dstNet, inpId, id, 0);
-            data_layouts[name] = DATA_LAYOUT_UNKNOWN;
+            data_layouts[name] = newShape.total() == 2 ? DATA_LAYOUT_PLANAR : DATA_LAYOUT_UNKNOWN;
         }
         else if (type == "Flatten" || type == "Squeeze")
         {
@@ -1020,7 +1021,7 @@ void TFImporter::populateNet(Net dstNet)
             int id = dstNet.addLayer(name, "Flatten", layerParams);
             layer_id[name] = id;
             connect(layer_id, dstNet, inpId, id, 0);
-            data_layouts[name] = DATA_LAYOUT_UNKNOWN;
+            data_layouts[name] = DATA_LAYOUT_PLANAR;
         }
         else if (type == "Transpose")
         {