Nearest neighbor resize from Keras
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 4 Jul 2018 08:53:24 +0000 (11:53 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 4 Jul 2018 08:53:24 +0000 (11:53 +0300)
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
modules/dnn/test/test_tf_importer.cpp

index a537358..3d8a97f 100644 (file)
@@ -571,6 +571,50 @@ public:
     }
 };
 
+// In case of resizing by factor.
+class UpsamplingKerasSubgraph : public Subgraph
+{
+public:
+    UpsamplingKerasSubgraph()
+    {
+        int input = addNodeToMatch("");
+        int shape = addNodeToMatch("Shape", input);
+        int stack = addNodeToMatch("Const");
+        int stack_1 = addNodeToMatch("Const");
+        int stack_2 = addNodeToMatch("Const");
+        int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
+        int factors = addNodeToMatch("Const");
+        int mul = addNodeToMatch("Mul", strided_slice, factors);
+        addNodeToMatch("ResizeNearestNeighbor", input, mul);
+        setFusedNode("ResizeNearestNeighbor", input, factors);
+    }
+
+    virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode,
+                          std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
+    {
+        Mat factorsMat = getTensorContent(inputNodes[1]->attr().at("value").tensor());
+        CV_Assert(factorsMat.total() == 2, factorsMat.type() == CV_32SC1);
+
+        // Height scale factor
+        tensorflow::TensorProto* factorY = inputNodes[1]->mutable_attr()->at("value").mutable_tensor();
+        factorY->clear_int_val();
+        factorY->clear_tensor_content();
+        factorY->add_int_val(factorsMat.at<int>(0, 0));
+
+        // Width scale factor.
+        tensorflow::NodeDef* factorXNode = net.add_node();
+        factorXNode->set_op("Const");
+        factorXNode->set_name(fusedNode->name() + "/factor_y");
+
+        tensorflow::AttrValue factorX;
+        factorX.mutable_tensor()->set_dtype(tensorflow::DT_INT32);
+        factorX.mutable_tensor()->add_int_val(factorsMat.at<int>(0, 1));
+        factorXNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("value", factorX));
+
+        fusedNode->add_input(factorXNode->name());
+    }
+};
+
 void simplifySubgraphs(tensorflow::GraphDef& net)
 {
     std::vector<Ptr<Subgraph> > subgraphs;
@@ -585,6 +629,7 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
     subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph()));
     subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph()));
     subgraphs.push_back(Ptr<Subgraph>(new ResizeBilinearSubgraph()));
+    subgraphs.push_back(Ptr<Subgraph>(new UpsamplingKerasSubgraph()));
 
     int numNodes = net.node_size();
     std::vector<int> matchedNodesIds;
index d4ffc94..bb60d46 100644 (file)
@@ -402,6 +402,7 @@ TEST(Test_TensorFlow, split)
 TEST(Test_TensorFlow, resize_nearest_neighbor)
 {
     runTensorFlowNet("resize_nearest_neighbor");
+    runTensorFlowNet("keras_upsampling2d");
 }
 
 TEST(Test_TensorFlow, slice)