Add Tune for SSIM
authorsdeng <sdeng@google.com>
Fri, 29 Mar 2019 16:18:56 +0000 (09:18 -0700)
committersdeng <sdeng@google.com>
Fri, 12 Apr 2019 22:02:05 +0000 (15:02 -0700)
Implementation with some tuning of the paper:
C. Yeo, H. L. Tan, and Y. H. Tan, "On rate distortion optimization using
SSIM," Circuits and Systems for Video Technology, IEEE Transactions on,
vol. 23, no. 7, pp. 1170-1181, 2013.

Test results:
           avg_psnr      ssim      ms-ssim
lowres      2.516       -2.622     -2.450
midres      2.312       -3.062     -3.882
hdres       2.292       -4.293     -5.246

The encoding time is about the same as the baseline.

Change-Id: Ida2c380ade79b6c15cf12b88bf090069da8765d8

vp9/encoder/vp9_encodeframe.c
vp9/encoder/vp9_encoder.c
vp9/encoder/vp9_encoder.h
vp9/vp9_cx_iface.c

index 933cd3c..7d50122 100644 (file)
@@ -259,6 +259,24 @@ static INLINE void set_mode_info_offsets(VP9_COMMON *const cm,
   x->mbmi_ext = x->mbmi_ext_base + (mi_row * cm->mi_cols + mi_col);
 }
 
