Add Reshape layer tests
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 3 Jul 2018 05:26:43 +0000 (08:26 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 3 Jul 2018 05:26:43 +0000 (08:26 +0300)
modules/dnn/src/layers/reshape_layer.cpp
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/src/torch/torch_importer.cpp
modules/dnn/test/test_layers.cpp
modules/dnn/test/test_tf_importer.cpp

index 65a81c7..c9e632d 100644 (file)
@@ -82,17 +82,26 @@ static void computeShapeByReshapeMask(const MatShape &srcShape,
         {
             if (matched)
             {
-                if (i == 0 || total(srcShape, i, srcRange.end) != maskTotal)
+                if (total(srcShape, i, srcRange.end) != maskTotal)
                 {
                     srcRange.start = i + 1;
                     break;
                 }
+                else if (i == 0)
+                {
+                    srcRange.start = 0;
+                    break;
+                }
             }
             else
             {
                 matched = total(srcShape, i, srcRange.end) == maskTotal;
             }
         }
+        while (total(srcShape, srcRange.start, srcRange.end) != maskTotal && srcRange.start > 0)
+        {
+            srcRange.start -= 1;
+        }
         CV_Assert(total(srcShape, srcRange.start, srcRange.end) == maskTotal);
     }
 
index 1faa7fb..7d7d300 100644 (file)
@@ -262,6 +262,18 @@ static int getDataLayout(const tensorflow::NodeDef& layer)
     return DATA_LAYOUT_UNKNOWN;
 }
 
+static inline std::string getNodeName(const std::string& tensorName)
+{
+    return tensorName.substr(0, tensorName.rfind(':'));
+}
+
+static inline int getDataLayout(const std::string& layerName,
+                                const std::map<String, int>& data_layouts)
+{
+    std::map<String, int>::const_iterator it = data_layouts.find(getNodeName(layerName));
+    return it != data_layouts.end() ? it->second : DATA_LAYOUT_UNKNOWN;
+}
+
 void setStrides(LayerParams &layerParams, const tensorflow::NodeDef &layer)
 {
     if (hasLayerAttr(layer, "strides"))
@@ -604,11 +616,6 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map<String, int>& cons
     }
 }
 
-static inline std::string getNodeName(const std::string& tensorName)
-{
-    return tensorName.substr(0, tensorName.rfind(':'));
-}
-
 // If all inputs of specific layer have the same data layout we can say that
 // this layer's output has this data layout too. Returns DATA_LAYOUT_UNKNOWN otherwise.
 static int predictOutputDataLayout(const tensorflow::GraphDef& net,
@@ -830,7 +837,8 @@ void TFImporter::populateNet(Net dstNet)
             // one input only
             connect(layer_id, dstNet, parsePin(input), id, 0);
 
-            if (data_layouts[name] == DATA_LAYOUT_UNKNOWN)
+
+            if (getDataLayout(name, data_layouts) == DATA_LAYOUT_UNKNOWN)
                 data_layouts[name] = DATA_LAYOUT_NHWC;
         }
         else if (type == "BiasAdd" || type == "Add")
@@ -956,7 +964,8 @@ void TFImporter::populateNet(Net dstNet)
             Pin inpId = parsePin(layer.input(0));
             Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1));
 
