added tests for gpu::sum, it supports all data types, but single channel images only
authorAlexey Spizhevoy <no@email>
Mon, 13 Dec 2010 12:00:58 +0000 (12:00 +0000)
committerAlexey Spizhevoy <no@email>
Mon, 13 Dec 2010 12:00:58 +0000 (12:00 +0000)
modules/gpu/include/opencv2/gpu/gpu.hpp
modules/gpu/src/arithm.cpp
modules/gpu/src/cuda/mathfunc.cu
tests/gpu/src/arithm.cpp
tests/gpu/src/gputest_main.cpp

index dafa5e1..f0d4dd3 100644 (file)
@@ -421,9 +421,12 @@ namespace cv
         CV_EXPORTS void flip(const GpuMat& a, GpuMat& b, int flipCode);\r
 \r
         //! computes sum of array elements\r
-        //! supports CV_8UC1, CV_8UC4 types\r
-        //! disabled until fix crash\r
-        CV_EXPORTS Scalar sum(const GpuMat& m);\r
+        //! supports only single channel images\r
+        CV_EXPORTS Scalar sum(const GpuMat& src);\r
+\r
+        //! computes sum of array elements\r
+        //! supports only single channel images\r
+        CV_EXPORTS Scalar sum(const GpuMat& src, GpuMat& buf);\r
 \r
         //! finds global minimum and maximum array elements and returns their values\r
         CV_EXPORTS void minMax(const GpuMat& src, double* minVal, double* maxVal=0, const GpuMat& mask=GpuMat());\r
index 3dcae2c..049bfa4 100644 (file)
@@ -65,6 +65,7 @@ double cv::gpu::norm(const GpuMat&, int) { throw_nogpu(); return 0.0; }
 double cv::gpu::norm(const GpuMat&, const GpuMat&, int) { throw_nogpu(); return 0.0; }\r
 void cv::gpu::flip(const GpuMat&, GpuMat&, int) { throw_nogpu(); }\r
 Scalar cv::gpu::sum(const GpuMat&) { throw_nogpu(); return Scalar(); }\r
+Scalar cv::gpu::sum(const GpuMat&, GpuMat&) { throw_nogpu(); return Scalar(); }\r
 void cv::gpu::minMax(const GpuMat&, double*, double*, const GpuMat&) { throw_nogpu(); }\r
 void cv::gpu::minMax(const GpuMat&, double*, double*, const GpuMat&, GpuMat&) { throw_nogpu(); }\r
 void cv::gpu::minMaxLoc(const GpuMat&, double*, double*, Point*, Point*, const GpuMat&) { throw_nogpu(); }\r
@@ -480,36 +481,50 @@ void cv::gpu::flip(const GpuMat& src, GpuMat& dst, int flipCode)
 ////////////////////////////////////////////////////////////////////////\r
 // sum\r
 \r
-Scalar cv::gpu::sum(const GpuMat& src)\r
+namespace cv { namespace gpu { namespace mathfunc\r
 {\r
-    CV_Assert(!"disabled until fix crash");\r
+    template <typename T>\r
+    void sum_caller(const DevMem2D src, PtrStep buf, double* sum);\r
 \r
-    CV_Assert(src.type() == CV_8UC1 || src.type() == CV_8UC4);\r
+    template <typename T>\r
+    void sum_multipass_caller(const DevMem2D src, PtrStep buf, double* sum);\r
 \r
-    NppiSize sz;\r
-    sz.width  = src.cols;\r
-    sz.height = src.rows;\r
+    namespace sum\r
+    {\r
+        void get_buf_size_required(int cols, int rows, int& bufcols, int& bufrows);\r
+    }\r
+}}}\r
 \r
-    Scalar res;\r
+Scalar cv::gpu::sum(const GpuMat& src) \r
+{\r
+    GpuMat buf;\r
+    return sum(src, buf);\r
+}\r
 \r
