optimize stereoBP kernel
authoryao <bitwangyaoyao@gmail.com>
Thu, 4 Jul 2013 06:46:38 +0000 (14:46 +0800)
committeryao <bitwangyaoyao@gmail.com>
Thu, 4 Jul 2013 06:46:38 +0000 (14:46 +0800)
modules/ocl/src/opencl/stereobp.cl
modules/ocl/src/stereobp.cpp

index 4d13f80..8a71629 100644 (file)
 
 #ifdef T_FLOAT
 #define T float
+#define T4 float4
 #else
 #define T short
+#define T4 short4
 #endif
 
 ///////////////////////////////////////////////////////////////
@@ -71,6 +73,14 @@ T saturate_cast(float v){
 #endif
 }
 
+T4 saturate_cast4(float4 v){
+#ifdef T_SHORT
+    return convert_short4_sat_rte(v); 
+#else
+    return v;
+#endif
+}
+
 #define FLOAT_MAX 3.402823466e+38f
 typedef struct
 {
@@ -84,29 +94,14 @@ typedef struct
 ////////////////////////// comp data //////////////////////////
 ///////////////////////////////////////////////////////////////
 
-float pix_diff_1(__global const uchar *ls, __global const uchar *rs)
+inline float pix_diff_1(const uchar4 l, __global const uchar *rs)
 {
-    return abs((int)(*ls) - *rs); 
+    return abs((int)(l.x) - *rs); 
 }
 
-float pix_diff_3(__global const uchar *ls, __global const uchar *rs)
+float pix_diff_4(const uchar4 l, __global const uchar *rs)
 {
-    const float tr = 0.299f;
-    const float tg = 0.587f;
-    const float tb = 0.114f;
-
-    float val;
-
-    val =  tb * abs((int)ls[0] - rs[0]);
-    val += tg * abs((int)ls[1] - rs[1]);
-    val += tr * abs((int)ls[2] - rs[2]);
-
-    return val;
-}
-float pix_diff_4(__global const uchar *ls, __global const uchar *rs)
-{
-    uchar4 l, r;
-    l = *((__global uchar4 *)ls);
+    uchar4 r;
     r = *((__global uchar4 *)rs);
 
     const float tr = 0.299f;
@@ -122,11 +117,19 @@ float pix_diff_4(__global const uchar *ls, __global const uchar *rs)
     return val;
 }
 
+inline float pix_diff_3(const uchar4 l, __global const uchar *rs)
+{
+    return pix_diff_4(l, rs);
+}
 
 #ifndef CN
 #define CN 4
 #endif
 
+#ifndef CNDISP
+#define CNDISP 64
+#endif
+
 #define CAT(X,Y) X##Y
 #define CAT2(X,Y) CAT(X,Y)
 
@@ -149,19 +152,20 @@ __kernel void comp_data(__global uchar *left,  int left_rows,  int left_cols,  i
         __global T *ds = data + y * data_step + x;
 
         const unsigned int disp_step = data_step * left_rows;
+        const float weightXterm = con_st -> cdata_weight * con_st -> cmax_data_term;
+        const uchar4 ls_data = vload4(0, ls);
 
         for (int disp = 0; disp < con_st -> cndisp; disp++)
         {
             if (x - disp >= 1)
             {
                 float val = 0;
-                val = PIX_DIFF(ls, rs - disp * CN);
-                ds[disp * disp_step] =  saturate_cast(fmin(con_st -> cdata_weight * val, 
-                    con_st -> cdata_weight * con_st -> cmax_data_term));
+                val = PIX_DIFF(ls_data, rs - disp * CN);
+                ds[disp * disp_step] =  saturate_cast(fmin(con_st -> cdata_weight * val, weightXterm));
             }
             else
             {
-                ds[disp * disp_step] =  saturate_cast(con_st -> cdata_weight * con_st -> cmax_data_term);
+                ds[disp * disp_step] =  saturate_cast(weightXterm);
             }
         }
     }
@@ -182,13 +186,20 @@ __kernel void data_step_down(__global T *src, int src_rows,
     {
         src_step /= sizeof(T);
         dst_step /= sizeof(T);
+        int4 coor_step = (int4)(src_rows * src_step);
+        int4 coor = (int4)(min(2*y+0, src_rows-1) * src_step + 2*x+0,
+                           min(2*y+1, src_rows-1) * src_step + 2*x+0,
+                           min(2*y+0, src_rows-1) * src_step + 2*x+1,
+                           min(2*y+1, src_rows-1) * src_step + 2*x+1);
+
         for (int d = 0; d < cndisp; ++d)
         {
             float dst_reg;
-            dst_reg  = src[(d * src_rows + min(2*y+0, src_rows-1)) * src_step + 2*x+0];
-            dst_reg += src[(d * src_rows + min(2*y+1, src_rows-1)) * src_step + 2*x+0];
-            dst_reg += src[(d * src_rows + min(2*y+0, src_rows-1)) * src_step + 2*x+1];
-            dst_reg += src[(d * src_rows + min(2*y+1, src_rows-1)) * src_step + 2*x+1];
+            dst_reg  = src[coor.x];
+            dst_reg += src[coor.y];
+            dst_reg += src[coor.z];
+            dst_reg += src[coor.w];
+            coor += coor_step;
 
             dst[(d * dst_rows + y) * dst_step + x] = saturate_cast(dst_reg);
         }
@@ -224,85 +235,95 @@ __kernel void level_up_message(__global T *src, int src_rows, int src_step,
 ///////////////////////////////////////////////////////////////
 ////////////////////  calc all iterations /////////////////////
 ///////////////////////////////////////////////////////////////
-void calc_min_linear_penalty(__global T * dst, int disp_step, 
-                             int cndisp, float cdisc_single_jump)
+void message(__global T *us_, __global T *ds_, __global T *ls_, __global T *rs_,
+              const __global T *dt,
+              int u_step, int msg_disp_step, int data_disp_step,
+              float4 cmax_disc_term, float4 cdisc_single_jump)
 {
-    float prev = dst[0];
-    float cur;
+    __global T *us = us_ + u_step;
+    __global T *ds = ds_ - u_step;
+    __global T *ls = ls_ + 1;
+    __global T *rs = rs_ - 1;
 
-    for (int disp = 1; disp < cndisp; ++disp)
-    {
-        prev += cdisc_single_jump;
-        cur = dst[disp_step * disp];
+    float4 minimum = (float4)(FLOAT_MAX);
 
-        if (prev < cur)
-        {
-            cur = prev;
-            dst[disp_step * disp] = saturate_cast(prev);
-        }
+    T4 t_dst[CNDISP];
+    float4 dst_reg;
+    float4 prev;
+    float4 cur;
 
-        prev = cur;
-    }
+    T t_us = us[0];
+    T t_ds = ds[0];
+    T t_ls = ls[0];
+    T t_rs = rs[0];
+    T t_dt = dt[0];
+
+    prev = (float4)(t_us + t_ls + t_rs + t_dt,
+                    t_ds + t_ls + t_rs + t_dt,
+                    t_us + t_ds + t_rs + t_dt,
+                    t_us + t_ds + t_ls + t_dt);
+
+    minimum = min(prev, minimum);
+
+    t_dst[0] = saturate_cast4(prev);
 
-    prev = dst[(cndisp - 1) * disp_step];
-    for (int disp = cndisp - 2; disp >= 0; disp--)
+    for(int i = 1, idx = msg_disp_step; i < CNDISP; ++i, idx+=msg_disp_step)
     {
+        t_us = us[idx];
+        t_ds = ds[idx];
+        t_ls = ls[idx];
+        t_rs = rs[idx];
+        t_dt = dt[data_disp_step * i];
+
+        dst_reg = (float4)(t_us + t_ls + t_rs + t_dt,
+                           t_ds + t_ls + t_rs + t_dt,
+                           t_us + t_ds + t_rs + t_dt,
+                           t_us + t_ds + t_ls + t_dt);
+
+        minimum = min(dst_reg, minimum);
+
         prev += cdisc_single_jump;
-        cur = dst[disp_step * disp];
+        prev = min(prev, dst_reg);
 
-        if (prev < cur)
-        {
-            cur = prev;
-            dst[disp_step * disp] = saturate_cast(prev);
-        }
-        prev = cur;
+        t_dst[i] = saturate_cast4(prev);
     }
-}
-void message(const __global T *msg1, const __global T *msg2,
-             const __global T *msg3, const __global T *data, __global T *dst,
-             int msg_disp_step, int data_disp_step, int cndisp, float cmax_disc_term, float cdisc_single_jump)
-{
-    float minimum = FLOAT_MAX;
 
-    for(int i = 0; i < cndisp; ++i)
+    minimum += cmax_disc_term;
+    
+    float4 sum = 0;
+    prev = convert_float4(t_dst[CNDISP - 1]);
+    for (int disp = CNDISP - 2; disp >= 0; disp--)
     {
-        float dst_reg;
-        dst_reg  = msg1[msg_disp_step * i];
-        dst_reg += msg2[msg_disp_step * i];
-        dst_reg += msg3[msg_disp_step * i];
-        dst_reg += data[data_disp_step * i];
-
-        if (dst_reg < minimum)
-            minimum = dst_reg;
+        prev += cdisc_single_jump;
+        cur = convert_float4(t_dst[disp]);
+        prev = min(prev, cur);
+        cur = min(prev, minimum);
+        sum += cur;
 
-        dst[msg_disp_step * i] = saturate_cast(dst_reg);
+        t_dst[disp] = saturate_cast4(cur);
     }
 
-    calc_min_linear_penalty(dst, msg_disp_step, cndisp, cdisc_single_jump);
+    dst_reg = convert_float4(t_dst[CNDISP - 1]);
+    dst_reg = min(dst_reg, minimum);
+    t_dst[CNDISP - 1] = saturate_cast4(dst_reg);
+    sum += dst_reg;
 
-    minimum += cmax_disc_term;
-
-    float sum = 0;
-    for(int i = 0; i < cndisp; ++i)
+    sum /= CNDISP;
+#pragma unroll
+    for(int i = 0, idx = 0; i < CNDISP; ++i, idx+=msg_disp_step)
     {
-        float dst_reg = dst[msg_disp_step * i];
-        if (dst_reg > minimum)
-        {
-            dst_reg = minimum;
-            dst[msg_disp_step * i] = saturate_cast(minimum);
-        }
-        sum += dst_reg;
+        T4 dst = t_dst[i];
+        us_[idx] = dst.x - sum.x;
+        ds_[idx] = dst.y - sum.y;
+        rs_[idx] = dst.z - sum.z;
+        ls_[idx] = dst.w - sum.w;
     }
-    sum /= cndisp;
-
-    for(int i = 0; i < cndisp; ++i)
-        dst[msg_disp_step * i] -= sum;
 }
 __kernel void one_iteration(__global T *u,    int u_step,
                             __global T *data, int data_step,
                             __global T *d,    __global T *l, __global T *r,
                             int t, int cols, int rows, 
-                            int cndisp, float cmax_disc_term, float cdisc_single_jump)
+                            float cmax_disc_term, float cdisc_single_jump)
 {
     const int y = get_global_id(1);
     const int x = ((get_global_id(0)) << 1) + ((y + t) & 1);
@@ -321,15 +342,9 @@ __kernel void one_iteration(__global T *u,    int u_step,
         int msg_disp_step = u_step * rows;
         int data_disp_step = data_step * rows;
 
-        message(us + u_step, ls      + 1, rs - 1, dt, us, msg_disp_step, data_disp_step, cndisp, 
-            cmax_disc_term, cdisc_single_jump);
-        message(ds - u_step, ls      + 1, rs - 1, dt, ds, msg_disp_step, data_disp_step, cndisp,
-            cmax_disc_term, cdisc_single_jump);
-
-        message(us + u_step, ds - u_step, rs - 1, dt, rs, msg_disp_step, data_disp_step, cndisp,
-            cmax_disc_term, cdisc_single_jump);
-        message(us + u_step, ds - u_step, ls + 1, dt, ls, msg_disp_step, data_disp_step, cndisp,
-            cmax_disc_term, cdisc_single_jump);
+        message(us, ds, ls, rs, dt,
+                u_step, msg_disp_step, data_disp_step,
+                (float4)(cmax_disc_term), (float4)(cdisc_single_jump));
     }
 }
 
index bd88ec0..cca1db3 100644 (file)
@@ -236,13 +236,13 @@ namespace cv
                 args.push_back( make_pair( sizeof(cl_int) , (void *)&t));
                 args.push_back( make_pair( sizeof(cl_int) , (void *)&cols));
                 args.push_back( make_pair( sizeof(cl_int) , (void *)&rows));
-                args.push_back( make_pair( sizeof(cl_int) , (void *)&cndisp));
                 args.push_back( make_pair( sizeof(cl_float) , (void *)&cmax_disc_term));
                 args.push_back( make_pair( sizeof(cl_float) , (void *)&cdisc_single_jump));
 
                 size_t gt[3] = {cols, rows, 1}, lt[3] = {16, 16, 1};
-                const char* t_opt  = data_type == CV_16S ? "-D T_SHORT":"-D T_FLOAT";
-                openCLExecuteKernel(clCxt, &stereobp, kernelName, gt, lt, args, -1, -1, t_opt);
+                char opt[80] = "";
+                sprintf(opt, "-D %s -D CNDISP=%d", data_type == CV_16S ? "T_SHORT":"T_FLOAT", cndisp);
+                openCLExecuteKernel(clCxt, &stereobp, kernelName, gt, lt, args, -1, -1, opt);
             }
 
             static void calc_all_iterations_calls(int cols, int rows, int iters, oclMat &u,