Fix Mobilenet v2 from TensorFlow slim
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 27 Mar 2019 12:10:57 +0000 (15:10 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 27 Mar 2019 12:10:57 +0000 (15:10 +0300)
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

index 5cf65c8..59d0d57 100644 (file)
@@ -630,6 +630,21 @@ public:
     }
 };
 
+class SoftMaxSlimSubgraph : public Subgraph
+{
+public:
+    SoftMaxSlimSubgraph()
+    {
+        int input = addNodeToMatch("");
+        int shape = addNodeToMatch("Const");
+        int shapeOp = addNodeToMatch("Shape", input);
+        int reshape = addNodeToMatch("Reshape", input, shape);
+        int softmax = addNodeToMatch("Softmax", reshape);
+        addNodeToMatch("Reshape", softmax, shapeOp);
+        setFusedNode("Softmax", input);
+    }
+};
+
 void simplifySubgraphs(tensorflow::GraphDef& net)
 {
     std::vector<Ptr<Subgraph> > subgraphs;
@@ -646,6 +661,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
     subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
     subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph()));
     subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
+    subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimSubgraph()));
 
     int numNodes = net.node_size();
     std::vector<int> matchedNodesIds;
index 6ce99d6..480e8c7 100644 (file)
@@ -661,7 +661,10 @@ void TFImporter::populateNet(Net dstNet)
     RemoveIdentityOps(netTxt);
 
     if (!netTxt.ByteSize())
+    {
         simplifySubgraphs(netBin);
+        sortByExecutionOrder(netBin);
+    }
 
     std::set<String> layers_to_ignore;
 
index a5d5512..9a7b09c 100644 (file)
@@ -549,6 +549,7 @@ TEST_P(Test_TensorFlow_layers, slice)
 TEST_P(Test_TensorFlow_layers, softmax)
 {
     runTensorFlowNet("keras_softmax");
+    runTensorFlowNet("slim_softmax");
 }
 
 TEST_P(Test_TensorFlow_layers, relu6)