let MatMul can work when both two inputs are const
authorzoom <zhongwl2018@mail.sustech.edu.cn>
Sun, 27 Nov 2022 09:32:41 +0000 (17:32 +0800)
committerzoom <zhongwl2018@mail.sustech.edu.cn>
Sun, 27 Nov 2022 09:32:41 +0000 (17:32 +0800)
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index a1a99f6..6ef8894 100644 (file)
@@ -2037,9 +2037,25 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
     CV_Assert(node_proto.input_size() == 2);
     layerParams.type = "InnerProduct";
     layerParams.set("bias_term", false);
-    CV_Assert(constBlobs.find(node_proto.input(0)) == constBlobs.end());
-    int firstInpDims = outShapes[node_proto.input(0)].size();
-    int secondInpDims;
+    int firstInpDims, secondInpDims;
+
+    if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
+    {
+        Mat blob = getBlob(node_proto, 0);
+        firstInpDims = blob.dims;
+        LayerParams constParams;
+        constParams.name = layerParams.name + "/const_0";
+        constParams.type = "Const";
+        constParams.blobs.push_back(blob);
+
+        opencv_onnx::NodeProto tmpProto;
+        tmpProto.add_output(constParams.name);
+        addLayer(constParams, tmpProto);
+
+        node_proto.set_input(0, constParams.name);
+    }
+    else
+        firstInpDims = outShapes[node_proto.input(0)].size();
 
     if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
     {
@@ -2053,7 +2069,7 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
         else
         {
             LayerParams constParams;
-            constParams.name = layerParams.name + "/const";
+            constParams.name = layerParams.name + "/const_1";
             constParams.type = "Const";
             constParams.blobs.push_back(blob);
 
@@ -2063,9 +2079,10 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
 
             node_proto.set_input(1, constParams.name);
         }
-    } else {
-        secondInpDims = outShapes[node_proto.input(1)].size();
     }
+    else
+        secondInpDims = outShapes[node_proto.input(1)].size();
+
     layerParams.set("axis", firstInpDims - secondInpDims + 1);
     addLayer(layerParams, node_proto);
 }
index be14041..43dc817 100644 (file)
@@ -961,6 +961,8 @@ TEST_P(Test_ONNX_layers, MatMul_init)
     testONNXModels("matmul_2d_init");
     testONNXModels("matmul_3d_init");
     testONNXModels("matmul_4d_init");
+
+    testONNXModels("matmul_init_2");
 }
 
 TEST_P(Test_ONNX_layers, MatMulAdd)