1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
6 #include "lib/jxl/enc_ans.h"
15 #include <type_traits>
16 #include <unordered_map>
20 #include "lib/jxl/ans_common.h"
21 #include "lib/jxl/base/bits.h"
22 #include "lib/jxl/base/fast_math-inl.h"
23 #include "lib/jxl/dec_ans.h"
24 #include "lib/jxl/enc_aux_out.h"
25 #include "lib/jxl/enc_cluster.h"
26 #include "lib/jxl/enc_context_map.h"
27 #include "lib/jxl/enc_fields.h"
28 #include "lib/jxl/enc_huffman.h"
29 #include "lib/jxl/fields.h"
35 #if !JXL_IS_DEBUG_BUILD
38 bool ans_fuzzer_friendly_ = false;
40 static const int kMaxNumSymbolsForSmallCode = 4;
42 void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table,
43 size_t alphabet_size, size_t log_alpha_size,
44 ANSEncSymbolInfo* info) {
45 size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size;
46 size_t entry_size_minus_1 = (1 << log_entry_size) - 1;
47 // create valid alias table for empty streams.
48 for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) {
49 const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s];
50 info[s].freq_ = static_cast<uint16_t>(freq);
51 #ifdef USE_MULT_BY_RECIPROCAL
54 ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_;
56 info[s].ifreq_ = 1; // shouldn't matter (symbol shouldn't occur), but...
59 info[s].reverse_map_.resize(freq);
61 for (int i = 0; i < ANS_TAB_SIZE; i++) {
62 AliasTable::Symbol s =
63 AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1);
64 info[s.value].reverse_map_[s.offset] = i;
68 float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts,
71 int total_histogram = 0;
73 for (size_t i = 0; i < len; ++i) {
74 total_histogram += histogram[i];
75 total_counts += counts[i];
76 if (histogram[i] > 0) {
77 JXL_ASSERT(counts[i] > 0);
78 // += histogram[i] * -log(counts[i]/total_counts)
80 std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i]));
83 if (total_histogram > 0) {
84 // Used only in assert.
86 JXL_ASSERT(total_counts == ANS_TAB_SIZE);
91 float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) {
92 const float flat_bits = std::max(FastLog2f(len), 0.0f);
93 float total_histogram = 0;
94 for (size_t i = 0; i < len; ++i) {
95 total_histogram += histogram[i];
97 return total_histogram * flat_bits;
100 // Static Huffman code for encoding logcounts. The last symbol is used as RLE
102 static const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = {
103 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
105 static const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = {
106 17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
109 // Returns the difference between largest count that can be represented and is
110 // smaller than "count" and smallest representable count larger than "count".
111 static int SmallestIncrement(uint32_t count, uint32_t shift) {
112 int bits = count == 0 ? -1 : FloorLog2Nonzero(count);
113 int drop_bits = bits - GetPopulationCountPrecision(bits, shift);
114 return drop_bits < 0 ? 1 : (1 << drop_bits);
117 template <bool minimize_error_of_sum>
118 bool RebalanceHistogram(const float* targets, int max_symbol, int table_size,
119 uint32_t shift, int* omit_pos, ANSHistBin* counts) {
121 float sum_nonrounded = 0.0;
122 int remainder_pos = 0; // if all of them are handled in first loop
123 int remainder_log = -1;
124 for (int n = 0; n < max_symbol; ++n) {
125 if (targets[n] > 0 && targets[n] < 1.0f) {
127 sum_nonrounded += targets[n];
131 const float discount_ratio =
132 (table_size - sum) / (table_size - sum_nonrounded);
133 JXL_ASSERT(discount_ratio > 0);
134 JXL_ASSERT(discount_ratio <= 1.0f);
135 // Invariant for minimize_error_of_sum == true:
136 // abs(sum - sum_nonrounded)
137 // <= SmallestIncrement(max(targets[])) + max_symbol
138 for (int n = 0; n < max_symbol; ++n) {
139 if (targets[n] >= 1.0f) {
140 sum_nonrounded += targets[n];
142 static_cast<ANSHistBin>(targets[n] * discount_ratio); // truncate
143 if (counts[n] == 0) counts[n] = 1;
144 if (counts[n] == table_size) counts[n] = table_size - 1;
145 // Round the count to the closest nonzero multiple of SmallestIncrement
146 // (when minimize_error_of_sum is false) or one of two closest so as to
147 // keep the sum as close as possible to sum_nonrounded.
148 int inc = SmallestIncrement(counts[n], shift);
149 counts[n] -= counts[n] & (inc - 1);
150 // TODO(robryk): Should we rescale targets[n]?
152 minimize_error_of_sum ? (sum_nonrounded - sum) : targets[n];
153 if (counts[n] == 0 ||
154 (target > counts[n] + inc / 2 && counts[n] + inc < table_size)) {
158 const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n]));
159 if (count_log > remainder_log) {
161 remainder_log = count_log;
165 JXL_ASSERT(remainder_pos != -1);
166 // NOTE: This is the only place where counts could go negative. We could
167 // detect that, return false and make ANSHistBin uint32_t.
168 counts[remainder_pos] -= sum - table_size;
169 *omit_pos = remainder_pos;
170 return counts[remainder_pos] > 0;
173 Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length,
174 const int precision_bits, uint32_t shift,
175 int* num_symbols, int* symbols) {
176 const int32_t table_size = 1 << precision_bits; // target sum / table size
179 int symbol_count = 0;
180 for (int n = 0; n < length; ++n) {
183 if (symbol_count < kMaxNumSymbolsForSmallCode) {
184 symbols[symbol_count] = n;
190 *num_symbols = symbol_count;
191 if (symbol_count == 0) {
194 if (symbol_count == 1) {
195 counts[symbols[0]] = table_size;
198 if (symbol_count > table_size)
199 return JXL_FAILURE("Too many entries in an ANS histogram");
201 const float norm = 1.f * table_size / total;
202 std::vector<float> targets(max_symbol);
203 for (size_t n = 0; n < targets.size(); ++n) {
204 targets[n] = norm * counts[n];
206 if (!RebalanceHistogram<false>(&targets[0], max_symbol, table_size, shift,
208 // Use an alternative rebalancing mechanism if the one above failed
209 // to create a histogram that is positive wherever the original one was.
210 if (!RebalanceHistogram<true>(&targets[0], max_symbol, table_size, shift,
212 return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
220 void Write(size_t num, size_t bits) { size += num; }
223 template <typename Writer>
224 void StoreVarLenUint8(size_t n, Writer* writer) {
225 JXL_DASSERT(n <= 255);
230 size_t nbits = FloorLog2Nonzero(n);
231 writer->Write(3, nbits);
232 writer->Write(nbits, n - (1ULL << nbits));
236 template <typename Writer>
237 void StoreVarLenUint16(size_t n, Writer* writer) {
238 JXL_DASSERT(n <= 65535);
243 size_t nbits = FloorLog2Nonzero(n);
244 writer->Write(4, nbits);
245 writer->Write(nbits, n - (1ULL << nbits));
249 template <typename Writer>
250 bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size,
251 const int omit_pos, const int num_symbols, uint32_t shift,
252 const int* symbols, Writer* writer) {
254 if (num_symbols <= 2) {
255 // Small tree marker to encode 1-2 symbols.
257 if (num_symbols == 0) {
259 StoreVarLenUint8(0, writer);
261 writer->Write(1, num_symbols - 1);
262 for (int i = 0; i < num_symbols; ++i) {
263 StoreVarLenUint8(symbols[i], writer);
266 if (num_symbols == 2) {
267 writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]);
270 // Mark non-small tree.
272 // Mark non-flat histogram.
275 // Precompute sequences for RLE encoding. Contains the number of identical
276 // values starting at a given index. Only contains the value at the first
277 // element of the series.
278 std::vector<uint32_t> same(alphabet_size, 0);
280 for (int i = 1; i < alphabet_size; i++) {
281 // Store the sequence length once different symbol reached, or we're at
282 // the end, or the length is longer than we can encode, or we are at
283 // the omit_pos. We don't support including the omit_pos in an RLE
284 // sequence because this value may use a different amount of log2 bits
285 // than standard, it is too complex to handle in the decoder.
286 if (counts[i] != counts[last] || i + 1 == alphabet_size ||
287 (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) {
288 same[last] = (i - last);
294 std::vector<int> logcounts(alphabet_size);
296 for (int i = 0; i < alphabet_size; ++i) {
297 JXL_ASSERT(counts[i] <= ANS_TAB_SIZE);
298 JXL_ASSERT(counts[i] >= 0);
301 } else if (counts[i] > 0) {
302 logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1;
305 omit_log = std::max(omit_log, logcounts[i] + 1);
307 omit_log = std::max(omit_log, logcounts[i]);
311 logcounts[omit_pos] = omit_log;
313 // Elias gamma-like code for shift. Only difference is that if the number
314 // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip
315 // the terminating 0 in unary coding.
316 int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
317 int log = FloorLog2Nonzero(shift + 1);
318 writer->Write(log, (1 << log) - 1);
319 if (log != upper_bound_log) writer->Write(1, 0);
320 writer->Write(log, ((1 << log) - 1) & (shift + 1));
322 // Since num_symbols >= 3, we know that length >= 3, therefore we encode
324 if (length - 3 > 255) {
325 // Pretend that everything is OK, but complain about correctness later.
326 StoreVarLenUint8(255, writer);
329 StoreVarLenUint8(length - 3, writer);
332 // The logcount values are encoded with a static Huffman code.
333 static const size_t kMinReps = 4;
334 size_t rep = ANS_LOG_TAB_SIZE + 1;
335 for (int i = 0; i < length; ++i) {
336 if (i > 0 && same[i - 1] > kMinReps) {
337 // Encode the RLE symbol and skip the repeated ones.
338 writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]);
339 StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer);
340 i += same[i - 1] - 2;
343 writer->Write(kLogCountBitLengths[logcounts[i]],
344 kLogCountSymbols[logcounts[i]]);
346 for (int i = 0; i < length; ++i) {
347 if (i > 0 && same[i - 1] > kMinReps) {
348 // Skip symbols encoded by RLE.
349 i += same[i - 1] - 2;
352 if (logcounts[i] > 1 && i != omit_pos) {
353 int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift);
354 int drop_bits = logcounts[i] - 1 - bitcount;
355 JXL_CHECK((counts[i] & ((1 << drop_bits) - 1)) == 0);
356 writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount));
363 void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) {
364 // Mark non-small tree.
366 // Mark uniform histogram.
368 JXL_ASSERT(alphabet_size > 0);
369 // Encode alphabet size.
370 StoreVarLenUint8(alphabet_size - 1, writer);
373 float ComputeHistoAndDataCost(const ANSHistBin* histogram, size_t alphabet_size,
375 if (method == 0) { // Flat code
376 return ANS_LOG_TAB_SIZE + 2 +
377 EstimateDataBitsFlat(histogram, alphabet_size);
379 // Non-flat: shift = method-1.
380 uint32_t shift = method - 1;
381 std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
384 int symbols[kMaxNumSymbolsForSmallCode] = {};
385 JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
386 ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
388 // Ignore the correctness, no real encoding happens at this stage.
389 (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift,
392 EstimateDataBits(histogram, counts.data(), alphabet_size);
395 uint32_t ComputeBestMethod(
396 const ANSHistBin* histogram, size_t alphabet_size, float* cost,
397 HistogramParams::ANSHistogramStrategy ans_histogram_strategy) {
399 float fcost = ComputeHistoAndDataCost(histogram, alphabet_size, 0);
400 auto try_shift = [&](size_t shift) {
401 float c = ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1);
407 switch (ans_histogram_strategy) {
408 case HistogramParams::ANSHistogramStrategy::kPrecise: {
409 for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift++) {
414 case HistogramParams::ANSHistogramStrategy::kApproximate: {
415 for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift += 2) {
420 case HistogramParams::ANSHistogramStrategy::kFast: {
422 try_shift(ANS_LOG_TAB_SIZE / 2);
423 try_shift(ANS_LOG_TAB_SIZE);
433 // Returns an estimate of the cost of encoding this histogram and the
434 // corresponding data.
435 size_t BuildAndStoreANSEncodingData(
436 HistogramParams::ANSHistogramStrategy ans_histogram_strategy,
437 const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size,
438 bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) {
439 if (use_prefix_code) {
440 if (alphabet_size <= 1) return 0;
441 std::vector<uint32_t> histo(alphabet_size);
442 for (size_t i = 0; i < alphabet_size; i++) {
443 histo[i] = histogram[i];
444 JXL_CHECK(histogram[i] >= 0);
448 std::vector<uint8_t> depths(alphabet_size);
449 std::vector<uint16_t> bits(alphabet_size);
450 if (writer == nullptr) {
451 BitWriter tmp_writer;
452 BitWriter::Allotment allotment(
453 &tmp_writer, 8 * alphabet_size + 8); // safe upper bound
454 BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(),
455 bits.data(), &tmp_writer);
456 allotment.ReclaimAndCharge(&tmp_writer, 0, /*aux_out=*/nullptr);
457 cost = tmp_writer.BitsWritten();
459 size_t start = writer->BitsWritten();
460 BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(),
461 bits.data(), writer);
462 cost = writer->BitsWritten() - start;
464 for (size_t i = 0; i < alphabet_size; i++) {
465 info[i].bits = depths[i] == 0 ? 0 : bits[i];
466 info[i].depth = depths[i];
469 // Estimate data cost.
470 for (size_t i = 0; i < alphabet_size; i++) {
471 cost += histogram[i] * info[i].depth;
475 JXL_ASSERT(alphabet_size <= ANS_TAB_SIZE);
476 // Ensure we ignore trailing zeros in the histogram.
477 if (alphabet_size != 0) {
478 size_t largest_symbol = 0;
479 for (size_t i = 0; i < alphabet_size; i++) {
480 if (histogram[i] != 0) largest_symbol = i;
482 alphabet_size = largest_symbol + 1;
485 uint32_t method = ComputeBestMethod(histogram, alphabet_size, &cost,
486 ans_histogram_strategy);
487 JXL_ASSERT(cost >= 0);
489 int symbols[kMaxNumSymbolsForSmallCode] = {};
490 std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
491 if (!counts.empty()) {
493 for (size_t i = 0; i < counts.size(); i++) {
497 counts[0] = ANS_TAB_SIZE;
501 counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
502 AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
503 InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
504 ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
505 if (writer != nullptr) {
506 EncodeFlatHistogram(alphabet_size, writer);
511 uint32_t shift = method - 1;
512 JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
513 ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
514 AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
515 InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
516 ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
517 if (writer != nullptr) {
518 bool ok = EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols,
519 shift, symbols, writer);
526 float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size) {
528 ComputeBestMethod(data, alphabet_size, &c,
529 HistogramParams::ANSHistogramStrategy::kFast);
533 template <typename Writer>
534 void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer,
535 size_t log_alpha_size) {
536 writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
537 uint_config.split_exponent);
538 if (uint_config.split_exponent == log_alpha_size) {
539 return; // msb/lsb don't matter.
541 size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
542 writer->Write(nbits, uint_config.msb_in_token);
543 nbits = CeilLog2Nonzero(uint_config.split_exponent -
544 uint_config.msb_in_token + 1);
545 writer->Write(nbits, uint_config.lsb_in_token);
547 template <typename Writer>
548 void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
549 Writer* writer, size_t log_alpha_size) {
550 // TODO(veluca): RLE?
551 for (size_t i = 0; i < uint_config.size(); i++) {
552 EncodeUintConfig(uint_config[i], writer, log_alpha_size);
555 template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
560 void ChooseUintConfigs(const HistogramParams& params,
561 const std::vector<std::vector<Token>>& tokens,
562 const std::vector<uint8_t>& context_map,
563 std::vector<Histogram>* clustered_histograms,
564 EntropyEncodingData* codes, size_t* log_alpha_size) {
565 codes->uint_config.resize(clustered_histograms->size());
567 if (params.uint_method == HistogramParams::HybridUintMethod::kNone) return;
568 if (params.uint_method == HistogramParams::HybridUintMethod::k000) {
569 codes->uint_config.clear();
570 codes->uint_config.resize(clustered_histograms->size(),
571 HybridUintConfig(0, 0, 0));
574 if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
575 codes->uint_config.clear();
576 codes->uint_config.resize(clustered_histograms->size(),
577 HybridUintConfig(2, 0, 1));
581 // Brute-force method that tries a few options.
582 std::vector<HybridUintConfig> configs;
583 if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
585 HybridUintConfig(4, 2, 0), // default
586 HybridUintConfig(4, 1, 0), // less precise
587 HybridUintConfig(4, 2, 1), // add sign
588 HybridUintConfig(4, 2, 2), // add sign+parity
589 HybridUintConfig(4, 1, 2), // add parity but less msb
590 // Same as above, but more direct coding.
591 HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0),
592 HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2),
593 HybridUintConfig(5, 1, 2),
594 // Same as above, but less direct coding.
595 HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0),
596 HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2),
597 // For near-lossless.
598 HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4),
599 HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5),
600 HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0),
602 HybridUintConfig(0, 0, 0), // varlenuint
603 HybridUintConfig(2, 0, 1), // works well for ctx map
604 HybridUintConfig(7, 0, 0), // direct coding
605 HybridUintConfig(8, 0, 0), // direct coding
606 HybridUintConfig(9, 0, 0), // direct coding
607 HybridUintConfig(10, 0, 0), // direct coding
608 HybridUintConfig(11, 0, 0), // direct coding
609 HybridUintConfig(12, 0, 0), // direct coding
611 } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) {
613 HybridUintConfig(4, 2, 0), // default
614 HybridUintConfig(4, 1, 2), // add parity but less msb
615 HybridUintConfig(0, 0, 0), // smallest histograms
616 HybridUintConfig(2, 0, 1), // works well for ctx map
620 std::vector<float> costs(clustered_histograms->size(),
621 std::numeric_limits<float>::max());
622 std::vector<uint32_t> extra_bits(clustered_histograms->size());
623 std::vector<uint8_t> is_valid(clustered_histograms->size());
625 codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE;
626 for (HybridUintConfig cfg : configs) {
627 std::fill(is_valid.begin(), is_valid.end(), true);
628 std::fill(extra_bits.begin(), extra_bits.end(), 0);
630 for (size_t i = 0; i < clustered_histograms->size(); i++) {
631 (*clustered_histograms)[i].Clear();
633 for (size_t i = 0; i < tokens.size(); ++i) {
634 for (size_t j = 0; j < tokens[i].size(); ++j) {
635 const Token token = tokens[i][j];
636 // TODO(veluca): do not ignore lz77 commands.
637 if (token.is_lz77_length) continue;
638 size_t histo = context_map[token.context];
639 uint32_t tok, nbits, bits;
640 cfg.Encode(token.value, &tok, &nbits, &bits);
641 if (tok >= max_alpha ||
642 (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) {
643 is_valid[histo] = false;
646 extra_bits[histo] += nbits;
647 (*clustered_histograms)[histo].Add(tok);
651 for (size_t i = 0; i < clustered_histograms->size(); i++) {
652 if (!is_valid[i]) continue;
653 float cost = (*clustered_histograms)[i].PopulationCost() + extra_bits[i];
654 // add signaling cost of the hybriduintconfig itself
655 cost += CeilLog2Nonzero(cfg.split_exponent + 1);
656 cost += CeilLog2Nonzero(cfg.split_exponent - cfg.msb_in_token + 1);
657 if (cost < costs[i]) {
658 codes->uint_config[i] = cfg;
664 // Rebuild histograms.
665 for (size_t i = 0; i < clustered_histograms->size(); i++) {
666 (*clustered_histograms)[i].Clear();
669 for (size_t i = 0; i < tokens.size(); ++i) {
670 for (size_t j = 0; j < tokens[i].size(); ++j) {
671 const Token token = tokens[i][j];
672 uint32_t tok, nbits, bits;
673 size_t histo = context_map[token.context];
674 (token.is_lz77_length ? codes->lz77.length_uint_config
675 : codes->uint_config[histo])
676 .Encode(token.value, &tok, &nbits, &bits);
677 tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
678 (*clustered_histograms)[histo].Add(tok);
679 while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++;
682 #if JXL_ENABLE_ASSERT
683 size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8;
684 JXL_ASSERT(*log_alpha_size <= max_log_alpha_size);
688 class HistogramBuilder {
690 explicit HistogramBuilder(const size_t num_contexts)
691 : histograms_(num_contexts) {}
693 void VisitSymbol(int symbol, size_t histo_idx) {
694 JXL_DASSERT(histo_idx < histograms_.size());
695 histograms_[histo_idx].Add(symbol);
698 // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge.
699 size_t BuildAndStoreEntropyCodes(
700 const HistogramParams& params,
701 const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes,
702 std::vector<uint8_t>* context_map, bool use_prefix_code,
703 BitWriter* writer, size_t layer, AuxOut* aux_out) const {
705 codes->encoding_info.clear();
706 std::vector<Histogram> clustered_histograms(histograms_);
707 context_map->resize(histograms_.size());
708 if (histograms_.size() > 1) {
709 if (!ans_fuzzer_friendly_) {
710 std::vector<uint32_t> histogram_symbols;
711 ClusterHistograms(params, histograms_, kClustersLimit,
712 &clustered_histograms, &histogram_symbols);
713 for (size_t c = 0; c < histograms_.size(); ++c) {
714 (*context_map)[c] = static_cast<uint8_t>(histogram_symbols[c]);
717 fill(context_map->begin(), context_map->end(), 0);
718 size_t max_symbol = 0;
719 for (const Histogram& h : histograms_) {
720 max_symbol = std::max(h.data_.size(), max_symbol);
722 size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1);
723 clustered_histograms.resize(1);
724 clustered_histograms[0].Clear();
725 for (size_t i = 0; i < num_symbols; i++) {
726 clustered_histograms[0].Add(i);
729 if (writer != nullptr) {
730 EncodeContextMap(*context_map, clustered_histograms.size(), writer,
734 if (aux_out != nullptr) {
735 for (size_t i = 0; i < clustered_histograms.size(); ++i) {
736 aux_out->layers[layer].clustered_entropy +=
737 clustered_histograms[i].ShannonEntropy();
740 codes->use_prefix_code = use_prefix_code;
741 size_t log_alpha_size = codes->lz77.enabled ? 8 : 7; // Sane default.
742 if (ans_fuzzer_friendly_) {
743 codes->uint_config.clear();
744 codes->uint_config.resize(1, HybridUintConfig(7, 0, 0));
746 ChooseUintConfigs(params, tokens, *context_map, &clustered_histograms,
747 codes, &log_alpha_size);
749 if (log_alpha_size < 5) log_alpha_size = 5;
750 SizeWriter size_writer; // Used if writer == nullptr to estimate costs.
752 if (writer) writer->Write(1, use_prefix_code);
754 if (use_prefix_code) {
755 log_alpha_size = PREFIX_MAX_BITS;
759 if (writer == nullptr) {
760 EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size);
762 if (!use_prefix_code) writer->Write(2, log_alpha_size - 5);
763 EncodeUintConfigs(codes->uint_config, writer, log_alpha_size);
765 if (use_prefix_code) {
766 for (size_t c = 0; c < clustered_histograms.size(); ++c) {
767 size_t num_symbol = 1;
768 for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
769 if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
772 StoreVarLenUint16(num_symbol - 1, writer);
774 StoreVarLenUint16(num_symbol - 1, &size_writer);
778 cost += size_writer.size;
779 for (size_t c = 0; c < clustered_histograms.size(); ++c) {
780 size_t num_symbol = 1;
781 for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
782 if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
784 codes->encoding_info.emplace_back();
785 codes->encoding_info.back().resize(std::max<size_t>(1, num_symbol));
787 BitWriter::Allotment allotment(writer, 256 + num_symbol * 24);
788 cost += BuildAndStoreANSEncodingData(
789 params.ans_histogram_strategy, clustered_histograms[c].data_.data(),
790 num_symbol, log_alpha_size, use_prefix_code,
791 codes->encoding_info.back().data(), writer);
792 allotment.FinishedHistogram(writer);
793 allotment.ReclaimAndCharge(writer, layer, aux_out);
798 const Histogram& Histo(size_t i) const { return histograms_[i]; }
801 std::vector<Histogram> histograms_;
804 class SymbolCostEstimator {
806 SymbolCostEstimator(size_t num_contexts, bool force_huffman,
807 const std::vector<std::vector<Token>>& tokens,
808 const LZ77Params& lz77) {
809 HistogramBuilder builder(num_contexts);
810 // Build histograms for estimating lz77 savings.
811 HybridUintConfig uint_config;
812 for (size_t i = 0; i < tokens.size(); ++i) {
813 for (size_t j = 0; j < tokens[i].size(); ++j) {
814 const Token token = tokens[i][j];
815 uint32_t tok, nbits, bits;
816 (token.is_lz77_length ? lz77.length_uint_config : uint_config)
817 .Encode(token.value, &tok, &nbits, &bits);
818 tok += token.is_lz77_length ? lz77.min_symbol : 0;
819 builder.VisitSymbol(tok, token.context);
822 max_alphabet_size_ = 0;
823 for (size_t i = 0; i < num_contexts; i++) {
825 std::max(max_alphabet_size_, builder.Histo(i).data_.size());
827 bits_.resize(num_contexts * max_alphabet_size_);
828 // TODO(veluca): SIMD?
829 add_symbol_cost_.resize(num_contexts);
830 for (size_t i = 0; i < num_contexts; i++) {
831 float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f);
832 float total_cost = 0;
833 for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) {
834 size_t cnt = builder.Histo(i).data_[j];
836 if (cnt != 0 && cnt != builder.Histo(i).total_count_) {
837 cost = -FastLog2f(cnt * inv_total);
838 if (force_huffman) cost = std::ceil(cost);
839 } else if (cnt == 0) {
840 cost = ANS_LOG_TAB_SIZE; // Highest possible cost.
842 bits_[i * max_alphabet_size_ + j] = cost;
843 total_cost += cost * builder.Histo(i).data_[j];
845 // Penalty for adding a lz77 symbol to this contest (only used for static
846 // cost model). Higher penalty for contexts that have a very low
847 // per-symbol entropy.
848 add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
851 float Bits(size_t ctx, size_t sym) const {
852 return bits_[ctx * max_alphabet_size_ + sym];
854 float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
855 uint32_t nbits, bits, tok;
856 lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
857 tok += lz77.min_symbol;
858 return nbits + Bits(ctx, tok);
860 float DistCost(size_t len, const LZ77Params& lz77) const {
861 uint32_t nbits, bits, tok;
862 HybridUintConfig().Encode(len, &tok, &nbits, &bits);
863 return nbits + Bits(lz77.nonserialized_distance_context, tok);
865 float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }
868 size_t max_alphabet_size_;
869 std::vector<float> bits_;
870 std::vector<float> add_symbol_cost_;
873 void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts,
874 const std::vector<std::vector<Token>>& tokens,
876 std::vector<std::vector<Token>>& tokens_lz77) {
877 // TODO(veluca): tune heuristics here.
878 SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
879 float bit_decrease = 0;
880 size_t total_symbols = 0;
881 tokens_lz77.resize(tokens.size());
882 std::vector<float> sym_cost;
883 HybridUintConfig uint_config;
884 for (size_t stream = 0; stream < tokens.size(); stream++) {
885 size_t distance_multiplier =
886 params.image_widths.size() > stream ? params.image_widths[stream] : 0;
887 const auto& in = tokens[stream];
888 auto& out = tokens_lz77[stream];
889 total_symbols += in.size();
890 // Cumulative sum of bit costs.
891 sym_cost.resize(in.size() + 1);
892 for (size_t i = 0; i < in.size(); i++) {
893 uint32_t tok, nbits, unused_bits;
894 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
895 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
897 out.reserve(in.size());
898 for (size_t i = 0; i < in.size(); i++) {
899 size_t num_to_copy = 0;
900 size_t distance_symbol = 0; // 1 for RLE.
901 if (distance_multiplier != 0) {
902 distance_symbol = 1; // Special distance 1 if enabled.
903 JXL_DASSERT(kSpecialDistances[1][0] == 1);
904 JXL_DASSERT(kSpecialDistances[1][1] == 0);
907 for (; i + num_to_copy < in.size(); num_to_copy++) {
908 if (in[i + num_to_copy].value != in[i - 1].value) {
913 if (num_to_copy == 0) {
914 out.push_back(in[i]);
917 float cost = sym_cost[i + num_to_copy] - sym_cost[i];
918 // This subtraction might overflow, but that's OK.
919 size_t lz77_len = num_to_copy - lz77.min_length;
920 float lz77_cost = num_to_copy >= lz77.min_length
921 ? CeilLog2Nonzero(lz77_len + 1) + 1
923 if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
924 for (size_t j = 0; j < num_to_copy; j++) {
925 out.push_back(in[i + j]);
927 i += num_to_copy - 1;
930 // Output the LZ77 length
931 out.emplace_back(in[i].context, lz77_len);
932 out.back().is_lz77_length = true;
933 i += num_to_copy - 1;
934 bit_decrease += cost - lz77_cost;
935 // Output the LZ77 copy distance.
936 out.emplace_back(lz77.nonserialized_distance_context, distance_symbol);
940 if (bit_decrease > total_symbols * 0.2 + 16) {
945 // Hash chain for LZ77 matching
948 std::vector<uint32_t> data_;
950 unsigned hash_num_values_ = 32768;
951 unsigned hash_mask_ = hash_num_values_ - 1;
952 unsigned hash_shift_ = 5;
954 std::vector<int> head;
955 std::vector<uint32_t> chain;
956 std::vector<int> val;
958 // Speed up repetitions of zero
959 std::vector<int> headz;
960 std::vector<uint32_t> chainz;
961 std::vector<uint32_t> zeros;
962 uint32_t numzeros = 0;
969 // Map of special distance codes.
970 std::unordered_map<int, int> special_dist_table_;
971 size_t num_special_distances_ = 0;
973 uint32_t maxchainlength = 256; // window_size_ to allow all
975 HashChain(const Token* data, size_t size, size_t window_size,
976 size_t min_length, size_t max_length, size_t distance_multiplier)
978 window_size_(window_size),
979 window_mask_(window_size - 1),
980 min_length_(min_length),
981 max_length_(max_length) {
983 for (size_t i = 0; i < size; i++) {
984 data_[i] = data[i].value;
987 head.resize(hash_num_values_, -1);
988 val.resize(window_size_, -1);
989 chain.resize(window_size_);
990 for (uint32_t i = 0; i < window_size_; ++i) {
991 chain[i] = i; // same value as index indicates uninitialized
994 zeros.resize(window_size_);
995 headz.resize(window_size_ + 1, -1);
996 chainz.resize(window_size_);
997 for (uint32_t i = 0; i < window_size_; ++i) {
1000 // Translate distance to special distance code.
1001 if (distance_multiplier) {
1002 // Count down, so if due to small distance multiplier multiple distances
1003 // map to the same code, the smallest code will be used in the end.
1004 for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
1005 int xi = kSpecialDistances[i][0];
1006 int yi = kSpecialDistances[i][1];
1007 int distance = yi * distance_multiplier + xi;
1008 // Ensure that we map distance 1 to the lowest symbols.
1009 if (distance < 1) distance = 1;
1010 special_dist_table_[distance] = i;
1012 num_special_distances_ = kNumSpecialDistances;
1016 uint32_t GetHash(size_t pos) const {
1017 uint32_t result = 0;
1018 if (pos + 2 < size_) {
1019 // TODO(lode): take the MSB's of the uint32_t values into account as well,
1020 // given that the hash code itself is less than 32 bits.
1021 result ^= (uint32_t)(data_[pos + 0] << 0u);
1022 result ^= (uint32_t)(data_[pos + 1] << hash_shift_);
1023 result ^= (uint32_t)(data_[pos + 2] << (hash_shift_ * 2));
1025 // No need to compute hash of last 2 bytes, the length 2 is too short.
1028 return result & hash_mask_;
1031 uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
1032 size_t end = pos + window_size_;
1033 if (end > size_) end = size_;
1034 if (prevzeros > 0) {
1035 if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
1036 end == pos + window_size_) {
1039 return prevzeros - 1;
1043 while (pos + num < end && data_[pos + num] == 0) num++;
1047 void Update(size_t pos) {
1048 uint32_t hashval = GetHash(pos);
1049 uint32_t wpos = pos & window_mask_;
1051 val[wpos] = (int)hashval;
1052 if (head[hashval] != -1) chain[wpos] = head[hashval];
1053 head[hashval] = wpos;
1055 if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
1056 numzeros = CountZeros(pos, numzeros);
1058 zeros[wpos] = numzeros;
1059 if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
1060 headz[numzeros] = wpos;
1063 void Update(size_t pos, size_t len) {
1064 for (size_t i = 0; i < len; i++) {
1069 template <typename CB>
1070 void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
1071 uint32_t wpos = pos & window_mask_;
1072 uint32_t hashval = GetHash(pos);
1073 uint32_t hashpos = chain[wpos];
1076 int end = std::min<int>(pos + max_length_, size_);
1077 uint32_t chainlength = 0;
1078 uint32_t best_len = 0;
1080 int dist = (hashpos <= wpos) ? (wpos - hashpos)
1081 : (wpos - hashpos + window_mask_ + 1);
1082 if (dist < prev_dist) break;
1089 int r = std::min<int>(numzeros - 1, zeros[hashpos]);
1090 if (i + r >= end) r = end - i - 1;
1094 while (i < end && data_[i] == data_[j]) {
1099 // This can trigger even if the new length is slightly smaller than the
1100 // best length, because it is possible for a slightly cheaper distance
1102 if (len >= min_length_ && len + 2 >= best_len) {
1103 auto it = special_dist_table_.find(dist);
1104 int dist_symbol = (it == special_dist_table_.end())
1105 ? (num_special_distances_ + dist - 1)
1107 found_match(len, dist_symbol);
1108 if (len > best_len) best_len = len;
1113 if (chainlength >= maxchainlength) break;
1115 if (numzeros >= 3 && len > numzeros) {
1116 if (hashpos == chainz[hashpos]) break;
1117 hashpos = chainz[hashpos];
1118 if (zeros[hashpos] != numzeros) break;
1120 if (hashpos == chain[hashpos]) break;
1121 hashpos = chain[hashpos];
1122 if (val[hashpos] != (int)hashval) break; // outdated hash value
1126 void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
1127 size_t* result_len) const {
1128 *result_dist_symbol = 0;
1130 FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
1131 if (len > *result_len ||
1132 (len == *result_len && *result_dist_symbol > dist_symbol)) {
1134 *result_dist_symbol = dist_symbol;
1140 float LenCost(size_t len) {
1141 uint32_t nbits, bits, tok;
1142 HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
1143 constexpr float kCostTable[] = {
1144 2.797667318563126, 3.213177690381199, 2.5706009246743737,
1145 2.408392498667534, 2.829649191872326, 3.3923087753324577,
1146 4.029267451554331, 4.415576699706408, 4.509357574741465,
1147 9.21481543803004, 10.020590190114898, 11.858671627804766,
1148 12.45853300490526, 11.713105831990857, 12.561996324849314,
1149 13.775477692278367, 13.174027068768641,
1151 size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1152 if (tok >= table_size) tok = table_size - 1;
1153 return kCostTable[tok] + nbits;
1156 // TODO(veluca): this does not take into account usage or non-usage of distance
1158 float DistCost(size_t dist) {
1159 uint32_t nbits, bits, tok;
1160 HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
1161 constexpr float kCostTable[] = {
1162 6.368282626312716, 5.680793277090298, 8.347404197105247,
1163 7.641619201599141, 6.914328374119438, 7.959808291537444,
1164 8.70023120759855, 8.71378518934703, 9.379132523982769,
1165 9.110472749092708, 9.159029569270908, 9.430936766731973,
1166 7.278284055315169, 7.8278514904267755, 10.026641158289236,
1167 9.976049229827066, 9.64351607048908, 9.563403863480442,
1168 10.171474111762747, 10.45950155077234, 9.994813912104219,
1169 10.322524683741156, 8.465808729388186, 8.756254166066853,
1170 10.160930174662234, 10.247329273413435, 10.04090403724809,
1171 10.129398517544082, 9.342311691539546, 9.07608009102374,
1172 10.104799540677513, 10.378079384990906, 10.165828974075072,
1173 10.337595322341553, 7.940557464567944, 10.575665823319431,
1174 11.023344321751955, 10.736144698831827, 11.118277044595054,
1175 7.468468230648442, 10.738305230932939, 10.906980780216568,
1176 10.163468216353817, 10.17805759656433, 11.167283670483565,
1177 11.147050200274544, 10.517921919244333, 10.651764778156886,
1178 10.17074446448919, 11.217636876224745, 11.261630721139484,
1179 11.403140815247259, 10.892472096873417, 11.1859607804481,
1180 8.017346947551262, 7.895143720278828, 11.036577113822025,
1181 11.170562110315794, 10.326988722591086, 10.40872184751056,
1182 11.213498225466386, 11.30580635516863, 10.672272515665442,
1183 10.768069466228063, 11.145257364153565, 11.64668307145549,
1184 10.593156194627339, 11.207499484844943, 10.767517766396908,
1185 10.826629811407042, 10.737764794499988, 10.6200448518045,
1186 10.191315385198092, 8.468384171390085, 11.731295299170432,
1187 11.824619886654398, 10.41518844301179, 10.16310536548649,
1188 10.539423685097576, 10.495136599328031, 10.469112847728267,
1189 11.72057686174922, 10.910326337834674, 11.378921834673758,
1190 11.847759036098536, 11.92071647623854, 10.810628276345282,
1191 11.008601085273893, 11.910326337834674, 11.949212023423133,
1192 11.298614839104337, 11.611603659010392, 10.472930394619985,
1193 11.835564720850282, 11.523267392285337, 12.01055816679611,
1194 8.413029688994023, 11.895784139536406, 11.984679534970505,
1195 11.220654278717394, 11.716311684833672, 10.61036646226114,
1196 10.89849965960364, 10.203762898863669, 10.997560826267238,
1197 11.484217379438984, 11.792836176993665, 12.24310468755171,
1198 11.464858097919262, 12.212747017409377, 11.425595666074955,
1199 11.572048533398757, 12.742093965163013, 11.381874288645637,
1200 12.191870445817015, 11.683156920035426, 11.152442115262197,
1201 11.90303691580457, 11.653292787169159, 11.938615382266098,
1202 16.970641701570223, 16.853602280380002, 17.26240782594733,
1203 16.644655390108507, 17.14310889757499, 16.910935455445955,
1204 17.505678976959697, 17.213498225466388, 2.4162310293553024,
1205 3.494587244462329, 3.5258600986408344, 3.4959806589517095,
1206 3.098390886949687, 3.343454654302911, 3.588847442290287,
1207 4.14614790111827, 5.152948641990529, 7.433696808092598,
1210 size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1211 if (tok >= table_size) tok = table_size - 1;
1212 return kCostTable[tok] + nbits;
1215 void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts,
1216 const std::vector<std::vector<Token>>& tokens,
1218 std::vector<std::vector<Token>>& tokens_lz77) {
1219 // TODO(veluca): tune heuristics here.
1220 SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
1221 float bit_decrease = 0;
1222 size_t total_symbols = 0;
1223 tokens_lz77.resize(tokens.size());
1224 HybridUintConfig uint_config;
1225 std::vector<float> sym_cost;
1226 for (size_t stream = 0; stream < tokens.size(); stream++) {
1227 size_t distance_multiplier =
1228 params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1229 const auto& in = tokens[stream];
1230 auto& out = tokens_lz77[stream];
1231 total_symbols += in.size();
1232 // Cumulative sum of bit costs.
1233 sym_cost.resize(in.size() + 1);
1234 for (size_t i = 0; i < in.size(); i++) {
1235 uint32_t tok, nbits, unused_bits;
1236 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1237 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1240 out.reserve(in.size());
1241 size_t max_distance = in.size();
1242 size_t min_length = lz77.min_length;
1243 JXL_ASSERT(min_length >= 3);
1244 size_t max_length = in.size();
1246 // Use next power of two as window size.
1247 size_t window_size = 1;
1248 while (window_size < max_distance && window_size < kWindowSize) {
1252 HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1253 distance_multiplier);
1254 size_t len, dist_symbol;
1256 const size_t max_lazy_match_len = 256; // 0 to disable lazy matching
1258 // Whether the next symbol was already updated (to test lazy matching)
1259 bool already_updated = false;
1260 for (size_t i = 0; i < in.size(); i++) {
1261 out.push_back(in[i]);
1262 if (!already_updated) chain.Update(i);
1263 already_updated = false;
1264 chain.FindMatch(i, max_distance, &dist_symbol, &len);
1265 if (len >= min_length) {
1266 if (len < max_lazy_match_len && i + 1 < in.size()) {
1267 // Try length at next symbol lazy matching
1268 chain.Update(i + 1);
1269 already_updated = true;
1270 size_t len2, dist_symbol2;
1271 chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
1273 // Use the lazy match. Add literal, and use the next length starting
1274 // from the next byte.
1276 already_updated = false;
1278 dist_symbol = dist_symbol2;
1279 out.push_back(in[i]);
1283 float cost = sym_cost[i + len] - sym_cost[i];
1284 size_t lz77_len = len - lz77.min_length;
1285 float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
1286 sce.AddSymbolCost(out.back().context);
1288 if (lz77_cost <= cost) {
1289 out.back().value = len - min_length;
1290 out.back().is_lz77_length = true;
1291 out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1292 bit_decrease += cost - lz77_cost;
1294 // LZ77 match ignored, and symbol already pushed. Push all other
1295 // symbols and skip.
1296 for (size_t j = 1; j < len; j++) {
1297 out.push_back(in[i + j]);
1301 if (already_updated) {
1302 chain.Update(i + 2, len - 2);
1303 already_updated = false;
1305 chain.Update(i + 1, len - 1);
1309 // Literal, already pushed
1314 if (bit_decrease > total_symbols * 0.2 + 16) {
1315 lz77.enabled = true;
1319 void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts,
1320 const std::vector<std::vector<Token>>& tokens,
1322 std::vector<std::vector<Token>>& tokens_lz77) {
1323 std::vector<std::vector<Token>> tokens_for_cost_estimate;
1324 ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate);
1325 // If greedy-LZ77 does not give better compression than no-lz77, no reason to
1326 // run the optimal matching.
1327 if (!lz77.enabled) return;
1328 SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
1329 tokens_for_cost_estimate, lz77);
1330 tokens_lz77.resize(tokens.size());
1331 HybridUintConfig uint_config;
1332 std::vector<float> sym_cost;
1333 std::vector<uint32_t> dist_symbols;
1334 for (size_t stream = 0; stream < tokens.size(); stream++) {
1335 size_t distance_multiplier =
1336 params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1337 const auto& in = tokens[stream];
1338 auto& out = tokens_lz77[stream];
1339 // Cumulative sum of bit costs.
1340 sym_cost.resize(in.size() + 1);
1341 for (size_t i = 0; i < in.size(); i++) {
1342 uint32_t tok, nbits, unused_bits;
1343 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1344 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1347 out.reserve(in.size());
1348 size_t max_distance = in.size();
1349 size_t min_length = lz77.min_length;
1350 JXL_ASSERT(min_length >= 3);
1351 size_t max_length = in.size();
1353 // Use next power of two as window size.
1354 size_t window_size = 1;
1355 while (window_size < max_distance && window_size < kWindowSize) {
1359 HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1360 distance_multiplier);
1364 uint32_t dist_symbol;
1366 float total_cost = std::numeric_limits<float>::max();
1368 // Total cost to encode the first N symbols.
1369 std::vector<MatchInfo> prefix_costs(in.size() + 1);
1370 prefix_costs[0].total_cost = 0;
1372 size_t rle_length = 0;
1373 size_t skip_lz77 = 0;
1374 for (size_t i = 0; i < in.size(); i++) {
1377 prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
1378 if (prefix_costs[i + 1].total_cost > lit_cost) {
1379 prefix_costs[i + 1].dist_symbol = 0;
1380 prefix_costs[i + 1].len = 1;
1381 prefix_costs[i + 1].ctx = in[i].context;
1382 prefix_costs[i + 1].total_cost = lit_cost;
1384 if (skip_lz77 > 0) {
1388 dist_symbols.clear();
1389 chain.FindMatches(i, max_distance,
1390 [&dist_symbols](size_t len, size_t dist_symbol) {
1391 if (dist_symbols.size() <= len) {
1392 dist_symbols.resize(len + 1, dist_symbol);
1394 if (dist_symbol < dist_symbols[len]) {
1395 dist_symbols[len] = dist_symbol;
1398 if (dist_symbols.size() <= min_length) continue;
1400 size_t best_cost = dist_symbols.back();
1401 for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
1402 if (dist_symbols[j] < best_cost) {
1403 best_cost = dist_symbols[j];
1405 dist_symbols[j] = best_cost;
1408 for (size_t j = min_length; j < dist_symbols.size(); j++) {
1409 // Cost model that uses results from lazy LZ77.
1410 float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
1411 sce.DistCost(dist_symbols[j], lz77);
1412 float cost = prefix_costs[i].total_cost + lz77_cost;
1413 if (prefix_costs[i + j].total_cost > cost) {
1414 prefix_costs[i + j].len = j;
1415 prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
1416 prefix_costs[i + j].ctx = in[i].context;
1417 prefix_costs[i + j].total_cost = cost;
1420 // We are in a RLE sequence: skip all the symbols except the first 8 and
1421 // the last 8. This avoid quadratic costs for sequences with long runs of
1423 if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
1424 (dist_symbols.back() == 1 && distance_multiplier != 0)) {
1429 if (rle_length >= 8 && dist_symbols.size() > 9) {
1430 skip_lz77 = dist_symbols.size() - 10;
1434 size_t pos = in.size();
1436 bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
1437 if (is_lz77_length) {
1438 size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
1439 out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1441 size_t val = is_lz77_length ? prefix_costs[pos].len - min_length
1442 : in[pos - 1].value;
1443 out.emplace_back(prefix_costs[pos].ctx, val);
1444 out.back().is_lz77_length = is_lz77_length;
1445 pos -= prefix_costs[pos].len;
1447 std::reverse(out.begin(), out.end());
1451 void ApplyLZ77(const HistogramParams& params, size_t num_contexts,
1452 const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77,
1453 std::vector<std::vector<Token>>& tokens_lz77) {
1454 lz77.enabled = false;
1455 if (params.force_huffman) {
1456 lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512);
1458 lz77.min_symbol = 224;
1460 if (params.lz77_method == HistogramParams::LZ77Method::kNone) {
1462 } else if (params.lz77_method == HistogramParams::LZ77Method::kRLE) {
1463 ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77);
1464 } else if (params.lz77_method == HistogramParams::LZ77Method::kLZ77) {
1465 ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77);
1466 } else if (params.lz77_method == HistogramParams::LZ77Method::kOptimal) {
1467 ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77);
1469 JXL_UNREACHABLE("Not implemented");
1474 size_t BuildAndEncodeHistograms(const HistogramParams& params,
1475 size_t num_contexts,
1476 std::vector<std::vector<Token>>& tokens,
1477 EntropyEncodingData* codes,
1478 std::vector<uint8_t>* context_map,
1479 BitWriter* writer, size_t layer,
1481 size_t total_bits = 0;
1482 codes->lz77.nonserialized_distance_context = num_contexts;
1483 std::vector<std::vector<Token>> tokens_lz77;
1484 ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77);
1485 if (ans_fuzzer_friendly_) {
1486 codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0);
1487 codes->lz77.min_symbol = 2048;
1490 const size_t max_contexts = std::min(num_contexts, kClustersLimit);
1491 BitWriter::Allotment allotment(writer,
1492 128 + num_contexts * 40 + max_contexts * 96);
1494 JXL_CHECK(Bundle::Write(codes->lz77, writer, layer, aux_out));
1497 JXL_CHECK(Bundle::CanEncode(codes->lz77, &ebits, &bits));
1500 if (codes->lz77.enabled) {
1502 size_t b = writer->BitsWritten();
1503 EncodeUintConfig(codes->lz77.length_uint_config, writer,
1504 /*log_alpha_size=*/8);
1505 total_bits += writer->BitsWritten() - b;
1507 SizeWriter size_writer;
1508 EncodeUintConfig(codes->lz77.length_uint_config, &size_writer,
1509 /*log_alpha_size=*/8);
1510 total_bits += size_writer.size;
1513 tokens = std::move(tokens_lz77);
1515 size_t total_tokens = 0;
1516 // Build histograms.
1517 HistogramBuilder builder(num_contexts);
1518 HybridUintConfig uint_config; // Default config for clustering.
1519 // Unless we are using the kContextMap histogram option.
1520 if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
1521 uint_config = HybridUintConfig(2, 0, 1);
1523 if (params.uint_method == HistogramParams::HybridUintMethod::k000) {
1524 uint_config = HybridUintConfig(0, 0, 0);
1526 if (ans_fuzzer_friendly_) {
1527 uint_config = HybridUintConfig(10, 0, 0);
1529 for (size_t i = 0; i < tokens.size(); ++i) {
1530 if (codes->lz77.enabled) {
1531 for (size_t j = 0; j < tokens[i].size(); ++j) {
1532 const Token& token = tokens[i][j];
1534 uint32_t tok, nbits, bits;
1535 (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
1536 .Encode(token.value, &tok, &nbits, &bits);
1537 tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
1538 builder.VisitSymbol(tok, token.context);
1540 } else if (num_contexts == 1) {
1541 for (size_t j = 0; j < tokens[i].size(); ++j) {
1542 const Token& token = tokens[i][j];
1544 uint32_t tok, nbits, bits;
1545 uint_config.Encode(token.value, &tok, &nbits, &bits);
1546 builder.VisitSymbol(tok, /*token.context=*/0);
1549 for (size_t j = 0; j < tokens[i].size(); ++j) {
1550 const Token& token = tokens[i][j];
1552 uint32_t tok, nbits, bits;
1553 uint_config.Encode(token.value, &tok, &nbits, &bits);
1554 builder.VisitSymbol(tok, token.context);
1559 bool use_prefix_code =
1560 params.force_huffman || total_tokens < 100 ||
1561 params.clustering == HistogramParams::ClusteringType::kFastest ||
1562 ans_fuzzer_friendly_;
1563 if (!use_prefix_code) {
1564 bool all_singleton = true;
1565 for (size_t i = 0; i < num_contexts; i++) {
1566 if (builder.Histo(i).ShannonEntropy() >= 1e-5) {
1567 all_singleton = false;
1570 if (all_singleton) {
1571 use_prefix_code = true;
1575 // Encode histograms.
1576 total_bits += builder.BuildAndStoreEntropyCodes(params, tokens, codes,
1577 context_map, use_prefix_code,
1578 writer, layer, aux_out);
1579 allotment.FinishedHistogram(writer);
1580 allotment.ReclaimAndCharge(writer, layer, aux_out);
1582 if (aux_out != nullptr) {
1583 aux_out->layers[layer].num_clustered_histograms +=
1584 codes->encoding_info.size();
1589 size_t WriteTokens(const std::vector<Token>& tokens,
1590 const EntropyEncodingData& codes,
1591 const std::vector<uint8_t>& context_map, BitWriter* writer) {
1592 size_t num_extra_bits = 0;
1593 if (codes.use_prefix_code) {
1594 for (size_t i = 0; i < tokens.size(); i++) {
1595 uint32_t tok, nbits, bits;
1596 const Token& token = tokens[i];
1597 size_t histo = context_map[token.context];
1598 (token.is_lz77_length ? codes.lz77.length_uint_config
1599 : codes.uint_config[histo])
1600 .Encode(token.value, &tok, &nbits, &bits);
1601 tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1602 // Combine two calls to the BitWriter. Equivalent to:
1603 // writer->Write(codes.encoding_info[histo][tok].depth,
1604 // codes.encoding_info[histo][tok].bits);
1605 // writer->Write(nbits, bits);
1606 uint64_t data = codes.encoding_info[histo][tok].bits;
1607 data |= bits << codes.encoding_info[histo][tok].depth;
1608 writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
1609 num_extra_bits += nbits;
1611 return num_extra_bits;
1613 std::vector<uint64_t> out;
1614 std::vector<uint8_t> out_nbits;
1615 out.reserve(tokens.size());
1616 out_nbits.reserve(tokens.size());
1617 uint64_t allbits = 0;
1618 size_t numallbits = 0;
1619 // Writes in *reversed* order.
1620 auto addbits = [&](size_t bits, size_t nbits) {
1621 if (JXL_UNLIKELY(nbits)) {
1622 JXL_DASSERT(bits >> nbits == 0);
1623 if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) {
1624 out.push_back(allbits);
1625 out_nbits.push_back(numallbits);
1626 numallbits = allbits = 0;
1630 numallbits += nbits;
1633 const int end = tokens.size();
1635 if (codes.lz77.enabled || context_map.size() > 1) {
1636 for (int i = end - 1; i >= 0; --i) {
1637 const Token token = tokens[i];
1638 const uint8_t histo = context_map[token.context];
1639 uint32_t tok, nbits, bits;
1640 (token.is_lz77_length ? codes.lz77.length_uint_config
1641 : codes.uint_config[histo])
1642 .Encode(tokens[i].value, &tok, &nbits, &bits);
1643 tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1644 const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok];
1645 // Extra bits first as this is reversed.
1646 addbits(bits, nbits);
1647 num_extra_bits += nbits;
1648 uint8_t ans_nbits = 0;
1649 uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1650 addbits(ans_bits, ans_nbits);
1653 for (int i = end - 1; i >= 0; --i) {
1654 uint32_t tok, nbits, bits;
1655 codes.uint_config[0].Encode(tokens[i].value, &tok, &nbits, &bits);
1656 const ANSEncSymbolInfo& info = codes.encoding_info[0][tok];
1657 // Extra bits first as this is reversed.
1658 addbits(bits, nbits);
1659 num_extra_bits += nbits;
1660 uint8_t ans_nbits = 0;
1661 uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1662 addbits(ans_bits, ans_nbits);
1665 const uint32_t state = ans.GetState();
1666 writer->Write(32, state);
1667 writer->Write(numallbits, allbits);
1668 for (int i = out.size(); i > 0; --i) {
1669 writer->Write(out_nbits[i - 1], out[i - 1]);
1671 return num_extra_bits;
1674 void WriteTokens(const std::vector<Token>& tokens,
1675 const EntropyEncodingData& codes,
1676 const std::vector<uint8_t>& context_map, BitWriter* writer,
1677 size_t layer, AuxOut* aux_out) {
1678 BitWriter::Allotment allotment(writer, 32 * tokens.size() + 32 * 1024 * 4);
1679 size_t num_extra_bits = WriteTokens(tokens, codes, context_map, writer);
1680 allotment.ReclaimAndCharge(writer, layer, aux_out);
1681 if (aux_out != nullptr) {
1682 aux_out->layers[layer].extra_bits += num_extra_bits;
1686 void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) {
1687 #if JXL_IS_DEBUG_BUILD // Guard against accidental / malicious changes.
1688 ans_fuzzer_friendly_ = ans_fuzzer_friendly;