Imported Upstream version 6.1
[platform/upstream/ffmpeg.git] / libavcodec / cbs_av1.c
index 1229480..1d9ac5a 100644 (file)
 #include "libavutil/opt.h"
 #include "libavutil/pixfmt.h"
 
-#include "avcodec.h"
 #include "cbs.h"
 #include "cbs_internal.h"
 #include "cbs_av1.h"
+#include "defs.h"
+#include "refstruct.h"
 
 
 static int cbs_av1_read_uvlc(CodedBitstreamContext *ctx, GetBitContext *gbc,
@@ -31,10 +32,8 @@ static int cbs_av1_read_uvlc(CodedBitstreamContext *ctx, GetBitContext *gbc,
                              uint32_t range_min, uint32_t range_max)
 {
     uint32_t zeroes, bits_value, value;
-    int position;
 
-    if (ctx->trace_enable)
-        position = get_bits_count(gbc);
+    CBS_TRACE_READ_START();
 
     zeroes = 0;
     while (1) {
@@ -50,6 +49,9 @@ static int cbs_av1_read_uvlc(CodedBitstreamContext *ctx, GetBitContext *gbc,
     }
 
     if (zeroes >= 32) {
+        // Note that the spec allows an arbitrarily large number of
+        // zero bits followed by a one bit in this case, but the
+        // libaom implementation does not support it.
         value = MAX_UINT_BITS(32);
     } else {
         if (get_bits_left(gbc) < zeroes) {
@@ -62,36 +64,7 @@ static int cbs_av1_read_uvlc(CodedBitstreamContext *ctx, GetBitContext *gbc,
         value = bits_value + (UINT32_C(1) << zeroes) - 1;
     }
 
-    if (ctx->trace_enable) {
-        char bits[65];
-        int i, j, k;
-
-        if (zeroes >= 32) {
-            while (zeroes > 32) {
-                k = FFMIN(zeroes - 32, 32);
-                for (i = 0; i < k; i++)
-                    bits[i] = '0';
-                bits[i] = 0;
-                ff_cbs_trace_syntax_element(ctx, position, name,
-                                            NULL, bits, 0);
-                zeroes -= k;
-                position += k;
-            }
-        }
-
-        for (i = 0; i < zeroes; i++)
-            bits[i] = '0';
-        bits[i++] = '1';
-
-        if (zeroes < 32) {
-            for (j = 0; j < zeroes; j++)
-                bits[i++] = (bits_value >> (zeroes - j - 1) & 1) ? '1' : '0';
-        }
-
-        bits[i] = 0;
-        ff_cbs_trace_syntax_element(ctx, position, name,
-                                    NULL, bits, value);
-    }
+    CBS_TRACE_READ_END_NO_SUBSCRIPTS();
 
     if (value < range_min || value > range_max) {
         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
@@ -109,7 +82,9 @@ static int cbs_av1_write_uvlc(CodedBitstreamContext *ctx, PutBitContext *pbc,
                               uint32_t range_min, uint32_t range_max)
 {
     uint32_t v;
-    int position, zeroes;
+    int zeroes;
+
+    CBS_TRACE_WRITE_START();
 
     if (value < range_min || value > range_max) {
         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
@@ -118,28 +93,17 @@ static int cbs_av1_write_uvlc(CodedBitstreamContext *ctx, PutBitContext *pbc,
         return AVERROR_INVALIDDATA;
     }
 
-    if (ctx->trace_enable)
-        position = put_bits_count(pbc);
-
     zeroes = av_log2(value + 1);
     v = value - (1U << zeroes) + 1;
+
+    if (put_bits_left(pbc) < 2 * zeroes + 1)
+        return AVERROR(ENOSPC);
+
     put_bits(pbc, zeroes, 0);
     put_bits(pbc, 1, 1);
     put_bits(pbc, zeroes, v);
 
-    if (ctx->trace_enable) {
-        char bits[65];
-        int i, j;
-        i = 0;
-        for (j = 0; j < zeroes; j++)
-            bits[i++] = '0';
-        bits[i++] = '1';
-        for (j = 0; j < zeroes; j++)
-            bits[i++] = (v >> (zeroes - j - 1) & 1) ? '1' : '0';
-        bits[i++] = 0;
-        ff_cbs_trace_syntax_element(ctx, position, name, NULL,
-                                    bits, value);
-    }
+    CBS_TRACE_WRITE_END_NO_SUBSCRIPTS();
 
     return 0;
 }
@@ -148,20 +112,19 @@ static int cbs_av1_read_leb128(CodedBitstreamContext *ctx, GetBitContext *gbc,
                                const char *name, uint64_t *write_to)
 {
     uint64_t value;
-    int position, err, i;
+    uint32_t byte;
+    int i;
 
-    if (ctx->trace_enable)
-        position = get_bits_count(gbc);
+    CBS_TRACE_READ_START();
 
     value = 0;
     for (i = 0; i < 8; i++) {
-        int subscript[2] = { 1, i };
-        uint32_t byte;
-        err = ff_cbs_read_unsigned(ctx, gbc, 8, "leb128_byte[i]", subscript,
-                                   &byte, 0x00, 0xff);
-        if (err < 0)
-            return err;
-
+        if (get_bits_left(gbc) < 8) {
+            av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid leb128 at "
+                   "%s: bitstream ended.\n", name);
+            return AVERROR_INVALIDDATA;
+        }
+        byte = get_bits(gbc, 8);
         value |= (uint64_t)(byte & 0x7f) << (i * 7);
         if (!(byte & 0x80))
             break;
@@ -170,39 +133,44 @@ static int cbs_av1_read_leb128(CodedBitstreamContext *ctx, GetBitContext *gbc,
     if (value > UINT32_MAX)
         return AVERROR_INVALIDDATA;
 
-    if (ctx->trace_enable)
-        ff_cbs_trace_syntax_element(ctx, position, name, NULL, "", value);
+    CBS_TRACE_READ_END_NO_SUBSCRIPTS();
 
     *write_to = value;
     return 0;
 }
 
 static int cbs_av1_write_leb128(CodedBitstreamContext *ctx, PutBitContext *pbc,
-                                const char *name, uint64_t value)
+                                const char *name, uint64_t value, int fixed_length)
 {
-    int position, err, len, i;
+    int len, i;
     uint8_t byte;
 
+    CBS_TRACE_WRITE_START();
+
     len = (av_log2(value) + 7) / 7;
 
-    if (ctx->trace_enable)
-        position = put_bits_count(pbc);
+    if (fixed_length) {
+        if (fixed_length < len) {
+            av_log(ctx->log_ctx, AV_LOG_ERROR, "OBU is too large for "
+                   "fixed length size field (%d > %d).\n",
+                   len, fixed_length);
+            return AVERROR(EINVAL);
+        }
+        len = fixed_length;
+    }
 
     for (i = 0; i < len; i++) {
-        int subscript[2] = { 1, i };
+        if (put_bits_left(pbc) < 8)
+            return AVERROR(ENOSPC);
 
         byte = value >> (7 * i) & 0x7f;
         if (i < len - 1)
             byte |= 0x80;
 
-        err = ff_cbs_write_unsigned(ctx, pbc, 8, "leb128_byte[i]", subscript,
-                                    byte, 0x00, 0xff);
-        if (err < 0)
-            return err;
+        put_bits(pbc, 8, byte);
     }
 
-    if (ctx->trace_enable)
-        ff_cbs_trace_syntax_element(ctx, position, name, NULL, "", value);
+    CBS_TRACE_WRITE_END_NO_SUBSCRIPTS();
 
     return 0;
 }
@@ -212,12 +180,11 @@ static int cbs_av1_read_ns(CodedBitstreamContext *ctx, GetBitContext *gbc,
                            const int *subscripts, uint32_t *write_to)
 {
     uint32_t m, v, extra_bit, value;
-    int position, w;
+    int w;
 
-    av_assert0(n > 0);
+    CBS_TRACE_READ_START();
 
-    if (ctx->trace_enable)
-        position = get_bits_count(gbc);
+    av_assert0(n > 0);
 
     w = av_log2(n) + 1;
     m = (1 << w) - n;
@@ -240,18 +207,7 @@ static int cbs_av1_read_ns(CodedBitstreamContext *ctx, GetBitContext *gbc,
         value = (v << 1) - m + extra_bit;
     }
 
-    if (ctx->trace_enable) {
-        char bits[33];
-        int i;
-        for (i = 0; i < w - 1; i++)
-            bits[i] = (v >> i & 1) ? '1' : '0';
-        if (v >= m)
-            bits[i++] = extra_bit ? '1' : '0';
-        bits[i] = 0;
-
-        ff_cbs_trace_syntax_element(ctx, position,
-                                    name, subscripts, bits, value);
-    }
+    CBS_TRACE_READ_END();
 
     *write_to = value;
     return 0;
@@ -262,7 +218,8 @@ static int cbs_av1_write_ns(CodedBitstreamContext *ctx, PutBitContext *pbc,
                             const int *subscripts, uint32_t value)
 {
     uint32_t w, m, v, extra_bit;
-    int position;
+
+    CBS_TRACE_WRITE_START();
 
     if (value > n) {
         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
@@ -271,9 +228,6 @@ static int cbs_av1_write_ns(CodedBitstreamContext *ctx, PutBitContext *pbc,
         return AVERROR_INVALIDDATA;
     }
 
-    if (ctx->trace_enable)
-        position = put_bits_count(pbc);
-
     w = av_log2(n) + 1;
     m = (1 << w) - n;
 
@@ -290,18 +244,7 @@ static int cbs_av1_write_ns(CodedBitstreamContext *ctx, PutBitContext *pbc,
         put_bits(pbc, 1, extra_bit);
     }
 
-    if (ctx->trace_enable) {
-        char bits[33];
-        int i;
-        for (i = 0; i < w - 1; i++)
-            bits[i] = (v >> i & 1) ? '1' : '0';
-        if (value >= m)
-            bits[i++] = extra_bit ? '1' : '0';
-        bits[i] = 0;
-
-        ff_cbs_trace_syntax_element(ctx, position,
-                                    name, subscripts, bits, value);
-    }
+    CBS_TRACE_WRITE_END();
 
     return 0;
 }
@@ -311,33 +254,24 @@ static int cbs_av1_read_increment(CodedBitstreamContext *ctx, GetBitContext *gbc
                                   const char *name, uint32_t *write_to)
 {
     uint32_t value;
-    int position, i;
-    char bits[33];
 
-    av_assert0(range_min <= range_max && range_max - range_min < sizeof(bits) - 1);
-    if (ctx->trace_enable)
-        position = get_bits_count(gbc);
+    CBS_TRACE_READ_START();
 
-    for (i = 0, value = range_min; value < range_max;) {
+    av_assert0(range_min <= range_max && range_max - range_min < 32);
+
+    for (value = range_min; value < range_max;) {
         if (get_bits_left(gbc) < 1) {
             av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid increment value at "
                    "%s: bitstream ended.\n", name);
             return AVERROR_INVALIDDATA;
         }
-        if (get_bits1(gbc)) {
-            bits[i++] = '1';
+        if (get_bits1(gbc))
             ++value;
-        } else {
-            bits[i++] = '0';
+        else
             break;
-        }
     }
 
-    if (ctx->trace_enable) {
-        bits[i] = 0;
-        ff_cbs_trace_syntax_element(ctx, position,
-                                    name, NULL, bits, value);
-    }
+    CBS_TRACE_READ_END_NO_SUBSCRIPTS();
 
     *write_to = value;
     return 0;
@@ -349,6 +283,8 @@ static int cbs_av1_write_increment(CodedBitstreamContext *ctx, PutBitContext *pb
 {
     int len;
 
+    CBS_TRACE_WRITE_START();
+
     av_assert0(range_min <= range_max && range_max - range_min < 32);
     if (value < range_min || value > range_max) {
         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
@@ -364,23 +300,11 @@ static int cbs_av1_write_increment(CodedBitstreamContext *ctx, PutBitContext *pb
     if (put_bits_left(pbc) < len)
         return AVERROR(ENOSPC);
 
-    if (ctx->trace_enable) {
-        char bits[33];
-        int i;
-        for (i = 0; i < len; i++) {
-            if (range_min + i == value)
-                bits[i] = '0';
-            else
-                bits[i] = '1';
-        }
-        bits[i] = 0;
-        ff_cbs_trace_syntax_element(ctx, put_bits_count(pbc),
-                                    name, NULL, bits, value);
-    }
-
     if (len > 0)
         put_bits(pbc, len, (1 << len) - 1 - (value != range_max));
 
+    CBS_TRACE_WRITE_END_NO_SUBSCRIPTS();
+
     return 0;
 }
 
@@ -388,12 +312,10 @@ static int cbs_av1_read_subexp(CodedBitstreamContext *ctx, GetBitContext *gbc,
                                uint32_t range_max, const char *name,
                                const int *subscripts, uint32_t *write_to)
 {
-    uint32_t value;
-    int position, err;
-    uint32_t max_len, len, range_offset, range_bits;
+    uint32_t value, max_len, len, range_offset, range_bits;
+    int err;
 
-    if (ctx->trace_enable)
-        position = get_bits_count(gbc);
+    CBS_TRACE_READ_START();
 
     av_assert0(range_max > 0);
     max_len = av_log2(range_max - 1) - 3;
@@ -412,9 +334,8 @@ static int cbs_av1_read_subexp(CodedBitstreamContext *ctx, GetBitContext *gbc,
     }
 
     if (len < max_len) {
-        err = ff_cbs_read_unsigned(ctx, gbc, range_bits,
-                                   "subexp_bits", NULL, &value,
-                                   0, MAX_UINT_BITS(range_bits));
+        err = ff_cbs_read_simple_unsigned(ctx, gbc, range_bits,
+                                          "subexp_bits", &value);
         if (err < 0)
             return err;
 
@@ -426,9 +347,7 @@ static int cbs_av1_read_subexp(CodedBitstreamContext *ctx, GetBitContext *gbc,
     }
     value += range_offset;
 
-    if (ctx->trace_enable)
-        ff_cbs_trace_syntax_element(ctx, position,
-                                    name, subscripts, "", value);
+    CBS_TRACE_READ_END_VALUE_ONLY();
 
     *write_to = value;
     return err;
@@ -438,9 +357,11 @@ static int cbs_av1_write_subexp(CodedBitstreamContext *ctx, PutBitContext *pbc,
                                 uint32_t range_max, const char *name,
                                 const int *subscripts, uint32_t value)
 {
-    int position, err;
+    int err;
     uint32_t max_len, len, range_offset, range_bits;
 
+    CBS_TRACE_WRITE_START();
+
     if (value > range_max) {
         av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
                "%"PRIu32", but must be in [0,%"PRIu32"].\n",
@@ -448,9 +369,6 @@ static int cbs_av1_write_subexp(CodedBitstreamContext *ctx, PutBitContext *pbc,
         return AVERROR_INVALIDDATA;
     }
 
-    if (ctx->trace_enable)
-        position = put_bits_count(pbc);
-
     av_assert0(range_max > 0);
     max_len = av_log2(range_max - 1) - 3;
 
@@ -476,10 +394,9 @@ static int cbs_av1_write_subexp(CodedBitstreamContext *ctx, PutBitContext *pbc,
         return err;
 
     if (len < max_len) {
-        err = ff_cbs_write_unsigned(ctx, pbc, range_bits,
-                                    "subexp_bits", NULL,
-                                    value - range_offset,
-                                    0, MAX_UINT_BITS(range_bits));
+        err = ff_cbs_write_simple_unsigned(ctx, pbc, range_bits,
+                                           "subexp_bits",
+                                           value - range_offset);
         if (err < 0)
             return err;
 
@@ -491,9 +408,7 @@ static int cbs_av1_write_subexp(CodedBitstreamContext *ctx, PutBitContext *pbc,
             return err;
     }
 
-    if (ctx->trace_enable)
-        ff_cbs_trace_syntax_element(ctx, position,
-                                    name, subscripts, "", value);
+    CBS_TRACE_WRITE_END_VALUE_ONLY();
 
     return err;
 }
@@ -546,8 +461,6 @@ static size_t cbs_av1_get_payload_bytes_left(GetBitContext *gbc)
 
 #define SUBSCRIPTS(subs, ...) (subs > 0 ? ((int[subs + 1]){ subs, __VA_ARGS__ }) : NULL)
 
-#define fb(width, name) \
-        xf(width, name, current->name, 0, MAX_UINT_BITS(width), 0, )
 #define fc(width, name, range_min, range_max) \
         xf(width, name, current->name, range_min, range_max, 0, )
 #define flag(name) fb(1, name)
@@ -573,6 +486,13 @@ static size_t cbs_av1_get_payload_bytes_left(GetBitContext *gbc)
 #define READWRITE read
 #define RWContext GetBitContext
 
+#define fb(width, name) do { \
+        uint32_t value; \
+        CHECK(ff_cbs_read_simple_unsigned(ctx, rw, width, \
+                                          #name, &value)); \
+        current->name = value; \
+    } while (0)
+
 #define xf(width, name, var, range_min, range_max, subs, ...) do { \
         uint32_t value; \
         CHECK(ff_cbs_read_unsigned(ctx, rw, width, #name, \
@@ -645,6 +565,7 @@ static size_t cbs_av1_get_payload_bytes_left(GetBitContext *gbc)
 #undef READ
 #undef READWRITE
 #undef RWContext
+#undef fb
 #undef xf
 #undef xsu
 #undef uvlc
@@ -661,6 +582,11 @@ static size_t cbs_av1_get_payload_bytes_left(GetBitContext *gbc)
 #define READWRITE write
 #define RWContext PutBitContext
 
+#define fb(width, name) do { \
+        CHECK(ff_cbs_write_simple_unsigned(ctx, rw, width, #name, \
+                                           current->name)); \
+    } while (0)
+
 #define xf(width, name, var, range_min, range_max, subs, ...) do { \
         CHECK(ff_cbs_write_unsigned(ctx, rw, width, #name, \
                                     SUBSCRIPTS(subs, __VA_ARGS__), \
@@ -703,7 +629,7 @@ static size_t cbs_av1_get_payload_bytes_left(GetBitContext *gbc)
     } while (0)
 
 #define leb128(name) do { \
-        CHECK(cbs_av1_write_leb128(ctx, rw, #name, current->name)); \
+        CHECK(cbs_av1_write_leb128(ctx, rw, #name, current->name, 0)); \
     } while (0)
 
 #define infer(name, value) do { \
@@ -723,6 +649,7 @@ static size_t cbs_av1_get_payload_bytes_left(GetBitContext *gbc)
 #undef WRITE
 #undef READWRITE
 #undef RWContext
+#undef fb
 #undef xf
 #undef xsu
 #undef uvlc
@@ -878,7 +805,7 @@ static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
     GetBitContext gbc;
     int err, start_pos, end_pos;
 
-    err = ff_cbs_alloc_unit_content2(ctx, unit);
+    err = ff_cbs_alloc_unit_content(ctx, unit);
     if (err < 0)
         return err;
     obu = unit->content;
@@ -943,12 +870,7 @@ static int cbs_av1_read_unit(CodedBitstreamContext *ctx,
                 priv->operating_point_idc = sequence_header->operating_point_idc[priv->operating_point];
             }
 
-            av_buffer_unref(&priv->sequence_header_ref);
-            priv->sequence_header = NULL;
-
-            priv->sequence_header_ref = av_buffer_ref(unit->content_ref);
-            if (!priv->sequence_header_ref)
-                return AVERROR(ENOMEM);
+            ff_refstruct_replace(&priv->sequence_header_ref, unit->content_ref);
             priv->sequence_header = &obu->obu.sequence_header;
         }
         break;
@@ -1058,21 +980,40 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
     AV1RawTileData *td;
     size_t header_size;
     int err, start_pos, end_pos, data_pos;
+    CodedBitstreamAV1Context av1ctx;
 
     // OBUs in the normal bitstream format must contain a size field
     // in every OBU (in annex B it is optional, but we don't support
     // writing that).
     obu->header.obu_has_size_field = 1;
+    av1ctx = *priv;
+
+    if (priv->sequence_header_ref) {
+        av1ctx.sequence_header_ref = ff_refstruct_ref(priv->sequence_header_ref);
+    }
+
+    if (priv->frame_header_ref) {
+        av1ctx.frame_header_ref = av_buffer_ref(priv->frame_header_ref);
+        if (!av1ctx.frame_header_ref) {
+            err = AVERROR(ENOMEM);
+            goto error;
+        }
+    }
 
     err = cbs_av1_write_obu_header(ctx, pbc, &obu->header);
     if (err < 0)
-        return err;
+        goto error;
 
     if (obu->header.obu_has_size_field) {
         pbc_tmp = *pbc;
-        // Add space for the size field to fill later.
-        put_bits32(pbc, 0);
-        put_bits32(pbc, 0);
+        if (priv->fixed_obu_size_length) {
+            for (int i = 0; i < priv->fixed_obu_size_length; i++)
+                put_bits(pbc, 8, 0);
+        } else {
+            // Add space for the size field to fill later.
+            put_bits32(pbc, 0);
+            put_bits32(pbc, 0);
+        }
     }
 
     td = NULL;
@@ -1084,18 +1025,16 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
             err = cbs_av1_write_sequence_header_obu(ctx, pbc,
                                                     &obu->obu.sequence_header);
             if (err < 0)
-                return err;
+                goto error;
 
-            av_buffer_unref(&priv->sequence_header_ref);
+            ff_refstruct_unref(&priv->sequence_header_ref);
             priv->sequence_header = NULL;
 
             err = ff_cbs_make_unit_refcounted(ctx, unit);
             if (err < 0)
-                return err;
+                goto error;
 
-            priv->sequence_header_ref = av_buffer_ref(unit->content_ref);
-            if (!priv->sequence_header_ref)
-                return AVERROR(ENOMEM);
+            priv->sequence_header_ref = ff_refstruct_ref(unit->content_ref);
             priv->sequence_header = &obu->obu.sequence_header;
         }
         break;
@@ -1103,7 +1042,7 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
         {
             err = cbs_av1_write_temporal_delimiter_obu(ctx, pbc);
             if (err < 0)
-                return err;
+                goto error;
         }
         break;
     case AV1_OBU_FRAME_HEADER:
@@ -1115,7 +1054,7 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
                                                  AV1_OBU_REDUNDANT_FRAME_HEADER,
                                                  NULL);
             if (err < 0)
-                return err;
+                goto error;
         }
         break;
     case AV1_OBU_TILE_GROUP:
@@ -1123,7 +1062,7 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
             err = cbs_av1_write_tile_group_obu(ctx, pbc,
                                                &obu->obu.tile_group);
             if (err < 0)
-                return err;
+                goto error;
 
             td = &obu->obu.tile_group.tile_data;
         }
@@ -1132,7 +1071,7 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
         {
             err = cbs_av1_write_frame_obu(ctx, pbc, &obu->obu.frame, NULL);
             if (err < 0)
-                return err;
+                goto error;
 
             td = &obu->obu.frame.tile_group.tile_data;
         }
@@ -1141,7 +1080,7 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
         {
             err = cbs_av1_write_tile_list_obu(ctx, pbc, &obu->obu.tile_list);
             if (err < 0)
-                return err;
+                goto error;
 
             td = &obu->obu.tile_list.tile_data;
         }
@@ -1150,18 +1089,19 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
         {
             err = cbs_av1_write_metadata_obu(ctx, pbc, &obu->obu.metadata);
             if (err < 0)
-                return err;
+                goto error;
         }
         break;
     case AV1_OBU_PADDING:
         {
             err = cbs_av1_write_padding_obu(ctx, pbc, &obu->obu.padding);
             if (err < 0)
-                return err;
+                goto error;
         }
         break;
     default:
-        return AVERROR(ENOSYS);
+        err = AVERROR(ENOSYS);
+        goto error;
     }
 
     end_pos = put_bits_count(pbc);
@@ -1172,7 +1112,7 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
         // Add trailing bits and recalculate.
         err = cbs_av1_write_trailing_bits(ctx, pbc, 8 - end_pos % 8);
         if (err < 0)
-            return err;
+            goto error;
         end_pos = put_bits_count(pbc);
         obu->obu_size = header_size = (end_pos - start_pos + 7) / 8;
     } else {
@@ -1188,20 +1128,31 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
     end_pos   /= 8;
 
     *pbc = pbc_tmp;
-    err = cbs_av1_write_leb128(ctx, pbc, "obu_size", obu->obu_size);
+    err = cbs_av1_write_leb128(ctx, pbc, "obu_size", obu->obu_size,
+                               priv->fixed_obu_size_length);
     if (err < 0)
-        return err;
+        goto error;
 
     data_pos = put_bits_count(pbc) / 8;
     flush_put_bits(pbc);
     av_assert0(data_pos <= start_pos);
 
-    if (8 * obu->obu_size > put_bits_left(pbc))
+    if (8 * obu->obu_size > put_bits_left(pbc)) {
+        ff_refstruct_unref(&priv->sequence_header_ref);
+        av_buffer_unref(&priv->frame_header_ref);
+        *priv = av1ctx;
+
         return AVERROR(ENOSPC);
+    }
 
     if (obu->obu_size > 0) {
-        memmove(pbc->buf + data_pos,
-                pbc->buf + start_pos, header_size);
+        if (!priv->fixed_obu_size_length) {
+            memmove(pbc->buf + data_pos,
+                    pbc->buf + start_pos, header_size);
+        } else {
+            // The size was fixed so the following data was
+            // already written in the correct place.
+        }
         skip_put_bytes(pbc, header_size);
 
         if (td) {
@@ -1213,8 +1164,13 @@ static int cbs_av1_write_obu(CodedBitstreamContext *ctx,
 
     // OBU data must be byte-aligned.
     av_assert0(put_bits_count(pbc) % 8 == 0);
+    err = 0;
 
-    return 0;
+error:
+    ff_refstruct_unref(&av1ctx.sequence_header_ref);
+    av_buffer_unref(&av1ctx.frame_header_ref);
+
+    return err;
 }
 
 static int cbs_av1_assemble_fragment(CodedBitstreamContext *ctx,
@@ -1263,24 +1219,30 @@ static void cbs_av1_close(CodedBitstreamContext *ctx)
 {
     CodedBitstreamAV1Context *priv = ctx->priv_data;
 
-    av_buffer_unref(&priv->sequence_header_ref);
+    ff_refstruct_unref(&priv->sequence_header_ref);
     av_buffer_unref(&priv->frame_header_ref);
 }
 
-static void cbs_av1_free_metadata(void *unit, uint8_t *content)
+static void cbs_av1_free_metadata(FFRefStructOpaque unused, void *content)
 {
-    AV1RawOBU *obu = (AV1RawOBU*)content;
+    AV1RawOBU *obu = content;
     AV1RawMetadata *md;
 
     av_assert0(obu->header.obu_type == AV1_OBU_METADATA);
     md = &obu->obu.metadata;
 
     switch (md->metadata_type) {
+    case AV1_METADATA_TYPE_HDR_CLL:
+    case AV1_METADATA_TYPE_HDR_MDCV:
+    case AV1_METADATA_TYPE_SCALABILITY:
+    case AV1_METADATA_TYPE_TIMECODE:
+        break;
     case AV1_METADATA_TYPE_ITUT_T35:
         av_buffer_unref(&md->metadata.itut_t35.payload_ref);
         break;
+    default:
+        av_buffer_unref(&md->metadata.unknown.payload_ref);
     }
-    av_free(content);
 }
 
 static const CodedBitstreamUnitTypeDescriptor cbs_av1_unit_types[] = {
@@ -1308,6 +1270,8 @@ static const CodedBitstreamUnitTypeDescriptor cbs_av1_unit_types[] = {
 static const AVOption cbs_av1_options[] = {
     { "operating_point",  "Set operating point to select layers to parse from a scalable bitstream",
                           OFFSET(operating_point), AV_OPT_TYPE_INT, { .i64 = -1 }, -1, AV1_MAX_OPERATING_POINTS - 1, 0 },
+    { "fixed_obu_size_length", "Set fixed length of the obu_size field",
+      OFFSET(fixed_obu_size_length), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, 8, 0 },
     { NULL }
 };