Fuse subgraphs from Keras
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Thu, 1 Mar 2018 16:47:50 +0000 (19:47 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 12 Mar 2018 07:53:06 +0000 (10:53 +0300)
modules/dnn/src/tensorflow/tf_graph_simplifier.cpp [moved from modules/dnn/src/tensorflow/tf_graph_editor.cpp with 70% similarity]
modules/dnn/src/tensorflow/tf_graph_simplifier.hpp [moved from modules/dnn/src/tensorflow/tf_graph_editor.hpp with 100% similarity]
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

@@ -7,7 +7,7 @@
 
 #ifdef HAVE_PROTOBUF
 
-#include "tf_graph_editor.hpp"
+#include "tf_graph_simplifier.hpp"
 
 namespace cv { namespace dnn {
 CV__DNN_EXPERIMENTAL_NS_BEGIN
@@ -28,11 +28,19 @@ public:
         int numInputs = 0;
         for (int i = 0; i < 4; ++i)
         {
-            CV_Assert(nodeInputs[i] < (int)nodes.size());
             numInputs += (int)(nodeInputs[i] != -1);
         }
+        return addNodeToMatch(op, std::vector<int>(&nodeInputs[0], &nodeInputs[0] + numInputs));
+    }
+
+    int addNodeToMatch(const std::string& op, const std::vector<int>& inputs_)
+    {
+        for (int i = 0; i < inputs_.size(); ++i)
+        {
+            CV_Assert(inputs_[i] < (int)nodes.size());
+        }
         nodes.push_back(op);
-        inputs.push_back(std::vector<int>(&nodeInputs[0], &nodeInputs[0] + numInputs));
+        inputs.push_back(inputs_);
         return nodes.size() - 1;
     }
 
@@ -50,13 +58,18 @@ public:
             CV_Assert(nodeInputs[i] < (int)nodes.size());
             numInputs += (int)(nodeInputs[i] != -1);
         }
-        fusedNodeInputs = std::vector<int>(&nodeInputs[0], &nodeInputs[0] + numInputs);
+        setFusedNode(op, std::vector<int>(&nodeInputs[0], &nodeInputs[0] + numInputs));
+    }
 
+    void setFusedNode(const std::string& op, const std::vector<int>& inputs_)
+    {
+        fusedNodeInputs = inputs_;
         fusedNodeOp = op;
         nodesToFuse.clear();
         for (int i = 0; i < nodes.size(); ++i)
         {
-            if (std::find(fusedNodeInputs.begin(), fusedNodeInputs.end(), i) == fusedNodeInputs.end())
+            if (std::find(fusedNodeInputs.begin(), fusedNodeInputs.end(), i) == fusedNodeInputs.end() &&
+                nodes[i] != "Const")
                 nodesToFuse.push_back(i);
         }
     }
@@ -70,26 +83,32 @@ public:
         const int numNodes = net.node_size();
         for (int i = 0; i < numNodes; ++i)
         {
-            const tensorflow::NodeDef& node = net.node(i);
-            if (node.name() == name)
-                return node;
+            if (net.node(i).name() == name)
+                return net.node(i);
         }
         CV_Error(Error::StsParseError, "Input node with name " + name + " not found");
         return net.node(0);  // just return something
     }
 
     // Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
-    // Returns true if nodes are matched and can be fused.
-    bool match(const tensorflow::GraphDef& net, int nodeId, int* numMatchedNodes)
+    // Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
+    virtual bool match(const tensorflow::GraphDef& net, int nodeId, std::vector<int>& matchedNodesIds)
     {
-        *numMatchedNodes = 0;
+        matchedNodesIds.clear();
+        matchedNodesIds.reserve(nodesToFuse.size());
+
         int numNodes = net.node_size();
         for (int i = 0; i < nodesToFuse.size(); ++i)
         {
-            if (nodeId + i > numNodes - 1)
+            while (nodeId < numNodes && net.node(nodeId).op() == "Const")
+            {
+                nodeId += 1;
+            }
+            if (nodeId > numNodes - 1)
                 return false;
 
-            const tensorflow::NodeDef &node = net.node(nodeId + i);
+            const tensorflow::NodeDef& node = net.node(nodeId);
+
             if (node.op() != nodes[nodesToFuse[i]])
                 return false;
 
@@ -105,25 +124,24 @@ public:
                     return false;
             }
 
-            *numMatchedNodes += 1;
+            matchedNodesIds.push_back(nodeId);
+            nodeId += 1;
         }
         return true;
     }
 
     // Fuse matched subgraph.
