Operate with shapes in ONNX models
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 12 Dec 2018 14:36:17 +0000 (17:36 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Wed, 12 Dec 2018 15:34:22 +0000 (18:34 +0300)
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index 22eda50..18e26f1 100644 (file)
@@ -6,6 +6,7 @@
 // Third party copyrights are property of their respective owners.
 
 #include "../precomp.hpp"
+#include <opencv2/dnn/shape_utils.hpp>
 
 #ifdef HAVE_PROTOBUF
 
@@ -134,9 +135,38 @@ Mat getMatFromTensor(opencv_onnx::TensorProto& tensor_proto)
     else
         CV_Error(Error::StsUnsupportedFormat, "Unsupported data type: " +
                         opencv_onnx::TensorProto_DataType_Name(datatype));
+    if (tensor_proto.dims_size() == 0)
+        blob.dims = 1;  // To force 1-dimensional cv::Mat for scalars.
     return blob;
 }
 
+void runLayer(Ptr<Layer> layer, const std::vector<Mat>& inputs,
+              std::vector<Mat>& outputs)
+{
+    std::vector<MatShape> inpShapes(inputs.size());
+    int ddepth = CV_32F;
+    for (size_t i = 0; i < inputs.size(); ++i)
+    {
+        inpShapes[i] = shape(inputs[i]);
+        if (i > 0 && ddepth != inputs[i].depth())
+            CV_Error(Error::StsNotImplemented, "Mixed input data types.");
+        ddepth = inputs[i].depth();
+    }
+
+    std::vector<MatShape> outShapes, internalShapes;
+    layer->getMemoryShapes(inpShapes, 0, outShapes, internalShapes);
+
+    std::vector<Mat> internals(internalShapes.size());
+    outputs.resize(outShapes.size());
+    for (size_t i = 0; i < outShapes.size(); ++i)
+        outputs[i].create(outShapes[i], ddepth);
+    for (size_t i = 0; i < internalShapes.size(); ++i)
+        internals[i].create(internalShapes[i], ddepth);
+
+    layer->finalize(inputs, outputs);
+    layer->forward(inputs, outputs, internals);
+}
+
 std::map<std::string, Mat> ONNXImporter::getGraphTensors(
                                         const opencv_onnx::GraphProto& graph_proto)
 {
@@ -292,6 +322,26 @@ void ONNXImporter::populateNet(Net dstNet)
     CV_Assert(model_proto.has_graph());
     opencv_onnx::GraphProto graph_proto = model_proto.graph();
     std::map<std::string, Mat> constBlobs = getGraphTensors(graph_proto);
+    // List of internal blobs shapes.
+    std::map<std::string, MatShape> outShapes;
+    // Add all the inputs shapes. It includes as constant blobs as network's inputs shapes.
+    for (int i = 0; i < graph_proto.input_size(); ++i)
+    {
+        opencv_onnx::ValueInfoProto valueInfoProto = graph_proto.input(i);
+        CV_Assert(valueInfoProto.has_type());
+        opencv_onnx::TypeProto typeProto = valueInfoProto.type();
+        CV_Assert(typeProto.has_tensor_type());
+        opencv_onnx::TypeProto::Tensor tensor = typeProto.tensor_type();
+        CV_Assert(tensor.has_shape());
+        opencv_onnx::TensorShapeProto tensorShape = tensor.shape();
+
+        MatShape inpShape(tensorShape.dim_size());
+        for (int j = 0; j < inpShape.size(); ++j)
+        {
+            inpShape[j] = tensorShape.dim(j).dim_value();
+        }
+        outShapes[valueInfoProto.name()] = inpShape;
+    }
 
     std::string framework_name;
     if (model_proto.has_producer_name()) {
@@ -301,6 +351,7 @@ void ONNXImporter::populateNet(Net dstNet)
     // create map with network inputs (without const blobs)
     std::map<std::string, LayerInfo> layer_id;
     std::map<std::string, LayerInfo>::iterator layerId;
+    std::map<std::string, MatShape>::iterator shapeIt;
     // fill map: push layer name, layer id and output id
     std::vector<String> netInputs;
     for (int j = 0; j < graph_proto.input_size(); j++)
@@ -317,9 +368,9 @@ void ONNXImporter::populateNet(Net dstNet)
     LayerParams layerParams;
     opencv_onnx::NodeProto node_proto;
 
-    for(int i = 0; i < layersSize; i++)
+    for(int li = 0; li < layersSize; li++)
     {
-        node_proto = graph_proto.node(i);
+        node_proto = graph_proto.node(li);
         layerParams = getLayerParams(node_proto);
         CV_Assert(node_proto.output_size() >= 1);
         layerParams.name = node_proto.output(0);
@@ -598,6 +649,65 @@ void ONNXImporter::populateNet(Net dstNet)
         {
             layerParams.type = "Padding";
         }
+        else if (layer_type == "Shape")
+        {
+            CV_Assert(node_proto.input_size() == 1);
+            shapeIt = outShapes.find(node_proto.input(0));
+            CV_Assert(shapeIt != outShapes.end());
+            MatShape inpShape = shapeIt->second;
+
+            Mat shapeMat(inpShape.size(), 1, CV_32S);
+            for (int j = 0; j < inpShape.size(); ++j)
+                shapeMat.at<int>(j) = inpShape[j];
+            shapeMat.dims = 1;
+
+            constBlobs.insert(std::make_pair(layerParams.name, shapeMat));
+            continue;
+        }
+        else if (layer_type == "Gather")
+        {
+            CV_Assert(node_proto.input_size() == 2);
+            CV_Assert(layerParams.has("axis"));
+            Mat input = getBlob(node_proto, constBlobs, 0);
+            Mat indexMat = getBlob(node_proto, constBlobs, 1);
+            CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
+            int index = indexMat.at<int>(0);
+            int axis = layerParams.get<int>("axis");
+
+            std::vector<cv::Range> ranges(input.dims, Range::all());
+            ranges[axis] = Range(index, index + 1);
+
+            Mat out = input(ranges);
+            constBlobs.insert(std::make_pair(layerParams.name, out));
+            continue;
+        }
+        else if (layer_type == "Concat")
+        {
+            bool hasVariableInps = false;
+            for (int i = 0; i < node_proto.input_size(); ++i)
+            {
+                if (layer_id.find(node_proto.input(i)) != layer_id.end())
+                {
+                    hasVariableInps = true;
+                    break;
+                }
+            }
+
+            if (!hasVariableInps)
+            {
+                std::vector<Mat> inputs(node_proto.input_size()), concatenated;
+                for (size_t i = 0; i < inputs.size(); ++i)
+                {
+                    inputs[i] = getBlob(node_proto, constBlobs, i);
+                }
+                Ptr<Layer> concat = ConcatLayer::create(layerParams);
+                runLayer(concat, inputs, concatenated);
+
+                CV_Assert(concatenated.size() == 1);
+                constBlobs.insert(std::make_pair(layerParams.name, concatenated[0]));
+                continue;
+            }
+        }
         else
         {
             for (int j = 0; j < node_proto.input_size(); j++) {
@@ -609,12 +719,24 @@ void ONNXImporter::populateNet(Net dstNet)
          int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
          layer_id.insert(std::make_pair(layerParams.name, LayerInfo(id, 0)));
 
+
+         std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
          for (int j = 0; j < node_proto.input_size(); j++) {
              layerId = layer_id.find(node_proto.input(j));
              if (layerId != layer_id.end()) {
                  dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, j);
+                 // Collect input shapes.
+                 shapeIt = outShapes.find(node_proto.input(j));
+                 CV_Assert(shapeIt != outShapes.end());
+                 layerInpShapes.push_back(shapeIt->second);
              }
          }
+
+         // Compute shape of output blob for this layer.
+         Ptr<Layer> layer = dstNet.getLayer(id);
+         layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
+         CV_Assert(!layerOutShapes.empty());
+         outShapes[layerParams.name] = layerOutShapes[0];
      }
  }
 
index 61b06cc..36e7450 100644 (file)
@@ -162,6 +162,10 @@ TEST_P(Test_ONNX_layers, MultyInputs)
     normAssert(ref, out, "", default_l1,  default_lInf);
 }
 
+TEST_P(Test_ONNX_layers, DynamicReshape)
+{
+    testONNXModels("dynamic_reshape");
+}
 
 INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());