From: Liubov Batanina Date: Tue, 21 Apr 2020 07:34:56 +0000 (+0300) Subject: Supported TF concat 3d X-Git-Tag: submit/tizen/20210224.033012~2^2~198^2~9^2~2 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=aa08900ac8c3399534d57f55b7e7b3b1a6d8c132;p=platform%2Fupstream%2Fopencv.git Supported TF concat 3d --- diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index 0dd2177..6aadc57 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -46,6 +46,14 @@ static int toNCHW(int idx) else return (4 + idx) % 3 + 1; } +static int toNCDHW(int idx) +{ + CV_Assert(-5 <= idx && idx < 5); + if (idx == 0) return 0; + else if (idx > 0) return idx % 4 + 1; + else return (5 + idx) % 4 + 1; +} + // This values are used to indicate layer output's data layout where it's possible. enum DataLayout { @@ -1313,6 +1321,8 @@ void TFImporter::populateNet(Net dstNet) if (getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC) axis = toNCHW(axis); + else if (getDataLayout(name, data_layouts) == DATA_LAYOUT_NDHWC) + axis = toNCDHW(axis); layerParams.set("axis", axis); // input(0) or input(n-1) is concat_dim diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index b71dfbc..8c48743 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -196,6 +196,7 @@ TEST_P(Test_TensorFlow_layers, pad_and_concat) TEST_P(Test_TensorFlow_layers, concat_axis_1) { runTensorFlowNet("concat_axis_1"); + runTensorFlowNet("concat_3d"); } TEST_P(Test_TensorFlow_layers, batch_norm_1)