From ec41a4897a7feb1c83afe19078af174368464b07 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Wed, 3 Apr 2019 13:42:06 +0300 Subject: [PATCH] Remove Switch and Merge nodes from TensorFlow networks --- modules/dnn/src/tensorflow/tf_graph_simplifier.cpp | 81 +++++++++++++++++++++- modules/dnn/src/tensorflow/tf_graph_simplifier.hpp | 2 + modules/dnn/src/tensorflow/tf_importer.cpp | 3 + modules/dnn/test/test_tf_importer.cpp | 10 +++ 4 files changed, 95 insertions(+), 1 deletion(-) diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp index 59d0d57..37e5750 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp @@ -10,6 +10,7 @@ #ifdef HAVE_PROTOBUF #include "tf_graph_simplifier.hpp" +#include namespace cv { namespace dnn { CV__DNN_EXPERIMENTAL_NS_BEGIN @@ -883,7 +884,6 @@ void sortByExecutionOrder(tensorflow::GraphDef& net) nodesToAdd.pop_back(); permIds.push_back(nodeToAdd); - // std::cout << net.node(nodeToAdd).name() << '\n'; for (int i = 0; i < edges[nodeToAdd].size(); ++i) { @@ -902,6 +902,85 @@ void sortByExecutionOrder(tensorflow::GraphDef& net) permute(net.mutable_node(), permIds); } +// Remove training switches (Switch and Merge nodes and corresponding subgraphs). +void removePhaseSwitches(tensorflow::GraphDef& net) +{ + std::vector nodesToRemove; + std::map nodesMap; + std::map::iterator nodesMapIt; + std::queue mergeOpSubgraphNodes; + for (int i = 0; i < net.node_size(); ++i) + { + const tensorflow::NodeDef& node = net.node(i); + nodesMap.insert(std::make_pair(node.name(), i)); + if (node.op() == "Switch" || node.op() == "Merge") + { + CV_Assert(node.input_size() > 0); + // Replace consumers' inputs. + for (int j = 0; j < net.node_size(); ++j) + { + tensorflow::NodeDef* consumer = net.mutable_node(j); + for (int k = 0; k < consumer->input_size(); ++k) + { + std::string inpName = consumer->input(k); + inpName = inpName.substr(0, inpName.rfind(':')); + if (inpName == node.name()) + { + consumer->set_input(k, node.input(0)); + } + } + } + nodesToRemove.push_back(i); + if (node.op() == "Merge") + mergeOpSubgraphNodes.push(i); + } + } + + std::vector numConsumers(net.node_size(), 0); + for (int i = 0; i < net.node_size(); ++i) + { + const tensorflow::NodeDef& node = net.node(i); + for (int j = 0; j < node.input_size(); ++j) + { + std::string inpName = node.input(j); + inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':')); + nodesMapIt = nodesMap.find(inpName); + CV_Assert(nodesMapIt != nodesMap.end()); + numConsumers[nodesMapIt->second] += 1; + } + } + + // Remove subgraphs of unused nodes which are terminated by Merge nodes. + while (!mergeOpSubgraphNodes.empty()) + { + const tensorflow::NodeDef& node = net.node(mergeOpSubgraphNodes.front()); + mergeOpSubgraphNodes.pop(); + for (int i = 0; i < node.input_size(); ++i) + { + std::string inpName = node.input(i); + inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':')); + nodesMapIt = nodesMap.find(inpName); + CV_Assert(nodesMapIt != nodesMap.end()); + + int inpNodeId = nodesMapIt->second; + if (numConsumers[inpNodeId] == 1) + { + mergeOpSubgraphNodes.push(inpNodeId); + nodesToRemove.push_back(inpNodeId); + } + else if (numConsumers[inpNodeId] > 0) + numConsumers[inpNodeId] -= 1; + } + } + std::sort(nodesToRemove.begin(), nodesToRemove.end()); + for (int i = nodesToRemove.size() - 1; i >= 0; --i) + { + if (nodesToRemove[i] < net.node_size()) // Ids might be repeated. + net.mutable_node()->DeleteSubrange(nodesToRemove[i], 1); + } +} + + CV__DNN_EXPERIMENTAL_NS_END }} // namespace dnn, namespace cv diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.hpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.hpp index 24c4bd5..5929d1f 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.hpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.hpp @@ -27,6 +27,8 @@ void releaseTensor(tensorflow::TensorProto* tensor); void sortByExecutionOrder(tensorflow::GraphDef& net); +void removePhaseSwitches(tensorflow::GraphDef& net); + CV__DNN_EXPERIMENTAL_NS_END }} // namespace dnn, namespace cv diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 9cca9e9..a162810 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -657,6 +657,9 @@ static int predictOutputDataLayout(const tensorflow::GraphDef& net, void TFImporter::populateNet(Net dstNet) { + if (!netTxt.ByteSize()) + removePhaseSwitches(netBin); + RemoveIdentityOps(netBin); RemoveIdentityOps(netTxt); diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 395a965..ef5206f 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -185,6 +185,16 @@ TEST_P(Test_TensorFlow_layers, batch_norm) runTensorFlowNet("mvn_batch_norm_1x1"); } +TEST_P(Test_TensorFlow_layers, slim_batch_norm) +{ + if (backend == DNN_BACKEND_INFERENCE_ENGINE) + throw SkipTestException("Test is disabled for DLIE"); + // Output values range: [-40.0597, 207.827] + double l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.041 : default_l1; + double lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.33 : default_lInf; + runTensorFlowNet("slim_batch_norm", false, l1, lInf); +} + TEST_P(Test_TensorFlow_layers, pooling) { runTensorFlowNet("max_pool_even"); -- 2.7.4