Fuse batch normalization and flatten TensorFlow subgraphs in runtime
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 21 Feb 2018 16:52:48 +0000 (19:52 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Mon, 12 Mar 2018 07:51:35 +0000 (10:51 +0300)
modules/dnn/src/tensorflow/tf_graph_editor.cpp [new file with mode: 0644]
modules/dnn/src/tensorflow/tf_graph_editor.hpp [new file with mode: 0644]
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

diff --git a/modules/dnn/src/tensorflow/tf_graph_editor.cpp b/modules/dnn/src/tensorflow/tf_graph_editor.cpp
new file mode 100644 (file)
index 0000000..6e841f2
--- /dev/null
@@ -0,0 +1,434 @@
+// This file is part of OpenCV project.
+// It is subject to the license terms in the LICENSE file found in the top-level directory
+// of this distribution and at http://opencv.org/license.html.
+
+// Copyright (C) 2018, Intel Corporation, all rights reserved.
+// Third party copyrights are property of their respective owners.
+
+#ifdef HAVE_PROTOBUF
+
+#include "tf_graph_editor.hpp"
+
+namespace cv { namespace dnn {
+CV__DNN_EXPERIMENTAL_NS_BEGIN
+
+using ::google::protobuf::RepeatedField;
+using ::google::protobuf::MapPair;
+
+class Subgraph  // Interface to match and replace TensorFlow subgraphs.
+{
+public:
+    // Add a node to be matched in the origin graph. Specify ids of nodes that
+    // are expected to be inputs. Returns id of a newly added node.
+    // TODO: Replace inputs to std::vector<int> in C++11
+    int addNodeToMatch(const std::string& op, int input_0 = -1, int input_1 = -1,
+                       int input_2 = -1, int input_3 = -1)
+    {
+        int nodeInputs[] = {input_0, input_1, input_2, input_3};
+        int numInputs = 0;
+        for (int i = 0; i < 4; ++i)
+        {
+            CV_Assert(nodeInputs[i] < (int)nodes.size());
+            numInputs += (int)(nodeInputs[i] != -1);
+        }
+        nodes.push_back(op);
+        inputs.push_back(std::vector<int>(&nodeInputs[0], &nodeInputs[0] + numInputs));
+        return nodes.size() - 1;
+    }
+
+    // Specify resulting node. All the matched nodes in subgraph excluding
+    // input nodes will be fused into this single node.
+    // TODO: Replace inputs to std::vector<int> in C++11
+    void setFusedNode(const std::string& op, int input_0 = -1, int input_1 = -1,
+                      int input_2 = -1, int input_3 = -1, int input_4 = -1,
+                      int input_5 = -1)
+    {
+        int nodeInputs[] = {input_0, input_1, input_2, input_3, input_4, input_5};
+        int numInputs = 0;
+        for (int i = 0; i < 6; ++i)
+        {
+            CV_Assert(nodeInputs[i] < (int)nodes.size());
+            numInputs += (int)(nodeInputs[i] != -1);
+        }
+        fusedNodeInputs = std::vector<int>(&nodeInputs[0], &nodeInputs[0] + numInputs);
+
+        fusedNodeOp = op;
+        nodesToFuse.clear();
+        for (int i = 0; i < nodes.size(); ++i)
+        {
+            if (std::find(fusedNodeInputs.begin(), fusedNodeInputs.end(), i) == fusedNodeInputs.end())
+                nodesToFuse.push_back(i);
+        }
+    }
+
+    static const tensorflow::NodeDef& getInputNode(const tensorflow::GraphDef& net,
+                                                   const tensorflow::NodeDef& node,
+                                                   int inpId)
+    {
+        CV_Assert(inpId < node.input_size());
+        std::string name = node.input(inpId);
+        const int numNodes = net.node_size();
+        for (int i = 0; i < numNodes; ++i)
+        {
+            const tensorflow::NodeDef& node = net.node(i);
+            if (node.name() == name)
+                return node;
+        }
+        CV_Error(Error::StsParseError, "Input node with name " + name + " not found");
+        return net.node(0);  // just return something
+    }
+
+    // Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
+    // Returns true if nodes are matched and can be fused.
+    bool match(const tensorflow::GraphDef& net, int nodeId, int* numMatchedNodes)
+    {
+        *numMatchedNodes = 0;
+        int numNodes = net.node_size();
+        for (int i = 0; i < nodesToFuse.size(); ++i)
+        {
+            if (nodeId + i > numNodes - 1)
+                return false;
+
+            const tensorflow::NodeDef &node = net.node(nodeId + i);
+            if (node.op() != nodes[nodesToFuse[i]])
+                return false;
+
+            std::vector<int>& inputNodes = inputs[nodesToFuse[i]];
+            if (inputNodes.size() != node.input_size())
+                return false;
+            for (int j = 0; j < inputNodes.size(); ++j)
+            {
+                if (nodes[inputNodes[j]].empty())  // Unknown input node type.
+                    continue;
+                const tensorflow::NodeDef& inpNode = getInputNode(net, node, j);
+                if (inpNode.op() != nodes[inputNodes[j]])
+                    return false;
+            }
+
+            *numMatchedNodes += 1;
+        }
+        return true;
+    }
+
+    // Fuse matched subgraph.
+    void replace(tensorflow::GraphDef& net, int nodeId, int* numReplacedNodes)
+    {
+        *numReplacedNodes = 0;
+
+        // Extract names of input nodes.
+        std::vector<std::string> inputsNames(fusedNodeInputs.size());
+        for (int i = 0; i < fusedNodeInputs.size(); ++i)
+        {
+            std::string inpName;
+            // Find input node name looking at inputs of fused nodes.
+            for (int j = 0; j < nodesToFuse.size() && inpName.empty(); ++j)
+            {
+                const tensorflow::NodeDef &node = net.node(nodeId + j);
+                std::vector<int>& inpIndices = inputs[nodesToFuse[j]];
+
+                CV_Assert(node.input_size() == inpIndices.size());
+                for (int k = 0; k < inpIndices.size(); ++k)
+                {
+                    if (inpIndices[k] == fusedNodeInputs[i])
+                    {
+                        inpName = node.input(k);
+                        break;
+                    }
+                }
+            }
+            CV_Assert(!inpName.empty());
+            inputsNames[i] = inpName;
+        }
+
+        // Remove all nodes except the last one.
+        *numReplacedNodes = nodesToFuse.size() - 1;
+        net.mutable_node()->DeleteSubrange(nodeId, *numReplacedNodes);
+
+        // Modify the last node to be a fused one.
+        tensorflow::NodeDef* node = net.mutable_node(nodeId);
+        node->set_op(fusedNodeOp);
+        node->clear_input();
+        for (int i = 0; i < inputsNames.size(); ++i)
+        {
+            node->add_input(inputsNames[i]);
+        }
+
+        std::vector<tensorflow::NodeDef> inputNodes(inputsNames.size());
+        for (int i = 0; i < inputsNames.size(); ++i)
+        {
+            inputNodes[i] = getInputNode(net, *node, i);
+        }
+        finalize(net, node, inputNodes);
+    }
+
+    virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef*,
+                          const std::vector<tensorflow::NodeDef>&) {}
+
+private:
+    std::vector<std::string> nodes;         // Nodes to be matched in the origin graph.
+    std::vector<std::vector<int> > inputs;  // Connections of an every node to it's inputs.
+
+    std::string fusedNodeOp;           // Operation name of resulting fused node.
+    std::vector<int> nodesToFuse;      // Set of nodes to be fused.
+    std::vector<int> fusedNodeInputs;  // Inputs of fused node.
+};
+
+class BatchNormSubgraph : public Subgraph
+{
+public:
+    BatchNormSubgraph()
+    {
+        int input = addNodeToMatch("");
+        int epsilon = addNodeToMatch("Const");
+        int moving_variance = addNodeToMatch("Const");
+        int moving_mean = addNodeToMatch("Const");
+        int beta = addNodeToMatch("Const");
+        int gamma = addNodeToMatch("Const");
+        int add = addNodeToMatch("Add", moving_variance, epsilon);
+        int rsqrt = addNodeToMatch("Rsqrt", add);
+        int mul = addNodeToMatch("Mul", rsqrt, gamma);
+        int mul_1 = addNodeToMatch("Mul", input, mul);
+        int mul_2 = addNodeToMatch("Mul", moving_mean, mul);
+        int sub = addNodeToMatch("Sub", beta, mul_2);
+        addNodeToMatch("Add", mul_1, sub);
+
+        setFusedNode("FusedBatchNorm", input, gamma, beta, moving_mean, moving_variance, epsilon);
+    }
+
+    virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
+                          const std::vector<tensorflow::NodeDef>& inputNodes)
+    {
+        Mat epsMat = getTensorContent(inputNodes.back().attr().at("value").tensor());
+        CV_Assert(epsMat.total() == 1, epsMat.type() == CV_32FC1);
+
+        fusedNode->mutable_input()->ReleaseLast();
+        fusedNode->clear_attr();
+        tensorflow::AttrValue epsilon;
+        epsilon.set_f(epsMat.at<float>(0));
+        fusedNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("epsilon", epsilon));
+    }
+};
+
+class BatchNormNoGammaSubgraph : public Subgraph
+{
+public:
+    BatchNormNoGammaSubgraph()
+    {
+        int input = addNodeToMatch("");
+        int epsilon = addNodeToMatch("Const");
+        int moving_variance = addNodeToMatch("Const");
+        int moving_mean = addNodeToMatch("Const");
+        int beta = addNodeToMatch("Const");
+        int add = addNodeToMatch("Add", moving_variance, epsilon);
+        int rsqrt = addNodeToMatch("Rsqrt", add);
+        int mul = addNodeToMatch("Mul", input, rsqrt);
+        int mul_1 = addNodeToMatch("Mul", moving_mean, rsqrt);
+        int sub = addNodeToMatch("Sub", beta, mul_1);
+        addNodeToMatch("Add", mul, sub);
+
+        // There is a fake reference to beta that will be replaced to a new gamma tensor.
+        setFusedNode("FusedBatchNorm", input, beta, beta, moving_mean, moving_variance, epsilon);
+    }
+
+    virtual void finalize(tensorflow::GraphDef& net, tensorflow::NodeDef* fusedNode,
+                          const std::vector<tensorflow::NodeDef>& inputNodes)
+    {
+        Mat epsMat = getTensorContent(inputNodes.back().attr().at("value").tensor());
+        CV_Assert(epsMat.total() == 1, epsMat.type() == CV_32FC1);
+
+        fusedNode->mutable_input()->ReleaseLast();
+        fusedNode->clear_attr();
+        tensorflow::AttrValue epsilon;
+        epsilon.set_f(epsMat.at<float>(0));
+        fusedNode->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("epsilon", epsilon));
+
+        tensorflow::NodeDef* gamma = net.add_node();
+        gamma->set_op("Const");
+        gamma->set_name(fusedNode->name() + "/gamma");
+        // Just put a single value to recognize this node as Const.
+        gamma->mutable_attr()->insert(MapPair<std::string, tensorflow::AttrValue>("value", epsilon));
+        fusedNode->set_input(1, gamma->name());
+    }
+};
+
+// tf.contrib.layers.flatten
+class FlattenSubgraph : public Subgraph
+{
+public:
+    FlattenSubgraph()
+    {
+        int input = addNodeToMatch("");
+        int shape = addNodeToMatch("Const");
+        int stack = addNodeToMatch("Const");
+        int stack_1 = addNodeToMatch("Const");
+        int stack_2 = addNodeToMatch("Const");
+        int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
+        int shape_pack = addNodeToMatch("Const");
+        int pack = addNodeToMatch("Pack", strided_slice, shape_pack);
+        addNodeToMatch("Reshape", input, pack);
+
+        setFusedNode("Flatten", input);
+    }
+};
+
+// tf.contrib.layers.flatten in case of unknown batch size
+class FlattenShapeSubgraph : public Subgraph
+{
+public:
+    FlattenShapeSubgraph()
+    {
+        int input = addNodeToMatch("");
+        int shape = addNodeToMatch("Shape", input);
+        int stack = addNodeToMatch("Const");
+        int stack_1 = addNodeToMatch("Const");
+        int stack_2 = addNodeToMatch("Const");
+        int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
+        int shape_pack = addNodeToMatch("Const");
+        int pack = addNodeToMatch("Pack", strided_slice, shape_pack);
+        addNodeToMatch("Reshape", input, pack);
+
+        setFusedNode("Flatten", input);
+    }
+};
+
+void simplifySubgraphs(tensorflow::GraphDef& net)
+{
+    std::vector<Ptr<Subgraph> > subgraphs;
+    subgraphs.push_back(Ptr<Subgraph>(new BatchNormSubgraph()));
+    subgraphs.push_back(Ptr<Subgraph>(new BatchNormNoGammaSubgraph()));
+    subgraphs.push_back(Ptr<Subgraph>(new FlattenSubgraph()));
+    subgraphs.push_back(Ptr<Subgraph>(new FlattenShapeSubgraph()));
+
+    int numNodes = net.node_size();
+    int numMatchedNodes, numReplacedNodes;
+    for (int i = 0; i < numNodes; ++i)
+    {
+        for (int j = 0; j < subgraphs.size(); ++j)
+        {
+            if (subgraphs[j]->match(net, i, &numMatchedNodes))
+            {
+                subgraphs[j]->replace(net, i, &numReplacedNodes);
+                numNodes -= numReplacedNodes;
+                break;
+            }
+        }
+    }
+}
+
+void RemoveIdentityOps(tensorflow::GraphDef& net)
+{
+    typedef std::map<String, String>  IdentityOpsMap;
+    IdentityOpsMap identity_ops;
+
+    std::vector<int> identity_ops_idx;
+
+    int layersCount = net.node_size();
+    for (int li = 0; li < layersCount; li++)
+    {
+        const tensorflow::NodeDef &layer = net.node(li);
+        String type = layer.op();
+
+        if (type == "Identity" || type == "Dropout") {
+            identity_ops_idx.push_back(li);
+            identity_ops[layer.name()] = layer.input(0);
+        }
+    }
+
+    for (int li = 0; li < layersCount; li++)
+    {
+        tensorflow::NodeDef* layer = net.mutable_node(li);
+        for (int input_id = 0; input_id < layer->input_size(); input_id++) {
+            String input_op_name = layer->input(input_id);
+            IdentityOpsMap::iterator it = identity_ops.find(input_op_name);
+
+            if (it != identity_ops.end()) {
+                layer->set_input(input_id, it->second);
+            }
+        }
+    }
+
+    std::sort(identity_ops_idx.begin(), identity_ops_idx.end());
+
+    int removed_nodes = 0;
+    for(size_t i = 0; i < identity_ops_idx.size(); i++) {
+        int start_id = identity_ops_idx[i] - removed_nodes;
+        net.mutable_node()->DeleteSubrange(start_id, 1);
+        removed_nodes++;
+    }
+}
+
+Mat getTensorContent(const tensorflow::TensorProto &tensor)
+{
+    std::string content = tensor.tensor_content();
+    switch (tensor.dtype())
+    {
+        case tensorflow::DT_FLOAT:
+        {
+            if (!content.empty())
+                return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone();
+            else
+            {
+                const RepeatedField<float>& field = tensor.float_val();
+                CV_Assert(!field.empty());
+                return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone();
+            }
+        }
+        case tensorflow::DT_DOUBLE:
+        {
+            if (!content.empty())
+                return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone();
+            else
+            {
+                const RepeatedField<double>& field = tensor.double_val();
+                CV_Assert(!field.empty());
+                return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone();
+            }
+        }
+        case tensorflow::DT_INT32:
+        {
+            if (!content.empty())
+                return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone();
+            else
+            {
+                const RepeatedField<int32_t>& field = tensor.int_val();
+                CV_Assert(!field.empty());
+                return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone();
+            }
+        }
+        case tensorflow::DT_HALF:
+        {
+            Mat halfs;
+            if (!content.empty())
+            {
+                static const int kHalfSize = 2;
+                halfs = Mat(1, content.size() / kHalfSize, CV_16UC1, (void*)content.c_str());
+            }
+            else
+            {
+                const RepeatedField<int32_t>& field = tensor.half_val();
+                CV_Assert(!field.empty());
+                Mat ints(1, field.size(), CV_32SC1, (void*)field.data());
+                ints.convertTo(halfs, CV_16UC1);
+            }
+            // Reinterpret as a signed shorts just for a convertFp16 call.
+            Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data);
+            Mat floats(halfs.size(), CV_32FC1);
+            convertFp16(halfsSigned, floats);
+            return floats;
+        }
+        case tensorflow::DT_QUINT8:
+        {
+            CV_Assert(!content.empty());
+            return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone();
+        }
+        default:
+            CV_Error(Error::StsError, "Tensor's data type is not supported");
+            break;
+    }
+    return Mat();
+}
+
+CV__DNN_EXPERIMENTAL_NS_END
+}}  // namespace dnn, namespace cv
+
+#endif  // HAVE_PROTOBUF
diff --git a/modules/dnn/src/tensorflow/tf_graph_editor.hpp b/modules/dnn/src/tensorflow/tf_graph_editor.hpp
new file mode 100644 (file)
index 0000000..5568c09
--- /dev/null
@@ -0,0 +1,30 @@
+// This file is part of OpenCV project.
+// It is subject to the license terms in the LICENSE file found in the top-level directory
+// of this distribution and at http://opencv.org/license.html.
+
+// Copyright (C) 2018, Intel Corporation, all rights reserved.
+// Third party copyrights are property of their respective owners.
+
+#ifndef __OPENCV_DNN_TF_SIMPLIFIER_HPP__
+#define __OPENCV_DNN_TF_SIMPLIFIER_HPP__
+
+#include "../precomp.hpp"
+
+#ifdef HAVE_PROTOBUF
+
+#include "tf_io.hpp"
+
+namespace cv { namespace dnn {
+CV__DNN_EXPERIMENTAL_NS_BEGIN
+
+void RemoveIdentityOps(tensorflow::GraphDef& net);
+
+void simplifySubgraphs(tensorflow::GraphDef& net);
+
+Mat getTensorContent(const tensorflow::TensorProto &tensor);
+
+CV__DNN_EXPERIMENTAL_NS_END
+}}  // namespace dnn, namespace cv
+
+#endif  // HAVE_PROTOBUF
+#endif  // __OPENCV_DNN_TF_SIMPLIFIER_HPP__
index 5309ec4..9be29b9 100644 (file)
@@ -22,6 +22,7 @@ Implementation of Tensorflow models parser
 #include <google/protobuf/text_format.h>
 #include <google/protobuf/io/zero_copy_stream_impl.h>
 #include "tf_io.hpp"