-    int bufsz;\r
+Scalar cv::gpu::sum(const GpuMat& src, GpuMat& buf) \r
+{\r
+    using namespace mathfunc;\r
+    CV_Assert(src.channels() == 1);\r
 \r
-    if (src.type() == CV_8UC1)\r
-    {\r
-        nppiReductionGetBufferHostSize_8u_C1R(sz, &bufsz);\r
-        GpuMat buf(1, bufsz, CV_32S);\r
+    typedef void (*Caller)(const DevMem2D, PtrStep, double*);\r
+    static const Caller callers[2][7] = \r
+        { { sum_multipass_caller<unsigned char>, sum_multipass_caller<char>, \r
+            sum_multipass_caller<unsigned short>, sum_multipass_caller<short>, \r
+            sum_multipass_caller<int>, sum_multipass_caller<float>, 0 },\r
+          { sum_caller<unsigned char>, sum_caller<char>, \r
+            sum_caller<unsigned short>, sum_caller<short>, \r
+            sum_caller<int>, sum_caller<float>, sum_caller<double> } };\r
 \r
-        nppSafeCall( nppiSum_8u_C1R(src.ptr<Npp8u>(), src.step, sz, buf.ptr<Npp32s>(), res.val) );\r
-    }\r
-    else\r
-    {\r
-        nppiReductionGetBufferHostSize_8u_C4R(sz, &bufsz);\r
-        GpuMat buf(1, bufsz, CV_32S);\r
+    Size bufSize;\r
+    sum::get_buf_size_required(src.cols, src.rows, bufSize.width, bufSize.height); \r
+    buf.create(bufSize, CV_8U);\r
 \r
-        nppSafeCall( nppiSum_8u_C4R(src.ptr<Npp8u>(), src.step, sz, buf.ptr<Npp32s>(), res.val) );\r
-    }\r
+    Caller caller = callers[hasAtomicsSupport(getDevice())][src.type()];\r
+    if (!caller) CV_Error(CV_StsBadArg, "sum: unsupported type");\r
 \r
-    return res;\r
+    double result;\r
+    caller(src, buf, &result);\r
+    return result;\r
 }\r
 \r
 ////////////////////////////////////////////////////////////////////////\r
