Imported Upstream version 0.9.0
[platform/upstream/libjxl.git] / lib / jxl / enc_ac_strategy.cc
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include "lib/jxl/enc_ac_strategy.h"
7
8 #include <stdint.h>
9 #include <string.h>
10
11 #include <algorithm>
12 #include <cmath>
13 #include <cstdio>
14
15 #undef HWY_TARGET_INCLUDE
16 #define HWY_TARGET_INCLUDE "lib/jxl/enc_ac_strategy.cc"
17 #include <hwy/foreach_target.h>
18 #include <hwy/highway.h>
19
20 #include "lib/jxl/ac_strategy.h"
21 #include "lib/jxl/ans_params.h"
22 #include "lib/jxl/base/bits.h"
23 #include "lib/jxl/base/compiler_specific.h"
24 #include "lib/jxl/base/fast_math-inl.h"
25 #include "lib/jxl/base/status.h"
26 #include "lib/jxl/coeff_order_fwd.h"
27 #include "lib/jxl/convolve.h"
28 #include "lib/jxl/dct_scales.h"
29 #include "lib/jxl/dec_transforms-inl.h"
30 #include "lib/jxl/enc_aux_out.h"
31 #include "lib/jxl/enc_debug_image.h"
32 #include "lib/jxl/enc_params.h"
33 #include "lib/jxl/enc_transforms-inl.h"
34 #include "lib/jxl/entropy_coder.h"
35 #include "lib/jxl/simd_util.h"
36
37 // Some of the floating point constants in this file and in other
38 // files in the libjxl project have been obtained using the
39 // tools/optimizer/simplex_fork.py tool. It is a variation of
40 // Nelder-Mead optimization, and we generally try to minimize
41 // BPP * pnorm aggregate as reported by the benchmark_xl tool,
42 // but occasionally the values are optimized by using additional
43 // constraints such as maintaining a certain density, or ratio of
44 // popularity of integral transforms. Jyrki visually reviews all
45 // such changes and often makes manual changes to maintain good
46 // visual quality to changes where butteraugli was not sufficiently
47 // sensitive to some kind of degradation. Unfortunately image quality
48 // is still more of an art than science.
49
50 // Set JXL_DEBUG_AC_STRATEGY to 1 to enable debugging.
51 #ifndef JXL_DEBUG_AC_STRATEGY
52 #define JXL_DEBUG_AC_STRATEGY 0
53 #endif
54
55 // This must come before the begin/end_target, but HWY_ONCE is only true
56 // after that, so use an "include guard".
57 #ifndef LIB_JXL_ENC_AC_STRATEGY_
58 #define LIB_JXL_ENC_AC_STRATEGY_
59 // Parameters of the heuristic are marked with a OPTIMIZE comment.
60 namespace jxl {
61 namespace {
62
63 // Debugging utilities.
64
65 // Returns a linear sRGB color (as bytes) for each AC strategy.
66 const uint8_t* TypeColor(const uint8_t& raw_strategy) {
67   JXL_ASSERT(AcStrategy::IsRawStrategyValid(raw_strategy));
68   static_assert(AcStrategy::kNumValidStrategies == 27, "Change colors");
69   static constexpr uint8_t kColors[][3] = {
70       {0xFF, 0xFF, 0x00},  // DCT8
71       {0xFF, 0x80, 0x80},  // HORNUSS
72       {0xFF, 0x80, 0x80},  // DCT2x2
73       {0xFF, 0x80, 0x80},  // DCT4x4
74       {0x80, 0xFF, 0x00},  // DCT16x16
75       {0x00, 0xC0, 0x00},  // DCT32x32
76       {0xC0, 0xFF, 0x00},  // DCT16x8
77       {0xC0, 0xFF, 0x00},  // DCT8x16
78       {0x00, 0xFF, 0x00},  // DCT32x8
79       {0x00, 0xFF, 0x00},  // DCT8x32
80       {0x00, 0xFF, 0x00},  // DCT32x16
81       {0x00, 0xFF, 0x00},  // DCT16x32
82       {0xFF, 0x80, 0x00},  // DCT4x8
83       {0xFF, 0x80, 0x00},  // DCT8x4
84       {0xFF, 0xFF, 0x80},  // AFV0
85       {0xFF, 0xFF, 0x80},  // AFV1
86       {0xFF, 0xFF, 0x80},  // AFV2
87       {0xFF, 0xFF, 0x80},  // AFV3
88       {0x00, 0xC0, 0xFF},  // DCT64x64
89       {0x00, 0xFF, 0xFF},  // DCT64x32
90       {0x00, 0xFF, 0xFF},  // DCT32x64
91       {0x00, 0x40, 0xFF},  // DCT128x128
92       {0x00, 0x80, 0xFF},  // DCT128x64
93       {0x00, 0x80, 0xFF},  // DCT64x128
94       {0x00, 0x00, 0xC0},  // DCT256x256
95       {0x00, 0x00, 0xFF},  // DCT256x128
96       {0x00, 0x00, 0xFF},  // DCT128x256
97   };
98   return kColors[raw_strategy];
99 }
100
101 const uint8_t* TypeMask(const uint8_t& raw_strategy) {
102   JXL_ASSERT(AcStrategy::IsRawStrategyValid(raw_strategy));
103   static_assert(AcStrategy::kNumValidStrategies == 27, "Add masks");
104   // implicitly, first row and column is made dark
105   static constexpr uint8_t kMask[][64] = {
106       {
107           0, 0, 0, 0, 0, 0, 0, 0,  //
108           0, 0, 0, 0, 0, 0, 0, 0,  //
109           0, 0, 0, 0, 0, 0, 0, 0,  //
110           0, 0, 0, 0, 0, 0, 0, 0,  //
111           0, 0, 0, 0, 0, 0, 0, 0,  //
112           0, 0, 0, 0, 0, 0, 0, 0,  //
113           0, 0, 0, 0, 0, 0, 0, 0,  //
114           0, 0, 0, 0, 0, 0, 0, 0,  //
115       },                           // DCT8
116       {
117           0, 0, 0, 0, 0, 0, 0, 0,  //
118           0, 0, 0, 0, 0, 0, 0, 0,  //
119           0, 0, 1, 0, 0, 1, 0, 0,  //
120           0, 0, 1, 0, 0, 1, 0, 0,  //
121           0, 0, 1, 1, 1, 1, 0, 0,  //
122           0, 0, 1, 0, 0, 1, 0, 0,  //
123           0, 0, 1, 0, 0, 1, 0, 0,  //
124           0, 0, 0, 0, 0, 0, 0, 0,  //
125       },                           // HORNUSS
126       {
127           1, 1, 1, 1, 1, 1, 1, 1,  //
128           1, 0, 1, 0, 1, 0, 1, 0,  //
129           1, 1, 1, 1, 1, 1, 1, 1,  //
130           1, 0, 1, 0, 1, 0, 1, 0,  //
131           1, 1, 1, 1, 1, 1, 1, 1,  //
132           1, 0, 1, 0, 1, 0, 1, 0,  //
133           1, 1, 1, 1, 1, 1, 1, 1,  //
134           1, 0, 1, 0, 1, 0, 1, 0,  //
135       },                           // 2x2
136       {
137           0, 0, 0, 0, 1, 0, 0, 0,  //
138           0, 0, 0, 0, 1, 0, 0, 0,  //
139           0, 0, 0, 0, 1, 0, 0, 0,  //
140           0, 0, 0, 0, 1, 0, 0, 0,  //
141           1, 1, 1, 1, 1, 1, 1, 1,  //
142           0, 0, 0, 0, 1, 0, 0, 0,  //
143           0, 0, 0, 0, 1, 0, 0, 0,  //
144           0, 0, 0, 0, 1, 0, 0, 0,  //
145       },                           // 4x4
146       {},                          // DCT16x16 (unused)
147       {},                          // DCT32x32 (unused)
148       {},                          // DCT16x8 (unused)
149       {},                          // DCT8x16 (unused)
150       {},                          // DCT32x8 (unused)
151       {},                          // DCT8x32 (unused)
152       {},                          // DCT32x16 (unused)
153       {},                          // DCT16x32 (unused)
154       {
155           0, 0, 0, 0, 0, 0, 0, 0,  //
156           0, 0, 0, 0, 0, 0, 0, 0,  //
157           0, 0, 0, 0, 0, 0, 0, 0,  //
158           0, 0, 0, 0, 0, 0, 0, 0,  //
159           1, 1, 1, 1, 1, 1, 1, 1,  //
160           0, 0, 0, 0, 0, 0, 0, 0,  //
161           0, 0, 0, 0, 0, 0, 0, 0,  //
162           0, 0, 0, 0, 0, 0, 0, 0,  //
163       },                           // DCT4x8
164       {
165           0, 0, 0, 0, 1, 0, 0, 0,  //
166           0, 0, 0, 0, 1, 0, 0, 0,  //
167           0, 0, 0, 0, 1, 0, 0, 0,  //
168           0, 0, 0, 0, 1, 0, 0, 0,  //
169           0, 0, 0, 0, 1, 0, 0, 0,  //
170           0, 0, 0, 0, 1, 0, 0, 0,  //
171           0, 0, 0, 0, 1, 0, 0, 0,  //
172           0, 0, 0, 0, 1, 0, 0, 0,  //
173       },                           // DCT8x4
174       {
175           1, 1, 1, 1, 1, 0, 0, 0,  //
176           1, 1, 1, 1, 0, 0, 0, 0,  //
177           1, 1, 1, 0, 0, 0, 0, 0,  //
178           1, 1, 0, 0, 0, 0, 0, 0,  //
179           1, 0, 0, 0, 0, 0, 0, 0,  //
180           0, 0, 0, 0, 0, 0, 0, 0,  //
181           0, 0, 0, 0, 0, 0, 0, 0,  //
182           0, 0, 0, 0, 0, 0, 0, 0,  //
183       },                           // AFV0
184       {
185           0, 0, 0, 0, 1, 1, 1, 1,  //
186           0, 0, 0, 0, 0, 1, 1, 1,  //
187           0, 0, 0, 0, 0, 0, 1, 1,  //
188           0, 0, 0, 0, 0, 0, 0, 1,  //
189           0, 0, 0, 0, 0, 0, 0, 0,  //
190           0, 0, 0, 0, 0, 0, 0, 0,  //
191           0, 0, 0, 0, 0, 0, 0, 0,  //
192           0, 0, 0, 0, 0, 0, 0, 0,  //
193       },                           // AFV1
194       {
195           0, 0, 0, 0, 0, 0, 0, 0,  //
196           0, 0, 0, 0, 0, 0, 0, 0,  //
197           0, 0, 0, 0, 0, 0, 0, 0,  //
198           0, 0, 0, 0, 0, 0, 0, 0,  //
199           1, 0, 0, 0, 0, 0, 0, 0,  //
200           1, 1, 0, 0, 0, 0, 0, 0,  //
201           1, 1, 1, 0, 0, 0, 0, 0,  //
202           1, 1, 1, 1, 0, 0, 0, 0,  //
203       },                           // AFV2
204       {
205           0, 0, 0, 0, 0, 0, 0, 0,  //
206           0, 0, 0, 0, 0, 0, 0, 0,  //
207           0, 0, 0, 0, 0, 0, 0, 0,  //
208           0, 0, 0, 0, 0, 0, 0, 0,  //
209           0, 0, 0, 0, 0, 0, 0, 0,  //
210           0, 0, 0, 0, 0, 0, 0, 1,  //
211           0, 0, 0, 0, 0, 0, 1, 1,  //
212           0, 0, 0, 0, 0, 1, 1, 1,  //
213       },                           // AFV3
214   };
215   return kMask[raw_strategy];
216 }
217
218 void DumpAcStrategy(const AcStrategyImage& ac_strategy, size_t xsize,
219                     size_t ysize, const char* tag, AuxOut* aux_out,
220                     const CompressParams& cparams) {
221   Image3F color_acs(xsize, ysize);
222   for (size_t y = 0; y < ysize; y++) {
223     float* JXL_RESTRICT rows[3] = {
224         color_acs.PlaneRow(0, y),
225         color_acs.PlaneRow(1, y),
226         color_acs.PlaneRow(2, y),
227     };
228     const AcStrategyRow acs_row = ac_strategy.ConstRow(y / kBlockDim);
229     for (size_t x = 0; x < xsize; x++) {
230       AcStrategy acs = acs_row[x / kBlockDim];
231       const uint8_t* JXL_RESTRICT color = TypeColor(acs.RawStrategy());
232       for (size_t c = 0; c < 3; c++) {
233         rows[c][x] = color[c] / 255.f;
234       }
235     }
236   }
237   size_t stride = color_acs.PixelsPerRow();
238   for (size_t c = 0; c < 3; c++) {
239     for (size_t by = 0; by < DivCeil(ysize, kBlockDim); by++) {
240       float* JXL_RESTRICT row = color_acs.PlaneRow(c, by * kBlockDim);
241       const AcStrategyRow acs_row = ac_strategy.ConstRow(by);
242       for (size_t bx = 0; bx < DivCeil(xsize, kBlockDim); bx++) {
243         AcStrategy acs = acs_row[bx];
244         if (!acs.IsFirstBlock()) continue;
245         const uint8_t* JXL_RESTRICT color = TypeColor(acs.RawStrategy());
246         const uint8_t* JXL_RESTRICT mask = TypeMask(acs.RawStrategy());
247         if (acs.covered_blocks_x() == 1 && acs.covered_blocks_y() == 1) {
248           for (size_t iy = 0; iy < kBlockDim && by * kBlockDim + iy < ysize;
249                iy++) {
250             for (size_t ix = 0; ix < kBlockDim && bx * kBlockDim + ix < xsize;
251                  ix++) {
252               if (mask[iy * kBlockDim + ix]) {
253                 row[iy * stride + bx * kBlockDim + ix] = color[c] / 800.f;
254               }
255             }
256           }
257         }
258         // draw block edges
259         for (size_t ix = 0; ix < kBlockDim * acs.covered_blocks_x() &&
260                             bx * kBlockDim + ix < xsize;
261              ix++) {
262           row[0 * stride + bx * kBlockDim + ix] = color[c] / 350.f;
263         }
264         for (size_t iy = 0; iy < kBlockDim * acs.covered_blocks_y() &&
265                             by * kBlockDim + iy < ysize;
266              iy++) {
267           row[iy * stride + bx * kBlockDim + 0] = color[c] / 350.f;
268         }
269       }
270     }
271   }
272   DumpImage(cparams, tag, color_acs);
273 }
274
275 }  // namespace
276 }  // namespace jxl
277 #endif  // LIB_JXL_ENC_AC_STRATEGY_
278
279 HWY_BEFORE_NAMESPACE();
280 namespace jxl {
281 namespace HWY_NAMESPACE {
282
283 // These templates are not found via ADL.
284 using hwy::HWY_NAMESPACE::AbsDiff;
285 using hwy::HWY_NAMESPACE::Eq;
286 using hwy::HWY_NAMESPACE::IfThenElseZero;
287 using hwy::HWY_NAMESPACE::IfThenZeroElse;
288 using hwy::HWY_NAMESPACE::Round;
289 using hwy::HWY_NAMESPACE::Sqrt;
290
291 bool MultiBlockTransformCrossesHorizontalBoundary(
292     const AcStrategyImage& ac_strategy, size_t start_x, size_t y,
293     size_t end_x) {
294   if (start_x >= ac_strategy.xsize() || y >= ac_strategy.ysize()) {
295     return false;
296   }
297   if (y % 8 == 0) {
298     // Nothing crosses 64x64 boundaries, and the memory on the other side
299     // of the 64x64 block may still uninitialized.
300     return false;
301   }
302   end_x = std::min(end_x, ac_strategy.xsize());
303   // The first multiblock might be before the start_x, let's adjust it
304   // to point to the first IsFirstBlock() == true block we find by backward
305   // tracing.
306   AcStrategyRow row = ac_strategy.ConstRow(y);
307   const size_t start_x_limit = start_x & ~7;
308   while (start_x != start_x_limit && !row[start_x].IsFirstBlock()) {
309     --start_x;
310   }
311   for (size_t x = start_x; x < end_x;) {
312     if (row[x].IsFirstBlock()) {
313       x += row[x].covered_blocks_x();
314     } else {
315       return true;
316     }
317   }
318   return false;
319 }
320
321 bool MultiBlockTransformCrossesVerticalBoundary(
322     const AcStrategyImage& ac_strategy, size_t x, size_t start_y,
323     size_t end_y) {
324   if (x >= ac_strategy.xsize() || start_y >= ac_strategy.ysize()) {
325     return false;
326   }
327   if (x % 8 == 0) {
328     // Nothing crosses 64x64 boundaries, and the memory on the other side
329     // of the 64x64 block may still uninitialized.
330     return false;
331   }
332   end_y = std::min(end_y, ac_strategy.ysize());
333   // The first multiblock might be before the start_y, let's adjust it
334   // to point to the first IsFirstBlock() == true block we find by backward
335   // tracing.
336   const size_t start_y_limit = start_y & ~7;
337   while (start_y != start_y_limit &&
338          !ac_strategy.ConstRow(start_y)[x].IsFirstBlock()) {
339     --start_y;
340   }
341
342   for (size_t y = start_y; y < end_y;) {
343     AcStrategyRow row = ac_strategy.ConstRow(y);
344     if (row[x].IsFirstBlock()) {
345       y += row[x].covered_blocks_y();
346     } else {
347       return true;
348     }
349   }
350   return false;
351 }
352
353 float EstimateEntropy(const AcStrategy& acs, float entropy_mul, size_t x,
354                       size_t y, const ACSConfig& config,
355                       const float* JXL_RESTRICT cmap_factors, float* block,
356                       float* scratch_space, uint32_t* quantized) {
357   const size_t size = (1 << acs.log2_covered_blocks()) * kDCTBlockSize;
358
359   // Apply transform.
360   for (size_t c = 0; c < 3; c++) {
361     float* JXL_RESTRICT block_c = block + size * c;
362     TransformFromPixels(acs.Strategy(), &config.Pixel(c, x, y),
363                         config.src_stride, block_c, scratch_space);
364   }
365   HWY_FULL(float) df;
366
367   const size_t num_blocks = acs.covered_blocks_x() * acs.covered_blocks_y();
368   // avoid large blocks when there is a lot going on in red-green.
369   float quant_norm16 = 0;
370   if (num_blocks == 1) {
371     // When it is only one 8x8, we don't need aggregation of values.
372     quant_norm16 = config.Quant(x / 8, y / 8);
373   } else if (num_blocks == 2) {
374     // Taking max instead of 8th norm seems to work
375     // better for smallest blocks up to 16x8. Jyrki couldn't get
376     // improvements in trying the same for 16x16 blocks.
377     if (acs.covered_blocks_y() == 2) {
378       quant_norm16 =
379           std::max(config.Quant(x / 8, y / 8), config.Quant(x / 8, y / 8 + 1));
380     } else {
381       quant_norm16 =
382           std::max(config.Quant(x / 8, y / 8), config.Quant(x / 8 + 1, y / 8));
383     }
384   } else {
385     // Load QF value, calculate empirical heuristic on masking field
386     // for weighting the information loss. Information loss manifests
387     // itself as ringing, and masking could hide it.
388     for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
389       for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
390         float qval = config.Quant(x / 8 + ix, y / 8 + iy);
391         qval *= qval;
392         qval *= qval;
393         qval *= qval;
394         quant_norm16 += qval * qval;
395       }
396     }
397     quant_norm16 /= num_blocks;
398     quant_norm16 = FastPowf(quant_norm16, 1.0f / 16.0f);
399   }
400   const auto quant = Set(df, quant_norm16);
401
402   // Compute entropy.
403   float entropy = 0.0f;
404   const HWY_CAPPED(float, 8) df8;
405
406   auto mem_alloc = hwy::AllocateAligned<float>(AcStrategy::kMaxCoeffArea);
407   float* mem = mem_alloc.get();
408   auto loss = Zero(df8);
409   for (size_t c = 0; c < 3; c++) {
410     const float* inv_matrix = config.dequant->InvMatrix(acs.RawStrategy(), c);
411     const float* matrix = config.dequant->Matrix(acs.RawStrategy(), c);
412     const auto cmap_factor = Set(df, cmap_factors[c]);
413
414     auto entropy_v = Zero(df);
415     auto nzeros_v = Zero(df);
416     for (size_t i = 0; i < num_blocks * kDCTBlockSize; i += Lanes(df)) {
417       const auto in = Load(df, block + c * size + i);
418       const auto in_y = Mul(Load(df, block + size + i), cmap_factor);
419       const auto im = Load(df, inv_matrix + i);
420       const auto val = Mul(Sub(in, in_y), Mul(im, quant));
421       const auto rval = Round(val);
422       const auto diff = Sub(val, rval);
423       const auto m = Load(df, matrix + i);
424       Store(Mul(m, diff), df, &mem[i]);
425       const auto q = Abs(rval);
426       const auto q_is_zero = Eq(q, Zero(df));
427       // We used to have q * C here, but that cost model seems to
428       // be punishing large values more than necessary. Sqrt tries
429       // to avoid large values less aggressively.
430       entropy_v = Add(Sqrt(q), entropy_v);
431       nzeros_v = Add(nzeros_v, IfThenZeroElse(q_is_zero, Set(df, 1.0f)));
432     }
433
434     {
435       auto lossc = Zero(df8);
436       TransformToPixels(acs.Strategy(), &mem[0], block,
437                         acs.covered_blocks_x() * 8, scratch_space);
438
439       for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
440         for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
441           for (size_t dy = 0; dy < kBlockDim; ++dy) {
442             for (size_t dx = 0; dx < kBlockDim; dx += Lanes(df8)) {
443               auto in = Load(df8, block +
444                                       (iy * kBlockDim + dy) *
445                                           (acs.covered_blocks_x() * kBlockDim) +
446                                       ix * kBlockDim + dx);
447               auto masku = Abs(Load(
448                   df8, config.MaskingPtr1x1(x + ix * 8 + dx, y + iy * 8 + dy)));
449               in = Mul(masku, in);
450               in = Mul(in, in);
451               in = Mul(in, in);
452               in = Mul(in, in);
453               lossc = Add(lossc, in);
454             }
455           }
456         }
457       }
458       static const double kChannelMul[3] = {
459           10.2,
460           1.0,
461           1.03,
462       };
463       lossc = Mul(Set(df8, pow(kChannelMul[c], 8.0)), lossc);
464       loss = Add(loss, lossc);
465     }
466     entropy += config.cost_delta * GetLane(SumOfLanes(df, entropy_v));
467     size_t num_nzeros = GetLane(SumOfLanes(df, nzeros_v));
468     // Add #bit of num_nonzeros, as an estimate of the cost for encoding the
469     // number of non-zeros of the block.
470     size_t nbits = CeilLog2Nonzero(num_nzeros + 1) + 1;
471     // Also add #bit of #bit of num_nonzeros, to estimate the ANS cost, with a
472     // bias.
473     entropy += config.zeros_mul * (CeilLog2Nonzero(nbits + 17) + nbits);
474   }
475   float loss_scalar =
476       pow(GetLane(SumOfLanes(df8, loss)) / (num_blocks * kDCTBlockSize),
477           1.0 / 8.0) *
478       (num_blocks * kDCTBlockSize) / quant_norm16;
479   float ret = entropy * entropy_mul;
480   ret += config.info_loss_multiplier * loss_scalar;
481   return ret;
482 }
483
484 uint8_t FindBest8x8Transform(size_t x, size_t y, int encoding_speed_tier,
485                              float butteraugli_target, const ACSConfig& config,
486                              const float* JXL_RESTRICT cmap_factors,
487                              AcStrategyImage* JXL_RESTRICT ac_strategy,
488                              float* block, float* scratch_space,
489                              uint32_t* quantized, float* entropy_out) {
490   struct TransformTry8x8 {
491     AcStrategy::Type type;
492     int encoding_speed_tier_max_limit;
493     double entropy_mul;
494   };
495   static const TransformTry8x8 kTransforms8x8[] = {
496       {
497           AcStrategy::Type::DCT,
498           9,
499           0.8,
500       },
501       {
502           AcStrategy::Type::DCT4X4,
503           5,
504           1.08,
505       },
506       {
507           AcStrategy::Type::DCT2X2,
508           5,
509           0.95,
510       },
511       {
512           AcStrategy::Type::DCT4X8,
513           4,
514           0.85931637428340035,
515       },
516       {
517           AcStrategy::Type::DCT8X4,
518           4,
519           0.85931637428340035,
520       },
521       {
522           AcStrategy::Type::IDENTITY,
523           5,
524           1.0427542510634957,
525       },
526       {
527           AcStrategy::Type::AFV0,
528           4,
529           0.81779489591359944,
530       },
531       {
532           AcStrategy::Type::AFV1,
533           4,
534           0.81779489591359944,
535       },
536       {
537           AcStrategy::Type::AFV2,
538           4,
539           0.81779489591359944,
540       },
541       {
542           AcStrategy::Type::AFV3,
543           4,
544           0.81779489591359944,
545       },
546   };
547   double best = 1e30;
548   uint8_t best_tx = kTransforms8x8[0].type;
549   for (auto tx : kTransforms8x8) {
550     if (tx.encoding_speed_tier_max_limit < encoding_speed_tier) {
551       continue;
552     }
553     AcStrategy acs = AcStrategy::FromRawStrategy(tx.type);
554     float entropy_mul = tx.entropy_mul / kTransforms8x8[0].entropy_mul;
555     if ((tx.type == AcStrategy::Type::DCT2X2 ||
556          tx.type == AcStrategy::Type::IDENTITY) &&
557         butteraugli_target < 5.0) {
558       static const float kFavor2X2AtHighQuality = 0.4;
559       float weight = pow((5.0f - butteraugli_target) / 5.0f, 2.0);
560       entropy_mul -= kFavor2X2AtHighQuality * weight;
561     }
562     if ((tx.type != AcStrategy::Type::DCT &&
563          tx.type != AcStrategy::Type::DCT2X2 &&
564          tx.type != AcStrategy::Type::IDENTITY) &&
565         butteraugli_target > 4.0) {
566       static const float kAvoidEntropyOfTransforms = 0.5;
567       float mul = 1.0;
568       if (butteraugli_target < 12.0) {
569         mul *= (12.0 - 4.0) / (butteraugli_target - 4.0);
570       }
571       entropy_mul += kAvoidEntropyOfTransforms * mul;
572     }
573     float entropy =
574         EstimateEntropy(acs, entropy_mul, x, y, config, cmap_factors, block,
575                         scratch_space, quantized);
576     if (entropy < best) {
577       best_tx = tx.type;
578       best = entropy;
579     }
580   }
581   *entropy_out = best;
582   return best_tx;
583 }
584
585 // bx, by addresses the 64x64 block at 8x8 subresolution
586 // cx, cy addresses the left, upper 8x8 block position of the candidate
587 // transform.
588 void TryMergeAcs(AcStrategy::Type acs_raw, size_t bx, size_t by, size_t cx,
589                  size_t cy, const ACSConfig& config,
590                  const float* JXL_RESTRICT cmap_factors,
591                  AcStrategyImage* JXL_RESTRICT ac_strategy,
592                  const float entropy_mul, const uint8_t candidate_priority,
593                  uint8_t* priority, float* JXL_RESTRICT entropy_estimate,
594                  float* block, float* scratch_space, uint32_t* quantized) {
595   AcStrategy acs = AcStrategy::FromRawStrategy(acs_raw);
596   float entropy_current = 0;
597   for (size_t iy = 0; iy < acs.covered_blocks_y(); ++iy) {
598     for (size_t ix = 0; ix < acs.covered_blocks_x(); ++ix) {
599       if (priority[(cy + iy) * 8 + (cx + ix)] >= candidate_priority) {
600         // Transform would reuse already allocated blocks and
601         // lead to invalid overlaps, for example DCT64X32 vs.
602         // DCT32X64.
603         return;
604       }
605       entropy_current += entropy_estimate[(cy + iy) * 8 + (cx + ix)];
606     }
607   }
608   float entropy_candidate =
609       EstimateEntropy(acs, entropy_mul, (bx + cx) * 8, (by + cy) * 8, config,
610                       cmap_factors, block, scratch_space, quantized);
611   if (entropy_candidate >= entropy_current) return;
612   // Accept the candidate.
613   for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
614     for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
615       entropy_estimate[(cy + iy) * 8 + cx + ix] = 0;
616       priority[(cy + iy) * 8 + cx + ix] = candidate_priority;
617     }
618   }
619   ac_strategy->Set(bx + cx, by + cy, acs_raw);
620   entropy_estimate[cy * 8 + cx] = entropy_candidate;
621 }
622
623 static void SetEntropyForTransform(size_t cx, size_t cy,
624                                    const AcStrategy::Type acs_raw,
625                                    float entropy,
626                                    float* JXL_RESTRICT entropy_estimate) {
627   const AcStrategy acs = AcStrategy::FromRawStrategy(acs_raw);
628   for (size_t dy = 0; dy < acs.covered_blocks_y(); ++dy) {
629     for (size_t dx = 0; dx < acs.covered_blocks_x(); ++dx) {
630       entropy_estimate[(cy + dy) * 8 + cx + dx] = 0.0;
631     }
632   }
633   entropy_estimate[cy * 8 + cx] = entropy;
634 }
635
636 AcStrategy::Type AcsSquare(size_t blocks) {
637   if (blocks == 2) {
638     return AcStrategy::Type::DCT16X16;
639   } else if (blocks == 4) {
640     return AcStrategy::Type::DCT32X32;
641   } else {
642     return AcStrategy::Type::DCT64X64;
643   }
644 }
645
646 AcStrategy::Type AcsVerticalSplit(size_t blocks) {
647   if (blocks == 2) {
648     return AcStrategy::Type::DCT16X8;
649   } else if (blocks == 4) {
650     return AcStrategy::Type::DCT32X16;
651   } else {
652     return AcStrategy::Type::DCT64X32;
653   }
654 }
655
656 AcStrategy::Type AcsHorizontalSplit(size_t blocks) {
657   if (blocks == 2) {
658     return AcStrategy::Type::DCT8X16;
659   } else if (blocks == 4) {
660     return AcStrategy::Type::DCT16X32;
661   } else {
662     return AcStrategy::Type::DCT32X64;
663   }
664 }
665
666 // The following function tries to merge smaller transforms into
667 // squares and the rectangles originating from a single middle division
668 // (horizontal or vertical) fairly.
669 //
670 // This is now generalized to concern about squares
671 // of blocks X blocks size, where a block is 8x8 pixels.
672 void FindBestFirstLevelDivisionForSquare(
673     size_t blocks, bool allow_square_transform, size_t bx, size_t by, size_t cx,
674     size_t cy, const ACSConfig& config, const float* JXL_RESTRICT cmap_factors,
675     AcStrategyImage* JXL_RESTRICT ac_strategy, const float entropy_mul_JXK,
676     const float entropy_mul_JXJ, float* JXL_RESTRICT entropy_estimate,
677     float* block, float* scratch_space, uint32_t* quantized) {
678   // We denote J for the larger dimension here, and K for the smaller.
679   // For example, for 32x32 block splitting, J would be 32, K 16.
680   const size_t blocks_half = blocks / 2;
681   const AcStrategy::Type acs_rawJXK = AcsVerticalSplit(blocks);
682   const AcStrategy::Type acs_rawKXJ = AcsHorizontalSplit(blocks);
683   const AcStrategy::Type acs_rawJXJ = AcsSquare(blocks);
684   const AcStrategy acsJXK = AcStrategy::FromRawStrategy(acs_rawJXK);
685   const AcStrategy acsKXJ = AcStrategy::FromRawStrategy(acs_rawKXJ);
686   const AcStrategy acsJXJ = AcStrategy::FromRawStrategy(acs_rawJXJ);
687   AcStrategyRow row0 = ac_strategy->ConstRow(by + cy + 0);
688   AcStrategyRow row1 = ac_strategy->ConstRow(by + cy + blocks_half);
689   // Let's check if we can consider a JXJ block here at all.
690   // This is not necessary in the basic use of hierarchically merging
691   // blocks in the simplest possible way, but is needed when we try other
692   // 'floating' options of merging, possibly after a simple hierarchical
693   // merge has been explored.
694   if (MultiBlockTransformCrossesHorizontalBoundary(*ac_strategy, bx + cx,
695                                                    by + cy, bx + cx + blocks) ||
696       MultiBlockTransformCrossesHorizontalBoundary(
697           *ac_strategy, bx + cx, by + cy + blocks, bx + cx + blocks) ||
698       MultiBlockTransformCrossesVerticalBoundary(*ac_strategy, bx + cx, by + cy,
699                                                  by + cy + blocks) ||
700       MultiBlockTransformCrossesVerticalBoundary(*ac_strategy, bx + cx + blocks,
701                                                  by + cy, by + cy + blocks)) {
702     return;  // not suitable for JxJ analysis, some transforms leak out.
703   }
704   // For floating transforms there may be
705   // already blocks selected that make either or both JXK and
706   // KXJ not feasible for this location.
707   const bool allow_JXK = !MultiBlockTransformCrossesVerticalBoundary(
708       *ac_strategy, bx + cx + blocks_half, by + cy, by + cy + blocks);
709   const bool allow_KXJ = !MultiBlockTransformCrossesHorizontalBoundary(
710       *ac_strategy, bx + cx, by + cy + blocks_half, bx + cx + blocks);
711   // Current entropies aggregated on NxN resolution.
712   float entropy[2][2] = {};
713   for (size_t dy = 0; dy < blocks; ++dy) {
714     for (size_t dx = 0; dx < blocks; ++dx) {
715       entropy[dy / blocks_half][dx / blocks_half] +=
716           entropy_estimate[(cy + dy) * 8 + (cx + dx)];
717     }
718   }
719   float entropy_JXK_left = std::numeric_limits<float>::max();
720   float entropy_JXK_right = std::numeric_limits<float>::max();
721   float entropy_KXJ_top = std::numeric_limits<float>::max();
722   float entropy_KXJ_bottom = std::numeric_limits<float>::max();
723   float entropy_JXJ = std::numeric_limits<float>::max();
724   if (allow_JXK) {
725     if (row0[bx + cx + 0].RawStrategy() != acs_rawJXK) {
726       entropy_JXK_left = EstimateEntropy(
727           acsJXK, entropy_mul_JXK, (bx + cx + 0) * 8, (by + cy + 0) * 8, config,
728           cmap_factors, block, scratch_space, quantized);
729     }
730     if (row0[bx + cx + blocks_half].RawStrategy() != acs_rawJXK) {
731       entropy_JXK_right =
732           EstimateEntropy(acsJXK, entropy_mul_JXK, (bx + cx + blocks_half) * 8,
733                           (by + cy + 0) * 8, config, cmap_factors, block,
734                           scratch_space, quantized);
735     }
736   }
737   if (allow_KXJ) {
738     if (row0[bx + cx].RawStrategy() != acs_rawKXJ) {
739       entropy_KXJ_top = EstimateEntropy(
740           acsKXJ, entropy_mul_JXK, (bx + cx + 0) * 8, (by + cy + 0) * 8, config,
741           cmap_factors, block, scratch_space, quantized);
742     }
743     if (row1[bx + cx].RawStrategy() != acs_rawKXJ) {
744       entropy_KXJ_bottom =
745           EstimateEntropy(acsKXJ, entropy_mul_JXK, (bx + cx + 0) * 8,
746                           (by + cy + blocks_half) * 8, config, cmap_factors,
747                           block, scratch_space, quantized);
748     }
749   }
750   if (allow_square_transform) {
751     // We control the exploration of the square transform separately so that
752     // we can turn it off at high decoding speeds for 32x32, but still allow
753     // exploring 16x32 and 32x16.
754     entropy_JXJ = EstimateEntropy(acsJXJ, entropy_mul_JXJ, (bx + cx + 0) * 8,
755                                   (by + cy + 0) * 8, config, cmap_factors,
756                                   block, scratch_space, quantized);
757   }
758
759   // Test if this block should have JXK or KXJ transforms,
760   // because it can have only one or the other.
761   float costJxN = std::min(entropy_JXK_left, entropy[0][0] + entropy[1][0]) +
762                   std::min(entropy_JXK_right, entropy[0][1] + entropy[1][1]);
763   float costNxJ = std::min(entropy_KXJ_top, entropy[0][0] + entropy[0][1]) +
764                   std::min(entropy_KXJ_bottom, entropy[1][0] + entropy[1][1]);
765   if (entropy_JXJ < costJxN && entropy_JXJ < costNxJ) {
766     ac_strategy->Set(bx + cx, by + cy, acs_rawJXJ);
767     SetEntropyForTransform(cx, cy, acs_rawJXJ, entropy_JXJ, entropy_estimate);
768   } else if (costJxN < costNxJ) {
769     if (entropy_JXK_left < entropy[0][0] + entropy[1][0]) {
770       ac_strategy->Set(bx + cx, by + cy, acs_rawJXK);
771       SetEntropyForTransform(cx, cy, acs_rawJXK, entropy_JXK_left,
772                              entropy_estimate);
773     }
774     if (entropy_JXK_right < entropy[0][1] + entropy[1][1]) {
775       ac_strategy->Set(bx + cx + blocks_half, by + cy, acs_rawJXK);
776       SetEntropyForTransform(cx + blocks_half, cy, acs_rawJXK,
777                              entropy_JXK_right, entropy_estimate);
778     }
779   } else {
780     if (entropy_KXJ_top < entropy[0][0] + entropy[0][1]) {
781       ac_strategy->Set(bx + cx, by + cy, acs_rawKXJ);
782       SetEntropyForTransform(cx, cy, acs_rawKXJ, entropy_KXJ_top,
783                              entropy_estimate);
784     }
785     if (entropy_KXJ_bottom < entropy[1][0] + entropy[1][1]) {
786       ac_strategy->Set(bx + cx, by + cy + blocks_half, acs_rawKXJ);
787       SetEntropyForTransform(cx, cy + blocks_half, acs_rawKXJ,
788                              entropy_KXJ_bottom, entropy_estimate);
789     }
790   }
791 }
792
793 void ProcessRectACS(PassesEncoderState* JXL_RESTRICT enc_state,
794                     const ACSConfig& config, const Rect& rect) {
795   // Main philosophy here:
796   // 1. First find best 8x8 transform for each area.
797   // 2. Merging them into larger transforms where possibly, but
798   // starting from the smallest transforms (16x8 and 8x16).
799   // Additional complication: 16x8 and 8x16 are considered
800   // simultaneously and fairly against each other.
801   // We are looking at 64x64 squares since the YtoX and YtoB
802   // maps happen to be at that resolution, and having
803   // integral transforms cross these boundaries leads to
804   // additional complications.
805   const CompressParams& cparams = enc_state->cparams;
806   const float butteraugli_target = cparams.butteraugli_distance;
807   AcStrategyImage* ac_strategy = &enc_state->shared.ac_strategy;
808   const size_t dct_scratch_size =
809       3 * (MaxVectorSize() / sizeof(float)) * AcStrategy::kMaxBlockDim;
810   // TODO(veluca): reuse allocations
811   auto mem = hwy::AllocateAligned<float>(5 * AcStrategy::kMaxCoeffArea +
812                                          dct_scratch_size);
813   auto qmem = hwy::AllocateAligned<uint32_t>(AcStrategy::kMaxCoeffArea);
814   uint32_t* JXL_RESTRICT quantized = qmem.get();
815   float* JXL_RESTRICT block = mem.get();
816   float* JXL_RESTRICT scratch_space = mem.get() + 3 * AcStrategy::kMaxCoeffArea;
817   size_t bx = rect.x0();
818   size_t by = rect.y0();
819   JXL_ASSERT(rect.xsize() <= 8);
820   JXL_ASSERT(rect.ysize() <= 8);
821   size_t tx = bx / kColorTileDimInBlocks;
822   size_t ty = by / kColorTileDimInBlocks;
823   const float cmap_factors[3] = {
824       enc_state->shared.cmap.YtoXRatio(
825           enc_state->shared.cmap.ytox_map.ConstRow(ty)[tx]),
826       0.0f,
827       enc_state->shared.cmap.YtoBRatio(
828           enc_state->shared.cmap.ytob_map.ConstRow(ty)[tx]),
829   };
830   if (cparams.speed_tier > SpeedTier::kHare) return;
831   // First compute the best 8x8 transform for each square. Later, we do not
832   // experiment with different combinations, but only use the best of the 8x8s
833   // when DCT8X8 is specified in the tree search.
834   // 8x8 transforms have 10 variants, but every larger transform is just a DCT.
835   float entropy_estimate[64] = {};
836   // Favor all 8x8 transforms (against 16x8 and larger transforms)) at
837   // low butteraugli_target distances.
838   static const float k8x8mul1 = -0.4;
839   static const float k8x8mul2 = 1.0;
840   static const float k8x8base = 1.4;
841   const float mul8x8 = k8x8mul2 + k8x8mul1 / (butteraugli_target + k8x8base);
842   for (size_t iy = 0; iy < rect.ysize(); iy++) {
843     for (size_t ix = 0; ix < rect.xsize(); ix++) {
844       float entropy = 0.0;
845       const uint8_t best_of_8x8s = FindBest8x8Transform(
846           8 * (bx + ix), 8 * (by + iy), static_cast<int>(cparams.speed_tier),
847           butteraugli_target, config, cmap_factors, ac_strategy, block,
848           scratch_space, quantized, &entropy);
849       ac_strategy->Set(bx + ix, by + iy,
850                        static_cast<AcStrategy::Type>(best_of_8x8s));
851       entropy_estimate[iy * 8 + ix] = entropy * mul8x8;
852     }
853   }
854   // Merge when a larger transform is better than the previously
855   // searched best combination of 8x8 transforms.
856   struct MergeTry {
857     AcStrategy::Type type;
858     uint8_t priority;
859     uint8_t decoding_speed_tier_max_limit;
860     uint8_t encoding_speed_tier_max_limit;
861     float entropy_mul;
862   };
863   // These numbers need to be figured out manually and looking at
864   // ringing next to sky etc. Optimization will find larger numbers
865   // and produce more ringing than is ideal. Larger numbers will
866   // help stop ringing.
867   const float entropy_mul16X8 = 1.25;
868   const float entropy_mul16X16 = 1.35;
869   const float entropy_mul16X32 = 1.5;
870   const float entropy_mul32X32 = 1.5;
871   const float entropy_mul64X32 = 2.26;
872   const float entropy_mul64X64 = 2.26;
873   // TODO(jyrki): Consider this feedback in further changes:
874   // Also effectively when the multipliers for smaller blocks are
875   // below 1, this raises the bar for the bigger blocks even higher
876   // in that sense these constants are not independent (e.g. changing
877   // the constant for DCT16x32 by -5% (making it more likely) also
878   // means that DCT32x32 becomes harder to do when starting from
879   // two DCT16x32s). It might be better to make them more independent,
880   // e.g. by not applying the multiplier when storing the new entropy
881   // estimates in TryMergeToACSCandidate().
882   const MergeTry kTransformsForMerge[9] = {
883       {AcStrategy::Type::DCT16X8, 2, 4, 5, entropy_mul16X8},
884       {AcStrategy::Type::DCT8X16, 2, 4, 5, entropy_mul16X8},
885       // FindBestFirstLevelDivisionForSquare looks for DCT16X16 and its
886       // subdivisions. {AcStrategy::Type::DCT16X16, 3, entropy_mul16X16},
887       {AcStrategy::Type::DCT16X32, 4, 4, 4, entropy_mul16X32},
888       {AcStrategy::Type::DCT32X16, 4, 4, 4, entropy_mul16X32},
889       // FindBestFirstLevelDivisionForSquare looks for DCT32X32 and its
890       // subdivisions. {AcStrategy::Type::DCT32X32, 5, 1, 5,
891       // 0.9822994906548809f},
892       {AcStrategy::Type::DCT64X32, 6, 1, 3, entropy_mul64X32},
893       {AcStrategy::Type::DCT32X64, 6, 1, 3, entropy_mul64X32},
894       // {AcStrategy::Type::DCT64X64, 8, 1, 3, 2.0846542128012948f},
895   };
896   /*
897   These sizes not yet included in merge heuristic:
898   set(AcStrategy::Type::DCT32X8, 0.0f, 2.261390410971102f);
899   set(AcStrategy::Type::DCT8X32, 0.0f, 2.261390410971102f);
900   set(AcStrategy::Type::DCT128X128, 0.0f, 1.0f);
901   set(AcStrategy::Type::DCT128X64, 0.0f, 0.73f);
902   set(AcStrategy::Type::DCT64X128, 0.0f, 0.73f);
903   set(AcStrategy::Type::DCT256X256, 0.0f, 1.0f);
904   set(AcStrategy::Type::DCT256X128, 0.0f, 0.73f);
905   set(AcStrategy::Type::DCT128X256, 0.0f, 0.73f);
906   */
907
908   // Priority is a tricky kludge to avoid collisions so that transforms
909   // don't overlap.
910   uint8_t priority[64] = {};
911   bool enable_32x32 = cparams.decoding_speed_tier < 4;
912   for (auto tx : kTransformsForMerge) {
913     if (tx.decoding_speed_tier_max_limit < cparams.decoding_speed_tier) {
914       continue;
915     }
916     AcStrategy acs = AcStrategy::FromRawStrategy(tx.type);
917
918     for (size_t cy = 0; cy + acs.covered_blocks_y() - 1 < rect.ysize();
919          cy += acs.covered_blocks_y()) {
920       for (size_t cx = 0; cx + acs.covered_blocks_x() - 1 < rect.xsize();
921            cx += acs.covered_blocks_x()) {
922         if (cy + 7 < rect.ysize() && cx + 7 < rect.xsize()) {
923           if (cparams.decoding_speed_tier < 4 &&
924               tx.type == AcStrategy::Type::DCT32X64) {
925             // We handle both DCT8X16 and DCT16X8 at the same time.
926             if ((cy | cx) % 8 == 0) {
927               FindBestFirstLevelDivisionForSquare(
928                   8, true, bx, by, cx, cy, config, cmap_factors, ac_strategy,
929                   tx.entropy_mul, entropy_mul64X64, entropy_estimate, block,
930                   scratch_space, quantized);
931             }
932             continue;
933           } else if (tx.type == AcStrategy::Type::DCT32X16) {
934             // We handled both DCT8X16 and DCT16X8 at the same time,
935             // and that is above. The last column and last row,
936             // when the last column or last row is odd numbered,
937             // are still handled by TryMergeAcs.
938             continue;
939           }
940         }
941         if ((tx.type == AcStrategy::Type::DCT16X32 && cy % 4 != 0) ||
942             (tx.type == AcStrategy::Type::DCT32X16 && cx % 4 != 0)) {
943           // already covered by FindBest32X32
944           continue;
945         }
946
947         if (cy + 3 < rect.ysize() && cx + 3 < rect.xsize()) {
948           if (tx.type == AcStrategy::Type::DCT16X32) {
949             // We handle both DCT8X16 and DCT16X8 at the same time.
950             if ((cy | cx) % 4 == 0) {
951               FindBestFirstLevelDivisionForSquare(
952                   4, enable_32x32, bx, by, cx, cy, config, cmap_factors,
953                   ac_strategy, tx.entropy_mul, entropy_mul32X32,
954                   entropy_estimate, block, scratch_space, quantized);
955             }
956             continue;
957           } else if (tx.type == AcStrategy::Type::DCT32X16) {
958             // We handled both DCT8X16 and DCT16X8 at the same time,
959             // and that is above. The last column and last row,
960             // when the last column or last row is odd numbered,
961             // are still handled by TryMergeAcs.
962             continue;
963           }
964         }
965         if ((tx.type == AcStrategy::Type::DCT16X32 && cy % 4 != 0) ||
966             (tx.type == AcStrategy::Type::DCT32X16 && cx % 4 != 0)) {
967           // already covered by FindBest32X32
968           continue;
969         }
970         if (cy + 1 < rect.ysize() && cx + 1 < rect.xsize()) {
971           if (tx.type == AcStrategy::Type::DCT8X16) {
972             // We handle both DCT8X16 and DCT16X8 at the same time.
973             if ((cy | cx) % 2 == 0) {
974               FindBestFirstLevelDivisionForSquare(
975                   2, true, bx, by, cx, cy, config, cmap_factors, ac_strategy,
976                   tx.entropy_mul, entropy_mul16X16, entropy_estimate, block,
977                   scratch_space, quantized);
978             }
979             continue;
980           } else if (tx.type == AcStrategy::Type::DCT16X8) {
981             // We handled both DCT8X16 and DCT16X8 at the same time,
982             // and that is above. The last column and last row,
983             // when the last column or last row is odd numbered,
984             // are still handled by TryMergeAcs.
985             continue;
986           }
987         }
988         if ((tx.type == AcStrategy::Type::DCT8X16 && cy % 2 == 1) ||
989             (tx.type == AcStrategy::Type::DCT16X8 && cx % 2 == 1)) {
990           // already covered by FindBestFirstLevelDivisionForSquare
991           continue;
992         }
993         // All other merge sizes are handled here.
994         // Some of the DCT16X8s and DCT8X16s will still leak through here
995         // when there is an odd number of 8x8 blocks, then the last row
996         // and column will get their DCT16X8s and DCT8X16s through the
997         // normal integral transform merging process.
998         TryMergeAcs(tx.type, bx, by, cx, cy, config, cmap_factors, ac_strategy,
999                     tx.entropy_mul, tx.priority, &priority[0], entropy_estimate,
1000                     block, scratch_space, quantized);
1001       }
1002     }
1003   }
1004   if (cparams.speed_tier >= SpeedTier::kHare) {
1005     return;
1006   }
1007   // Here we still try to do some non-aligned matching, find a few more
1008   // 16X8, 8X16 and 16X16s between the non-2-aligned blocks.
1009   for (size_t cy = 0; cy + 1 < rect.ysize(); ++cy) {
1010     for (size_t cx = 0; cx + 1 < rect.xsize(); ++cx) {
1011       if ((cy | cx) % 2 != 0) {
1012         FindBestFirstLevelDivisionForSquare(
1013             2, true, bx, by, cx, cy, config, cmap_factors, ac_strategy,
1014             entropy_mul16X8, entropy_mul16X16, entropy_estimate, block,
1015             scratch_space, quantized);
1016       }
1017     }
1018   }
1019   // Non-aligned matching for 32X32, 16X32 and 32X16.
1020   size_t step = cparams.speed_tier >= SpeedTier::kTortoise ? 2 : 1;
1021   for (size_t cy = 0; cy + 3 < rect.ysize(); cy += step) {
1022     for (size_t cx = 0; cx + 3 < rect.xsize(); cx += step) {
1023       if ((cy | cx) % 4 == 0) {
1024         continue;  // Already tried with loop above (DCT16X32 case).
1025       }
1026       FindBestFirstLevelDivisionForSquare(
1027           4, enable_32x32, bx, by, cx, cy, config, cmap_factors, ac_strategy,
1028           entropy_mul16X32, entropy_mul32X32, entropy_estimate, block,
1029           scratch_space, quantized);
1030     }
1031   }
1032 }
1033
1034 // NOLINTNEXTLINE(google-readability-namespace-comments)
1035 }  // namespace HWY_NAMESPACE
1036 }  // namespace jxl
1037 HWY_AFTER_NAMESPACE();
1038
1039 #if HWY_ONCE
1040 namespace jxl {
1041 HWY_EXPORT(ProcessRectACS);
1042
1043 void AcStrategyHeuristics::Init(const Image3F& src,
1044                                 PassesEncoderState* enc_state) {
1045   this->enc_state = enc_state;
1046   config.dequant = &enc_state->shared.matrices;
1047   const CompressParams& cparams = enc_state->cparams;
1048
1049   if (cparams.speed_tier >= SpeedTier::kCheetah) {
1050     JXL_CHECK(enc_state->shared.matrices.EnsureComputed(1));  // DCT8 only
1051   } else {
1052     uint32_t acs_mask = 0;
1053     // All transforms up to 64x64.
1054     for (size_t i = 0; i < AcStrategy::DCT128X128; i++) {
1055       acs_mask |= (1 << i);
1056     }
1057     JXL_CHECK(enc_state->shared.matrices.EnsureComputed(acs_mask));
1058   }
1059
1060   // Image row pointers and strides.
1061   config.quant_field_row = enc_state->initial_quant_field.Row(0);
1062   config.quant_field_stride = enc_state->initial_quant_field.PixelsPerRow();
1063   auto& mask = enc_state->initial_quant_masking;
1064   auto& mask1x1 = enc_state->initial_quant_masking1x1;
1065   if (mask.xsize() > 0 && mask.ysize() > 0) {
1066     config.masking_field_row = mask.Row(0);
1067     config.masking_field_stride = mask.PixelsPerRow();
1068   }
1069   if (mask1x1.xsize() > 0 && mask1x1.ysize() > 0) {
1070     config.masking1x1_field_row = mask1x1.Row(0);
1071     config.masking1x1_field_stride = mask1x1.PixelsPerRow();
1072   }
1073
1074   config.src_rows[0] = src.ConstPlaneRow(0, 0);
1075   config.src_rows[1] = src.ConstPlaneRow(1, 0);
1076   config.src_rows[2] = src.ConstPlaneRow(2, 0);
1077   config.src_stride = src.PixelsPerRow();
1078
1079   // Entropy estimate is composed of two factors:
1080   //  - estimate of the number of bits that will be used by the block
1081   //  - information loss due to quantization
1082   // The following constant controls the relative weights of these components.
1083   config.info_loss_multiplier = 1.2;
1084   config.zeros_mul = 9.3089059022677905;
1085   config.cost_delta = 10.833273317067883;
1086
1087   static const float kBias = 0.13731742964354549;
1088   const float ratio = (cparams.butteraugli_distance + kBias) / (1.0f + kBias);
1089
1090   static const float kPow1 = 0.33677806662454718;
1091   static const float kPow2 = 0.50990926717963703;
1092   static const float kPow3 = 0.36702940662370243;
1093   config.info_loss_multiplier *= pow(ratio, kPow1);
1094   config.zeros_mul *= pow(ratio, kPow2);
1095   config.cost_delta *= pow(ratio, kPow3);
1096   JXL_ASSERT(enc_state->shared.ac_strategy.xsize() ==
1097              enc_state->shared.frame_dim.xsize_blocks);
1098   JXL_ASSERT(enc_state->shared.ac_strategy.ysize() ==
1099              enc_state->shared.frame_dim.ysize_blocks);
1100 }
1101
1102 void AcStrategyHeuristics::ProcessRect(const Rect& rect) {
1103   const CompressParams& cparams = enc_state->cparams;
1104   // In Falcon mode, use DCT8 everywhere and uniform quantization.
1105   if (cparams.speed_tier >= SpeedTier::kCheetah) {
1106     enc_state->shared.ac_strategy.FillDCT8(rect);
1107     return;
1108   }
1109   HWY_DYNAMIC_DISPATCH(ProcessRectACS)
1110   (enc_state, config, rect);
1111 }
1112
1113 void AcStrategyHeuristics::Finalize(AuxOut* aux_out) {
1114   const auto& ac_strategy = enc_state->shared.ac_strategy;
1115   // Accounting and debug output.
1116   if (aux_out != nullptr) {
1117     aux_out->num_small_blocks =
1118         ac_strategy.CountBlocks(AcStrategy::Type::IDENTITY) +
1119         ac_strategy.CountBlocks(AcStrategy::Type::DCT2X2) +
1120         ac_strategy.CountBlocks(AcStrategy::Type::DCT4X4);
1121     aux_out->num_dct4x8_blocks =
1122         ac_strategy.CountBlocks(AcStrategy::Type::DCT4X8) +
1123         ac_strategy.CountBlocks(AcStrategy::Type::DCT8X4);
1124     aux_out->num_afv_blocks = ac_strategy.CountBlocks(AcStrategy::Type::AFV0) +
1125                               ac_strategy.CountBlocks(AcStrategy::Type::AFV1) +
1126                               ac_strategy.CountBlocks(AcStrategy::Type::AFV2) +
1127                               ac_strategy.CountBlocks(AcStrategy::Type::AFV3);
1128     aux_out->num_dct8_blocks = ac_strategy.CountBlocks(AcStrategy::Type::DCT);
1129     aux_out->num_dct8x16_blocks =
1130         ac_strategy.CountBlocks(AcStrategy::Type::DCT8X16) +
1131         ac_strategy.CountBlocks(AcStrategy::Type::DCT16X8);
1132     aux_out->num_dct8x32_blocks =
1133         ac_strategy.CountBlocks(AcStrategy::Type::DCT8X32) +
1134         ac_strategy.CountBlocks(AcStrategy::Type::DCT32X8);
1135     aux_out->num_dct16_blocks =
1136         ac_strategy.CountBlocks(AcStrategy::Type::DCT16X16);
1137     aux_out->num_dct16x32_blocks =
1138         ac_strategy.CountBlocks(AcStrategy::Type::DCT16X32) +
1139         ac_strategy.CountBlocks(AcStrategy::Type::DCT32X16);
1140     aux_out->num_dct32_blocks =
1141         ac_strategy.CountBlocks(AcStrategy::Type::DCT32X32);
1142     aux_out->num_dct32x64_blocks =
1143         ac_strategy.CountBlocks(AcStrategy::Type::DCT32X64) +
1144         ac_strategy.CountBlocks(AcStrategy::Type::DCT64X32);
1145     aux_out->num_dct64_blocks =
1146         ac_strategy.CountBlocks(AcStrategy::Type::DCT64X64);
1147   }
1148
1149   // if (JXL_DEBUG_AC_STRATEGY && WantDebugOutput(aux_out)) {
1150   if (JXL_DEBUG_AC_STRATEGY && WantDebugOutput(enc_state->cparams)) {
1151     DumpAcStrategy(ac_strategy, enc_state->shared.frame_dim.xsize,
1152                    enc_state->shared.frame_dim.ysize, "ac_strategy", aux_out,
1153                    enc_state->cparams);
1154   }
1155 }
1156
1157 }  // namespace jxl
1158 #endif  // HWY_ONCE