From a6f3a2125607ea8bb7750b3f889c8331e811fc98 Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Sat, 15 Feb 2020 16:42:20 +0530 Subject: [PATCH] Merge pull request #16424 from czgdp1807:issue-16370 * fixed Split layer in ONNXImporter * added test for fix of split layer * fixed tests for Split layer * applied reviews * updated tests * fixed paths in tests --- modules/dnn/src/onnx/onnx_importer.cpp | 21 ++++++++++++++------- modules/dnn/test/test_onnx_importer.cpp | 12 ++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 039b578..91f954a 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -485,16 +485,23 @@ void ONNXImporter::populateNet(Net dstNet) } else if (layer_type == "Split") { - DictValue splits = layerParams.get("split"); - const int numSplits = splits.size(); - CV_Assert(numSplits > 1); + if (layerParams.has("split")) + { + DictValue splits = layerParams.get("split"); + const int numSplits = splits.size(); + CV_Assert(numSplits > 1); - std::vector slicePoints(numSplits - 1, splits.get(0)); - for (int i = 1; i < splits.size() - 1; ++i) + std::vector slicePoints(numSplits - 1, splits.get(0)); + for (int i = 1; i < splits.size() - 1; ++i) + { + slicePoints[i] = slicePoints[i - 1] + splits.get(i - 1); + } + layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size())); + } + else { - slicePoints[i] = slicePoints[i - 1] + splits.get(i - 1); + layerParams.set("num_split", node_proto.output_size()); } - layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size())); layerParams.type = "Slice"; } else if (layer_type == "Add" || layer_type == "Sum") diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 7607e54..ba2a882 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -386,6 +386,18 @@ TEST_P(Test_ONNX_layers, ReduceL2) testONNXModels("reduceL2"); } +TEST_P(Test_ONNX_layers, Split) +{ + if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER); + if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH); + testONNXModels("split_1"); + testONNXModels("split_2"); + testONNXModels("split_3"); + testONNXModels("split_4"); +} + TEST_P(Test_ONNX_layers, Slice) { #if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2019010000) -- 2.7.4