Make MatMul layer support 3d or 4d operation with const input
authorzoom <zhongwl2018@mail.sustech.edu.cn>
Wed, 9 Nov 2022 08:23:42 +0000 (16:23 +0800)
committerzoom <zhongwl2018@mail.sustech.edu.cn>
Thu, 10 Nov 2022 03:41:44 +0000 (11:41 +0800)
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index 66f524a..a1a99f6 100644 (file)
@@ -2031,8 +2031,9 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr
     addLayer(layerParams, node_proto);
 }
 
-void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
+void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
 {
+    opencv_onnx::NodeProto node_proto = node_proto_;
     CV_Assert(node_proto.input_size() == 2);
     layerParams.type = "InnerProduct";
     layerParams.set("bias_term", false);
@@ -2044,8 +2045,24 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
     {
         Mat blob = getBlob(node_proto, 1);
         secondInpDims = blob.dims;
-        layerParams.blobs.push_back(blob.t());
-        layerParams.set("num_output", layerParams.blobs[0].size[0]);
+        if (secondInpDims == 2)
+        {
+            layerParams.blobs.push_back(blob.t());
+            layerParams.set("num_output", layerParams.blobs[0].size[0]);
+        }
+        else
+        {
+            LayerParams constParams;
+            constParams.name = layerParams.name + "/const";
+            constParams.type = "Const";
+            constParams.blobs.push_back(blob);
+
+            opencv_onnx::NodeProto tmpProto;
+            tmpProto.add_output(constParams.name);
+            addLayer(constParams, tmpProto);
+
+            node_proto.set_input(1, constParams.name);
+        }
     } else {
         secondInpDims = outShapes[node_proto.input(1)].size();
     }
index b310dce..be14041 100644 (file)
@@ -956,6 +956,13 @@ TEST_P(Test_ONNX_layers, MatMul)
     testONNXModels("matmul_4d");
 }
 
+TEST_P(Test_ONNX_layers, MatMul_init)
+{
+    testONNXModels("matmul_2d_init");
+    testONNXModels("matmul_3d_init");
+    testONNXModels("matmul_4d_init");
+}
+
 TEST_P(Test_ONNX_layers, MatMulAdd)
 {
 #if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_EQ(2022010000)