From 7379152afbc21f616d7a53cf34fe92607861a940 Mon Sep 17 00:00:00 2001 From: Ilya Lavrenov Date: Mon, 30 Sep 2013 16:54:09 +0400 Subject: [PATCH] fixed ocl::setIdentity --- modules/ocl/include/opencv2/ocl/ocl.hpp | 3 +- modules/ocl/src/arithm.cpp | 70 ++++++-------------- modules/ocl/src/opencl/arithm_setidentity.cl | 47 +++---------- modules/ocl/test/test_arithm.cpp | 60 +++++++++++------ 4 files changed, 70 insertions(+), 110 deletions(-) diff --git a/modules/ocl/include/opencv2/ocl/ocl.hpp b/modules/ocl/include/opencv2/ocl/ocl.hpp index d3dbded34d..bb23e1323f 100644 --- a/modules/ocl/include/opencv2/ocl/ocl.hpp +++ b/modules/ocl/include/opencv2/ocl/ocl.hpp @@ -584,7 +584,8 @@ namespace cv CV_EXPORTS void cvtColor(const oclMat &src, oclMat &dst, int code , int dcn = 0); - CV_EXPORTS void setIdentity(oclMat& src, double val); + //! initializes a scaled identity matrix + CV_EXPORTS void setIdentity(oclMat& src, const Scalar & val = Scalar(1)); //////////////////////////////// Filter Engine //////////////////////////////// diff --git a/modules/ocl/src/arithm.cpp b/modules/ocl/src/arithm.cpp index 6467040f1c..883ac8f0da 100644 --- a/modules/ocl/src/arithm.cpp +++ b/modules/ocl/src/arithm.cpp @@ -1709,63 +1709,35 @@ void cv::ocl::pow(const oclMat &x, double p, oclMat &y) /////////////////////////////// setIdentity ////////////////////////////////// ////////////////////////////////////////////////////////////////////////////// -void cv::ocl::setIdentity(oclMat& src, double scalar) +void cv::ocl::setIdentity(oclMat& src, const Scalar & scalar) { - CV_Assert(src.empty() == false && src.rows == src.cols); - CV_Assert(src.type() == CV_32SC1 || src.type() == CV_32FC1); - int src_step = src.step/src.elemSize(); Context *clCxt = Context::getContext(); - size_t local_threads[] = {16, 16, 1}; - size_t global_threads[] = {src.cols, src.rows, 1}; - - string kernelName = "setIdentityKernel"; - if (src.type() == CV_32FC1) - kernelName += "_F1"; - else if (src.type() == CV_32SC1) - kernelName += "_I1"; - else + if (!clCxt->supportsFeature(Context::CL_DOUBLE) && src.depth() == CV_64F) { - kernelName += "_D1"; - if (!(clCxt->supportsFeature(Context::CL_DOUBLE))) - { - oclMat temp; - src.convertTo(temp, CV_32FC1); - temp.copyTo(src); - } - + CV_Error(CV_GpuNotSupported, "Selected device doesn't support double\r\n"); + return; } + CV_Assert(src.step % src.elemSize() == 0); + + int src_step1 = src.step / src.elemSize(), src_offset1 = src.offset / src.elemSize(); + size_t local_threads[] = { 16, 16, 1 }; + size_t global_threads[] = { src.cols, src.rows, 1 }; + + const char * const typeMap[] = { "uchar", "char", "ushort", "short", "int", "float", "double" }; + const char * const channelMap[] = { "", "", "2", "4", "4" }; + string buildOptions = format("-D T=%s%s", typeMap[src.depth()], channelMap[src.oclchannels()]); + vector > args; args.push_back( make_pair( sizeof(cl_mem), (void *)&src.data )); - args.push_back( make_pair( sizeof(cl_int), (void *)&src.rows)); + args.push_back( make_pair( sizeof(cl_int), (void *)&src_step1 )); + args.push_back( make_pair( sizeof(cl_int), (void *)&src_offset1 )); args.push_back( make_pair( sizeof(cl_int), (void *)&src.cols)); - args.push_back( make_pair( sizeof(cl_int), (void *)&src_step )); + args.push_back( make_pair( sizeof(cl_int), (void *)&src.rows)); - int scalar_i = 0; - float scalar_f = 0.0f; - if (clCxt->supportsFeature(Context::CL_DOUBLE)) - { - if (src.type() == CV_32SC1) - { - scalar_i = (int)scalar; - args.push_back(make_pair(sizeof(cl_int), (void*)&scalar_i)); - } - else - args.push_back(make_pair(sizeof(cl_double), (void*)&scalar)); - } - else - { - if (src.type() == CV_32SC1) - { - scalar_i = (int)scalar; - args.push_back(make_pair(sizeof(cl_int), (void*)&scalar_i)); - } - else - { - scalar_f = (float)scalar; - args.push_back(make_pair(sizeof(cl_float), (void*)&scalar_f)); - } - } + oclMat sc(1, 1, src.type(), scalar); + args.push_back( make_pair( sizeof(cl_mem), (void *)&sc.data )); - openCLExecuteKernel(clCxt, &arithm_setidentity, kernelName, global_threads, local_threads, args, -1, -1); + openCLExecuteKernel(clCxt, &arithm_setidentity, "setIdentity", global_threads, local_threads, + args, -1, -1, buildOptions.c_str()); } diff --git a/modules/ocl/src/opencl/arithm_setidentity.cl b/modules/ocl/src/opencl/arithm_setidentity.cl index 0604ae81dd..921026b40d 100644 --- a/modules/ocl/src/opencl/arithm_setidentity.cl +++ b/modules/ocl/src/opencl/arithm_setidentity.cl @@ -42,6 +42,7 @@ // the use of this software, even if advised of the possibility of such damage. // //M*/ + #if defined (DOUBLE_SUPPORT) #ifdef cl_khr_fp64 #pragma OPENCL EXTENSION cl_khr_fp64:enable @@ -50,51 +51,19 @@ #endif #endif - -#if defined (DOUBLE_SUPPORT) -#define DATA_TYPE double -#else -#define DATA_TYPE float -#endif - -__kernel void setIdentityKernel_F1(__global float* src, int src_row, int src_col, int src_step, DATA_TYPE scalar) -{ - int x = get_global_id(0); - int y = get_global_id(1); - - if(x < src_col && y < src_row) - { - if(x == y) - src[y * src_step + x] = scalar; - else - src[y * src_step + x] = 0 * scalar; - } -} - -__kernel void setIdentityKernel_D1(__global DATA_TYPE* src, int src_row, int src_col, int src_step, DATA_TYPE scalar) +__kernel void setIdentity(__global T * src, int src_step, int src_offset, + int cols, int rows, __global const T * scalar) { int x = get_global_id(0); int y = get_global_id(1); - if(x < src_col && y < src_row) + if (x < cols && y < rows) { - if(x == y) - src[y * src_step + x] = scalar; - else - src[y * src_step + x] = 0 * scalar; - } -} + int src_index = mad24(y, src_step, src_offset + x); -__kernel void setIdentityKernel_I1(__global int* src, int src_row, int src_col, int src_step, int scalar) -{ - int x = get_global_id(0); - int y = get_global_id(1); - - if(x < src_col && y < src_row) - { - if(x == y) - src[y * src_step + x] = scalar; + if (x == y) + src[src_index] = *scalar; else - src[y * src_step + x] = 0 * scalar; + src[src_index] = 0; } } diff --git a/modules/ocl/test/test_arithm.cpp b/modules/ocl/test/test_arithm.cpp index 2438148033..ee45cf5e3f 100644 --- a/modules/ocl/test/test_arithm.cpp +++ b/modules/ocl/test/test_arithm.cpp @@ -1423,34 +1423,52 @@ TEST_P(AddWeighted, Mat) } } +//////////////////////////////// setIdentity ///////////////////////////////////////////////// + +typedef ArithmTestBase SetIdentity; + +TEST_P(SetIdentity, Mat) +{ + for (int j = 0; j < LOOP_TIMES; j++) + { + random_roi(); + + cv::setIdentity(dst1_roi, val); + cv::ocl::setIdentity(gdst1, val); + + Near(0); + } +} + //////////////////////////////////////// Instantiation ///////////////////////////////////////// -INSTANTIATE_TEST_CASE_P(Arithm, Lut, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool(), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, Exp, Combine(testing::Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, Log, Combine(testing::Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, Add, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, Sub, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + +INSTANTIATE_TEST_CASE_P(Arithm, Lut, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool(), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, Exp, Combine(testing::Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, Log, Combine(testing::Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, Add, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, Sub, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); INSTANTIATE_TEST_CASE_P(Arithm, Mul, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); INSTANTIATE_TEST_CASE_P(Arithm, Div, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); -INSTANTIATE_TEST_CASE_P(Arithm, Absdiff, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, CartToPolar, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, PolarToCart, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, Magnitude, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, Transpose, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, Flip, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + +INSTANTIATE_TEST_CASE_P(Arithm, Absdiff, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, CartToPolar, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, PolarToCart, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, Magnitude, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, Transpose, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, Flip, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); INSTANTIATE_TEST_CASE_P(Arithm, MinMax, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool())); -INSTANTIATE_TEST_CASE_P(Arithm, MinMaxLoc, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool())); // + +INSTANTIATE_TEST_CASE_P(Arithm, MinMaxLoc, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool())); INSTANTIATE_TEST_CASE_P(Arithm, Sum, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); INSTANTIATE_TEST_CASE_P(Arithm, SqrSum, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); INSTANTIATE_TEST_CASE_P(Arithm, AbsSum, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); -INSTANTIATE_TEST_CASE_P(Arithm, CountNonZero, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, Phase, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_and, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_or, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_xor, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_not, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, Compare, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, Pow, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); // + -INSTANTIATE_TEST_CASE_P(Arithm, AddWeighted, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); // + +INSTANTIATE_TEST_CASE_P(Arithm, CountNonZero, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, Phase, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_and, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_or, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_xor, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, Bitwise_not, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, Compare, Combine(testing::Range(CV_8U, CV_USRTYPE1), Values(1), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, Pow, Combine(Values(CV_32F, CV_64F), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, AddWeighted, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); +INSTANTIATE_TEST_CASE_P(Arithm, SetIdentity, Combine(testing::Range(CV_8U, CV_USRTYPE1), testing::Range(1, 5), Bool())); #endif // HAVE_OPENCL -- 2.34.1