support ReduceSum with two input and dynamic shape batch size in ReduceLayer.
authorZihao Mu <zihaomu@outlook.com>
Wed, 13 Jul 2022 05:46:16 +0000 (13:46 +0800)
committerZihao Mu <zihaomu@outlook.com>
Wed, 13 Jul 2022 05:46:16 +0000 (13:46 +0800)
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index ebbda98..7390d03 100644 (file)
@@ -1180,32 +1180,43 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
     layerParams.set("reduce", reduceType);
     bool keepdims = layerParams.get<int>("keepdims", 1) == 1;
 
-    if (layer_type == "ReduceSum" && node_proto.input_size() == 2)
-    {
-        // TODO support the opset 13 of ReduceSum.
-        //  in opset 13, the ReduceSum has two input, it takes axes as input instead of attribute
-        //  details:https://github.com/onnx/onnx/issues/3420#issuecomment-844295687
-        CV_Error(Error::StsNotImplemented, "Unsupported " + layer_type + " operation of opset 13, please try to "
-                                                                         "re-export the onnx model with opset 11.");
-    }
-
     MatShape inpShape = outShapes[node_proto.input(0)];
     std::vector<bool> shouldDelete(inpShape.size(), false);
 
-    if (layerParams.has("axes"))
+    if (layer_type == "ReduceSum" && node_proto.input_size() == 2)
     {
-        DictValue axes = layerParams.get("axes");
-        for (int i = 0; i < axes.size(); i++)
+        if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
         {
-            int axis = normalize_axis(axes.get<int>(i), inpShape.size());
-            shouldDelete[axis] = true;
+            Mat axesMat = getBlob(node_proto, 1);
+            int axesNum = axesMat.total();
+            for (int i = 0; i < axesNum; i++)
+            {
+                int axis = normalize_axis(static_cast<int>(axesMat.at<float>(i)), inpShape.size());
+                shouldDelete[axis] = true;
+            }
         }
+        else
+            //  in opset 13, the ReduceSum has two input, it takes axes as input instead of attribute
+            //  details:https://github.com/onnx/onnx/issues/3420#issuecomment-844295687
+            CV_Error(Error::StsNotImplemented, "Non-constant axis values in ReduceSum are not supported.");
     }
     else
     {
-        for (int i = 0; i < inpShape.size(); i++)
+        if (layerParams.has("axes"))
         {
-            shouldDelete[i] = true;
+            DictValue axes = layerParams.get("axes");
+            for (int i = 0; i < axes.size(); i++)
+            {
+                int axis = normalize_axis(axes.get<int>(i), inpShape.size());
+                shouldDelete[axis] = true;
+            }
+        }
+        else
+        {
+            for (int i = 0; i < inpShape.size(); i++)
+            {
+                shouldDelete[i] = true;
+            }
         }
     }
 
@@ -1291,6 +1302,17 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
     layerParams.type = (depth == CV_8S) ? "ReshapeInt8" : "Reshape";
     layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
 
+    // Set batchsize dim as dynamic to be compatible with batch size >= 2.
+    if (targetShape[0] == 1 && targetShape.size() > 1)
+    {
+        std::vector<int> dynamicAxes = {0};  // The index of batchsize dim is 0.
+        std::vector<int> inputIndices = {0};
+
+        layerParams.set("has_dynamic_shapes", true);
+        layerParams.set("dynamic_axes", DictValue::arrayInt(dynamicAxes.data(), dynamicAxes.size()));
+        layerParams.set("input_indices", DictValue::arrayInt(inputIndices.data(), inputIndices.size()));
+    }
+
     node_proto.set_input(0, node_proto.output(0));
     node_proto.set_output(0, output_name);
 
index 578e044..5f94f98 100644 (file)
@@ -411,6 +411,8 @@ TEST_P(Test_ONNX_layers, ReduceMean)
 TEST_P(Test_ONNX_layers, ReduceSum)
 {
     testONNXModels("reduce_sum");
+    testONNXModels("reduce_sum_axis");
+    testONNXModels("reduce_sum_axis_dynamic_batch");
 }
 
 TEST_P(Test_ONNX_layers, ReduceMax)