1d9ac5ab4496bde15d7a574b5faa330584195323
[platform/upstream/ffmpeg.git] / libavcodec / cbs_av1.c
1 /*
2  * This file is part of FFmpeg.
3  *
4  * FFmpeg is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU Lesser General Public
6  * License as published by the Free Software Foundation; either
7  * version 2.1 of the License, or (at your option) any later version.
8  *
9  * FFmpeg is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12  * Lesser General Public License for more details.
13  *
14  * You should have received a copy of the GNU Lesser General Public
15  * License along with FFmpeg; if not, write to the Free Software
16  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17  */
18
19 #include "libavutil/avassert.h"
20 #include "libavutil/opt.h"
21 #include "libavutil/pixfmt.h"
22
23 #include "cbs.h"
24 #include "cbs_internal.h"
25 #include "cbs_av1.h"
26 #include "defs.h"
27 #include "refstruct.h"
28
29
30 static int cbs_av1_read_uvlc(CodedBitstreamContext *ctx, GetBitContext *gbc,
31                              const char *name, uint32_t *write_to,
32                              uint32_t range_min, uint32_t range_max)
33 {
34     uint32_t zeroes, bits_value, value;
35
36     CBS_TRACE_READ_START();
37
38     zeroes = 0;
39     while (1) {
40         if (get_bits_left(gbc) < 1) {
41             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid uvlc code at "
42                    "%s: bitstream ended.\n", name);
43             return AVERROR_INVALIDDATA;
44         }
45
46         if (get_bits1(gbc))
47             break;
48         ++zeroes;
49     }
50
51     if (zeroes >= 32) {
52         // Note that the spec allows an arbitrarily large number of
53         // zero bits followed by a one bit in this case, but the
54         // libaom implementation does not support it.
55         value = MAX_UINT_BITS(32);
56     } else {
57         if (get_bits_left(gbc) < zeroes) {
58             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid uvlc code at "
59                    "%s: bitstream ended.\n", name);
60             return AVERROR_INVALIDDATA;
61         }
62
63         bits_value = get_bits_long(gbc, zeroes);
64         value = bits_value + (UINT32_C(1) << zeroes) - 1;
65     }
66
67     CBS_TRACE_READ_END_NO_SUBSCRIPTS();
68
69     if (value < range_min || value > range_max) {
70         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
71                "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
72                name, value, range_min, range_max);
73         return AVERROR_INVALIDDATA;
74     }
75
76     *write_to = value;
77     return 0;
78 }
79
80 static int cbs_av1_write_uvlc(CodedBitstreamContext *ctx, PutBitContext *pbc,
81                               const char *name, uint32_t value,
82                               uint32_t range_min, uint32_t range_max)
83 {
84     uint32_t v;
85     int zeroes;
86
87     CBS_TRACE_WRITE_START();
88
89     if (value < range_min || value > range_max) {
90         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
91                "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
92                name, value, range_min, range_max);
93         return AVERROR_INVALIDDATA;
94     }
95
96     zeroes = av_log2(value + 1);
97     v = value - (1U << zeroes) + 1;
98
99     if (put_bits_left(pbc) < 2 * zeroes + 1)
100         return AVERROR(ENOSPC);
101
102     put_bits(pbc, zeroes, 0);
103     put_bits(pbc, 1, 1);
104     put_bits(pbc, zeroes, v);
105
106     CBS_TRACE_WRITE_END_NO_SUBSCRIPTS();
107
108     return 0;
109 }
110
111 static int cbs_av1_read_leb128(CodedBitstreamContext *ctx, GetBitContext *gbc,
112                                const char *name, uint64_t *write_to)
113 {
114     uint64_t value;
115     uint32_t byte;
116     int i;
117
118     CBS_TRACE_READ_START();
119
120     value = 0;
121     for (i = 0; i < 8; i++) {
122         if (get_bits_left(gbc) < 8) {
123             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid leb128 at "
124                    "%s: bitstream ended.\n", name);
125             return AVERROR_INVALIDDATA;
126         }
127         byte = get_bits(gbc, 8);
128         value |= (uint64_t)(byte & 0x7f) << (i * 7);
129         if (!(byte & 0x80))
130             break;
131     }
132
133     if (value > UINT32_MAX)
134         return AVERROR_INVALIDDATA;
135
136     CBS_TRACE_READ_END_NO_SUBSCRIPTS();
137
138     *write_to = value;
139     return 0;
140 }
141
142 static int cbs_av1_write_leb128(CodedBitstreamContext *ctx, PutBitContext *pbc,
143                                 const char *name, uint64_t value, int fixed_length)
144 {
145     int len, i;
146     uint8_t byte;
147
148     CBS_TRACE_WRITE_START();
149
150     len = (av_log2(value) + 7) / 7;
151
152     if (fixed_length) {
153         if (fixed_length < len) {
154             av_log(ctx->log_ctx, AV_LOG_ERROR, "OBU is too large for "
155                    "fixed length size field (%d > %d).\n",
156                    len, fixed_length);
157             return AVERROR(EINVAL);
158         }
159         len = fixed_length;
160     }
161
162     for (i = 0; i < len; i++) {
163         if (put_bits_left(pbc) < 8)
164             return AVERROR(ENOSPC);
165
166         byte = value >> (7 * i) & 0x7f;
167         if (i < len - 1)
168             byte |= 0x80;
169
170         put_bits(pbc, 8, byte);
171     }
172
173     CBS_TRACE_WRITE_END_NO_SUBSCRIPTS();
174
175     return 0;
176 }
177
178 static int cbs_av1_read_ns(CodedBitstreamContext *ctx, GetBitContext *gbc,
179                            uint32_t n, const char *name,
180                            const int *subscripts, uint32_t *write_to)
181 {
182     uint32_t m, v, extra_bit, value;
183     int w;
184
185     CBS_TRACE_READ_START();
186
187     av_assert0(n > 0);
188
189     w = av_log2(n) + 1;
190     m = (1 << w) - n;
191
192     if (get_bits_left(gbc) < w) {
193         av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid non-symmetric value at "
194                "%s: bitstream ended.\n", name);
195         return AVERROR_INVALIDDATA;
196     }
197
198     if (w - 1 > 0)
199         v = get_bits(gbc, w - 1);
200     else
201         v = 0;
202
203     if (v < m) {
204         value = v;
205     } else {
206         extra_bit = get_bits1(gbc);
207         value = (v << 1) - m + extra_bit;
208     }
209
210     CBS_TRACE_READ_END();
211
212     *write_to = value;
213     return 0;
214 }
215
216 static int cbs_av1_write_ns(CodedBitstreamContext *ctx, PutBitContext *pbc,
217                             uint32_t n, const char *name,
218                             const int *subscripts, uint32_t value)
219 {
220     uint32_t w, m, v, extra_bit;
221
222     CBS_TRACE_WRITE_START();
223
224     if (value > n) {
225         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
226                "%"PRIu32", but must be in [0,%"PRIu32"].\n",
227                name, value, n);
228         return AVERROR_INVALIDDATA;
229     }
230
231     w = av_log2(n) + 1;
232     m = (1 << w) - n;
233
234     if (put_bits_left(pbc) < w)
235         return AVERROR(ENOSPC);
236
237     if (value < m) {
238         v = value;
239         put_bits(pbc, w - 1, v);
240     } else {
241         v = m + ((value - m) >> 1);
242         extra_bit = (value - m) & 1;
243         put_bits(pbc, w - 1, v);
244         put_bits(pbc, 1, extra_bit);
245     }
246
247     CBS_TRACE_WRITE_END();
248
249     return 0;
250 }
251
252 static int cbs_av1_read_increment(CodedBitstreamContext *ctx, GetBitContext *gbc,
253                                   uint32_t range_min, uint32_t range_max,
254                                   const char *name, uint32_t *write_to)
255 {
256     uint32_t value;
257
258     CBS_TRACE_READ_START();
259
260     av_assert0(range_min <= range_max && range_max - range_min < 32);
261
262     for (value = range_min; value < range_max;) {
263         if (get_bits_left(gbc) < 1) {
264             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid increment value at "
265                    "%s: bitstream ended.\n", name);
266             return AVERROR_INVALIDDATA;
267         }
268         if (get_bits1(gbc))
269             ++value;
270         else
271             break;
272     }
273
274     CBS_TRACE_READ_END_NO_SUBSCRIPTS();
275
276     *write_to = value;
277     return 0;
278 }
279
280 static int cbs_av1_write_increment(CodedBitstreamContext *ctx, PutBitContext *pbc,
281                                    uint32_t range_min, uint32_t range_max,
282                                    const char *name, uint32_t value)
283 {
284     int len;
285
286     CBS_TRACE_WRITE_START();
287
288     av_assert0(range_min <= range_max && range_max - range_min < 32);
289     if (value < range_min || value > range_max) {
290         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
291                "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
292                name, value, range_min, range_max);
293         return AVERROR_INVALIDDATA;
294     }
295
296     if (value == range_max)
297         len = range_max - range_min;
298     else
299         len = value - range_min + 1;
300     if (put_bits_left(pbc) < len)
301         return AVERROR(ENOSPC);
302
303     if (len > 0)
304         put_bits(pbc, len, (1 << len) - 1 - (value != range_max));
305
306     CBS_TRACE_WRITE_END_NO_SUBSCRIPTS();
307
308     return 0;
309 }
310
311 static int cbs_av1_read_subexp(CodedBitstreamContext *ctx, GetBitContext *gbc,
312                                uint32_t range_max, const char *name,
313                                const int *subscripts, uint32_t *write_to)
314 {
315     uint32_t value, max_len, len, range_offset, range_bits;
316     int err;
317
318     CBS_TRACE_READ_START();
319
320     av_assert0(range_max > 0);
321     max_len = av_log2(range_max - 1) - 3;
322
323     err = cbs_av1_read_increment(ctx, gbc, 0, max_len,
324                                  "subexp_more_bits", &len);
325     if (err < 0)
326         return err;
327
328     if (len) {
329         range_bits   = 2 + len;
330         range_offset = 1 << range_bits;
331     } else {
332         range_bits   = 3;
333         range_offset = 0;
334     }
335
336     if (len < max_len) {
337         err = ff_cbs_read_simple_unsigned(ctx, gbc, range_bits,
338                                           "subexp_bits", &value);
339         if (err < 0)
340             return err;
341
342     } else {
343         err = cbs_av1_read_ns(ctx, gbc, range_max - range_offset,
344                               "subexp_final_bits", NULL, &value);
345         if (err < 0)
346             return err;
347     }
348     value += range_offset;
349
350     CBS_TRACE_READ_END_VALUE_ONLY();
351
352     *write_to = value;
353     return err;
354 }
355
356 static int cbs_av1_write_subexp(CodedBitstreamContext *ctx, PutBitContext *pbc,
357                                 uint32_t range_max, const char *name,
358                                 const int *subscripts, uint32_t value)
359 {
360     int err;
361     uint32_t max_len, len, range_offset, range_bits;
362
363     CBS_TRACE_WRITE_START();
364
365     if (value > range_max) {
366         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
367                "%"PRIu32", but must be in [0,%"PRIu32"].\n",
368                name, value, range_max);
369         return AVERROR_INVALIDDATA;
370     }
371
372     av_assert0(range_max > 0);
373     max_len = av_log2(range_max - 1) - 3;
374
375     if (value < 8) {
376         range_bits   = 3;
377         range_offset = 0;
378         len = 0;
379     } else {
380         range_bits = av_log2(value);
381         len = range_bits - 2;
382         if (len > max_len) {
383             // The top bin is combined with the one below it.
384             av_assert0(len == max_len + 1);
385             --range_bits;
386             len = max_len;
387         }
388         range_offset = 1 << range_bits;
389     }
390
391     err = cbs_av1_write_increment(ctx, pbc, 0, max_len,
392                                   "subexp_more_bits", len);
393     if (err < 0)
394         return err;
395
396     if (len < max_len) {
397         err = ff_cbs_write_simple_unsigned(ctx, pbc, range_bits,
398                                            "subexp_bits",
399                                            value - range_offset);
400         if (err < 0)
401             return err;
402
403     } else {
404         err = cbs_av1_write_ns(ctx, pbc, range_max - range_offset,
405                                "subexp_final_bits", NULL,
406                                value - range_offset);
407         if (err < 0)
408             return err;
409     }
410
411     CBS_TRACE_WRITE_END_VALUE_ONLY();
412
413     return err;
414 }
415
416
417 static int cbs_av1_tile_log2(int blksize, int target)
418 {
419     int k;
420     for (k = 0; (blksize << k) < target; k++);
421     return k;
422 }
423
424 static int cbs_av1_get_relative_dist(const AV1RawSequenceHeader *seq,
425                                      unsigned int a, unsigned int b)
426 {
427     unsigned int diff, m;
428     if (!seq->enable_order_hint)
429         return 0;
430     diff = a - b;
431     m = 1 << seq->order_hint_bits_minus_1;
432     diff = (diff & (m - 1)) - (diff & m);
433     return diff;
434 }
435
436 static size_t cbs_av1_get_payload_bytes_left(GetBitContext *gbc)
437 {
438     GetBitContext tmp = *gbc;
439     size_t size = 0;
440     for (int i = 0; get_bits_left(&tmp) >= 8; i++) {
441         if (get_bits(&tmp, 8))
442             size = i;
443     }
444     return size;
445 }
446
447
448 #define HEADER(name) do { \
449         ff_cbs_trace_header(ctx, name); \
450     } while (0)
451
452 #define CHECK(call) do { \
453         err = (call); \
454         if (err < 0) \
455             return err; \
456     } while (0)
457
458 #define FUNC_NAME(rw, codec, name) cbs_ ## codec ## _ ## rw ## _ ## name
459 #define FUNC_AV1(rw, name) FUNC_NAME(rw, av1, name)
460 #define FUNC(name) FUNC_AV1(READWRITE, name)
461
462 #define SUBSCRIPTS(subs, ...) (subs > 0 ? ((int[subs + 1]){ subs, __VA_ARGS__ }) : NULL)
463
464 #define fc(width, name, range_min, range_max) \
465         xf(width, name, current->name, range_min, range_max, 0, )
466 #define flag(name) fb(1, name)
467 #define su(width, name) \
468         xsu(width, name, current->name, 0, )
469
470 #define fbs(width, name, subs, ...) \
471         xf(width, name, current->name, 0, MAX_UINT_BITS(width), subs, __VA_ARGS__)
472 #define fcs(width, name, range_min, range_max, subs, ...) \
473         xf(width, name, current->name, range_min, range_max, subs, __VA_ARGS__)
474 #define flags(name, subs, ...) \
475         xf(1, name, current->name, 0, 1, subs, __VA_ARGS__)
476 #define sus(width, name, subs, ...) \
477         xsu(width, name, current->name, subs, __VA_ARGS__)
478
479 #define fixed(width, name, value) do { \
480         av_unused uint32_t fixed_value = value; \
481         xf(width, name, fixed_value, value, value, 0, ); \
482     } while (0)
483
484
485 #define READ
486 #define READWRITE read
487 #define RWContext GetBitContext
488
489 #define fb(width, name) do { \
490         uint32_t value; \
491         CHECK(ff_cbs_read_simple_unsigned(ctx, rw, width, \
492                                           #name, &value)); \
493         current->name = value; \
494     } while (0)
495
496 #define xf(width, name, var, range_min, range_max, subs, ...) do { \
497         uint32_t value; \
498         CHECK(ff_cbs_read_unsigned(ctx, rw, width, #name, \
499                                    SUBSCRIPTS(subs, __VA_ARGS__), \
500                                    &value, range_min, range_max)); \
501         var = value; \
502     } while (0)
503
504 #define xsu(width, name, var, subs, ...) do { \
505         int32_t value; \
506         CHECK(ff_cbs_read_signed(ctx, rw, width, #name, \
507                                  SUBSCRIPTS(subs, __VA_ARGS__), &value, \
508                                  MIN_INT_BITS(width), \
509                                  MAX_INT_BITS(width))); \
510         var = value; \
511     } while (0)
512
513 #define uvlc(name, range_min, range_max) do { \
514         uint32_t value; \
515         CHECK(cbs_av1_read_uvlc(ctx, rw, #name, \
516                                 &value, range_min, range_max)); \
517         current->name = value; \
518     } while (0)
519
520 #define ns(max_value, name, subs, ...) do { \
521         uint32_t value; \
522         CHECK(cbs_av1_read_ns(ctx, rw, max_value, #name, \
523                               SUBSCRIPTS(subs, __VA_ARGS__), &value)); \
524         current->name = value; \
525     } while (0)
526
527 #define increment(name, min, max) do { \
528         uint32_t value; \
529         CHECK(cbs_av1_read_increment(ctx, rw, min, max, #name, &value)); \
530         current->name = value; \
531     } while (0)
532
533 #define subexp(name, max, subs, ...) do { \
534         uint32_t value; \
535         CHECK(cbs_av1_read_subexp(ctx, rw, max, #name, \
536                                   SUBSCRIPTS(subs, __VA_ARGS__), &value)); \
537         current->name = value; \
538     } while (0)
539
540 #define delta_q(name) do { \
541         uint8_t delta_coded; \
542         int8_t delta_q; \
543         xf(1, name.delta_coded, delta_coded, 0, 1, 0, ); \
544         if (delta_coded) \
545             xsu(1 + 6, name.delta_q, delta_q, 0, ); \
546         else \
547             delta_q = 0; \
548         current->name = delta_q; \
549     } while (0)
550
551 #define leb128(name) do { \
552         uint64_t value; \
553         CHECK(cbs_av1_read_leb128(ctx, rw, #name, &value)); \
554         current->name = value; \
555     } while (0)
556
557 #define infer(name, value) do { \
558         current->name = value; \
559     } while (0)
560
561 #define byte_alignment(rw) (get_bits_count(rw) % 8)
562
563 #include "cbs_av1_syntax_template.c"
564
565 #undef READ
566 #undef READWRITE
567 #undef RWContext
568 #undef fb
569 #undef xf
570 #undef xsu
571 #undef uvlc
572 #undef ns
573 #undef increment
574 #undef subexp
575 #undef delta_q
576 #undef leb128
577 #undef infer
578 #undef byte_alignment
579
580
581 #define WRITE
582 #define READWRITE write
583 #define RWContext PutBitContext
584
585 #define fb(width, name) do { \
586         CHECK(ff_cbs_write_simple_unsigned(ctx, rw, width, #name, \
587                                            current->name)); \
588     } while (0)
589
590 #define xf(width, name, var, range_min, range_max, subs, ...) do { \
591         CHECK(ff_cbs_write_unsigned(ctx, rw, width, #name, \
592                                     SUBSCRIPTS(subs, __VA_ARGS__), \
593                                     var, range_min, range_max)); \
594     } while (0)
595
596 #define xsu(width, name, var, subs, ...) do { \
597         CHECK(ff_cbs_write_signed(ctx, rw, width, #name, \
598                                   SUBSCRIPTS(subs, __VA_ARGS__), var, \
599                                   MIN_INT_BITS(width), \
600                                   MAX_INT_BITS(width))); \
601     } while (0)
602
603 #define uvlc(name, range_min, range_max) do { \
604         CHECK(cbs_av1_write_uvlc(ctx, rw, #name, current->name, \
605                                  range_min, range_max)); \
606     } while (0)
607
608 #define ns(max_value, name, subs, ...) do { \
609         CHECK(cbs_av1_write_ns(ctx, rw, max_value, #name, \
610                                SUBSCRIPTS(subs, __VA_ARGS__), \
611                                current->name)); \
612     } while (0)
613
614 #define increment(name, min, max) do { \
615         CHECK(cbs_av1_write_increment(ctx, rw, min, max, #name, \
616                                       current->name)); \
617     } while (0)
618
619 #define subexp(name, max, subs, ...) do { \
620         CHECK(cbs_av1_write_subexp(ctx, rw, max, #name, \
621                                    SUBSCRIPTS(subs, __VA_ARGS__), \
622                                    current->name)); \
623     } while (0)
624
625 #define delta_q(name) do { \
626         xf(1, name.delta_coded, current->name != 0, 0, 1, 0, ); \
627         if (current->name) \
628             xsu(1 + 6, name.delta_q, current->name, 0, ); \
629     } while (0)
630
631 #define leb128(name) do { \
632         CHECK(cbs_av1_write_leb128(ctx, rw, #name, current->name, 0)); \
633     } while (0)
634
635 #define infer(name, value) do { \
636         if (current->name != (value)) { \
637             av_log(ctx->log_ctx, AV_LOG_ERROR, \
638                    "%s does not match inferred value: " \
639                    "%"PRId64", but should be %"PRId64".\n", \
640                    #name, (int64_t)current->name, (int64_t)(value)); \
641             return AVERROR_INVALIDDATA; \
642         } \
643     } while (0)
644
645 #define byte_alignment(rw) (put_bits_count(rw) % 8)
646
647 #include "cbs_av1_syntax_template.c"
648
649 #undef WRITE
650 #undef READWRITE
651 #undef RWContext
652 #undef fb
653 #undef xf
654 #undef xsu
655 #undef uvlc
656 #undef ns
657 #undef increment
658 #undef subexp
659 #undef delta_q
660 #undef leb128
661 #undef infer
662 #undef byte_alignment
663
664
665 static int cbs_av1_split_fragment(CodedBitstreamContext *ctx,
666                                   CodedBitstreamFragment *frag,
667                                   int header)
668 {
669     GetBitContext gbc;
670     uint8_t *data;
671     size_t size;
672     uint64_t obu_length;
673     int pos, err, trace;
674
675     // Don't include this parsing in trace output.
676     trace = ctx->trace_enable;
677     ctx->trace_enable = 0;
678
679     data = frag->data;
680     size = frag->data_size;
681
682     if (INT_MAX / 8 < size) {
683         av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid fragment: "
684                "too large (%"SIZE_SPECIFIER" bytes).\n", size);
685         err = AVERROR_INVALIDDATA;
686         goto fail;
687     }
688
689     if (header && size && data[0] & 0x80) {
690         // first bit is nonzero, the extradata does not consist purely of
691         // OBUs. Expect MP4/Matroska AV1CodecConfigurationRecord
692         int config_record_version = data[0] & 0x7f;
693
694         if (config_record_version != 1) {
695             av_log(ctx->log_ctx, AV_LOG_ERROR,
696                    "Unknown version %d of AV1CodecConfigurationRecord "
697                    "found!\n",
698                    config_record_version);
699             err = AVERROR_INVALIDDATA;
700             goto fail;
701         }
702
703         if (size <= 4) {
704             if (size < 4) {
705                 av_log(ctx->log_ctx, AV_LOG_WARNING,
706                        "Undersized AV1CodecConfigurationRecord v%d found!\n",
707                        config_record_version);
708                 err = AVERROR_INVALIDDATA;
709                 goto fail;
710             }
711
712             goto success;
713         }
714
715         // In AV1CodecConfigurationRecord v1, actual OBUs start after
716         // four bytes. Thus set the offset as required for properly
717         // parsing them.
718         data += 4;
719         size -= 4;
720     }
721
722     while (size > 0) {
723         AV1RawOBUHeader header;
724         uint64_t obu_size;
725
726         init_get_bits(&gbc, data, 8 * size);
727
728         err = cbs_av1_read_obu_header(ctx, &gbc, &header);
729         if (err < 0)
730             goto fail;
731
732         if (header.obu_has_size_field) {
733             if (get_bits_left(&gbc) < 8) {
734                 av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU: fragment "
735                        "too short (%"SIZE_SPECIFIER" bytes).\n", size);
736                 err = AVERROR_INVALIDDATA;
737                 goto fail;
738             }
739             err = cbs_av1_read_leb128(ctx, &gbc, "obu_size", &obu_size);
740             if (err < 0)
741                 goto fail;
742         } else
743             obu_size = size - 1 - header.obu_extension_flag;
744
745         pos = get_bits_count(&gbc);
746         av_assert0(pos % 8 == 0 && pos / 8 <= size);
747
748         obu_length = pos / 8 + obu_size;
749
750         if (size < obu_length) {
751             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU length: "
752                    "%"PRIu64", but only %"SIZE_SPECIFIER" bytes remaining in fragment.\n",
753                    obu_length, size);
754             err = AVERROR_INVALIDDATA;
755             goto fail;
756         }
757
758         err = ff_cbs_append_unit_data(frag, header.obu_type,
759                                       data, obu_length, frag->data_ref);
760         if (err < 0)
761             goto fail;
762
763         data += obu_length;
764         size -= obu_length;
765     }
766
767 success:
768     err = 0;
769 fail:
770     ctx->trace_enable = trace;
771     return err;
772 }
773
774 static int cbs_av1_ref_tile_data(CodedBitstreamContext *ctx,
775                                  CodedBitstreamUnit *unit,
776                                  GetBitContext *gbc,
777                                  AV1RawTileData *td)
778 {
779     int pos;
780
781     pos = get_bits_count(gbc);
782     if (pos >= 8 * unit->data_size) {
783         av_log(ctx->log_ctx, AV_LOG_ERROR, "Bitstream ended before "
784                "any data in tile group (%d bits read).\n", pos);
785         return AVERROR_INVALIDDATA;
786     }
787     // Must be byte-aligned at this point.
788     av_assert0(pos % 8 == 0);
789
790     td->data_ref = av_buffer_ref(unit->data_ref);
791     if (!td->data_ref)
792         return AVERROR(ENOMEM);
793
794     td->data      = unit->data      + pos / 8;
795     td->data_size = unit->data_size - pos / 8;
796
797     return 0;
798 }
799
800 static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
801                              CodedBitstreamUnit *unit)
802 {
803     CodedBitstreamAV1Context *priv = ctx->priv_data;
804     AV1RawOBU *obu;
805     GetBitContext gbc;
806     int err, start_pos, end_pos;
807
808     err = ff_cbs_alloc_unit_content(ctx, unit);
809     if (err < 0)
810         return err;
811     obu = unit->content;
812
813     err = init_get_bits(&gbc, unit->data, 8 * unit->data_size);
814     if (err < 0)
815         return err;
816
817     err = cbs_av1_read_obu_header(ctx, &gbc, &obu->header);
818     if (err < 0)
819         return err;
820     av_assert0(obu->header.obu_type == unit->type);
821
822     if (obu->header.obu_has_size_field) {
823         uint64_t obu_size;
824         err = cbs_av1_read_leb128(ctx, &gbc, "obu_size", &obu_size);
825         if (err < 0)
826             return err;
827         obu->obu_size = obu_size;
828     } else {
829         if (unit->data_size < 1 + obu->header.obu_extension_flag) {
830             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid OBU length: "
831                    "unit too short (%"SIZE_SPECIFIER").\n", unit->data_size);
832             return AVERROR_INVALIDDATA;
833         }
834         obu->obu_size = unit->data_size - 1 - obu->header.obu_extension_flag;
835     }
836
837     start_pos = get_bits_count(&gbc);
838
839     if (obu->header.obu_extension_flag) {
840         if (obu->header.obu_type != AV1_OBU_SEQUENCE_HEADER &&
841             obu->header.obu_type != AV1_OBU_TEMPORAL_DELIMITER &&
842             priv->operating_point_idc) {
843             int in_temporal_layer =
844                 (priv->operating_point_idc >>  priv->temporal_id    ) & 1;
845             int in_spatial_layer  =
846                 (priv->operating_point_idc >> (priv->spatial_id + 8)) & 1;
847             if (!in_temporal_layer || !in_spatial_layer) {
848                 return AVERROR(EAGAIN); // drop_obu()
849             }
850         }
851     }
852
853     switch (obu->header.obu_type) {
854     case AV1_OBU_SEQUENCE_HEADER:
855         {
856             err = cbs_av1_read_sequence_header_obu(ctx, &gbc,
857                                                    &obu->obu.sequence_header);
858             if (err < 0)
859                 return err;
860
861             if (priv->operating_point >= 0) {
862                 AV1RawSequenceHeader *sequence_header = &obu->obu.sequence_header;
863
864                 if (priv->operating_point > sequence_header->operating_points_cnt_minus_1) {
865                     av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid Operating Point %d requested. "
866                                                        "Must not be higher than %u.\n",
867                            priv->operating_point, sequence_header->operating_points_cnt_minus_1);
868                     return AVERROR(EINVAL);
869                 }
870                 priv->operating_point_idc = sequence_header->operating_point_idc[priv->operating_point];
871             }
872
873             ff_refstruct_replace(&priv->sequence_header_ref, unit->content_ref);
874             priv->sequence_header = &obu->obu.sequence_header;
875         }
876         break;
877     case AV1_OBU_TEMPORAL_DELIMITER:
878         {
879             err = cbs_av1_read_temporal_delimiter_obu(ctx, &gbc);
880             if (err < 0)
881                 return err;
882         }
883         break;
884     case AV1_OBU_FRAME_HEADER:
885     case AV1_OBU_REDUNDANT_FRAME_HEADER:
886         {
887             err = cbs_av1_read_frame_header_obu(ctx, &gbc,
888                                                 &obu->obu.frame_header,
889                                                 obu->header.obu_type ==
890                                                 AV1_OBU_REDUNDANT_FRAME_HEADER,
891                                                 unit->data_ref);
892             if (err < 0)
893                 return err;
894         }
895         break;
896     case AV1_OBU_TILE_GROUP:
897         {
898             err = cbs_av1_read_tile_group_obu(ctx, &gbc,
899                                               &obu->obu.tile_group);
900             if (err < 0)
901                 return err;
902
903             err = cbs_av1_ref_tile_data(ctx, unit, &gbc,
904                                         &obu->obu.tile_group.tile_data);
905             if (err < 0)
906                 return err;
907         }
908         break;
909     case AV1_OBU_FRAME:
910         {
911             err = cbs_av1_read_frame_obu(ctx, &gbc, &obu->obu.frame,
912                                          unit->data_ref);
913             if (err < 0)
914                 return err;
915
916             err = cbs_av1_ref_tile_data(ctx, unit, &gbc,
917                                         &obu->obu.frame.tile_group.tile_data);
918             if (err < 0)
919                 return err;
920         }
921         break;
922     case AV1_OBU_TILE_LIST:
923         {
924             err = cbs_av1_read_tile_list_obu(ctx, &gbc,
925                                              &obu->obu.tile_list);
926             if (err < 0)
927                 return err;
928
929             err = cbs_av1_ref_tile_data(ctx, unit, &gbc,
930                                         &obu->obu.tile_list.tile_data);
931             if (err < 0)
932                 return err;
933         }
934         break;
935     case AV1_OBU_METADATA:
936         {
937             err = cbs_av1_read_metadata_obu(ctx, &gbc, &obu->obu.metadata);
938             if (err < 0)
939                 return err;
940         }
941         break;
942     case AV1_OBU_PADDING:
943         {
944             err = cbs_av1_read_padding_obu(ctx, &gbc, &obu->obu.padding);
945             if (err < 0)
946                 return err;
947         }
948         break;
949     default:
950         return AVERROR(ENOSYS);
951     }
952
953     end_pos = get_bits_count(&gbc);
954     av_assert0(end_pos <= unit->data_size * 8);
955
956     if (obu->obu_size > 0 &&
957         obu->header.obu_type != AV1_OBU_TILE_GROUP &&
958         obu->header.obu_type != AV1_OBU_TILE_LIST &&
959         obu->header.obu_type != AV1_OBU_FRAME) {
960         int nb_bits = obu->obu_size * 8 + start_pos - end_pos;
961
962         if (nb_bits <= 0)
963             return AVERROR_INVALIDDATA;
964
965         err = cbs_av1_read_trailing_bits(ctx, &gbc, nb_bits);
966         if (err < 0)
967             return err;
968     }
969
970     return 0;
971 }
972
973 static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
974                              CodedBitstreamUnit *unit,
975                              PutBitContext *pbc)
976 {
977     CodedBitstreamAV1Context *priv = ctx->priv_data;
978     AV1RawOBU *obu = unit->content;
979     PutBitContext pbc_tmp;
980     AV1RawTileData *td;
981     size_t header_size;
982     int err, start_pos, end_pos, data_pos;
983     CodedBitstreamAV1Context av1ctx;
984
985     // OBUs in the normal bitstream format must contain a size field
986     // in every OBU (in annex B it is optional, but we don't support
987     // writing that).
988     obu->header.obu_has_size_field = 1;
989     av1ctx = *priv;
990
991     if (priv->sequence_header_ref) {
992         av1ctx.sequence_header_ref = ff_refstruct_ref(priv->sequence_header_ref);
993     }
994
995     if (priv->frame_header_ref) {
996         av1ctx.frame_header_ref = av_buffer_ref(priv->frame_header_ref);
997         if (!av1ctx.frame_header_ref) {
998             err = AVERROR(ENOMEM);
999             goto error;
1000         }
1001     }
1002
1003     err = cbs_av1_write_obu_header(ctx, pbc, &obu->header);
1004     if (err < 0)
1005         goto error;
1006
1007     if (obu->header.obu_has_size_field) {
1008         pbc_tmp = *pbc;
1009         if (priv->fixed_obu_size_length) {
1010             for (int i = 0; i < priv->fixed_obu_size_length; i++)
1011                 put_bits(pbc, 8, 0);
1012         } else {
1013             // Add space for the size field to fill later.
1014             put_bits32(pbc, 0);
1015             put_bits32(pbc, 0);
1016         }
1017     }
1018
1019     td = NULL;
1020     start_pos = put_bits_count(pbc);
1021
1022     switch (obu->header.obu_type) {
1023     case AV1_OBU_SEQUENCE_HEADER:
1024         {
1025             err = cbs_av1_write_sequence_header_obu(ctx, pbc,
1026                                                     &obu->obu.sequence_header);
1027             if (err < 0)
1028                 goto error;
1029
1030             ff_refstruct_unref(&priv->sequence_header_ref);
1031             priv->sequence_header = NULL;
1032
1033             err = ff_cbs_make_unit_refcounted(ctx, unit);
1034             if (err < 0)
1035                 goto error;
1036
1037             priv->sequence_header_ref = ff_refstruct_ref(unit->content_ref);
1038             priv->sequence_header = &obu->obu.sequence_header;
1039         }
1040         break;
1041     case AV1_OBU_TEMPORAL_DELIMITER:
1042         {
1043             err = cbs_av1_write_temporal_delimiter_obu(ctx, pbc);
1044             if (err < 0)
1045                 goto error;
1046         }
1047         break;
1048     case AV1_OBU_FRAME_HEADER:
1049     case AV1_OBU_REDUNDANT_FRAME_HEADER:
1050         {
1051             err = cbs_av1_write_frame_header_obu(ctx, pbc,
1052                                                  &obu->obu.frame_header,
1053                                                  obu->header.obu_type ==
1054                                                  AV1_OBU_REDUNDANT_FRAME_HEADER,
1055                                                  NULL);
1056             if (err < 0)
1057                 goto error;
1058         }
1059         break;
1060     case AV1_OBU_TILE_GROUP:
1061         {
1062             err = cbs_av1_write_tile_group_obu(ctx, pbc,
1063                                                &obu->obu.tile_group);
1064             if (err < 0)
1065                 goto error;
1066
1067             td = &obu->obu.tile_group.tile_data;
1068         }
1069         break;
1070     case AV1_OBU_FRAME:
1071         {
1072             err = cbs_av1_write_frame_obu(ctx, pbc, &obu->obu.frame, NULL);
1073             if (err < 0)
1074                 goto error;
1075
1076             td = &obu->obu.frame.tile_group.tile_data;
1077         }
1078         break;
1079     case AV1_OBU_TILE_LIST:
1080         {
1081             err = cbs_av1_write_tile_list_obu(ctx, pbc, &obu->obu.tile_list);
1082             if (err < 0)
1083                 goto error;
1084
1085             td = &obu->obu.tile_list.tile_data;
1086         }
1087         break;
1088     case AV1_OBU_METADATA:
1089         {
1090             err = cbs_av1_write_metadata_obu(ctx, pbc, &obu->obu.metadata);
1091             if (err < 0)
1092                 goto error;
1093         }
1094         break;
1095     case AV1_OBU_PADDING:
1096         {
1097             err = cbs_av1_write_padding_obu(ctx, pbc, &obu->obu.padding);
1098             if (err < 0)
1099                 goto error;
1100         }
1101         break;
1102     default:
1103         err = AVERROR(ENOSYS);
1104         goto error;
1105     }
1106
1107     end_pos = put_bits_count(pbc);
1108     header_size = (end_pos - start_pos + 7) / 8;
1109     if (td) {
1110         obu->obu_size = header_size + td->data_size;
1111     } else if (header_size > 0) {
1112         // Add trailing bits and recalculate.
1113         err = cbs_av1_write_trailing_bits(ctx, pbc, 8 - end_pos % 8);
1114         if (err < 0)
1115             goto error;
1116         end_pos = put_bits_count(pbc);
1117         obu->obu_size = header_size = (end_pos - start_pos + 7) / 8;
1118     } else {
1119         // Empty OBU.
1120         obu->obu_size = 0;
1121     }
1122
1123     end_pos = put_bits_count(pbc);
1124     // Must now be byte-aligned.
1125     av_assert0(end_pos % 8 == 0);
1126     flush_put_bits(pbc);
1127     start_pos /= 8;
1128     end_pos   /= 8;
1129
1130     *pbc = pbc_tmp;
1131     err = cbs_av1_write_leb128(ctx, pbc, "obu_size", obu->obu_size,
1132                                priv->fixed_obu_size_length);
1133     if (err < 0)
1134         goto error;
1135
1136     data_pos = put_bits_count(pbc) / 8;
1137     flush_put_bits(pbc);
1138     av_assert0(data_pos <= start_pos);
1139
1140     if (8 * obu->obu_size > put_bits_left(pbc)) {
1141         ff_refstruct_unref(&priv->sequence_header_ref);
1142         av_buffer_unref(&priv->frame_header_ref);
1143         *priv = av1ctx;
1144
1145         return AVERROR(ENOSPC);
1146     }
1147
1148     if (obu->obu_size > 0) {
1149         if (!priv->fixed_obu_size_length) {
1150             memmove(pbc->buf + data_pos,
1151                     pbc->buf + start_pos, header_size);
1152         } else {
1153             // The size was fixed so the following data was
1154             // already written in the correct place.
1155         }
1156         skip_put_bytes(pbc, header_size);
1157
1158         if (td) {
1159             memcpy(pbc->buf + data_pos + header_size,
1160                    td->data, td->data_size);
1161             skip_put_bytes(pbc, td->data_size);
1162         }
1163     }
1164
1165     // OBU data must be byte-aligned.
1166     av_assert0(put_bits_count(pbc) % 8 == 0);
1167     err = 0;
1168
1169 error:
1170     ff_refstruct_unref(&av1ctx.sequence_header_ref);
1171     av_buffer_unref(&av1ctx.frame_header_ref);
1172
1173     return err;
1174 }
1175
1176 static int cbs_av1_assemble_fragment(CodedBitstreamContext *ctx,
1177                                      CodedBitstreamFragment *frag)
1178 {
1179     size_t size, pos;
1180     int i;
1181
1182     size = 0;
1183     for (i = 0; i < frag->nb_units; i++)
1184         size += frag->units[i].data_size;
1185
1186     frag->data_ref = av_buffer_alloc(size + AV_INPUT_BUFFER_PADDING_SIZE);
1187     if (!frag->data_ref)
1188         return AVERROR(ENOMEM);
1189     frag->data = frag->data_ref->data;
1190     memset(frag->data + size, 0, AV_INPUT_BUFFER_PADDING_SIZE);
1191
1192     pos = 0;
1193     for (i = 0; i < frag->nb_units; i++) {
1194         memcpy(frag->data + pos, frag->units[i].data,
1195                frag->units[i].data_size);
1196         pos += frag->units[i].data_size;
1197     }
1198     av_assert0(pos == size);
1199     frag->data_size = size;
1200
1201     return 0;
1202 }
1203
1204 static void cbs_av1_flush(CodedBitstreamContext *ctx)
1205 {
1206     CodedBitstreamAV1Context *priv = ctx->priv_data;
1207
1208     av_buffer_unref(&priv->frame_header_ref);
1209     priv->sequence_header = NULL;
1210     priv->frame_header = NULL;
1211
1212     memset(priv->ref, 0, sizeof(priv->ref));
1213     priv->operating_point_idc = 0;
1214     priv->seen_frame_header = 0;
1215     priv->tile_num = 0;
1216 }
1217
1218 static void cbs_av1_close(CodedBitstreamContext *ctx)
1219 {
1220     CodedBitstreamAV1Context *priv = ctx->priv_data;
1221
1222     ff_refstruct_unref(&priv->sequence_header_ref);
1223     av_buffer_unref(&priv->frame_header_ref);
1224 }
1225
1226 static void cbs_av1_free_metadata(FFRefStructOpaque unused, void *content)
1227 {
1228     AV1RawOBU *obu = content;
1229     AV1RawMetadata *md;
1230
1231     av_assert0(obu->header.obu_type == AV1_OBU_METADATA);
1232     md = &obu->obu.metadata;
1233
1234     switch (md->metadata_type) {
1235     case AV1_METADATA_TYPE_HDR_CLL:
1236     case AV1_METADATA_TYPE_HDR_MDCV:
1237     case AV1_METADATA_TYPE_SCALABILITY:
1238     case AV1_METADATA_TYPE_TIMECODE:
1239         break;
1240     case AV1_METADATA_TYPE_ITUT_T35:
1241         av_buffer_unref(&md->metadata.itut_t35.payload_ref);
1242         break;
1243     default:
1244         av_buffer_unref(&md->metadata.unknown.payload_ref);
1245     }
1246 }
1247
1248 static const CodedBitstreamUnitTypeDescriptor cbs_av1_unit_types[] = {
1249     CBS_UNIT_TYPE_POD(AV1_OBU_SEQUENCE_HEADER,        AV1RawOBU),
1250     CBS_UNIT_TYPE_POD(AV1_OBU_TEMPORAL_DELIMITER,     AV1RawOBU),
1251     CBS_UNIT_TYPE_POD(AV1_OBU_FRAME_HEADER,           AV1RawOBU),
1252     CBS_UNIT_TYPE_POD(AV1_OBU_REDUNDANT_FRAME_HEADER, AV1RawOBU),
1253
1254     CBS_UNIT_TYPE_INTERNAL_REF(AV1_OBU_TILE_GROUP, AV1RawOBU,
1255                                obu.tile_group.tile_data.data),
1256     CBS_UNIT_TYPE_INTERNAL_REF(AV1_OBU_FRAME,      AV1RawOBU,
1257                                obu.frame.tile_group.tile_data.data),
1258     CBS_UNIT_TYPE_INTERNAL_REF(AV1_OBU_TILE_LIST,  AV1RawOBU,
1259                                obu.tile_list.tile_data.data),
1260     CBS_UNIT_TYPE_INTERNAL_REF(AV1_OBU_PADDING,    AV1RawOBU,
1261                                obu.padding.payload),
1262
1263     CBS_UNIT_TYPE_COMPLEX(AV1_OBU_METADATA, AV1RawOBU,
1264                           &cbs_av1_free_metadata),
1265
1266     CBS_UNIT_TYPE_END_OF_LIST
1267 };
1268
1269 #define OFFSET(x) offsetof(CodedBitstreamAV1Context, x)
1270 static const AVOption cbs_av1_options[] = {
1271     { "operating_point",  "Set operating point to select layers to parse from a scalable bitstream",
1272                           OFFSET(operating_point), AV_OPT_TYPE_INT, { .i64 = -1 }, -1, AV1_MAX_OPERATING_POINTS - 1, 0 },
1273     { "fixed_obu_size_length", "Set fixed length of the obu_size field",
1274       OFFSET(fixed_obu_size_length), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, 8, 0 },
1275     { NULL }
1276 };
1277
1278 static const AVClass cbs_av1_class = {
1279     .class_name = "cbs_av1",
1280     .item_name  = av_default_item_name,
1281     .option     = cbs_av1_options,
1282     .version    = LIBAVUTIL_VERSION_INT,
1283 };
1284
1285 const CodedBitstreamType ff_cbs_type_av1 = {
1286     .codec_id          = AV_CODEC_ID_AV1,
1287
1288     .priv_class        = &cbs_av1_class,
1289     .priv_data_size    = sizeof(CodedBitstreamAV1Context),
1290
1291     .unit_types        = cbs_av1_unit_types,
1292
1293     .split_fragment    = &cbs_av1_split_fragment,
1294     .read_unit         = &cbs_av1_read_unit,
1295     .write_unit        = &cbs_av1_write_obu,
1296     .assemble_fragment = &cbs_av1_assemble_fragment,
1297
1298     .flush             = &cbs_av1_flush,
1299     .close             = &cbs_av1_close,
1300 };