);
}
+ /** @brief Strided batched GEMM for colummn-major matrices
+ *
+ * \f$ C_i = \alpha A_i B_i + \beta C_i \f$ for a stack of matrices A, B and C indexed by i
+ *
+ * @tparam T matrix element type (must be `half` or `float`)
+ *
+ * @param handle valid cuBLAS Handle
+ * @param transa use transposed matrix of A_i for computation
+ * @param transb use transposed matrix of B_i for computation
+ * @param rows_c number of rows in C_i
+ * @param cols_c number of columns in C_i
+ * @param common_dim common dimension of A_i (or trans A_i) and B_i (or trans B_i)
+ * @param alpha scale factor for A_i B_i
+ * @param[in] A pointer to stack of column-major matrices A in device memory
+ * @param lda leading dimension of matrix A_i
+ * @param strideA stride between matrices in A
+ * @param[in] B pointer to stack of column-major matrices B in device memory
+ * @param ldb leading dimension of matrix B_i
+ * @param strideB stride between matrices in B
+ * @param beta scale factor for C_i
+ * @param[in,out] C pointer to stack of column-major matrices C in device memory
+ * @param ldc leading dimension of matrix C_i
+ * @param strideC stride between matrices in C
+ * @param batchCount number of matrices in the batch
+ *
+ * Exception Guarantee: Basic
+ */
+ template <class T>
+ void gemmStridedBatched(const Handle& handle,
+ bool transa, bool transb,
+ std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
+ T alpha, const DevicePtr<const T> A, std::size_t lda, std::size_t strideA,
+ const DevicePtr<const T> B, std::size_t ldb, std::size_t strideB,
+ T beta, const DevicePtr<T> C, std::size_t ldc, std::size_t strideC,
+ std::size_t batchCount);
+
+ template <> inline
+ void gemmStridedBatched<half>(const Handle& handle,
+ bool transa, bool transb,
+ std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
+ half alpha, const DevicePtr<const half> A, std::size_t lda, std::size_t strideA,
+ const DevicePtr<const half> B, std::size_t ldb, std::size_t strideB,
+ half beta, const DevicePtr<half> C, std::size_t ldc, std::size_t strideC,
+ std::size_t batchCount)
+ {
+ CV_Assert(handle);
+
+ const auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N,
+ opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
+ const auto irows_c = static_cast<int>(rows_c),
+ icols_c = static_cast<int>(cols_c),
+ icommon_dim = static_cast<int>(common_dim),
+ ilda = static_cast<int>(lda),
+ ildb = static_cast<int>(ldb),
+ ildc = static_cast<int>(ldc);
+
+ const auto batch_count = static_cast<int>(batchCount);
+ const auto stride_a = static_cast<long long int>(strideA),
+ stride_b = static_cast<long long int>(strideB),
+ stride_c = static_cast<long long int>(strideC);
+
+ CV_Assert(stride_c >= irows_c * icols_c); // output matrices must not overlap
+
+ CUDA4DNN_CHECK_CUBLAS(
+ cublasHgemmStridedBatched(
+ handle.get(),
+ opa, opb,
+ irows_c, icols_c, icommon_dim,
+ &alpha, A.get(), ilda, stride_a,
+ B.get(), ildb, stride_b,
+ &beta, C.get(), ildc, stride_c,
+ batch_count
+ )
+ );
+ }
+
+ template <> inline
+ void gemmStridedBatched<float>(const Handle& handle,
+ bool transa, bool transb,
+ std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
+ float alpha, const DevicePtr<const float> A, std::size_t lda, std::size_t strideA,
+ const DevicePtr<const float> B, std::size_t ldb, std::size_t strideB,
+ float beta, const DevicePtr<float> C, std::size_t ldc, std::size_t strideC,
+ std::size_t batchCount)
+ {
+ CV_Assert(handle);
+
+ const auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N,
+ opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
+ const auto irows_c = static_cast<int>(rows_c),
+ icols_c = static_cast<int>(cols_c),
+ icommon_dim = static_cast<int>(common_dim),
+ ilda = static_cast<int>(lda),
+ ildb = static_cast<int>(ldb),
+ ildc = static_cast<int>(ldc);
+
+ const auto batch_count = static_cast<int>(batchCount);
+ const auto stride_a = static_cast<long long int>(strideA),
+ stride_b = static_cast<long long int>(strideB),
+ stride_c = static_cast<long long int>(strideC);
+
+ CV_Assert(stride_c >= irows_c * icols_c); // output matrices must not overlap
+
+ CUDA4DNN_CHECK_CUBLAS(
+ cublasSgemmStridedBatched(
+ handle.get(),
+ opa, opb,
+ irows_c, icols_c, icommon_dim,
+ &alpha, A.get(), ilda, stride_a,
+ B.get(), ildb, stride_b,
+ &beta, C.get(), ildc, stride_c,
+ batch_count
+ )
+ );
+ }
+
}}}}} /* namespace cv::dnn::cuda4dnn::csl::cublas */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP */
shape.erase(std::begin(shape) + axis);
}
+ /** @brief squeezes the tensor
+ *
+ * removes leading singleton axes until the tensor's rank is equal to the requested rank
+ *
+ * Pre-conditions:
+ * - the tensor must be non-empty
+ * - the tensor's rank must be at least two
+ * - the tensor's rank must be at least the requested rank
+ * - the tensor must be squeezable up to the requested rank
+ *
+ * Exception Guarantee: Strong
+ */
+ void squeeze_to(int r) {
+ CV_Assert(!empty());
+ CV_Assert(rank() >= r);
+ CV_Assert(std::all_of(std::begin(shape), std::end(shape) - r, [](size_type x){ return x == 1; }));
+ std::copy(std::end(shape) - r, std::end(shape), std::begin(shape));
+ shape.resize(r);
+ }
+
/** @brief unsqueezes the tensor
*
* adds a axis of unit size at the requested before the specified axis
shape.erase(std::begin(shape) + axis);
}
+ /** @brief squeezes the tensor
+ *
+ * removes leading singleton axes until the tensor's rank is equal to the requested rank
+ *
+ * Pre-conditions:
+ * - the tensor must be non-empty
+ * - the tensor's rank must be at least two
+ * - the tensor's rank must be at least the requested rank
+ * - the tensor must be squeezable up to the requested rank
+ *
+ * Exception Guarantee: Strong
+ */
+ void squeeze_to(int r) {
+ CV_Assert(!empty());
+ CV_Assert(rank() >= r);
+ CV_Assert(std::all_of(std::begin(shape), std::end(shape) - r, [](size_type x){ return x == 1; }));
+ std::copy(std::end(shape) - r, std::end(shape), std::begin(shape));
+ shape.resize(r);
+ }
+
/** @brief unsqueezes the tensor
*
* adds a axis of unit size at the requested before the specified axis
shape.erase(std::begin(shape) + axis);
}
+ /** @brief squeezes the tensor
+ *
+ * removes leading singleton axes until the tensor's rank is equal to the requested rank
+ *
+ * Pre-conditions:
+ * - the tensor must be non-empty
+ * - the tensor's rank must be at least two
+ * - the tensor's rank must be at least the requested rank
+ * - the tensor must be squeezable up to the requested rank
+ *
+ * Exception Guarantee: Strong
+ */
+ void squeeze_to(int r) {
+ CV_Assert(!empty());
+ CV_Assert(rank() >= r);
+ CV_Assert(std::all_of(std::begin(shape), std::end(shape) - r, [](size_type x){ return x == 1; }));
+ std::copy(std::end(shape) - r, std::end(shape), std::begin(shape));
+ shape.resize(r);
+ }
+
/** @brief unsqueezes the tensor
*
* adds a axis of unit size at the requested before the specified axis
memcpy(dest.get(), src.get(), dest.size(), stream);
}
+ namespace detail {
+ template <class T>
+ void assertGEMMCompatiblity(const TensorSpan<T>& result, bool transa, const TensorView<T>& A, bool transb, const TensorView<T>& B) {
+ /* check dimension requirements for matrix multiplication */
+ if (!transa && !transb) {
+ CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2));
+ CV_Assert(A.get_axis_size(-1) == B.get_axis_size(-2));
+ CV_Assert(B.get_axis_size(-1) == result.get_axis_size(-1));
+ } else if (!transa && transb) {
+ CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2));
+ CV_Assert(A.get_axis_size(-1) == B.get_axis_size(-1));
+ CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1));
+ } else if (transa && !transb) {
+ CV_Assert(A.get_axis_size(-1) == result.get_axis_size(-2));
+ CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-2));
+ CV_Assert(B.get_axis_size(-1) == result.get_axis_size(-1));
+ } else {
+ CV_Assert(A.get_axis_size(-1) == result.get_axis_size(-2));
+ CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-1));
+ CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1));
+ }
+ }
+ }
+
/** @brief performs generalized matrix-multiplication
*
* Pre-conditions:
*/
template <class T> inline
void gemm(const cublas::Handle& handle, T beta, TensorSpan<T> result, T alpha, bool transa, TensorView<T> A, bool transb, TensorView<T> B) {
- /* matrix operations can be performed only on rank two or less tensors */
- CV_Assert(get_effective_rank(A) <= 2 &&
- get_effective_rank(B) <= 2 &&
- get_effective_rank(result) <= 2);
-
- /* check dimension requirements for matrix multiplication */
- if (!transa && !transb) {
- CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2));
- CV_Assert(A.get_axis_size(-1) == B.get_axis_size(-2));
- CV_Assert(B.get_axis_size(-1) == result.get_axis_size(-1));
- } else if (!transa && transb) {
- CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2));
- CV_Assert(A.get_axis_size(-1) == B.get_axis_size(-1));
- CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1));
- } else if (transa && !transb) {
- CV_Assert(A.get_axis_size(-1) == result.get_axis_size(-2));
- CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-2));
- CV_Assert(B.get_axis_size(-1) == result.get_axis_size(-1));
- } else {
- CV_Assert(A.get_axis_size(-1) == result.get_axis_size(-2));
- CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-1));
- CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1));
- }
+ /* matrix operations can be performed only on tensors with rank two or below */
+ CV_Assert(get_effective_rank(A) <= 2);
+ CV_Assert(get_effective_rank(B) <= 2);
+ CV_Assert(get_effective_rank(result) <= 2);
const auto result_nr = result.get_axis_size(-2);
const auto result_nc = result.get_axis_size(-1);
const auto A_nc = A.get_axis_size(-1);
const auto B_nc = B.get_axis_size(-1);
+ detail::assertGEMMCompatiblity(result, transa, A, transb, B);
+
/* tensors are stored in row-major but cublas::gemm operates on column-major matrices
* a row-major matrix when read as column-major matrix gives the transpose of the intended matrix
*
beta, result.get(), result_nc);
}
+ /** @brief performs generalized matrix-multiplication for a strided batch of matrices
+ *
+ * Pre-conditions:
+ * - A, B and C must be rank three tensors with dimensions (batch, rows, cols)
+ * - the last two axes of \p A and \p B must meet the mathematical requirements for matrix multiplication
+ * - \p result must be large enough to hold the result and the matrices must not overlap in memory
+ * - batch dimension should be same in \p A, \p B and \p result
+ *
+ * Exception Guarantee: Basic
+ */
+ template <class T> inline
+ void gemmStridedBatched(const cublas::Handle& handle, T beta, TensorSpan<T> result, T alpha, bool transa, TensorView<T> A, bool transb, TensorView<T> B) {
+ CV_Assert(A.rank() == 3);
+ CV_Assert(B.rank() == 3);
+ CV_Assert(result.rank() == 3);
+
+ const auto batch_size = result.get_axis_size(0);
+ CV_Assert(batch_size == A.get_axis_size(0));
+ CV_Assert(batch_size == B.get_axis_size(0));
+
+ detail::assertGEMMCompatiblity(result, transa, A, transb, B);
+
+ const auto result_nr = result.get_axis_size(-2);
+ const auto result_nc = result.get_axis_size(-1);
+ const auto common_dim = A.get_axis_size(transa ? -2 : -1);
+ const auto A_nc = A.get_axis_size(-1);
+ const auto B_nc = B.get_axis_size(-1);
+
+ std::size_t strideA = (A.size() / batch_size),
+ strideB = (B.size() / batch_size),
+ strideC = (result.size() / batch_size);
+
+ cublas::gemmStridedBatched<T>(handle,
+ transb, transa,
+ result_nc, result_nr, common_dim,
+ alpha, B.get(), B_nc, strideB,
+ A.get(), A_nc, strideA,
+ beta, result.get(), result_nc, strideC,
+ batch_size);
+ }
+
/** @brief performs element-wise addition with broadcasting
*
* Pre-conditions:
--- /dev/null
+// This file is part of OpenCV project.
+// It is subject to the license terms in the LICENSE file found in the top-level directory
+// of this distribution and at http://opencv.org/license.html.
+
+#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_HPP
+#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_HPP
+
+#include "../../op_cuda.hpp"
+
+#include "../csl/stream.hpp"
+#include "../csl/cublas.hpp"
+#include "../csl/tensor.hpp"
+#include "../csl/tensor_ops.hpp"
+
+#include <opencv2/core.hpp>
+
+#include <utility>
+
+namespace cv { namespace dnn { namespace cuda4dnn {
+
+ template <class T>
+ class MatMulOp final : public CUDABackendNode {
+ public:
+ using wrapper_type = GetCUDABackendWrapperType<T>;
+
+ MatMulOp(csl::Stream stream_, csl::cublas::Handle handle)
+ : stream(std::move(stream_)), cublasHandle(std::move(handle))
+ {
+ }
+
+ void forward(
+ const std::vector<cv::Ptr<BackendWrapper>>& inputs,
+ const std::vector<cv::Ptr<BackendWrapper>>& outputs,
+ csl::Workspace& workspace) override
+ {
+ CV_Assert(inputs.size() == 2 && outputs.size() == 1);
+
+ auto input1_wrapper = inputs[0].dynamicCast<wrapper_type>();
+ auto input1 = input1_wrapper->getView();
+
+ auto input2_wrapper = inputs[1].dynamicCast<wrapper_type>();
+ auto input2 = input2_wrapper->getView();
+
+ auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
+ auto output = output_wrapper->getSpan();
+
+ auto rank = output.rank();
+ CV_Assert(rank == input1.rank());
+ CV_Assert(rank == input2.rank());
+ CV_Assert(rank >= 2); // 1D MatMul not supported
+
+ for (int i = 0; i < rank - 2; i++)
+ {
+ // broadcasting not supported
+ auto size = output.get_axis_size(i);
+ CV_Assert(input1.get_axis_size(i) == size);
+ CV_Assert(input2.get_axis_size(i) == size);
+ }
+
+ auto m = input1.get_axis_size(-2);
+ auto n = input1.get_axis_size(-1);
+ auto k = input2.get_axis_size(-1);
+ auto b = input1.size() / m / n;
+ CV_Assert(input2.get_axis_size(-2) == n);
+ CV_Assert(output.get_axis_size(-2) == m);
+ CV_Assert(output.get_axis_size(-1) == k);
+
+ if (get_effective_rank(output) <= 2)
+ {
+ CV_Assert(b == 1);
+ CV_Assert(get_effective_rank(input1) <= 2);
+ CV_Assert(get_effective_rank(input2) <= 2);
+ csl::tensor_ops::gemm<T>(cublasHandle, 0.0, output, 1.0, false, input1, false, input2);
+ }
+ else
+ {
+ CV_Assert(rank >= 3);
+ input1.reshape(b, m, n);
+ input2.reshape(b, n, k);
+ output.reshape(b, m, k);
+ input1.squeeze_to(3);
+ input2.squeeze_to(3);
+ output.squeeze_to(3);
+ csl::tensor_ops::gemmStridedBatched<T>(cublasHandle, 0.0, output, 1.0, false, input1, false, input2);
+ }
+ }
+
+ private:
+ csl::Stream stream;
+ csl::cublas::Handle cublasHandle;
+ };
+
+}}} /* namespace cv::dnn::cuda4dnn */
+
+#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_HPP */
#endif
#ifdef HAVE_CUDA
+#include "../cuda4dnn/primitives/matmul.hpp"
#include "../cuda4dnn/primitives/inner_product.hpp"
using namespace cv::dnn::cuda4dnn;
#endif
{
auto context = reinterpret_cast<csl::CSLContext*>(context_);
- auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
+ if (weightsMat.empty())
+ {
+ CV_Assert(!bias);
+ return make_cuda_node<cuda4dnn::MatMulOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle));
+ }
+ auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
auto flatten_start_axis = normalize_axis(axis, input_wrapper->getRank());
-
auto biasMat_ = bias ? biasMat : Mat();
return make_cuda_node<cuda4dnn::InnerProductOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), flatten_start_axis, weightsMat, biasMat_);
}
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
- if (backend == DNN_BACKEND_CUDA)
- applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA); // not supported
testONNXModels("matmul_2d");
testONNXModels("matmul_3d");
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2020040000)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE);
#endif
- if (backend == DNN_BACKEND_CUDA)
- applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA);
testONNXModels("matmul_with_two_inputs");
}