+#include "tf_graph_editor.hpp"
 #endif
 
 namespace cv {
@@ -87,77 +88,6 @@ void blobShapeFromTensor(const tensorflow::TensorProto &tensor, MatShape& shape)
     }
 }
 
-static Mat getTensorContent(const tensorflow::TensorProto &tensor)
-{
-    std::string content = tensor.tensor_content();
-    switch (tensor.dtype())
-    {
-        case tensorflow::DT_FLOAT:
-        {
-            if (!content.empty())
-                return Mat(1, content.size() / sizeof(float), CV_32FC1, (void*)content.c_str()).clone();
-            else
-            {
-                const RepeatedField<float>& field = tensor.float_val();
-                CV_Assert(!field.empty());
-                return Mat(1, field.size(), CV_32FC1, (void*)field.data()).clone();
-            }
-        }
-        case tensorflow::DT_DOUBLE:
-        {
-            if (!content.empty())
-                return Mat(1, content.size() / sizeof(double), CV_64FC1, (void*)content.c_str()).clone();
-            else
-            {
-                const RepeatedField<double>& field = tensor.double_val();
-                CV_Assert(!field.empty());
-                return Mat(1, field.size(), CV_64FC1, (void*)field.data()).clone();
-            }
-        }
-        case tensorflow::DT_INT32:
-        {
-            if (!content.empty())
-                return Mat(1, content.size() / sizeof(int32_t), CV_32SC1, (void*)content.c_str()).clone();
-            else
-            {
-                const RepeatedField<int32_t>& field = tensor.int_val();
-                CV_Assert(!field.empty());
-                return Mat(1, field.size(), CV_32SC1, (void*)field.data()).clone();
-            }
-        }
-        case tensorflow::DT_HALF:
-        {
-            Mat halfs;
-            if (!content.empty())
-            {
-                static const int kHalfSize = 2;
-                halfs = Mat(1, content.size() / kHalfSize, CV_16UC1, (void*)content.c_str());
-            }
-            else
-            {
-                const RepeatedField<int32_t>& field = tensor.half_val();
-                CV_Assert(!field.empty());
-                Mat ints(1, field.size(), CV_32SC1, (void*)field.data());
-                ints.convertTo(halfs, CV_16UC1);
-            }
-            // Reinterpret as a signed shorts just for a convertFp16 call.
-            Mat halfsSigned(halfs.size(), CV_16SC1, halfs.data);
-            Mat floats(halfs.size(), CV_32FC1);
-            convertFp16(halfsSigned, floats);
-            return floats;
-        }
-        case tensorflow::DT_QUINT8:
-        {
-            CV_Assert(!content.empty());
-            return Mat(1, content.size(), CV_8UC1, (void*)content.c_str()).clone();
-        }
-        default:
-            CV_Error(Error::StsError, "Tensor's data type is not supported");
-            break;
-    }
-    return Mat();
-}
-
 template <typename T>
 void parseTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
 {
@@ -364,47 +294,6 @@ void setPadding(LayerParams &layerParams, const tensorflow::NodeDef &layer)
         layerParams.set("pad_mode", getLayerAttr(layer, "padding").s());
 }
 
