Supported TF concat 3d
authorLiubov Batanina <piccione-mail@yandex.ru>
Tue, 21 Apr 2020 07:34:56 +0000 (10:34 +0300)
committerLiubov Batanina <piccione-mail@yandex.ru>
Tue, 21 Apr 2020 12:15:22 +0000 (15:15 +0300)
modules/dnn/src/tensorflow/tf_importer.cpp
modules/dnn/test/test_tf_importer.cpp

index 0dd2177..6aadc57 100644 (file)
@@ -46,6 +46,14 @@ static int toNCHW(int idx)
     else return (4 + idx) % 3 + 1;
 }
 
+static int toNCDHW(int idx)
+{
+    CV_Assert(-5 <= idx && idx < 5);
+    if (idx == 0) return 0;
+    else if (idx > 0) return idx % 4 + 1;
+    else return (5 + idx) % 4 + 1;
+}
+
 // This values are used to indicate layer output's data layout where it's possible.
 enum DataLayout
 {
@@ -1313,6 +1321,8 @@ void TFImporter::populateNet(Net dstNet)
 
             if (getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC)
                 axis = toNCHW(axis);
+            else if (getDataLayout(name, data_layouts) == DATA_LAYOUT_NDHWC)
+                axis = toNCDHW(axis);
             layerParams.set("axis", axis);
 
             // input(0) or input(n-1) is concat_dim
index b71dfbc..8c48743 100644 (file)
@@ -196,6 +196,7 @@ TEST_P(Test_TensorFlow_layers, pad_and_concat)
 TEST_P(Test_TensorFlow_layers, concat_axis_1)
 {
     runTensorFlowNet("concat_axis_1");
+    runTensorFlowNet("concat_3d");
 }
 
 TEST_P(Test_TensorFlow_layers, batch_norm_1)