-    void replace(tensorflow::GraphDef& net, int nodeId, int* numReplacedNodes)
+    void replace(tensorflow::GraphDef& net, const std::vector<int>& matchedNodesIds)
     {
-        *numReplacedNodes = 0;
-
         // Extract names of input nodes.
         std::vector<std::string> inputsNames(fusedNodeInputs.size());
         for (int i = 0; i < fusedNodeInputs.size(); ++i)
         {
             std::string inpName;
             // Find input node name looking at inputs of fused nodes.
-            for (int j = 0; j < nodesToFuse.size() && inpName.empty(); ++j)
+            for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j)
             {
-                const tensorflow::NodeDef &node = net.node(nodeId + j);
+                const tensorflow::NodeDef &node = net.node(matchedNodesIds[j]);
                 std::vector<int>& inpIndices = inputs[nodesToFuse[j]];
 
                 CV_Assert(node.input_size() == inpIndices.size());
@@ -140,12 +158,12 @@ public:
             inputsNames[i] = inpName;
         }
 
-        // Remove all nodes except the last one.
-        *numReplacedNodes = nodesToFuse.size() - 1;
-        net.mutable_node()->DeleteSubrange(nodeId, *numReplacedNodes);
+        // Remove matched nodes except the last one. Indices in ascending order are expected.
+        tensorflow::NodeDef* node = net.mutable_node(matchedNodesIds.back());
+        for (int i = matchedNodesIds.size() - 2; i >= 0; --i)
+            net.mutable_node()->DeleteSubrange(matchedNodesIds[i], 1);
 
         // Modify the last node to be a fused one.
-        tensorflow::NodeDef* node = net.mutable_node(nodeId);
         node->set_op(fusedNodeOp);
         node->clear_input();
         for (int i = 0; i < inputsNames.size(); ++i)
@@ -153,16 +171,16 @@ public:
             node->add_input(inputsNames[i]);
         }
 
-        std::vector<tensorflow::NodeDef> inputNodes(inputsNames.size());
+        std::vector<tensorflow::NodeDef*> inputNodes(inputsNames.size());
         for (int i = 0; i < inputsNames.size(); ++i)
         {
-            inputNodes[i] = getInputNode(net, *node, i);
+            inputNodes[i] = (tensorflow::NodeDef*)&getInputNode(net, *node, i);
         }
         finalize(net, node, inputNodes);
     }
 
     virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef*,
-                          const std::vector<tensorflow::NodeDef>&) {}
+                          std::vector<tensorflow::NodeDef*>&) {}
 
 private:
     std::vector<std::string> nodes;         // Nodes to be matched in the origin graph.
@@ -196,9 +214,9 @@ public:
     }
 
     virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
-                          const std::vector<tensorflow::NodeDef>& inputNodes)
+                          std::vector<tensorflow::NodeDef*>& inputNodes)
     {
-        Mat epsMat = getTensorContent(inputNodes.back().attr().at("value").tensor());
+        Mat epsMat = getTensorContent(inputNodes.back()->attr().at("value").tensor());
         CV_Assert(epsMat.total() == 1, epsMat.type() == CV_32FC1);
 
         fusedNode->mutable_input()->ReleaseLast();
@@ -231,9 +249,9 @@ public:
     }
 
     virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode,
-                          const std::vector<tensorflow::NodeDef>& inputNodes)
+                          std::vector<tensorflow::NodeDef*>& inputNodes)
     {
-        Mat epsMat = getTensorContent(inputNodes.back().attr().at("value").tensor());
+        Mat epsMat = getTensorContent(inputNodes.back()->attr().at("value").tensor());
         CV_Assert(epsMat.total() == 1, epsMat.type() == CV_32FC1);
 
         fusedNode->mutable_input()->ReleaseLast();
@@ -291,6 +309,97 @@ public:
     }
 };
 
