Remove Switch and Merge nodes from TensorFlow networks
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 3 Apr 2019 10:42:06 +0000 (13:42 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 5 Apr 2019 09:32:35 +0000 (12:32 +0300)
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
modules/dnn/src/tensorflow/tf_graph_simplifier.hpp
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

index 59d0d57..37e5750 100644 (file)
@@ -10,6 +10,7 @@
 #ifdef HAVE_PROTOBUF
 
 #include "tf_graph_simplifier.hpp"
+#include <queue>
 
 namespace cv { namespace dnn {
 CV__DNN_EXPERIMENTAL_NS_BEGIN
@@ -883,7 +884,6 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
         nodesToAdd.pop_back();
 
         permIds.push_back(nodeToAdd);
-        // std::cout << net.node(nodeToAdd).name() << '\n';
 
         for (int i = 0; i < edges[nodeToAdd].size(); ++i)
         {
@@ -902,6 +902,85 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
     permute(net.mutable_node(), permIds);
 }
 
+// Remove training switches (Switch and Merge nodes and corresponding subgraphs).
+void removePhaseSwitches(tensorflow::GraphDef& net)
+{
+    std::vector<int> nodesToRemove;
+    std::map<std::string, int> nodesMap;
+    std::map<std::string, int>::iterator nodesMapIt;
+    std::queue<int> mergeOpSubgraphNodes;
+    for (int i = 0; i < net.node_size(); ++i)
+    {
+        const tensorflow::NodeDef& node = net.node(i);
+        nodesMap.insert(std::make_pair(node.name(), i));
+        if (node.op() == "Switch" || node.op() == "Merge")
+        {
+            CV_Assert(node.input_size() > 0);
+            // Replace consumers' inputs.
+            for (int j = 0; j < net.node_size(); ++j)
+            {
+                tensorflow::NodeDef* consumer = net.mutable_node(j);
+                for (int k = 0; k < consumer->input_size(); ++k)
+                {
+                    std::string inpName = consumer->input(k);
+                    inpName = inpName.substr(0, inpName.rfind(':'));
+                    if (inpName == node.name())
+                    {
+                        consumer->set_input(k, node.input(0));
+                    }
+                }
+            }
+            nodesToRemove.push_back(i);
+            if (node.op() == "Merge")
+                mergeOpSubgraphNodes.push(i);
+        }
+    }
+
+    std::vector<int> numConsumers(net.node_size(), 0);
+    for (int i = 0; i < net.node_size(); ++i)
+    {
+        const tensorflow::NodeDef& node = net.node(i);
+        for (int j = 0; j < node.input_size(); ++j)
+        {
+            std::string inpName = node.input(j);
+            inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':'));
+            nodesMapIt = nodesMap.find(inpName);
+            CV_Assert(nodesMapIt != nodesMap.end());
+            numConsumers[nodesMapIt->second] += 1;
+        }
+    }
+
+    // Remove subgraphs of unused nodes which are terminated by Merge nodes.
+    while (!mergeOpSubgraphNodes.empty())
+    {
+        const tensorflow::NodeDef& node = net.node(mergeOpSubgraphNodes.front());
+        mergeOpSubgraphNodes.pop();
+        for (int i = 0; i < node.input_size(); ++i)
+        {
+            std::string inpName = node.input(i);
+            inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':'));
+            nodesMapIt = nodesMap.find(inpName);
+            CV_Assert(nodesMapIt != nodesMap.end());
+
+            int inpNodeId = nodesMapIt->second;
+            if (numConsumers[inpNodeId] == 1)
+            {
+                mergeOpSubgraphNodes.push(inpNodeId);
+                nodesToRemove.push_back(inpNodeId);
+            }
+            else if (numConsumers[inpNodeId] > 0)
+                numConsumers[inpNodeId] -= 1;
+        }
+    }
+    std::sort(nodesToRemove.begin(), nodesToRemove.end());
+    for (int i = nodesToRemove.size() - 1; i >= 0; --i)
+    {
+        if (nodesToRemove[i] < net.node_size())  // Ids might be repeated.
+            net.mutable_node()->DeleteSubrange(nodesToRemove[i], 1);
+    }
+}
+
+
 CV__DNN_EXPERIMENTAL_NS_END
 }}  // namespace dnn, namespace cv
 
index 24c4bd5..5929d1f 100644 (file)
@@ -27,6 +27,8 @@ void releaseTensor(tensorflow::TensorProto* tensor);
 
 void sortByExecutionOrder(tensorflow::GraphDef& net);
 
+void removePhaseSwitches(tensorflow::GraphDef& net);
+
 CV__DNN_EXPERIMENTAL_NS_END
 }}  // namespace dnn, namespace cv
 
index 9cca9e9..a162810 100644 (file)
@@ -657,6 +657,9 @@ static int predictOutputDataLayout(const tensorflow::GraphDef& net,
 
 void TFImporter::populateNet(Net dstNet)
 {
+    if (!netTxt.ByteSize())
+        removePhaseSwitches(netBin);
+
     RemoveIdentityOps(netBin);
     RemoveIdentityOps(netTxt);
 
index 395a965..ef5206f 100644 (file)
@@ -185,6 +185,16 @@ TEST_P(Test_TensorFlow_layers, batch_norm)
     runTensorFlowNet("mvn_batch_norm_1x1");
 }
 
+TEST_P(Test_TensorFlow_layers, slim_batch_norm)
+{
+    if (backend == DNN_BACKEND_INFERENCE_ENGINE)
+        throw SkipTestException("Test is disabled for DLIE");
+    // Output values range: [-40.0597, 207.827]
+    double l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.041 : default_l1;
+    double lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.33 : default_lInf;
+    runTensorFlowNet("slim_batch_norm", false, l1, lInf);
+}
+
 TEST_P(Test_TensorFlow_layers, pooling)
 {
     runTensorFlowNet("max_pool_even");