Fuse deconvolution layer subgraphs from Keras
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 20 Apr 2018 13:44:05 +0000 (16:44 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Fri, 20 Apr 2018 13:51:38 +0000 (16:51 +0300)
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

index cfb472e..80bb6b5 100644 (file)
@@ -419,6 +419,125 @@ public:
     }
 };
 
+class DeconvolutionValidKerasSubgraph : public Subgraph
+{
+public:
+    DeconvolutionValidKerasSubgraph()
+    {
+        int input = addNodeToMatch("");
+        int shape = addNodeToMatch("Shape", input);
+        int kernel = addNodeToMatch("Const");
+
+        int stack = addNodeToMatch("Const");
+        int stack_1 = addNodeToMatch("Const");
+        int stack_2 = addNodeToMatch("Const");
+        int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
+
+        stack = addNodeToMatch("Const");
+        stack_1 = addNodeToMatch("Const");
+        stack_2 = addNodeToMatch("Const");
+        int strided_slice_1 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
+
+        stack = addNodeToMatch("Const");
+        stack_1 = addNodeToMatch("Const");
+        stack_2 = addNodeToMatch("Const");
+        int strided_slice_2 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
+
+        int mul = addNodeToMatch("Mul", strided_slice_1, addNodeToMatch("Const"));
+        int add = addNodeToMatch("Add", mul, addNodeToMatch("Const"));
+
+        int mul_1 = addNodeToMatch("Mul", strided_slice_2, addNodeToMatch("Const"));
+        int add_1 = addNodeToMatch("Add", mul_1, addNodeToMatch("Const"));
+        int pack = addNodeToMatch("Pack", strided_slice, add, add_1, addNodeToMatch("Const"));
+        addNodeToMatch("Conv2DBackpropInput", pack, kernel, input);
+        // Put any unused Const op to the first input.
+        setFusedNode("Conv2DBackpropInput", stack, kernel, input);
+    }
+
+    virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
+                          std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
+    {
+        // Disable adjusted paddings (see Conv2DBackpropInput layer at tf_importer.cpp)
+        // adj_w = (outW - (pad == "SAME") ? 1 : kernelW) % strideX;
+        // adj_h = (outH - (pad == "SAME") ? 1 : kernelH) % strideY;
+        // Where outH and outW are 1st and 2nd dimensions (NHWC) or 2nd and third (NCHW).
+        std::string padMode = fusedNode->attr().at("padding").s();
+        CV_Assert(padMode == "VALID");
+
+        const tensorflow::TensorShapeProto& kernelShape =
+            inputNodes[1]->mutable_attr()->at("value").tensor().tensor_shape();
+
+        CV_Assert(kernelShape.dim_size() == 4);
+        const int kernelHeight = kernelShape.dim(0).size();
+        const int kernelWidth = kernelShape.dim(1).size();
+
+        tensorflow::TensorProto* outShape = inputNodes[0]->mutable_attr()->at("value").mutable_tensor();
+        outShape->clear_int_val();
+        outShape->add_int_val(-1);
+        outShape->add_int_val(kernelHeight);
+        outShape->add_int_val(kernelWidth);
+        outShape->add_int_val(-1);
+    }
+};
+
+class DeconvolutionSameKerasSubgraph : public Subgraph
+{
+public:
+    DeconvolutionSameKerasSubgraph()
+    {
+        int input = addNodeToMatch("");
+        int shape = addNodeToMatch("Shape", input);
+        int kernel = addNodeToMatch("Const");
+
+        int stack = addNodeToMatch("Const");
+        int stack_1 = addNodeToMatch("Const");
+        int stack_2 = addNodeToMatch("Const");
+        int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
+
+        stack = addNodeToMatch("Const");
+        stack_1 = addNodeToMatch("Const");
+        stack_2 = addNodeToMatch("Const");
+        int strided_slice_1 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
+
+        stack = addNodeToMatch("Const");
+        stack_1 = addNodeToMatch("Const");
+        stack_2 = addNodeToMatch("Const");
+        int strided_slice_2 = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
+
+        int mul = addNodeToMatch("Mul", strided_slice_1, addNodeToMatch("Const"));
+
+        int mul_1 = addNodeToMatch("Mul", strided_slice_2, addNodeToMatch("Const"));
+        int pack = addNodeToMatch("Pack", strided_slice, mul, mul_1, addNodeToMatch("Const"));
+        addNodeToMatch("Conv2DBackpropInput", pack, kernel, input);
+        // Put any unused Const op to the first input.
+        setFusedNode("Conv2DBackpropInput", stack, kernel, input);
+    }
+
+    virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
+                          std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
+    {
+        // Disable adjusted paddings (see Conv2DBackpropInput layer at tf_importer.cpp)
+        // adj_w = (outW - (pad == "SAME") ? 1 : kernelW) % strideX;
+        // adj_h = (outH - (pad == "SAME") ? 1 : kernelH) % strideY;
+        // Where outH and outW are 1st and 2nd dimensions (NHWC) or 2nd and third (NCHW).
+        std::string padMode = fusedNode->attr().at("padding").s();
+        CV_Assert(padMode == "SAME");
+
+        const tensorflow::AttrValue_ListValue& strides = fusedNode->attr().at("strides").list();
+        CV_Assert(strides.i_size() == 4);
+
+        const int strideY = strides.i(1);
+        const int strideX = strides.i(2);
+
+        tensorflow::TensorProto* outShape = inputNodes[0]->mutable_attr()->at("value").mutable_tensor();
+        outShape->clear_int_val();
+        outShape->add_int_val(-1);
+        outShape->add_int_val(strideY);
+        outShape->add_int_val(strideX);
+        outShape->add_int_val(-1);
+    }
+};
+
 void simplifySubgraphs(tensorflow::GraphDef& net)
 {
     std::vector<Ptr<Subgraph> > subgraphs;
@@ -430,6 +549,8 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
     subgraphs.push_back(Ptr<Subgraph>(new ReLU6KerasSubgraph()));
     subgraphs.push_back(Ptr<Subgraph>(new ReshapeKerasSubgraph(3)));
     subgraphs.push_back(Ptr<Subgraph>(new L2NormalizeSubgraph()));
+    subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionValidKerasSubgraph()));
+    subgraphs.push_back(Ptr<Subgraph>(new DeconvolutionSameKerasSubgraph()));
 
     int numNodes = net.node_size();
     std::vector<int> matchedNodesIds;
index f580916..a401f71 100644 (file)
@@ -1303,8 +1303,8 @@ void TFImporter::populateNet(Net dstNet)
             const int strideY = layerParams.get<int>("stride_h");
             const int strideX = layerParams.get<int>("stride_w");
             Mat outShape = getTensorContent(getConstBlob(layer, value_id, 0));
-            const int outH = outShape.at<int>(2);
-            const int outW = outShape.at<int>(1);
+            const int outH = outShape.at<int>(1);
+            const int outW = outShape.at<int>(2);
             if (layerParams.get<String>("pad_mode") == "SAME")
             {
                 layerParams.set("adj_w", (outW - 1) % strideX);
index 43e30eb..b7f3c5c 100644 (file)
@@ -173,6 +173,8 @@ TEST_P(Test_TensorFlow_layers, deconvolution)
     runTensorFlowNet("deconvolution_stride_2_same", targetId);
     runTensorFlowNet("deconvolution_adj_pad_valid", targetId);
     runTensorFlowNet("deconvolution_adj_pad_same", targetId);
+    runTensorFlowNet("keras_deconv_valid", targetId);
+    runTensorFlowNet("keras_deconv_same", targetId);
 }
 
 TEST_P(Test_TensorFlow_layers, matmul)