void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
- int axis = 0;
- std::vector<int> begin;
- std::vector<int> end;
+ MatShape inpShape = outShapes[node_proto.input(0)];
+ int dims = inpShape.size();
+ std::vector<int> begin(dims, 0);
+ std::vector<int> end(dims, INT_MAX);
std::vector<int> steps;
int inp_size = node_proto.input_size();
+ int axis = 0;
+ bool has_axes = false;
+ DictValue starts_, ends_, axes_, steps_;
+ // opset = 1
if (inp_size == 1)
{
- if (layerParams.has("axes")) {
- DictValue axes = layerParams.get("axes");
- for (int i = 1; i < axes.size(); ++i) {
- CV_Assert(axes.get<int>(i - 1) == axes.get<int>(i) - 1);
- }
- axis = axes.get<int>(0);
- }
-
- DictValue starts = layerParams.get("starts");
- DictValue ends = layerParams.get("ends");
- CV_Assert(starts.size() == ends.size());
-
- if (axis > 0) {
- CV_CheckLE(axis, 1024, "Slice layer can't have more than 1024 axes"); // arbitrary limit
- begin.resize(axis, 0);
- end.resize(axis, INT_MAX);
- }
- for (int i = 0; i < starts.size(); ++i)
+ starts_ = layerParams.get("starts");
+ ends_ = layerParams.get("ends");
+ CV_Assert(starts_.size() == ends_.size());
+ if (layerParams.has("axes"))
{
- begin.push_back(starts.get<int>(i));
- end.push_back(ends.get<int>(i));
+ axes_ = layerParams.get("axes");
+ CV_Assert(axes_.size() == starts_.size());
+ axis = axes_.getIntValue(0) < 0 ? axes_.getIntValue(0) + dims : axes_.getIntValue(0);
+ has_axes = true;
}
- } else { // inp_size > 1
+ }
+ // opset > 1
+ else
+ {
CV_Assert(inp_size >= 3);
- for (int i = 1; i < inp_size; i++) {
+ for (int i = 1; i < inp_size; ++i)
+ {
CV_Assert(constBlobs.find(node_proto.input(i)) != constBlobs.end());
}
Mat start_blob = getBlob(node_proto, 1);
- Mat end_blob = getBlob(node_proto, 2);
+ Mat end_blob = getBlob(node_proto, 2);
CV_Assert(start_blob.total() == end_blob.total());
+ starts_ = DictValue::arrayInt(start_blob.begin<int>(), start_blob.total());
+ ends_ = DictValue::arrayInt(end_blob.begin<int>(), end_blob.total());
- if (inp_size > 3) {
+ if (inp_size > 3)
+ {
Mat axes_blob = getBlob(node_proto, 3);
- const int* axes = (int*)axes_blob.data;
- for (int i = 1; i < axes_blob.total(); ++i) {
- CV_Assert(axes[i - 1] == axes[i] - 1);
- }
- axis = axes[0];
- }
-
- const int* starts = start_blob.ptr<int>();
- const int* ends = end_blob.ptr<int>();
- if (axis > 0) {
- begin.resize(axis, 0);
- end.resize(axis, INT_MAX);
+ CV_Assert(axes_blob.total() == start_blob.total());
+ axes_ = DictValue::arrayInt(axes_blob.begin<int>(), axes_blob.total());
+ axis = axes_.getIntValue(0) < 0 ? axes_.getIntValue(0) + dims : axes_.getIntValue(0);
+ has_axes = true;
}
- std::copy(starts, starts + start_blob.total(), std::back_inserter(begin));
- std::copy(ends, ends + end_blob.total(), std::back_inserter(end));
- if (inp_size == 5) {
- CV_Assert(constBlobs.find(node_proto.input(4)) != constBlobs.end());
+ if (inp_size == 5)
+ {
Mat step_blob = getBlob(node_proto, 4);
- const int* steps_ptr = step_blob.ptr<int>();
-
- if (axis > 0)
- steps.resize(axis, 1);
-
- std::copy(steps_ptr, steps_ptr + step_blob.total(), std::back_inserter(steps));
+ CV_Assert(step_blob.total() == start_blob.total());
+ steps_ = DictValue::arrayInt(step_blob.begin<int>(), step_blob.total());
+ steps.resize(dims, 1);
// Very strange application for Slice op with tensor reversing.
// We just workaround it for 2d constants.
}
}
}
+
+ if (!has_axes)
+ {
+ // make a default axes [0, 1, 2...]
+ Mat axes_tmp(1, starts_.size(), CV_32S);
+ std::iota(axes_tmp.begin<int>(), axes_tmp.end<int>(), 0);
+ axes_ = DictValue::arrayInt(axes_tmp.begin<int>(), axes_tmp.total());
+ }
+
+ int cur_axe;
+ std::vector<bool> flag(dims, false);
+ Mat axes(1, starts_.size(), CV_32S);
+ auto axes_ptr = axes.ptr<int>();
+ // resize begin and end
+ for (int i = 0; i < axes_.size(); ++i)
+ {
+ // dims should be added to the negative axes
+ cur_axe = axes_.getIntValue(i) < 0 ? axes_.getIntValue(i) + dims : axes_.getIntValue(i);
+ CV_CheckGE(cur_axe, 0, "Axes should be grater or equal to '-dims'.");
+ CV_CheckLT(cur_axe, dims, "Axes should be less than 'dim'.");
+ CV_CheckEQ(flag[cur_axe], false, "Axes shouldn't have duplicated values.");
+ flag[cur_axe] = true;
+ // change axis to the minimum axe
+ if (cur_axe < axis) axis = cur_axe;
+ axes_ptr[i] = cur_axe;
+ begin[cur_axe] = starts_.getIntValue(i);
+ end[cur_axe] = ends_.getIntValue(i);
+ }
+
layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
layerParams.set("axis", axis);
if (!steps.empty())
+ {
+ for (int i = 0; i < axes.total(); ++i)
+ steps[axes_ptr[i]] = steps_.getIntValue(i);
layerParams.set("steps", DictValue::arrayInt(&steps[0], steps.size()));
+ }
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{