From 0d2bc7b5fd2f19c2aa87533d64309d78baa8f2dc Mon Sep 17 00:00:00 2001 From: Liubov Batanina Date: Wed, 17 Jul 2019 15:50:50 +0300 Subject: [PATCH] Fix TF Split layer --- modules/dnn/include/opencv2/dnn/all_layers.hpp | 1 + modules/dnn/src/layers/slice_layer.cpp | 8 +++++--- modules/dnn/src/tensorflow/tf_importer.cpp | 3 +++ modules/dnn/test/test_tf_importer.cpp | 5 +++++ 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 868a8f0..75cba09 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -366,6 +366,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN */ std::vector > sliceRanges; int axis; + int num_split; static Ptr create(const LayerParams ¶ms); }; diff --git a/modules/dnn/src/layers/slice_layer.cpp b/modules/dnn/src/layers/slice_layer.cpp index 73d6a30..7640d46 100644 --- a/modules/dnn/src/layers/slice_layer.cpp +++ b/modules/dnn/src/layers/slice_layer.cpp @@ -61,6 +61,7 @@ public: { setParamsFrom(params); axis = params.get("axis", 1); + num_split = params.get("num_split", 0); if (params.has("slice_point")) { CV_Assert(!params.has("begin") && !params.has("size") && !params.has("end")); @@ -141,9 +142,10 @@ public: else // Divide input blob on equal parts by axis. { CV_Assert(0 <= axis && axis < inpShape.size()); - CV_Assert(requiredOutputs > 0 && inpShape[axis] % requiredOutputs == 0); - inpShape[axis] /= requiredOutputs; - outputs.resize(requiredOutputs, inpShape); + int splits = num_split ? num_split : requiredOutputs; + CV_Assert(splits > 0 && inpShape[axis] % splits == 0); + inpShape[axis] /= splits; + outputs.resize(splits, inpShape); } return false; } diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index c38b250..e546d9e 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -1410,6 +1410,9 @@ void TFImporter::populateNet(Net dstNet) axis = toNCHW(axis); layerParams.set("axis", axis); + if (hasLayerAttr(layer, "num_split")) + layerParams.set("num_split", getLayerAttr(layer, "num_split").i()); + int id = dstNet.addLayer(name, "Slice", layerParams); layer_id[name] = id; diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index 2dae678..0357b8e 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -350,6 +350,11 @@ TEST_P(Test_TensorFlow_layers, l2_normalize_3d) runTensorFlowNet("l2_normalize_3d"); } +TEST_P(Test_TensorFlow_layers, Split) +{ + runTensorFlowNet("split"); +} + class Test_TensorFlow_nets : public DNNTestLayer {}; TEST_P(Test_TensorFlow_nets, MobileNet_SSD) -- 2.7.4