From 2f34bb9aa0c24c08302e168a15dfa0a6ec4933cc Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Sat, 14 Dec 2013 23:16:53 +0400 Subject: [PATCH] added cv::gemm to T-API --- modules/core/include/opencv2/core/ocl.hpp | 1 + modules/core/src/matmul.cpp | 92 ++++++++++++++++++ modules/core/src/ocl.cpp | 104 +++++++++++++++++---- modules/core/test/ocl/test_gemm.cpp | 149 ++++++++++++++++++++++++++++++ 4 files changed, 329 insertions(+), 17 deletions(-) create mode 100644 modules/core/test/ocl/test_gemm.cpp diff --git a/modules/core/include/opencv2/core/ocl.hpp b/modules/core/include/opencv2/core/ocl.hpp index f253594..2a90b85 100644 --- a/modules/core/include/opencv2/core/ocl.hpp +++ b/modules/core/include/opencv2/core/ocl.hpp @@ -48,6 +48,7 @@ namespace cv { namespace ocl { CV_EXPORTS bool haveOpenCL(); CV_EXPORTS bool useOpenCL(); +CV_EXPORTS bool haveAmdBlas(); CV_EXPORTS void setUseOpenCL(bool flag); CV_EXPORTS void finish2(); diff --git a/modules/core/src/matmul.cpp b/modules/core/src/matmul.cpp index 7d832cb..16eb6e0 100644 --- a/modules/core/src/matmul.cpp +++ b/modules/core/src/matmul.cpp @@ -41,6 +41,7 @@ //M*/ #include "precomp.hpp" +#include "opencv2/core/opencl/runtime/opencl_clamdblas.hpp" #ifdef HAVE_IPP #include "ippversion.h" @@ -693,11 +694,102 @@ static void GEMMStore_64fc( const Complexd* c_data, size_t c_step, GEMMStore(c_data, c_step, d_buf, d_buf_step, d_data, d_step, d_size, alpha, beta, flags); } +#ifdef HAVE_CLAMDBLAS + +static bool ocl_gemm( InputArray matA, InputArray matB, double alpha, + InputArray matC, double beta, OutputArray matD, int flags ) +{ + int type = matA.type(), esz = CV_ELEM_SIZE(type); + bool haveC = matC.kind() != cv::_InputArray::NONE; + Size sizeA = matA.size(), sizeB = matB.size(), sizeC = haveC ? matC.size() : Size(0, 0); + bool atrans = (flags & GEMM_1_T) != 0, btrans = (flags & GEMM_2_T) != 0, ctrans = (flags & GEMM_3_T) != 0; + + if (atrans) + sizeA = Size(sizeA.height, sizeA.width); + if (btrans) + sizeB = Size(sizeB.height, sizeB.width); + if (haveC && ctrans) + sizeC = Size(sizeC.height, sizeC.width); + + Size sizeD(sizeB.width, sizeA.height); + + CV_Assert( matB.type() == type && (!haveC || matC.type() == type) ); + CV_Assert( sizeA.width == sizeB.height && (!haveC || sizeC == sizeD) ); + + matD.create(sizeD, type); + if ( matA.offset() % esz != 0 || matA.step() % esz != 0 || + matB.offset() % esz != 0 || matB.step() % esz != 0 || + (haveC && (matC.offset() % esz != 0 || matC.step() % esz != 0)) ) + return false; + + UMat A = matA.getUMat(), B = matB.getUMat(), D = matD.getUMat(); + if (haveC) + ctrans ? transpose(matC, D) : matC.getMat().copyTo(D); // TODO fix it as soon as .copyTo works as expected + else + D.setTo(Scalar::all(0)); + + int M = sizeD.height, N = sizeD.width, K = sizeA.width; + int lda = (int)A.step / esz, ldb = (int)B.step / esz, ldc = (int)D.step / esz; + int offa = (int)A.offset / esz, offb = (int)B.offset / esz, offc = (int)D.offset / esz; + + cl_command_queue clq = (cl_command_queue)ocl::Queue::getDefault().ptr(); + clAmdBlasTranspose transA = atrans ? clAmdBlasTrans : clAmdBlasNoTrans; + clAmdBlasTranspose transB = btrans ? clAmdBlasTrans : clAmdBlasNoTrans; + clAmdBlasOrder order = clAmdBlasRowMajor; + clAmdBlasStatus status = clAmdBlasSuccess; + + if (type == CV_32FC1) + status = clAmdBlasSgemmEx(order, transA, transB, M, N, K, + (cl_float)alpha, (const cl_mem)A.handle(ACCESS_READ), offa, lda, + (const cl_mem)B.handle(ACCESS_READ), offb, ldb, + (cl_float)beta, (cl_mem)D.handle(ACCESS_RW), offc, ldc, + 1, &clq, 0, NULL, NULL); + else if (type == CV_64FC1) + status = clAmdBlasDgemmEx(order, transA, transB, M, N, K, + alpha, (const cl_mem)A.handle(ACCESS_READ), offa, lda, + (const cl_mem)B.handle(ACCESS_READ), offb, ldb, + beta, (cl_mem)D.handle(ACCESS_RW), offc, ldc, + 1, &clq, 0, NULL, NULL); + else if (type == CV_32FC2) + { + cl_float2 alpha_2 = { { (cl_float)alpha, 0 } }; + cl_float2 beta_2 = { { (cl_float)beta, 0 } }; + status = clAmdBlasCgemmEx(order, transA, transB, M, N, K, + alpha_2, (const cl_mem)A.handle(ACCESS_READ), offa, lda, + (const cl_mem)B.handle(ACCESS_READ), offb, ldb, + beta_2, (cl_mem)D.handle(ACCESS_RW), offc, ldc, + 1, &clq, 0, NULL, NULL); + } + else if (type == CV_64FC2) + { + cl_double2 alpha_2 = { { alpha, 0 } }; + cl_double2 beta_2 = { { beta, 0 } }; + status = clAmdBlasZgemmEx(order, transA, transB, M, N, K, + alpha_2, (const cl_mem)A.handle(ACCESS_READ), offa, lda, + (const cl_mem)B.handle(ACCESS_READ), offb, ldb, + beta_2, (cl_mem)D.handle(ACCESS_RW), offc, ldc, + 1, &clq, 0, NULL, NULL); + } + else + CV_Error(Error::StsUnsupportedFormat, ""); + + return status == clAmdBlasSuccess; +} + +#endif + } void cv::gemm( InputArray matA, InputArray matB, double alpha, InputArray matC, double beta, OutputArray _matD, int flags ) { +#ifdef HAVE_CLAMDBLAS + if (ocl::haveAmdBlas() && matA.dims() <= 2 && matB.dims() <= 2 && matC.dims() <= 2 && + ocl::useOpenCL() && _matD.isUMat() && + ocl_gemm(matA, matB, alpha, matC, beta, _matD, flags)) + return; +#endif + const int block_lin_size = 128; const int block_size = block_lin_size * block_lin_size; diff --git a/modules/core/src/ocl.cpp b/modules/core/src/ocl.cpp index 6681e81..835cd2a 100644 --- a/modules/core/src/ocl.cpp +++ b/modules/core/src/ocl.cpp @@ -42,6 +42,8 @@ #include "precomp.hpp" #include +#include "opencv2/core/opencl/runtime/opencl_clamdblas.hpp" + #ifdef HAVE_OPENCL #include "opencv2/core/opencl/runtime/opencl_core.hpp" #else @@ -1309,29 +1311,23 @@ inline bool operator < (const HashKey& h1, const HashKey& h2) return h1.a < h2.a || (h1.a == h2.a && h1.b < h2.b); } -static bool g_isInitialized = false; +static bool g_isOpenCLInitialized = false; static bool g_isOpenCLAvailable = false; + bool haveOpenCL() { - if (!g_isInitialized) + if (!g_isOpenCLInitialized) { - if (!g_isInitialized) + try { - try - { - cl_uint n = 0; - cl_int err = ::clGetPlatformIDs(0, NULL, &n); - if (err != CL_SUCCESS) - g_isOpenCLAvailable = false; - else - g_isOpenCLAvailable = true; - } - catch (...) - { - g_isOpenCLAvailable = false; - } - g_isInitialized = true; + cl_uint n = 0; + g_isOpenCLAvailable = ::clGetPlatformIDs(0, NULL, &n) == CL_SUCCESS; + } + catch (...) + { + g_isOpenCLAvailable = false; } + g_isOpenCLInitialized = true; } return g_isOpenCLAvailable; } @@ -1353,6 +1349,80 @@ void setUseOpenCL(bool flag) } } +#ifdef HAVE_CLAMDBLAS + +class AmdBlasHelper +{ +public: + static AmdBlasHelper & getInstance() + { + static AmdBlasHelper amdBlas; + return amdBlas; + } + + bool isAvailable() const + { + return g_isAmdBlasAvailable; + } + + ~AmdBlasHelper() + { + try + { + clAmdBlasTeardown(); + } + catch (...) { } + } + +protected: + AmdBlasHelper() + { + if (!g_isAmdBlasInitialized) + { + AutoLock lock(m); + + if (!g_isAmdBlasInitialized && haveOpenCL()) + { + try + { + g_isAmdBlasAvailable = clAmdBlasSetup() == clAmdBlasSuccess; + } + catch (...) + { + g_isAmdBlasAvailable = false; + } + } + else + g_isAmdBlasAvailable = false; + + g_isAmdBlasInitialized = true; + } + } + +private: + static Mutex m; + static bool g_isAmdBlasInitialized; + static bool g_isAmdBlasAvailable; +}; + +bool AmdBlasHelper::g_isAmdBlasAvailable = false; +bool AmdBlasHelper::g_isAmdBlasInitialized = false; +Mutex AmdBlasHelper::m; + +bool haveAmdBlas() +{ + return AmdBlasHelper::getInstance().isAvailable(); +} + +#else + +bool haveAmdBlas() +{ + return false; +} + +#endif + void finish2() { Queue::getDefault().finish(); diff --git a/modules/core/test/ocl/test_gemm.cpp b/modules/core/test/ocl/test_gemm.cpp new file mode 100644 index 0000000..4d453f3 --- /dev/null +++ b/modules/core/test/ocl/test_gemm.cpp @@ -0,0 +1,149 @@ +/*M/////////////////////////////////////////////////////////////////////////////////////// +// +// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. +// +// By downloading, copying, installing or using the software you agree to this license. +// If you do not agree to this license, do not download, install, +// copy or use the software. +// +// +// License Agreement +// For Open Source Computer Vision Library +// +// Copyright (C) 2010-2012, Multicoreware, Inc., all rights reserved. +// Copyright (C) 2010-2012, Advanced Micro Devices, Inc., all rights reserved. +// Third party copyrights are property of their respective owners. +// +// @Authors +// Peng Xiao, pengxiao@multicorewareinc.com +// Redistribution and use in source and binary forms, with or without modification, +// are permitted provided that the following conditions are met: +// +// * Redistribution's of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistribution's in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * The name of the copyright holders may not be used to endorse or promote products +// derived from this software without specific prior written permission. +// +// This software is provided by the copyright holders and contributors as is and +// any express or implied warranties, including, but not limited to, the implied +// warranties of merchantability and fitness for a particular purpose are disclaimed. +// In no event shall the Intel Corporation or contributors be liable for any direct, +// indirect, incidental, special, exemplary, or consequential damages +// (including, but not limited to, procurement of substitute goods or services; +// loss of use, data, or profits; or business interruption) however caused +// and on any theory of liability, whether in contract, strict liability, +// or tort (including negligence or otherwise) arising in any way out of +// the use of this software, even if advised of the possibility of such damage. +// +//M*/ + +#include "test_precomp.hpp" +#include "opencv2/ts/ocl_test.hpp" + +#ifdef HAVE_OPENCL + +namespace cvtest { +namespace ocl { + +//////////////////////////////////////////////////////////////////////////// +// GEMM + +PARAM_TEST_CASE(Gemm, + MatType, + bool, // GEMM_1_T + bool, // GEMM_2_T + bool, // GEMM_3_T + bool // ROI + ) +{ + bool use_roi; + int type, flags; + bool atrans, btrans, ctrans; + + double alpha, beta; + + TEST_DECLARE_INPUT_PARAMETER(A) + TEST_DECLARE_INPUT_PARAMETER(B) + TEST_DECLARE_INPUT_PARAMETER(C) + TEST_DECLARE_OUTPUT_PARAMETER(D) + + virtual void SetUp() + { + atrans = btrans = ctrans = false; + + type = GET_PARAM(0); + use_roi = GET_PARAM(4); + + flags = 0; + if (GET_PARAM(1)) + flags |= GEMM_1_T, atrans = true; + if (GET_PARAM(2)) + flags |= GEMM_2_T, btrans = true; + if (GET_PARAM(3)) + flags |= GEMM_3_T, ctrans = true; + } + + void generateTestData() + { + Size ARoiSize = randomSize(1, MAX_VALUE); + Border ABorder = randomBorder(0, use_roi ? MAX_VALUE : 0); + randomSubMat(A, A_roi, ARoiSize, ABorder, type, -11, 11); + + if (atrans) + ARoiSize = Size(ARoiSize.height, ARoiSize.width); + + Size BRoiSize = randomSize(1, MAX_VALUE); + if (btrans) + BRoiSize.width = ARoiSize.width; + else + BRoiSize.height = ARoiSize.width; + + Border BBorder = randomBorder(0, use_roi ? MAX_VALUE : 0); + randomSubMat(B, B_roi, BRoiSize, BBorder, type, -11, 11); + + if (btrans) + BRoiSize = Size(BRoiSize.height, BRoiSize.width); + + Size DRoiSize = Size(BRoiSize.width, ARoiSize.height), CRoiSizeT(DRoiSize.height, DRoiSize.width); + Border CBorder = randomBorder(0, use_roi ? MAX_VALUE : 0); + randomSubMat(C, C_roi, ctrans ? CRoiSizeT : DRoiSize, CBorder, type, -11, 11); + + Border DBorder = randomBorder(0, use_roi ? MAX_VALUE : 0); + randomSubMat(D, D_roi, DRoiSize, DBorder, type, -11, 11); + + alpha = randomDouble(-4, 4); + beta = randomDouble(-4, 4); + + UMAT_UPLOAD_INPUT_PARAMETER(A) + UMAT_UPLOAD_INPUT_PARAMETER(B) + UMAT_UPLOAD_INPUT_PARAMETER(C) + UMAT_UPLOAD_OUTPUT_PARAMETER(D) + } +}; + +OCL_TEST_P(Gemm, Accuracy) +{ + for (int i = 0; i < test_loop_times; ++i) + { + generateTestData(); + + OCL_OFF(cv::gemm(A_roi, B_roi, alpha, C_roi, beta, D_roi, flags)); + OCL_ON(cv::gemm(uA_roi, uB_roi, alpha, uC_roi, beta, uD_roi, flags)); + + double eps = D_roi.size().area() * 1e-4; + OCL_EXPECT_MATS_NEAR(D, eps); + } +} + +OCL_INSTANTIATE_TEST_CASE_P(Core, Gemm, ::testing::Combine( + testing::Values(CV_32FC1, CV_32FC2, CV_64FC1, CV_64FC2), + Bool(), Bool(), Bool(), Bool())); + +} } // namespace cvtest::ocl + +#endif // HAVE_OPENCL -- 2.7.4