Refactored TensorFlow subgraphs fusion
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 29 Apr 2019 15:55:09 +0000 (18:55 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 30 Apr 2019 06:21:05 +0000 (09:21 +0300)
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

index 7f10018..086f0ae 100644 (file)
@@ -79,9 +79,9 @@ public:
         }
     }
 
-    static const tensorflow::NodeDef& getInputNode(const tensorflow::GraphDef& net,
-                                                   const tensorflow::NodeDef& node,
-                                                   int inpId)
+    static int getInputNodeId(const tensorflow::GraphDef& net,
+                              const tensorflow::NodeDef& node,
+                              int inpId)
     {
         CV_Assert(inpId < node.input_size());
         std::string name = node.input(inpId);
@@ -92,7 +92,7 @@ public:
         for (int i = 0; i < numNodes; ++i)
         {
             if (net.node(i).name() == name)
-                return net.node(i);
+                return i;
         }
         CV_Error(Error::StsParseError, "Input node with name " + name + " not found");
     }
@@ -104,36 +104,46 @@ public:
         matchedNodesIds.clear();
         matchedNodesIds.reserve(nodesToFuse.size());
 
-        int numNodes = net.node_size();
-        for (int i = 0; i < nodesToFuse.size(); ++i)
+        std::queue<int> nodesToMatch;
+        std::queue<int> targetNodes;
+        nodesToMatch.push(nodeId);
+        targetNodes.push(nodesToFuse.back());
+        while (!nodesToMatch.empty())
         {
-            while (nodeId < numNodes && net.node(nodeId).op() == "Const")
-            {
-                nodeId += 1;
-            }
-            if (nodeId > numNodes - 1)
-                return false;
+            int nodeToMatch = nodesToMatch.front();
+            int targetNodeId = targetNodes.front();
+            nodesToMatch.pop();
+            targetNodes.pop();
 
-            const tensorflow::NodeDef& node = net.node(nodeId);
+            if (std::find(matchedNodesIds.begin(), matchedNodesIds.end(), nodeToMatch) !=
+                matchedNodesIds.end())
+                continue;
 
-            if (node.op() != nodes[nodesToFuse[i]])
+            const tensorflow::NodeDef& node = net.node(nodeToMatch);
+            if (node.op() != nodes[targetNodeId])
                 return false;
 
-            std::vector<int>& inputNodes = inputs[nodesToFuse[i]];
+            std::vector<int>& inputNodes = inputs[targetNodeId];
             if (inputNodes.size() != node.input_size())
                 return false;
+
             for (int j = 0; j < inputNodes.size(); ++j)
             {
                 if (nodes[inputNodes[j]].empty())  // Unknown input node type.
                     continue;
-                const tensorflow::NodeDef& inpNode = getInputNode(net, node, j);
-                if (inpNode.op() != nodes[inputNodes[j]])
+                nodeId = getInputNodeId(net, node, j);
+                const tensorflow::NodeDef& inpNode = net.node(nodeId);
+                if (inpNode.op() != "Const")
+                {
+                    nodesToMatch.push(nodeId);
+                    targetNodes.push(inputNodes[j]);
+                }
+                else if (nodes[inputNodes[j]] != "Const")
                     return false;
             }
-
-            matchedNodesIds.push_back(nodeId);
-            nodeId += 1;
+            matchedNodesIds.push_back(nodeToMatch);
         }
+        std::sort(matchedNodesIds.begin(), matchedNodesIds.end());
         return true;
     }
 
@@ -181,7 +191,7 @@ public:
         std::vector<tensorflow::NodeDef*> inputNodes(inputsNames.size());
         for (int i = 0; i < inputsNames.size(); ++i)
         {
-            inputNodes[i] = (tensorflow::NodeDef*)&getInputNode(net, *node, i);
+            inputNodes[i] = net.mutable_node(getInputNodeId(net, *node, i));
         }
         finalize(net, node, inputNodes);
     }
@@ -354,7 +364,7 @@ public:
     {
         if (!Subgraph::match(net, nodeId, matchedNodesIds))
             return false;
-        Mat maxValue = getTensorContent(net.node(nodeId + 1).attr().at("value").tensor());
+        Mat maxValue = getTensorContent(net.node(matchedNodesIds.front() + 1).attr().at("value").tensor());
         return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at<float>(0) == 6;
     }
 };
@@ -384,6 +394,17 @@ public:
         setFusedNode("Reshape", ids);
     }
 
