enhance slice layer
authorzoom <zhongwl2018@mail.sustech.edu.cn>
Thu, 22 Sep 2022 06:40:39 +0000 (14:40 +0800)
committerzoom <zhongwl2018@mail.sustech.edu.cn>
Sat, 1 Oct 2022 09:12:07 +0000 (17:12 +0800)
refactor the code for parsing Slice layer
add test for Slice layer
let 'begin' and 'end' resize to dims
add opset message comment

modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index 0b104d1..4a1ebdb 100644 (file)
@@ -1326,72 +1326,59 @@ void ONNXImporter::parseReduce(LayerParams& layerParams, const opencv_onnx::Node
 
 void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
 {
-    int axis = 0;
-    std::vector<int> begin;
-    std::vector<int> end;
+    MatShape inpShape = outShapes[node_proto.input(0)];
+    int dims = inpShape.size();
+    std::vector<int> begin(dims, 0);
+    std::vector<int> end(dims, INT_MAX);
     std::vector<int> steps;
     int inp_size = node_proto.input_size();
+    int axis = 0;
+    bool has_axes = false;
+    DictValue starts_, ends_, axes_, steps_;
 
+    // opset = 1
     if (inp_size == 1)
     {
-        if (layerParams.has("axes")) {
-            DictValue axes = layerParams.get("axes");
-            for (int i = 1; i < axes.size(); ++i) {
-                CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
-            }
-            axis = axes.get<int>(0);
-        }
-
-        DictValue starts = layerParams.get("starts");
-        DictValue ends = layerParams.get("ends");
-        CV_Assert(starts.size() == ends.size());
-
-        if (axis > 0) {
-            CV_CheckLE(axis, 1024, "Slice layer can't have more than 1024 axes"); // arbitrary limit
-            begin.resize(axis, 0);
-            end.resize(axis, INT_MAX);
-        }
-        for (int i = 0; i < starts.size(); ++i)
+        starts_ = layerParams.get("starts");
+        ends_ = layerParams.get("ends");
+        CV_Assert(starts_.size() == ends_.size());
+        if (layerParams.has("axes"))
         {
-            begin.push_back(starts.get<int>(i));
-            end.push_back(ends.get<int>(i));
+            axes_ = layerParams.get("axes");
+            CV_Assert(axes_.size() == starts_.size());
+            axis = axes_.getIntValue(0) < 0 ? axes_.getIntValue(0) + dims : axes_.getIntValue(0);
+            has_axes = true;
         }
-    } else { // inp_size > 1
+    }
+    // opset > 1
+    else
+    {
         CV_Assert(inp_size >= 3);
-        for (int i = 1; i < inp_size; i++) {
+        for (int i = 1; i < inp_size; ++i)
+        {
             CV_Assert(constBlobs.find(node_proto.input(i)) != constBlobs.end());
         }
         Mat start_blob = getBlob(node_proto, 1);
-        Mat end_blob   = getBlob(node_proto, 2);
+        Mat end_blob = getBlob(node_proto, 2);
         CV_Assert(start_blob.total() == end_blob.total());
+        starts_ = DictValue::arrayInt(start_blob.begin<int>(), start_blob.total());
+        ends_ = DictValue::arrayInt(end_blob.begin<int>(), end_blob.total());
 
-        if (inp_size > 3) {
+        if (inp_size > 3)
+        {
             Mat axes_blob = getBlob(node_proto, 3);
-            const int* axes = (int*)axes_blob.data;
-            for (int i = 1; i < axes_blob.total(); ++i) {
-                CV_Assert(axes[i - 1] == axes[i] - 1);
-            }
-            axis = axes[0];
-        }
-
-        const int* starts = start_blob.ptr<int>();
-        const int* ends   = end_blob.ptr<int>();
-        if (axis > 0) {
-            begin.resize(axis, 0);
-            end.resize(axis, INT_MAX);
+            CV_Assert(axes_blob.total() == start_blob.total());
+            axes_ = DictValue::arrayInt(axes_blob.begin<int>(), axes_blob.total());
+            axis = axes_.getIntValue(0) < 0 ? axes_.getIntValue(0) + dims : axes_.getIntValue(0);
+            has_axes = true;
         }
-        std::copy(starts, starts + start_blob.total(), std::back_inserter(begin));
-        std::copy(ends, ends + end_blob.total(), std::back_inserter(end));
 
-        if (inp_size == 5) {
-            CV_Assert(constBlobs.find(node_proto.input(4)) != constBlobs.end());
+        if (inp_size == 5)
+        {
             Mat step_blob = getBlob(node_proto, 4);
-            const int* steps_ptr = step_blob.ptr<int>();
-
-            if (axis > 0)
-                steps.resize(axis, 1);
-
-            std::copy(steps_ptr, steps_ptr + step_blob.total(), std::back_inserter(steps));
+            CV_Assert(step_blob.total() == start_blob.total());
+            steps_ = DictValue::arrayInt(step_blob.begin<int>(), step_blob.total());
+            steps.resize(dims, 1);
 
             // Very strange application for Slice op with tensor reversing.
             // We just workaround it for 2d constants.
@@ -1411,12 +1398,45 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
             }
         }
     }
+
+    if (!has_axes)
+    {
+        // make a default axes [0, 1, 2...]
+        Mat axes_tmp(1, starts_.size(), CV_32S);
+        std::iota(axes_tmp.begin<int>(), axes_tmp.end<int>(), 0);
+        axes_ = DictValue::arrayInt(axes_tmp.begin<int>(), axes_tmp.total());
+    }
+
+    int cur_axe;
+    std::vector<bool> flag(dims, false);
+    Mat axes(1, starts_.size(), CV_32S);
+    auto axes_ptr = axes.ptr<int>();
+    // resize begin and end
+    for (int i = 0; i < axes_.size(); ++i)
+    {
+        // dims should be added to the negative axes
+        cur_axe = axes_.getIntValue(i) < 0 ? axes_.getIntValue(i) + dims : axes_.getIntValue(i);
+        CV_CheckGE(cur_axe, 0, "Axes should be grater or equal to '-dims'.");
+        CV_CheckLT(cur_axe, dims, "Axes should be less than 'dim'.");
+        CV_CheckEQ(flag[cur_axe], false, "Axes shouldn't have duplicated values.");
+        flag[cur_axe] = true;
+        // change axis to the minimum axe
+        if (cur_axe < axis) axis = cur_axe;
+        axes_ptr[i] = cur_axe;
+        begin[cur_axe] = starts_.getIntValue(i);
+        end[cur_axe] = ends_.getIntValue(i);
+    }
+
     layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
     layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
     layerParams.set("axis", axis);
 
     if (!steps.empty())
+    {
+        for (int i = 0; i < axes.total(); ++i)
+            steps[axes_ptr[i]] = steps_.getIntValue(i);
         layerParams.set("steps", DictValue::arrayInt(&steps[0], steps.size()));
+    }
 
     if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
     {
index 8090eba..4eb6b5a 100644 (file)
@@ -1172,6 +1172,20 @@ TEST_P(Test_ONNX_layers, Slice_Steps_5DInput)
     testONNXModels("slice_opset_11_steps_5d");
 }
 
+TEST_P(Test_ONNX_layers, Slice_Nonseq_Axes)
+{
+    testONNXModels("slice_nonseq_axes");
+    testONNXModels("slice_nonseq_axes_steps");
+    testONNXModels("slice_nonseq_miss_axes_steps");
+}
+
+TEST_P(Test_ONNX_layers, Slice_Neg_Axes)
+{
+    testONNXModels("slice_neg_axes");
+    testONNXModels("slice_neg_axes_steps");
+    testONNXModels("slice_neg_miss_axes_steps");
+}
+
 TEST_P(Test_ONNX_layers, Softmax)
 {
     testONNXModels("softmax");