Rewrite hex search function
authorYunqing Wang <yunqingwang@google.com>
Fri, 20 May 2011 19:26:32 +0000 (15:26 -0400)
committerYunqing Wang <yunqingwang@google.com>
Mon, 23 May 2011 20:18:52 +0000 (16:18 -0400)
Reduced some bound checks in hex search function.

Change-Id: Ie5f73a6c227590341c960a74dc508cff80f8aa06

vp8/encoder/mcomp.c

index 0cd165b..f8e574c 100644 (file)
@@ -793,12 +793,36 @@ int vp8_find_best_half_pixel_step(MACROBLOCK *mb, BLOCK *b, BLOCKD *d,
     return bestmse;
 }
 
+#define CHECK_BOUNDS(range) \
+{\
+    all_in = 1;\
+    all_in &= ((br-range) >= x->mv_row_min);\
+    all_in &= ((br+range) <= x->mv_row_max);\
+    all_in &= ((bc-range) >= x->mv_col_min);\
+    all_in &= ((bc+range) <= x->mv_col_max);\
+}
+
+#define CHECK_POINT \
+{\
+    if (this_mv.as_mv.col < x->mv_col_min) continue;\
+    if (this_mv.as_mv.col > x->mv_col_max) continue;\
+    if (this_mv.as_mv.row < x->mv_row_min) continue;\
+    if (this_mv.as_mv.row > x->mv_row_max) continue;\
+}
+
+#define CHECK_BETTER \
+{\
+    if (thissad < bestsad)\
+    {\
+        thissad += mvsad_err_cost(&this_mv, &fcenter_mv, mvsadcost, error_per_bit);\
+        if (thissad < bestsad)\
+        {\
+            bestsad = thissad;\
+            best_site = i;\
+        }\
+    }\
+}
 
-#define MVC(r,c) (((mvsadcost[0][r-rr] + mvsadcost[1][c-rc]) * error_per_bit + 128 )>>8 ) // estimated cost of a motion vector (r,c)
-#define PRE(r,c) (*(d->base_pre) + d->pre + (r) * d->pre_stride + (c)) // pointer to predictor base of a motionvector
-#define DIST(r,c,v) vfp->sdf( src,src_stride,PRE(r,c),d->pre_stride, v) // returns sad error score.
-#define ERR(r,c,v) (MVC(r,c)+DIST(r,c,v)) // returns distortion + motion vector cost
-#define CHECK_BETTER(v,r,c) if ((v = ERR(r,c,besterr)) < besterr) { besterr = v; br=r; bc=c; } // checks if (r,c) has better score than previous best
 static const MV next_chkpts[6][3] =
 {
     {{ -2, 0}, { -1, -2}, {1, -2}},
@@ -808,6 +832,7 @@ static const MV next_chkpts[6][3] =
     {{1, 2}, { -1, 2}, { -2, 0}},
     {{ -1, 2}, { -2, 0}, { -1, -2}}
 };
