Aligned TF Reshape layer behaviour
authorAnastasia Murzova <anastasia.murzova@xperience.ai>
Sun, 28 Feb 2021 16:55:43 +0000 (19:55 +0300)
committerAnastasia Murzova <anastasia.murzova@xperience.ai>
Thu, 4 Mar 2021 22:01:37 +0000 (01:01 +0300)
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

index c03ac8a..53d62fc 100644 (file)
@@ -295,6 +295,22 @@ DataLayout getDataLayout(
     return it != data_layouts.end() ? it->second : DATA_LAYOUT_UNKNOWN;
 }
 
+static
+bool hasAllOnes(const Mat &inputs, int startPos, int endPos)
+{
+    CV_CheckLE(inputs.dims, 2, "");
+    CV_CheckGE(startPos, 0, "");
+    CV_CheckLE(startPos, endPos, "");
+    CV_CheckLT((size_t)endPos, inputs.total(), "");
+
+    for (int i = startPos; i < endPos; i++)
+    {
+        if (inputs.at<int>(i) != 1 || inputs.at<int>(i)!= -1)
+            return false;
+    }
+    return true;
+}
+
 void setStrides(LayerParams &layerParams, const tensorflow::NodeDef &layer)
 {
     if (hasLayerAttr(layer, "strides"))
@@ -490,6 +506,9 @@ protected:
     std::map<String, Mat> sharedWeights;
 
     std::map<String, int> layer_id;
+
+private:
+    void addPermuteLayer(const int* order, const std::string& permName, Pin& inpId);
 };
 
 TFImporter::TFImporter(Net& net, const char *model, const char *config)
@@ -895,6 +914,17 @@ void TFImporter::populateNet()
     CV_LOG_DEBUG(NULL, "DNN/TF: ===================== Import completed =====================");
 }
 
+void TFImporter::addPermuteLayer(const int* order, const std::string& permName, Pin& inpId)
+{
+    LayerParams permLP;
+    permLP.set("order", DictValue::arrayInt<const int*>(order, 4));
+    CV_Assert(layer_id.find(permName) == layer_id.end());
+    int permId = dstNet.addLayer(permName, "Permute", permLP);
+    layer_id[permName] = permId;
+    connect(layer_id, dstNet, inpId, permId, 0);
+    inpId = Pin(permName);
+}
+
 void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
 {
     tensorflow::NodeDef layer = layer_;
@@ -1276,37 +1306,49 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
             if (value_id.find(layer.input(1)) != value_id.end())
             {
                 Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1));
-                if (newShape.total() == 4)
+                int newShapeSize = newShape.total();
+                bool hasSwap = false;
+                if (newShapeSize == 4 && hasAllOnes(newShape, 0, 2))
                 {
                     // NHWC->NCHW
                     std::swap(*newShape.ptr<int32_t>(0, 2), *newShape.ptr<int32_t>(0, 3));
                     std::swap(*newShape.ptr<int32_t>(0, 1), *newShape.ptr<int32_t>(0, 2));
+                    hasSwap = true;
                 }
                 if (inpLayout == DATA_LAYOUT_NHWC)
                 {
-                    if (newShape.total() != 4 || newShape.at<int>(1) == 1)
+                    if (newShapeSize >= 2 || newShape.at<int>(1) == 1)
                     {
-                        LayerParams permLP;
                         int order[] = {0, 2, 3, 1};  // From OpenCV's NCHW to NHWC.
-                        permLP.set("order", DictValue::arrayInt<int*>(order, 4));
-
-                        std::string permName = name + "/nchw";
-                        CV_Assert(layer_id.find(permName) == layer_id.end());
-                        int permId = dstNet.addLayer(permName, "Permute", permLP);
-                        layer_id[permName] = permId;
-                        connect(layer_id, dstNet, inpId, permId, 0);
-                        inpId = Pin(permName);
-                        inpLayout = DATA_LAYOUT_NCHW;
+                        addPermuteLayer(order, name + "/nhwc", inpId);
+                        if (newShapeSize < 4)
+                        {
+                            inpLayout = DATA_LAYOUT_NCHW;
+                        }
+                        else
+                        {
+                            inpLayout = DATA_LAYOUT_NHWC;
+                        }
                     }
                 }
-                layerParams.set("dim", DictValue::arrayInt<int*>(newShape.ptr<int>(), newShape.total()));
+                layerParams.set("dim", DictValue::arrayInt<int*>(newShape.ptr<int>(), newShapeSize));
 
                 int id = dstNet.addLayer(name, "Reshape", layerParams);
                 layer_id[name] = id;
 
                 // one input only
                 connect(layer_id, dstNet, inpId, id, 0);
-                data_layouts[name] = newShape.total() == 2 ? DATA_LAYOUT_PLANAR : inpLayout;
+                inpId = Pin(name);
+
+                if ((inpLayout == DATA_LAYOUT_NHWC || inpLayout == DATA_LAYOUT_UNKNOWN || inpLayout == DATA_LAYOUT_PLANAR) &&
+                    newShapeSize == 4 && !hasSwap)
+                {
+                    int order[] = {0, 3, 1, 2};  // Transform back to OpenCV's NCHW.
+                    addPermuteLayer(order, name + "/nchw", inpId);
+                    inpLayout = DATA_LAYOUT_NCHW;
+                }
+
+                data_layouts[name] = newShapeSize == 2 ? DATA_LAYOUT_PLANAR : inpLayout;
             }
             else
             {
index 6163e89..6a1a44f 100644 (file)
@@ -457,6 +457,16 @@ TEST_P(Test_TensorFlow_layers, unfused_flatten)
     runTensorFlowNet("unfused_flatten_unknown_batch");
 }
 
+TEST_P(Test_TensorFlow_layers, reshape_layer)
+{
+    runTensorFlowNet("reshape_layer");
+}
+
+TEST_P(Test_TensorFlow_layers, reshape_nchw)
+{
+    runTensorFlowNet("reshape_nchw");
+}
+
 TEST_P(Test_TensorFlow_layers, leaky_relu)
 {
 #if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_EQ(2018050000)