Support Gather for variable inputs
authorLiubov Batanina <piccione-mail@yandex.ru>
Mon, 20 Jul 2020 09:04:20 +0000 (12:04 +0300)
committerLiubov Batanina <piccione-mail@yandex.ru>
Mon, 20 Jul 2020 11:02:45 +0000 (14:02 +0300)
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index 682418bffaf9bf002aae3c8a896f72a0d041383a..220cae813e420675c7405de059dc7ee0c2305a2e 100644 (file)
@@ -1342,32 +1342,64 @@ void ONNXImporter::populateNet(Net dstNet)
         else if (layer_type == "Gather")
         {
             CV_Assert(node_proto.input_size() == 2);
-            Mat input = getBlob(node_proto, constBlobs, 0);
             Mat indexMat = getBlob(node_proto, constBlobs, 1);
             CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
             int index = indexMat.at<int>(0);
+            int axis = layerParams.get<int>("axis", 0);
 
-            Mat out;
-            if (layerParams.has("axis"))
+            if ((constBlobs.find(node_proto.input(0)) != constBlobs.end()))
             {
-                int axis = layerParams.get<int>("axis");
-
+                Mat input = getBlob(node_proto, constBlobs, 0);
+                Mat out;
                 std::vector<cv::Range> ranges(input.dims, Range::all());
                 ranges[axis] = Range(index, index + 1);
 
                 out = input(ranges);
+                MatShape outShape = shape(out);
+                if (outShape.size() > 1)
+                {
+                    outShape.erase(outShape.begin() + axis);
+                    out.reshape(0, outShape);
+                }
+                addConstant(layerParams.name, out, constBlobs, outShapes);
+                continue;
             }
             else
             {
-                CV_Assert(index < input.total());
-                const int dims = input.dims;
-                input = input.reshape(1, 1);
-                input.dims = 2;
-                out = input.reshape(1, 1).colRange(index, index + 1);
-                out.dims = dims;
+                shapeIt = outShapes.find(node_proto.input(0));
+                CV_Assert(shapeIt != outShapes.end());
+                MatShape inpShape = shapeIt->second;
+
+                LayerParams sliceLp;
+                sliceLp.type = "Slice";
+                sliceLp.name = inpShape.size() > 1 ? layerParams.name + "/slice" : layerParams.name;
+                std::vector<int> begin(inpShape.size(), 0);
+                std::vector<int> end(inpShape.size(), -1);
+                begin[axis] = index;
+                end[axis] = index + 1;
+
+                cv::dnn::DictValue paramBegin = cv::dnn::DictValue::arrayInt(begin.data(), begin.size());
+                cv::dnn::DictValue paramEnd = cv::dnn::DictValue::arrayInt(end.data(), end.size());
+                sliceLp.set("begin", paramBegin);
+                sliceLp.set("end", paramEnd);
+
+                if (inpShape.size() > 1)
+                {
+                    opencv_onnx::NodeProto proto;
+                    proto.add_input(node_proto.input(0));
+                    proto.add_output(sliceLp.name);
+                    addLayer(dstNet, sliceLp, proto, layer_id, outShapes);
+
+                    inpShape.erase(inpShape.begin() + axis);
+                    layerParams.type = "Reshape";
+                    layerParams.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
+                    node_proto.set_input(0, sliceLp.name);
+                }
+                else
+                {
+                    layerParams = sliceLp;
+                }
             }
-            addConstant(layerParams.name, out, constBlobs, outShapes);
-            continue;
         }
         else if (layer_type == "Concat")
         {
index 4c8e66aae10533011a020d715830f70c6d74cdc2..e932bc691959ad353a1180d4afb9e714cd1ce6d2 100644 (file)
@@ -111,6 +111,17 @@ TEST_P(Test_ONNX_layers, Convolution)
     testONNXModels("convolution");
 }
 
+TEST_P(Test_ONNX_layers, Gather)
+{
+    if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target == DNN_TARGET_MYRIAD)
+        applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
+    testONNXModels("gather");
+    // GPU plugin unsupported slice for constant
+    if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && (target == DNN_TARGET_OPENCL || target == DNN_TARGET_OPENCL_FP16))
+        applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL, CV_TEST_TAG_DNN_SKIP_IE_OPENCL_FP16, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
+    testONNXModels("gather_scalar", npy, 0, 0, false, false);
+}
+
 TEST_P(Test_ONNX_layers, Convolution3D)
 {
 #if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2019010000)