From 72d06080c6fc8009e38eb991c05ad075b44be5e5 Mon Sep 17 00:00:00 2001 From: Liubov Batanina Date: Tue, 17 Nov 2020 14:45:36 +0300 Subject: [PATCH] [ONNX] Added Reduce ops for batch and channel --- modules/dnn/src/onnx/onnx_importer.cpp | 33 +++++++++++++++++++++++++++++---- modules/dnn/test/test_onnx_importer.cpp | 4 +++- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 9443336..32f7f02 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -551,11 +551,36 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_) CV_Assert(axes.size() <= inpShape.size() - 2); std::vector kernel_size(inpShape.size() - 2, 1); - for (int i = 0; i < axes.size(); i++) { - int axis = clamp(axes.get(i), inpShape.size()); - CV_Assert_N(axis >= 2 + i, axis < inpShape.size()); - kernel_size[axis - 2] = inpShape[axis]; + if (axes.size() == 1 && (clamp(axes.get(0), inpShape.size()) <= 1)) + { + int axis = clamp(axes.get(0), inpShape.size()); + MatShape newShape = inpShape; + newShape[axis + 1] = total(newShape, axis + 1); + newShape.resize(axis + 2); + newShape.insert(newShape.begin(), 2 - axis, 1); + + LayerParams reshapeLp; + reshapeLp.type = "Reshape"; + reshapeLp.name = layerParams.name + "/reshape"; + CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end()); + reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], newShape.size())); + + node_proto.set_output(0, reshapeLp.name); + addLayer(reshapeLp, node_proto); + + kernel_size.resize(2); + kernel_size[0] = inpShape[axis]; + node_proto.set_input(0, node_proto.output(0)); + } + else + { + for (int i = 0; i < axes.size(); i++) { + int axis = clamp(axes.get(i), inpShape.size()); + CV_Assert_N(axis >= 2 + i, axis < inpShape.size()); + kernel_size[axis - 2] = inpShape[axis]; + } } + LayerParams poolLp = layerParams; poolLp.name = layerParams.name + "/avg"; CV_Assert(layer_id.find(poolLp.name) == layer_id.end()); diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 14d2d28..1c5d2e5 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -267,9 +267,11 @@ TEST_P(Test_ONNX_layers, ReduceSum) testONNXModels("reduce_sum"); } -TEST_P(Test_ONNX_layers, ReduceMaxGlobal) +TEST_P(Test_ONNX_layers, ReduceMax) { testONNXModels("reduce_max"); + testONNXModels("reduce_max_axis_0"); + testONNXModels("reduce_max_axis_1"); } TEST_P(Test_ONNX_layers, Scale) -- 2.7.4