make MatMul support 3D or 4D with broadcast
authorzoom <zhongwl2018@mail.sustech.edu.cn>
Thu, 15 Dec 2022 02:36:08 +0000 (10:36 +0800)
committerzoom <zhongwl2018@mail.sustech.edu.cn>
Thu, 15 Dec 2022 02:36:08 +0000 (10:36 +0800)
modules/dnn/src/cuda4dnn/primitives/matmul.hpp
modules/dnn/src/layers/fully_connected_layer.cpp
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp

index e29036d5f48e2374b00c19ec9a219122ec916a0e..e4ab3d272115fd4d371057ab54bba06f1547372c 100644 (file)
@@ -23,9 +23,14 @@ namespace cv { namespace dnn { namespace cuda4dnn {
     public:
         using wrapper_type = GetCUDABackendWrapperType<T>;
 
-        MatMulOp(csl::Stream stream_, csl::cublas::Handle handle)
+        MatMulOp(csl::Stream stream_, csl::cublas::Handle handle, const Mat& constInp)
             : stream(std::move(stream_)), cublasHandle(std::move(handle))
         {
+            if (!constInp.empty())
+            {
+                constTensor = csl::makeTensorHeader<T>(constInp);
+                csl::copyMatToTensor<T>(constInp, constTensor, stream);
+            }
         }
 
         void forward(
@@ -33,13 +38,20 @@ namespace cv { namespace dnn { namespace cuda4dnn {
             const std::vector<cv::Ptr<BackendWrapper>>& outputs,
             csl::Workspace& workspace) override
         {
-            CV_Assert(inputs.size() == 2 && outputs.size() == 1);
+            CV_Assert((inputs.size() == 2 && constTensor.empty() ||
+                       inputs.size() == 1 && !constTensor.empty()) && outputs.size() == 1);
 
             auto input1_wrapper = inputs[0].dynamicCast<wrapper_type>();
             auto input1 = input1_wrapper->getView();
 
-            auto input2_wrapper = inputs[1].dynamicCast<wrapper_type>();
-            auto input2 = input2_wrapper->getView();
+            csl::TensorView<T> input2;
+            if (constTensor.empty())
+            {
+                auto input2_wrapper = inputs[1].dynamicCast<wrapper_type>();
+                input2 = input2_wrapper->getView();
+            }
+            else
+                input2 = csl::TensorView<T>(constTensor);
 
             auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
             auto output = output_wrapper->getSpan();
@@ -59,9 +71,18 @@ namespace cv { namespace dnn { namespace cuda4dnn {
 
             auto m = input1.get_axis_size(-2);
             auto n = input1.get_axis_size(-1);
-            auto k = input2.get_axis_size(-1);
             auto b = input1.size() / m / n;
-            CV_Assert(input2.get_axis_size(-2) == n);
+            int k;
+            if (constTensor.empty())
+            {
+                k = input2.get_axis_size(-1);
+                CV_Assert(input2.get_axis_size(-2) == n);
+            }
+            else
+            {
+                k = input2.get_axis_size(-2);
+                CV_Assert(input2.get_axis_size(-1) == n);
+            }
             CV_Assert(output.get_axis_size(-2) == m);
             CV_Assert(output.get_axis_size(-1) == k);
 
@@ -70,24 +91,28 @@ namespace cv { namespace dnn { namespace cuda4dnn {
                 CV_Assert(b == 1);
                 CV_Assert(get_effective_rank(input1) <= 2);
                 CV_Assert(get_effective_rank(input2) <= 2);
-                csl::tensor_ops::gemm<T>(cublasHandle, 0.0, output, 1.0, false, input1, false, input2);
+                csl::tensor_ops::gemm<T>(cublasHandle, 0.0, output, 1.0, false, input1, !constTensor.empty(), input2);
             }
             else
             {
                 CV_Assert(rank >= 3);
                 input1.reshape(b, m, n);
-                input2.reshape(b, n, k);
+                if (constTensor.empty())
+                    input2.reshape(b, n, k);
+                else
+                    input2.reshape(b, k, n);
                 output.reshape(b, m, k);
                 input1.squeeze_to(3);
                 input2.squeeze_to(3);
                 output.squeeze_to(3);
-                csl::tensor_ops::gemmStridedBatched<T>(cublasHandle, 0.0, output, 1.0, false, input1, false, input2);
+                csl::tensor_ops::gemmStridedBatched<T>(cublasHandle, 0.0, output, 1.0, false, input1, !constTensor.empty(), input2);
             }
         }
 
     private:
         csl::Stream stream;
         csl::cublas::Handle cublasHandle;
+        csl::Tensor<T> constTensor;
     };
 
 }}} /* namespace cv::dnn::cuda4dnn */
index 1be5bbe3661513f2dc3cb2a380b250ca4eef4d9e..321994cbb753c7f44e8514e3390bab3e6dca3669 100644 (file)
@@ -85,6 +85,7 @@ public:
 
         bias = params.get<bool>("bias_term", true);
         axis = params.get<int>("axis", 1);
+        isMatMul = params.get<bool>("is_matmul", false);
         if (!blobs.empty())
         {
             CV_Assert(1 <= blobs.size() && blobs.size() <= 2);
@@ -94,6 +95,7 @@ public:
             CV_Assert(blobs[0].dims >= 2 && (size_t)(innerSize * numOutput) == blobs[0].total());
             CV_Assert(!bias || (blobs.size() == 2 && (size_t)numOutput == blobs[1].total()));
 
+            blobs[0].copyTo(oriMat);
             weightsMat = blobs[0] = blobs[0].reshape(1, numOutput);
             int vecsize = weightsMat.cols;
             if (vecsize % VEC_ALIGN != 0)
@@ -108,6 +110,8 @@ public:
 
             if (bias)
                 biasMat = blobs[1] = blobs[1].reshape(1, 1);
+            else if(isMatMul)
+                biasMat = Mat::zeros(1, oriMat.size[oriMat.dims - 2], weightsMat.type());
             else
                 biasMat = Mat::zeros(1, numOutput, weightsMat.type());
         }
@@ -153,7 +157,10 @@ public:
             CV_Assert(!transA && !transB);
             CV_CheckEQ(inputsTmp.size(), (size_t)1, "");
             CV_CheckEQ(blobs[0].dims, 2, "");
-            numOutput = blobs[0].size[0];
+            if(isMatMul)
+                numOutput = oriMat.size[oriMat.dims - 2];
+            else
+                numOutput = blobs[0].size[0];
             CV_Assert(!bias || (size_t)numOutput == blobs[1].total());
             cAxis = normalize_axis(axis, inputsTmp[0]);
         }
@@ -519,16 +526,40 @@ public:
         if (!blobs.empty())
         {
             CV_Assert(!transA && !transB);
-            int axisCan = normalize_axis(axis, input[0].dims);
-            int outerSize = input[0].total(0, axisCan);
+            int inp1Dim = input[0].dims;
+            if (isMatMul)
+            {
+                int matNum = input[0].total(0, inp1Dim - 2);
+                int rowMatMul = oriMat.size[oriMat.dims - 2];
+                Mat srcMatTmp = input[0].reshape(1, matNum);
+                Mat dstMatTmp = output[0].reshape(1, matNum);
+
+                int outerSize = input[0].size[inp1Dim - 2];
+                int rowStart = -rowMatMul;
+                for (int n = 0; n < matNum; ++n)
+                {
+                    Mat srcMat = srcMatTmp.row(n).reshape(1, outerSize);
+                    Mat dstMat = dstMatTmp.row(n).reshape(1, outerSize);
+                    rowStart = (rowStart + rowMatMul) % weightsMat.rows;
+                    Mat weiMat = weightsMat.rowRange(rowStart, rowStart + rowMatMul);
 
-            for (size_t i = 0; i < input.size(); i++)
+                    const int nstripes = getNumThreads();
+                    FullyConnected::run(srcMat, weiMat, biasMat, dstMat, activ.get(), nstripes);
+                }
+            }
+            else
             {
-                Mat srcMat = input[i].reshape(1, outerSize);
-                Mat dstMat = output[i].reshape(1, outerSize);
+                int axisCan = normalize_axis(axis, inp1Dim);
+                int outerSize = input[0].total(0, axisCan);
+
+                for (size_t i = 0; i < input.size(); i++)
+                {
+                    Mat srcMat = input[i].reshape(1, outerSize);
+                    Mat dstMat = output[i].reshape(1, outerSize);
 
-                const int nstripes = getNumThreads();
-                FullyConnected::run(srcMat, weightsMat, biasMat, dstMat, activ.get(), nstripes);
+                    const int nstripes = getNumThreads();
+                    FullyConnected::run(srcMat, weightsMat, biasMat, dstMat, activ.get(), nstripes);
+                }
             }
         }
         else
@@ -579,14 +610,26 @@ public:
     ) override
     {
         auto context = reinterpret_cast<csl::CSLContext*>(context_);
+        auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
 
-        if (weightsMat.empty())
+        if (weightsMat.empty() || isMatMul)
         {
             CV_Assert(!bias);
-            return make_cuda_node<cuda4dnn::MatMulOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle));
+            int inp2Dim;
+            // broadcast is not supported with CUDA
+            if(weightsMat.empty())
+            {
+                auto input_wrapper2 = inputs[1].dynamicCast<CUDABackendWrapper>();
+                inp2Dim = input_wrapper2->getRank();
+            }else
+                inp2Dim = oriMat.dims;
+
+            if(input_wrapper->getRank() == inp2Dim)
+                return make_cuda_node<cuda4dnn::MatMulOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), oriMat);
+            else
+                return Ptr<BackendNode>();
         }
 
-        auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
         auto flatten_start_axis = normalize_axis(axis, input_wrapper->getRank());
         auto biasMat_ = bias ? biasMat : Mat();
         return make_cuda_node<cuda4dnn::InnerProductOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), flatten_start_axis, weightsMat, biasMat_);
