Make Unsqueeze layer support negative axes.
authorzoom <zhongwl2018@mail.sustech.edu.cn>
Fri, 14 Oct 2022 08:46:25 +0000 (16:46 +0800)
committerzoom <zhongwl2018@mail.sustech.edu.cn>
Fri, 14 Oct 2022 10:00:19 +0000 (18:00 +0800)
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index 013520d..b324ad9 100644 (file)
@@ -2313,8 +2313,9 @@ void ONNXImporter::parseUnsqueeze(LayerParams& layerParams, const opencv_onnx::N
         }
 //        CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size());
         for (int j = 0; j < axes.size(); j++) {
-            const int idx = axes.getIntValue(j);
-            CV_Assert(idx <= dims.size());
+            int idx = axes.getIntValue(j);
+            idx = idx < 0 ? idx + input_dims + 1 : idx;
+            CV_Assert(0 <= idx && idx <= dims.size());
             dims.insert(dims.begin() + idx, 1);
         }
 
@@ -2331,6 +2332,7 @@ void ONNXImporter::parseUnsqueeze(LayerParams& layerParams, const opencv_onnx::N
 
     MatShape inpShape = outShapes[node_proto.input(0)];
     int axis = axes.getIntValue(0);
+    axis = axis < 0 ? axis + (int)inpShape.size() + 1 : axis;
     CV_Assert(0 <= axis && axis <= inpShape.size());
     std::vector<int> outShape = inpShape;
     outShape.insert(outShape.begin() + axis, 1);
index 6372bcc..9fe0078 100644 (file)
@@ -1096,6 +1096,11 @@ TEST_P(Test_ONNX_layers, Reshape)
     testONNXModels("unsqueeze_opset_13");
 }
 
+TEST_P(Test_ONNX_layers, Unsqueeze_Neg_Axes)
+{
+    testONNXModels("unsqueeze_neg_axes");
+}
+
 TEST_P(Test_ONNX_layers, Squeeze)
 {
     if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target == DNN_TARGET_MYRIAD)