added cv::gemm to T-API
authorIlya Lavrenov <ilya.lavrenov@itseez.com>
Sat, 14 Dec 2013 19:16:53 +0000 (23:16 +0400)
committerIlya Lavrenov <ilya.lavrenov@itseez.com>
Mon, 16 Dec 2013 15:12:27 +0000 (19:12 +0400)
modules/core/include/opencv2/core/ocl.hpp
modules/core/src/matmul.cpp
modules/core/src/ocl.cpp
modules/core/test/ocl/test_gemm.cpp [new file with mode: 0644]

index f253594..2a90b85 100644 (file)
@@ -48,6 +48,7 @@ namespace cv { namespace ocl {
 
 CV_EXPORTS bool haveOpenCL();
 CV_EXPORTS bool useOpenCL();
+CV_EXPORTS bool haveAmdBlas();
 CV_EXPORTS void setUseOpenCL(bool flag);
 CV_EXPORTS void finish2();
 
index 7d832cb..16eb6e0 100644 (file)
@@ -41,6 +41,7 @@
 //M*/
 
 #include "precomp.hpp"
+#include "opencv2/core/opencl/runtime/opencl_clamdblas.hpp"
 
 #ifdef HAVE_IPP
 #include "ippversion.h"
@@ -693,11 +694,102 @@ static void GEMMStore_64fc( const Complexd* c_data, size_t c_step,
     GEMMStore(c_data, c_step, d_buf, d_buf_step, d_data, d_step, d_size, alpha, beta, flags);
 }
 
+#ifdef HAVE_CLAMDBLAS
+
+static bool ocl_gemm( InputArray matA, InputArray matB, double alpha,
+                      InputArray matC, double beta, OutputArray matD, int flags )
+{
+    int type = matA.type(), esz = CV_ELEM_SIZE(type);
+    bool haveC = matC.kind() != cv::_InputArray::NONE;
+    Size sizeA = matA.size(), sizeB = matB.size(), sizeC = haveC ? matC.size() : Size(0, 0);
+    bool atrans = (flags & GEMM_1_T) != 0, btrans = (flags & GEMM_2_T) != 0, ctrans = (flags & GEMM_3_T) != 0;
+
+    if (atrans)
+        sizeA = Size(sizeA.height, sizeA.width);
+    if (btrans)
+        sizeB = Size(sizeB.height, sizeB.width);
+    if (haveC && ctrans)
+        sizeC = Size(sizeC.height, sizeC.width);
+
+    Size sizeD(sizeB.width, sizeA.height);
+
+    CV_Assert( matB.type() == type && (!haveC || matC.type() == type) );
+    CV_Assert( sizeA.width == sizeB.height && (!haveC || sizeC == sizeD) );
+
+    matD.create(sizeD, type);
+    if ( matA.offset() % esz != 0 || matA.step() % esz != 0 ||
+         matB.offset() % esz != 0 || matB.step() % esz != 0 ||
+         (haveC && (matC.offset() % esz != 0 || matC.step() % esz != 0)) )
+        return false;
+
+    UMat A = matA.getUMat(), B = matB.getUMat(), D = matD.getUMat();
+    if (haveC)
+        ctrans ? transpose(matC, D) : matC.getMat().copyTo(D); // TODO fix it as soon as .copyTo works as expected
+    else
+        D.setTo(Scalar::all(0));
+
+    int M = sizeD.height, N = sizeD.width, K = sizeA.width;
+    int lda = (int)A.step / esz, ldb = (int)B.step / esz, ldc = (int)D.step / esz;
+    int offa = (int)A.offset / esz, offb = (int)B.offset / esz, offc = (int)D.offset / esz;
+
+    cl_command_queue clq = (cl_command_queue)ocl::Queue::getDefault().ptr();
+    clAmdBlasTranspose transA = atrans ? clAmdBlasTrans : clAmdBlasNoTrans;
+    clAmdBlasTranspose transB = btrans ? clAmdBlasTrans : clAmdBlasNoTrans;
+    clAmdBlasOrder order = clAmdBlasRowMajor;
+    clAmdBlasStatus status = clAmdBlasSuccess;
+
+    if (type == CV_32FC1)
+        status = clAmdBlasSgemmEx(order, transA, transB, M, N, K,
+                                  (cl_float)alpha, (const cl_mem)A.handle(ACCESS_READ), offa, lda,
+                                  (const cl_mem)B.handle(ACCESS_READ), offb, ldb,
+                                  (cl_float)beta, (cl_mem)D.handle(ACCESS_RW), offc, ldc,
+                                  1, &clq, 0, NULL, NULL);
+    else if (type == CV_64FC1)
+        status = clAmdBlasDgemmEx(order, transA, transB, M, N, K,
+                                  alpha, (const cl_mem)A.handle(ACCESS_READ), offa, lda,
+                                  (const cl_mem)B.handle(ACCESS_READ), offb, ldb,
+                                  beta, (cl_mem)D.handle(ACCESS_RW), offc, ldc,
+                                  1, &clq, 0, NULL, NULL);
+    else if (type == CV_32FC2)
+    {
+         cl_float2 alpha_2 = { { (cl_float)alpha, 0 } };
+         cl_float2 beta_2  = { { (cl_float)beta, 0 } };
+         status = clAmdBlasCgemmEx(order, transA, transB, M, N, K,
+                                   alpha_2, (const cl_mem)A.handle(ACCESS_READ), offa, lda,
+                                   (const cl_mem)B.handle(ACCESS_READ), offb, ldb,
+                                   beta_2, (cl_mem)D.handle(ACCESS_RW), offc, ldc,
+                                   1, &clq, 0, NULL, NULL);
+    }
+    else if (type == CV_64FC2)
+    {
+        cl_double2 alpha_2 = { { alpha, 0 } };
+        cl_double2 beta_2  = { { beta, 0 } };
+        status = clAmdBlasZgemmEx(order, transA, transB, M, N, K,
+                                  alpha_2, (const cl_mem)A.handle(ACCESS_READ), offa, lda,
+                                  (const cl_mem)B.handle(ACCESS_READ), offb, ldb,
+                                  beta_2, (cl_mem)D.handle(ACCESS_RW), offc, ldc,
+                                  1, &clq, 0, NULL, NULL);
+    }
+    else
+        CV_Error(Error::StsUnsupportedFormat, "");
+
+    return status == clAmdBlasSuccess;
+}
+
+#endif
+
 }
 
 void cv::gemm( InputArray matA, InputArray matB, double alpha,
            InputArray matC, double beta, OutputArray _matD, int flags )
 {
+#ifdef HAVE_CLAMDBLAS
+    if (ocl::haveAmdBlas() && matA.dims() <= 2 && matB.dims() <= 2 && matC.dims() <= 2 &&
+            ocl::useOpenCL() && _matD.isUMat() &&
+            ocl_gemm(matA, matB, alpha, matC, beta, _matD, flags))
+        return;
+#endif
+
     const int block_lin_size = 128;
     const int block_size = block_lin_size * block_lin_size;
 
index 6681e81..835cd2a 100644 (file)
@@ -42,6 +42,8 @@
 #include "precomp.hpp"
 #include <map>
 
+#include "opencv2/core/opencl/runtime/opencl_clamdblas.hpp"
+
 #ifdef HAVE_OPENCL
 #include "opencv2/core/opencl/runtime/opencl_core.hpp"
 #else
@@ -1309,29 +1311,23 @@ inline bool operator < (const HashKey& h1, const HashKey& h2)
     return h1.a < h2.a || (h1.a == h2.a && h1.b < h2.b);
 }
 
-static bool g_isInitialized = false;
+static bool g_isOpenCLInitialized = false;
 static bool g_isOpenCLAvailable = false;
+
 bool haveOpenCL()
 {
-    if (!g_isInitialized)
+    if (!g_isOpenCLInitialized)
     {
-        if (!g_isInitialized)
+        try
         {
-            try
-            {
-                cl_uint n = 0;
-                cl_int err = ::clGetPlatformIDs(0, NULL, &n);
-                if (err != CL_SUCCESS)
-                    g_isOpenCLAvailable = false;
-                else
-                    g_isOpenCLAvailable = true;
-            }
-            catch (...)
-            {
-                g_isOpenCLAvailable = false;
-            }
-            g_isInitialized = true;
+            cl_uint n = 0;
+            g_isOpenCLAvailable = ::clGetPlatformIDs(0, NULL, &n) == CL_SUCCESS;
+        }
+        catch (...)
+        {
+            g_isOpenCLAvailable = false;
         }
+        g_isOpenCLInitialized = true;
     }
     return g_isOpenCLAvailable;
 }
@@ -1353,6 +1349,80 @@ void setUseOpenCL(bool flag)
     }
 }
 