+    virtual bool match(const tensorflow::GraphDef& net, int nodeId, std::vector<int>& matchedNodesIds) CV_OVERRIDE
+    {
+        const tensorflow::NodeDef& node = net.node(nodeId);
+        if (node.input_size() == 0)
+            return false;
+
+        inpName = node.input(0);
+        return Subgraph::match(net, nodeId, matchedNodesIds);
+    }
+
+
     virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
                           std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
     {
@@ -395,6 +416,7 @@ public:
         }
         tensorflow::TensorProto* shapeTensor = inputNodes[1]->mutable_attr()->at("value").mutable_tensor();
         fusedNode->mutable_input()->DeleteSubrange(2, numOutDims - 1);
+        fusedNode->set_input(0, inpName);
 
         shapeTensor->clear_int_val();
         for (int i = 0; i < shape.size(); ++i)
@@ -405,6 +427,7 @@ public:
 
 private:
     int numOutDims;
+    std::string inpName;
 };
 
 class L2NormalizeSubgraph : public Subgraph
@@ -685,9 +708,9 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
     subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph()));
     subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
     subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph()));
-    subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
     subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph()));
     subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimV2Subgraph()));
+    subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
 
     int numNodes = net.node_size();
     std::vector<int> matchedNodesIds;
index ef0b196..0ff155f 100644 (file)
@@ -1079,25 +1079,28 @@ void TFImporter::populateNet(Net dstNet)
             {
                 Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1));
 
-                if (newShape.total() != 4 && inpLayout == DATA_LAYOUT_NHWC)
-                {
-                    LayerParams permLP;
-                    int order[] = {0, 2, 3, 1};  // From OpenCV's NCHW to NHWC.
-                    permLP.set("order", DictValue::arrayInt<int*>(order, 4));
-
-                    std::string permName = name + "/nchw";
-                    CV_Assert(layer_id.find(permName) == layer_id.end());
-                    int permId = dstNet.addLayer(permName, "Permute", permLP);
-                    layer_id[permName] = permId;
-                    connect(layer_id, dstNet, inpId, permId, 0);
-                    inpId = Pin(permName);
-                    inpLayout = DATA_LAYOUT_NCHW;
-                }
-                else if (newShape.total() == 4 && inpLayout == DATA_LAYOUT_NHWC)
+                if (inpLayout == DATA_LAYOUT_NHWC)
                 {
-                    // NHWC->NCHW
-                    std::swap(*newShape.ptr<int32_t>(0, 2), *newShape.ptr<int32_t>(0, 3));
-                    std::swap(*newShape.ptr<int32_t>(0, 1), *newShape.ptr<int32_t>(0, 2));
+                    if (newShape.total() == 4)
+                    {
+                        // NHWC->NCHW
+                        std::swap(*newShape.ptr<int32_t>(0, 2), *newShape.ptr<int32_t>(0, 3));
+                        std::swap(*newShape.ptr<int32_t>(0, 1), *newShape.ptr<int32_t>(0, 2));
+                    }
+                    if (newShape.total() != 4 || newShape.at<int>(1) == 1)
+                    {
+                        LayerParams permLP;
+                        int order[] = {0, 2, 3, 1};  // From OpenCV's NCHW to NHWC.
+                        permLP.set("order", DictValue::arrayInt<int*>(order, 4));
+
+                        std::string permName = name + "/nchw";
+                        CV_Assert(layer_id.find(permName) == layer_id.end());
+                        int permId = dstNet.addLayer(permName, "Permute", permLP);
+                        layer_id[permName] = permId;
+                        connect(layer_id, dstNet, inpId, permId, 0);
+                        inpId = Pin(permName);
+                        inpLayout = DATA_LAYOUT_NCHW;
+                    }
                 }
                 layerParams.set("dim", DictValue::arrayInt<int*>(newShape.ptr<int>(), newShape.total()));
 
@@ -1335,7 +1338,9 @@ void TFImporter::populateNet(Net dstNet)
             // num_split
             // 1st blob is dims tensor
             int axis = getConstBlob(layer, value_id, 0).int_val().Get(0);
-            layerParams.set("axis", toNCHW(axis));
+            if (getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC)
+                axis = toNCHW(axis);
+            layerParams.set("axis", axis);
 
             int id = dstNet.addLayer(name, "Slice", layerParams);
             layer_id[name] = id;
index 8b750bb..4973008 100644 (file)
@@ -654,6 +654,13 @@ TEST_P(Test_TensorFlow_layers, relu6)
     runTensorFlowNet("keras_relu6", /*hasText*/ true);
 }
 
+TEST_P(Test_TensorFlow_layers, subpixel)
+{
+    if (backend == DNN_BACKEND_INFERENCE_ENGINE)
+        throw SkipTestException("");
+    runTensorFlowNet("subpixel");
+}
+
 TEST_P(Test_TensorFlow_layers, keras_mobilenet_head)
 {
     runTensorFlowNet("keras_mobilenet_head");