Imported Upstream version 0.9.0
[platform/upstream/libjxl.git] / lib / jxl / enc_ans.cc
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include "lib/jxl/enc_ans.h"
7
8 #include <stdint.h>
9
10 #include <algorithm>
11 #include <array>
12 #include <cmath>
13 #include <limits>
14 #include <numeric>
15 #include <type_traits>
16 #include <unordered_map>
17 #include <utility>
18 #include <vector>
19
20 #include "lib/jxl/ans_common.h"
21 #include "lib/jxl/base/bits.h"
22 #include "lib/jxl/base/fast_math-inl.h"
23 #include "lib/jxl/dec_ans.h"
24 #include "lib/jxl/enc_aux_out.h"
25 #include "lib/jxl/enc_cluster.h"
26 #include "lib/jxl/enc_context_map.h"
27 #include "lib/jxl/enc_fields.h"
28 #include "lib/jxl/enc_huffman.h"
29 #include "lib/jxl/fields.h"
30
31 namespace jxl {
32
33 namespace {
34
35 #if !JXL_IS_DEBUG_BUILD
36 constexpr
37 #endif
38     bool ans_fuzzer_friendly_ = false;
39
40 static const int kMaxNumSymbolsForSmallCode = 4;
41
42 void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table,
43                        size_t alphabet_size, size_t log_alpha_size,
44                        ANSEncSymbolInfo* info) {
45   size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size;
46   size_t entry_size_minus_1 = (1 << log_entry_size) - 1;
47   // create valid alias table for empty streams.
48   for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) {
49     const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s];
50     info[s].freq_ = static_cast<uint16_t>(freq);
51 #ifdef USE_MULT_BY_RECIPROCAL
52     if (freq != 0) {
53       info[s].ifreq_ =
54           ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_;
55     } else {
56       info[s].ifreq_ = 1;  // shouldn't matter (symbol shouldn't occur), but...
57     }
58 #endif
59     info[s].reverse_map_.resize(freq);
60   }
61   for (int i = 0; i < ANS_TAB_SIZE; i++) {
62     AliasTable::Symbol s =
63         AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1);
64     info[s.value].reverse_map_[s.offset] = i;
65   }
66 }
67
68 float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts,
69                        size_t len) {
70   float sum = 0.0f;
71   int total_histogram = 0;
72   int total_counts = 0;
73   for (size_t i = 0; i < len; ++i) {
74     total_histogram += histogram[i];
75     total_counts += counts[i];
76     if (histogram[i] > 0) {
77       JXL_ASSERT(counts[i] > 0);
78       // += histogram[i] * -log(counts[i]/total_counts)
79       sum += histogram[i] *
80              std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i]));
81     }
82   }
83   if (total_histogram > 0) {
84     // Used only in assert.
85     (void)total_counts;
86     JXL_ASSERT(total_counts == ANS_TAB_SIZE);
87   }
88   return sum;
89 }
90
91 float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) {
92   const float flat_bits = std::max(FastLog2f(len), 0.0f);
93   float total_histogram = 0;
94   for (size_t i = 0; i < len; ++i) {
95     total_histogram += histogram[i];
96   }
97   return total_histogram * flat_bits;
98 }
99
100 // Static Huffman code for encoding logcounts. The last symbol is used as RLE
101 // sequence.
102 static const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = {
103     5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
104 };
105 static const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = {
106     17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
107 };
108
109 // Returns the difference between largest count that can be represented and is
110 // smaller than "count" and smallest representable count larger than "count".
111 static int SmallestIncrement(uint32_t count, uint32_t shift) {
112   int bits = count == 0 ? -1 : FloorLog2Nonzero(count);
113   int drop_bits = bits - GetPopulationCountPrecision(bits, shift);
114   return drop_bits < 0 ? 1 : (1 << drop_bits);
115 }
116
117 template <bool minimize_error_of_sum>
118 bool RebalanceHistogram(const float* targets, int max_symbol, int table_size,
119                         uint32_t shift, int* omit_pos, ANSHistBin* counts) {
120   int sum = 0;
121   float sum_nonrounded = 0.0;
122   int remainder_pos = 0;  // if all of them are handled in first loop
123   int remainder_log = -1;
124   for (int n = 0; n < max_symbol; ++n) {
125     if (targets[n] > 0 && targets[n] < 1.0f) {
126       counts[n] = 1;
127       sum_nonrounded += targets[n];
128       sum += counts[n];
129     }
130   }
131   const float discount_ratio =
132       (table_size - sum) / (table_size - sum_nonrounded);
133   JXL_ASSERT(discount_ratio > 0);
134   JXL_ASSERT(discount_ratio <= 1.0f);
135   // Invariant for minimize_error_of_sum == true:
136   // abs(sum - sum_nonrounded)
137   //   <= SmallestIncrement(max(targets[])) + max_symbol
138   for (int n = 0; n < max_symbol; ++n) {
139     if (targets[n] >= 1.0f) {
140       sum_nonrounded += targets[n];
141       counts[n] =
142           static_cast<ANSHistBin>(targets[n] * discount_ratio);  // truncate
143       if (counts[n] == 0) counts[n] = 1;
144       if (counts[n] == table_size) counts[n] = table_size - 1;
145       // Round the count to the closest nonzero multiple of SmallestIncrement
146       // (when minimize_error_of_sum is false) or one of two closest so as to
147       // keep the sum as close as possible to sum_nonrounded.
148       int inc = SmallestIncrement(counts[n], shift);
149       counts[n] -= counts[n] & (inc - 1);
150       // TODO(robryk): Should we rescale targets[n]?
151       const float target =
152           minimize_error_of_sum ? (sum_nonrounded - sum) : targets[n];
153       if (counts[n] == 0 ||
154           (target > counts[n] + inc / 2 && counts[n] + inc < table_size)) {
155         counts[n] += inc;
156       }
157       sum += counts[n];
158       const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n]));
159       if (count_log > remainder_log) {
160         remainder_pos = n;
161         remainder_log = count_log;
162       }
163     }
164   }
165   JXL_ASSERT(remainder_pos != -1);
166   // NOTE: This is the only place where counts could go negative. We could
167   // detect that, return false and make ANSHistBin uint32_t.
168   counts[remainder_pos] -= sum - table_size;
169   *omit_pos = remainder_pos;
170   return counts[remainder_pos] > 0;
171 }
172
173 Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length,
174                        const int precision_bits, uint32_t shift,
175                        int* num_symbols, int* symbols) {
176   const int32_t table_size = 1 << precision_bits;  // target sum / table size
177   uint64_t total = 0;
178   int max_symbol = 0;
179   int symbol_count = 0;
180   for (int n = 0; n < length; ++n) {
181     total += counts[n];
182     if (counts[n] > 0) {
183       if (symbol_count < kMaxNumSymbolsForSmallCode) {
184         symbols[symbol_count] = n;
185       }
186       ++symbol_count;
187       max_symbol = n + 1;
188     }
189   }
190   *num_symbols = symbol_count;
191   if (symbol_count == 0) {
192     return true;
193   }
194   if (symbol_count == 1) {
195     counts[symbols[0]] = table_size;
196     return true;
197   }
198   if (symbol_count > table_size)
199     return JXL_FAILURE("Too many entries in an ANS histogram");
200
201   const float norm = 1.f * table_size / total;
202   std::vector<float> targets(max_symbol);
203   for (size_t n = 0; n < targets.size(); ++n) {
204     targets[n] = norm * counts[n];
205   }
206   if (!RebalanceHistogram<false>(&targets[0], max_symbol, table_size, shift,
207                                  omit_pos, counts)) {
208     // Use an alternative rebalancing mechanism if the one above failed
209     // to create a histogram that is positive wherever the original one was.
210     if (!RebalanceHistogram<true>(&targets[0], max_symbol, table_size, shift,
211                                   omit_pos, counts)) {
212       return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
213     }
214   }
215   return true;
216 }
217
218 struct SizeWriter {
219   size_t size = 0;
220   void Write(size_t num, size_t bits) { size += num; }
221 };
222
223 template <typename Writer>
224 void StoreVarLenUint8(size_t n, Writer* writer) {
225   JXL_DASSERT(n <= 255);
226   if (n == 0) {
227     writer->Write(1, 0);
228   } else {
229     writer->Write(1, 1);
230     size_t nbits = FloorLog2Nonzero(n);
231     writer->Write(3, nbits);
232     writer->Write(nbits, n - (1ULL << nbits));
233   }
234 }
235
236 template <typename Writer>
237 void StoreVarLenUint16(size_t n, Writer* writer) {
238   JXL_DASSERT(n <= 65535);
239   if (n == 0) {
240     writer->Write(1, 0);
241   } else {
242     writer->Write(1, 1);
243     size_t nbits = FloorLog2Nonzero(n);
244     writer->Write(4, nbits);
245     writer->Write(nbits, n - (1ULL << nbits));
246   }
247 }
248
249 template <typename Writer>
250 bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size,
251                   const int omit_pos, const int num_symbols, uint32_t shift,
252                   const int* symbols, Writer* writer) {
253   bool ok = true;
254   if (num_symbols <= 2) {
255     // Small tree marker to encode 1-2 symbols.
256     writer->Write(1, 1);
257     if (num_symbols == 0) {
258       writer->Write(1, 0);
259       StoreVarLenUint8(0, writer);
260     } else {
261       writer->Write(1, num_symbols - 1);
262       for (int i = 0; i < num_symbols; ++i) {
263         StoreVarLenUint8(symbols[i], writer);
264       }
265     }
266     if (num_symbols == 2) {
267       writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]);
268     }
269   } else {
270     // Mark non-small tree.
271     writer->Write(1, 0);
272     // Mark non-flat histogram.
273     writer->Write(1, 0);
274
275     // Precompute sequences for RLE encoding. Contains the number of identical
276     // values starting at a given index. Only contains the value at the first
277     // element of the series.
278     std::vector<uint32_t> same(alphabet_size, 0);
279     int last = 0;
280     for (int i = 1; i < alphabet_size; i++) {
281       // Store the sequence length once different symbol reached, or we're at
282       // the end, or the length is longer than we can encode, or we are at
283       // the omit_pos. We don't support including the omit_pos in an RLE
284       // sequence because this value may use a different amount of log2 bits
285       // than standard, it is too complex to handle in the decoder.
286       if (counts[i] != counts[last] || i + 1 == alphabet_size ||
287           (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) {
288         same[last] = (i - last);
289         last = i + 1;
290       }
291     }
292
293     int length = 0;
294     std::vector<int> logcounts(alphabet_size);
295     int omit_log = 0;
296     for (int i = 0; i < alphabet_size; ++i) {
297       JXL_ASSERT(counts[i] <= ANS_TAB_SIZE);
298       JXL_ASSERT(counts[i] >= 0);
299       if (i == omit_pos) {
300         length = i + 1;
301       } else if (counts[i] > 0) {
302         logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1;
303         length = i + 1;
304         if (i < omit_pos) {
305           omit_log = std::max(omit_log, logcounts[i] + 1);
306         } else {
307           omit_log = std::max(omit_log, logcounts[i]);
308         }
309       }
310     }
311     logcounts[omit_pos] = omit_log;
312
313     // Elias gamma-like code for shift. Only difference is that if the number
314     // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip
315     // the terminating 0 in unary coding.
316     int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
317     int log = FloorLog2Nonzero(shift + 1);
318     writer->Write(log, (1 << log) - 1);
319     if (log != upper_bound_log) writer->Write(1, 0);
320     writer->Write(log, ((1 << log) - 1) & (shift + 1));
321
322     // Since num_symbols >= 3, we know that length >= 3, therefore we encode
323     // length - 3.
324     if (length - 3 > 255) {
325       // Pretend that everything is OK, but complain about correctness later.
326       StoreVarLenUint8(255, writer);
327       ok = false;
328     } else {
329       StoreVarLenUint8(length - 3, writer);
330     }
331
332     // The logcount values are encoded with a static Huffman code.
333     static const size_t kMinReps = 4;
334     size_t rep = ANS_LOG_TAB_SIZE + 1;
335     for (int i = 0; i < length; ++i) {
336       if (i > 0 && same[i - 1] > kMinReps) {
337         // Encode the RLE symbol and skip the repeated ones.
338         writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]);
339         StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer);
340         i += same[i - 1] - 2;
341         continue;
342       }
343       writer->Write(kLogCountBitLengths[logcounts[i]],
344                     kLogCountSymbols[logcounts[i]]);
345     }
346     for (int i = 0; i < length; ++i) {
347       if (i > 0 && same[i - 1] > kMinReps) {
348         // Skip symbols encoded by RLE.
349         i += same[i - 1] - 2;
350         continue;
351       }
352       if (logcounts[i] > 1 && i != omit_pos) {
353         int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift);
354         int drop_bits = logcounts[i] - 1 - bitcount;
355         JXL_CHECK((counts[i] & ((1 << drop_bits) - 1)) == 0);
356         writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount));
357       }
358     }
359   }
360   return ok;
361 }
362
363 void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) {
364   // Mark non-small tree.
365   writer->Write(1, 0);
366   // Mark uniform histogram.
367   writer->Write(1, 1);
368   JXL_ASSERT(alphabet_size > 0);
369   // Encode alphabet size.
370   StoreVarLenUint8(alphabet_size - 1, writer);
371 }
372
373 float ComputeHistoAndDataCost(const ANSHistBin* histogram, size_t alphabet_size,
374                               uint32_t method) {
375   if (method == 0) {  // Flat code
376     return ANS_LOG_TAB_SIZE + 2 +
377            EstimateDataBitsFlat(histogram, alphabet_size);
378   }
379   // Non-flat: shift = method-1.
380   uint32_t shift = method - 1;
381   std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
382   int omit_pos = 0;
383   int num_symbols;
384   int symbols[kMaxNumSymbolsForSmallCode] = {};
385   JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
386                             ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
387   SizeWriter writer;
388   // Ignore the correctness, no real encoding happens at this stage.
389   (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift,
390                      symbols, &writer);
391   return writer.size +
392          EstimateDataBits(histogram, counts.data(), alphabet_size);
393 }
394
395 uint32_t ComputeBestMethod(
396     const ANSHistBin* histogram, size_t alphabet_size, float* cost,
397     HistogramParams::ANSHistogramStrategy ans_histogram_strategy) {
398   size_t method = 0;
399   float fcost = ComputeHistoAndDataCost(histogram, alphabet_size, 0);
400   auto try_shift = [&](size_t shift) {
401     float c = ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1);
402     if (c < fcost) {
403       method = shift + 1;
404       fcost = c;
405     }
406   };
407   switch (ans_histogram_strategy) {
408     case HistogramParams::ANSHistogramStrategy::kPrecise: {
409       for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift++) {
410         try_shift(shift);
411       }
412       break;
413     }
414     case HistogramParams::ANSHistogramStrategy::kApproximate: {
415       for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift += 2) {
416         try_shift(shift);
417       }
418       break;
419     }
420     case HistogramParams::ANSHistogramStrategy::kFast: {
421       try_shift(0);
422       try_shift(ANS_LOG_TAB_SIZE / 2);
423       try_shift(ANS_LOG_TAB_SIZE);
424       break;
425     }
426   };
427   *cost = fcost;
428   return method;
429 }
430
431 }  // namespace
432
433 // Returns an estimate of the cost of encoding this histogram and the
434 // corresponding data.
435 size_t BuildAndStoreANSEncodingData(
436     HistogramParams::ANSHistogramStrategy ans_histogram_strategy,
437     const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size,
438     bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) {
439   if (use_prefix_code) {
440     if (alphabet_size <= 1) return 0;
441     std::vector<uint32_t> histo(alphabet_size);
442     for (size_t i = 0; i < alphabet_size; i++) {
443       histo[i] = histogram[i];
444       JXL_CHECK(histogram[i] >= 0);
445     }
446     size_t cost = 0;
447     {
448       std::vector<uint8_t> depths(alphabet_size);
449       std::vector<uint16_t> bits(alphabet_size);
450       if (writer == nullptr) {
451         BitWriter tmp_writer;
452         BitWriter::Allotment allotment(
453             &tmp_writer, 8 * alphabet_size + 8);  // safe upper bound
454         BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(),
455                                  bits.data(), &tmp_writer);
456         allotment.ReclaimAndCharge(&tmp_writer, 0, /*aux_out=*/nullptr);
457         cost = tmp_writer.BitsWritten();
458       } else {
459         size_t start = writer->BitsWritten();
460         BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(),
461                                  bits.data(), writer);
462         cost = writer->BitsWritten() - start;
463       }
464       for (size_t i = 0; i < alphabet_size; i++) {
465         info[i].bits = depths[i] == 0 ? 0 : bits[i];
466         info[i].depth = depths[i];
467       }
468     }
469     // Estimate data cost.
470     for (size_t i = 0; i < alphabet_size; i++) {
471       cost += histogram[i] * info[i].depth;
472     }
473     return cost;
474   }
475   JXL_ASSERT(alphabet_size <= ANS_TAB_SIZE);
476   // Ensure we ignore trailing zeros in the histogram.
477   if (alphabet_size != 0) {
478     size_t largest_symbol = 0;
479     for (size_t i = 0; i < alphabet_size; i++) {
480       if (histogram[i] != 0) largest_symbol = i;
481     }
482     alphabet_size = largest_symbol + 1;
483   }
484   float cost;
485   uint32_t method = ComputeBestMethod(histogram, alphabet_size, &cost,
486                                       ans_histogram_strategy);
487   JXL_ASSERT(cost >= 0);
488   int num_symbols;
489   int symbols[kMaxNumSymbolsForSmallCode] = {};
490   std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
491   if (!counts.empty()) {
492     size_t sum = 0;
493     for (size_t i = 0; i < counts.size(); i++) {
494       sum += counts[i];
495     }
496     if (sum == 0) {
497       counts[0] = ANS_TAB_SIZE;
498     }
499   }
500   if (method == 0) {
501     counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
502     AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
503     InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
504     ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
505     if (writer != nullptr) {
506       EncodeFlatHistogram(alphabet_size, writer);
507     }
508     return cost;
509   }
510   int omit_pos = 0;
511   uint32_t shift = method - 1;
512   JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
513                             ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
514   AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
515   InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
516   ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
517   if (writer != nullptr) {
518     bool ok = EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols,
519                            shift, symbols, writer);
520     (void)ok;
521     JXL_DASSERT(ok);
522   }
523   return cost;
524 }
525
526 float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size) {
527   float c;
528   ComputeBestMethod(data, alphabet_size, &c,
529                     HistogramParams::ANSHistogramStrategy::kFast);
530   return c;
531 }
532
533 template <typename Writer>
534 void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer,
535                       size_t log_alpha_size) {
536   writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
537                 uint_config.split_exponent);
538   if (uint_config.split_exponent == log_alpha_size) {
539     return;  // msb/lsb don't matter.
540   }
541   size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
542   writer->Write(nbits, uint_config.msb_in_token);
543   nbits = CeilLog2Nonzero(uint_config.split_exponent -
544                           uint_config.msb_in_token + 1);
545   writer->Write(nbits, uint_config.lsb_in_token);
546 }
547 template <typename Writer>
548 void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
549                        Writer* writer, size_t log_alpha_size) {
550   // TODO(veluca): RLE?
551   for (size_t i = 0; i < uint_config.size(); i++) {
552     EncodeUintConfig(uint_config[i], writer, log_alpha_size);
553   }
554 }
555 template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
556                                 BitWriter*, size_t);
557
558 namespace {
559
560 void ChooseUintConfigs(const HistogramParams& params,
561                        const std::vector<std::vector<Token>>& tokens,
562                        const std::vector<uint8_t>& context_map,
563                        std::vector<Histogram>* clustered_histograms,
564                        EntropyEncodingData* codes, size_t* log_alpha_size) {
565   codes->uint_config.resize(clustered_histograms->size());
566
567   if (params.uint_method == HistogramParams::HybridUintMethod::kNone) return;
568   if (params.uint_method == HistogramParams::HybridUintMethod::k000) {
569     codes->uint_config.clear();
570     codes->uint_config.resize(clustered_histograms->size(),
571                               HybridUintConfig(0, 0, 0));
572     return;
573   }
574   if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
575     codes->uint_config.clear();
576     codes->uint_config.resize(clustered_histograms->size(),
577                               HybridUintConfig(2, 0, 1));
578     return;
579   }
580
581   // Brute-force method that tries a few options.
582   std::vector<HybridUintConfig> configs;
583   if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
584     configs = {
585         HybridUintConfig(4, 2, 0),  // default
586         HybridUintConfig(4, 1, 0),  // less precise
587         HybridUintConfig(4, 2, 1),  // add sign
588         HybridUintConfig(4, 2, 2),  // add sign+parity
589         HybridUintConfig(4, 1, 2),  // add parity but less msb
590         // Same as above, but more direct coding.
591         HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0),
592         HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2),
593         HybridUintConfig(5, 1, 2),
594         // Same as above, but less direct coding.
595         HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0),
596         HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2),
597         // For near-lossless.
598         HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4),
599         HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5),
600         HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0),
601         // Other
602         HybridUintConfig(0, 0, 0),   // varlenuint
603         HybridUintConfig(2, 0, 1),   // works well for ctx map
604         HybridUintConfig(7, 0, 0),   // direct coding
605         HybridUintConfig(8, 0, 0),   // direct coding
606         HybridUintConfig(9, 0, 0),   // direct coding
607         HybridUintConfig(10, 0, 0),  // direct coding
608         HybridUintConfig(11, 0, 0),  // direct coding
609         HybridUintConfig(12, 0, 0),  // direct coding
610     };
611   } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) {
612     configs = {
613         HybridUintConfig(4, 2, 0),  // default
614         HybridUintConfig(4, 1, 2),  // add parity but less msb
615         HybridUintConfig(0, 0, 0),  // smallest histograms
616         HybridUintConfig(2, 0, 1),  // works well for ctx map
617     };
618   }
619
620   std::vector<float> costs(clustered_histograms->size(),
621                            std::numeric_limits<float>::max());
622   std::vector<uint32_t> extra_bits(clustered_histograms->size());
623   std::vector<uint8_t> is_valid(clustered_histograms->size());
624   size_t max_alpha =
625       codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE;
626   for (HybridUintConfig cfg : configs) {
627     std::fill(is_valid.begin(), is_valid.end(), true);
628     std::fill(extra_bits.begin(), extra_bits.end(), 0);
629
630     for (size_t i = 0; i < clustered_histograms->size(); i++) {
631       (*clustered_histograms)[i].Clear();
632     }
633     for (size_t i = 0; i < tokens.size(); ++i) {
634       for (size_t j = 0; j < tokens[i].size(); ++j) {
635         const Token token = tokens[i][j];
636         // TODO(veluca): do not ignore lz77 commands.
637         if (token.is_lz77_length) continue;
638         size_t histo = context_map[token.context];
639         uint32_t tok, nbits, bits;
640         cfg.Encode(token.value, &tok, &nbits, &bits);
641         if (tok >= max_alpha ||
642             (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) {
643           is_valid[histo] = false;
644           continue;
645         }
646         extra_bits[histo] += nbits;
647         (*clustered_histograms)[histo].Add(tok);
648       }
649     }
650
651     for (size_t i = 0; i < clustered_histograms->size(); i++) {
652       if (!is_valid[i]) continue;
653       float cost = (*clustered_histograms)[i].PopulationCost() + extra_bits[i];
654       // add signaling cost of the hybriduintconfig itself
655       cost += CeilLog2Nonzero(cfg.split_exponent + 1);
656       cost += CeilLog2Nonzero(cfg.split_exponent - cfg.msb_in_token + 1);
657       if (cost < costs[i]) {
658         codes->uint_config[i] = cfg;
659         costs[i] = cost;
660       }
661     }
662   }
663
664   // Rebuild histograms.
665   for (size_t i = 0; i < clustered_histograms->size(); i++) {
666     (*clustered_histograms)[i].Clear();
667   }
668   *log_alpha_size = 4;
669   for (size_t i = 0; i < tokens.size(); ++i) {
670     for (size_t j = 0; j < tokens[i].size(); ++j) {
671       const Token token = tokens[i][j];
672       uint32_t tok, nbits, bits;
673       size_t histo = context_map[token.context];
674       (token.is_lz77_length ? codes->lz77.length_uint_config
675                             : codes->uint_config[histo])
676           .Encode(token.value, &tok, &nbits, &bits);
677       tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
678       (*clustered_histograms)[histo].Add(tok);
679       while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++;
680     }
681   }
682 #if JXL_ENABLE_ASSERT
683   size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8;
684   JXL_ASSERT(*log_alpha_size <= max_log_alpha_size);
685 #endif
686 }
687
688 class HistogramBuilder {
689  public:
690   explicit HistogramBuilder(const size_t num_contexts)
691       : histograms_(num_contexts) {}
692
693   void VisitSymbol(int symbol, size_t histo_idx) {
694     JXL_DASSERT(histo_idx < histograms_.size());
695     histograms_[histo_idx].Add(symbol);
696   }
697
698   // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge.
699   size_t BuildAndStoreEntropyCodes(
700       const HistogramParams& params,
701       const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes,
702       std::vector<uint8_t>* context_map, bool use_prefix_code,
703       BitWriter* writer, size_t layer, AuxOut* aux_out) const {
704     size_t cost = 0;
705     codes->encoding_info.clear();
706     std::vector<Histogram> clustered_histograms(histograms_);
707     context_map->resize(histograms_.size());
708     if (histograms_.size() > 1) {
709       if (!ans_fuzzer_friendly_) {
710         std::vector<uint32_t> histogram_symbols;
711         ClusterHistograms(params, histograms_, kClustersLimit,
712                           &clustered_histograms, &histogram_symbols);
713         for (size_t c = 0; c < histograms_.size(); ++c) {
714           (*context_map)[c] = static_cast<uint8_t>(histogram_symbols[c]);
715         }
716       } else {
717         fill(context_map->begin(), context_map->end(), 0);
718         size_t max_symbol = 0;
719         for (const Histogram& h : histograms_) {
720           max_symbol = std::max(h.data_.size(), max_symbol);
721         }
722         size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1);
723         clustered_histograms.resize(1);
724         clustered_histograms[0].Clear();
725         for (size_t i = 0; i < num_symbols; i++) {
726           clustered_histograms[0].Add(i);
727         }
728       }
729       if (writer != nullptr) {
730         EncodeContextMap(*context_map, clustered_histograms.size(), writer,
731                          layer, aux_out);
732       }
733     }
734     if (aux_out != nullptr) {
735       for (size_t i = 0; i < clustered_histograms.size(); ++i) {
736         aux_out->layers[layer].clustered_entropy +=
737             clustered_histograms[i].ShannonEntropy();
738       }
739     }
740     codes->use_prefix_code = use_prefix_code;
741     size_t log_alpha_size = codes->lz77.enabled ? 8 : 7;  // Sane default.
742     if (ans_fuzzer_friendly_) {
743       codes->uint_config.clear();
744       codes->uint_config.resize(1, HybridUintConfig(7, 0, 0));
745     } else {
746       ChooseUintConfigs(params, tokens, *context_map, &clustered_histograms,
747                         codes, &log_alpha_size);
748     }
749     if (log_alpha_size < 5) log_alpha_size = 5;
750     SizeWriter size_writer;  // Used if writer == nullptr to estimate costs.
751     cost += 1;
752     if (writer) writer->Write(1, use_prefix_code);
753
754     if (use_prefix_code) {
755       log_alpha_size = PREFIX_MAX_BITS;
756     } else {
757       cost += 2;
758     }
759     if (writer == nullptr) {
760       EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size);
761     } else {
762       if (!use_prefix_code) writer->Write(2, log_alpha_size - 5);
763       EncodeUintConfigs(codes->uint_config, writer, log_alpha_size);
764     }
765     if (use_prefix_code) {
766       for (size_t c = 0; c < clustered_histograms.size(); ++c) {
767         size_t num_symbol = 1;
768         for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
769           if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
770         }
771         if (writer) {
772           StoreVarLenUint16(num_symbol - 1, writer);
773         } else {
774           StoreVarLenUint16(num_symbol - 1, &size_writer);
775         }
776       }
777     }
778     cost += size_writer.size;
779     for (size_t c = 0; c < clustered_histograms.size(); ++c) {
780       size_t num_symbol = 1;
781       for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
782         if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
783       }
784       codes->encoding_info.emplace_back();
785       codes->encoding_info.back().resize(std::max<size_t>(1, num_symbol));
786
787       BitWriter::Allotment allotment(writer, 256 + num_symbol * 24);
788       cost += BuildAndStoreANSEncodingData(
789           params.ans_histogram_strategy, clustered_histograms[c].data_.data(),
790           num_symbol, log_alpha_size, use_prefix_code,
791           codes->encoding_info.back().data(), writer);
792       allotment.FinishedHistogram(writer);
793       allotment.ReclaimAndCharge(writer, layer, aux_out);
794     }
795     return cost;
796   }
797
798   const Histogram& Histo(size_t i) const { return histograms_[i]; }
799
800  private:
801   std::vector<Histogram> histograms_;
802 };
803
804 class SymbolCostEstimator {
805  public:
806   SymbolCostEstimator(size_t num_contexts, bool force_huffman,
807                       const std::vector<std::vector<Token>>& tokens,
808                       const LZ77Params& lz77) {
809     HistogramBuilder builder(num_contexts);
810     // Build histograms for estimating lz77 savings.
811     HybridUintConfig uint_config;
812     for (size_t i = 0; i < tokens.size(); ++i) {
813       for (size_t j = 0; j < tokens[i].size(); ++j) {
814         const Token token = tokens[i][j];
815         uint32_t tok, nbits, bits;
816         (token.is_lz77_length ? lz77.length_uint_config : uint_config)
817             .Encode(token.value, &tok, &nbits, &bits);
818         tok += token.is_lz77_length ? lz77.min_symbol : 0;
819         builder.VisitSymbol(tok, token.context);
820       }
821     }
822     max_alphabet_size_ = 0;
823     for (size_t i = 0; i < num_contexts; i++) {
824       max_alphabet_size_ =
825           std::max(max_alphabet_size_, builder.Histo(i).data_.size());
826     }
827     bits_.resize(num_contexts * max_alphabet_size_);
828     // TODO(veluca): SIMD?
829     add_symbol_cost_.resize(num_contexts);
830     for (size_t i = 0; i < num_contexts; i++) {
831       float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f);
832       float total_cost = 0;
833       for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) {
834         size_t cnt = builder.Histo(i).data_[j];
835         float cost = 0;
836         if (cnt != 0 && cnt != builder.Histo(i).total_count_) {
837           cost = -FastLog2f(cnt * inv_total);
838           if (force_huffman) cost = std::ceil(cost);
839         } else if (cnt == 0) {
840           cost = ANS_LOG_TAB_SIZE;  // Highest possible cost.
841         }
842         bits_[i * max_alphabet_size_ + j] = cost;
843         total_cost += cost * builder.Histo(i).data_[j];
844       }
845       // Penalty for adding a lz77 symbol to this contest (only used for static
846       // cost model). Higher penalty for contexts that have a very low
847       // per-symbol entropy.
848       add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
849     }
850   }
851   float Bits(size_t ctx, size_t sym) const {
852     return bits_[ctx * max_alphabet_size_ + sym];
853   }
854   float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
855     uint32_t nbits, bits, tok;
856     lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
857     tok += lz77.min_symbol;
858     return nbits + Bits(ctx, tok);
859   }
860   float DistCost(size_t len, const LZ77Params& lz77) const {
861     uint32_t nbits, bits, tok;
862     HybridUintConfig().Encode(len, &tok, &nbits, &bits);
863     return nbits + Bits(lz77.nonserialized_distance_context, tok);
864   }
865   float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }
866
867  private:
868   size_t max_alphabet_size_;
869   std::vector<float> bits_;
870   std::vector<float> add_symbol_cost_;
871 };
872
873 void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts,
874                    const std::vector<std::vector<Token>>& tokens,
875                    LZ77Params& lz77,
876                    std::vector<std::vector<Token>>& tokens_lz77) {
877   // TODO(veluca): tune heuristics here.
878   SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
879   float bit_decrease = 0;
880   size_t total_symbols = 0;
881   tokens_lz77.resize(tokens.size());
882   std::vector<float> sym_cost;
883   HybridUintConfig uint_config;
884   for (size_t stream = 0; stream < tokens.size(); stream++) {
885     size_t distance_multiplier =
886         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
887     const auto& in = tokens[stream];
888     auto& out = tokens_lz77[stream];
889     total_symbols += in.size();
890     // Cumulative sum of bit costs.
891     sym_cost.resize(in.size() + 1);
892     for (size_t i = 0; i < in.size(); i++) {
893       uint32_t tok, nbits, unused_bits;
894       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
895       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
896     }
897     out.reserve(in.size());
898     for (size_t i = 0; i < in.size(); i++) {
899       size_t num_to_copy = 0;
900       size_t distance_symbol = 0;  // 1 for RLE.
901       if (distance_multiplier != 0) {
902         distance_symbol = 1;  // Special distance 1 if enabled.
903         JXL_DASSERT(kSpecialDistances[1][0] == 1);
904         JXL_DASSERT(kSpecialDistances[1][1] == 0);
905       }
906       if (i > 0) {
907         for (; i + num_to_copy < in.size(); num_to_copy++) {
908           if (in[i + num_to_copy].value != in[i - 1].value) {
909             break;
910           }
911         }
912       }
913       if (num_to_copy == 0) {
914         out.push_back(in[i]);
915         continue;
916       }
917       float cost = sym_cost[i + num_to_copy] - sym_cost[i];
918       // This subtraction might overflow, but that's OK.
919       size_t lz77_len = num_to_copy - lz77.min_length;
920       float lz77_cost = num_to_copy >= lz77.min_length
921                             ? CeilLog2Nonzero(lz77_len + 1) + 1
922                             : 0;
923       if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
924         for (size_t j = 0; j < num_to_copy; j++) {
925           out.push_back(in[i + j]);
926         }
927         i += num_to_copy - 1;
928         continue;
929       }
930       // Output the LZ77 length
931       out.emplace_back(in[i].context, lz77_len);
932       out.back().is_lz77_length = true;
933       i += num_to_copy - 1;
934       bit_decrease += cost - lz77_cost;
935       // Output the LZ77 copy distance.
936       out.emplace_back(lz77.nonserialized_distance_context, distance_symbol);
937     }
938   }
939
940   if (bit_decrease > total_symbols * 0.2 + 16) {
941     lz77.enabled = true;
942   }
943 }
944
945 // Hash chain for LZ77 matching
946 struct HashChain {
947   size_t size_;
948   std::vector<uint32_t> data_;
949
950   unsigned hash_num_values_ = 32768;
951   unsigned hash_mask_ = hash_num_values_ - 1;
952   unsigned hash_shift_ = 5;
953
954   std::vector<int> head;
955   std::vector<uint32_t> chain;
956   std::vector<int> val;
957
958   // Speed up repetitions of zero
959   std::vector<int> headz;
960   std::vector<uint32_t> chainz;
961   std::vector<uint32_t> zeros;
962   uint32_t numzeros = 0;
963
964   size_t window_size_;
965   size_t window_mask_;
966   size_t min_length_;
967   size_t max_length_;
968
969   // Map of special distance codes.
970   std::unordered_map<int, int> special_dist_table_;
971   size_t num_special_distances_ = 0;
972
973   uint32_t maxchainlength = 256;  // window_size_ to allow all
974
975   HashChain(const Token* data, size_t size, size_t window_size,
976             size_t min_length, size_t max_length, size_t distance_multiplier)
977       : size_(size),
978         window_size_(window_size),
979         window_mask_(window_size - 1),
980         min_length_(min_length),
981         max_length_(max_length) {
982     data_.resize(size);
983     for (size_t i = 0; i < size; i++) {
984       data_[i] = data[i].value;
985     }
986
987     head.resize(hash_num_values_, -1);
988     val.resize(window_size_, -1);
989     chain.resize(window_size_);
990     for (uint32_t i = 0; i < window_size_; ++i) {
991       chain[i] = i;  // same value as index indicates uninitialized
992     }
993
994     zeros.resize(window_size_);
995     headz.resize(window_size_ + 1, -1);
996     chainz.resize(window_size_);
997     for (uint32_t i = 0; i < window_size_; ++i) {
998       chainz[i] = i;
999     }
1000     // Translate distance to special distance code.
1001     if (distance_multiplier) {
1002       // Count down, so if due to small distance multiplier multiple distances
1003       // map to the same code, the smallest code will be used in the end.
1004       for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
1005         int xi = kSpecialDistances[i][0];
1006         int yi = kSpecialDistances[i][1];
1007         int distance = yi * distance_multiplier + xi;
1008         // Ensure that we map distance 1 to the lowest symbols.
1009         if (distance < 1) distance = 1;
1010         special_dist_table_[distance] = i;
1011       }
1012       num_special_distances_ = kNumSpecialDistances;
1013     }
1014   }
1015
1016   uint32_t GetHash(size_t pos) const {
1017     uint32_t result = 0;
1018     if (pos + 2 < size_) {
1019       // TODO(lode): take the MSB's of the uint32_t values into account as well,
1020       // given that the hash code itself is less than 32 bits.
1021       result ^= (uint32_t)(data_[pos + 0] << 0u);
1022       result ^= (uint32_t)(data_[pos + 1] << hash_shift_);
1023       result ^= (uint32_t)(data_[pos + 2] << (hash_shift_ * 2));
1024     } else {
1025       // No need to compute hash of last 2 bytes, the length 2 is too short.
1026       return 0;
1027     }
1028     return result & hash_mask_;
1029   }
1030
1031   uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
1032     size_t end = pos + window_size_;
1033     if (end > size_) end = size_;
1034     if (prevzeros > 0) {
1035       if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
1036           end == pos + window_size_) {
1037         return prevzeros;
1038       } else {
1039         return prevzeros - 1;
1040       }
1041     }
1042     uint32_t num = 0;
1043     while (pos + num < end && data_[pos + num] == 0) num++;
1044     return num;
1045   }
1046
1047   void Update(size_t pos) {
1048     uint32_t hashval = GetHash(pos);
1049     uint32_t wpos = pos & window_mask_;
1050
1051     val[wpos] = (int)hashval;
1052     if (head[hashval] != -1) chain[wpos] = head[hashval];
1053     head[hashval] = wpos;
1054
1055     if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
1056     numzeros = CountZeros(pos, numzeros);
1057
1058     zeros[wpos] = numzeros;
1059     if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
1060     headz[numzeros] = wpos;
1061   }
1062
1063   void Update(size_t pos, size_t len) {
1064     for (size_t i = 0; i < len; i++) {
1065       Update(pos + i);
1066     }
1067   }
1068
1069   template <typename CB>
1070   void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
1071     uint32_t wpos = pos & window_mask_;
1072     uint32_t hashval = GetHash(pos);
1073     uint32_t hashpos = chain[wpos];
1074
1075     int prev_dist = 0;
1076     int end = std::min<int>(pos + max_length_, size_);
1077     uint32_t chainlength = 0;
1078     uint32_t best_len = 0;
1079     for (;;) {
1080       int dist = (hashpos <= wpos) ? (wpos - hashpos)
1081                                    : (wpos - hashpos + window_mask_ + 1);
1082       if (dist < prev_dist) break;
1083       prev_dist = dist;
1084       uint32_t len = 0;
1085       if (dist > 0) {
1086         int i = pos;
1087         int j = pos - dist;
1088         if (numzeros > 3) {
1089           int r = std::min<int>(numzeros - 1, zeros[hashpos]);
1090           if (i + r >= end) r = end - i - 1;
1091           i += r;
1092           j += r;
1093         }
1094         while (i < end && data_[i] == data_[j]) {
1095           i++;
1096           j++;
1097         }
1098         len = i - pos;
1099         // This can trigger even if the new length is slightly smaller than the
1100         // best length, because it is possible for a slightly cheaper distance
1101         // symbol to occur.
1102         if (len >= min_length_ && len + 2 >= best_len) {
1103           auto it = special_dist_table_.find(dist);
1104           int dist_symbol = (it == special_dist_table_.end())
1105                                 ? (num_special_distances_ + dist - 1)
1106                                 : it->second;
1107           found_match(len, dist_symbol);
1108           if (len > best_len) best_len = len;
1109         }
1110       }
1111
1112       chainlength++;
1113       if (chainlength >= maxchainlength) break;
1114
1115       if (numzeros >= 3 && len > numzeros) {
1116         if (hashpos == chainz[hashpos]) break;
1117         hashpos = chainz[hashpos];
1118         if (zeros[hashpos] != numzeros) break;
1119       } else {
1120         if (hashpos == chain[hashpos]) break;
1121         hashpos = chain[hashpos];
1122         if (val[hashpos] != (int)hashval) break;  // outdated hash value
1123       }
1124     }
1125   }
1126   void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
1127                  size_t* result_len) const {
1128     *result_dist_symbol = 0;
1129     *result_len = 1;
1130     FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
1131       if (len > *result_len ||
1132           (len == *result_len && *result_dist_symbol > dist_symbol)) {
1133         *result_len = len;
1134         *result_dist_symbol = dist_symbol;
1135       }
1136     });
1137   }
1138 };
1139
1140 float LenCost(size_t len) {
1141   uint32_t nbits, bits, tok;
1142   HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
1143   constexpr float kCostTable[] = {
1144       2.797667318563126,  3.213177690381199,  2.5706009246743737,
1145       2.408392498667534,  2.829649191872326,  3.3923087753324577,
1146       4.029267451554331,  4.415576699706408,  4.509357574741465,
1147       9.21481543803004,   10.020590190114898, 11.858671627804766,
1148       12.45853300490526,  11.713105831990857, 12.561996324849314,
1149       13.775477692278367, 13.174027068768641,
1150   };
1151   size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1152   if (tok >= table_size) tok = table_size - 1;
1153   return kCostTable[tok] + nbits;
1154 }
1155
1156 // TODO(veluca): this does not take into account usage or non-usage of distance
1157 // multipliers.
1158 float DistCost(size_t dist) {
1159   uint32_t nbits, bits, tok;
1160   HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
1161   constexpr float kCostTable[] = {
1162       6.368282626312716,  5.680793277090298,  8.347404197105247,
1163       7.641619201599141,  6.914328374119438,  7.959808291537444,
1164       8.70023120759855,   8.71378518934703,   9.379132523982769,
1165       9.110472749092708,  9.159029569270908,  9.430936766731973,
1166       7.278284055315169,  7.8278514904267755, 10.026641158289236,
1167       9.976049229827066,  9.64351607048908,   9.563403863480442,
1168       10.171474111762747, 10.45950155077234,  9.994813912104219,
1169       10.322524683741156, 8.465808729388186,  8.756254166066853,
1170       10.160930174662234, 10.247329273413435, 10.04090403724809,
1171       10.129398517544082, 9.342311691539546,  9.07608009102374,
1172       10.104799540677513, 10.378079384990906, 10.165828974075072,
1173       10.337595322341553, 7.940557464567944,  10.575665823319431,
1174       11.023344321751955, 10.736144698831827, 11.118277044595054,
1175       7.468468230648442,  10.738305230932939, 10.906980780216568,
1176       10.163468216353817, 10.17805759656433,  11.167283670483565,
1177       11.147050200274544, 10.517921919244333, 10.651764778156886,
1178       10.17074446448919,  11.217636876224745, 11.261630721139484,
1179       11.403140815247259, 10.892472096873417, 11.1859607804481,
1180       8.017346947551262,  7.895143720278828,  11.036577113822025,
1181       11.170562110315794, 10.326988722591086, 10.40872184751056,
1182       11.213498225466386, 11.30580635516863,  10.672272515665442,
1183       10.768069466228063, 11.145257364153565, 11.64668307145549,
1184       10.593156194627339, 11.207499484844943, 10.767517766396908,
1185       10.826629811407042, 10.737764794499988, 10.6200448518045,
1186       10.191315385198092, 8.468384171390085,  11.731295299170432,
1187       11.824619886654398, 10.41518844301179,  10.16310536548649,
1188       10.539423685097576, 10.495136599328031, 10.469112847728267,
1189       11.72057686174922,  10.910326337834674, 11.378921834673758,
1190       11.847759036098536, 11.92071647623854,  10.810628276345282,
1191       11.008601085273893, 11.910326337834674, 11.949212023423133,
1192       11.298614839104337, 11.611603659010392, 10.472930394619985,
1193       11.835564720850282, 11.523267392285337, 12.01055816679611,
1194       8.413029688994023,  11.895784139536406, 11.984679534970505,
1195       11.220654278717394, 11.716311684833672, 10.61036646226114,
1196       10.89849965960364,  10.203762898863669, 10.997560826267238,
1197       11.484217379438984, 11.792836176993665, 12.24310468755171,
1198       11.464858097919262, 12.212747017409377, 11.425595666074955,
1199       11.572048533398757, 12.742093965163013, 11.381874288645637,
1200       12.191870445817015, 11.683156920035426, 11.152442115262197,
1201       11.90303691580457,  11.653292787169159, 11.938615382266098,
1202       16.970641701570223, 16.853602280380002, 17.26240782594733,
1203       16.644655390108507, 17.14310889757499,  16.910935455445955,
1204       17.505678976959697, 17.213498225466388, 2.4162310293553024,
1205       3.494587244462329,  3.5258600986408344, 3.4959806589517095,
1206       3.098390886949687,  3.343454654302911,  3.588847442290287,
1207       4.14614790111827,   5.152948641990529,  7.433696808092598,
1208       9.716311684833672,
1209   };
1210   size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1211   if (tok >= table_size) tok = table_size - 1;
1212   return kCostTable[tok] + nbits;
1213 }
1214
1215 void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts,
1216                     const std::vector<std::vector<Token>>& tokens,
1217                     LZ77Params& lz77,
1218                     std::vector<std::vector<Token>>& tokens_lz77) {
1219   // TODO(veluca): tune heuristics here.
1220   SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
1221   float bit_decrease = 0;
1222   size_t total_symbols = 0;
1223   tokens_lz77.resize(tokens.size());
1224   HybridUintConfig uint_config;
1225   std::vector<float> sym_cost;
1226   for (size_t stream = 0; stream < tokens.size(); stream++) {
1227     size_t distance_multiplier =
1228         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1229     const auto& in = tokens[stream];
1230     auto& out = tokens_lz77[stream];
1231     total_symbols += in.size();
1232     // Cumulative sum of bit costs.
1233     sym_cost.resize(in.size() + 1);
1234     for (size_t i = 0; i < in.size(); i++) {
1235       uint32_t tok, nbits, unused_bits;
1236       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1237       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1238     }
1239
1240     out.reserve(in.size());
1241     size_t max_distance = in.size();
1242     size_t min_length = lz77.min_length;
1243     JXL_ASSERT(min_length >= 3);
1244     size_t max_length = in.size();
1245
1246     // Use next power of two as window size.
1247     size_t window_size = 1;
1248     while (window_size < max_distance && window_size < kWindowSize) {
1249       window_size <<= 1;
1250     }
1251
1252     HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1253                     distance_multiplier);
1254     size_t len, dist_symbol;
1255
1256     const size_t max_lazy_match_len = 256;  // 0 to disable lazy matching
1257
1258     // Whether the next symbol was already updated (to test lazy matching)
1259     bool already_updated = false;
1260     for (size_t i = 0; i < in.size(); i++) {
1261       out.push_back(in[i]);
1262       if (!already_updated) chain.Update(i);
1263       already_updated = false;
1264       chain.FindMatch(i, max_distance, &dist_symbol, &len);
1265       if (len >= min_length) {
1266         if (len < max_lazy_match_len && i + 1 < in.size()) {
1267           // Try length at next symbol lazy matching
1268           chain.Update(i + 1);
1269           already_updated = true;
1270           size_t len2, dist_symbol2;
1271           chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
1272           if (len2 > len) {
1273             // Use the lazy match. Add literal, and use the next length starting
1274             // from the next byte.
1275             ++i;
1276             already_updated = false;
1277             len = len2;
1278             dist_symbol = dist_symbol2;
1279             out.push_back(in[i]);
1280           }
1281         }
1282
1283         float cost = sym_cost[i + len] - sym_cost[i];
1284         size_t lz77_len = len - lz77.min_length;
1285         float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
1286                           sce.AddSymbolCost(out.back().context);
1287
1288         if (lz77_cost <= cost) {
1289           out.back().value = len - min_length;
1290           out.back().is_lz77_length = true;
1291           out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1292           bit_decrease += cost - lz77_cost;
1293         } else {
1294           // LZ77 match ignored, and symbol already pushed. Push all other
1295           // symbols and skip.
1296           for (size_t j = 1; j < len; j++) {
1297             out.push_back(in[i + j]);
1298           }
1299         }
1300
1301         if (already_updated) {
1302           chain.Update(i + 2, len - 2);
1303           already_updated = false;
1304         } else {
1305           chain.Update(i + 1, len - 1);
1306         }
1307         i += len - 1;
1308       } else {
1309         // Literal, already pushed
1310       }
1311     }
1312   }
1313
1314   if (bit_decrease > total_symbols * 0.2 + 16) {
1315     lz77.enabled = true;
1316   }
1317 }
1318
1319 void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts,
1320                        const std::vector<std::vector<Token>>& tokens,
1321                        LZ77Params& lz77,
1322                        std::vector<std::vector<Token>>& tokens_lz77) {
1323   std::vector<std::vector<Token>> tokens_for_cost_estimate;
1324   ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate);
1325   // If greedy-LZ77 does not give better compression than no-lz77, no reason to
1326   // run the optimal matching.
1327   if (!lz77.enabled) return;
1328   SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
1329                           tokens_for_cost_estimate, lz77);
1330   tokens_lz77.resize(tokens.size());
1331   HybridUintConfig uint_config;
1332   std::vector<float> sym_cost;
1333   std::vector<uint32_t> dist_symbols;
1334   for (size_t stream = 0; stream < tokens.size(); stream++) {
1335     size_t distance_multiplier =
1336         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1337     const auto& in = tokens[stream];
1338     auto& out = tokens_lz77[stream];
1339     // Cumulative sum of bit costs.
1340     sym_cost.resize(in.size() + 1);
1341     for (size_t i = 0; i < in.size(); i++) {
1342       uint32_t tok, nbits, unused_bits;
1343       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1344       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1345     }
1346
1347     out.reserve(in.size());
1348     size_t max_distance = in.size();
1349     size_t min_length = lz77.min_length;
1350     JXL_ASSERT(min_length >= 3);
1351     size_t max_length = in.size();
1352
1353     // Use next power of two as window size.
1354     size_t window_size = 1;
1355     while (window_size < max_distance && window_size < kWindowSize) {
1356       window_size <<= 1;
1357     }
1358
1359     HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1360                     distance_multiplier);
1361
1362     struct MatchInfo {
1363       uint32_t len;
1364       uint32_t dist_symbol;
1365       uint32_t ctx;
1366       float total_cost = std::numeric_limits<float>::max();
1367     };
1368     // Total cost to encode the first N symbols.
1369     std::vector<MatchInfo> prefix_costs(in.size() + 1);
1370     prefix_costs[0].total_cost = 0;
1371
1372     size_t rle_length = 0;
1373     size_t skip_lz77 = 0;
1374     for (size_t i = 0; i < in.size(); i++) {
1375       chain.Update(i);
1376       float lit_cost =
1377           prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
1378       if (prefix_costs[i + 1].total_cost > lit_cost) {
1379         prefix_costs[i + 1].dist_symbol = 0;
1380         prefix_costs[i + 1].len = 1;
1381         prefix_costs[i + 1].ctx = in[i].context;
1382         prefix_costs[i + 1].total_cost = lit_cost;
1383       }
1384       if (skip_lz77 > 0) {
1385         skip_lz77--;
1386         continue;
1387       }
1388       dist_symbols.clear();
1389       chain.FindMatches(i, max_distance,
1390                         [&dist_symbols](size_t len, size_t dist_symbol) {
1391                           if (dist_symbols.size() <= len) {
1392                             dist_symbols.resize(len + 1, dist_symbol);
1393                           }
1394                           if (dist_symbol < dist_symbols[len]) {
1395                             dist_symbols[len] = dist_symbol;
1396                           }
1397                         });
1398       if (dist_symbols.size() <= min_length) continue;
1399       {
1400         size_t best_cost = dist_symbols.back();
1401         for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
1402           if (dist_symbols[j] < best_cost) {
1403             best_cost = dist_symbols[j];
1404           }
1405           dist_symbols[j] = best_cost;
1406         }
1407       }
1408       for (size_t j = min_length; j < dist_symbols.size(); j++) {
1409         // Cost model that uses results from lazy LZ77.
1410         float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
1411                           sce.DistCost(dist_symbols[j], lz77);
1412         float cost = prefix_costs[i].total_cost + lz77_cost;
1413         if (prefix_costs[i + j].total_cost > cost) {
1414           prefix_costs[i + j].len = j;
1415           prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
1416           prefix_costs[i + j].ctx = in[i].context;
1417           prefix_costs[i + j].total_cost = cost;
1418         }
1419       }
1420       // We are in a RLE sequence: skip all the symbols except the first 8 and
1421       // the last 8. This avoid quadratic costs for sequences with long runs of
1422       // the same symbol.
1423       if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
1424           (dist_symbols.back() == 1 && distance_multiplier != 0)) {
1425         rle_length++;
1426       } else {
1427         rle_length = 0;
1428       }
1429       if (rle_length >= 8 && dist_symbols.size() > 9) {
1430         skip_lz77 = dist_symbols.size() - 10;
1431         rle_length = 0;
1432       }
1433     }
1434     size_t pos = in.size();
1435     while (pos > 0) {
1436       bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
1437       if (is_lz77_length) {
1438         size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
1439         out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1440       }
1441       size_t val = is_lz77_length ? prefix_costs[pos].len - min_length
1442                                   : in[pos - 1].value;
1443       out.emplace_back(prefix_costs[pos].ctx, val);
1444       out.back().is_lz77_length = is_lz77_length;
1445       pos -= prefix_costs[pos].len;
1446     }
1447     std::reverse(out.begin(), out.end());
1448   }
1449 }
1450
1451 void ApplyLZ77(const HistogramParams& params, size_t num_contexts,
1452                const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77,
1453                std::vector<std::vector<Token>>& tokens_lz77) {
1454   lz77.enabled = false;
1455   if (params.force_huffman) {
1456     lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512);
1457   } else {
1458     lz77.min_symbol = 224;
1459   }
1460   if (params.lz77_method == HistogramParams::LZ77Method::kNone) {
1461     return;
1462   } else if (params.lz77_method == HistogramParams::LZ77Method::kRLE) {
1463     ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77);
1464   } else if (params.lz77_method == HistogramParams::LZ77Method::kLZ77) {
1465     ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77);
1466   } else if (params.lz77_method == HistogramParams::LZ77Method::kOptimal) {
1467     ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77);
1468   } else {
1469     JXL_UNREACHABLE("Not implemented");
1470   }
1471 }
1472 }  // namespace
1473
1474 size_t BuildAndEncodeHistograms(const HistogramParams& params,
1475                                 size_t num_contexts,
1476                                 std::vector<std::vector<Token>>& tokens,
1477                                 EntropyEncodingData* codes,
1478                                 std::vector<uint8_t>* context_map,
1479                                 BitWriter* writer, size_t layer,
1480                                 AuxOut* aux_out) {
1481   size_t total_bits = 0;
1482   codes->lz77.nonserialized_distance_context = num_contexts;
1483   std::vector<std::vector<Token>> tokens_lz77;
1484   ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77);
1485   if (ans_fuzzer_friendly_) {
1486     codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0);
1487     codes->lz77.min_symbol = 2048;
1488   }
1489
1490   const size_t max_contexts = std::min(num_contexts, kClustersLimit);
1491   BitWriter::Allotment allotment(writer,
1492                                  128 + num_contexts * 40 + max_contexts * 96);
1493   if (writer) {
1494     JXL_CHECK(Bundle::Write(codes->lz77, writer, layer, aux_out));
1495   } else {
1496     size_t ebits, bits;
1497     JXL_CHECK(Bundle::CanEncode(codes->lz77, &ebits, &bits));
1498     total_bits += bits;
1499   }
1500   if (codes->lz77.enabled) {
1501     if (writer) {
1502       size_t b = writer->BitsWritten();
1503       EncodeUintConfig(codes->lz77.length_uint_config, writer,
1504                        /*log_alpha_size=*/8);
1505       total_bits += writer->BitsWritten() - b;
1506     } else {
1507       SizeWriter size_writer;
1508       EncodeUintConfig(codes->lz77.length_uint_config, &size_writer,
1509                        /*log_alpha_size=*/8);
1510       total_bits += size_writer.size;
1511     }
1512     num_contexts += 1;
1513     tokens = std::move(tokens_lz77);
1514   }
1515   size_t total_tokens = 0;
1516   // Build histograms.
1517   HistogramBuilder builder(num_contexts);
1518   HybridUintConfig uint_config;  //  Default config for clustering.
1519   // Unless we are using the kContextMap histogram option.
1520   if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
1521     uint_config = HybridUintConfig(2, 0, 1);
1522   }
1523   if (params.uint_method == HistogramParams::HybridUintMethod::k000) {
1524     uint_config = HybridUintConfig(0, 0, 0);
1525   }
1526   if (ans_fuzzer_friendly_) {
1527     uint_config = HybridUintConfig(10, 0, 0);
1528   }
1529   for (size_t i = 0; i < tokens.size(); ++i) {
1530     if (codes->lz77.enabled) {
1531       for (size_t j = 0; j < tokens[i].size(); ++j) {
1532         const Token& token = tokens[i][j];
1533         total_tokens++;
1534         uint32_t tok, nbits, bits;
1535         (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
1536             .Encode(token.value, &tok, &nbits, &bits);
1537         tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
1538         builder.VisitSymbol(tok, token.context);
1539       }
1540     } else if (num_contexts == 1) {
1541       for (size_t j = 0; j < tokens[i].size(); ++j) {
1542         const Token& token = tokens[i][j];
1543         total_tokens++;
1544         uint32_t tok, nbits, bits;
1545         uint_config.Encode(token.value, &tok, &nbits, &bits);
1546         builder.VisitSymbol(tok, /*token.context=*/0);
1547       }
1548     } else {
1549       for (size_t j = 0; j < tokens[i].size(); ++j) {
1550         const Token& token = tokens[i][j];
1551         total_tokens++;
1552         uint32_t tok, nbits, bits;
1553         uint_config.Encode(token.value, &tok, &nbits, &bits);
1554         builder.VisitSymbol(tok, token.context);
1555       }
1556     }
1557   }
1558
1559   bool use_prefix_code =
1560       params.force_huffman || total_tokens < 100 ||
1561       params.clustering == HistogramParams::ClusteringType::kFastest ||
1562       ans_fuzzer_friendly_;
1563   if (!use_prefix_code) {
1564     bool all_singleton = true;
1565     for (size_t i = 0; i < num_contexts; i++) {
1566       if (builder.Histo(i).ShannonEntropy() >= 1e-5) {
1567         all_singleton = false;
1568       }
1569     }
1570     if (all_singleton) {
1571       use_prefix_code = true;
1572     }
1573   }
1574
1575   // Encode histograms.
1576   total_bits += builder.BuildAndStoreEntropyCodes(params, tokens, codes,
1577                                                   context_map, use_prefix_code,
1578                                                   writer, layer, aux_out);
1579   allotment.FinishedHistogram(writer);
1580   allotment.ReclaimAndCharge(writer, layer, aux_out);
1581
1582   if (aux_out != nullptr) {
1583     aux_out->layers[layer].num_clustered_histograms +=
1584         codes->encoding_info.size();
1585   }
1586   return total_bits;
1587 }
1588
1589 size_t WriteTokens(const std::vector<Token>& tokens,
1590                    const EntropyEncodingData& codes,
1591                    const std::vector<uint8_t>& context_map, BitWriter* writer) {
1592   size_t num_extra_bits = 0;
1593   if (codes.use_prefix_code) {
1594     for (size_t i = 0; i < tokens.size(); i++) {
1595       uint32_t tok, nbits, bits;
1596       const Token& token = tokens[i];
1597       size_t histo = context_map[token.context];
1598       (token.is_lz77_length ? codes.lz77.length_uint_config
1599                             : codes.uint_config[histo])
1600           .Encode(token.value, &tok, &nbits, &bits);
1601       tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1602       // Combine two calls to the BitWriter. Equivalent to:
1603       // writer->Write(codes.encoding_info[histo][tok].depth,
1604       //               codes.encoding_info[histo][tok].bits);
1605       // writer->Write(nbits, bits);
1606       uint64_t data = codes.encoding_info[histo][tok].bits;
1607       data |= bits << codes.encoding_info[histo][tok].depth;
1608       writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
1609       num_extra_bits += nbits;
1610     }
1611     return num_extra_bits;
1612   }
1613   std::vector<uint64_t> out;
1614   std::vector<uint8_t> out_nbits;
1615   out.reserve(tokens.size());
1616   out_nbits.reserve(tokens.size());
1617   uint64_t allbits = 0;
1618   size_t numallbits = 0;
1619   // Writes in *reversed* order.
1620   auto addbits = [&](size_t bits, size_t nbits) {
1621     if (JXL_UNLIKELY(nbits)) {
1622       JXL_DASSERT(bits >> nbits == 0);
1623       if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) {
1624         out.push_back(allbits);
1625         out_nbits.push_back(numallbits);
1626         numallbits = allbits = 0;
1627       }
1628       allbits <<= nbits;
1629       allbits |= bits;
1630       numallbits += nbits;
1631     }
1632   };
1633   const int end = tokens.size();
1634   ANSCoder ans;
1635   if (codes.lz77.enabled || context_map.size() > 1) {
1636     for (int i = end - 1; i >= 0; --i) {
1637       const Token token = tokens[i];
1638       const uint8_t histo = context_map[token.context];
1639       uint32_t tok, nbits, bits;
1640       (token.is_lz77_length ? codes.lz77.length_uint_config
1641                             : codes.uint_config[histo])
1642           .Encode(tokens[i].value, &tok, &nbits, &bits);
1643       tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1644       const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok];
1645       // Extra bits first as this is reversed.
1646       addbits(bits, nbits);
1647       num_extra_bits += nbits;
1648       uint8_t ans_nbits = 0;
1649       uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1650       addbits(ans_bits, ans_nbits);
1651     }
1652   } else {
1653     for (int i = end - 1; i >= 0; --i) {
1654       uint32_t tok, nbits, bits;
1655       codes.uint_config[0].Encode(tokens[i].value, &tok, &nbits, &bits);
1656       const ANSEncSymbolInfo& info = codes.encoding_info[0][tok];
1657       // Extra bits first as this is reversed.
1658       addbits(bits, nbits);
1659       num_extra_bits += nbits;
1660       uint8_t ans_nbits = 0;
1661       uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1662       addbits(ans_bits, ans_nbits);
1663     }
1664   }
1665   const uint32_t state = ans.GetState();
1666   writer->Write(32, state);
1667   writer->Write(numallbits, allbits);
1668   for (int i = out.size(); i > 0; --i) {
1669     writer->Write(out_nbits[i - 1], out[i - 1]);
1670   }
1671   return num_extra_bits;
1672 }
1673
1674 void WriteTokens(const std::vector<Token>& tokens,
1675                  const EntropyEncodingData& codes,
1676                  const std::vector<uint8_t>& context_map, BitWriter* writer,
1677                  size_t layer, AuxOut* aux_out) {
1678   BitWriter::Allotment allotment(writer, 32 * tokens.size() + 32 * 1024 * 4);
1679   size_t num_extra_bits = WriteTokens(tokens, codes, context_map, writer);
1680   allotment.ReclaimAndCharge(writer, layer, aux_out);
1681   if (aux_out != nullptr) {
1682     aux_out->layers[layer].extra_bits += num_extra_bits;
1683   }
1684 }
1685
1686 void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) {
1687 #if JXL_IS_DEBUG_BUILD  // Guard against accidental / malicious changes.
1688   ans_fuzzer_friendly_ = ans_fuzzer_friendly;
1689 #endif
1690 }
1691 }  // namespace jxl