Import TF2.0 network from Keras
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 25 Mar 2020 12:34:28 +0000 (15:34 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 25 Mar 2020 12:34:28 +0000 (15:34 +0300)
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
modules/dnn/test/test_tf_importer.cpp

index 2e7bb57..b0978c2 100644 (file)
@@ -682,6 +682,15 @@ void RemoveIdentityOps(tensorflow::GraphDef& net)
             IdentityOpsMap::iterator it = identity_ops.find(input_op_name);
 
             if (it != identity_ops.end()) {
+                // In case of Identity after Identity
+                while (true)
+                {
+                    IdentityOpsMap::iterator nextIt = identity_ops.find(it->second);
+                    if (nextIt != identity_ops.end())
+                        it = nextIt;
+                    else
+                        break;
+                }
                 layer->set_input(input_id, it->second);
             }
         }
@@ -847,7 +856,7 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
             nodesToAdd.push_back(i);
         else
         {
-            if (node.op() == "Merge" || node.op() == "RefMerge")
+            if (node.op() == "Merge" || node.op() == "RefMerge" || node.op() == "NoOp")
             {
                 int numControlEdges = 0;
                 for (int j = 0; j < numInputsInGraph; ++j)
@@ -896,7 +905,7 @@ void removePhaseSwitches(tensorflow::GraphDef& net)
     {
         const tensorflow::NodeDef& node = net.node(i);
         nodesMap.insert(std::make_pair(node.name(), i));
-        if (node.op() == "Switch" || node.op() == "Merge")
+        if (node.op() == "Switch" || node.op() == "Merge" || node.op() == "NoOp")
         {
             CV_Assert(node.input_size() > 0);
             // Replace consumers' inputs.
@@ -914,7 +923,7 @@ void removePhaseSwitches(tensorflow::GraphDef& net)
                 }
             }
             nodesToRemove.push_back(i);
-            if (node.op() == "Merge" || node.op() == "Switch")
+            if (node.op() == "Merge" || node.op() == "Switch" || node.op() == "NoOp")
                 mergeOpSubgraphNodes.push(i);
         }
     }
index 8cacae8..0088cfd 100644 (file)
@@ -867,6 +867,11 @@ TEST_P(Test_TensorFlow_layers, resize_bilinear)
     runTensorFlowNet("resize_bilinear_factor");
 }
 
+TEST_P(Test_TensorFlow_layers, tf2_keras)
+{
+    runTensorFlowNet("tf2_dense");
+}
+
 TEST_P(Test_TensorFlow_layers, squeeze)
 {
 #if defined(INF_ENGINE_RELEASE)