Manage TensorFlow's NHWC data layout is smoother
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 20 Dec 2017 11:13:40 +0000 (14:13 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 20 Dec 2017 11:13:40 +0000 (14:13 +0300)
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

index 202958d..66943dd 100644 (file)
@@ -42,6 +42,14 @@ namespace
 
 static int toNCHW[] = {0, 2, 3, 1};
 
+// This values are used to indicate layer output's data layout where it's possible.
+enum DataLayout
+{
+    DATA_LAYOUT_NHWC,
+    DATA_LAYOUT_NCHW,
+    DATA_LAYOUT_UNKNOWN
+};
+
 typedef std::vector<std::pair<String, int> > StrIntVector;
 
 struct Pin
@@ -608,6 +616,31 @@ static void addConstNodes(const tensorflow::GraphDef& net, std::map<String, int>
     }
 }
 
+// If all inputs of specific layer have the same data layout we can say that
+// this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
+static int predictOutputDataLayout(const tensorflow::NodeDef& layer, const std::map<String, int>& data_layouts)
+{
+    int layout = DATA_LAYOUT_UNKNOWN;
+    std::map<String, int>::const_iterator it;
+    for (int i = 0, n = layer.input_size(); i < n; ++i)
+    {
+        it = data_layouts.find(layer.input(i));
+        if (it != data_layouts.end())
+        {
+            if (it->second == DATA_LAYOUT_UNKNOWN)
+                return DATA_LAYOUT_UNKNOWN;
+            else if (it->second != layout)
+            {
+                if (layout == DATA_LAYOUT_UNKNOWN)
+                    layout = it->second;
+                else
+                    return DATA_LAYOUT_UNKNOWN;
+            }
+        }
+    }
+    return layout;
+}
+
 void TFImporter::populateNet(Net dstNet)
 {
     RemoveIdentityOps(netBin);
@@ -619,6 +652,8 @@ void TFImporter::populateNet(Net dstNet)
 
     int layersSize = net.node_size();
 
+    std::map<String, int> data_layouts;
+
     // find all Const layers for params
     std::map<String, int> value_id;
     addConstNodes(netBin, value_id, layers_to_ignore);
@@ -636,6 +671,8 @@ void TFImporter::populateNet(Net dstNet)
         if(layers_to_ignore.find(name) != layers_to_ignore.end())
             continue;
 
+        data_layouts[name] = predictOutputDataLayout(layer, data_layouts);
+
         if (type == "Conv2D" || type == "SpaceToBatchND" || type == "DepthwiseConv2dNative")
         {
             // The first node of dilated convolution subgraph.
@@ -731,6 +768,19 @@ void TFImporter::populateNet(Net dstNet)
 
             // one input only
             connect(layer_id, dstNet, parsePin(input), id, 0);
+
+            if (hasLayerAttr(layer, "data_format"))
+            {
+                std::string format = getLayerAttr(layer, "data_format").s();
+                if (format == "NHWC")
+                    data_layouts[name] = DATA_LAYOUT_NHWC;
+                else if (format == "NCHW")
+                    data_layouts[name] = DATA_LAYOUT_NCHW;
+                else
+                    CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
+            }
+            else
+                data_layouts[name] = DATA_LAYOUT_NHWC;
         }
         else if (type == "BiasAdd" || type == "Add")
         {
@@ -806,22 +856,55 @@ void TFImporter::populateNet(Net dstNet)
             // one input only
             int input_blob_index = kernel_blob_index == 0 ? 1 : 0;
             connect(layer_id, dstNet, parsePin(layer.input(input_blob_index)), id, 0);
+            data_layouts[name] = DATA_LAYOUT_UNKNOWN;
         }
         else if (type == "Reshape")
         {
-            layerParams.set("dim", parseDims(getConstBlob(layer, value_id, 1)));
+            Pin inpId = parsePin(layer.input(0));
+            DictValue newShape = parseDims(getConstBlob(layer, value_id, 1));
+
+            if (newShape.size() != 4 && data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
+            {
+                LayerParams permLP;
+                int order[] = {0, 2, 3, 1};  // From OpenCV's NCHW to NHWC.
+                permLP.set("order", DictValue::arrayInt<int*>(order, 4));
+
+                std::string permName = name + "/nchw";
+                CV_Assert(layer_id.find(permName) == layer_id.end());
+                int permId = dstNet.addLayer(permName, "Permute", permLP);
+                layer_id[permName] = permId;
+                connect(layer_id, dstNet, inpId, permId, 0);
+                inpId = Pin(permName);
+            }
+            layerParams.set("dim", newShape);
 
             int id = dstNet.addLayer(name, "Reshape", layerParams);
             layer_id[name] = id;
 
             // one input only
-            connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+            connect(layer_id, dstNet, inpId, id, 0);
+            data_layouts[name] = DATA_LAYOUT_UNKNOWN;
         }
         else if (type == "Flatten")
         {
+            Pin inpId = parsePin(layer.input(0));
+            if (data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
+            {
+                LayerParams permLP;
+                int order[] = {0, 2, 3, 1};  // From OpenCV's NCHW to NHWC.
+                permLP.set("order", DictValue::arrayInt<int*>(order, 4));
+
+                std::string permName = name + "/nchw";
+                CV_Assert(layer_id.find(permName) == layer_id.end());
+                int permId = dstNet.addLayer(permName, "Permute", permLP);
+                layer_id[permName] = permId;
+                connect(layer_id, dstNet, inpId, permId, 0);
+                inpId = Pin(permName);
+            }
             int id = dstNet.addLayer(name, "Flatten", layerParams);
             layer_id[name] = id;
-            connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+            connect(layer_id, dstNet, inpId, id, 0);
+            data_layouts[name] = DATA_LAYOUT_UNKNOWN;
         }
         else if (type == "Transpose")
         {
@@ -830,16 +913,57 @@ void TFImporter::populateNet(Net dstNet)
             int* permData = (int*)perm.data;
             if (perm.total() == 4)
             {
-                for (int i = 0; i < 4; ++i)
-                    permData[i] = toNCHW[permData[i]];
+                // Only NHWC <-> NCHW permutations are allowed. OpenCV is always
+                // keep NCHW layout this way.
+                if (data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
+                {
+                    if (permData[0] == 0 && permData[1] == 3 && permData[2] == 1 && permData[3] == 2)
+                    {
+                        // in TensorFlow: NHWC->NCHW
+                        // in OpenCV: NCHW->NCHW
+                        data_layouts[name] = DATA_LAYOUT_NCHW;
+                    }
+                    else if (permData[0] == 0 && permData[1] == 1 && permData[2] == 2 && permData[3] == 3)
+                    {
+                        // in TensorFlow: NHWC->NHWC
+                        // in OpenCV: NCHW->NCHW
+                        data_layouts[name] = DATA_LAYOUT_NHWC;
+                    }
+                    else
+                        CV_Assert(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
+                }
+                else if (data_layouts[layer.input(0)] == DATA_LAYOUT_NCHW)
+                {
+                    if (permData[0] == 0 && permData[1] == 2 && permData[2] == 3 && permData[3] == 1)
+                    {
+                        // in TensorFlow: NCHW->NHWC
+                        // in OpenCV: NCHW->NCHW
+                        data_layouts[name] = DATA_LAYOUT_NHWC;
+                    }
+                    else if (permData[0] == 0 && permData[1] == 1 && permData[2] == 2 && permData[3] == 3)
+                    {
+                        // in TensorFlow: NCHW->NCHW
+                        // in OpenCV: NCHW->NCHW
+                        data_layouts[name] = DATA_LAYOUT_NCHW;
+                    }
+                    else
+                        CV_Assert(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
+                }
+                int id = dstNet.addLayer(name, "Identity", layerParams);
+                layer_id[name] = id;
+                connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
             }
-            layerParams.set("order", DictValue::arrayInt<int*>(permData, perm.total()));
+            else
+            {
+                layerParams.set("order", DictValue::arrayInt<int*>(permData, perm.total()));
 
-            int id = dstNet.addLayer(name, "Permute", layerParams);
-            layer_id[name] = id;
+                int id = dstNet.addLayer(name, "Permute", layerParams);
+                layer_id[name] = id;
 
-            // one input only
-            connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+                // one input only
+                connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+                data_layouts[name] = DATA_LAYOUT_UNKNOWN;
+            }
         }
         else if (type == "Const")
         {
@@ -1207,6 +1331,7 @@ void TFImporter::populateNet(Net dstNet)
 
             // one input only
             connect(layer_id, dstNet, parsePin(layer.input(1)), id, 0);
+            data_layouts[name] = DATA_LAYOUT_UNKNOWN;
         }
         else if (type == "ResizeNearestNeighbor")
         {
@@ -1258,6 +1383,7 @@ void TFImporter::populateNet(Net dstNet)
             layer_id[name] = id;
             connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
             connect(layer_id, dstNet, parsePin(layer.input(1)), id, 1);
+            data_layouts[name] = DATA_LAYOUT_UNKNOWN;
         }
         else if (type == "DetectionOutput")
         {
@@ -1288,6 +1414,7 @@ void TFImporter::populateNet(Net dstNet)
             layer_id[name] = id;
             for (int i = 0; i < 3; ++i)
                 connect(layer_id, dstNet, parsePin(layer.input(i)), id, i);
+            data_layouts[name] = DATA_LAYOUT_UNKNOWN;
         }
         else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
                  type == "Relu" || type == "Elu" || type == "Softmax" ||
index 1badf74..05adbf0 100644 (file)
@@ -159,6 +159,8 @@ TEST(Test_TensorFlow, deconvolution)
 TEST(Test_TensorFlow, matmul)
 {
     runTensorFlowNet("matmul");
+    runTensorFlowNet("nhwc_reshape_matmul");
+    runTensorFlowNet("nhwc_transpose_reshape_matmul");
 }
 
 TEST(Test_TensorFlow, defun)