bug fixed of GEMM node in ONNX_importer
authorZihao Mu <zihaomu@outlook.com>
Wed, 22 Jun 2022 13:08:48 +0000 (21:08 +0800)
committerZihao Mu <zihaomu@outlook.com>
Wed, 22 Jun 2022 13:08:48 +0000 (21:08 +0800)
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index e4cbe02..7a0532f 100644 (file)
@@ -1759,15 +1759,15 @@ void ONNXImporter::parseBatchNormalization(LayerParams& layerParams, const openc
     addLayer(layerParams, node_proto);
 }
 
+// A * B + C = Y, we require that the dimension of A is [m, k], and the dimension of B is [n, k].
+// And the dim of output Y is [m, n]
 void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
 {
     CV_Assert(node_proto.input_size() >= 2);
     layerParams.type = "InnerProduct";
     Mat weights = getBlob(node_proto, 1);
-    int ind_num_out = 0;
-    if (layerParams.has("transB") && !layerParams.get<int>("transB")) {
+    if (!layerParams.get<int>("transB", 0)) {
         transpose(weights, weights);
-        ind_num_out = 1;
     }
     layerParams.blobs.push_back(weights);
 
@@ -1789,7 +1789,7 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr
         addLayer(constParams, proto);
     }
 
-    layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]);
+    layerParams.set("num_output", layerParams.blobs[0].size[0]);
     layerParams.set("bias_term", node_proto.input_size() == 3);
     addLayer(layerParams, node_proto);
 }
index 60473ed..56203cb 100644 (file)
@@ -1389,6 +1389,12 @@ TEST_P(Test_ONNX_layers, DivConst)
     testONNXModels("div_const");
 }
 
+TEST_P(Test_ONNX_layers, Gemm)
+{
+    testONNXModels("gemm_no_transB");
+    testONNXModels("gemm_transB_0");
+}
+
 TEST_P(Test_ONNX_layers, OutputRegistration)
 {
     testONNXModels("output_registration", npy, 0, 0, false, true, 2);