From 2411b825b45d91440b25cce5d0b99dcf115cc8d8 Mon Sep 17 00:00:00 2001 From: Zihao Mu Date: Wed, 22 Jun 2022 15:00:17 +0800 Subject: [PATCH] bug fixed of GEMM node in ONNX_importer --- modules/dnn/src/onnx/onnx_importer.cpp | 10 ++++++---- modules/dnn/test/test_onnx_importer.cpp | 5 +++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index f06ff32..7e035b3 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -2080,15 +2080,17 @@ void ONNXImporter::parseBatchNormalization(LayerParams& layerParams, const openc addLayer(layerParams, node_proto); } +// A * B + C = Y, we require that the dimension of A is [m, k], and the dimension of B is [n, k]. +// And the dim of output Y is [m, n] void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) { CV_Assert(node_proto.input_size() >= 2); layerParams.type = "InnerProduct"; Mat weights = getBlob(node_proto, 1); - int ind_num_out = 0; - if (layerParams.has("transB") && !layerParams.get("transB")) { + + if (!layerParams.get("transB", 0)) + { transpose(weights, weights); - ind_num_out = 1; } layerParams.blobs.push_back(weights); @@ -2110,7 +2112,7 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr addLayer(constParams, proto); } - layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]); + layerParams.set("num_output", layerParams.blobs[0].size[0]); layerParams.set("bias_term", node_proto.input_size() == 3); addLayer(layerParams, node_proto); } diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 50540cd..bf59cbb 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -1746,6 +1746,11 @@ TEST_P(Test_ONNX_layers, DivConst) testONNXModels("div_const"); } +TEST_P(Test_ONNX_layers, Gemm) +{ + testONNXModels("gemm_no_transB"); + testONNXModels("gemm_transB_0"); +} TEST_P(Test_ONNX_layers, Quantized_Convolution) { -- 2.7.4