gemm support transA and transB, and first input is constance.
authorzihaomu <zihaomu@outlook.com>
Tue, 29 Nov 2022 09:13:36 +0000 (17:13 +0800)
committerzihaomu <zihaomu@outlook.com>
Tue, 29 Nov 2022 09:13:36 +0000 (17:13 +0800)
modules/dnn/src/layers/fully_connected_layer.cpp
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index 71ca706..1be5bbe 100644 (file)
@@ -80,6 +80,9 @@ public:
     FullyConnectedLayerImpl(const LayerParams& params)
     {
         setParamsFrom(params);
+        transA = params.get<bool>("transA", false);
+        transB = params.get<bool>("transB", false);
+
         bias = params.get<bool>("bias_term", true);
         axis = params.get<int>("axis", 1);
         if (!blobs.empty())
@@ -116,30 +119,48 @@ public:
                          std::vector<MatShape> &) const CV_OVERRIDE
     {
         int numOutput, cAxis;
+
+        std::vector<MatShape> inputsTmp;
+        inputsTmp.assign(inputs.begin(), inputs.end());
+
         if (blobs.empty())
         {
-            CV_CheckEQ(inputs.size(), (size_t)2, "");
-            numOutput = inputs[1].back();
-            cAxis = inputs[0].size() - 1;
-            int dims = inputs[0].size();
-            CV_CheckEQ(inputs[1].size(), (size_t)dims, "");
+            CV_CheckEQ(inputsTmp.size(), (size_t)2, "");
+
+            if (transA)
+            {
+                CV_CheckEQ(inputsTmp[0].size(), (size_t)2, "");
+                std::swap(inputsTmp[0][0], inputsTmp[0][1]);
+            }
+
+            if (transB)
+            {
+                CV_CheckEQ(inputsTmp[1].size(), (size_t)2, "");
+                std::swap(inputsTmp[1][0], inputsTmp[1][1]);
+            }
+
+            numOutput = inputsTmp[1].back();
+            cAxis = inputsTmp[0].size() - 1;
+            int dims = inputsTmp[0].size();
+            CV_CheckEQ(inputsTmp[1].size(), (size_t)dims, "");
             CV_CheckGE(dims, 2, "");
             for (int i = 0; i < dims - 2; i++)
-                CV_CheckEQ(inputs[0][i], inputs[1][i], "");
-            CV_CheckEQ(inputs[0].back(), inputs[1][dims - 2], "");
+                CV_CheckEQ(inputsTmp[0][i], inputsTmp[1][i], "");
+            CV_CheckEQ(inputsTmp[0].back(), inputsTmp[1][dims - 2], "");
         }
         else
         {
-            CV_CheckEQ(inputs.size(), (size_t)1, "");
+            CV_Assert(!transA && !transB);
+            CV_CheckEQ(inputsTmp.size(), (size_t)1, "");
             CV_CheckEQ(blobs[0].dims, 2, "");
             numOutput = blobs[0].size[0];
             CV_Assert(!bias || (size_t)numOutput == blobs[1].total());
-            cAxis = normalize_axis(axis, inputs[0]);
+            cAxis = normalize_axis(axis, inputsTmp[0]);
         }
 
         MatShape outShape(cAxis + 1);
         for (int i = 0; i < cAxis; ++i)
-            outShape[i] = inputs[0][i];
+            outShape[i] = inputsTmp[0][i];
         outShape.back() = numOutput;
 
         outputs.resize(1, outShape);
@@ -148,15 +169,15 @@ public:
 
     virtual bool supportBackend(int backendId) CV_OVERRIDE
     {
+        bool tranAorB = transA || transB;
 #ifdef HAVE_INF_ENGINE
         if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
-            return axis == 1;
+            return axis == 1 && !tranAorB;
 #endif
-
         return backendId == DNN_BACKEND_OPENCV ||
-               backendId == DNN_BACKEND_CUDA ||
-               (backendId == DNN_BACKEND_HALIDE && haveHalide() && axis == 1) ||
-               (backendId == DNN_BACKEND_WEBNN && axis == 1);
+               (backendId == DNN_BACKEND_CUDA && !tranAorB) ||
+               (backendId == DNN_BACKEND_HALIDE && haveHalide() && axis == 1 && !tranAorB) ||
+               (backendId == DNN_BACKEND_WEBNN && axis == 1 && !tranAorB);
     }
 
     virtual bool setActivation(const Ptr<ActivationLayer>& layer) CV_OVERRIDE
@@ -497,6 +518,7 @@ public:
 
         if (!blobs.empty())
         {
+            CV_Assert(!transA && !transB);
             int axisCan = normalize_axis(axis, input[0].dims);
             int outerSize = input[0].total(0, axisCan);
 
@@ -511,15 +533,30 @@ public:
         }
         else
         {
-            float* inpData = input[0].ptr<float>();
-            float* weightData = input[1].ptr<float>();
+            Mat input0 = input[0];
+            Mat input1 = input[1];
+
+            if (transA)
+            {
+                CV_Assert(input0.dims == 2);
+                input0 = input0.t();
+            }
+
+            if (transB)
+            {
+                CV_Assert(input1.dims == 2);
+                input1 = input1.t();
+            }
+
+            float* inpData = input0.ptr<float>();
+            float* weightData = input1.ptr<float>();
             float* outData = output[0].ptr<float>();
 
             int dims = output[0].dims;
             int numSlice = output[0].total() / output[0].total(dims - 2);
-            int m = input[0].size[dims - 2];
-            int n = input[0].size[dims - 1];
-            int k = input[1].size[dims - 1];
+            int m = input0.size[dims - 2];
+            int n = input0.size[dims - 1];
+            int k = input1.size[dims - 1];
             for (int i = 0; i < numSlice; i++)
             {
                 Mat inpSlice(m, n, CV_32F, inpData);
@@ -716,6 +753,7 @@ public:
 
     bool bias;
     Mat weightsMat, biasMat;
+    bool transA, transB;
     Ptr<ActivationLayer> activ;
 };
 
index 6ef8894..c626a99 100644 (file)
@@ -2000,18 +2000,9 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr
 {
     CV_Assert(node_proto.input_size() >= 2);
     layerParams.type = "InnerProduct";
-    Mat weights = getBlob(node_proto, 1);
+    int transA = layerParams.get<int>("transA", 0);
+    layerParams.set("transA", transA == 1);
 
-    if (!layerParams.get<int>("transB", 0))
-    {
-        transpose(weights, weights);
-    }
-    layerParams.blobs.push_back(weights);
-
-    if (node_proto.input_size() == 3) {
-        Mat bias = getBlob(node_proto, 2);
-        layerParams.blobs.push_back(bias);
-    }
     if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
     {
         Mat inputBuf = getBlob(node_proto, 0);
@@ -2026,7 +2017,43 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr
         addLayer(constParams, proto);
     }
 
-    layerParams.set("num_output", layerParams.blobs[0].size[0]);
+    int transB = layerParams.get<int>("transB", 0);
+    if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
+    {
+        Mat weights = getBlob(node_proto, 1);
+
+        if (transA == 0) // optimized barnch, for now, we can only optimize the Gemm when transA = 0.
+        {
+            if (transB == 0)
+            {
+                transpose(weights, weights);
+            }
+            layerParams.set("transB", false);
+            layerParams.blobs.push_back(weights);
+            layerParams.set("num_output", layerParams.blobs[0].size[0]);
+        }
+        else // no optimized branch, TODO! optimize when the transA==1.
+        {
+            LayerParams constParams;
+            constParams.name = node_proto.input(1);
+            constParams.type = "Const";
+            constParams.blobs.push_back(weights);
+
+            opencv_onnx::NodeProto proto;
+            proto.add_output(constParams.name);
+            addLayer(constParams, proto);
+            layerParams.set("transB", transB == 1);
+        }
+    }
+    else
+        layerParams.set("transB", transB == 1);
+
+    if (node_proto.input_size() == 3)
+    {
+        Mat bias = getBlob(node_proto, 2);
+        layerParams.blobs.push_back(bias);
+    }
+
     layerParams.set("bias_term", node_proto.input_size() == 3);
     addLayer(layerParams, node_proto);
 }
index 43dc817..cee7cf0 100644 (file)
@@ -1829,6 +1829,7 @@ TEST_P(Test_ONNX_layers, Gemm)
 {
     testONNXModels("gemm_no_transB");
     testONNXModels("gemm_transB_0");
+    testONNXModels("gemm_first_const");
 }
 
 TEST_P(Test_ONNX_layers, Quantized_Convolution)