optimized cv::setIdentity
authorIlya Lavrenov <ilya.lavrenov@itseez.com>
Wed, 11 Jun 2014 12:50:38 +0000 (16:50 +0400)
committerIlya Lavrenov <ilya.lavrenov@itseez.com>
Mon, 16 Jun 2014 09:41:43 +0000 (13:41 +0400)
modules/core/src/matrix.cpp
modules/core/src/opencl/set_identity.cl

index 653efe6..7023b39 100644 (file)
@@ -2758,21 +2758,30 @@ namespace cv {
 
 static bool ocl_setIdentity( InputOutputArray _m, const Scalar& s )
 {
-    int type = _m.type(), depth = CV_MAT_DEPTH(type), cn = CV_MAT_CN(type),
-            sctype = CV_MAKE_TYPE(depth, cn == 3 ? 4 : cn),
+    int type = _m.type(), depth = CV_MAT_DEPTH(type), cn = CV_MAT_CN(type), kercn = cn;
+    if (cn == 1)
+    {
+        kercn = std::min(ocl::predictOptimalVectorWidth(_m), 4);
+        if (kercn != 4)
+            kercn = 1;
+    }
+    int sctype = CV_MAKE_TYPE(depth, cn == 3 ? 4 : cn),
             rowsPerWI = ocl::Device::getDefault().isIntel() ? 4 : 1;
 
     ocl::Kernel k("setIdentity", ocl::core::set_identity_oclsrc,
-                  format("-D T=%s -D T1=%s -D cn=%d -D ST=%s", ocl::memopTypeToStr(type),
-                         ocl::memopTypeToStr(depth), cn, ocl::memopTypeToStr(sctype)));
+                  format("-D T=%s -D T1=%s -D cn=%d -D ST=%s -D kercn=%d -D rowsPerWI=%d",
+                         ocl::memopTypeToStr(CV_MAKE_TYPE(depth, kercn)),
+                         ocl::memopTypeToStr(depth), cn,
+                         ocl::memopTypeToStr(sctype),
+                         kercn, rowsPerWI));
     if (k.empty())
         return false;
 
     UMat m = _m.getUMat();
-    k.args(ocl::KernelArg::WriteOnly(m), ocl::KernelArg::Constant(Mat(1, 1, sctype, s)),
-           rowsPerWI);
+    k.args(ocl::KernelArg::WriteOnly(m, cn, kercn),
+           ocl::KernelArg::Constant(Mat(1, 1, sctype, s)));
 
-    size_t globalsize[2] = { m.cols, (m.rows + rowsPerWI - 1) / rowsPerWI };
+    size_t globalsize[2] = { m.cols * cn / kercn, (m.rows + rowsPerWI - 1) / rowsPerWI };
     return k.run(2, globalsize, NULL, false);
 }
 
index 6b277fe..952204d 100644 (file)
 //
 //M*/
 
-#if cn != 3
-#define loadpix(addr) *(__global const T *)(addr)
+#if kercn != 3
 #define storepix(val, addr)  *(__global T *)(addr) = val
 #define TSIZE (int)sizeof(T)
 #define scalar scalar_
 #else
-#define loadpix(addr) vload3(0, (__global const T1 *)(addr))
 #define storepix(val, addr) vstore3(val, 0, (__global T1 *)(addr))
 #define TSIZE ((int)sizeof(T1)*3)
 #define scalar (T)(scalar_.x, scalar_.y, scalar_.z)
 #endif
 
 __kernel void setIdentity(__global uchar * srcptr, int src_step, int src_offset, int rows, int cols,
-                          ST scalar_, int rowsPerWI)
+                          ST scalar_)
 {
     int x = get_global_id(0);
     int y0 = get_global_id(1) * rowsPerWI;
@@ -65,7 +63,35 @@ __kernel void setIdentity(__global uchar * srcptr, int src_step, int src_offset,
     {
         int src_index = mad24(y0, src_step, mad24(x, TSIZE, src_offset));
 
-        for (int y = y0, y1 = min(rows, y0 + rowsPerWI); y < y1; ++y, src_index += src_step)
-            storepix(x == y ? scalar : (T)(0), srcptr + src_index);
+#if kercn == cn
+        #pragma unroll
+        for (int y = y0, i = 0, y1 = min(rows, y0 + rowsPerWI); i < rowsPerWI; ++y, ++i, src_index += src_step)
+            if (y < y1)
+                storepix(x == y ? scalar : (T)(0), srcptr + src_index);
+#elif kercn == 4 && cn == 1
+        if (y0 < rows)
+        {
+            storepix(x == y0 >> 2 ? (T)(scalar, 0, 0, 0) : (T)(0), srcptr + src_index);
+            if (++y0 < rows)
+            {
+                src_index += src_step;
+                storepix(x == y0 >> 2 ? (T)(0, scalar, 0, 0) : (T)(0), srcptr + src_index);
+
+                if (++y0 < rows)
+                {
+                    src_index += src_step;
+                    storepix(x == y0 >> 2 ? (T)(0, 0, scalar, 0) : (T)(0), srcptr + src_index);
+
+                    if (++y0 < rows)
+                    {
+                        src_index += src_step;
+                        storepix(x == y0 >> 2 ? (T)(0, 0, 0, scalar) : (T)(0), srcptr + src_index);
+                    }
+                }
+            }
+        }
+#else
+#error "Incorrect combination of cn && kercn"
+#endif
     }
 }