index b06bef0..3c620c0 100644 (file)
@@ -1419,6 +1419,15 @@ namespace cv { namespace gpu { namespace mathfunc
     namespace sum \r
     {\r
 \r
+    template <typename T> struct SumType {};\r
+    template <> struct SumType<unsigned char> { typedef unsigned int R; };\r
+    template <> struct SumType<char> { typedef int R; };\r
+    template <> struct SumType<unsigned short> { typedef unsigned int R; };\r
+    template <> struct SumType<short> { typedef int R; };\r
+    template <> struct SumType<int> { typedef int R; };\r
+    template <> struct SumType<float> { typedef float R; };\r
+    template <> struct SumType<double> { typedef double R; };\r
+\r
     __constant__ int ctwidth;\r
     __constant__ int ctheight;\r
     __device__ unsigned int blocks_finished = 0;\r
@@ -1436,12 +1445,11 @@ namespace cv { namespace gpu { namespace mathfunc
     }\r
 \r
 \r
-    template <typename T>\r
     void get_buf_size_required(int cols, int rows, int& bufcols, int& bufrows)\r
     {\r
         dim3 threads, grid;\r
         estimate_thread_cfg(cols, rows, threads, grid);\r
-        bufcols = grid.x * grid.y * sizeof(T);\r
+        bufcols = grid.x * grid.y * sizeof(double);\r
         bufrows = 1;\r
     }\r
 \r
@@ -1454,17 +1462,17 @@ namespace cv { namespace gpu { namespace mathfunc
         cudaSafeCall(cudaMemcpyToSymbol(ctheight, &theight, sizeof(theight))); \r
     }\r
 \r
-    template <typename T, int nthreads>\r
-    __global__ void sum_kernel(const DevMem2D_<T> src, T* result)\r
+    template <typename T, typename R, int nthreads>\r
+    __global__ void sum_kernel(const DevMem2D_<T> src, R* result)\r
     {\r
-        __shared__ T smem[nthreads];\r
+        __shared__ R smem[nthreads];\r
 \r
         const int x0 = blockIdx.x * blockDim.x * ctwidth + threadIdx.x;\r
         const int y0 = blockIdx.y * blockDim.y * ctheight + threadIdx.y;\r
         const int tid = threadIdx.y * blockDim.x + threadIdx.x;\r
         const int bid = blockIdx.y * gridDim.x + blockIdx.x;\r
 \r
-        T sum = 0;\r
+        R sum = 0;\r
         for (int y = 0; y < ctheight && y0 + y * blockDim.y < src.rows; ++y)\r
         {\r
             const T* ptr = src.ptr(y0 + y * blockDim.y);\r
@@ -1475,7 +1483,7 @@ namespace cv { namespace gpu { namespace mathfunc
         smem[tid] = sum;\r
         __syncthreads();\r
 \r
-        sum_in_smem<nthreads, T>(smem, tid);\r
+        sum_in_smem<nthreads, R>(smem, tid);\r
 \r
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 110\r
         __shared__ bool is_last;\r
@@ -1496,7 +1504,7 @@ namespace cv { namespace gpu { namespace mathfunc
             smem[tid] = tid < gridDim.x * gridDim.y ? result[tid] : 0;\r
             __syncthreads();\r
 \r
-            sum_in_smem<nthreads, T>(smem, tid);\r
+            sum_in_smem<nthreads, R>(smem, tid);\r
 \r
             if (tid == 0) \r
             {\r
@@ -1510,14 +1518,16 @@ namespace cv { namespace gpu { namespace mathfunc
     }\r
 \r
 \r
-    template <typename T, int nthreads>\r
-    __global__ void sum_pass2_kernel(T* result, int size)\r
+    template <typename T, typename R, int nthreads>\r
+    __global__ void sum_pass2_kernel(R* result, int size)\r
     {\r
-        __shared__ T smem[nthreads];\r
+        __shared__ R smem[nthreads];\r
         int tid = threadIdx.y * blockDim.x + threadIdx.x;\r
 \r
         smem[tid] = tid < size ? result[tid] : 0;\r
-        sum_in_smem<nthreads, T>(smem, tid);\r
+        __syncthreads();\r
+\r
+        sum_in_smem<nthreads, R>(smem, tid);\r
 \r
         if (tid == 0) \r
             result[0] = smem[0];\r
@@ -1527,60 +1537,61 @@ namespace cv { namespace gpu { namespace mathfunc
 \r
 \r
     template <typename T>\r
-    T sum_multipass_caller(const DevMem2D_<T> src, PtrStep buf)\r
+    void sum_multipass_caller(const DevMem2D src, PtrStep buf, double* sum)\r
     {\r
         using namespace sum;\r
+        typedef typename SumType<T>::R R;\r
 \r
         dim3 threads, grid;\r
         estimate_thread_cfg(src.cols, src.rows, threads, grid);\r
         set_kernel_consts(src.cols, src.rows, threads, grid);\r
 \r
-        T* buf_ = (T*)buf.ptr(0);\r
+        R* buf_ = (R*)buf.ptr(0);\r
 \r
-        sum_kernel<T, threads_x * threads_y><<<grid, threads>>>(src, buf_);\r
-        sum_pass2_kernel<T, threads_x * threads_y><<<1, threads_x * threads_y>>>(\r
+        sum_kernel<T, R, threads_x * threads_y><<<grid, threads>>>((const DevMem2D_<T>)src, buf_);\r
+        sum_pass2_kernel<T, R, threads_x * threads_y><<<1, threads_x * threads_y>>>(\r
                 buf_, grid.x * grid.y);\r
         cudaSafeCall(cudaThreadSynchronize());\r
 \r
-        T sum;\r
-        cudaSafeCall(cudaMemcpy(&sum, buf_, sizeof(T), cudaMemcpyDeviceToHost));\r
-        \r
-        return sum;\r
+        R result = 0;\r
+        cudaSafeCall(cudaMemcpy(&result, buf_, result, cudaMemcpyDeviceToHost));\r
+        sum[0] = result;\r
     }  \r
 \r
-    template unsigned char sum_multipass_caller<unsigned char>(const DevMem2D_<unsigned char>, PtrStep);\r
-    template char sum_multipass_caller<char>(const DevMem2D_<char>, PtrStep);\r
-    template unsigned short sum_multipass_caller<unsigned short>(const DevMem2D_<unsigned short>, PtrStep);\r
-    template short sum_multipass_caller<short>(const DevMem2D_<short>, PtrStep);\r
-    template int sum_multipass_caller<int>(const DevMem2D_<int>, PtrStep);\r
-    template float sum_multipass_caller<float>(const DevMem2D_<float>, PtrStep);\r
+    template void sum_multipass_caller<unsigned char>(const DevMem2D, PtrStep, double*);\r
+    template void sum_multipass_caller<char>(const DevMem2D, PtrStep, double*);\r
+    template void sum_multipass_caller<unsigned short>(const DevMem2D, PtrStep, double*);\r
+    template void sum_multipass_caller<short>(const DevMem2D, PtrStep, double*);\r
+    template void sum_multipass_caller<int>(const DevMem2D, PtrStep, double*);\r
+    template void sum_multipass_caller<float>(const DevMem2D, PtrStep, double*);\r
 \r
 \r
     template <typename T>\r
-    T sum_caller(const DevMem2D_<T> src, PtrStep buf)\r
+    void sum_caller(const DevMem2D src, PtrStep buf, double* sum)\r
     {\r
         using namespace sum;\r
+        typedef typename SumType<T>::R R;\r
 \r
         dim3 threads, grid;\r
         estimate_thread_cfg(src.cols, src.rows, threads, grid);\r
         set_kernel_consts(src.cols, src.rows, threads, grid);\r
 \r
-        T* buf_ = (T*)buf.ptr(0);\r
+        R* buf_ = (R*)buf.ptr(0);\r
 \r
-        sum_kernel<T, threads_x * threads_y><<<grid, threads>>>(src, buf_);\r
+        sum_kernel<T, R, threads_x * threads_y><<<grid, threads>>>((const DevMem2D_<T>)src, buf_);\r
         cudaSafeCall(cudaThreadSynchronize());\r
 \r
-        T sum;\r
-        cudaSafeCall(cudaMemcpy(&sum, buf_, sizeof(T), cudaMemcpyDeviceToHost));\r
-\r
-        return sum;\r
+        R result = 0;\r
+        cudaSafeCall(cudaMemcpy(&result, buf_, sizeof(result), cudaMemcpyDeviceToHost));\r
+        sum[0] = result;\r
     }  \r
 \r
-    template unsigned char sum_caller<unsigned char>(const DevMem2D_<unsigned char>, PtrStep);\r
-    template char sum_caller<char>(const DevMem2D_<char>, PtrStep);\r
-    template unsigned short sum_caller<unsigned short>(const DevMem2D_<unsigned short>, PtrStep);\r
-    template short sum_caller<short>(const DevMem2D_<short>, PtrStep);\r
-    template int sum_caller<int>(const DevMem2D_<int>, PtrStep);\r
-    template float sum_caller<float>(const DevMem2D_<float>, PtrStep);\r
-    template double sum_caller<double>(const DevMem2D_<double>, PtrStep);\r
+    template void sum_caller<unsigned char>(const DevMem2D, PtrStep, double*);\r
+    template void sum_caller<char>(const DevMem2D, PtrStep, double*);\r
+    template void sum_caller<unsigned short>(const DevMem2D, PtrStep, double*);\r
+    template void sum_caller<short>(const DevMem2D, PtrStep, double*);\r
+    template void sum_caller<int>(const DevMem2D, PtrStep, double*);\r
+    template void sum_caller<float>(const DevMem2D, PtrStep, double*);\r
+    template void sum_caller<double>(const DevMem2D, PtrStep, double*);\r
 }}}\r
+\r
index 5b7d5d6..521120c 100644 (file)
@@ -459,29 +459,6 @@ struct CV_GpuNppImageFlipTest : public CV_GpuArithmTest
 };\r
 \r
 ////////////////////////////////////////////////////////////////////////////////\r
-// sum\r
-struct CV_GpuNppImageSumTest : public CV_GpuArithmTest\r
-{\r
-    CV_GpuNppImageSumTest() : CV_GpuArithmTest( "GPU-NppImageSum", "sum" ) {}\r
-\r
-    int test( const Mat& mat1, const Mat& )\r
-    {\r
-        if (mat1.type() != CV_8UC1 && mat1.type() != CV_8UC4)\r
-        {\r
-            ts->printf(CvTS::LOG, "\tUnsupported type\t");\r
-            return CvTS::OK;\r
-        }\r
-\r
-        Scalar cpures = cv::sum(mat1);\r
-\r
-        GpuMat gpu1(mat1);\r
-        Scalar gpures = cv::gpu::sum(gpu1);\r
-\r
-        return CheckNorm(cpures, gpures);\r
-    }\r
-};\r
-\r
-////////////////////////////////////////////////////////////////////////////////\r
 // LUT\r
 struct CV_GpuNppImageLUTTest : public CV_GpuArithmTest\r
 {\r
@@ -949,27 +926,49 @@ struct CV_GpuCountNonZeroTest: CvTest
     }\r
 };\r
 \r
-////////////////////////////////////////////////////////////////////////////////\r
-// min/max\r
 \r
-struct CV_GpuImageMinMaxTest : public CV_GpuArithmTest\r
+//////////////////////////////////////////////////////////////////////////////\r
+// sum\r
+\r
+struct CV_GpuSumTest: CvTest \r
 {\r
-    CV_GpuImageMinMaxTest() : CV_GpuArithmTest( "GPU-ImageMinMax", "min/max" ) {}\r
+    CV_GpuSumTest(): CvTest("GPU-SumTest", "sum") {}\r
 \r
-    int test( const Mat& mat1, const Mat& mat2 )\r
+    void run(int) \r
     {\r
-        cv::Mat cpuMinRes, cpuMaxRes;\r
-        cv::min(mat1, mat2, cpuMinRes);\r
-        cv::max(mat1, mat2, cpuMaxRes);\r
+        try\r
+        {\r
+            Mat src;\r
+            Scalar a, b;\r
+            double max_err = 1e-6;\r
 \r
-        GpuMat gpu1(mat1);\r
-        GpuMat gpu2(mat2);\r
-        GpuMat gpuMinRes, gpuMaxRes;\r
-        cv::gpu::min(gpu1, gpu2, gpuMinRes);\r
-        cv::gpu::max(gpu1, gpu2, gpuMaxRes);\r
+            int typemax = hasNativeDoubleSupport(getDevice()) ? CV_64F : CV_32F;\r
+            for (int type = CV_8U; type <= typemax; ++type) \r
+            {\r
+                gen(1 + rand() % 1000, 1 + rand() % 1000, type, src);\r
+                a = sum(src);\r
+                b = sum(GpuMat(src));\r
+                if (abs(a[0] - b[0]) > src.size().area() * max_err)\r
+                {\r
+                    ts->printf(CvTS::CONSOLE, "cols: %d, rows: %d, expected: %f, actual: %f\n", src.cols, src.rows, a[0], b[0]);\r
+                    ts->set_failed_test_info(CvTS::FAIL_INVALID_OUTPUT);\r
+                    return;\r
+                }\r
+            }\r
+        }\r
+        catch (const Exception& e)\r
+        {\r
+            if (!check_and_treat_gpu_exception(e, ts)) throw;\r
+            return;\r
+        }\r
+    }\r
+\r
+    void gen(int cols, int rows, int type, Mat& m)\r
+    {\r
+        m.create(rows, cols, type);\r
+        RNG rng;\r
+        rng.fill(m, RNG::UNIFORM, Scalar::all(0), Scalar::all(20));\r
 \r
-        return CheckNorm(cpuMinRes, gpuMinRes) == CvTS::OK && CheckNorm(cpuMaxRes, gpuMaxRes) == CvTS::OK ?\r
-            CvTS::OK : CvTS::FAIL_GENERIC;\r
     }\r
 };\r
 \r
@@ -992,7 +991,6 @@ CV_GpuNppImageCompareTest CV_GpuNppImageCompare_test;
 CV_GpuNppImageMeanStdDevTest CV_GpuNppImageMeanStdDev_test;\r
 CV_GpuNppImageNormTest CV_GpuNppImageNorm_test;\r
 CV_GpuNppImageFlipTest CV_GpuNppImageFlip_test;\r
-CV_GpuNppImageSumTest CV_GpuNppImageSum_test;\r
 CV_GpuNppImageLUTTest CV_GpuNppImageLUT_test;\r
 CV_GpuNppImageExpTest CV_GpuNppImageExp_test;\r
 CV_GpuNppImageLogTest CV_GpuNppImageLog_test;\r
@@ -1003,4 +1001,4 @@ CV_GpuNppImagePolarToCartTest CV_GpuNppImagePolarToCart_test;
 CV_GpuMinMaxTest CV_GpuMinMaxTest_test;\r
 CV_GpuMinMaxLocTest CV_GpuMinMaxLocTest_test;\r
 CV_GpuCountNonZeroTest CV_CountNonZero_test;\r
-CV_GpuImageMinMaxTest CV_GpuImageMinMax_test;\r
+CV_GpuSumTest CV_GpuSum_test;\r
index a634ef0..a388fa7 100644 (file)
@@ -46,9 +46,6 @@ CvTS test_system("gpu");
 const char* blacklist[] =
 {
     "GPU-AsyncGpuMatOperator",     // crash
-
-    "GPU-NppImageSum",              // crash, probably npp bug
-
     "GPU-NppImageCanny",            // NPP_TEXTURE_BIND_ERROR
     0
 };