-void RemoveIdentityOps(tensorflow::GraphDef& net) {
-    typedef std::map<String, String>  IdentityOpsMap;
-    IdentityOpsMap identity_ops;
-
-    std::vector<int> identity_ops_idx;
-
-    int layersCount = net.node_size();
-    for (int li = 0; li < layersCount; li++)
-    {
-        const tensorflow::NodeDef &layer = net.node(li);
-        String type = layer.op();
-
-        if (type == "Identity" || type == "Dropout") {
-            identity_ops_idx.push_back(li);
-            identity_ops[layer.name()] = layer.input(0);
-        }
-    }
-
-    for (int li = 0; li < layersCount; li++)
-    {
-        tensorflow::NodeDef* layer = net.mutable_node(li);
-        for (int input_id = 0; input_id < layer->input_size(); input_id++) {
-            String input_op_name = layer->input(input_id);
-            IdentityOpsMap::iterator it = identity_ops.find(input_op_name);
-
-            if (it != identity_ops.end()) {
-                layer->set_input(input_id, it->second);
-            }
-        }
-    }
-
-    std::sort(identity_ops_idx.begin(), identity_ops_idx.end());
-
-    int removed_nodes = 0;
-    for(size_t i = 0; i < identity_ops_idx.size(); i++) {
-        int start_id = identity_ops_idx[i] - removed_nodes;
-        net.mutable_node()->DeleteSubrange(start_id, 1);
-        removed_nodes++;
-    }
-}
-
 Pin parsePin(const std::string &name)
 {
     Pin pin(name);
@@ -697,6 +586,9 @@ void TFImporter::populateNet(Net dstNet)
     RemoveIdentityOps(netBin);
     RemoveIdentityOps(netTxt);
 
+    if (!netTxt.ByteSize())
+        simplifySubgraphs(netBin);
+
     std::set<String> layers_to_ignore;
 
     tensorflow::GraphDef& net = netTxt.ByteSize() != 0 ? netTxt : netBin;
@@ -936,10 +828,28 @@ void TFImporter::populateNet(Net dstNet)
             connect(layer_id, dstNet, inpId, id, 0);
             data_layouts[name] = DATA_LAYOUT_UNKNOWN;
         }
