From: Dmitry Kurtaev Date: Mon, 29 Apr 2019 15:55:09 +0000 (+0300) Subject: Refactored TensorFlow subgraphs fusion X-Git-Tag: accepted/tizen/6.0/unified/20201030.111113~1^2~252^2~2^2 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9408c3e6405c2a921923e12fa9c58382df5dc910;p=platform%2Fupstream%2Fopencv.git Refactored TensorFlow subgraphs fusion --- diff --git a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp index 7f10018..086f0ae 100644 --- a/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp +++ b/modules/dnn/src/tensorflow/tf_graph_simplifier.cpp @@ -79,9 +79,9 @@ public: } } - static const tensorflow::NodeDef& getInputNode(const tensorflow::GraphDef& net, - const tensorflow::NodeDef& node, - int inpId) + static int getInputNodeId(const tensorflow::GraphDef& net, + const tensorflow::NodeDef& node, + int inpId) { CV_Assert(inpId < node.input_size()); std::string name = node.input(inpId); @@ -92,7 +92,7 @@ public: for (int i = 0; i < numNodes; ++i) { if (net.node(i).name() == name) - return net.node(i); + return i; } CV_Error(Error::StsParseError, "Input node with name " + name + " not found"); } @@ -104,36 +104,46 @@ public: matchedNodesIds.clear(); matchedNodesIds.reserve(nodesToFuse.size()); - int numNodes = net.node_size(); - for (int i = 0; i < nodesToFuse.size(); ++i) + std::queue nodesToMatch; + std::queue targetNodes; + nodesToMatch.push(nodeId); + targetNodes.push(nodesToFuse.back()); + while (!nodesToMatch.empty()) { - while (nodeId < numNodes && net.node(nodeId).op() == "Const") - { - nodeId += 1; - } - if (nodeId > numNodes - 1) - return false; + int nodeToMatch = nodesToMatch.front(); + int targetNodeId = targetNodes.front(); + nodesToMatch.pop(); + targetNodes.pop(); - const tensorflow::NodeDef& node = net.node(nodeId); + if (std::find(matchedNodesIds.begin(), matchedNodesIds.end(), nodeToMatch) != + matchedNodesIds.end()) + continue; - if (node.op() != nodes[nodesToFuse[i]]) + const tensorflow::NodeDef& node = net.node(nodeToMatch); + if (node.op() != nodes[targetNodeId]) return false; - std::vector& inputNodes = inputs[nodesToFuse[i]]; + std::vector& inputNodes = inputs[targetNodeId]; if (inputNodes.size() != node.input_size()) return false; + for (int j = 0; j < inputNodes.size(); ++j) { if (nodes[inputNodes[j]].empty()) // Unknown input node type. continue; - const tensorflow::NodeDef& inpNode = getInputNode(net, node, j); - if (inpNode.op() != nodes[inputNodes[j]]) + nodeId = getInputNodeId(net, node, j); + const tensorflow::NodeDef& inpNode = net.node(nodeId); + if (inpNode.op() != "Const") + { + nodesToMatch.push(nodeId); + targetNodes.push(inputNodes[j]); + } + else if (nodes[inputNodes[j]] != "Const") return false; } - - matchedNodesIds.push_back(nodeId); - nodeId += 1; + matchedNodesIds.push_back(nodeToMatch); } + std::sort(matchedNodesIds.begin(), matchedNodesIds.end()); return true; } @@ -181,7 +191,7 @@ public: std::vector inputNodes(inputsNames.size()); for (int i = 0; i < inputsNames.size(); ++i) { - inputNodes[i] = (tensorflow::NodeDef*)&getInputNode(net, *node, i); + inputNodes[i] = net.mutable_node(getInputNodeId(net, *node, i)); } finalize(net, node, inputNodes); } @@ -354,7 +364,7 @@ public: { if (!Subgraph::match(net, nodeId, matchedNodesIds)) return false; - Mat maxValue = getTensorContent(net.node(nodeId + 1).attr().at("value").tensor()); + Mat maxValue = getTensorContent(net.node(matchedNodesIds.front() + 1).attr().at("value").tensor()); return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at(0) == 6; } }; @@ -384,6 +394,17 @@ public: setFusedNode("Reshape", ids); } + virtual bool match(const tensorflow::GraphDef& net, int nodeId, std::vector& matchedNodesIds) CV_OVERRIDE + { + const tensorflow::NodeDef& node = net.node(nodeId); + if (node.input_size() == 0) + return false; + + inpName = node.input(0); + return Subgraph::match(net, nodeId, matchedNodesIds); + } + + virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode, std::vector& inputNodes) CV_OVERRIDE { @@ -395,6 +416,7 @@ public: } tensorflow::TensorProto* shapeTensor = inputNodes[1]->mutable_attr()->at("value").mutable_tensor(); fusedNode->mutable_input()->DeleteSubrange(2, numOutDims - 1); + fusedNode->set_input(0, inpName); shapeTensor->clear_int_val(); for (int i = 0; i < shape.size(); ++i) @@ -405,6 +427,7 @@ public: private: int numOutDims; + std::string inpName; }; class L2NormalizeSubgraph : public Subgraph @@ -685,9 +708,9 @@ void simplifySubgraphs(tensorflow::GraphDef& net) subgraphs.push_back(Ptr(new DeconvolutionSameKerasSubgraph())); subgraphs.push_back(Ptr(new ResizeBilinearSubgraph())); subgraphs.push_back(Ptr(new UpsamplingKerasSubgraph())); - subgraphs.push_back(Ptr(new ReshapeAsShapeSubgraph())); subgraphs.push_back(Ptr(new SoftMaxSlimSubgraph())); subgraphs.push_back(Ptr(new SoftMaxSlimV2Subgraph())); + subgraphs.push_back(Ptr(new ReshapeAsShapeSubgraph())); int numNodes = net.node_size(); std::vector matchedNodesIds; diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index ef0b196..0ff155f 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -1079,25 +1079,28 @@ void TFImporter::populateNet(Net dstNet) { Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1)); - if (newShape.total() != 4 && inpLayout == DATA_LAYOUT_NHWC) - { - LayerParams permLP; - int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC. - permLP.set("order", DictValue::arrayInt(order, 4)); - - std::string permName = name + "/nchw"; - CV_Assert(layer_id.find(permName) == layer_id.end()); - int permId = dstNet.addLayer(permName, "Permute", permLP); - layer_id[permName] = permId; - connect(layer_id, dstNet, inpId, permId, 0); - inpId = Pin(permName); - inpLayout = DATA_LAYOUT_NCHW; - } - else if (newShape.total() == 4 && inpLayout == DATA_LAYOUT_NHWC) + if (inpLayout == DATA_LAYOUT_NHWC) { - // NHWC->NCHW - std::swap(*newShape.ptr(0, 2), *newShape.ptr(0, 3)); - std::swap(*newShape.ptr(0, 1), *newShape.ptr(0, 2)); + if (newShape.total() == 4) + { + // NHWC->NCHW + std::swap(*newShape.ptr(0, 2), *newShape.ptr(0, 3)); + std::swap(*newShape.ptr(0, 1), *newShape.ptr(0, 2)); + } + if (newShape.total() != 4 || newShape.at(1) == 1) + { + LayerParams permLP; + int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC. + permLP.set("order", DictValue::arrayInt(order, 4)); + + std::string permName = name + "/nchw"; + CV_Assert(layer_id.find(permName) == layer_id.end()); + int permId = dstNet.addLayer(permName, "Permute", permLP); + layer_id[permName] = permId; + connect(layer_id, dstNet, inpId, permId, 0); + inpId = Pin(permName); + inpLayout = DATA_LAYOUT_NCHW; + } } layerParams.set("dim", DictValue::arrayInt(newShape.ptr(), newShape.total())); @@ -1335,7 +1338,9 @@ void TFImporter::populateNet(Net dstNet) // num_split // 1st blob is dims tensor int axis = getConstBlob(layer, value_id, 0).int_val().Get(0); - layerParams.set("axis", toNCHW(axis)); + if (getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC) + axis = toNCHW(axis); + layerParams.set("axis", axis); int id = dstNet.addLayer(name, "Slice", layerParams); layer_id[name] = id; diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 8b750bb..4973008 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -654,6 +654,13 @@ TEST_P(Test_TensorFlow_layers, relu6) runTensorFlowNet("keras_relu6", /*hasText*/ true); } +TEST_P(Test_TensorFlow_layers, subpixel) +{ + if (backend == DNN_BACKEND_INFERENCE_ENGINE) + throw SkipTestException(""); + runTensorFlowNet("subpixel"); +} + TEST_P(Test_TensorFlow_layers, keras_mobilenet_head) { runTensorFlowNet("keras_mobilenet_head");