1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
6 #include "lib/jxl/fields.h"
14 #include "lib/jxl/base/bits.h"
15 #include "lib/jxl/base/printf_macros.h"
21 using ::jxl::fields_internal::VisitorBase;
23 struct InitVisitor : public VisitorBase {
24 Status Bits(const size_t /*unused*/, const uint32_t default_value,
25 uint32_t* JXL_RESTRICT value) override {
26 *value = default_value;
30 Status U32(const U32Enc /*unused*/, const uint32_t default_value,
31 uint32_t* JXL_RESTRICT value) override {
32 *value = default_value;
36 Status U64(const uint64_t default_value,
37 uint64_t* JXL_RESTRICT value) override {
38 *value = default_value;
42 Status Bool(bool default_value, bool* JXL_RESTRICT value) override {
43 *value = default_value;
47 Status F16(const float default_value, float* JXL_RESTRICT value) override {
48 *value = default_value;
52 // Always visit conditional fields to ensure they are initialized.
53 Status Conditional(bool /*condition*/) override { return true; }
55 Status AllDefault(const Fields& /*fields*/,
56 bool* JXL_RESTRICT all_default) override {
57 // Just initialize this field and don't skip initializing others.
58 JXL_RETURN_IF_ERROR(Bool(true, all_default));
62 Status VisitNested(Fields* /*fields*/) override {
63 // Avoid re-initializing nested bundles (their ctors already called
64 // Bundle::Init for their fields).
69 // Similar to InitVisitor, but also initializes nested fields.
70 struct SetDefaultVisitor : public VisitorBase {
71 Status Bits(const size_t /*unused*/, const uint32_t default_value,
72 uint32_t* JXL_RESTRICT value) override {
73 *value = default_value;
77 Status U32(const U32Enc /*unused*/, const uint32_t default_value,
78 uint32_t* JXL_RESTRICT value) override {
79 *value = default_value;
83 Status U64(const uint64_t default_value,
84 uint64_t* JXL_RESTRICT value) override {
85 *value = default_value;
89 Status Bool(bool default_value, bool* JXL_RESTRICT value) override {
90 *value = default_value;
94 Status F16(const float default_value, float* JXL_RESTRICT value) override {
95 *value = default_value;
99 // Always visit conditional fields to ensure they are initialized.
100 Status Conditional(bool /*condition*/) override { return true; }
102 Status AllDefault(const Fields& /*fields*/,
103 bool* JXL_RESTRICT all_default) override {
104 // Just initialize this field and don't skip initializing others.
105 JXL_RETURN_IF_ERROR(Bool(true, all_default));
110 class AllDefaultVisitor : public VisitorBase {
112 explicit AllDefaultVisitor() : VisitorBase() {}
114 Status Bits(const size_t bits, const uint32_t default_value,
115 uint32_t* JXL_RESTRICT value) override {
116 all_default_ &= *value == default_value;
120 Status U32(const U32Enc /*unused*/, const uint32_t default_value,
121 uint32_t* JXL_RESTRICT value) override {
122 all_default_ &= *value == default_value;
126 Status U64(const uint64_t default_value,
127 uint64_t* JXL_RESTRICT value) override {
128 all_default_ &= *value == default_value;
132 Status F16(const float default_value, float* JXL_RESTRICT value) override {
133 all_default_ &= std::abs(*value - default_value) < 1E-6f;
137 Status AllDefault(const Fields& /*fields*/,
138 bool* JXL_RESTRICT /*all_default*/) override {
139 // Visit all fields so we can compute the actual all_default_ value.
143 bool AllDefault() const { return all_default_; }
146 bool all_default_ = true;
149 class ReadVisitor : public VisitorBase {
151 explicit ReadVisitor(BitReader* reader) : VisitorBase(), reader_(reader) {}
153 Status Bits(const size_t bits, const uint32_t /*default_value*/,
154 uint32_t* JXL_RESTRICT value) override {
155 *value = BitsCoder::Read(bits, reader_);
156 if (!reader_->AllReadsWithinBounds()) {
157 return JXL_STATUS(StatusCode::kNotEnoughBytes,
158 "Not enough bytes for header");
163 Status U32(const U32Enc dist, const uint32_t /*default_value*/,
164 uint32_t* JXL_RESTRICT value) override {
165 *value = U32Coder::Read(dist, reader_);
166 if (!reader_->AllReadsWithinBounds()) {
167 return JXL_STATUS(StatusCode::kNotEnoughBytes,
168 "Not enough bytes for header");
173 Status U64(const uint64_t /*default_value*/,
174 uint64_t* JXL_RESTRICT value) override {
175 *value = U64Coder::Read(reader_);
176 if (!reader_->AllReadsWithinBounds()) {
177 return JXL_STATUS(StatusCode::kNotEnoughBytes,
178 "Not enough bytes for header");
183 Status F16(const float /*default_value*/,
184 float* JXL_RESTRICT value) override {
185 ok_ &= F16Coder::Read(reader_, value);
186 if (!reader_->AllReadsWithinBounds()) {
187 return JXL_STATUS(StatusCode::kNotEnoughBytes,
188 "Not enough bytes for header");
193 void SetDefault(Fields* fields) override { Bundle::SetDefault(fields); }
195 bool IsReading() const override { return true; }
197 // This never fails because visitors are expected to keep reading until
198 // EndExtensions, see comment there.
199 Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override {
200 JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions));
201 if (*extensions == 0) return true;
203 // For each nonzero bit, i.e. extension that is present:
204 for (uint64_t remaining_extensions = *extensions; remaining_extensions != 0;
205 remaining_extensions &= remaining_extensions - 1) {
206 const size_t idx_extension =
207 Num0BitsBelowLS1Bit_Nonzero(remaining_extensions);
208 // Read additional U64 (one per extension) indicating the number of bits
209 // (allows skipping individual extensions).
210 JXL_RETURN_IF_ERROR(U64(0, &extension_bits_[idx_extension]));
211 if (!SafeAdd(total_extension_bits_, extension_bits_[idx_extension],
212 total_extension_bits_)) {
213 return JXL_FAILURE("Extension bits overflowed, invalid codestream");
216 // Used by EndExtensions to skip past any _remaining_ extensions.
217 pos_after_ext_size_ = reader_->TotalBitsConsumed();
218 JXL_ASSERT(pos_after_ext_size_ != 0);
222 Status EndExtensions() override {
223 JXL_QUIET_RETURN_IF_ERROR(VisitorBase::EndExtensions());
224 // Happens if extensions == 0: don't read size, done.
225 if (pos_after_ext_size_ == 0) return true;
227 // Not enough bytes as set by BeginExtensions or earlier. Do not return
228 // this as a JXL_FAILURE or false (which can also propagate to error
229 // through e.g. JXL_RETURN_IF_ERROR), since this may be used while
230 // silently checking whether there are enough bytes. If this case must be
231 // treated as an error, reader_>Close() will do this, just like is already
232 // done for non-extension fields.
233 if (!enough_bytes_) return true;
235 // Skip new fields this (old?) decoder didn't know about, if any.
236 const size_t bits_read = reader_->TotalBitsConsumed();
238 if (!SafeAdd(pos_after_ext_size_, total_extension_bits_, end)) {
239 return JXL_FAILURE("Invalid extension size, caused overflow");
241 if (bits_read > end) {
242 return JXL_FAILURE("Read more extension bits than budgeted");
244 const size_t remaining_bits = end - bits_read;
245 if (remaining_bits != 0) {
246 JXL_WARNING("Skipping %" PRIuS "-bit extension(s)", remaining_bits);
247 reader_->SkipBits(remaining_bits);
248 if (!reader_->AllReadsWithinBounds()) {
249 return JXL_STATUS(StatusCode::kNotEnoughBytes,
250 "Not enough bytes for header");
256 Status OK() const { return ok_; }
259 // Whether any error other than not enough bytes occurred.
262 // Whether there are enough input bytes to read from.
263 bool enough_bytes_ = true;
264 BitReader* const reader_;
265 // May be 0 even if the corresponding extension is present.
266 uint64_t extension_bits_[Bundle::kMaxExtensions] = {0};
267 uint64_t total_extension_bits_ = 0;
268 size_t pos_after_ext_size_ = 0; // 0 iff extensions == 0.
270 friend Status jxl::CheckHasEnoughBits(Visitor*, size_t);
273 class MaxBitsVisitor : public VisitorBase {
275 Status Bits(const size_t bits, const uint32_t /*default_value*/,
276 uint32_t* JXL_RESTRICT /*value*/) override {
277 max_bits_ += BitsCoder::MaxEncodedBits(bits);
281 Status U32(const U32Enc enc, const uint32_t /*default_value*/,
282 uint32_t* JXL_RESTRICT /*value*/) override {
283 max_bits_ += U32Coder::MaxEncodedBits(enc);
287 Status U64(const uint64_t /*default_value*/,
288 uint64_t* JXL_RESTRICT /*value*/) override {
289 max_bits_ += U64Coder::MaxEncodedBits();
293 Status F16(const float /*default_value*/,
294 float* JXL_RESTRICT /*value*/) override {
295 max_bits_ += F16Coder::MaxEncodedBits();
299 Status AllDefault(const Fields& /*fields*/,
300 bool* JXL_RESTRICT all_default) override {
301 JXL_RETURN_IF_ERROR(Bool(true, all_default));
302 return false; // For max bits, assume nothing is default
305 // Always visit conditional fields to get a (loose) upper bound.
306 Status Conditional(bool /*condition*/) override { return true; }
308 Status BeginExtensions(uint64_t* JXL_RESTRICT /*extensions*/) override {
309 // Skip - extensions are not included in "MaxBits" because their length
310 // is potentially unbounded.
314 Status EndExtensions() override { return true; }
316 size_t MaxBits() const { return max_bits_; }
319 size_t max_bits_ = 0;
322 class CanEncodeVisitor : public VisitorBase {
324 explicit CanEncodeVisitor() : VisitorBase() {}
326 Status Bits(const size_t bits, const uint32_t /*default_value*/,
327 uint32_t* JXL_RESTRICT value) override {
328 size_t encoded_bits = 0;
329 ok_ &= BitsCoder::CanEncode(bits, *value, &encoded_bits);
330 encoded_bits_ += encoded_bits;
334 Status U32(const U32Enc enc, const uint32_t /*default_value*/,
335 uint32_t* JXL_RESTRICT value) override {
336 size_t encoded_bits = 0;
337 ok_ &= U32Coder::CanEncode(enc, *value, &encoded_bits);
338 encoded_bits_ += encoded_bits;
342 Status U64(const uint64_t /*default_value*/,
343 uint64_t* JXL_RESTRICT value) override {
344 size_t encoded_bits = 0;
345 ok_ &= U64Coder::CanEncode(*value, &encoded_bits);
346 encoded_bits_ += encoded_bits;
350 Status F16(const float /*default_value*/,
351 float* JXL_RESTRICT value) override {
352 size_t encoded_bits = 0;
353 ok_ &= F16Coder::CanEncode(*value, &encoded_bits);
354 encoded_bits_ += encoded_bits;
358 Status AllDefault(const Fields& fields,
359 bool* JXL_RESTRICT all_default) override {
360 *all_default = Bundle::AllDefault(fields);
361 JXL_RETURN_IF_ERROR(Bool(true, all_default));
365 Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override {
366 JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions));
367 extensions_ = *extensions;
368 if (*extensions != 0) {
369 JXL_ASSERT(pos_after_ext_ == 0);
370 pos_after_ext_ = encoded_bits_;
371 JXL_ASSERT(pos_after_ext_ != 0); // visited "extensions"
375 // EndExtensions = default.
377 Status GetSizes(size_t* JXL_RESTRICT extension_bits,
378 size_t* JXL_RESTRICT total_bits) {
379 JXL_RETURN_IF_ERROR(ok_);
381 *total_bits = encoded_bits_;
382 // Only if extension field was nonzero will we encode their sizes.
383 if (pos_after_ext_ != 0) {
384 JXL_ASSERT(encoded_bits_ >= pos_after_ext_);
385 *extension_bits = encoded_bits_ - pos_after_ext_;
386 // Also need to encode *extension_bits and bill it to *total_bits.
387 size_t encoded_bits = 0;
388 ok_ &= U64Coder::CanEncode(*extension_bits, &encoded_bits);
389 *total_bits += encoded_bits;
391 // TODO(janwas): support encoding individual extension sizes. We
392 // currently ascribe all bits to the first and send zeros for the
394 for (size_t i = 1; i < hwy::PopCount(extensions_); ++i) {
396 ok_ &= U64Coder::CanEncode(0, &encoded_bits);
397 *total_bits += encoded_bits;
405 size_t encoded_bits_ = 0;
406 uint64_t extensions_ = 0;
407 // Snapshot of encoded_bits_ after visiting the extension field, but NOT
408 // including the hidden extension sizes.
409 uint64_t pos_after_ext_ = 0;
413 void Bundle::Init(Fields* fields) {
415 if (!visitor.Visit(fields)) {
416 JXL_UNREACHABLE("Init should never fail");
419 void Bundle::SetDefault(Fields* fields) {
420 SetDefaultVisitor visitor;
421 if (!visitor.Visit(fields)) {
422 JXL_UNREACHABLE("SetDefault should never fail");
425 bool Bundle::AllDefault(const Fields& fields) {
426 AllDefaultVisitor visitor;
427 if (!visitor.VisitConst(fields)) {
428 JXL_UNREACHABLE("AllDefault should never fail");
430 return visitor.AllDefault();
432 size_t Bundle::MaxBits(const Fields& fields) {
433 MaxBitsVisitor visitor;
434 #if JXL_ENABLE_ASSERT
438 #endif // JXL_ENABLE_ASSERT
439 visitor.VisitConst(fields);
441 return visitor.MaxBits();
443 Status Bundle::CanEncode(const Fields& fields, size_t* extension_bits,
444 size_t* total_bits) {
445 CanEncodeVisitor visitor;
446 JXL_QUIET_RETURN_IF_ERROR(visitor.VisitConst(fields));
447 JXL_QUIET_RETURN_IF_ERROR(visitor.GetSizes(extension_bits, total_bits));
450 Status Bundle::Read(BitReader* reader, Fields* fields) {
451 ReadVisitor visitor(reader);
452 JXL_RETURN_IF_ERROR(visitor.Visit(fields));
455 bool Bundle::CanRead(BitReader* reader, Fields* fields) {
456 ReadVisitor visitor(reader);
457 Status status = visitor.Visit(fields);
458 // We are only checking here whether there are enough bytes. We still return
459 // true for other errors because it means there are enough bytes to determine
460 // there's an error. Use Read() to determine which error it is.
461 return status.code() != StatusCode::kNotEnoughBytes;
464 size_t BitsCoder::MaxEncodedBits(const size_t bits) { return bits; }
466 Status BitsCoder::CanEncode(const size_t bits, const uint32_t value,
467 size_t* JXL_RESTRICT encoded_bits) {
468 *encoded_bits = bits;
469 if (value >= (1ULL << bits)) {
470 return JXL_FAILURE("Value %u too large for %" PRIu64 " bits", value,
471 static_cast<uint64_t>(bits));
476 uint32_t BitsCoder::Read(const size_t bits, BitReader* JXL_RESTRICT reader) {
477 return reader->ReadBits(bits);
480 size_t U32Coder::MaxEncodedBits(const U32Enc enc) {
481 size_t extra_bits = 0;
482 for (uint32_t selector = 0; selector < 4; ++selector) {
483 const U32Distr d = enc.GetDistr(selector);
487 extra_bits = std::max<size_t>(extra_bits, d.ExtraBits());
490 return 2 + extra_bits;
493 Status U32Coder::CanEncode(const U32Enc enc, const uint32_t value,
494 size_t* JXL_RESTRICT encoded_bits) {
497 const Status ok = ChooseSelector(enc, value, &selector, &total_bits);
498 *encoded_bits = ok ? total_bits : 0;
502 uint32_t U32Coder::Read(const U32Enc enc, BitReader* JXL_RESTRICT reader) {
503 const uint32_t selector = reader->ReadFixedBits<2>();
504 const U32Distr d = enc.GetDistr(selector);
508 return reader->ReadBits(d.ExtraBits()) + d.Offset();
512 Status U32Coder::ChooseSelector(const U32Enc enc, const uint32_t value,
513 uint32_t* JXL_RESTRICT selector,
514 size_t* JXL_RESTRICT total_bits) {
515 #if JXL_ENABLE_ASSERT
516 const size_t bits_required = 32 - Num0BitsAboveMS1Bit(value);
517 #endif // JXL_ENABLE_ASSERT
518 JXL_ASSERT(bits_required <= 32);
523 // It is difficult to verify whether Dist32Byte are sorted, so check all
524 // selectors and keep the one with the fewest total_bits.
525 *total_bits = 64; // more than any valid encoding
526 for (uint32_t s = 0; s < 4; ++s) {
527 const U32Distr d = enc.GetDistr(s);
529 if (d.Direct() == value) {
532 return true; // Done, direct is always the best possible.
536 const size_t extra_bits = d.ExtraBits();
537 const uint32_t offset = d.Offset();
538 if (value < offset || value >= offset + (1ULL << extra_bits)) continue;
540 // Better than prior encoding, remember it:
541 if (2 + extra_bits < *total_bits) {
543 *total_bits = 2 + extra_bits;
547 if (*total_bits == 64) {
548 return JXL_FAILURE("No feasible selector for %u", value);
554 uint64_t U64Coder::Read(BitReader* JXL_RESTRICT reader) {
555 uint64_t selector = reader->ReadFixedBits<2>();
560 return 1 + reader->ReadFixedBits<4>();
563 return 17 + reader->ReadFixedBits<8>();
566 // selector 3, varint, groups have first 12, then 8, and last 4 bits.
567 uint64_t result = reader->ReadFixedBits<12>();
570 while (reader->ReadFixedBits<1>()) {
572 result |= static_cast<uint64_t>(reader->ReadFixedBits<4>()) << shift;
575 result |= static_cast<uint64_t>(reader->ReadFixedBits<8>()) << shift;
582 // Can always encode, but useful because it also returns bit size.
583 Status U64Coder::CanEncode(uint64_t value, size_t* JXL_RESTRICT encoded_bits) {
585 *encoded_bits = 2; // 2 selector bits
586 } else if (value <= 16) {
587 *encoded_bits = 2 + 4; // 2 selector bits + 4 payload bits
588 } else if (value <= 272) {
589 *encoded_bits = 2 + 8; // 2 selector bits + 8 payload bits
591 *encoded_bits = 2 + 12; // 2 selector bits + 12 payload bits
594 while (value > 0 && shift < 60) {
595 *encoded_bits += 1 + 8; // 1 continuation bit + 8 payload bits
600 // This only could happen if shift == N - 4.
601 *encoded_bits += 1 + 4; // 1 continuation bit + 4 payload bits
603 *encoded_bits += 1; // 1 stop bit
610 Status F16Coder::Read(BitReader* JXL_RESTRICT reader,
611 float* JXL_RESTRICT value) {
612 const uint32_t bits16 = reader->ReadFixedBits<16>();
613 const uint32_t sign = bits16 >> 15;
614 const uint32_t biased_exp = (bits16 >> 10) & 0x1F;
615 const uint32_t mantissa = bits16 & 0x3FF;
617 if (JXL_UNLIKELY(biased_exp == 31)) {
618 return JXL_FAILURE("F16 infinity or NaN are not supported");
622 if (JXL_UNLIKELY(biased_exp == 0)) {
623 *value = (1.0f / 16384) * (mantissa * (1.0f / 1024));
624 if (sign) *value = -*value;
628 // Normalized: convert the representation directly (faster than ldexp/tables).
629 const uint32_t biased_exp32 = biased_exp + (127 - 15);
630 const uint32_t mantissa32 = mantissa << (23 - 10);
631 const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32;
632 memcpy(value, &bits32, sizeof(bits32));
636 Status F16Coder::CanEncode(float value, size_t* JXL_RESTRICT encoded_bits) {
637 *encoded_bits = MaxEncodedBits();
638 if (std::isnan(value) || std::isinf(value)) {
639 return JXL_FAILURE("Should not attempt to store NaN and infinity");
641 return std::abs(value) <= 65504.0f;
644 Status CheckHasEnoughBits(Visitor* visitor, size_t bits) {
645 if (!visitor->IsReading()) return false;
646 ReadVisitor* rv = static_cast<ReadVisitor*>(visitor);
647 size_t have_bits = rv->reader_->TotalBytes() * kBitsPerByte;
648 size_t want_bits = bits + rv->reader_->TotalBitsConsumed();
649 if (have_bits < want_bits) {
650 return JXL_STATUS(StatusCode::kNotEnoughBytes,
651 "Not enough bytes for header");