vp9[loongarch]: Optimize fdct4x4/8x8_lsx
[platform/upstream/libvpx.git] / test / yuv_temporal_filter_test.cc
1 /*
2  *  Copyright (c) 2019 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10
11 #include "third_party/googletest/src/include/gtest/gtest.h"
12
13 #include "./vp9_rtcd.h"
14 #include "test/acm_random.h"
15 #include "test/buffer.h"
16 #include "test/register_state_check.h"
17 #include "vpx_ports/vpx_timer.h"
18
19 namespace {
20
21 using ::libvpx_test::ACMRandom;
22 using ::libvpx_test::Buffer;
23
24 typedef void (*YUVTemporalFilterFunc)(
25     const uint8_t *y_src, int y_src_stride, const uint8_t *y_pre,
26     int y_pre_stride, const uint8_t *u_src, const uint8_t *v_src,
27     int uv_src_stride, const uint8_t *u_pre, const uint8_t *v_pre,
28     int uv_pre_stride, unsigned int block_width, unsigned int block_height,
29     int ss_x, int ss_y, int strength, const int *const blk_fw, int use_32x32,
30     uint32_t *y_accumulator, uint16_t *y_count, uint32_t *u_accumulator,
31     uint16_t *u_count, uint32_t *v_accumulator, uint16_t *v_count);
32
33 struct TemporalFilterWithBd {
34   TemporalFilterWithBd(YUVTemporalFilterFunc func, int bitdepth)
35       : temporal_filter(func), bd(bitdepth) {}
36
37   YUVTemporalFilterFunc temporal_filter;
38   int bd;
39 };
40
41 std::ostream &operator<<(std::ostream &os, const TemporalFilterWithBd &tf) {
42   return os << "Bitdepth: " << tf.bd;
43 }
44
45 int GetFilterWeight(unsigned int row, unsigned int col,
46                     unsigned int block_height, unsigned int block_width,
47                     const int *const blk_fw, int use_32x32) {
48   if (use_32x32) {
49     return blk_fw[0];
50   }
51
52   return blk_fw[2 * (row >= block_height / 2) + (col >= block_width / 2)];
53 }
54
55 template <typename PixelType>
56 int GetModIndex(int sum_dist, int index, int rounding, int strength,
57                 int filter_weight) {
58   int mod = sum_dist * 3 / index;
59   mod += rounding;
60   mod >>= strength;
61
62   mod = VPXMIN(16, mod);
63
64   mod = 16 - mod;
65   mod *= filter_weight;
66
67   return mod;
68 }
69
70 template <>
71 int GetModIndex<uint8_t>(int sum_dist, int index, int rounding, int strength,
72                          int filter_weight) {
73   unsigned int index_mult[14] = { 0,     0,     0,     0,     49152,
74                                   39322, 32768, 28087, 24576, 21846,
75                                   19661, 17874, 0,     15124 };
76
77   assert(index >= 0 && index <= 13);
78   assert(index_mult[index] != 0);
79
80   int mod = (clamp(sum_dist, 0, UINT16_MAX) * index_mult[index]) >> 16;
81   mod += rounding;
82   mod >>= strength;
83
84   mod = VPXMIN(16, mod);
85
86   mod = 16 - mod;
87   mod *= filter_weight;
88
89   return mod;
90 }
91
92 template <>
93 int GetModIndex<uint16_t>(int sum_dist, int index, int rounding, int strength,
94                           int filter_weight) {
95   int64_t index_mult[14] = { 0U,          0U,          0U,          0U,
96                              3221225472U, 2576980378U, 2147483648U, 1840700270U,
97                              1610612736U, 1431655766U, 1288490189U, 1171354718U,
98                              0U,          991146300U };
99
100   assert(index >= 0 && index <= 13);
101   assert(index_mult[index] != 0);
102
103   int mod = static_cast<int>((sum_dist * index_mult[index]) >> 32);
104   mod += rounding;
105   mod >>= strength;
106
107   mod = VPXMIN(16, mod);
108
109   mod = 16 - mod;
110   mod *= filter_weight;
111
112   return mod;
113 }
114
115 template <typename PixelType>
116 void ApplyReferenceFilter(
117     const Buffer<PixelType> &y_src, const Buffer<PixelType> &y_pre,
118     const Buffer<PixelType> &u_src, const Buffer<PixelType> &v_src,
119     const Buffer<PixelType> &u_pre, const Buffer<PixelType> &v_pre,
120     unsigned int block_width, unsigned int block_height, int ss_x, int ss_y,
121     int strength, const int *const blk_fw, int use_32x32,
122     Buffer<uint32_t> *y_accumulator, Buffer<uint16_t> *y_counter,
123     Buffer<uint32_t> *u_accumulator, Buffer<uint16_t> *u_counter,
124     Buffer<uint32_t> *v_accumulator, Buffer<uint16_t> *v_counter) {
125   const PixelType *y_src_ptr = y_src.TopLeftPixel();
126   const PixelType *y_pre_ptr = y_pre.TopLeftPixel();
127   const PixelType *u_src_ptr = u_src.TopLeftPixel();
128   const PixelType *u_pre_ptr = u_pre.TopLeftPixel();
129   const PixelType *v_src_ptr = v_src.TopLeftPixel();
130   const PixelType *v_pre_ptr = v_pre.TopLeftPixel();
131
132   const int uv_block_width = block_width >> ss_x,
133             uv_block_height = block_height >> ss_y;
134   const int y_src_stride = y_src.stride(), y_pre_stride = y_pre.stride();
135   const int uv_src_stride = u_src.stride(), uv_pre_stride = u_pre.stride();
136   const int y_diff_stride = block_width, uv_diff_stride = uv_block_width;
137
138   Buffer<int> y_dif = Buffer<int>(block_width, block_height, 0);
139   Buffer<int> u_dif = Buffer<int>(uv_block_width, uv_block_height, 0);
140   Buffer<int> v_dif = Buffer<int>(uv_block_width, uv_block_height, 0);
141
142   ASSERT_TRUE(y_dif.Init());
143   ASSERT_TRUE(u_dif.Init());
144   ASSERT_TRUE(v_dif.Init());
145   y_dif.Set(0);
146   u_dif.Set(0);
147   v_dif.Set(0);
148
149   int *y_diff_ptr = y_dif.TopLeftPixel();
150   int *u_diff_ptr = u_dif.TopLeftPixel();
151   int *v_diff_ptr = v_dif.TopLeftPixel();
152
153   uint32_t *y_accum = y_accumulator->TopLeftPixel();
154   uint32_t *u_accum = u_accumulator->TopLeftPixel();
155   uint32_t *v_accum = v_accumulator->TopLeftPixel();
156   uint16_t *y_count = y_counter->TopLeftPixel();
157   uint16_t *u_count = u_counter->TopLeftPixel();
158   uint16_t *v_count = v_counter->TopLeftPixel();
159
160   const int y_accum_stride = y_accumulator->stride();
161   const int u_accum_stride = u_accumulator->stride();
162   const int v_accum_stride = v_accumulator->stride();
163   const int y_count_stride = y_counter->stride();
164   const int u_count_stride = u_counter->stride();
165   const int v_count_stride = v_counter->stride();
166
167   const int rounding = (1 << strength) >> 1;
168
169   // Get the square diffs
170   for (int row = 0; row < static_cast<int>(block_height); row++) {
171     for (int col = 0; col < static_cast<int>(block_width); col++) {
172       const int diff = y_src_ptr[row * y_src_stride + col] -
173                        y_pre_ptr[row * y_pre_stride + col];
174       y_diff_ptr[row * y_diff_stride + col] = diff * diff;
175     }
176   }
177
178   for (int row = 0; row < uv_block_height; row++) {
179     for (int col = 0; col < uv_block_width; col++) {
180       const int u_diff = u_src_ptr[row * uv_src_stride + col] -
181                          u_pre_ptr[row * uv_pre_stride + col];
182       const int v_diff = v_src_ptr[row * uv_src_stride + col] -
183                          v_pre_ptr[row * uv_pre_stride + col];
184       u_diff_ptr[row * uv_diff_stride + col] = u_diff * u_diff;
185       v_diff_ptr[row * uv_diff_stride + col] = v_diff * v_diff;
186     }
187   }
188
189   // Apply the filter to luma
190   for (int row = 0; row < static_cast<int>(block_height); row++) {
191     for (int col = 0; col < static_cast<int>(block_width); col++) {
192       const int uv_row = row >> ss_y;
193       const int uv_col = col >> ss_x;
194       const int filter_weight = GetFilterWeight(row, col, block_height,
195                                                 block_width, blk_fw, use_32x32);
196
197       // First we get the modifier for the current y pixel
198       const int y_pixel = y_pre_ptr[row * y_pre_stride + col];
199       int y_num_used = 0;
200       int y_mod = 0;
201
202       // Sum the neighboring 3x3 y pixels
203       for (int row_step = -1; row_step <= 1; row_step++) {
204         for (int col_step = -1; col_step <= 1; col_step++) {
205           const int sub_row = row + row_step;
206           const int sub_col = col + col_step;
207
208           if (sub_row >= 0 && sub_row < static_cast<int>(block_height) &&
209               sub_col >= 0 && sub_col < static_cast<int>(block_width)) {
210             y_mod += y_diff_ptr[sub_row * y_diff_stride + sub_col];
211             y_num_used++;
212           }
213         }
214       }
215
216       // Sum the corresponding uv pixels to the current y modifier
217       // Note we are rounding down instead of rounding to the nearest pixel.
218       y_mod += u_diff_ptr[uv_row * uv_diff_stride + uv_col];
219       y_mod += v_diff_ptr[uv_row * uv_diff_stride + uv_col];
220
221       y_num_used += 2;
222
223       // Set the modifier
224       y_mod = GetModIndex<PixelType>(y_mod, y_num_used, rounding, strength,
225                                      filter_weight);
226
227       // Accumulate the result
228       y_count[row * y_count_stride + col] += y_mod;
229       y_accum[row * y_accum_stride + col] += y_mod * y_pixel;
230     }
231   }
232
233   // Apply the filter to chroma
234   for (int uv_row = 0; uv_row < uv_block_height; uv_row++) {
235     for (int uv_col = 0; uv_col < uv_block_width; uv_col++) {
236       const int y_row = uv_row << ss_y;
237       const int y_col = uv_col << ss_x;
238       const int filter_weight = GetFilterWeight(
239           uv_row, uv_col, uv_block_height, uv_block_width, blk_fw, use_32x32);
240
241       const int u_pixel = u_pre_ptr[uv_row * uv_pre_stride + uv_col];
242       const int v_pixel = v_pre_ptr[uv_row * uv_pre_stride + uv_col];
243
244       int uv_num_used = 0;
245       int u_mod = 0, v_mod = 0;
246
247       // Sum the neighboring 3x3 chromal pixels to the chroma modifier
248       for (int row_step = -1; row_step <= 1; row_step++) {
249         for (int col_step = -1; col_step <= 1; col_step++) {
250           const int sub_row = uv_row + row_step;
251           const int sub_col = uv_col + col_step;
252
253           if (sub_row >= 0 && sub_row < uv_block_height && sub_col >= 0 &&
254               sub_col < uv_block_width) {
255             u_mod += u_diff_ptr[sub_row * uv_diff_stride + sub_col];
256             v_mod += v_diff_ptr[sub_row * uv_diff_stride + sub_col];
257             uv_num_used++;
258           }
259         }
260       }
261
262       // Sum all the luma pixels associated with the current luma pixel
263       for (int row_step = 0; row_step < 1 + ss_y; row_step++) {
264         for (int col_step = 0; col_step < 1 + ss_x; col_step++) {
265           const int sub_row = y_row + row_step;
266           const int sub_col = y_col + col_step;
267           const int y_diff = y_diff_ptr[sub_row * y_diff_stride + sub_col];
268
269           u_mod += y_diff;
270           v_mod += y_diff;
271           uv_num_used++;
272         }
273       }
274
275       // Set the modifier
276       u_mod = GetModIndex<PixelType>(u_mod, uv_num_used, rounding, strength,
277                                      filter_weight);
278       v_mod = GetModIndex<PixelType>(v_mod, uv_num_used, rounding, strength,
279                                      filter_weight);
280
281       // Accumulate the result
282       u_count[uv_row * u_count_stride + uv_col] += u_mod;
283       u_accum[uv_row * u_accum_stride + uv_col] += u_mod * u_pixel;
284       v_count[uv_row * v_count_stride + uv_col] += v_mod;
285       v_accum[uv_row * v_accum_stride + uv_col] += v_mod * v_pixel;
286     }
287   }
288 }
289
290 class YUVTemporalFilterTest
291     : public ::testing::TestWithParam<TemporalFilterWithBd> {
292  public:
293   virtual void SetUp() {
294     filter_func_ = GetParam().temporal_filter;
295     bd_ = GetParam().bd;
296     use_highbd_ = (bd_ != 8);
297
298     rnd_.Reset(ACMRandom::DeterministicSeed());
299     saturate_test_ = 0;
300     num_repeats_ = 10;
301
302     ASSERT_TRUE(bd_ == 8 || bd_ == 10 || bd_ == 12);
303   }
304
305  protected:
306   template <typename PixelType>
307   void CompareTestWithParam(int width, int height, int ss_x, int ss_y,
308                             int filter_strength, int use_32x32,
309                             const int *filter_weight);
310   template <typename PixelType>
311   void RunTestFilterWithParam(int width, int height, int ss_x, int ss_y,
312                               int filter_strength, int use_32x32,
313                               const int *filter_weight);
314   YUVTemporalFilterFunc filter_func_;
315   ACMRandom rnd_;
316   int saturate_test_;
317   int num_repeats_;
318   int use_highbd_;
319   int bd_;
320 };
321
322 template <typename PixelType>
323 void YUVTemporalFilterTest::CompareTestWithParam(int width, int height,
324                                                  int ss_x, int ss_y,
325                                                  int filter_strength,
326                                                  int use_32x32,
327                                                  const int *filter_weight) {
328   const int uv_width = width >> ss_x, uv_height = height >> ss_y;
329
330   Buffer<PixelType> y_src = Buffer<PixelType>(width, height, 0);
331   Buffer<PixelType> y_pre = Buffer<PixelType>(width, height, 0);
332   Buffer<uint16_t> y_count_ref = Buffer<uint16_t>(width, height, 0);
333   Buffer<uint32_t> y_accum_ref = Buffer<uint32_t>(width, height, 0);
334   Buffer<uint16_t> y_count_tst = Buffer<uint16_t>(width, height, 0);
335   Buffer<uint32_t> y_accum_tst = Buffer<uint32_t>(width, height, 0);
336
337   Buffer<PixelType> u_src = Buffer<PixelType>(uv_width, uv_height, 0);
338   Buffer<PixelType> u_pre = Buffer<PixelType>(uv_width, uv_height, 0);
339   Buffer<uint16_t> u_count_ref = Buffer<uint16_t>(uv_width, uv_height, 0);
340   Buffer<uint32_t> u_accum_ref = Buffer<uint32_t>(uv_width, uv_height, 0);
341   Buffer<uint16_t> u_count_tst = Buffer<uint16_t>(uv_width, uv_height, 0);
342   Buffer<uint32_t> u_accum_tst = Buffer<uint32_t>(uv_width, uv_height, 0);
343
344   Buffer<PixelType> v_src = Buffer<PixelType>(uv_width, uv_height, 0);
345   Buffer<PixelType> v_pre = Buffer<PixelType>(uv_width, uv_height, 0);
346   Buffer<uint16_t> v_count_ref = Buffer<uint16_t>(uv_width, uv_height, 0);
347   Buffer<uint32_t> v_accum_ref = Buffer<uint32_t>(uv_width, uv_height, 0);
348   Buffer<uint16_t> v_count_tst = Buffer<uint16_t>(uv_width, uv_height, 0);
349   Buffer<uint32_t> v_accum_tst = Buffer<uint32_t>(uv_width, uv_height, 0);
350
351   ASSERT_TRUE(y_src.Init());
352   ASSERT_TRUE(y_pre.Init());
353   ASSERT_TRUE(y_count_ref.Init());
354   ASSERT_TRUE(y_accum_ref.Init());
355   ASSERT_TRUE(y_count_tst.Init());
356   ASSERT_TRUE(y_accum_tst.Init());
357   ASSERT_TRUE(u_src.Init());
358   ASSERT_TRUE(u_pre.Init());
359   ASSERT_TRUE(u_count_ref.Init());
360   ASSERT_TRUE(u_accum_ref.Init());
361   ASSERT_TRUE(u_count_tst.Init());
362   ASSERT_TRUE(u_accum_tst.Init());
363
364   ASSERT_TRUE(v_src.Init());
365   ASSERT_TRUE(v_pre.Init());
366   ASSERT_TRUE(v_count_ref.Init());
367   ASSERT_TRUE(v_accum_ref.Init());
368   ASSERT_TRUE(v_count_tst.Init());
369   ASSERT_TRUE(v_accum_tst.Init());
370
371   y_accum_ref.Set(0);
372   y_accum_tst.Set(0);
373   y_count_ref.Set(0);
374   y_count_tst.Set(0);
375   u_accum_ref.Set(0);
376   u_accum_tst.Set(0);
377   u_count_ref.Set(0);
378   u_count_tst.Set(0);
379   v_accum_ref.Set(0);
380   v_accum_tst.Set(0);
381   v_count_ref.Set(0);
382   v_count_tst.Set(0);
383
384   for (int repeats = 0; repeats < num_repeats_; repeats++) {
385     if (saturate_test_) {
386       const int max_val = (1 << bd_) - 1;
387       y_src.Set(max_val);
388       y_pre.Set(0);
389       u_src.Set(max_val);
390       u_pre.Set(0);
391       v_src.Set(max_val);
392       v_pre.Set(0);
393     } else {
394       y_src.Set(&rnd_, 0, 7 << (bd_ - 8));
395       y_pre.Set(&rnd_, 0, 7 << (bd_ - 8));
396       u_src.Set(&rnd_, 0, 7 << (bd_ - 8));
397       u_pre.Set(&rnd_, 0, 7 << (bd_ - 8));
398       v_src.Set(&rnd_, 0, 7 << (bd_ - 8));
399       v_pre.Set(&rnd_, 0, 7 << (bd_ - 8));
400     }
401
402     ApplyReferenceFilter<PixelType>(
403         y_src, y_pre, u_src, v_src, u_pre, v_pre, width, height, ss_x, ss_y,
404         filter_strength, filter_weight, use_32x32, &y_accum_ref, &y_count_ref,
405         &u_accum_ref, &u_count_ref, &v_accum_ref, &v_count_ref);
406
407     ASM_REGISTER_STATE_CHECK(filter_func_(
408         reinterpret_cast<const uint8_t *>(y_src.TopLeftPixel()), y_src.stride(),
409         reinterpret_cast<const uint8_t *>(y_pre.TopLeftPixel()), y_pre.stride(),
410         reinterpret_cast<const uint8_t *>(u_src.TopLeftPixel()),
411         reinterpret_cast<const uint8_t *>(v_src.TopLeftPixel()), u_src.stride(),
412         reinterpret_cast<const uint8_t *>(u_pre.TopLeftPixel()),
413         reinterpret_cast<const uint8_t *>(v_pre.TopLeftPixel()), u_pre.stride(),
414         width, height, ss_x, ss_y, filter_strength, filter_weight, use_32x32,
415         y_accum_tst.TopLeftPixel(), y_count_tst.TopLeftPixel(),
416         u_accum_tst.TopLeftPixel(), u_count_tst.TopLeftPixel(),
417         v_accum_tst.TopLeftPixel(), v_count_tst.TopLeftPixel()));
418
419     EXPECT_TRUE(y_accum_tst.CheckValues(y_accum_ref));
420     EXPECT_TRUE(y_count_tst.CheckValues(y_count_ref));
421     EXPECT_TRUE(u_accum_tst.CheckValues(u_accum_ref));
422     EXPECT_TRUE(u_count_tst.CheckValues(u_count_ref));
423     EXPECT_TRUE(v_accum_tst.CheckValues(v_accum_ref));
424     EXPECT_TRUE(v_count_tst.CheckValues(v_count_ref));
425
426     if (HasFailure()) {
427       if (use_32x32) {
428         printf("SS_X: %d, SS_Y: %d, Strength: %d, Weight: %d\n", ss_x, ss_y,
429                filter_strength, *filter_weight);
430       } else {
431         printf("SS_X: %d, SS_Y: %d, Strength: %d, Weights: %d,%d,%d,%d\n", ss_x,
432                ss_y, filter_strength, filter_weight[0], filter_weight[1],
433                filter_weight[2], filter_weight[3]);
434       }
435       y_accum_tst.PrintDifference(y_accum_ref);
436       y_count_tst.PrintDifference(y_count_ref);
437       u_accum_tst.PrintDifference(u_accum_ref);
438       u_count_tst.PrintDifference(u_count_ref);
439       v_accum_tst.PrintDifference(v_accum_ref);
440       v_count_tst.PrintDifference(v_count_ref);
441
442       return;
443     }
444   }
445 }
446
447 template <typename PixelType>
448 void YUVTemporalFilterTest::RunTestFilterWithParam(int width, int height,
449                                                    int ss_x, int ss_y,
450                                                    int filter_strength,
451                                                    int use_32x32,
452                                                    const int *filter_weight) {
453   const int uv_width = width >> ss_x, uv_height = height >> ss_y;
454
455   Buffer<PixelType> y_src = Buffer<PixelType>(width, height, 0);
456   Buffer<PixelType> y_pre = Buffer<PixelType>(width, height, 0);
457   Buffer<uint16_t> y_count = Buffer<uint16_t>(width, height, 0);
458   Buffer<uint32_t> y_accum = Buffer<uint32_t>(width, height, 0);
459
460   Buffer<PixelType> u_src = Buffer<PixelType>(uv_width, uv_height, 0);
461   Buffer<PixelType> u_pre = Buffer<PixelType>(uv_width, uv_height, 0);
462   Buffer<uint16_t> u_count = Buffer<uint16_t>(uv_width, uv_height, 0);
463   Buffer<uint32_t> u_accum = Buffer<uint32_t>(uv_width, uv_height, 0);
464
465   Buffer<PixelType> v_src = Buffer<PixelType>(uv_width, uv_height, 0);
466   Buffer<PixelType> v_pre = Buffer<PixelType>(uv_width, uv_height, 0);
467   Buffer<uint16_t> v_count = Buffer<uint16_t>(uv_width, uv_height, 0);
468   Buffer<uint32_t> v_accum = Buffer<uint32_t>(uv_width, uv_height, 0);
469
470   ASSERT_TRUE(y_src.Init());
471   ASSERT_TRUE(y_pre.Init());
472   ASSERT_TRUE(y_count.Init());
473   ASSERT_TRUE(y_accum.Init());
474
475   ASSERT_TRUE(u_src.Init());
476   ASSERT_TRUE(u_pre.Init());
477   ASSERT_TRUE(u_count.Init());
478   ASSERT_TRUE(u_accum.Init());
479
480   ASSERT_TRUE(v_src.Init());
481   ASSERT_TRUE(v_pre.Init());
482   ASSERT_TRUE(v_count.Init());
483   ASSERT_TRUE(v_accum.Init());
484
485   y_accum.Set(0);
486   y_count.Set(0);
487
488   u_accum.Set(0);
489   u_count.Set(0);
490
491   v_accum.Set(0);
492   v_count.Set(0);
493
494   y_src.Set(&rnd_, 0, 7 << (bd_ - 8));
495   y_pre.Set(&rnd_, 0, 7 << (bd_ - 8));
496   u_src.Set(&rnd_, 0, 7 << (bd_ - 8));
497   u_pre.Set(&rnd_, 0, 7 << (bd_ - 8));
498   v_src.Set(&rnd_, 0, 7 << (bd_ - 8));
499   v_pre.Set(&rnd_, 0, 7 << (bd_ - 8));
500
501   for (int repeats = 0; repeats < num_repeats_; repeats++) {
502     ASM_REGISTER_STATE_CHECK(filter_func_(
503         reinterpret_cast<const uint8_t *>(y_src.TopLeftPixel()), y_src.stride(),
504         reinterpret_cast<const uint8_t *>(y_pre.TopLeftPixel()), y_pre.stride(),
505         reinterpret_cast<const uint8_t *>(u_src.TopLeftPixel()),
506         reinterpret_cast<const uint8_t *>(v_src.TopLeftPixel()), u_src.stride(),
507         reinterpret_cast<const uint8_t *>(u_pre.TopLeftPixel()),
508         reinterpret_cast<const uint8_t *>(v_pre.TopLeftPixel()), u_pre.stride(),
509         width, height, ss_x, ss_y, filter_strength, filter_weight, use_32x32,
510         y_accum.TopLeftPixel(), y_count.TopLeftPixel(), u_accum.TopLeftPixel(),
511         u_count.TopLeftPixel(), v_accum.TopLeftPixel(),
512         v_count.TopLeftPixel()));
513   }
514 }
515
516 TEST_P(YUVTemporalFilterTest, Use32x32) {
517   const int width = 32, height = 32;
518   const int use_32x32 = 1;
519
520   for (int ss_x = 0; ss_x <= 1; ss_x++) {
521     for (int ss_y = 0; ss_y <= 1; ss_y++) {
522       for (int filter_strength = 0; filter_strength <= 6;
523            filter_strength += 2) {
524         for (int filter_weight = 0; filter_weight <= 2; filter_weight++) {
525           if (use_highbd_) {
526             const int adjusted_strength = filter_strength + 2 * (bd_ - 8);
527             CompareTestWithParam<uint16_t>(width, height, ss_x, ss_y,
528                                            adjusted_strength, use_32x32,
529                                            &filter_weight);
530           } else {
531             CompareTestWithParam<uint8_t>(width, height, ss_x, ss_y,
532                                           filter_strength, use_32x32,
533                                           &filter_weight);
534           }
535           ASSERT_FALSE(HasFailure());
536         }
537       }
538     }
539   }
540 }
541
542 TEST_P(YUVTemporalFilterTest, Use16x16) {
543   const int width = 32, height = 32;
544   const int use_32x32 = 0;
545
546   for (int ss_x = 0; ss_x <= 1; ss_x++) {
547     for (int ss_y = 0; ss_y <= 1; ss_y++) {
548       for (int filter_idx = 0; filter_idx < 3 * 3 * 3 * 3; filter_idx++) {
549         // Set up the filter
550         int filter_weight[4];
551         int filter_idx_cp = filter_idx;
552         for (int idx = 0; idx < 4; idx++) {
553           filter_weight[idx] = filter_idx_cp % 3;
554           filter_idx_cp /= 3;
555         }
556
557         // Test each parameter
558         for (int filter_strength = 0; filter_strength <= 6;
559              filter_strength += 2) {
560           if (use_highbd_) {
561             const int adjusted_strength = filter_strength + 2 * (bd_ - 8);
562             CompareTestWithParam<uint16_t>(width, height, ss_x, ss_y,
563                                            adjusted_strength, use_32x32,
564                                            filter_weight);
565           } else {
566             CompareTestWithParam<uint8_t>(width, height, ss_x, ss_y,
567                                           filter_strength, use_32x32,
568                                           filter_weight);
569           }
570
571           ASSERT_FALSE(HasFailure());
572         }
573       }
574     }
575   }
576 }
577
578 TEST_P(YUVTemporalFilterTest, SaturationTest) {
579   const int width = 32, height = 32;
580   const int use_32x32 = 1;
581   const int filter_weight = 1;
582   saturate_test_ = 1;
583
584   for (int ss_x = 0; ss_x <= 1; ss_x++) {
585     for (int ss_y = 0; ss_y <= 1; ss_y++) {
586       for (int filter_strength = 0; filter_strength <= 6;
587            filter_strength += 2) {
588         if (use_highbd_) {
589           const int adjusted_strength = filter_strength + 2 * (bd_ - 8);
590           CompareTestWithParam<uint16_t>(width, height, ss_x, ss_y,
591                                          adjusted_strength, use_32x32,
592                                          &filter_weight);
593         } else {
594           CompareTestWithParam<uint8_t>(width, height, ss_x, ss_y,
595                                         filter_strength, use_32x32,
596                                         &filter_weight);
597         }
598
599         ASSERT_FALSE(HasFailure());
600       }
601     }
602   }
603 }
604
605 TEST_P(YUVTemporalFilterTest, DISABLED_Speed) {
606   const int width = 32, height = 32;
607   num_repeats_ = 1000;
608
609   for (int use_32x32 = 0; use_32x32 <= 1; use_32x32++) {
610     const int num_filter_weights = use_32x32 ? 3 : 3 * 3 * 3 * 3;
611     for (int ss_x = 0; ss_x <= 1; ss_x++) {
612       for (int ss_y = 0; ss_y <= 1; ss_y++) {
613         for (int filter_idx = 0; filter_idx < num_filter_weights;
614              filter_idx++) {
615           // Set up the filter
616           int filter_weight[4];
617           int filter_idx_cp = filter_idx;
618           for (int idx = 0; idx < 4; idx++) {
619             filter_weight[idx] = filter_idx_cp % 3;
620             filter_idx_cp /= 3;
621           }
622
623           // Test each parameter
624           for (int filter_strength = 0; filter_strength <= 6;
625                filter_strength += 2) {
626             vpx_usec_timer timer;
627             vpx_usec_timer_start(&timer);
628
629             if (use_highbd_) {
630               RunTestFilterWithParam<uint16_t>(width, height, ss_x, ss_y,
631                                                filter_strength, use_32x32,
632                                                filter_weight);
633             } else {
634               RunTestFilterWithParam<uint8_t>(width, height, ss_x, ss_y,
635                                               filter_strength, use_32x32,
636                                               filter_weight);
637             }
638
639             vpx_usec_timer_mark(&timer);
640             const int elapsed_time =
641                 static_cast<int>(vpx_usec_timer_elapsed(&timer));
642
643             printf(
644                 "Bitdepth: %d, Use 32X32: %d, SS_X: %d, SS_Y: %d, Weight Idx: "
645                 "%d, Strength: %d, Time: %5d\n",
646                 bd_, use_32x32, ss_x, ss_y, filter_idx, filter_strength,
647                 elapsed_time);
648           }
649         }
650       }
651     }
652   }
653 }
654
655 #if CONFIG_VP9_HIGHBITDEPTH
656 #define WRAP_HIGHBD_FUNC(func, bd)                                            \
657   void wrap_##func##_##bd(                                                    \
658       const uint8_t *y_src, int y_src_stride, const uint8_t *y_pre,           \
659       int y_pre_stride, const uint8_t *u_src, const uint8_t *v_src,           \
660       int uv_src_stride, const uint8_t *u_pre, const uint8_t *v_pre,          \
661       int uv_pre_stride, unsigned int block_width, unsigned int block_height, \
662       int ss_x, int ss_y, int strength, const int *const blk_fw,              \
663       int use_32x32, uint32_t *y_accumulator, uint16_t *y_count,              \
664       uint32_t *u_accumulator, uint16_t *u_count, uint32_t *v_accumulator,    \
665       uint16_t *v_count) {                                                    \
666     func(reinterpret_cast<const uint16_t *>(y_src), y_src_stride,             \
667          reinterpret_cast<const uint16_t *>(y_pre), y_pre_stride,             \
668          reinterpret_cast<const uint16_t *>(u_src),                           \
669          reinterpret_cast<const uint16_t *>(v_src), uv_src_stride,            \
670          reinterpret_cast<const uint16_t *>(u_pre),                           \
671          reinterpret_cast<const uint16_t *>(v_pre), uv_pre_stride,            \
672          block_width, block_height, ss_x, ss_y, strength, blk_fw, use_32x32,  \
673          y_accumulator, y_count, u_accumulator, u_count, v_accumulator,       \
674          v_count);                                                            \
675   }
676
677 WRAP_HIGHBD_FUNC(vp9_highbd_apply_temporal_filter_c, 10)
678 WRAP_HIGHBD_FUNC(vp9_highbd_apply_temporal_filter_c, 12)
679
680 INSTANTIATE_TEST_SUITE_P(
681     C, YUVTemporalFilterTest,
682     ::testing::Values(
683         TemporalFilterWithBd(&wrap_vp9_highbd_apply_temporal_filter_c_10, 10),
684         TemporalFilterWithBd(&wrap_vp9_highbd_apply_temporal_filter_c_12, 12)));
685 #if HAVE_SSE4_1
686 WRAP_HIGHBD_FUNC(vp9_highbd_apply_temporal_filter_sse4_1, 10)
687 WRAP_HIGHBD_FUNC(vp9_highbd_apply_temporal_filter_sse4_1, 12)
688
689 INSTANTIATE_TEST_SUITE_P(
690     SSE4_1, YUVTemporalFilterTest,
691     ::testing::Values(
692         TemporalFilterWithBd(&wrap_vp9_highbd_apply_temporal_filter_sse4_1_10,
693                              10),
694         TemporalFilterWithBd(&wrap_vp9_highbd_apply_temporal_filter_sse4_1_12,
695                              12)));
696 #endif  // HAVE_SSE4_1
697 #else
698 INSTANTIATE_TEST_SUITE_P(
699     C, YUVTemporalFilterTest,
700     ::testing::Values(TemporalFilterWithBd(&vp9_apply_temporal_filter_c, 8)));
701
702 #if HAVE_SSE4_1
703 INSTANTIATE_TEST_SUITE_P(SSE4_1, YUVTemporalFilterTest,
704                          ::testing::Values(TemporalFilterWithBd(
705                              &vp9_apply_temporal_filter_sse4_1, 8)));
706 #endif  // HAVE_SSE4_1
707 #endif  // CONFIG_VP9_HIGHBITDEPTH
708 }  // namespace