Imported Upstream version 0.9.0
[platform/upstream/libjxl.git] / lib / jxl / quant_weights.cc
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 #include "lib/jxl/quant_weights.h"
6
7 #include <stdio.h>
8 #include <stdlib.h>
9
10 #include <algorithm>
11 #include <cmath>
12 #include <limits>
13 #include <utility>
14
15 #include "lib/jxl/base/bits.h"
16 #include "lib/jxl/base/status.h"
17 #include "lib/jxl/dct_scales.h"
18 #include "lib/jxl/dec_modular.h"
19 #include "lib/jxl/fields.h"
20 #include "lib/jxl/image.h"
21
22 #undef HWY_TARGET_INCLUDE
23 #define HWY_TARGET_INCLUDE "lib/jxl/quant_weights.cc"
24 #include <hwy/foreach_target.h>
25 #include <hwy/highway.h>
26
27 #include "lib/jxl/base/fast_math-inl.h"
28
29 HWY_BEFORE_NAMESPACE();
30 namespace jxl {
31 namespace HWY_NAMESPACE {
32
33 // These templates are not found via ADL.
34 using hwy::HWY_NAMESPACE::Lt;
35 using hwy::HWY_NAMESPACE::MulAdd;
36 using hwy::HWY_NAMESPACE::Sqrt;
37
38 // kQuantWeights[N * N * c + N * y + x] is the relative weight of the (x, y)
39 // coefficient in component c. Higher weights correspond to finer quantization
40 // intervals and more bits spent in encoding.
41
42 static constexpr const float kAlmostZero = 1e-8f;
43
44 void GetQuantWeightsDCT2(const QuantEncoding::DCT2Weights& dct2weights,
45                          float* weights) {
46   for (size_t c = 0; c < 3; c++) {
47     size_t start = c * 64;
48     weights[start] = 0xBAD;
49     weights[start + 1] = weights[start + 8] = dct2weights[c][0];
50     weights[start + 9] = dct2weights[c][1];
51     for (size_t y = 0; y < 2; y++) {
52       for (size_t x = 0; x < 2; x++) {
53         weights[start + y * 8 + x + 2] = dct2weights[c][2];
54         weights[start + (y + 2) * 8 + x] = dct2weights[c][2];
55       }
56     }
57     for (size_t y = 0; y < 2; y++) {
58       for (size_t x = 0; x < 2; x++) {
59         weights[start + (y + 2) * 8 + x + 2] = dct2weights[c][3];
60       }
61     }
62     for (size_t y = 0; y < 4; y++) {
63       for (size_t x = 0; x < 4; x++) {
64         weights[start + y * 8 + x + 4] = dct2weights[c][4];
65         weights[start + (y + 4) * 8 + x] = dct2weights[c][4];
66       }
67     }
68     for (size_t y = 0; y < 4; y++) {
69       for (size_t x = 0; x < 4; x++) {
70         weights[start + (y + 4) * 8 + x + 4] = dct2weights[c][5];
71       }
72     }
73   }
74 }
75
76 void GetQuantWeightsIdentity(const QuantEncoding::IdWeights& idweights,
77                              float* weights) {
78   for (size_t c = 0; c < 3; c++) {
79     for (int i = 0; i < 64; i++) {
80       weights[64 * c + i] = idweights[c][0];
81     }
82     weights[64 * c + 1] = idweights[c][1];
83     weights[64 * c + 8] = idweights[c][1];
84     weights[64 * c + 9] = idweights[c][2];
85   }
86 }
87
88 float Interpolate(float pos, float max, const float* array, size_t len) {
89   float scaled_pos = pos * (len - 1) / max;
90   size_t idx = scaled_pos;
91   JXL_DASSERT(idx + 1 < len);
92   float a = array[idx];
93   float b = array[idx + 1];
94   return a * FastPowf(b / a, scaled_pos - idx);
95 }
96
97 float Mult(float v) {
98   if (v > 0.0f) return 1.0f + v;
99   return 1.0f / (1.0f - v);
100 }
101
102 using DF4 = HWY_CAPPED(float, 4);
103
104 hwy::HWY_NAMESPACE::Vec<DF4> InterpolateVec(
105     hwy::HWY_NAMESPACE::Vec<DF4> scaled_pos, const float* array) {
106   HWY_CAPPED(int32_t, 4) di;
107
108   auto idx = ConvertTo(di, scaled_pos);
109
110   auto frac = Sub(scaled_pos, ConvertTo(DF4(), idx));
111
112   // TODO(veluca): in theory, this could be done with 8 TableLookupBytes, but
113   // it's probably slower.
114   auto a = GatherIndex(DF4(), array, idx);
115   auto b = GatherIndex(DF4(), array + 1, idx);
116
117   return Mul(a, FastPowf(DF4(), Div(b, a), frac));
118 }
119
120 // Computes quant weights for a COLS*ROWS-sized transform, using num_bands
121 // eccentricity bands and num_ebands eccentricity bands. If print_mode is 1,
122 // prints the resulting matrix; if print_mode is 2, prints the matrix in a
123 // format suitable for a 3d plot with gnuplot.
124 Status GetQuantWeights(
125     size_t ROWS, size_t COLS,
126     const DctQuantWeightParams::DistanceBandsArray& distance_bands,
127     size_t num_bands, float* out) {
128   for (size_t c = 0; c < 3; c++) {
129     float bands[DctQuantWeightParams::kMaxDistanceBands] = {
130         distance_bands[c][0]};
131     if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid distance bands");
132     for (size_t i = 1; i < num_bands; i++) {
133       bands[i] = bands[i - 1] * Mult(distance_bands[c][i]);
134       if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid distance bands");
135     }
136     float scale = (num_bands - 1) / (kSqrt2 + 1e-6f);
137     float rcpcol = scale / (COLS - 1);
138     float rcprow = scale / (ROWS - 1);
139     JXL_ASSERT(COLS >= Lanes(DF4()));
140     HWY_ALIGN float l0123[4] = {0, 1, 2, 3};
141     for (uint32_t y = 0; y < ROWS; y++) {
142       float dy = y * rcprow;
143       float dy2 = dy * dy;
144       for (uint32_t x = 0; x < COLS; x += Lanes(DF4())) {
145         auto dx =
146             Mul(Add(Set(DF4(), x), Load(DF4(), l0123)), Set(DF4(), rcpcol));
147         auto scaled_distance = Sqrt(MulAdd(dx, dx, Set(DF4(), dy2)));
148         auto weight = num_bands == 1 ? Set(DF4(), bands[0])
149                                      : InterpolateVec(scaled_distance, bands);
150         StoreU(weight, DF4(), out + c * COLS * ROWS + y * COLS + x);
151       }
152     }
153   }
154   return true;
155 }
156
157 // TODO(veluca): SIMD-fy. With 256x256, this is actually slow.
158 Status ComputeQuantTable(const QuantEncoding& encoding,
159                          float* JXL_RESTRICT table,
160                          float* JXL_RESTRICT inv_table, size_t table_num,
161                          DequantMatrices::QuantTable kind, size_t* pos) {
162   constexpr size_t N = kBlockDim;
163   size_t wrows = 8 * DequantMatrices::required_size_x[kind],
164          wcols = 8 * DequantMatrices::required_size_y[kind];
165   size_t num = wrows * wcols;
166
167   std::vector<float> weights(3 * num);
168
169   switch (encoding.mode) {
170     case QuantEncoding::kQuantModeLibrary: {
171       // Library and copy quant encoding should get replaced by the actual
172       // parameters by the caller.
173       JXL_ASSERT(false);
174       break;
175     }
176     case QuantEncoding::kQuantModeID: {
177       JXL_ASSERT(num == kDCTBlockSize);
178       GetQuantWeightsIdentity(encoding.idweights, weights.data());
179       break;
180     }
181     case QuantEncoding::kQuantModeDCT2: {
182       JXL_ASSERT(num == kDCTBlockSize);
183       GetQuantWeightsDCT2(encoding.dct2weights, weights.data());
184       break;
185     }
186     case QuantEncoding::kQuantModeDCT4: {
187       JXL_ASSERT(num == kDCTBlockSize);
188       float weights4x4[3 * 4 * 4];
189       // Always use 4x4 GetQuantWeights for DCT4 quantization tables.
190       JXL_RETURN_IF_ERROR(
191           GetQuantWeights(4, 4, encoding.dct_params.distance_bands,
192                           encoding.dct_params.num_distance_bands, weights4x4));
193       for (size_t c = 0; c < 3; c++) {
194         for (size_t y = 0; y < kBlockDim; y++) {
195           for (size_t x = 0; x < kBlockDim; x++) {
196             weights[c * num + y * kBlockDim + x] =
197                 weights4x4[c * 16 + (y / 2) * 4 + (x / 2)];
198           }
199         }
200         weights[c * num + 1] /= encoding.dct4multipliers[c][0];
201         weights[c * num + N] /= encoding.dct4multipliers[c][0];
202         weights[c * num + N + 1] /= encoding.dct4multipliers[c][1];
203       }
204       break;
205     }
206     case QuantEncoding::kQuantModeDCT4X8: {
207       JXL_ASSERT(num == kDCTBlockSize);
208       float weights4x8[3 * 4 * 8];
209       // Always use 4x8 GetQuantWeights for DCT4X8 quantization tables.
210       JXL_RETURN_IF_ERROR(
211           GetQuantWeights(4, 8, encoding.dct_params.distance_bands,
212                           encoding.dct_params.num_distance_bands, weights4x8));
213       for (size_t c = 0; c < 3; c++) {
214         for (size_t y = 0; y < kBlockDim; y++) {
215           for (size_t x = 0; x < kBlockDim; x++) {
216             weights[c * num + y * kBlockDim + x] =
217                 weights4x8[c * 32 + (y / 2) * 8 + x];
218           }
219         }
220         weights[c * num + N] /= encoding.dct4x8multipliers[c];
221       }
222       break;
223     }
224     case QuantEncoding::kQuantModeDCT: {
225       JXL_RETURN_IF_ERROR(GetQuantWeights(
226           wrows, wcols, encoding.dct_params.distance_bands,
227           encoding.dct_params.num_distance_bands, weights.data()));
228       break;
229     }
230     case QuantEncoding::kQuantModeRAW: {
231       if (!encoding.qraw.qtable || encoding.qraw.qtable->size() != 3 * num) {
232         return JXL_FAILURE("Invalid table encoding");
233       }
234       for (size_t i = 0; i < 3 * num; i++) {
235         weights[i] =
236             1.f / (encoding.qraw.qtable_den * (*encoding.qraw.qtable)[i]);
237       }
238       break;
239     }
240     case QuantEncoding::kQuantModeAFV: {
241       constexpr float kFreqs[] = {
242           0xBAD,
243           0xBAD,
244           0.8517778890324296,
245           5.37778436506804,
246           0xBAD,
247           0xBAD,
248           4.734747904497923,
249           5.449245381693219,
250           1.6598270267479331,
251           4,
252           7.275749096817861,
253           10.423227632456525,
254           2.662932286148962,
255           7.630657783650829,
256           8.962388608184032,
257           12.97166202570235,
258       };
259
260       float weights4x8[3 * 4 * 8];
261       JXL_RETURN_IF_ERROR((
262           GetQuantWeights(4, 8, encoding.dct_params.distance_bands,
263                           encoding.dct_params.num_distance_bands, weights4x8)));
264       float weights4x4[3 * 4 * 4];
265       JXL_RETURN_IF_ERROR((GetQuantWeights(
266           4, 4, encoding.dct_params_afv_4x4.distance_bands,
267           encoding.dct_params_afv_4x4.num_distance_bands, weights4x4)));
268
269       constexpr float lo = 0.8517778890324296;
270       constexpr float hi = 12.97166202570235f - lo + 1e-6f;
271       for (size_t c = 0; c < 3; c++) {
272         float bands[4];
273         bands[0] = encoding.afv_weights[c][5];
274         if (bands[0] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands");
275         for (size_t i = 1; i < 4; i++) {
276           bands[i] = bands[i - 1] * Mult(encoding.afv_weights[c][i + 5]);
277           if (bands[i] < kAlmostZero) return JXL_FAILURE("Invalid AFV bands");
278         }
279         size_t start = c * 64;
280         auto set_weight = [&start, &weights](size_t x, size_t y, float val) {
281           weights[start + y * 8 + x] = val;
282         };
283         weights[start] = 1;  // Not used, but causes MSAN error otherwise.
284         // Weights for (0, 1) and (1, 0).
285         set_weight(0, 1, encoding.afv_weights[c][0]);
286         set_weight(1, 0, encoding.afv_weights[c][1]);
287         // AFV special weights for 3-pixel corner.
288         set_weight(0, 2, encoding.afv_weights[c][2]);
289         set_weight(2, 0, encoding.afv_weights[c][3]);
290         set_weight(2, 2, encoding.afv_weights[c][4]);
291
292         // All other AFV weights.
293         for (size_t y = 0; y < 4; y++) {
294           for (size_t x = 0; x < 4; x++) {
295             if (x < 2 && y < 2) continue;
296             float val = Interpolate(kFreqs[y * 4 + x] - lo, hi, bands, 4);
297             set_weight(2 * x, 2 * y, val);
298           }
299         }
300
301         // Put 4x8 weights in odd rows, except (1, 0).
302         for (size_t y = 0; y < kBlockDim / 2; y++) {
303           for (size_t x = 0; x < kBlockDim; x++) {
304             if (x == 0 && y == 0) continue;
305             weights[c * num + (2 * y + 1) * kBlockDim + x] =
306                 weights4x8[c * 32 + y * 8 + x];
307           }
308         }
309         // Put 4x4 weights in even rows / odd columns, except (0, 1).
310         for (size_t y = 0; y < kBlockDim / 2; y++) {
311           for (size_t x = 0; x < kBlockDim / 2; x++) {
312             if (x == 0 && y == 0) continue;
313             weights[c * num + (2 * y) * kBlockDim + 2 * x + 1] =
314                 weights4x4[c * 16 + y * 4 + x];
315           }
316         }
317       }
318       break;
319     }
320   }
321   size_t prev_pos = *pos;
322   HWY_CAPPED(float, 64) d;
323   for (size_t i = 0; i < num * 3; i += Lanes(d)) {
324     auto inv_val = LoadU(d, weights.data() + i);
325     if (JXL_UNLIKELY(!AllFalse(d, Ge(inv_val, Set(d, 1.0f / kAlmostZero))) ||
326                      !AllFalse(d, Lt(inv_val, Set(d, kAlmostZero))))) {
327       return JXL_FAILURE("Invalid quantization table");
328     }
329     auto val = Div(Set(d, 1.0f), inv_val);
330     StoreU(val, d, table + *pos + i);
331     StoreU(inv_val, d, inv_table + *pos + i);
332   }
333   (*pos) += 3 * num;
334
335   // Ensure that the lowest frequencies have a 0 inverse table.
336   // This does not affect en/decoding, but allows AC strategy selection to be
337   // slightly simpler.
338   size_t xs = DequantMatrices::required_size_x[kind];
339   size_t ys = DequantMatrices::required_size_y[kind];
340   CoefficientLayout(&ys, &xs);
341   for (size_t c = 0; c < 3; c++) {
342     for (size_t y = 0; y < ys; y++) {
343       for (size_t x = 0; x < xs; x++) {
344         inv_table[prev_pos + c * ys * xs * kDCTBlockSize + y * kBlockDim * xs +
345                   x] = 0;
346       }
347     }
348   }
349   return true;
350 }
351
352 // NOLINTNEXTLINE(google-readability-namespace-comments)
353 }  // namespace HWY_NAMESPACE
354 }  // namespace jxl
355 HWY_AFTER_NAMESPACE();
356
357 #if HWY_ONCE
358
359 namespace jxl {
360 namespace {
361
362 HWY_EXPORT(ComputeQuantTable);
363
364 static constexpr const float kAlmostZero = 1e-8f;
365
366 Status DecodeDctParams(BitReader* br, DctQuantWeightParams* params) {
367   params->num_distance_bands =
368       br->ReadFixedBits<DctQuantWeightParams::kLog2MaxDistanceBands>() + 1;
369   for (size_t c = 0; c < 3; c++) {
370     for (size_t i = 0; i < params->num_distance_bands; i++) {
371       JXL_RETURN_IF_ERROR(F16Coder::Read(br, &params->distance_bands[c][i]));
372     }
373     if (params->distance_bands[c][0] < kAlmostZero) {
374       return JXL_FAILURE("Distance band seed is too small");
375     }
376     params->distance_bands[c][0] *= 64.0f;
377   }
378   return true;
379 }
380
381 Status Decode(BitReader* br, QuantEncoding* encoding, size_t required_size_x,
382               size_t required_size_y, size_t idx,
383               ModularFrameDecoder* modular_frame_decoder) {
384   size_t required_size = required_size_x * required_size_y;
385   required_size_x *= kBlockDim;
386   required_size_y *= kBlockDim;
387   int mode = br->ReadFixedBits<kLog2NumQuantModes>();
388   switch (mode) {
389     case QuantEncoding::kQuantModeLibrary: {
390       encoding->predefined = br->ReadFixedBits<kCeilLog2NumPredefinedTables>();
391       if (encoding->predefined >= kNumPredefinedTables) {
392         return JXL_FAILURE("Invalid predefined table");
393       }
394       break;
395     }
396     case QuantEncoding::kQuantModeID: {
397       if (required_size != 1) return JXL_FAILURE("Invalid mode");
398       for (size_t c = 0; c < 3; c++) {
399         for (size_t i = 0; i < 3; i++) {
400           JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->idweights[c][i]));
401           if (std::abs(encoding->idweights[c][i]) < kAlmostZero) {
402             return JXL_FAILURE("ID Quantizer is too small");
403           }
404           encoding->idweights[c][i] *= 64;
405         }
406       }
407       break;
408     }
409     case QuantEncoding::kQuantModeDCT2: {
410       if (required_size != 1) return JXL_FAILURE("Invalid mode");
411       for (size_t c = 0; c < 3; c++) {
412         for (size_t i = 0; i < 6; i++) {
413           JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->dct2weights[c][i]));
414           if (std::abs(encoding->dct2weights[c][i]) < kAlmostZero) {
415             return JXL_FAILURE("Quantizer is too small");
416           }
417           encoding->dct2weights[c][i] *= 64;
418         }
419       }
420       break;
421     }
422     case QuantEncoding::kQuantModeDCT4X8: {
423       if (required_size != 1) return JXL_FAILURE("Invalid mode");
424       for (size_t c = 0; c < 3; c++) {
425         JXL_RETURN_IF_ERROR(
426             F16Coder::Read(br, &encoding->dct4x8multipliers[c]));
427         if (std::abs(encoding->dct4x8multipliers[c]) < kAlmostZero) {
428           return JXL_FAILURE("DCT4X8 multiplier is too small");
429         }
430       }
431       JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params));
432       break;
433     }
434     case QuantEncoding::kQuantModeDCT4: {
435       if (required_size != 1) return JXL_FAILURE("Invalid mode");
436       for (size_t c = 0; c < 3; c++) {
437         for (size_t i = 0; i < 2; i++) {
438           JXL_RETURN_IF_ERROR(
439               F16Coder::Read(br, &encoding->dct4multipliers[c][i]));
440           if (std::abs(encoding->dct4multipliers[c][i]) < kAlmostZero) {
441             return JXL_FAILURE("DCT4 multiplier is too small");
442           }
443         }
444       }
445       JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params));
446       break;
447     }
448     case QuantEncoding::kQuantModeAFV: {
449       if (required_size != 1) return JXL_FAILURE("Invalid mode");
450       for (size_t c = 0; c < 3; c++) {
451         for (size_t i = 0; i < 9; i++) {
452           JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->afv_weights[c][i]));
453         }
454         for (size_t i = 0; i < 6; i++) {
455           encoding->afv_weights[c][i] *= 64;
456         }
457       }
458       JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params));
459       JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params_afv_4x4));
460       break;
461     }
462     case QuantEncoding::kQuantModeDCT: {
463       JXL_RETURN_IF_ERROR(DecodeDctParams(br, &encoding->dct_params));
464       break;
465     }
466     case QuantEncoding::kQuantModeRAW: {
467       // Set mode early, to avoid mem-leak.
468       encoding->mode = QuantEncoding::kQuantModeRAW;
469       JXL_RETURN_IF_ERROR(ModularFrameDecoder::DecodeQuantTable(
470           required_size_x, required_size_y, br, encoding, idx,
471           modular_frame_decoder));
472       break;
473     }
474     default:
475       return JXL_FAILURE("Invalid quantization table encoding");
476   }
477   encoding->mode = QuantEncoding::Mode(mode);
478   return true;
479 }
480
481 }  // namespace
482
483 // These definitions are needed before C++17.
484 constexpr size_t DequantMatrices::required_size_[];
485 constexpr size_t DequantMatrices::required_size_x[];
486 constexpr size_t DequantMatrices::required_size_y[];
487 constexpr DequantMatrices::QuantTable DequantMatrices::kQuantTable[];
488
489 Status DequantMatrices::Decode(BitReader* br,
490                                ModularFrameDecoder* modular_frame_decoder) {
491   size_t all_default = br->ReadBits(1);
492   size_t num_tables = all_default ? 0 : static_cast<size_t>(kNum);
493   encodings_.clear();
494   encodings_.resize(kNum, QuantEncoding::Library(0));
495   for (size_t i = 0; i < num_tables; i++) {
496     JXL_RETURN_IF_ERROR(
497         jxl::Decode(br, &encodings_[i], required_size_x[i % kNum],
498                     required_size_y[i % kNum], i, modular_frame_decoder));
499   }
500   computed_mask_ = 0;
501   return true;
502 }
503
504 Status DequantMatrices::DecodeDC(BitReader* br) {
505   bool all_default = br->ReadBits(1);
506   if (!br->AllReadsWithinBounds()) return JXL_FAILURE("EOS during DecodeDC");
507   if (!all_default) {
508     for (size_t c = 0; c < 3; c++) {
509       JXL_RETURN_IF_ERROR(F16Coder::Read(br, &dc_quant_[c]));
510       dc_quant_[c] *= 1.0f / 128.0f;
511       // Negative values and nearly zero are invalid values.
512       if (dc_quant_[c] < kAlmostZero) {
513         return JXL_FAILURE("Invalid dc_quant: coefficient is too small.");
514       }
515       inv_dc_quant_[c] = 1.0f / dc_quant_[c];
516     }
517   }
518   return true;
519 }
520
521 constexpr float V(float v) { return static_cast<float>(v); }
522
523 namespace {
524 struct DequantMatricesLibraryDef {
525   // DCT8
526   static constexpr QuantEncodingInternal DCT() {
527     return QuantEncodingInternal::DCT(DctQuantWeightParams({{{{
528                                                                  V(3150.0),
529                                                                  V(0.0),
530                                                                  V(-0.4),
531                                                                  V(-0.4),
532                                                                  V(-0.4),
533                                                                  V(-2.0),
534                                                              }},
535                                                              {{
536                                                                  V(560.0),
537                                                                  V(0.0),
538                                                                  V(-0.3),
539                                                                  V(-0.3),
540                                                                  V(-0.3),
541                                                                  V(-0.3),
542                                                              }},
543                                                              {{
544                                                                  V(512.0),
545                                                                  V(-2.0),
546                                                                  V(-1.0),
547                                                                  V(0.0),
548                                                                  V(-1.0),
549                                                                  V(-2.0),
550                                                              }}}},
551                                                            6));
552   }
553
554   // Identity
555   static constexpr QuantEncodingInternal IDENTITY() {
556     return QuantEncodingInternal::Identity({{{{
557                                                  V(280.0),
558                                                  V(3160.0),
559                                                  V(3160.0),
560                                              }},
561                                              {{
562                                                  V(60.0),
563                                                  V(864.0),
564                                                  V(864.0),
565                                              }},
566                                              {{
567                                                  V(18.0),
568                                                  V(200.0),
569                                                  V(200.0),
570                                              }}}});
571   }
572
573   // DCT2
574   static constexpr QuantEncodingInternal DCT2X2() {
575     return QuantEncodingInternal::DCT2({{{{
576                                              V(3840.0),
577                                              V(2560.0),
578                                              V(1280.0),
579                                              V(640.0),
580                                              V(480.0),
581                                              V(300.0),
582                                          }},
583                                          {{
584                                              V(960.0),
585                                              V(640.0),
586                                              V(320.0),
587                                              V(180.0),
588                                              V(140.0),
589                                              V(120.0),
590                                          }},
591                                          {{
592                                              V(640.0),
593                                              V(320.0),
594                                              V(128.0),
595                                              V(64.0),
596                                              V(32.0),
597                                              V(16.0),
598                                          }}}});
599   }
600
601   // DCT4 (quant_kind 3)
602   static constexpr QuantEncodingInternal DCT4X4() {
603     return QuantEncodingInternal::DCT4(DctQuantWeightParams({{{{
604                                                                   V(2200.0),
605                                                                   V(0.0),
606                                                                   V(0.0),
607                                                                   V(0.0),
608                                                               }},
609                                                               {{
610                                                                   V(392.0),
611                                                                   V(0.0),
612                                                                   V(0.0),
613                                                                   V(0.0),
614                                                               }},
615                                                               {{
616                                                                   V(112.0),
617                                                                   V(-0.25),
618                                                                   V(-0.25),
619                                                                   V(-0.5),
620                                                               }}}},
621                                                             4),
622                                        /* kMul */
623                                        {{{{
624                                              V(1.0),
625                                              V(1.0),
626                                          }},
627                                          {{
628                                              V(1.0),
629                                              V(1.0),
630                                          }},
631                                          {{
632                                              V(1.0),
633                                              V(1.0),
634                                          }}}});
635   }
636
637   // DCT16
638   static constexpr QuantEncodingInternal DCT16X16() {
639     return QuantEncodingInternal::DCT(
640         DctQuantWeightParams({{{{
641                                    V(8996.8725711814115328),
642                                    V(-1.3000777393353804),
643                                    V(-0.49424529824571225),
644                                    V(-0.439093774457103443),
645                                    V(-0.6350101832695744),
646                                    V(-0.90177264050827612),
647                                    V(-1.6162099239887414),
648                                }},
649                                {{
650                                    V(3191.48366296844234752),
651                                    V(-0.67424582104194355),
652                                    V(-0.80745813428471001),
653                                    V(-0.44925837484843441),
654                                    V(-0.35865440981033403),
655                                    V(-0.31322389111877305),
656                                    V(-0.37615025315725483),
657                                }},
658                                {{
659                                    V(1157.50408145487200256),
660                                    V(-2.0531423165804414),
661                                    V(-1.4),
662                                    V(-0.50687130033378396),
663                                    V(-0.42708730624733904),
664                                    V(-1.4856834539296244),
665                                    V(-4.9209142884401604),
666                                }}}},
667                              7));
668   }
669
670   // DCT32
671   static constexpr QuantEncodingInternal DCT32X32() {
672     return QuantEncodingInternal::DCT(
673         DctQuantWeightParams({{{{
674                                    V(15718.40830982518931456),
675                                    V(-1.025),
676                                    V(-0.98),
677                                    V(-0.9012),
678                                    V(-0.4),
679                                    V(-0.48819395464),
680                                    V(-0.421064),
681                                    V(-0.27),
682                                }},
683                                {{
684                                    V(7305.7636810695983104),
685                                    V(-0.8041958212306401),
686                                    V(-0.7633036457487539),
687                                    V(-0.55660379990111464),
688                                    V(-0.49785304658857626),
689                                    V(-0.43699592683512467),
690                                    V(-0.40180866526242109),
691                                    V(-0.27321683125358037),
692                                }},
693                                {{
694                                    V(3803.53173721215041536),
695                                    V(-3.060733579805728),
696                                    V(-2.0413270132490346),
697                                    V(-2.0235650159727417),
698                                    V(-0.5495389509954993),
699                                    V(-0.4),
700                                    V(-0.4),
701                                    V(-0.3),
702                                }}}},
703                              8));
704   }
705
706   // DCT16X8
707   static constexpr QuantEncodingInternal DCT8X16() {
708     return QuantEncodingInternal::DCT(
709         DctQuantWeightParams({{{{
710                                    V(7240.7734393502),
711                                    V(-0.7),
712                                    V(-0.7),
713                                    V(-0.2),
714                                    V(-0.2),
715                                    V(-0.2),
716                                    V(-0.5),
717                                }},
718                                {{
719                                    V(1448.15468787004),
720                                    V(-0.5),
721                                    V(-0.5),
722                                    V(-0.5),
723                                    V(-0.2),
724                                    V(-0.2),
725                                    V(-0.2),
726                                }},
727                                {{
728                                    V(506.854140754517),
729                                    V(-1.4),
730                                    V(-0.2),
731                                    V(-0.5),
732                                    V(-0.5),
733                                    V(-1.5),
734                                    V(-3.6),
735                                }}}},
736                              7));
737   }
738
739   // DCT32X8
740   static constexpr QuantEncodingInternal DCT8X32() {
741     return QuantEncodingInternal::DCT(
742         DctQuantWeightParams({{{{
743                                    V(16283.2494710648897),
744                                    V(-1.7812845336559429),
745                                    V(-1.6309059012653515),
746                                    V(-1.0382179034313539),
747                                    V(-0.85),
748                                    V(-0.7),
749                                    V(-0.9),
750                                    V(-1.2360638576849587),
751                                }},
752                                {{
753                                    V(5089.15750884921511936),
754                                    V(-0.320049391452786891),
755                                    V(-0.35362849922161446),
756                                    V(-0.30340000000000003),
757                                    V(-0.61),
758                                    V(-0.5),
759                                    V(-0.5),
760                                    V(-0.6),
761                                }},
762                                {{
763                                    V(3397.77603275308720128),
764                                    V(-0.321327362693153371),
765                                    V(-0.34507619223117997),
766                                    V(-0.70340000000000003),
767                                    V(-0.9),
768                                    V(-1.0),
769                                    V(-1.0),
770                                    V(-1.1754605576265209),
771                                }}}},
772                              8));
773   }
774
775   // DCT32X16
776   static constexpr QuantEncodingInternal DCT16X32() {
777     return QuantEncodingInternal::DCT(
778         DctQuantWeightParams({{{{
779                                    V(13844.97076442300573),
780                                    V(-0.97113799999999995),
781                                    V(-0.658),
782                                    V(-0.42026),
783                                    V(-0.22712),
784                                    V(-0.2206),
785                                    V(-0.226),
786                                    V(-0.6),
787                                }},
788                                {{
789                                    V(4798.964084220744293),
790                                    V(-0.61125308982767057),
791                                    V(-0.83770786552491361),
792                                    V(-0.79014862079498627),
793                                    V(-0.2692727459704829),
794                                    V(-0.38272769465388551),
795                                    V(-0.22924222653091453),
796                                    V(-0.20719098826199578),
797                                }},
798                                {{
799                                    V(1807.236946760964614),
800                                    V(-1.2),
801                                    V(-1.2),
802                                    V(-0.7),
803                                    V(-0.7),
804                                    V(-0.7),
805                                    V(-0.4),
806                                    V(-0.5),
807                                }}}},
808                              8));
809   }
810
811   // DCT4X8 and 8x4
812   static constexpr QuantEncodingInternal DCT4X8() {
813     return QuantEncodingInternal::DCT4X8(
814         DctQuantWeightParams({{
815                                  {{
816                                      V(2198.050556016380522),
817                                      V(-0.96269623020744692),
818                                      V(-0.76194253026666783),
819                                      V(-0.6551140670773547),
820                                  }},
821                                  {{
822                                      V(764.3655248643528689),
823                                      V(-0.92630200888366945),
824                                      V(-0.9675229603596517),
825                                      V(-0.27845290869168118),
826                                  }},
827                                  {{
828                                      V(527.107573587542228),
829                                      V(-1.4594385811273854),
830                                      V(-1.450082094097871593),
831                                      V(-1.5843722511996204),
832                                  }},
833                              }},
834                              4),
835         /* kMuls */
836         {{
837             V(1.0),
838             V(1.0),
839             V(1.0),
840         }});
841   }
842   // AFV
843   static QuantEncodingInternal AFV0() {
844     return QuantEncodingInternal::AFV(DCT4X8().dct_params, DCT4X4().dct_params,
845                                       {{{{
846                                             // 4x4/4x8 DC tendency.
847                                             V(3072.0),
848                                             V(3072.0),
849                                             // AFV corner.
850                                             V(256.0),
851                                             V(256.0),
852                                             V(256.0),
853                                             // AFV high freqs.
854                                             V(414.0),
855                                             V(0.0),
856                                             V(0.0),
857                                             V(0.0),
858                                         }},
859                                         {{
860                                             // 4x4/4x8 DC tendency.
861                                             V(1024.0),
862                                             V(1024.0),
863                                             // AFV corner.
864                                             V(50),
865                                             V(50),
866                                             V(50),
867                                             // AFV high freqs.
868                                             V(58.0),
869                                             V(0.0),
870                                             V(0.0),
871                                             V(0.0),
872                                         }},
873                                         {{
874                                             // 4x4/4x8 DC tendency.
875                                             V(384.0),
876                                             V(384.0),
877                                             // AFV corner.
878                                             V(12.0),
879                                             V(12.0),
880                                             V(12.0),
881                                             // AFV high freqs.
882                                             V(22.0),
883                                             V(-0.25),
884                                             V(-0.25),
885                                             V(-0.25),
886                                         }}}});
887   }
888
889   // DCT64
890   static QuantEncodingInternal DCT64X64() {
891     return QuantEncodingInternal::DCT(
892         DctQuantWeightParams({{{{
893                                    V(0.9 * 26629.073922049845),
894                                    V(-1.025),
895                                    V(-0.78),
896                                    V(-0.65012),
897                                    V(-0.19041574084286472),
898                                    V(-0.20819395464),
899                                    V(-0.421064),
900                                    V(-0.32733845535848671),
901                                }},
902                                {{
903                                    V(0.9 * 9311.3238710010046),
904                                    V(-0.3041958212306401),
905                                    V(-0.3633036457487539),
906                                    V(-0.35660379990111464),
907                                    V(-0.3443074455424403),
908                                    V(-0.33699592683512467),
909                                    V(-0.30180866526242109),
910                                    V(-0.27321683125358037),
911                                }},
912                                {{
913                                    V(0.9 * 4992.2486445538634),
914                                    V(-1.2),
915                                    V(-1.2),
916                                    V(-0.8),
917                                    V(-0.7),
918                                    V(-0.7),
919                                    V(-0.4),
920                                    V(-0.5),
921                                }}}},
922                              8));
923   }
924
925   // DCT64X32
926   static QuantEncodingInternal DCT32X64() {
927     return QuantEncodingInternal::DCT(
928         DctQuantWeightParams({{{{
929                                    V(0.65 * 23629.073922049845),
930                                    V(-1.025),
931                                    V(-0.78),
932                                    V(-0.65012),
933                                    V(-0.19041574084286472),
934                                    V(-0.20819395464),
935                                    V(-0.421064),
936                                    V(-0.32733845535848671),
937                                }},
938                                {{
939                                    V(0.65 * 8611.3238710010046),
940                                    V(-0.3041958212306401),
941                                    V(-0.3633036457487539),
942                                    V(-0.35660379990111464),
943                                    V(-0.3443074455424403),
944                                    V(-0.33699592683512467),
945                                    V(-0.30180866526242109),
946                                    V(-0.27321683125358037),
947                                }},
948                                {{
949                                    V(0.65 * 4492.2486445538634),
950                                    V(-1.2),
951                                    V(-1.2),
952                                    V(-0.8),
953                                    V(-0.7),
954                                    V(-0.7),
955                                    V(-0.4),
956                                    V(-0.5),
957                                }}}},
958                              8));
959   }
960   // DCT128X128
961   static QuantEncodingInternal DCT128X128() {
962     return QuantEncodingInternal::DCT(
963         DctQuantWeightParams({{{{
964                                    V(1.8 * 26629.073922049845),
965                                    V(-1.025),
966                                    V(-0.78),
967                                    V(-0.65012),
968                                    V(-0.19041574084286472),
969                                    V(-0.20819395464),
970                                    V(-0.421064),
971                                    V(-0.32733845535848671),
972                                }},
973                                {{
974                                    V(1.8 * 9311.3238710010046),
975                                    V(-0.3041958212306401),
976                                    V(-0.3633036457487539),
977                                    V(-0.35660379990111464),
978                                    V(-0.3443074455424403),
979                                    V(-0.33699592683512467),
980                                    V(-0.30180866526242109),
981                                    V(-0.27321683125358037),
982                                }},
983                                {{
984                                    V(1.8 * 4992.2486445538634),
985                                    V(-1.2),
986                                    V(-1.2),
987                                    V(-0.8),
988                                    V(-0.7),
989                                    V(-0.7),
990                                    V(-0.4),
991                                    V(-0.5),
992                                }}}},
993                              8));
994   }
995
996   // DCT128X64
997   static QuantEncodingInternal DCT64X128() {
998     return QuantEncodingInternal::DCT(
999         DctQuantWeightParams({{{{
1000                                    V(1.3 * 23629.073922049845),
1001                                    V(-1.025),
1002                                    V(-0.78),
1003                                    V(-0.65012),
1004                                    V(-0.19041574084286472),
1005                                    V(-0.20819395464),
1006                                    V(-0.421064),
1007                                    V(-0.32733845535848671),
1008                                }},
1009                                {{
1010                                    V(1.3 * 8611.3238710010046),
1011                                    V(-0.3041958212306401),
1012                                    V(-0.3633036457487539),
1013                                    V(-0.35660379990111464),
1014                                    V(-0.3443074455424403),
1015                                    V(-0.33699592683512467),
1016                                    V(-0.30180866526242109),
1017                                    V(-0.27321683125358037),
1018                                }},
1019                                {{
1020                                    V(1.3 * 4492.2486445538634),
1021                                    V(-1.2),
1022                                    V(-1.2),
1023                                    V(-0.8),
1024                                    V(-0.7),
1025                                    V(-0.7),
1026                                    V(-0.4),
1027                                    V(-0.5),
1028                                }}}},
1029                              8));
1030   }
1031   // DCT256X256
1032   static QuantEncodingInternal DCT256X256() {
1033     return QuantEncodingInternal::DCT(
1034         DctQuantWeightParams({{{{
1035                                    V(3.6 * 26629.073922049845),
1036                                    V(-1.025),
1037                                    V(-0.78),
1038                                    V(-0.65012),
1039                                    V(-0.19041574084286472),
1040                                    V(-0.20819395464),
1041                                    V(-0.421064),
1042                                    V(-0.32733845535848671),
1043                                }},
1044                                {{
1045                                    V(3.6 * 9311.3238710010046),
1046                                    V(-0.3041958212306401),
1047                                    V(-0.3633036457487539),
1048                                    V(-0.35660379990111464),
1049                                    V(-0.3443074455424403),
1050                                    V(-0.33699592683512467),
1051                                    V(-0.30180866526242109),
1052                                    V(-0.27321683125358037),
1053                                }},
1054                                {{
1055                                    V(3.6 * 4992.2486445538634),
1056                                    V(-1.2),
1057                                    V(-1.2),
1058                                    V(-0.8),
1059                                    V(-0.7),
1060                                    V(-0.7),
1061                                    V(-0.4),
1062                                    V(-0.5),
1063                                }}}},
1064                              8));
1065   }
1066
1067   // DCT256X128
1068   static QuantEncodingInternal DCT128X256() {
1069     return QuantEncodingInternal::DCT(
1070         DctQuantWeightParams({{{{
1071                                    V(2.6 * 23629.073922049845),
1072                                    V(-1.025),
1073                                    V(-0.78),
1074                                    V(-0.65012),
1075                                    V(-0.19041574084286472),
1076                                    V(-0.20819395464),
1077                                    V(-0.421064),
1078                                    V(-0.32733845535848671),
1079                                }},
1080                                {{
1081                                    V(2.6 * 8611.3238710010046),
1082                                    V(-0.3041958212306401),
1083                                    V(-0.3633036457487539),
1084                                    V(-0.35660379990111464),
1085                                    V(-0.3443074455424403),
1086                                    V(-0.33699592683512467),
1087                                    V(-0.30180866526242109),
1088                                    V(-0.27321683125358037),
1089                                }},
1090                                {{
1091                                    V(2.6 * 4492.2486445538634),
1092                                    V(-1.2),
1093                                    V(-1.2),
1094                                    V(-0.8),
1095                                    V(-0.7),
1096                                    V(-0.7),
1097                                    V(-0.4),
1098                                    V(-0.5),
1099                                }}}},
1100                              8));
1101   }
1102 };
1103 }  // namespace
1104
1105 DequantMatrices::DequantLibraryInternal DequantMatrices::LibraryInit() {
1106   static_assert(kNum == 17,
1107                 "Update this function when adding new quantization kinds.");
1108   static_assert(kNumPredefinedTables == 1,
1109                 "Update this function when adding new quantization matrices to "
1110                 "the library.");
1111
1112   // The library and the indices need to be kept in sync manually.
1113   static_assert(0 == DCT, "Update the DequantLibrary array below.");
1114   static_assert(1 == IDENTITY, "Update the DequantLibrary array below.");
1115   static_assert(2 == DCT2X2, "Update the DequantLibrary array below.");
1116   static_assert(3 == DCT4X4, "Update the DequantLibrary array below.");
1117   static_assert(4 == DCT16X16, "Update the DequantLibrary array below.");
1118   static_assert(5 == DCT32X32, "Update the DequantLibrary array below.");
1119   static_assert(6 == DCT8X16, "Update the DequantLibrary array below.");
1120   static_assert(7 == DCT8X32, "Update the DequantLibrary array below.");
1121   static_assert(8 == DCT16X32, "Update the DequantLibrary array below.");
1122   static_assert(9 == DCT4X8, "Update the DequantLibrary array below.");
1123   static_assert(10 == AFV0, "Update the DequantLibrary array below.");
1124   static_assert(11 == DCT64X64, "Update the DequantLibrary array below.");
1125   static_assert(12 == DCT32X64, "Update the DequantLibrary array below.");
1126   static_assert(13 == DCT128X128, "Update the DequantLibrary array below.");
1127   static_assert(14 == DCT64X128, "Update the DequantLibrary array below.");
1128   static_assert(15 == DCT256X256, "Update the DequantLibrary array below.");
1129   static_assert(16 == DCT128X256, "Update the DequantLibrary array below.");
1130   return DequantMatrices::DequantLibraryInternal{{
1131       DequantMatricesLibraryDef::DCT(),
1132       DequantMatricesLibraryDef::IDENTITY(),
1133       DequantMatricesLibraryDef::DCT2X2(),
1134       DequantMatricesLibraryDef::DCT4X4(),
1135       DequantMatricesLibraryDef::DCT16X16(),
1136       DequantMatricesLibraryDef::DCT32X32(),
1137       DequantMatricesLibraryDef::DCT8X16(),
1138       DequantMatricesLibraryDef::DCT8X32(),
1139       DequantMatricesLibraryDef::DCT16X32(),
1140       DequantMatricesLibraryDef::DCT4X8(),
1141       DequantMatricesLibraryDef::AFV0(),
1142       DequantMatricesLibraryDef::DCT64X64(),
1143       DequantMatricesLibraryDef::DCT32X64(),
1144       // Same default for large transforms (128+) as for 64x* transforms.
1145       DequantMatricesLibraryDef::DCT128X128(),
1146       DequantMatricesLibraryDef::DCT64X128(),
1147       DequantMatricesLibraryDef::DCT256X256(),
1148       DequantMatricesLibraryDef::DCT128X256(),
1149   }};
1150 }
1151
1152 const QuantEncoding* DequantMatrices::Library() {
1153   static const DequantMatrices::DequantLibraryInternal kDequantLibrary =
1154       DequantMatrices::LibraryInit();
1155   // Downcast the result to a const QuantEncoding* from QuantEncodingInternal*
1156   // since the subclass (QuantEncoding) doesn't add any new members and users
1157   // will need to upcast to QuantEncodingInternal to access the members of that
1158   // class. This allows to have kDequantLibrary as a constexpr value while still
1159   // allowing to create QuantEncoding::RAW() instances that use std::vector in
1160   // C++11.
1161   return reinterpret_cast<const QuantEncoding*>(kDequantLibrary.data());
1162 }
1163
1164 DequantMatrices::DequantMatrices() {
1165   encodings_.resize(size_t(QuantTable::kNum), QuantEncoding::Library(0));
1166   size_t pos = 0;
1167   size_t offsets[kNum * 3];
1168   for (size_t i = 0; i < size_t(QuantTable::kNum); i++) {
1169     size_t num = required_size_[i] * kDCTBlockSize;
1170     for (size_t c = 0; c < 3; c++) {
1171       offsets[3 * i + c] = pos + c * num;
1172     }
1173     pos += 3 * num;
1174   }
1175   for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) {
1176     for (size_t c = 0; c < 3; c++) {
1177       table_offsets_[i * 3 + c] = offsets[kQuantTable[i] * 3 + c];
1178     }
1179   }
1180 }
1181
1182 Status DequantMatrices::EnsureComputed(uint32_t acs_mask) {
1183   const QuantEncoding* library = Library();
1184
1185   if (!table_storage_) {
1186     table_storage_ = hwy::AllocateAligned<float>(2 * kTotalTableSize);
1187     table_ = table_storage_.get();
1188     inv_table_ = table_storage_.get() + kTotalTableSize;
1189   }
1190
1191   size_t offsets[kNum * 3 + 1];
1192   size_t pos = 0;
1193   for (size_t i = 0; i < kNum; i++) {
1194     size_t num = required_size_[i] * kDCTBlockSize;
1195     for (size_t c = 0; c < 3; c++) {
1196       offsets[3 * i + c] = pos + c * num;
1197     }
1198     pos += 3 * num;
1199   }
1200   offsets[kNum * 3] = pos;
1201   JXL_ASSERT(pos == kTotalTableSize);
1202
1203   uint32_t kind_mask = 0;
1204   for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) {
1205     if (acs_mask & (1u << i)) {
1206       kind_mask |= 1u << kQuantTable[i];
1207     }
1208   }
1209   uint32_t computed_kind_mask = 0;
1210   for (size_t i = 0; i < AcStrategy::kNumValidStrategies; i++) {
1211     if (computed_mask_ & (1u << i)) {
1212       computed_kind_mask |= 1u << kQuantTable[i];
1213     }
1214   }
1215   for (size_t table = 0; table < kNum; table++) {
1216     if ((1 << table) & computed_kind_mask) continue;
1217     if ((1 << table) & ~kind_mask) continue;
1218     size_t pos = offsets[table * 3];
1219     if (encodings_[table].mode == QuantEncoding::kQuantModeLibrary) {
1220       JXL_CHECK(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)(
1221           library[table], table_storage_.get(),
1222           table_storage_.get() + kTotalTableSize, table, QuantTable(table),
1223           &pos));
1224     } else {
1225       JXL_RETURN_IF_ERROR(HWY_DYNAMIC_DISPATCH(ComputeQuantTable)(
1226           encodings_[table], table_storage_.get(),
1227           table_storage_.get() + kTotalTableSize, table, QuantTable(table),
1228           &pos));
1229     }
1230     JXL_ASSERT(pos == offsets[table * 3 + 3]);
1231   }
1232   computed_mask_ |= acs_mask;
1233
1234   return true;
1235 }
1236
1237 }  // namespace jxl
1238 #endif