@@ -752,8 +795,9 @@ public:
     }
 
     bool bias;
-    Mat weightsMat, biasMat;
+    Mat weightsMat, biasMat, oriMat;
     bool transA, transB;
+    bool isMatMul = false;
     Ptr<ActivationLayer> activ;
 };
 
index 39c274e4fa41ef1f92b5ee4bb731671ba7782240..fe4d4660f3ef8e8bcec64936ff5ceba6572500c5 100644 (file)
@@ -2088,30 +2088,21 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
     if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
     {
         Mat blob = getBlob(node_proto, 1);
+        Mat transBlob;
         secondInpDims = blob.dims;
-        if (secondInpDims == 2)
-        {
-            layerParams.blobs.push_back(blob.t());
-            layerParams.set("num_output", layerParams.blobs[0].size[0]);
-        }
-        else
-        {
-            LayerParams constParams;
-            constParams.name = layerParams.name + "/const_1";
-            constParams.type = "Const";
-            constParams.blobs.push_back(blob);
-
-            opencv_onnx::NodeProto tmpProto;
-            tmpProto.add_output(constParams.name);
-            addLayer(constParams, tmpProto);
-
-            node_proto.set_input(1, constParams.name);
-        }
-    }
-    else
+        // create order transposing last 2 dimensions
+        std::vector<int> order(secondInpDims);
+        std::iota(order.begin(), order.end(), 0);
+        std::swap(order[secondInpDims - 2], order[secondInpDims - 1]);
+        transposeND(blob, order, transBlob);
+        layerParams.blobs.push_back(transBlob);
+        int numOutput = layerParams.blobs[0].total(0, secondInpDims - 1);
+        layerParams.set("num_output", numOutput);
+        layerParams.set("is_matmul", true);
+    } else
         secondInpDims = outShapes[node_proto.input(1)].size();
 
-    layerParams.set("axis", firstInpDims - secondInpDims + 1);
+    layerParams.set("axis", firstInpDims - 1);
     addLayer(layerParams, node_proto);
 }
 
index 80db735f4f1de0aab0deb13764683740737492e4..e8350e418d587109a64c1cbdd93e13c73456e16b 100644 (file)
@@ -921,6 +921,7 @@ TEST_P(Test_ONNX_layers, MatMul_init)
     testONNXModels("matmul_4d_init");
 
     testONNXModels("matmul_init_2");
+    testONNXModels("matmul_init_bcast");
 }
 
 TEST_P(Test_ONNX_layers, MatMulAdd)