Optimize SSE4_1 lowbd temporal filter implementation
authorchiyotsai <chiyotsai@google.com>
Mon, 4 Mar 2019 18:40:14 +0000 (10:40 -0800)
committerchiyotsai <chiyotsai@google.com>
Thu, 7 Mar 2019 23:55:06 +0000 (15:55 -0800)
 - Change some unaligned loads to aligned loads
 - Preload filter weights

BUG=webm:1591

Change-Id: I4e5e755e1fa5613d1c14191265bf80b0bfd0b75c

vp9/encoder/x86/temporal_filter_sse4.c

index 3a1142f..437f49f 100644 (file)
@@ -75,11 +75,11 @@ static INLINE void read_dist_16(const uint16_t *dist, __m128i *reg_first,
 // by weight.
 static INLINE __m128i average_8(__m128i sum, const __m128i *mul_constants,
                                 const int strength, const int rounding,
-                                const int weight) {
+                                const __m128i *weight) {
   // _mm_srl_epi16 uses the lower 64 bit value for the shift.
   const __m128i strength_u128 = _mm_set_epi32(0, 0, 0, strength);
   const __m128i rounding_u16 = _mm_set1_epi16(rounding);
-  const __m128i weight_u16 = _mm_set1_epi16(weight);
+  const __m128i weight_u16 = *weight;
   const __m128i sixteen = _mm_set1_epi16(16);
 
   // modifier * 3 / index;
@@ -98,62 +98,6 @@ static INLINE __m128i average_8(__m128i sum, const __m128i *mul_constants,
   return _mm_mullo_epi16(sum, weight_u16);
 }
 
-static __m128i average_4_4(__m128i sum, const __m128i *mul_constants,
-                           const int strength, const int rounding,
-                           const int weight_0, const int weight_1) {
-  // _mm_srl_epi16 uses the lower 64 bit value for the shift.
-  const __m128i strength_u128 = _mm_set_epi32(0, 0, 0, strength);
-  const __m128i rounding_u16 = _mm_set1_epi16(rounding);
-  const __m128i weight_u16 =
-      _mm_setr_epi16(weight_0, weight_0, weight_0, weight_0, weight_1, weight_1,
-                     weight_1, weight_1);
-  const __m128i sixteen = _mm_set1_epi16(16);
-
-  // modifier * 3 / index;
-  sum = _mm_mulhi_epu16(sum, *mul_constants);
-
-  sum = _mm_adds_epu16(sum, rounding_u16);
-  sum = _mm_srl_epi16(sum, strength_u128);
-
-  // The maximum input to this comparison is UINT16_MAX * NEIGHBOR_CONSTANT_4
-  // >> 16 (also NEIGHBOR_CONSTANT_4 -1) which is 49151 / 0xbfff / -16385
-  // So this needs to use the epu16 version which did not come until SSE4.
-  sum = _mm_min_epu16(sum, sixteen);
-
-  sum = _mm_sub_epi16(sixteen, sum);
-
-  return _mm_mullo_epi16(sum, weight_u16);
-}
-
-static INLINE void average_16(__m128i *sum_0_u16, __m128i *sum_1_u16,
-                              const __m128i *mul_constants_0,
-                              const __m128i *mul_constants_1,
-                              const int strength, const int rounding,
-                              const int weight) {
-  const __m128i strength_u128 = _mm_set_epi32(0, 0, 0, strength);
-  const __m128i rounding_u16 = _mm_set1_epi16(rounding);
-  const __m128i weight_u16 = _mm_set1_epi16(weight);
-  const __m128i sixteen = _mm_set1_epi16(16);
-  __m128i input_0, input_1;
-
-  input_0 = _mm_mulhi_epu16(*sum_0_u16, *mul_constants_0);
-  input_0 = _mm_adds_epu16(input_0, rounding_u16);
-
-  input_1 = _mm_mulhi_epu16(*sum_1_u16, *mul_constants_1);
-  input_1 = _mm_adds_epu16(input_1, rounding_u16);
-
-  input_0 = _mm_srl_epi16(input_0, strength_u128);
-  input_1 = _mm_srl_epi16(input_1, strength_u128);
-
-  input_0 = _mm_min_epu16(input_0, sixteen);
-  input_1 = _mm_min_epu16(input_1, sixteen);
-  input_0 = _mm_sub_epi16(sixteen, input_0);
-  input_1 = _mm_sub_epi16(sixteen, input_1);
-
-  *sum_0_u16 = _mm_mullo_epi16(input_0, weight_u16);
-  *sum_1_u16 = _mm_mullo_epi16(input_1, weight_u16);
-}
-
 // Add 'sum_u16' to 'count'. Multiply by 'pred' and add to 'accumulator.'
 static void accumulate_and_store_8(const __m128i sum_u16, const uint8_t *pred,
                                    uint16_t *count, uint32_t *accumulator) {
@@ -336,7 +280,7 @@ static void vp9_apply_temporal_filter_luma_16(
     const int16_t *const *neighbors_second, int top_weight, int bottom_weight,
     const int *blk_fw) {
   const int rounding = (1 << strength) >> 1;
-  int weight = top_weight;
+  __m128i weight_first, weight_second;
 
   __m128i mul_first, mul_second;
 
@@ -360,9 +304,18 @@ static void vp9_apply_temporal_filter_luma_16(
 
   (void)block_width;
 
+  // Initialize the weights
+  if (blk_fw) {
+    weight_first = _mm_set1_epi16(blk_fw[0]);
+    weight_second = _mm_set1_epi16(blk_fw[1]);
+  } else {
+    weight_first = _mm_set1_epi16(top_weight);
+    weight_second = weight_first;
+  }
+
   // First row
-  mul_first = _mm_loadu_si128((const __m128i *)neighbors_first[0]);
-  mul_second = _mm_loadu_si128((const __m128i *)neighbors_second[0]);
+  mul_first = _mm_load_si128((const __m128i *)neighbors_first[0]);
+  mul_second = _mm_load_si128((const __m128i *)neighbors_second[0]);
 
   // Add luma values
   get_sum_16(y_dist, &sum_row_2_first, &sum_row_2_second);
@@ -382,15 +335,10 @@ static void vp9_apply_temporal_filter_luma_16(
   sum_row_second = _mm_adds_epu16(sum_row_second, v_second);
 
   // Get modifier and store result
-  if (blk_fw) {
-    sum_row_first =
-        average_8(sum_row_first, &mul_first, strength, rounding, blk_fw[0]);
-    sum_row_second =
-        average_8(sum_row_second, &mul_second, strength, rounding, blk_fw[1]);
-  } else {
-    average_16(&sum_row_first, &sum_row_second, &mul_first, &mul_second,
-               strength, rounding, weight);
-  }
+  sum_row_first =
+      average_8(sum_row_first, &mul_first, strength, rounding, &weight_first);
+  sum_row_second = average_8(sum_row_second, &mul_second, strength, rounding,
+                             &weight_second);
   accumulate_and_store_16(sum_row_first, sum_row_second, y_pre, y_count,
                           y_accum);
 
@@ -408,16 +356,18 @@ static void vp9_apply_temporal_filter_luma_16(
   v_dist += DIST_STRIDE;
 
   // Then all the rows except the last one
-  mul_first = _mm_loadu_si128((const __m128i *)neighbors_first[1]);
-  mul_second = _mm_loadu_si128((const __m128i *)neighbors_second[1]);
+  mul_first = _mm_load_si128((const __m128i *)neighbors_first[1]);
+  mul_second = _mm_load_si128((const __m128i *)neighbors_second[1]);
 
   for (h = 1; h < block_height - 1; ++h) {
     // Move the weight to bottom half
     if (!use_whole_blk && h == block_height / 2) {
       if (blk_fw) {
-        blk_fw += 2;
+        weight_first = _mm_set1_epi16(blk_fw[2]);
+        weight_second = _mm_set1_epi16(blk_fw[3]);
       } else {
-        weight = bottom_weight;
+        weight_first = _mm_set1_epi16(bottom_weight);
+        weight_second = weight_first;
       }
     }
     // Shift the rows up
@@ -456,15 +406,10 @@ static void vp9_apply_temporal_filter_luma_16(
     sum_row_second = _mm_adds_epu16(sum_row_second, v_second);
 
     // Get modifier and store result
-    if (blk_fw) {
-      sum_row_first =
-          average_8(sum_row_first, &mul_first, strength, rounding, blk_fw[0]);
-      sum_row_second =
-          average_8(sum_row_second, &mul_second, strength, rounding, blk_fw[1]);
-    } else {
-      average_16(&sum_row_first, &sum_row_second, &mul_first, &mul_second,
-                 strength, rounding, weight);
-    }
+    sum_row_first =
+        average_8(sum_row_first, &mul_first, strength, rounding, &weight_first);
+    sum_row_second = average_8(sum_row_second, &mul_second, strength, rounding,
+                               &weight_second);
     accumulate_and_store_16(sum_row_first, sum_row_second, y_pre, y_count,
                             y_accum);
 
@@ -476,8 +421,8 @@ static void vp9_apply_temporal_filter_luma_16(
   }
 
   // The last row
-  mul_first = _mm_loadu_si128((const __m128i *)neighbors_first[0]);
-  mul_second = _mm_loadu_si128((const __m128i *)neighbors_second[0]);
+  mul_first = _mm_load_si128((const __m128i *)neighbors_first[0]);
+  mul_second = _mm_load_si128((const __m128i *)neighbors_second[0]);
 
   // Shift the rows up
   sum_row_1_first = sum_row_2_first;
@@ -503,15 +448,10 @@ static void vp9_apply_temporal_filter_luma_16(
   sum_row_second = _mm_adds_epu16(sum_row_second, v_second);
 
   // Get modifier and store result
-  if (blk_fw) {
-    sum_row_first =
-        average_8(sum_row_first, &mul_first, strength, rounding, blk_fw[0]);
-    sum_row_second =
-        average_8(sum_row_second, &mul_second, strength, rounding, blk_fw[1]);
-  } else {
-    average_16(&sum_row_first, &sum_row_second, &mul_first, &mul_second,
-               strength, rounding, weight);
-  }
+  sum_row_first =
+      average_8(sum_row_first, &mul_first, strength, rounding, &weight_first);
+  sum_row_second = average_8(sum_row_second, &mul_second, strength, rounding,
+                             &weight_second);
   accumulate_and_store_16(sum_row_first, sum_row_second, y_pre, y_count,
                           y_accum);
 }
@@ -634,7 +574,8 @@ static void vp9_apply_temporal_filter_chroma_8(
     const int16_t *const *neighbors, int top_weight, int bottom_weight,
     const int *blk_fw) {
   const int rounding = (1 << strength) >> 1;
-  int weight = top_weight;
+
+  __m128i weight;
 
   __m128i mul;
 
@@ -648,8 +589,16 @@ static void vp9_apply_temporal_filter_chroma_8(
 
   (void)uv_block_width;
 
+  // Initilize weight
+  if (blk_fw) {
+    weight = _mm_setr_epi16(blk_fw[0], blk_fw[0], blk_fw[0], blk_fw[0],
+                            blk_fw[1], blk_fw[1], blk_fw[1], blk_fw[1]);
+  } else {
+    weight = _mm_set1_epi16(top_weight);
+  }
+
   // First row
-  mul = _mm_loadu_si128((const __m128i *)neighbors[0]);
+  mul = _mm_load_si128((const __m128i *)neighbors[0]);
 
   // Add chroma values
   get_sum_8(u_dist, &u_sum_row_2);
@@ -666,15 +615,9 @@ static void vp9_apply_temporal_filter_chroma_8(
   add_luma_dist_to_8_chroma_mod(y_dist, ss_x, ss_y, &u_sum_row, &v_sum_row);
 
   // Get modifier and store result
-  if (blk_fw) {
-    u_sum_row =
-        average_4_4(u_sum_row, &mul, strength, rounding, blk_fw[0], blk_fw[1]);
-    v_sum_row =
-        average_4_4(v_sum_row, &mul, strength, rounding, blk_fw[0], blk_fw[1]);
-  } else {
-    u_sum_row = average_8(u_sum_row, &mul, strength, rounding, weight);
-    v_sum_row = average_8(v_sum_row, &mul, strength, rounding, weight);
-  }
+  u_sum_row = average_8(u_sum_row, &mul, strength, rounding, &weight);
+  v_sum_row = average_8(v_sum_row, &mul, strength, rounding, &weight);
+
   accumulate_and_store_8(u_sum_row, u_pre, u_count, u_accum);
   accumulate_and_store_8(v_sum_row, v_pre, v_count, v_accum);
 
@@ -694,15 +637,16 @@ static void vp9_apply_temporal_filter_chroma_8(
   y_dist += DIST_STRIDE * (1 + ss_y);
 
   // Then all the rows except the last one
-  mul = _mm_loadu_si128((const __m128i *)neighbors[1]);
+  mul = _mm_load_si128((const __m128i *)neighbors[1]);
 
   for (h = 1; h < uv_block_height - 1; ++h) {
     // Move the weight pointer to the bottom half of the blocks
     if (h == uv_block_height / 2) {
       if (blk_fw) {
-        blk_fw += 2;
+        weight = _mm_setr_epi16(blk_fw[2], blk_fw[2], blk_fw[2], blk_fw[2],
+                                blk_fw[3], blk_fw[3], blk_fw[3], blk_fw[3]);
       } else {
-        weight = bottom_weight;
+        weight = _mm_set1_epi16(bottom_weight);
       }
     }
 
@@ -726,15 +670,8 @@ static void vp9_apply_temporal_filter_chroma_8(
     add_luma_dist_to_8_chroma_mod(y_dist, ss_x, ss_y, &u_sum_row, &v_sum_row);
 
     // Get modifier and store result
-    if (blk_fw) {
-      u_sum_row = average_4_4(u_sum_row, &mul, strength, rounding, blk_fw[0],
-                              blk_fw[1]);
-      v_sum_row = average_4_4(v_sum_row, &mul, strength, rounding, blk_fw[0],
-                              blk_fw[1]);
-    } else {
-      u_sum_row = average_8(u_sum_row, &mul, strength, rounding, weight);
-      v_sum_row = average_8(v_sum_row, &mul, strength, rounding, weight);
-    }
+    u_sum_row = average_8(u_sum_row, &mul, strength, rounding, &weight);
+    v_sum_row = average_8(v_sum_row, &mul, strength, rounding, &weight);
 
     accumulate_and_store_8(u_sum_row, u_pre, u_count, u_accum);
     accumulate_and_store_8(v_sum_row, v_pre, v_count, v_accum);
@@ -756,7 +693,7 @@ static void vp9_apply_temporal_filter_chroma_8(
   }
 
   // The last row
-  mul = _mm_loadu_si128((const __m128i *)neighbors[0]);
+  mul = _mm_load_si128((const __m128i *)neighbors[0]);
 
   // Shift the rows up
   u_sum_row_1 = u_sum_row_2;
@@ -773,15 +710,8 @@ static void vp9_apply_temporal_filter_chroma_8(
   add_luma_dist_to_8_chroma_mod(y_dist, ss_x, ss_y, &u_sum_row, &v_sum_row);
 
   // Get modifier and store result
-  if (blk_fw) {
-    u_sum_row =
-        average_4_4(u_sum_row, &mul, strength, rounding, blk_fw[0], blk_fw[1]);
-    v_sum_row =
-        average_4_4(v_sum_row, &mul, strength, rounding, blk_fw[0], blk_fw[1]);
-  } else {
-    u_sum_row = average_8(u_sum_row, &mul, strength, rounding, weight);
-    v_sum_row = average_8(v_sum_row, &mul, strength, rounding, weight);
-  }
+  u_sum_row = average_8(u_sum_row, &mul, strength, rounding, &weight);
+  v_sum_row = average_8(v_sum_row, &mul, strength, rounding, &weight);
 
   accumulate_and_store_8(u_sum_row, u_pre, u_count, u_accum);
   accumulate_and_store_8(v_sum_row, v_pre, v_count, v_accum);