Imported Upstream version 0.9.0
[platform/upstream/libjxl.git] / lib / jxl / fields.cc
1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include "lib/jxl/fields.h"
7
8 #include <stddef.h>
9
10 #include <algorithm>
11 #include <cmath>
12 #include <hwy/base.h>
13
14 #include "lib/jxl/base/bits.h"
15 #include "lib/jxl/base/printf_macros.h"
16
17 namespace jxl {
18
19 namespace {
20
21 using ::jxl::fields_internal::VisitorBase;
22
23 struct InitVisitor : public VisitorBase {
24   Status Bits(const size_t /*unused*/, const uint32_t default_value,
25               uint32_t* JXL_RESTRICT value) override {
26     *value = default_value;
27     return true;
28   }
29
30   Status U32(const U32Enc /*unused*/, const uint32_t default_value,
31              uint32_t* JXL_RESTRICT value) override {
32     *value = default_value;
33     return true;
34   }
35
36   Status U64(const uint64_t default_value,
37              uint64_t* JXL_RESTRICT value) override {
38     *value = default_value;
39     return true;
40   }
41
42   Status Bool(bool default_value, bool* JXL_RESTRICT value) override {
43     *value = default_value;
44     return true;
45   }
46
47   Status F16(const float default_value, float* JXL_RESTRICT value) override {
48     *value = default_value;
49     return true;
50   }
51
52   // Always visit conditional fields to ensure they are initialized.
53   Status Conditional(bool /*condition*/) override { return true; }
54
55   Status AllDefault(const Fields& /*fields*/,
56                     bool* JXL_RESTRICT all_default) override {
57     // Just initialize this field and don't skip initializing others.
58     JXL_RETURN_IF_ERROR(Bool(true, all_default));
59     return false;
60   }
61
62   Status VisitNested(Fields* /*fields*/) override {
63     // Avoid re-initializing nested bundles (their ctors already called
64     // Bundle::Init for their fields).
65     return true;
66   }
67 };
68
69 // Similar to InitVisitor, but also initializes nested fields.
70 struct SetDefaultVisitor : public VisitorBase {
71   Status Bits(const size_t /*unused*/, const uint32_t default_value,
72               uint32_t* JXL_RESTRICT value) override {
73     *value = default_value;
74     return true;
75   }
76
77   Status U32(const U32Enc /*unused*/, const uint32_t default_value,
78              uint32_t* JXL_RESTRICT value) override {
79     *value = default_value;
80     return true;
81   }
82
83   Status U64(const uint64_t default_value,
84              uint64_t* JXL_RESTRICT value) override {
85     *value = default_value;
86     return true;
87   }
88
89   Status Bool(bool default_value, bool* JXL_RESTRICT value) override {
90     *value = default_value;
91     return true;
92   }
93
94   Status F16(const float default_value, float* JXL_RESTRICT value) override {
95     *value = default_value;
96     return true;
97   }
98
99   // Always visit conditional fields to ensure they are initialized.
100   Status Conditional(bool /*condition*/) override { return true; }
101
102   Status AllDefault(const Fields& /*fields*/,
103                     bool* JXL_RESTRICT all_default) override {
104     // Just initialize this field and don't skip initializing others.
105     JXL_RETURN_IF_ERROR(Bool(true, all_default));
106     return false;
107   }
108 };
109
110 class AllDefaultVisitor : public VisitorBase {
111  public:
112   explicit AllDefaultVisitor() : VisitorBase() {}
113
114   Status Bits(const size_t bits, const uint32_t default_value,
115               uint32_t* JXL_RESTRICT value) override {
116     all_default_ &= *value == default_value;
117     return true;
118   }
119
120   Status U32(const U32Enc /*unused*/, const uint32_t default_value,
121              uint32_t* JXL_RESTRICT value) override {
122     all_default_ &= *value == default_value;
123     return true;
124   }
125
126   Status U64(const uint64_t default_value,
127              uint64_t* JXL_RESTRICT value) override {
128     all_default_ &= *value == default_value;
129     return true;
130   }
131
132   Status F16(const float default_value, float* JXL_RESTRICT value) override {
133     all_default_ &= std::abs(*value - default_value) < 1E-6f;
134     return true;
135   }
136
137   Status AllDefault(const Fields& /*fields*/,
138                     bool* JXL_RESTRICT /*all_default*/) override {
139     // Visit all fields so we can compute the actual all_default_ value.
140     return false;
141   }
142
143   bool AllDefault() const { return all_default_; }
144
145  private:
146   bool all_default_ = true;
147 };
148
149 class ReadVisitor : public VisitorBase {
150  public:
151   explicit ReadVisitor(BitReader* reader) : VisitorBase(), reader_(reader) {}
152
153   Status Bits(const size_t bits, const uint32_t /*default_value*/,
154               uint32_t* JXL_RESTRICT value) override {
155     *value = BitsCoder::Read(bits, reader_);
156     if (!reader_->AllReadsWithinBounds()) {
157       return JXL_STATUS(StatusCode::kNotEnoughBytes,
158                         "Not enough bytes for header");
159     }
160     return true;
161   }
162
163   Status U32(const U32Enc dist, const uint32_t /*default_value*/,
164              uint32_t* JXL_RESTRICT value) override {
165     *value = U32Coder::Read(dist, reader_);
166     if (!reader_->AllReadsWithinBounds()) {
167       return JXL_STATUS(StatusCode::kNotEnoughBytes,
168                         "Not enough bytes for header");
169     }
170     return true;
171   }
172
173   Status U64(const uint64_t /*default_value*/,
174              uint64_t* JXL_RESTRICT value) override {
175     *value = U64Coder::Read(reader_);
176     if (!reader_->AllReadsWithinBounds()) {
177       return JXL_STATUS(StatusCode::kNotEnoughBytes,
178                         "Not enough bytes for header");
179     }
180     return true;
181   }
182
183   Status F16(const float /*default_value*/,
184              float* JXL_RESTRICT value) override {
185     ok_ &= F16Coder::Read(reader_, value);
186     if (!reader_->AllReadsWithinBounds()) {
187       return JXL_STATUS(StatusCode::kNotEnoughBytes,
188                         "Not enough bytes for header");
189     }
190     return true;
191   }
192
193   void SetDefault(Fields* fields) override { Bundle::SetDefault(fields); }
194
195   bool IsReading() const override { return true; }
196
197   // This never fails because visitors are expected to keep reading until
198   // EndExtensions, see comment there.
199   Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override {
200     JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions));
201     if (*extensions == 0) return true;
202
203     // For each nonzero bit, i.e. extension that is present:
204     for (uint64_t remaining_extensions = *extensions; remaining_extensions != 0;
205          remaining_extensions &= remaining_extensions - 1) {
206       const size_t idx_extension =
207           Num0BitsBelowLS1Bit_Nonzero(remaining_extensions);
208       // Read additional U64 (one per extension) indicating the number of bits
209       // (allows skipping individual extensions).
210       JXL_RETURN_IF_ERROR(U64(0, &extension_bits_[idx_extension]));
211       if (!SafeAdd(total_extension_bits_, extension_bits_[idx_extension],
212                    total_extension_bits_)) {
213         return JXL_FAILURE("Extension bits overflowed, invalid codestream");
214       }
215     }
216     // Used by EndExtensions to skip past any _remaining_ extensions.
217     pos_after_ext_size_ = reader_->TotalBitsConsumed();
218     JXL_ASSERT(pos_after_ext_size_ != 0);
219     return true;
220   }
221
222   Status EndExtensions() override {
223     JXL_QUIET_RETURN_IF_ERROR(VisitorBase::EndExtensions());
224     // Happens if extensions == 0: don't read size, done.
225     if (pos_after_ext_size_ == 0) return true;
226
227     // Not enough bytes as set by BeginExtensions or earlier. Do not return
228     // this as a JXL_FAILURE or false (which can also propagate to error
229     // through e.g. JXL_RETURN_IF_ERROR), since this may be used while
230     // silently checking whether there are enough bytes. If this case must be
231     // treated as an error, reader_>Close() will do this, just like is already
232     // done for non-extension fields.
233     if (!enough_bytes_) return true;
234
235     // Skip new fields this (old?) decoder didn't know about, if any.
236     const size_t bits_read = reader_->TotalBitsConsumed();
237     uint64_t end;
238     if (!SafeAdd(pos_after_ext_size_, total_extension_bits_, end)) {
239       return JXL_FAILURE("Invalid extension size, caused overflow");
240     }
241     if (bits_read > end) {
242       return JXL_FAILURE("Read more extension bits than budgeted");
243     }
244     const size_t remaining_bits = end - bits_read;
245     if (remaining_bits != 0) {
246       JXL_WARNING("Skipping %" PRIuS "-bit extension(s)", remaining_bits);
247       reader_->SkipBits(remaining_bits);
248       if (!reader_->AllReadsWithinBounds()) {
249         return JXL_STATUS(StatusCode::kNotEnoughBytes,
250                           "Not enough bytes for header");
251       }
252     }
253     return true;
254   }
255
256   Status OK() const { return ok_; }
257
258  private:
259   // Whether any error other than not enough bytes occurred.
260   bool ok_ = true;
261
262   // Whether there are enough input bytes to read from.
263   bool enough_bytes_ = true;
264   BitReader* const reader_;
265   // May be 0 even if the corresponding extension is present.
266   uint64_t extension_bits_[Bundle::kMaxExtensions] = {0};
267   uint64_t total_extension_bits_ = 0;
268   size_t pos_after_ext_size_ = 0;  // 0 iff extensions == 0.
269
270   friend Status jxl::CheckHasEnoughBits(Visitor*, size_t);
271 };
272
273 class MaxBitsVisitor : public VisitorBase {
274  public:
275   Status Bits(const size_t bits, const uint32_t /*default_value*/,
276               uint32_t* JXL_RESTRICT /*value*/) override {
277     max_bits_ += BitsCoder::MaxEncodedBits(bits);
278     return true;
279   }
280
281   Status U32(const U32Enc enc, const uint32_t /*default_value*/,
282              uint32_t* JXL_RESTRICT /*value*/) override {
283     max_bits_ += U32Coder::MaxEncodedBits(enc);
284     return true;
285   }
286
287   Status U64(const uint64_t /*default_value*/,
288              uint64_t* JXL_RESTRICT /*value*/) override {
289     max_bits_ += U64Coder::MaxEncodedBits();
290     return true;
291   }
292
293   Status F16(const float /*default_value*/,
294              float* JXL_RESTRICT /*value*/) override {
295     max_bits_ += F16Coder::MaxEncodedBits();
296     return true;
297   }
298
299   Status AllDefault(const Fields& /*fields*/,
300                     bool* JXL_RESTRICT all_default) override {
301     JXL_RETURN_IF_ERROR(Bool(true, all_default));
302     return false;  // For max bits, assume nothing is default
303   }
304
305   // Always visit conditional fields to get a (loose) upper bound.
306   Status Conditional(bool /*condition*/) override { return true; }
307
308   Status BeginExtensions(uint64_t* JXL_RESTRICT /*extensions*/) override {
309     // Skip - extensions are not included in "MaxBits" because their length
310     // is potentially unbounded.
311     return true;
312   }
313
314   Status EndExtensions() override { return true; }
315
316   size_t MaxBits() const { return max_bits_; }
317
318  private:
319   size_t max_bits_ = 0;
320 };
321
322 class CanEncodeVisitor : public VisitorBase {
323  public:
324   explicit CanEncodeVisitor() : VisitorBase() {}
325
326   Status Bits(const size_t bits, const uint32_t /*default_value*/,
327               uint32_t* JXL_RESTRICT value) override {
328     size_t encoded_bits = 0;
329     ok_ &= BitsCoder::CanEncode(bits, *value, &encoded_bits);
330     encoded_bits_ += encoded_bits;
331     return true;
332   }
333
334   Status U32(const U32Enc enc, const uint32_t /*default_value*/,
335              uint32_t* JXL_RESTRICT value) override {
336     size_t encoded_bits = 0;
337     ok_ &= U32Coder::CanEncode(enc, *value, &encoded_bits);
338     encoded_bits_ += encoded_bits;
339     return true;
340   }
341
342   Status U64(const uint64_t /*default_value*/,
343              uint64_t* JXL_RESTRICT value) override {
344     size_t encoded_bits = 0;
345     ok_ &= U64Coder::CanEncode(*value, &encoded_bits);
346     encoded_bits_ += encoded_bits;
347     return true;
348   }
349
350   Status F16(const float /*default_value*/,
351              float* JXL_RESTRICT value) override {
352     size_t encoded_bits = 0;
353     ok_ &= F16Coder::CanEncode(*value, &encoded_bits);
354     encoded_bits_ += encoded_bits;
355     return true;
356   }
357
358   Status AllDefault(const Fields& fields,
359                     bool* JXL_RESTRICT all_default) override {
360     *all_default = Bundle::AllDefault(fields);
361     JXL_RETURN_IF_ERROR(Bool(true, all_default));
362     return *all_default;
363   }
364
365   Status BeginExtensions(uint64_t* JXL_RESTRICT extensions) override {
366     JXL_QUIET_RETURN_IF_ERROR(VisitorBase::BeginExtensions(extensions));
367     extensions_ = *extensions;
368     if (*extensions != 0) {
369       JXL_ASSERT(pos_after_ext_ == 0);
370       pos_after_ext_ = encoded_bits_;
371       JXL_ASSERT(pos_after_ext_ != 0);  // visited "extensions"
372     }
373     return true;
374   }
375   // EndExtensions = default.
376
377   Status GetSizes(size_t* JXL_RESTRICT extension_bits,
378                   size_t* JXL_RESTRICT total_bits) {
379     JXL_RETURN_IF_ERROR(ok_);
380     *extension_bits = 0;
381     *total_bits = encoded_bits_;
382     // Only if extension field was nonzero will we encode their sizes.
383     if (pos_after_ext_ != 0) {
384       JXL_ASSERT(encoded_bits_ >= pos_after_ext_);
385       *extension_bits = encoded_bits_ - pos_after_ext_;
386       // Also need to encode *extension_bits and bill it to *total_bits.
387       size_t encoded_bits = 0;
388       ok_ &= U64Coder::CanEncode(*extension_bits, &encoded_bits);
389       *total_bits += encoded_bits;
390
391       // TODO(janwas): support encoding individual extension sizes. We
392       // currently ascribe all bits to the first and send zeros for the
393       // others.
394       for (size_t i = 1; i < hwy::PopCount(extensions_); ++i) {
395         encoded_bits = 0;
396         ok_ &= U64Coder::CanEncode(0, &encoded_bits);
397         *total_bits += encoded_bits;
398       }
399     }
400     return true;
401   }
402
403  private:
404   bool ok_ = true;
405   size_t encoded_bits_ = 0;
406   uint64_t extensions_ = 0;
407   // Snapshot of encoded_bits_ after visiting the extension field, but NOT
408   // including the hidden extension sizes.
409   uint64_t pos_after_ext_ = 0;
410 };
411 }  // namespace
412
413 void Bundle::Init(Fields* fields) {
414   InitVisitor visitor;
415   if (!visitor.Visit(fields)) {
416     JXL_UNREACHABLE("Init should never fail");
417   }
418 }
419 void Bundle::SetDefault(Fields* fields) {
420   SetDefaultVisitor visitor;
421   if (!visitor.Visit(fields)) {
422     JXL_UNREACHABLE("SetDefault should never fail");
423   }
424 }
425 bool Bundle::AllDefault(const Fields& fields) {
426   AllDefaultVisitor visitor;
427   if (!visitor.VisitConst(fields)) {
428     JXL_UNREACHABLE("AllDefault should never fail");
429   }
430   return visitor.AllDefault();
431 }
432 size_t Bundle::MaxBits(const Fields& fields) {
433   MaxBitsVisitor visitor;
434 #if JXL_ENABLE_ASSERT
435   Status ret =
436 #else
437   (void)
438 #endif  // JXL_ENABLE_ASSERT
439       visitor.VisitConst(fields);
440   JXL_ASSERT(ret);
441   return visitor.MaxBits();
442 }
443 Status Bundle::CanEncode(const Fields& fields, size_t* extension_bits,
444                          size_t* total_bits) {
445   CanEncodeVisitor visitor;
446   JXL_QUIET_RETURN_IF_ERROR(visitor.VisitConst(fields));
447   JXL_QUIET_RETURN_IF_ERROR(visitor.GetSizes(extension_bits, total_bits));
448   return true;
449 }
450 Status Bundle::Read(BitReader* reader, Fields* fields) {
451   ReadVisitor visitor(reader);
452   JXL_RETURN_IF_ERROR(visitor.Visit(fields));
453   return visitor.OK();
454 }
455 bool Bundle::CanRead(BitReader* reader, Fields* fields) {
456   ReadVisitor visitor(reader);
457   Status status = visitor.Visit(fields);
458   // We are only checking here whether there are enough bytes. We still return
459   // true for other errors because it means there are enough bytes to determine
460   // there's an error. Use Read() to determine which error it is.
461   return status.code() != StatusCode::kNotEnoughBytes;
462 }
463
464 size_t BitsCoder::MaxEncodedBits(const size_t bits) { return bits; }
465
466 Status BitsCoder::CanEncode(const size_t bits, const uint32_t value,
467                             size_t* JXL_RESTRICT encoded_bits) {
468   *encoded_bits = bits;
469   if (value >= (1ULL << bits)) {
470     return JXL_FAILURE("Value %u too large for %" PRIu64 " bits", value,
471                        static_cast<uint64_t>(bits));
472   }
473   return true;
474 }
475
476 uint32_t BitsCoder::Read(const size_t bits, BitReader* JXL_RESTRICT reader) {
477   return reader->ReadBits(bits);
478 }
479
480 size_t U32Coder::MaxEncodedBits(const U32Enc enc) {
481   size_t extra_bits = 0;
482   for (uint32_t selector = 0; selector < 4; ++selector) {
483     const U32Distr d = enc.GetDistr(selector);
484     if (d.IsDirect()) {
485       continue;
486     } else {
487       extra_bits = std::max<size_t>(extra_bits, d.ExtraBits());
488     }
489   }
490   return 2 + extra_bits;
491 }
492
493 Status U32Coder::CanEncode(const U32Enc enc, const uint32_t value,
494                            size_t* JXL_RESTRICT encoded_bits) {
495   uint32_t selector;
496   size_t total_bits;
497   const Status ok = ChooseSelector(enc, value, &selector, &total_bits);
498   *encoded_bits = ok ? total_bits : 0;
499   return ok;
500 }
501
502 uint32_t U32Coder::Read(const U32Enc enc, BitReader* JXL_RESTRICT reader) {
503   const uint32_t selector = reader->ReadFixedBits<2>();
504   const U32Distr d = enc.GetDistr(selector);
505   if (d.IsDirect()) {
506     return d.Direct();
507   } else {
508     return reader->ReadBits(d.ExtraBits()) + d.Offset();
509   }
510 }
511
512 Status U32Coder::ChooseSelector(const U32Enc enc, const uint32_t value,
513                                 uint32_t* JXL_RESTRICT selector,
514                                 size_t* JXL_RESTRICT total_bits) {
515 #if JXL_ENABLE_ASSERT
516   const size_t bits_required = 32 - Num0BitsAboveMS1Bit(value);
517 #endif  // JXL_ENABLE_ASSERT
518   JXL_ASSERT(bits_required <= 32);
519
520   *selector = 0;
521   *total_bits = 0;
522
523   // It is difficult to verify whether Dist32Byte are sorted, so check all
524   // selectors and keep the one with the fewest total_bits.
525   *total_bits = 64;  // more than any valid encoding
526   for (uint32_t s = 0; s < 4; ++s) {
527     const U32Distr d = enc.GetDistr(s);
528     if (d.IsDirect()) {
529       if (d.Direct() == value) {
530         *selector = s;
531         *total_bits = 2;
532         return true;  // Done, direct is always the best possible.
533       }
534       continue;
535     }
536     const size_t extra_bits = d.ExtraBits();
537     const uint32_t offset = d.Offset();
538     if (value < offset || value >= offset + (1ULL << extra_bits)) continue;
539
540     // Better than prior encoding, remember it:
541     if (2 + extra_bits < *total_bits) {
542       *selector = s;
543       *total_bits = 2 + extra_bits;
544     }
545   }
546
547   if (*total_bits == 64) {
548     return JXL_FAILURE("No feasible selector for %u", value);
549   }
550
551   return true;
552 }
553
554 uint64_t U64Coder::Read(BitReader* JXL_RESTRICT reader) {
555   uint64_t selector = reader->ReadFixedBits<2>();
556   if (selector == 0) {
557     return 0;
558   }
559   if (selector == 1) {
560     return 1 + reader->ReadFixedBits<4>();
561   }
562   if (selector == 2) {
563     return 17 + reader->ReadFixedBits<8>();
564   }
565
566   // selector 3, varint, groups have first 12, then 8, and last 4 bits.
567   uint64_t result = reader->ReadFixedBits<12>();
568
569   uint64_t shift = 12;
570   while (reader->ReadFixedBits<1>()) {
571     if (shift == 60) {
572       result |= static_cast<uint64_t>(reader->ReadFixedBits<4>()) << shift;
573       break;
574     }
575     result |= static_cast<uint64_t>(reader->ReadFixedBits<8>()) << shift;
576     shift += 8;
577   }
578
579   return result;
580 }
581
582 // Can always encode, but useful because it also returns bit size.
583 Status U64Coder::CanEncode(uint64_t value, size_t* JXL_RESTRICT encoded_bits) {
584   if (value == 0) {
585     *encoded_bits = 2;  // 2 selector bits
586   } else if (value <= 16) {
587     *encoded_bits = 2 + 4;  // 2 selector bits + 4 payload bits
588   } else if (value <= 272) {
589     *encoded_bits = 2 + 8;  // 2 selector bits + 8 payload bits
590   } else {
591     *encoded_bits = 2 + 12;  // 2 selector bits + 12 payload bits
592     value >>= 12;
593     int shift = 12;
594     while (value > 0 && shift < 60) {
595       *encoded_bits += 1 + 8;  // 1 continuation bit + 8 payload bits
596       value >>= 8;
597       shift += 8;
598     }
599     if (value > 0) {
600       // This only could happen if shift == N - 4.
601       *encoded_bits += 1 + 4;  // 1 continuation bit + 4 payload bits
602     } else {
603       *encoded_bits += 1;  // 1 stop bit
604     }
605   }
606
607   return true;
608 }
609
610 Status F16Coder::Read(BitReader* JXL_RESTRICT reader,
611                       float* JXL_RESTRICT value) {
612   const uint32_t bits16 = reader->ReadFixedBits<16>();
613   const uint32_t sign = bits16 >> 15;
614   const uint32_t biased_exp = (bits16 >> 10) & 0x1F;
615   const uint32_t mantissa = bits16 & 0x3FF;
616
617   if (JXL_UNLIKELY(biased_exp == 31)) {
618     return JXL_FAILURE("F16 infinity or NaN are not supported");
619   }
620
621   // Subnormal or zero
622   if (JXL_UNLIKELY(biased_exp == 0)) {
623     *value = (1.0f / 16384) * (mantissa * (1.0f / 1024));
624     if (sign) *value = -*value;
625     return true;
626   }
627
628   // Normalized: convert the representation directly (faster than ldexp/tables).
629   const uint32_t biased_exp32 = biased_exp + (127 - 15);
630   const uint32_t mantissa32 = mantissa << (23 - 10);
631   const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32;
632   memcpy(value, &bits32, sizeof(bits32));
633   return true;
634 }
635
636 Status F16Coder::CanEncode(float value, size_t* JXL_RESTRICT encoded_bits) {
637   *encoded_bits = MaxEncodedBits();
638   if (std::isnan(value) || std::isinf(value)) {
639     return JXL_FAILURE("Should not attempt to store NaN and infinity");
640   }
641   return std::abs(value) <= 65504.0f;
642 }
643
644 Status CheckHasEnoughBits(Visitor* visitor, size_t bits) {
645   if (!visitor->IsReading()) return false;
646   ReadVisitor* rv = static_cast<ReadVisitor*>(visitor);
647   size_t have_bits = rv->reader_->TotalBytes() * kBitsPerByte;
648   size_t want_bits = bits + rv->reader_->TotalBitsConsumed();
649   if (have_bits < want_bits) {
650     return JXL_STATUS(StatusCode::kNotEnoughBytes,
651                       "Not enough bytes for header");
652   }
653   return true;
654 }
655
656 }  // namespace jxl