+static double get_ssim_rdmult_scaling_factor(VP9_COMP *const cpi, int mi_row,
+                                             int mi_col) {
+  const VP9EncoderConfig *const oxcf = &cpi->oxcf;
+  if (oxcf->tuning == VP8_TUNE_SSIM) {
+    const VP9_COMMON *const cm = &cpi->common;
+    // SSIM rdmult scaling factors are currently 64x64 based.
+    const int num_8x8_w = 8;
+    const int num_8x8_h = 8;
+    const int num_cols = (cm->mi_cols + num_8x8_w - 1) / num_8x8_w;
+    const int row = mi_row / num_8x8_h;
+    const int col = mi_col / num_8x8_w;
+    const int index = row * num_cols + col;
+
+    return cpi->mi_ssim_rdmult_scaling_factors[index];
+  }
+  return 1.0;
+}
+
 static void set_offsets(VP9_COMP *cpi, const TileInfo *const tile,
                         MACROBLOCK *const x, int mi_row, int mi_col,
                         BLOCK_SIZE bsize) {
@@ -267,6 +285,8 @@ static void set_offsets(VP9_COMP *cpi, const TileInfo *const tile,
   const int mi_width = num_8x8_blocks_wide_lookup[bsize];
   const int mi_height = num_8x8_blocks_high_lookup[bsize];
   MvLimits *const mv_limits = &x->mv_limits;
+  const double ssim_factor =
+      get_ssim_rdmult_scaling_factor(cpi, mi_row, mi_col);
 
   set_skip_context(xd, mi_row, mi_col);
 
@@ -293,6 +313,7 @@ static void set_offsets(VP9_COMP *cpi, const TileInfo *const tile,
   // R/D setup.
   x->rddiv = cpi->rd.RDDIV;
   x->rdmult = cpi->rd.RDMULT;
+  x->rdmult = (int)(ssim_factor * x->rdmult);
 
   // required by vp9_append_sub8x8_mvs_for_idx() and vp9_find_best_ref_mvs()
   xd->tile = *tile;
@@ -1916,6 +1937,8 @@ static void set_segment_rdmult(VP9_COMP *const cpi, MACROBLOCK *const x,
   VP9_COMMON *const cm = &cpi->common;
   const uint8_t *const map =
       cm->seg.update_map ? cpi->segmentation_map : cm->last_frame_seg_map;
+  const double ssim_factor =
+      get_ssim_rdmult_scaling_factor(cpi, mi_row, mi_col);
 
   vp9_init_plane_quantizers(cpi, x);
   vpx_clear_system_state();
@@ -1923,25 +1946,22 @@ static void set_segment_rdmult(VP9_COMP *const cpi, MACROBLOCK *const x,
   if (aq_mode == NO_AQ || aq_mode == PSNR_AQ) {
     if (cpi->sf.enable_tpl_model || cpi->sf.enable_wiener_variance)
       x->rdmult = x->cb_rdmult;
-    return;
-  }
-
-  if (aq_mode == CYCLIC_REFRESH_AQ) {
+  } else if (aq_mode == CYCLIC_REFRESH_AQ) {
     // If segment is boosted, use rdmult for that segment.
     if (cyclic_refresh_segment_id_boosted(
             get_segment_id(cm, map, bsize, mi_row, mi_col)))
       x->rdmult = vp9_cyclic_refresh_get_rdmult(cpi->cyclic_refresh);
-    return;
+  } else {
+    x->rdmult = vp9_compute_rd_mult(cpi, cm->base_qindex + cm->y_dc_delta_q);
+    if (cpi->sf.enable_wiener_variance && cm->show_frame) {
+      if (cm->seg.enabled)
+        x->rdmult = vp9_compute_rd_mult(
+            cpi, vp9_get_qindex(&cm->seg, x->e_mbd.mi[0]->segment_id,
+                                cm->base_qindex));
+    }
   }
 
-  x->rdmult = vp9_compute_rd_mult(cpi, cm->base_qindex + cm->y_dc_delta_q);
-
-  if (cpi->sf.enable_wiener_variance && cm->show_frame) {
-    if (cm->seg.enabled)
-      x->rdmult = vp9_compute_rd_mult(
-          cpi, vp9_get_qindex(&cm->seg, x->e_mbd.mi[0]->segment_id,
-                              cm->base_qindex));
-  }
+  x->rdmult = (int)(ssim_factor * x->rdmult);
 }
 
 static void rd_pick_sb_modes(VP9_COMP *cpi, TileDataEnc *tile_data,
@@ -2179,8 +2199,12 @@ static void encode_b(VP9_COMP *cpi, const TileInfo *const tile, ThreadData *td,
   set_offsets(cpi, tile, x, mi_row, mi_col, bsize);
 
   if ((cpi->sf.enable_tpl_model || cpi->sf.enable_wiener_variance) &&
-      cpi->oxcf.aq_mode == NO_AQ)
+      cpi->oxcf.aq_mode == NO_AQ) {
+    const double ssim_factor =
+        get_ssim_rdmult_scaling_factor(cpi, mi_row, mi_col);
     x->rdmult = x->cb_rdmult;
+    x->rdmult = (int)(ssim_factor * x->rdmult);
+  }
 
   update_state(cpi, td, ctx, mi_row, mi_col, bsize, output_enabled);
   encode_superblock(cpi, td, tp, output_enabled, mi_row, mi_col, bsize, ctx);
@@ -3783,11 +3807,16 @@ static void rd_pick_partition(VP9_COMP *cpi, ThreadData *td,
   int64_t dist_breakout_thr = cpi->sf.partition_search_breakout_thr.dist;
   int rate_breakout_thr = cpi->sf.partition_search_breakout_thr.rate;
   int must_split = 0;
-  int partition_mul = x->cb_rdmult;
+
   // Ref frames picked in the [i_th] quarter subblock during square partition
   // RD search. It may be used to prune ref frame selection of rect partitions.
   uint8_t ref_frames_used[4] = { 0, 0, 0, 0 };
 
+  int partition_mul = x->cb_rdmult;
+  const double ssim_factor =
+      get_ssim_rdmult_scaling_factor(cpi, mi_row, mi_col);
+  partition_mul = (int)(ssim_factor * partition_mul);
+
   (void)*tp_orig;
 
   assert(num_8x8_blocks_wide_lookup[bsize] ==
index 292f2ee..8cae609 100644 (file)
@@ -990,6 +990,9 @@ static void dealloc_compressor_data(VP9_COMP *cpi) {
   vpx_free(cpi->mb_wiener_variance);
   cpi->mb_wiener_variance = NULL;
 
+  vpx_free(cpi->mi_ssim_rdmult_scaling_factors);
+  cpi->mi_ssim_rdmult_scaling_factors = NULL;
+
   vp9_free_ref_frame_buffers(cm->buffer_pool);
 #if CONFIG_VP9_POSTPROC
   vp9_free_postproc_buffers(cm);
@@ -2388,6 +2391,17 @@ VP9_COMP *vp9_create_compressor(VP9EncoderConfig *oxcf,
                                sizeof(*cpi->mb_wiener_variance)));
   }
 
+  {
+    const int bsize = BLOCK_64X64;
+    const int w = num_8x8_blocks_wide_lookup[bsize];
+    const int h = num_8x8_blocks_high_lookup[bsize];
+    const int num_cols = (cm->mi_cols + w - 1) / w;
+    const int num_rows = (cm->mi_rows + h - 1) / h;
+    CHECK_MEM_ERROR(cm, cpi->mi_ssim_rdmult_scaling_factors,
+                    vpx_calloc(num_rows * num_cols,
+                               sizeof(*cpi->mi_ssim_rdmult_scaling_factors)));
+  }
+
   cpi->kmeans_data_arr_alloc = 0;
 #if CONFIG_NON_GREEDY_MV
   cpi->feature_score_loc_alloc = 0;
@@ -4717,6 +4731,72 @@ static void set_frame_index(VP9_COMP *cpi, VP9_COMMON *cm) {
   }
 }
 
+// Implementation and modifications of C. Yeo, H. L. Tan, and Y. H. Tan, "On
+// rate distortion optimization using SSIM," Circuits and Systems for Video
+// Technology, IEEE Transactions on, vol. 23, no. 7, pp. 1170-1181, 2013.
+// SSIM_VAR_SCALE defines the strength of the bias towards SSIM in RDO.
+// Some sample values are:
+// SSIM_VAR_SCALE  avg_psnr   ssim   ms_ssim  (for midres test set)
+//     16.0          2.312   -3.062  -3.882
+//     32.0          0.852   -2.260  -2.821
+//     64.0          0.294   -1.606  -1.925
+#define SSIM_VAR_SCALE 16.0
+static void set_mb_ssim_rdmult_scaling(VP9_COMP *cpi) {
+  const double c2 = 0.03 * 0.03 * 255 * 255;
+  VP9_COMMON *cm = &cpi->common;
+  ThreadData *td = &cpi->td;
+  MACROBLOCK *x = &td->mb;
+  MACROBLOCKD *xd = &x->e_mbd;
+  uint8_t *y_buffer = cpi->Source->y_buffer;
+  const int y_stride = cpi->Source->y_stride;
+  const int block_size = BLOCK_64X64;
+
+  const int num_8x8_w = num_8x8_blocks_wide_lookup[block_size];
+  const int num_8x8_h = num_8x8_blocks_high_lookup[block_size];
+  const int num_cols = (cm->mi_cols + num_8x8_w - 1) / num_8x8_w;
+  const int num_rows = (cm->mi_rows + num_8x8_h - 1) / num_8x8_h;
+  double log_sum = 0.0;
+  int row, col;
+
+  // Loop through each 64x64 block.
+  for (row = 0; row < num_rows; ++row) {
+    for (col = 0; col < num_cols; ++col) {
+      int mi_row, mi_col;
+      double var = 0.0, num_of_var = 0.0;
+      const int index = row * num_cols + col;
+
+      for (mi_row = row * num_8x8_h;
+           mi_row < cm->mi_rows && mi_row < (row + 1) * num_8x8_h; ++mi_row) {
+        for (mi_col = col * num_8x8_w;
+             mi_col < cm->mi_cols && mi_col < (col + 1) * num_8x8_w; ++mi_col) {
+          struct buf_2d buf;
+          const int row_offset_y = mi_row << 3;
+          const int col_offset_y = mi_col << 3;
+
+          buf.buf = y_buffer + row_offset_y * y_stride + col_offset_y;
+          buf.stride = y_stride;
+          var += vp9_get_sby_variance(cpi, &buf, BLOCK_8X8) / 64.0;
+          num_of_var += 1.0;
+        }
+      }
+      var = var / num_of_var / SSIM_VAR_SCALE;
+      var = 2.0 * var + c2;
+      cpi->mi_ssim_rdmult_scaling_factors[index] = var;
+      log_sum += log(var);
+    }
+  }
+  log_sum = exp(log_sum / (double)(num_rows * num_cols));
+
+  for (row = 0; row < num_rows; ++row) {
+    for (col = 0; col < num_cols; ++col) {
+      const int index = row * num_cols + col;
+      cpi->mi_ssim_rdmult_scaling_factors[index] /= log_sum;
+    }
+  }
+
+  (void)xd;
+}
+
 // Process the wiener variance in 16x16 block basis.
 static void set_mb_wiener_variance(VP9_COMP *cpi) {
   VP9_COMMON *cm = &cpi->common;
@@ -4906,6 +4986,8 @@ static void encode_frame_to_data_rate(VP9_COMP *cpi, size_t *size,
     }
   }
 
+  if (oxcf->tuning == VP8_TUNE_SSIM) set_mb_ssim_rdmult_scaling(cpi);
+
   set_mb_wiener_variance(cpi);
 
   vpx_clear_system_state();
index 67d764f..5eec17d 100644 (file)
@@ -649,6 +649,7 @@ typedef struct VP9_COMP {
 
   int64_t norm_wiener_variance;
   int64_t *mb_wiener_variance;
+  double *mi_ssim_rdmult_scaling_factors;
   int *stack_rank_buffer;
 
   YV12_BUFFER_CONFIG last_frame_uf;
index 4d358b7..ab5d53d 100644 (file)
@@ -263,9 +263,11 @@ static vpx_codec_err_t validate_config(vpx_codec_alg_priv_t *ctx,
   RANGE_CHECK(extra_cfg, content, VP9E_CONTENT_DEFAULT,
               VP9E_CONTENT_INVALID - 1);
 
-  // TODO(yaowu): remove this when ssim tuning is implemented for vp9
+  // TODO(sdeng): remove this when ssim tuning is implemented for highbd
+#if CONFIG_VP9_HIGHBITDEPTH
   if (extra_cfg->tuning == VP8_TUNE_SSIM)
-    ERROR("Option --tune=ssim is not currently supported in VP9.");
+    ERROR("Option --tune=ssim is not currently supported in highbd VP9.");
+#endif
 
 #if !CONFIG_REALTIME_ONLY
   if (cfg->g_pass == VPX_RC_LAST_PASS) {