-            if (newShape.total() != 4 && data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
+            int inpLayout = getDataLayout(layer.input(0), data_layouts);
+            if (newShape.total() != 4 && inpLayout == DATA_LAYOUT_NHWC)
             {
                 LayerParams permLP;
                 int order[] = {0, 2, 3, 1};  // From OpenCV's NCHW to NHWC.
@@ -969,7 +978,7 @@ void TFImporter::populateNet(Net dstNet)
                 connect(layer_id, dstNet, inpId, permId, 0);
                 inpId = Pin(permName);
             }
-            else if (newShape.total() == 4 && data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
+            else if (newShape.total() == 4 && inpLayout == DATA_LAYOUT_NHWC)
             {
                 // NHWC->NCHW
                 std::swap(*newShape.ptr<int32_t>(0, 2), *newShape.ptr<int32_t>(0, 3));
@@ -987,7 +996,7 @@ void TFImporter::populateNet(Net dstNet)
         else if (type == "Flatten" || type == "Squeeze")
         {
             Pin inpId = parsePin(layer.input(0));
-            int inpLayout = data_layouts[layer.input(0)];
+            int inpLayout = getDataLayout(layer.input(0), data_layouts);
             if (type == "Squeeze")
             {
                 CV_Assert(hasLayerAttr(layer, "squeeze_dims"));
@@ -1032,7 +1041,8 @@ void TFImporter::populateNet(Net dstNet)
             {
                 // Only NHWC <-> NCHW permutations are allowed. OpenCV is always
                 // keep NCHW layout this way.
-                if (data_layouts[layer.input(0)] == DATA_LAYOUT_NHWC)
+                int inpLayout = getDataLayout(layer.input(0), data_layouts);
+                if (inpLayout == DATA_LAYOUT_NHWC)
                 {
                     if (permData[0] == 0 && permData[1] == 3 && permData[2] == 1 && permData[3] == 2)
                     {
@@ -1049,7 +1059,7 @@ void TFImporter::populateNet(Net dstNet)
                     else
                         CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
                 }
-                else if (data_layouts[layer.input(0)] == DATA_LAYOUT_NCHW)
+                else if (inpLayout == DATA_LAYOUT_NCHW)
                 {
                     if (permData[0] == 0 && permData[1] == 2 && permData[2] == 3 && permData[3] == 1)
                     {
@@ -1112,7 +1122,7 @@ void TFImporter::populateNet(Net dstNet)
             int axisId = (type == "Concat" ? 0 : layer.input_size() - 1);
             int axis = getConstBlob(layer, value_id, axisId).int_val().Get(0);
 
-            if (data_layouts[name] == DATA_LAYOUT_NHWC)
+            if (getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC)
                 axis = toNCHW(axis);
             layerParams.set("axis", axis);
 
@@ -1197,7 +1207,7 @@ void TFImporter::populateNet(Net dstNet)
             CV_Assert(!begins.empty(), !sizes.empty(), begins.type() == CV_32SC1,
                       sizes.type() == CV_32SC1);
 
-            if (begins.total() == 4 && data_layouts[name] == DATA_LAYOUT_NHWC)
+            if (begins.total() == 4 && getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC)
             {
                 // Swap NHWC parameters' order to NCHW.
                 std::swap(*begins.ptr<int32_t>(0, 2), *begins.ptr<int32_t>(0, 3));
@@ -1597,7 +1607,7 @@ void TFImporter::populateNet(Net dstNet)
             CV_Assert(reductionIndices.type() == CV_32SC1);
 
             const int numAxes = reductionIndices.total();
-            if (data_layouts[name] == DATA_LAYOUT_NHWC)
+            if (getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC)
                 for (int i = 0; i < numAxes; ++i)
                     reductionIndices.at<int>(i) = toNCHW(reductionIndices.at<int>(i));
 
index 3607e6c..88779e9 100644 (file)
@@ -592,8 +592,8 @@ struct TorchImporter
                 DictValue dimParam = scalarParams.get("size");
                 layerParams.set("dim", dimParam);
 
-                if (scalarParams.has("batchMode") && scalarParams.get<bool>("batchMode"))
-                    layerParams.set("axis", 1);
+                int axis = (int)scalarParams.get<bool>("batchMode", true);
+                layerParams.set("axis", axis);
 
                 curModule->modules.push_back(newModule);
             }
index 720447a..963206b 100644 (file)
@@ -201,6 +201,13 @@ TEST(Layer_Test_Reshape, Accuracy)
         testReshape(MatShape(inp, inp + 4), MatShape(out, out + 2), 0, -1,
                     MatShape(mask, mask + 2));
     }
+    {
+        int inp[] = {1, 2, 3};
+        int out[] = {3, 1, 2};
+        int mask[] = {3, 1, 2};
+        testReshape(MatShape(inp, inp + 3), MatShape(out, out + 3), 0, -1,
+                    MatShape(mask, mask + 3));
+    }
 }
 
 TEST(Layer_Test_BatchNorm, Accuracy)
index d4ffc94..4087822 100644 (file)
@@ -198,6 +198,7 @@ TEST_P(Test_TensorFlow_layers, reshape)
 {
     int targetId = GetParam();
     runTensorFlowNet("shift_reshape_no_reorder", targetId);
+    runTensorFlowNet("reshape_no_reorder", targetId);
     runTensorFlowNet("reshape_reduce", targetId);
     runTensorFlowNet("flatten", targetId, true);
     runTensorFlowNet("unfused_flatten", targetId);