+#ifdef HAVE_CLAMDBLAS
+
+class AmdBlasHelper
+{
+public:
+    static AmdBlasHelper & getInstance()
+    {
+        static AmdBlasHelper amdBlas;
+        return amdBlas;
+    }
+
+    bool isAvailable() const
+    {
+        return g_isAmdBlasAvailable;
+    }
+
+    ~AmdBlasHelper()
+    {
+        try
+        {
+            clAmdBlasTeardown();
+        }
+        catch (...) { }
+    }
+
+protected:
+    AmdBlasHelper()
+    {
+        if (!g_isAmdBlasInitialized)
+        {
+            AutoLock lock(m);
+
+            if (!g_isAmdBlasInitialized && haveOpenCL())
+            {
+                try
+                {
+                    g_isAmdBlasAvailable = clAmdBlasSetup() == clAmdBlasSuccess;
+                }
+                catch (...)
+                {
+                    g_isAmdBlasAvailable = false;
+                }
+            }
+            else
+                g_isAmdBlasAvailable = false;
+
+            g_isAmdBlasInitialized = true;
+        }
+    }
+
+private:
+    static Mutex m;
+    static bool g_isAmdBlasInitialized;
+    static bool g_isAmdBlasAvailable;
+};
+
+bool AmdBlasHelper::g_isAmdBlasAvailable = false;
+bool AmdBlasHelper::g_isAmdBlasInitialized = false;
+Mutex AmdBlasHelper::m;
+
+bool haveAmdBlas()
+{
+    return AmdBlasHelper::getInstance().isAvailable();
+}
+
+#else
+
+bool haveAmdBlas()
+{
+    return false;
+}
+
+#endif
+
 void finish2()
 {
     Queue::getDefault().finish();
diff --git a/modules/core/test/ocl/test_gemm.cpp b/modules/core/test/ocl/test_gemm.cpp
new file mode 100644 (file)
index 0000000..4d453f3
--- /dev/null
@@ -0,0 +1,149 @@
+/*M///////////////////////////////////////////////////////////////////////////////////////
+//
+//  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
+//
+//  By downloading, copying, installing or using the software you agree to this license.
+//  If you do not agree to this license, do not download, install,
+//  copy or use the software.
+//
+//
+//                           License Agreement
+//                For Open Source Computer Vision Library
+//
+// Copyright (C) 2010-2012, Multicoreware, Inc., all rights reserved.
+// Copyright (C) 2010-2012, Advanced Micro Devices, Inc., all rights reserved.
+// Third party copyrights are property of their respective owners.
+//
+// @Authors
+//    Peng Xiao, pengxiao@multicorewareinc.com
+// Redistribution and use in source and binary forms, with or without modification,
+// are permitted provided that the following conditions are met:
+//
+//   * Redistribution's of source code must retain the above copyright notice,
+//     this list of conditions and the following disclaimer.
+//
+//   * Redistribution's in binary form must reproduce the above copyright notice,
+//     this list of conditions and the following disclaimer in the documentation
+//     and/or other materials provided with the distribution.
+//
+//   * The name of the copyright holders may not be used to endorse or promote products
+//     derived from this software without specific prior written permission.
+//
+// This software is provided by the copyright holders and contributors as is and
+// any express or implied warranties, including, but not limited to, the implied
+// warranties of merchantability and fitness for a particular purpose are disclaimed.
+// In no event shall the Intel Corporation or contributors be liable for any direct,
+// indirect, incidental, special, exemplary, or consequential damages
+// (including, but not limited to, procurement of substitute goods or services;
+// loss of use, data, or profits; or business interruption) however caused
+// and on any theory of liability, whether in contract, strict liability,
+// or tort (including negligence or otherwise) arising in any way out of
+// the use of this software, even if advised of the possibility of such damage.
+//
+//M*/
+
+#include "test_precomp.hpp"
+#include "opencv2/ts/ocl_test.hpp"
+
+#ifdef HAVE_OPENCL
+
+namespace cvtest {
+namespace ocl {
+
+////////////////////////////////////////////////////////////////////////////
+// GEMM
+
+PARAM_TEST_CASE(Gemm,
+                MatType,
+                bool, // GEMM_1_T
+                bool, // GEMM_2_T
+                bool, // GEMM_3_T
+                bool // ROI
+                )
+{
+    bool use_roi;
+    int type, flags;
+    bool atrans, btrans, ctrans;
+
+    double alpha, beta;
+
+    TEST_DECLARE_INPUT_PARAMETER(A)
+    TEST_DECLARE_INPUT_PARAMETER(B)
+    TEST_DECLARE_INPUT_PARAMETER(C)
+    TEST_DECLARE_OUTPUT_PARAMETER(D)
+
+    virtual void SetUp()
+    {
+        atrans = btrans = ctrans = false;
+
+        type = GET_PARAM(0);
+        use_roi = GET_PARAM(4);
+
+        flags = 0;
+        if (GET_PARAM(1))
+            flags |= GEMM_1_T, atrans = true;
+        if (GET_PARAM(2))
+            flags |= GEMM_2_T, btrans = true;
+        if (GET_PARAM(3))
+            flags |= GEMM_3_T, ctrans = true;
+    }
+
+    void generateTestData()
+    {
+        Size ARoiSize = randomSize(1, MAX_VALUE);
+        Border ABorder = randomBorder(0, use_roi ? MAX_VALUE : 0);
+        randomSubMat(A, A_roi, ARoiSize, ABorder, type, -11, 11);
+
+        if (atrans)
+            ARoiSize = Size(ARoiSize.height, ARoiSize.width);
+
+        Size BRoiSize = randomSize(1, MAX_VALUE);
+        if (btrans)
+            BRoiSize.width = ARoiSize.width;
+        else
+            BRoiSize.height = ARoiSize.width;
+
+        Border BBorder = randomBorder(0, use_roi ? MAX_VALUE : 0);
+        randomSubMat(B, B_roi, BRoiSize, BBorder, type, -11, 11);
+
+        if (btrans)
+            BRoiSize = Size(BRoiSize.height, BRoiSize.width);
+
+        Size DRoiSize = Size(BRoiSize.width, ARoiSize.height), CRoiSizeT(DRoiSize.height, DRoiSize.width);
+        Border CBorder = randomBorder(0, use_roi ? MAX_VALUE : 0);
+        randomSubMat(C, C_roi, ctrans ? CRoiSizeT : DRoiSize, CBorder, type, -11, 11);
+
+        Border DBorder = randomBorder(0, use_roi ? MAX_VALUE : 0);
+        randomSubMat(D, D_roi, DRoiSize, DBorder, type, -11, 11);
+
+        alpha = randomDouble(-4, 4);
+        beta = randomDouble(-4, 4);
+
+        UMAT_UPLOAD_INPUT_PARAMETER(A)
+        UMAT_UPLOAD_INPUT_PARAMETER(B)
+        UMAT_UPLOAD_INPUT_PARAMETER(C)
+        UMAT_UPLOAD_OUTPUT_PARAMETER(D)
+    }
+};
+
+OCL_TEST_P(Gemm, Accuracy)
+{
+    for (int i = 0; i < test_loop_times; ++i)
+    {
+        generateTestData();
+
+        OCL_OFF(cv::gemm(A_roi, B_roi, alpha, C_roi, beta, D_roi, flags));
+        OCL_ON(cv::gemm(uA_roi, uB_roi, alpha, uC_roi, beta, uD_roi, flags));
+
+        double eps = D_roi.size().area() * 1e-4;
+        OCL_EXPECT_MATS_NEAR(D, eps);
+    }
+}
+
+OCL_INSTANTIATE_TEST_CASE_P(Core, Gemm, ::testing::Combine(
+                            testing::Values(CV_32FC1, CV_32FC2, CV_64FC1, CV_64FC2),
+                            Bool(), Bool(), Bool(), Bool()));
+
+} } // namespace cvtest::ocl
+
+#endif // HAVE_OPENCL