Parse strides and convolution kernel shapes considering data layout
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 26 Jun 2018 13:13:40 +0000 (16:13 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 26 Jun 2018 13:18:21 +0000 (16:18 +0300)
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

index 986225a..6a4a0ab 100644 (file)
@@ -246,16 +246,41 @@ const tensorflow::AttrValue& getLayerAttr(const tensorflow::NodeDef &layer, cons
     return layer.attr().at(name);
 }
 
+static int getDataLayout(const tensorflow::NodeDef& layer)
+{
+    if (hasLayerAttr(layer, "data_format"))
+    {
+        std::string format = getLayerAttr(layer, "data_format").s();
+        if (format == "NHWC" || format == "channels_last")
+            return DATA_LAYOUT_NHWC;
+        else if (format == "NCHW" || format == "channels_first")
+            return DATA_LAYOUT_NCHW;
+        else
+            CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
+    }
+    return DATA_LAYOUT_UNKNOWN;
+}
+
 void setStrides(LayerParams &layerParams, const tensorflow::NodeDef &layer)
 {
     if (hasLayerAttr(layer, "strides"))
     {
         const tensorflow::AttrValue& val = getLayerAttr(layer, "strides");
+        int dimX, dimY, dimC;
+        int layout = getDataLayout(layer);
+        if (layout == DATA_LAYOUT_NCHW)
+        {
+            dimC = 1; dimY = 2; dimX = 3;
+        }
+        else
+        {
+            dimY = 1; dimX = 2; dimC = 3;
+        }
         if (val.list().i_size() != 4 ||
-            val.list().i(0) != 1 || val.list().i(3) != 1)
+            val.list().i(0) != 1 || val.list().i(dimC) != 1)
             CV_Error(Error::StsError, "Unsupported strides");
-        layerParams.set("stride_h", static_cast<int>(val.list().i(1)));
-        layerParams.set("stride_w", static_cast<int>(val.list().i(2)));
+        layerParams.set("stride_h", static_cast<int>(val.list().i(dimY)));
+        layerParams.set("stride_w", static_cast<int>(val.list().i(dimX)));
     }
 }
 
@@ -278,11 +303,21 @@ void setKSize(LayerParams &layerParams, const tensorflow::NodeDef &layer)
     if (hasLayerAttr(layer, "ksize"))
     {
         const tensorflow::AttrValue& val = getLayerAttr(layer, "ksize");
+        int dimX, dimY, dimC;
+        int layout = getDataLayout(layer);
+        if (layout == DATA_LAYOUT_NCHW)
+        {
+            dimC = 1; dimY = 2; dimX = 3;
+        }
+        else
+        {
+            dimY = 1; dimX = 2; dimC = 3;
+        }
         if (val.list().i_size() != 4 ||
-            val.list().i(0) != 1 || val.list().i(3) != 1)
+            val.list().i(0) != 1 || val.list().i(dimC) != 1)
             CV_Error(Error::StsError, "Unsupported ksize");
-        layerParams.set("kernel_h", static_cast<int>(val.list().i(1)));
-        layerParams.set("kernel_w", static_cast<int>(val.list().i(2)));
+        layerParams.set("kernel_h", static_cast<int>(val.list().i(dimY)));
+        layerParams.set("kernel_w", static_cast<int>(val.list().i(dimX)));
     }
     else
     {
@@ -568,21 +603,6 @@ static void addConstNodes(tensorflow::GraphDef& net, std::map<String, int>& cons
     }
 }
 
-static int getDataLayout(const tensorflow::NodeDef& layer)
-{
-    if (hasLayerAttr(layer, "data_format"))
-    {
-        std::string format = getLayerAttr(layer, "data_format").s();
-        if (format == "NHWC" || format == "channels_last")
-            return DATA_LAYOUT_NHWC;
-        else if (format == "NCHW" || format == "channels_first")
-            return DATA_LAYOUT_NCHW;
-        else
-            CV_Error(Error::StsParseError, "Unknown data_format value: " + format);
-    }
-    return DATA_LAYOUT_UNKNOWN;
-}
-
 static inline std::string getNodeName(const std::string& tensorName)
 {
     return tensorName.substr(0, tensorName.rfind(':'));
index 747fefd..d4ffc94 100644 (file)
@@ -127,6 +127,7 @@ TEST_P(Test_TensorFlow_layers, conv)
     runTensorFlowNet("atrous_conv2d_same", targetId);
     runTensorFlowNet("depthwise_conv2d", targetId);
     runTensorFlowNet("keras_atrous_conv2d_same", targetId);
+    runTensorFlowNet("conv_pool_nchw", targetId);
 }
 
 TEST_P(Test_TensorFlow_layers, padding)