Sort text TensorFlow graphs
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 31 Dec 2019 08:43:32 +0000 (11:43 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 31 Dec 2019 08:43:32 +0000 (11:43 +0300)
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
modules/dnn/src/tensorflow/tf_importer.cpp

index 0f9670e..a7845b2 100644 (file)
@@ -950,6 +950,7 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
     for (int i = 0; i < net.node_size(); ++i)
     {
         const tensorflow::NodeDef& node = net.node(i);
+        int numInputsInGraph = 0;
         for (int j = 0; j < node.input_size(); ++j)
         {
             std::string inpName = node.input(j);
@@ -957,22 +958,25 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
             inpName = inpName.substr(inpName.find('^') + 1);
 
             nodesMapIt = nodesMap.find(inpName);
-            CV_Assert(nodesMapIt != nodesMap.end());
-            edges[nodesMapIt->second].push_back(i);
+            if (nodesMapIt != nodesMap.end())
+            {
+                edges[nodesMapIt->second].push_back(i);
+                numInputsInGraph += 1;
+            }
         }
-        if (node.input_size() == 0)
+        if (numInputsInGraph == 0)
             nodesToAdd.push_back(i);
         else
         {
             if (node.op() == "Merge" || node.op() == "RefMerge")
             {
                 int numControlEdges = 0;
-                for (int j = 0; j < node.input_size(); ++j)
+                for (int j = 0; j < numInputsInGraph; ++j)
                     numControlEdges += node.input(j)[0] == '^';
                 numRefsToAdd[i] = numControlEdges + 1;
             }
             else
-                numRefsToAdd[i] = node.input_size();
+                numRefsToAdd[i] = numInputsInGraph;
         }
     }
 
index 192b94e..2775934 100644 (file)
@@ -715,6 +715,10 @@ void TFImporter::populateNet(Net dstNet)
         simplifySubgraphs(netBin);
         sortByExecutionOrder(netBin);
     }
+    else
+    {
+        sortByExecutionOrder(netTxt);
+    }
 
     std::set<String> layers_to_ignore;