Add support for slice from ONNX with multiple outputs
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Sat, 27 Jul 2019 19:10:13 +0000 (22:10 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Sun, 28 Jul 2019 18:20:25 +0000 (21:20 +0300)
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index ddea2ed..51f9aff 100644 (file)
@@ -465,6 +465,20 @@ void ONNXImporter::populateNet(Net dstNet)
             }
             layerParams.set("begin", DictValue::arrayInt(&begin[0], begin.size()));
             layerParams.set("end", DictValue::arrayInt(&end[0], end.size()));
+         }
+        else if (layer_type == "Split")
+        {
+            DictValue splits = layerParams.get("split");
+            const int numSplits = splits.size();
+            CV_Assert(numSplits > 1);
+
+            std::vector<int> slicePoints(numSplits - 1, splits.get<int>(0));
+            for (int i = 1; i < splits.size() - 1; ++i)
+            {
+                slicePoints[i] = slicePoints[i - 1] + splits.get<int>(i - 1);
+            }
+            layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
+            layerParams.type = "Slice";
         }
         else if (layer_type == "Add" || layer_type == "Sum")
         {
@@ -486,6 +500,11 @@ void ONNXImporter::populateNet(Net dstNet)
                 layerParams.type = "Eltwise";
             }
         }
+        else if (layer_type == "Max")
+        {
+            layerParams.type = "Eltwise";
+            layerParams.set("operation", "max");
+        }
         else if (layer_type == "Sub")
         {
             Mat blob = getBlob(node_proto, constBlobs, 1);
@@ -741,6 +760,16 @@ void ONNXImporter::populateNet(Net dstNet)
         {
             layerParams.type = "Permute";
             replaceLayerParam(layerParams, "perm", "order");
+
+            CV_Assert(node_proto.input_size() == 1);
+            if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
+            {
+                std::vector<Mat> inputs(1, getBlob(node_proto, constBlobs, 0)), transposed;
+                runLayer(layerParams, inputs, transposed);
+                CV_Assert(transposed.size() == 1);
+                constBlobs.insert(std::make_pair(layerParams.name, transposed[0]));
+                continue;
+            }
         }
         else if (layer_type == "Unsqueeze")
         {
@@ -906,8 +935,10 @@ void ONNXImporter::populateNet(Net dstNet)
         }
 
         int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
-        layer_id.insert(std::make_pair(layerParams.name, LayerInfo(id, 0)));
-
+        for (int i = 0; i < node_proto.output_size(); ++i)
+        {
+            layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
+        }
 
         std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
         for (int j = 0; j < node_proto.input_size(); j++) {
@@ -924,8 +955,10 @@ void ONNXImporter::populateNet(Net dstNet)
         // Compute shape of output blob for this layer.
         Ptr<Layer> layer = dstNet.getLayer(id);
         layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
-        CV_Assert(!layerOutShapes.empty());
-        outShapes[layerParams.name] = layerOutShapes[0];
+        for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
+        {
+            outShapes[node_proto.output(i)] = layerOutShapes[i];
+        }
     }
 }
 
index aeceb9a..fa6ee92 100644 (file)
@@ -348,6 +348,13 @@ TEST_P(Test_ONNX_layers, Softmax)
     testONNXModels("log_softmax", npy, 0, 0, false, false);
 }
 
+TEST_P(Test_ONNX_layers, Split_EltwiseMax)
+{
+    if (backend == DNN_BACKEND_INFERENCE_ENGINE)
+        applyTestTag(CV_TEST_TAG_DNN_SKIP_IE);
+    testONNXModels("split_max");
+}
+
 INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
 
 class Test_ONNX_nets : public Test_ONNX_layers