+
 int vp8_hex_search
 (
     MACROBLOCK *x,
@@ -825,135 +850,160 @@ int vp8_hex_search
 )
 {
     MV hex[6] = { { -1, -2}, {1, -2}, {2, 0}, {1, 2}, { -1, 2}, { -2, 0} } ;
-    //MV neighbors[8] = { { -1, -1}, {0, -1}, {1, -1}, { -1, 0}, {1, 0}, { -1, 1}, {0, 1}, {1, 1} } ;
     MV neighbors[4] = {{0, -1}, { -1, 0}, {1, 0}, {0, 1}} ;
-
     int i, j;
-    unsigned char *src = (*(b->base_src) + b->src);
-    int src_stride = b->src_stride;
-    int rr = center_mv->as_mv.row, rc = center_mv->as_mv.col;
-    int br = ref_mv->as_mv.row >> 3, bc = ref_mv->as_mv.col >> 3, tr, tc;
-    unsigned int besterr, thiserr = 0x7fffffff;
-    int k = -1, tk;
-
-    if (bc < x->mv_col_min) bc = x->mv_col_min;
-
-    if (bc > x->mv_col_max) bc = x->mv_col_max;
-
-    if (br < x->mv_row_min) br = x->mv_row_min;
 
-    if (br > x->mv_row_max) br = x->mv_row_max;
+    unsigned char *what = (*(b->base_src) + b->src);
+    int what_stride = b->src_stride;
+    int in_what_stride = d->pre_stride;
+    int br = ref_mv->as_mv.row >> 3, bc = ref_mv->as_mv.col >> 3;
+    int_mv this_mv;
+    unsigned int bestsad = 0x7fffffff;
+    unsigned int thissad;
+    unsigned char *base_offset;
+    unsigned char *this_offset;
+    int k = -1;
+    int all_in;
+    int best_site = -1;
 
-    rr >>= 3;
-    rc >>= 3;
+    int_mv fcenter_mv;
+    fcenter_mv.as_mv.row = center_mv->as_mv.row >> 3;
+    fcenter_mv.as_mv.col = center_mv->as_mv.col >> 3;
 
-    besterr = ERR(br, bc, thiserr);
+    // Work out the start point for the search
+    base_offset = (unsigned char *)(*(d->base_pre) + d->pre);
+    this_offset = base_offset + (br * (d->pre_stride)) + bc;
+    this_mv.as_mv.row = br;
+    this_mv.as_mv.col = bc;
+    bestsad = vfp->sdf( what, what_stride, this_offset, in_what_stride, 0x7fffffff) + mvsad_err_cost(&this_mv, &fcenter_mv, mvsadcost, error_per_bit);
 
     // hex search
     //j=0
-    tr = br;
-    tc = bc;
+    CHECK_BOUNDS(2)
 
-    for (i = 0; i < 6; i++)
+    if(all_in)
     {
-        int nr = tr + hex[i].row, nc = tc + hex[i].col;
-
-        if (nc < x->mv_col_min) continue;
-
-        if (nc > x->mv_col_max) continue;
-
-        if (nr < x->mv_row_min) continue;
-
-        if (nr > x->mv_row_max) continue;
-
-        //CHECK_BETTER(thiserr,nr,nc);
-        if ((thiserr = ERR(nr, nc, besterr)) < besterr)
+        for (i = 0; i < 6; i++)
         {
-            besterr = thiserr;
-            br = nr;
-            bc = nc;
-            k = i;
+            this_mv.as_mv.row = br + hex[i].row;
+            this_mv.as_mv.col = bc + hex[i].col;
+            this_offset = base_offset + (this_mv.as_mv.row * in_what_stride) + this_mv.as_mv.col;
+            thissad=vfp->sdf( what, what_stride, this_offset, in_what_stride, bestsad);
+            CHECK_BETTER
+        }
+    }else
+    {
+        for (i = 0; i < 6; i++)
+        {
+            this_mv.as_mv.row = br + hex[i].row;
+            this_mv.as_mv.col = bc + hex[i].col;
+            CHECK_POINT
+            this_offset = base_offset + (this_mv.as_mv.row * in_what_stride) + this_mv.as_mv.col;
+            thissad=vfp->sdf( what, what_stride, this_offset, in_what_stride, bestsad);
+            CHECK_BETTER
         }
     }
 
-    if (tr == br && tc == bc)
+    if (best_site == -1)
         goto cal_neighbors;
+    else
+    {
+        br += hex[best_site].row;
+        bc += hex[best_site].col;
+        k = best_site;
+    }
 
     for (j = 1; j < 127; j++)
     {
-        tr = br;
-        tc = bc;
-        tk = k;
+        best_site = -1;
+        CHECK_BOUNDS(2)
 
-        for (i = 0; i < 3; i++)
+        if(all_in)
         {
-            int nr = tr + next_chkpts[tk][i].row, nc = tc + next_chkpts[tk][i].col;
-
-            if (nc < x->mv_col_min) continue;
-
-            if (nc > x->mv_col_max) continue;
-
-            if (nr < x->mv_row_min) continue;
-
-            if (nr > x->mv_row_max) continue;
-
-            //CHECK_BETTER(thiserr,nr,nc);
-            if ((thiserr = ERR(nr, nc, besterr)) < besterr)
+            for (i = 0; i < 3; i++)
             {
-                besterr = thiserr;
-                br = nr;
-                bc = nc; //k=(tk+5+i)%6;}
-                k = tk + 5 + i;
-
-                if (k >= 12) k -= 12;
-                else if (k >= 6) k -= 6;
+                this_mv.as_mv.row = br + next_chkpts[k][i].row;
+                this_mv.as_mv.col = bc + next_chkpts[k][i].col;
+                this_offset = base_offset + (this_mv.as_mv.row * (in_what_stride)) + this_mv.as_mv.col;
+                thissad = vfp->sdf( what, what_stride, this_offset, in_what_stride, bestsad);
+                CHECK_BETTER
+            }
+        }else
+        {
+            for (i = 0; i < 3; i++)
+            {
+                this_mv.as_mv.row = br + next_chkpts[k][i].row;
+                this_mv.as_mv.col = bc + next_chkpts[k][i].col;
+                CHECK_POINT
+                this_offset = base_offset + (this_mv.as_mv.row * (in_what_stride)) + this_mv.as_mv.col;
+                thissad = vfp->sdf( what, what_stride, this_offset, in_what_stride, bestsad);
+                CHECK_BETTER
             }
         }
 
-        if (tr == br && tc == bc)
+        if (best_site == -1)
             break;
+        else
+        {
+            br += next_chkpts[k][best_site].row;
+            bc += next_chkpts[k][best_site].col;
+            k += 5 + best_site;
+            if (k >= 12) k -= 12;
+            else if (k >= 6) k -= 6;
+        }
     }
 
     // check 4 1-away neighbors
 cal_neighbors:
-
     for (j = 0; j < 32; j++)
     {
-        tr = br;
-        tc = bc;
+        best_site = -1;
+        CHECK_BOUNDS(1)
 
-        for (i = 0; i < 4; i++)
+        if(all_in)
         {
-            int nr = tr + neighbors[i].row, nc = tc + neighbors[i].col;
-
-            if (nc < x->mv_col_min) continue;
-
-            if (nc > x->mv_col_max) continue;
-
-            if (nr < x->mv_row_min) continue;
-
-            if (nr > x->mv_row_max) continue;
-
-            CHECK_BETTER(thiserr, nr, nc);
+            for (i = 0; i < 4; i++)
+            {
+                this_mv.as_mv.row = br + neighbors[i].row;
+                this_mv.as_mv.col = bc + neighbors[i].col;
+                this_offset = base_offset + (this_mv.as_mv.row * (in_what_stride)) + this_mv.as_mv.col;
+                thissad = vfp->sdf( what, what_stride, this_offset, in_what_stride, bestsad);
+                CHECK_BETTER
+            }
+        }else
+        {
+            for (i = 0; i < 4; i++)
+            {
+                this_mv.as_mv.row = br + neighbors[i].row;
+                this_mv.as_mv.col = bc + neighbors[i].col;
+                CHECK_POINT
+                this_offset = base_offset + (this_mv.as_mv.row * (in_what_stride)) + this_mv.as_mv.col;
+                thissad = vfp->sdf( what, what_stride, this_offset, in_what_stride, bestsad);
+                CHECK_BETTER
+            }
         }
 
-        if (tr == br && tc == bc)
+        if (best_site == -1)
             break;
+        else
+        {
+            br += neighbors[best_site].row;
+            bc += neighbors[best_site].col;
+        }
     }
 
     best_mv->as_mv.row = br;
     best_mv->as_mv.col = bc;
+    this_mv.as_mv.row = br<<3;
+    this_mv.as_mv.col = bc<<3;
 
-    return vfp->vf(src, src_stride, PRE(br, bc), d->pre_stride, &thiserr) + mv_err_cost(best_mv, center_mv, mvcost, error_per_bit) ;
+    this_offset = (unsigned char *)(*(d->base_pre) + d->pre + (br * (in_what_stride)) + bc);
+    return vfp->vf(what, what_stride, this_offset, in_what_stride, &bestsad) + mv_err_cost(&this_mv, center_mv, mvcost, error_per_bit) ;
 }
-#undef MVC
-#undef PRE
-#undef SP
-#undef DIST
-#undef ERR
+#undef CHECK_BOUNDS
+#undef CHECK_POINT
 #undef CHECK_BETTER
 
-
 int vp8_diamond_search_sad
 (
     MACROBLOCK *x,