-        else if (type == "Flatten")
+        else if (type == "Flatten" || type == "Squeeze")
         {
             Pin inpId = parsePin(layer.input(0));
-            if (data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
+            int inpLayout = data_layouts[layer.input(0)];
+            if (type == "Squeeze")
+            {
+                CV_Assert(hasLayerAttr(layer, "squeeze_dims"));
+                const tensorflow::AttrValue& dims = getLayerAttr(layer, "squeeze_dims");
+                if (inpLayout == DATA_LAYOUT_NHWC)
+                {
+                    if (dims.list().i_size() != 2 || dims.list().i(0) != 1 || dims.list().i(1) != 2)
+                        CV_Error(Error::StsNotImplemented, "Unsupported squeeze configuration");
+                }
+                else if (inpLayout == DATA_LAYOUT_NCHW)
+                {
+                    if (dims.list().i_size() != 2 || dims.list().i(0) != 2 || dims.list().i(1) != 3)
+                        CV_Error(Error::StsNotImplemented, "Unsupported squeeze configuration");
+                }
+                else
+                    CV_Error(Error::StsNotImplemented, "Unsupported squeeze configuration");
+            }
+            if (inpLayout == DATA_LAYOUT_NHWC)
             {
                 LayerParams permLP;
                 int order[] = {0, 2, 3, 1};  // From OpenCV's NCHW to NHWC.
@@ -1274,14 +1184,36 @@ void TFImporter::populateNet(Net dstNet)
 
             bool isTraining = hasLayerAttr(layer, "is_training") && getLayerAttr(layer, "is_training").b();
 
-            layerParams.blobs.resize(4);
-            Mat gamma, beta, mean, std;
-            blobFromTensor(getConstBlob(layer, value_id, 1), gamma);
-            blobFromTensor(getConstBlob(layer, value_id, 2), beta);
+            layerParams.blobs.resize(2);
+
+            const tensorflow::TensorProto& gammaTensor = getConstBlob(layer, value_id, 1);
+            if (!gammaTensor.tensor_content().empty())
+            {
+                layerParams.blobs.resize(layerParams.blobs.size() + 1);
+                layerParams.set("has_weight", true);
+                blobFromTensor(gammaTensor, layerParams.blobs.back());
+            }
+            else
+                layerParams.set("has_weight", false);
+
+            const tensorflow::TensorProto& betaTensor = getConstBlob(layer, value_id, 2);
+            if (!betaTensor.tensor_content().empty())
+            {
+                layerParams.blobs.resize(layerParams.blobs.size() + 1);
+                layerParams.set("has_bias", true);
+                blobFromTensor(betaTensor, layerParams.blobs.back());
+            }
+            else
+                layerParams.set("has_bias", false);
+
+            Mat mean, std;
             if (isTraining)
             {
-                mean = Mat::zeros(1, beta.total(), CV_32F);
-                std = Mat::ones(1, beta.total(), CV_32F);
+                if (layerParams.blobs.size() == 2)
+                    CV_Error(Error::StsNotImplemented, "Cannot determine number "
+                             "of parameters for batch normalization layer.");
+                mean = Mat::zeros(1, layerParams.blobs[3].total(), CV_32F);
+                std = Mat::ones(1, layerParams.blobs[3].total(), CV_32F);
 
                 // Add an extra layer: Mean-Variance normalization
                 LayerParams mvnParams;
@@ -1299,15 +1231,10 @@ void TFImporter::populateNet(Net dstNet)
             }
             layerParams.blobs[0] = mean;
             layerParams.blobs[1] = std;
-            layerParams.blobs[2] = gamma;
-            layerParams.blobs[3] = beta;
 
             if (hasLayerAttr(layer, "epsilon"))
                 layerParams.set("eps", getLayerAttr(layer, "epsilon").f());
 
-            layerParams.set("has_weight", true);
-            layerParams.set("has_bias", true);
-
             int id = dstNet.addLayer(name, "BatchNorm", layerParams);
             layer_id[name] = id;
 
index b3b9959..b5c9567 100644 (file)
@@ -150,6 +150,9 @@ TEST_P(Test_TensorFlow_layers, batch_norm)
     runTensorFlowNet("batch_norm_text", targetId, true);
     runTensorFlowNet("mvn_batch_norm", targetId);
     runTensorFlowNet("mvn_batch_norm_1x1", targetId);
+    runTensorFlowNet("unfused_batch_norm", targetId);
+    runTensorFlowNet("fused_batch_norm_no_gamma", targetId);
+    runTensorFlowNet("unfused_batch_norm_no_gamma", targetId);
 }
 
 TEST_P(Test_TensorFlow_layers, pooling)
@@ -185,6 +188,8 @@ TEST_P(Test_TensorFlow_layers, reshape)
     runTensorFlowNet("shift_reshape_no_reorder", targetId);
     runTensorFlowNet("reshape_reduce", targetId);
     runTensorFlowNet("flatten", targetId, true);
+    runTensorFlowNet("unfused_flatten", targetId);
+    runTensorFlowNet("unfused_flatten_unknown_batch", targetId);
 }
 
 INSTANTIATE_TEST_CASE_P(/**/, Test_TensorFlow_layers, availableDnnTargets());