+// K.layers.Softmax
+class SoftMaxKerasSubgraph : public Subgraph
+{
+public:
+    SoftMaxKerasSubgraph()
+    {
+        int input = addNodeToMatch("");
+        int maxReductionIndices = addNodeToMatch("Const");
+        int smMax = addNodeToMatch("Max", input, maxReductionIndices);
+        int smSub = addNodeToMatch("Sub", input, smMax);
+        int smExp = addNodeToMatch("Exp", smSub);
+        int sumReductionIndices = addNodeToMatch("Const");
+        int smSum = addNodeToMatch("Sum", smExp, sumReductionIndices);
+        addNodeToMatch("RealDiv", smExp, smSum);
+
+        setFusedNode("Softmax", input);
+    }
+};
+
+class ReLU6KerasSubgraph : public Subgraph
+{
+public:
+    ReLU6KerasSubgraph()
+    {
+        int input = addNodeToMatch("");
+        int relu = addNodeToMatch("Relu", input);
+        int maxValue = addNodeToMatch("Const");
+        int clipValue = addNodeToMatch("Const");
+        int minimum = addNodeToMatch("Minimum", relu, maxValue);
+        addNodeToMatch("Maximum", minimum, clipValue);
+
+        setFusedNode("Relu6", input);
+    }
+
+    virtual bool match(const tensorflow::GraphDef& net, int nodeId, std::vector<int>& matchedNodesIds)
+    {
+        if (!Subgraph::match(net, nodeId, matchedNodesIds))
+            return false;
+        Mat maxValue = getTensorContent(net.node(nodeId + 1).attr().at("value").tensor());
+        return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at<float>(0) == 6;
+    }
+};
+
+// Keras' reshape stores output shape in separate Const nodes by one value.
+// Need to merge them into a single Const node.
+class ReshapeKerasSubgraph : public Subgraph
+{
+public:
+    ReshapeKerasSubgraph(int _numOutDims) : numOutDims(_numOutDims)
+    {
+        int input = addNodeToMatch("");
+        int shape = addNodeToMatch("Shape", input);
+        int stack = addNodeToMatch("Const");
+        int stack_1 = addNodeToMatch("Const");
+        int stack_2 = addNodeToMatch("Const");
+        int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
+
+        std::vector<int> ids(1 + numOutDims);
+        ids[0] = strided_slice;
+        for (int i = 0; i < numOutDims; ++i)
+            ids[1 + i] = addNodeToMatch("Const");
+        int pack = addNodeToMatch("Pack", ids);
+        addNodeToMatch("Reshape", input, pack);
+
+        ids[0] = input;
+        setFusedNode("Reshape", ids);
+    }
+
+    virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
+                          std::vector<tensorflow::NodeDef*>& inputNodes)
+    {
+        std::vector<int> shape(numOutDims + 1);  // batch size in Keras is implicit.
+        shape[0] = -1;
+        for (int i = 0; i < numOutDims; ++i)
+        {
+            shape[1 + i] = inputNodes[1 + i]->attr().at("value").tensor().int_val(0);
+        }
+        tensorflow::TensorProto* shapeTensor = inputNodes[1]->mutable_attr()->at("value").mutable_tensor();
+        fusedNode->mutable_input()->DeleteSubrange(2, numOutDims - 1);
+
+        shapeTensor->clear_int_val();
+        for (int i = 0; i < shape.size(); ++i)
+        {
+            shapeTensor->add_int_val(shape[i]);
+        }
+    }
+
+private:
+    int numOutDims;
+};
+
 void simplifySubgraphs(tensorflow::GraphDef& net)
 {
     std::vector<Ptr<Subgraph> > subgraphs;
@@ -298,17 +407,20 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
     subgraphs.push_back(Ptr<Subgraph>(new BatchNormNoGammaSubgraph()));
     subgraphs.push_back(Ptr<Subgraph>(new FlattenSubgraph()));
     subgraphs.push_back(Ptr<Subgraph>(new FlattenShapeSubgraph()));
+    subgraphs.push_back(Ptr<Subgraph>(new SoftMaxKerasSubgraph()));
+    subgraphs.push_back(Ptr<Subgraph>(new ReLU6KerasSubgraph()));
+    subgraphs.push_back(Ptr<Subgraph>(new ReshapeKerasSubgraph(3)));
 
     int numNodes = net.node_size();
-    int numMatchedNodes, numReplacedNodes;
+    std::vector<int> matchedNodesIds;
     for (int i = 0; i < numNodes; ++i)
     {
         for (int j = 0; j < subgraphs.size(); ++j)
         {
-            if (subgraphs[j]->match(net, i, &numMatchedNodes))
+            if (subgraphs[j]->match(net, i, matchedNodesIds))
             {
-                subgraphs[j]->replace(net, i, &numReplacedNodes);
-                numNodes -= numReplacedNodes;
+                subgraphs[j]->replace(net, matchedNodesIds);
+                numNodes -= matchedNodesIds.size() - 1;  // #matchedNodes removed and one added.
                 break;
             }
         }
index 9be29b9..4f7e6f4 100644 (file)
@@ -22,7 +22,7 @@ Implementation of Tensorflow models parser
 #include <google/protobuf/text_format.h>
 #include <google/protobuf/io/zero_copy_stream_impl.h>
 #include "tf_io.hpp"
-#include "tf_graph_editor.hpp"
+#include "tf_graph_simplifier.hpp"
 #endif
 
 namespace cv {
@@ -715,9 +715,9 @@ void TFImporter::populateNet(Net dstNet)
             if (hasLayerAttr(layer, "data_format"))
             {
                 std::string format = getLayerAttr(layer, "data_format").s();
-                if (format == "NHWC")
+                if (format == "NHWC" || format == "channels_last")
                     data_layouts[name] = DATA_LAYOUT_NHWC;
-                else if (format == "NCHW")
+                else if (format == "NCHW" || format == "channels_first")
                     data_layouts[name] = DATA_LAYOUT_NCHW;
                 else
                     CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
@@ -804,9 +804,9 @@ void TFImporter::populateNet(Net dstNet)
         else if (type == "Reshape")
         {
             Pin inpId = parsePin(layer.input(0));
-            DictValue newShape = parseDims(getConstBlob(layer, value_id, 1));
+            Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1));
 
-            if (newShape.size() != 4 && data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
+            if (newShape.total() != 4 && data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
             {
                 LayerParams permLP;
                 int order[] = {0, 2, 3, 1};  // From OpenCV's NCHW to NHWC.
@@ -819,14 +819,19 @@ void TFImporter::populateNet(Net dstNet)
                 connect(layer_id, dstNet, inpId, permId, 0);
                 inpId = Pin(permName);
             }
-            layerParams.set("dim", newShape);
+            else if (newShape.total() == 4 && data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
+            {
+                // NHWC->NCHW
+                std::swap(*newShape.ptr<int32_t>(0, 2), *newShape.ptr<int32_t>(0, 3));
+                std::swap(*newShape.ptr<int32_t>(0, 1), *newShape.ptr<int32_t>(0, 2));
+            }
+            layerParams.set("dim", DictValue::arrayInt<int*>(newShape.ptr<int>(), newShape.total()));
 
             int id = dstNet.addLayer(name, "Reshape", layerParams);
             layer_id[name] = id;
 
             // one input only
             connect(layer_id, dstNet, inpId, id, 0);
-            data_layouts[name] = DATA_LAYOUT_UNKNOWN;
         }
         else if (type == "Flatten" || type == "Squeeze")
         {
@@ -1488,6 +1493,39 @@ void TFImporter::populateNet(Net dstNet)
             layer_id[name] = id;
             connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, layer.input_size());
         }
+        else if (type == "Mean")
+        {
+            Mat indices = getTensorContent(getConstBlob(layer, value_id, 1));
+            CV_Assert(indices.type() == CV_32SC1);
+
+            if (indices.total() != 2 || indices.at<int>(0) != 1 || indices.at<int>(1) != 2)
+                CV_Error(Error::StsNotImplemented, "Unsupported mode of reduce_mean operation.");
+
+            layerParams.set("pool", "ave");
+            layerParams.set("global_pooling", true);
+
+            int id = dstNet.addLayer(name, "Pooling", layerParams);
+            layer_id[name] = id;
+
+            connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
+
+            // There are two attributes, "keepdims" and a deprecated "keep_dims".
+            bool keepDims = false;
+            if (hasLayerAttr(layer, "keepdims"))
+                keepDims = getLayerAttr(layer, "keepdims").b();
+            else if (hasLayerAttr(layer, "keep_dims"))
+                keepDims = getLayerAttr(layer, "keep_dims").b();
+
+            if (!keepDims)
+            {
+                LayerParams flattenLp;
+                std::string flattenName = name + "/flatten";
+                CV_Assert(layer_id.find(flattenName) == layer_id.end());
+                int flattenId = dstNet.addLayer(flattenName, "Flatten", flattenLp);
+                layer_id[flattenName] = flattenId;
+                connect(layer_id, dstNet, Pin(name), flattenId, 0);
+            }
+        }
         else if (type == "Abs" || type == "Tanh" || type == "Sigmoid" ||
                  type == "Relu" || type == "Elu" ||
                  type == "Identity" || type == "Relu6")
index b5c9567..ff21228 100644 (file)
@@ -162,6 +162,7 @@ TEST_P(Test_TensorFlow_layers, pooling)
     runTensorFlowNet("max_pool_odd_valid", targetId);
     runTensorFlowNet("ave_pool_same", targetId);
     runTensorFlowNet("max_pool_odd_same", targetId);
+    runTensorFlowNet("reduce_mean", targetId);  // an average pooling over all spatial dimensions.
 }
 
 TEST_P(Test_TensorFlow_layers, deconvolution)
@@ -337,6 +338,21 @@ TEST(Test_TensorFlow, slice)
     runTensorFlowNet("slice_4d");
 }
 
+TEST(Test_TensorFlow, softmax)
+{
+    runTensorFlowNet("keras_softmax");
+}
+
+TEST(Test_TensorFlow, relu6)
+{
+    runTensorFlowNet("keras_relu6");
+}
+
+TEST(Test_TensorFlow, keras_mobilenet_head)
+{
+    runTensorFlowNet("keras_mobilenet_head");
+}
+
 TEST(Test_TensorFlow, memory_read)
 {
     double l1 = 1e-5;