Fix LSTM from ONNX with batch==1
authorDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 17 Mar 2020 21:00:24 +0000 (00:00 +0300)
committerDmitry Kurtaev <dmitry.kurtaev+github@gmail.com>
Tue, 17 Mar 2020 21:00:24 +0000 (00:00 +0300)
modules/dnn/src/layers/recurrent_layers.cpp
modules/dnn/src/onnx/onnx_importer.cpp

index 3f9a229..26d2ea9 100644 (file)
@@ -110,10 +110,11 @@ public:
             const Mat& Wh = blobs[0];
             const Mat& Wx = blobs[1];
             const Mat& bias = blobs[2];
-            CV_Assert(Wh.dims == 2 && Wx.dims == 2);
-            CV_Assert(Wh.rows == Wx.rows);
-            CV_Assert(Wh.rows == 4*Wh.cols);
-            CV_Assert(Wh.rows == (int)bias.total());
+            CV_CheckEQ(Wh.dims, 2, "");
+            CV_CheckEQ(Wx.dims, 2, "");
+            CV_CheckEQ(Wh.rows, Wx.rows, "");
+            CV_CheckEQ(Wh.rows, 4*Wh.cols, "");
+            CV_CheckEQ(Wh.rows, (int)bias.total(), "");
             CV_Assert(Wh.type() == Wx.type() && Wx.type() == bias.type());
 
             // Peephole weights.
index 2bcba9e..b243a98 100644 (file)
@@ -49,6 +49,11 @@ class ONNXImporter
     LayerParams getLayerParams(const opencv_onnx::NodeProto& node_proto);
     bool isCeilMode(const LayerParams& layerParams);
 
+    void addLayer(Net& dstNet, LayerParams& layerParams,
+                  const opencv_onnx::NodeProto& node_proto,
+                  std::map<std::string, LayerInfo>& layer_id,
+                  std::map<std::string, MatShape>& outShapes);
+
 public:
 
     ONNXImporter(const char *onnxFile)
@@ -259,6 +264,42 @@ Mat ONNXImporter::getBlob(const opencv_onnx::NodeProto& node_proto,
     return constBlob->second;
 }
 
+void ONNXImporter::addLayer(Net& dstNet, LayerParams& layerParams,
+                            const opencv_onnx::NodeProto& node_proto,
+                            std::map<std::string, LayerInfo>& layer_id,
+                            std::map<std::string, MatShape>& outShapes)
+{
+    std::map<std::string, LayerInfo>::iterator layerId;
+    std::map<std::string, MatShape>::iterator shapeIt;
+
+    int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
+    for (int i = 0; i < node_proto.output_size(); ++i)
+    {
+        layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
+    }
+
+    std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
+    int inpNum = 0;
+    for (int j = 0; j < node_proto.input_size(); j++) {
+        layerId = layer_id.find(node_proto.input(j));
+        if (layerId != layer_id.end()) {
+            dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, inpNum);
+            ++inpNum;
+            // Collect input shapes.
+            shapeIt = outShapes.find(node_proto.input(j));
+            CV_Assert(shapeIt != outShapes.end());
+            layerInpShapes.push_back(shapeIt->second);
+        }
+    }
+    // Compute shape of output blob for this layer.
+    Ptr<Layer> layer = dstNet.getLayer(id);
+    layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
+    for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
+    {
+        outShapes[node_proto.output(i)] = layerOutShapes[i];
+    }
+}
+
 void ONNXImporter::populateNet(Net dstNet)
 {
     CV_Assert(model_proto.has_graph());
@@ -581,13 +622,16 @@ void ONNXImporter::populateNet(Net dstNet)
         }
         else if (layer_type == "LSTM")
         {
+            LayerParams lstmParams = layerParams;
+            lstmParams.name += "/lstm";
+
             // https://pytorch.org/docs/stable/nn.html#lstm
             CV_Assert(node_proto.input_size() == 7);
             Mat Wx = getBlob(node_proto, constBlobs, 1);
             Mat Wh = getBlob(node_proto, constBlobs, 2);
             Mat b = getBlob(node_proto, constBlobs, 3);
 
-            const int numHidden = Wh.size[2];
+            const int numHidden = lstmParams.get<int>("hidden_size");
 
             Wx = Wx.reshape(1, Wx.size[1]);
             Wh = Wh.reshape(1, Wh.size[1]);
@@ -612,10 +656,24 @@ void ONNXImporter::populateNet(Net dstNet)
                 }
                 std::swap(biasData[numHidden + j], biasData[numHidden * 2 + j]);
             }
-            layerParams.blobs.resize(3);
-            layerParams.blobs[0] = Wh;
-            layerParams.blobs[1] = Wx;
-            layerParams.blobs[2] = b;
+
+            lstmParams.blobs.resize(3);
+            lstmParams.blobs[0] = Wh;
+            lstmParams.blobs[1] = Wx;
+            lstmParams.blobs[2] = b;
+
+            node_proto.set_output(0, lstmParams.name);  // set different name so output shapes will be registered on that name
+            addLayer(dstNet, lstmParams, node_proto, layer_id, outShapes);
+
+            MatShape lstmShape = outShapes[node_proto.output(0)];
+
+            // Add fake 1 as it is done in ONNX
+            lstmShape.insert(lstmShape.begin() + 1, 1);
+
+            layerParams.type = "Reshape";
+            layerParams.set("dim", DictValue::arrayInt(&lstmShape[0], lstmShape.size()));
+            node_proto.set_input(0, lstmParams.name);  // redirect input to LSTM
+            node_proto.set_output(0, layerParams.name);  // keep origin LSTM's name
         }
         else if (layer_type == "ImageScaler")
         {
@@ -1228,34 +1286,7 @@ void ONNXImporter::populateNet(Net dstNet)
                     layerParams.blobs.push_back(getBlob(node_proto, constBlobs, j));
             }
         }
-
-        int id = dstNet.addLayer(layerParams.name, layerParams.type, layerParams);
-        for (int i = 0; i < node_proto.output_size(); ++i)
-        {
-            layer_id.insert(std::make_pair(node_proto.output(i), LayerInfo(id, i)));
-        }
-
-        std::vector<MatShape> layerInpShapes, layerOutShapes, layerInternalShapes;
-        int inpNum = 0;
-        for (int j = 0; j < node_proto.input_size(); j++) {
-            layerId = layer_id.find(node_proto.input(j));
-            if (layerId != layer_id.end()) {
-                dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, inpNum);
-                ++inpNum;
-                // Collect input shapes.
-                shapeIt = outShapes.find(node_proto.input(j));
-                CV_Assert(shapeIt != outShapes.end());
-                layerInpShapes.push_back(shapeIt->second);
-            }
-        }
-
-        // Compute shape of output blob for this layer.
-        Ptr<Layer> layer = dstNet.getLayer(id);
-        layer->getMemoryShapes(layerInpShapes, 0, layerOutShapes, layerInternalShapes);
-        for (int i = 0; i < node_proto.output_size() && i < (int)layerOutShapes.size(); ++i)
-        {
-            outShapes[node_proto.output(i)] = layerOutShapes[i];
-        }
+        addLayer(dstNet, layerParams, node_proto, layer_id, outShapes);
     }
 }