Fix ROS detection in combination with incomplete RSI
[platform/upstream/libaec.git] / src / encode.c
1 /**
2  * @file encode.c
3  *
4  * @section LICENSE
5  * Copyright 2012 - 2016
6  *
7  * Mathis Rosenhauer, Moritz Hanke, Joerg Behrens
8  * Deutsches Klimarechenzentrum GmbH
9  * Bundesstr. 45a
10  * 20146 Hamburg Germany
11  *
12  * Luis Kornblueh
13  * Max-Planck-Institut fuer Meteorologie
14  * Bundesstr. 53
15  * 20146 Hamburg
16  * Germany
17  *
18  * All rights reserved.
19  *
20  * Redistribution and use in source and binary forms, with or without
21  * modification, are permitted provided that the following conditions
22  * are met:
23  *
24  * 1. Redistributions of source code must retain the above copyright
25  *    notice, this list of conditions and the following disclaimer.
26  * 2. Redistributions in binary form must reproduce the above
27  *    copyright notice, this list of conditions and the following
28  *    disclaimer in the documentation and/or other materials provided
29  *    with the distribution.
30  *
31  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
32  * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
33  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
34  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
35  * COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
36  * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
37  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
38  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
39  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
40  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
41  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
42  * OF THE POSSIBILITY OF SUCH DAMAGE.
43  *
44  * @section DESCRIPTION
45  *
46  * Adaptive Entropy Encoder
47  * Based on CCSDS documents 121.0-B-2 and 120.0-G-3
48  *
49  */
50
51 #include <stdio.h>
52 #include <stdlib.h>
53 #include <string.h>
54
55 #include "libaec.h"
56 #include "encode.h"
57 #include "encode_accessors.h"
58
59 static int m_get_block(struct aec_stream *strm);
60
61 static inline void emit(struct internal_state *state,
62                         uint32_t data, int bits)
63 {
64     /**
65        Emit sequence of bits.
66      */
67
68     if (bits <= state->bits) {
69         state->bits -= bits;
70         *state->cds += (uint8_t)(data << state->bits);
71     } else {
72         bits -= state->bits;
73         *state->cds++ += (uint8_t)((uint64_t)data >> bits);
74
75         while (bits > 8) {
76             bits -= 8;
77             *state->cds++ = (uint8_t)(data >> bits);
78         }
79
80         state->bits = 8 - bits;
81         *state->cds = (uint8_t)(data << state->bits);
82     }
83 }
84
85 static inline void emitfs(struct internal_state *state, int fs)
86 {
87     /**
88        Emits a fundamental sequence.
89
90        fs zero bits followed by one 1 bit.
91      */
92
93     for(;;) {
94         if (fs < state->bits) {
95             state->bits -= fs + 1;
96             *state->cds += 1U << state->bits;
97             break;
98         } else {
99             fs -= state->bits;
100             *++state->cds = 0;
101             state->bits = 8;
102         }
103     }
104 }
105
106 static inline void copy64(uint8_t *dst, uint64_t src)
107 {
108     dst[0] = (uint8_t)(src >> 56);
109     dst[1] = (uint8_t)(src >> 48);
110     dst[2] = (uint8_t)(src >> 40);
111     dst[3] = (uint8_t)(src >> 32);
112     dst[4] = (uint8_t)(src >> 24);
113     dst[5] = (uint8_t)(src >> 16);
114     dst[6] = (uint8_t)(src >> 8);
115     dst[7] = (uint8_t)src;
116 }
117
118 static inline void emitblock_fs(struct aec_stream *strm, int k, int ref)
119 {
120     size_t i;
121     uint32_t used; /* used bits in 64 bit accumulator */
122     uint64_t acc; /* accumulator */
123     struct internal_state *state = strm->state;
124
125     acc = (uint64_t)*state->cds << 56;
126     used = 7 - state->bits;
127
128     for (i = ref; i < strm->block_size; i++) {
129         used += (state->block[i] >> k) + 1;
130         while (used > 63) {
131             copy64(state->cds, acc);
132             state->cds += 8;
133             acc = 0;
134             used -= 64;
135         }
136         acc |= UINT64_C(1) << (63 - used);
137     }
138
139     copy64(state->cds, acc);
140     state->cds += used >> 3;
141     state->bits = 7 - (used & 7);
142 }
143
144 static inline void emitblock(struct aec_stream *strm, int k, int ref)
145 {
146     /**
147        Emit the k LSB of a whole block of input data.
148     */
149
150     uint64_t a;
151     struct internal_state *state = strm->state;
152     uint32_t *in = state->block + ref;
153     uint32_t *in_end = state->block + strm->block_size;
154     uint64_t mask = (UINT64_C(1) << k) - 1;
155     uint8_t *o = state->cds;
156     int p = state->bits;
157
158     a = *o;
159
160     while(in < in_end) {
161         a <<= 56;
162         p = (p % 8) + 56;
163
164         while (p > k && in < in_end) {
165             p -= k;
166             a += ((uint64_t)(*in++) & mask) << p;
167         }
168
169         switch (p & ~7) {
170         case 0:
171             o[0] = (uint8_t)(a >> 56);
172             o[1] = (uint8_t)(a >> 48);
173             o[2] = (uint8_t)(a >> 40);
174             o[3] = (uint8_t)(a >> 32);
175             o[4] = (uint8_t)(a >> 24);
176             o[5] = (uint8_t)(a >> 16);
177             o[6] = (uint8_t)(a >> 8);
178             o += 7;
179             break;
180         case 8:
181             o[0] = (uint8_t)(a >> 56);
182             o[1] = (uint8_t)(a >> 48);
183             o[2] = (uint8_t)(a >> 40);
184             o[3] = (uint8_t)(a >> 32);
185             o[4] = (uint8_t)(a >> 24);
186             o[5] = (uint8_t)(a >> 16);
187             a >>= 8;
188             o += 6;
189             break;
190         case 16:
191             o[0] = (uint8_t)(a >> 56);
192             o[1] = (uint8_t)(a >> 48);
193             o[2] = (uint8_t)(a >> 40);
194             o[3] = (uint8_t)(a >> 32);
195             o[4] = (uint8_t)(a >> 24);
196             a >>= 16;
197             o += 5;
198             break;
199         case 24:
200             o[0] = (uint8_t)(a >> 56);
201             o[1] = (uint8_t)(a >> 48);
202             o[2] = (uint8_t)(a >> 40);
203             o[3] = (uint8_t)(a >> 32);
204             a >>= 24;
205             o += 4;
206             break;
207         case 32:
208             o[0] = (uint8_t)(a >> 56);
209             o[1] = (uint8_t)(a >> 48);
210             o[2] = (uint8_t)(a >> 40);
211             a >>= 32;
212             o += 3;
213             break;
214         case 40:
215             o[0] = (uint8_t)(a >> 56);
216             o[1] = (uint8_t)(a >> 48);
217             a >>= 40;
218             o += 2;
219             break;
220         case 48:
221             *o++ = (uint8_t)(a >> 56);
222             a >>= 48;
223             break;
224         default:
225             a >>= 56;
226             break;
227         }
228     }
229
230     *o = (uint8_t)a;
231     state->cds = o;
232     state->bits = p % 8;
233 }
234
235 static void preprocess_unsigned(struct aec_stream *strm)
236 {
237     /**
238        Preprocess RSI of unsigned samples.
239
240        Combining preprocessing and converting to uint32_t in one loop
241        is slower due to the data dependance on x_i-1.
242     */
243
244     uint32_t D;
245     struct internal_state *state = strm->state;
246     const uint32_t *restrict x = state->data_raw;
247     uint32_t *restrict d = state->data_pp;
248     uint32_t xmax = state->xmax;
249     uint32_t rsi = strm->rsi * strm->block_size - 1;
250     size_t i;
251
252     state->ref = 1;
253     state->ref_sample = x[0];
254     d[0] = 0;
255     for (i = 0; i < rsi; i++) {
256         if (x[i + 1] >= x[i]) {
257             D = x[i + 1] - x[i];
258             if (D <= x[i])
259                 d[i + 1] = 2 * D;
260             else
261                 d[i + 1] = x[i + 1];
262         } else {
263             D = x[i] - x[i + 1];
264             if (D <= xmax - x[i])
265                 d[i + 1] = 2 * D - 1;
266             else
267                 d[i + 1] = xmax - x[i + 1];
268         }
269     }
270     state->uncomp_len = (strm->block_size - 1) * strm->bits_per_sample;
271 }
272
273 static void preprocess_signed(struct aec_stream *strm)
274 {
275     /**
276        Preprocess RSI of signed samples.
277     */
278
279     uint32_t D;
280     struct internal_state *state = strm->state;
281     int32_t *restrict x = (int32_t *)state->data_raw;
282     uint32_t *restrict d = state->data_pp;
283     int32_t xmax = (int32_t)state->xmax;
284     int32_t xmin = (int32_t)state->xmin;
285     uint32_t rsi = strm->rsi * strm->block_size - 1;
286     uint32_t m = UINT64_C(1) << (strm->bits_per_sample - 1);
287     size_t i;
288
289     state->ref = 1;
290     state->ref_sample = x[0];
291     d[0] = 0;
292     x[0] = (x[0] ^ m) - m;
293
294     for (i = 0; i < rsi; i++) {
295         x[i + 1] = (x[i + 1] ^ m) - m;
296         if (x[i + 1] < x[i]) {
297             D = (uint32_t)(x[i] - x[i + 1]);
298             if (D <= (uint32_t)(xmax - x[i]))
299                 d[i + 1] = 2 * D - 1;
300             else
301                 d[i + 1] = xmax - x[i + 1];
302         } else {
303             D = (uint32_t)(x[i + 1] - x[i]);
304             if (D <= (uint32_t)(x[i] - xmin))
305                 d[i + 1] = 2 * D;
306             else
307                 d[i + 1] = x[i + 1] - xmin;
308         }
309     }
310     state->uncomp_len = (strm->block_size - 1) * strm->bits_per_sample;
311 }
312
313 static inline uint64_t block_fs(struct aec_stream *strm, int k)
314 {
315     /**
316        Sum FS of all samples in block for given splitting position.
317     */
318
319     size_t i;
320     uint64_t fs = 0;
321     struct internal_state *state = strm->state;
322
323     for (i = 0; i < strm->block_size; i++)
324         fs += (uint64_t)(state->block[i] >> k);
325
326     return fs;
327 }
328
329 static uint32_t assess_splitting_option(struct aec_stream *strm)
330 {
331     /**
332        Length of CDS encoded with splitting option and optimal k.
333
334        In Rice coding each sample in a block of samples is split at
335        the same position into k LSB and bits_per_sample - k MSB. The
336        LSB part is left binary and the MSB part is coded as a
337        fundamental sequence a.k.a. unary (see CCSDS 121.0-B-2). The
338        function of the length of the Coded Data Set (CDS) depending on
339        k has exactly one minimum (see A. Kiely, IPN Progress Report
340        42-159).
341
342        To find that minimum with only a few costly evaluations of the
343        CDS length, we start with the k of the previous CDS. K is
344        increased and the CDS length evaluated. If the CDS length gets
345        smaller, then we are moving towards the minimum. If the length
346        increases, then the minimum will be found with smaller k.
347
348        For increasing k we know that we will gain block_size bits in
349        length through the larger binary part. If the FS lenth is less
350        than the block size then a reduced FS part can't compensate the
351        larger binary part. So we know that the CDS for k+1 will be
352        larger than for k without actually computing the length. An
353        analogue check can be done for decreasing k.
354      */
355
356     int k;
357     int k_min;
358     int this_bs; /* Block size of current block */
359     int no_turn; /* 1 if we shouldn't reverse */
360     int dir; /* Direction, 1 means increasing k, 0 decreasing k */
361     uint64_t len; /* CDS length for current k */
362     uint64_t len_min; /* CDS length minimum so far */
363     uint64_t fs_len; /* Length of FS part (not including 1s) */
364
365     struct internal_state *state = strm->state;
366
367     this_bs = strm->block_size - state->ref;
368     len_min = UINT64_MAX;
369     k = k_min = state->k;
370     no_turn = k == 0;
371     dir = 1;
372
373     for (;;) {
374         fs_len = block_fs(strm, k);
375         len = fs_len + this_bs * (k + 1);
376
377         if (len < len_min) {
378             if (len_min < UINT64_MAX)
379                 no_turn = 1;
380
381             len_min = len;
382             k_min = k;
383
384             if (dir) {
385                 if (fs_len < this_bs || k >= state->kmax) {
386                     if (no_turn)
387                         break;
388                     k = state->k - 1;
389                     dir = 0;
390                     no_turn = 1;
391                 } else {
392                     k++;
393                 }
394             } else {
395                 if (fs_len >= this_bs || k == 0)
396                     break;
397                 k--;
398             }
399         } else {
400             if (no_turn)
401                 break;
402             k = state->k - 1;
403             dir = 0;
404             no_turn = 1;
405         }
406     }
407     state->k = k_min;
408
409     return (uint32_t)len_min;
410 }
411
412 static uint32_t assess_se_option(struct aec_stream *strm)
413 {
414     /**
415        Length of CDS encoded with Second Extension option.
416
417        If length is above limit just return UINT32_MAX.
418     */
419
420     size_t i;
421     uint64_t len, d;
422     struct internal_state *state = strm->state;
423     uint32_t *block = state->block;
424
425     len = 1;
426
427     for (i = 0; i < strm->block_size; i += 2) {
428         d = (uint64_t)block[i] + (uint64_t)block[i + 1];
429         len += d * (d + 1) / 2 + block[i + 1] + 1;
430         if (len > state->uncomp_len)
431             return UINT32_MAX;
432     }
433     return (uint32_t)len;
434 }
435
436 static void init_output(struct aec_stream *strm)
437 {
438     /**
439        Direct output to next_out if next_out can hold a Coded Data
440        Set, use internal buffer otherwise.
441     */
442
443     struct internal_state *state = strm->state;
444
445     if (strm->avail_out > CDSLEN) {
446         if (!state->direct_out) {
447             state->direct_out = 1;
448             *strm->next_out = *state->cds;
449             state->cds = strm->next_out;
450         }
451     } else {
452         if (state->zero_blocks == 0 || state->direct_out) {
453             /* copy leftover from last block */
454             *state->cds_buf = *state->cds;
455             state->cds = state->cds_buf;
456         }
457         state->direct_out = 0;
458     }
459 }
460
461 /*
462  *
463  * FSM functions
464  *
465  */
466
467 static int m_flush_block_resumable(struct aec_stream *strm)
468 {
469     /**
470        Slow and restartable flushing
471     */
472     struct internal_state *state = strm->state;
473
474     int n = (int)MIN((size_t)(state->cds - state->cds_buf - state->i),
475                      strm->avail_out);
476     memcpy(strm->next_out, state->cds_buf + state->i, n);
477     strm->next_out += n;
478     strm->avail_out -= n;
479     state->i += n;
480
481     if (strm->avail_out == 0) {
482         return M_EXIT;
483     } else {
484         state->mode = m_get_block;
485         return M_CONTINUE;
486     }
487 }
488
489 static int m_flush_block(struct aec_stream *strm)
490 {
491     /**
492        Flush block in direct_out mode by updating counters.
493
494        Fall back to slow flushing if in buffered mode.
495     */
496     int n;
497     struct internal_state *state = strm->state;
498
499 #ifdef ENABLE_RSI_PADDING
500     if (state->blocks_avail == 0
501         && strm->flags & AEC_PAD_RSI
502         && state->block_nonzero == 0
503         )
504         emit(state, 0, state->bits % 8);
505 #endif
506
507     if (state->direct_out) {
508         n = (int)(state->cds - strm->next_out);
509         strm->next_out += n;
510         strm->avail_out -= n;
511         state->mode = m_get_block;
512         return M_CONTINUE;
513     }
514
515     state->i = 0;
516     state->mode = m_flush_block_resumable;
517     return M_CONTINUE;
518 }
519
520 static int m_encode_splitting(struct aec_stream *strm)
521 {
522     struct internal_state *state = strm->state;
523     int k = state->k;
524
525     emit(state, k + 1, state->id_len);
526     if (state->ref)
527         emit(state, state->ref_sample, strm->bits_per_sample);
528
529     emitblock_fs(strm, k, state->ref);
530     if (k)
531         emitblock(strm, k, state->ref);
532
533     return m_flush_block(strm);
534 }
535
536 static int m_encode_uncomp(struct aec_stream *strm)
537 {
538     struct internal_state *state = strm->state;
539
540     emit(state, (1U << state->id_len) - 1, state->id_len);
541     if (state->ref)
542         state->block[0] = state->ref_sample;
543     emitblock(strm, strm->bits_per_sample, 0);
544     return m_flush_block(strm);
545 }
546
547 static int m_encode_se(struct aec_stream *strm)
548 {
549     size_t i;
550     uint32_t d;
551     struct internal_state *state = strm->state;
552
553     emit(state, 1, state->id_len + 1);
554     if (state->ref)
555         emit(state, state->ref_sample, strm->bits_per_sample);
556
557     for (i = 0; i < strm->block_size; i+= 2) {
558         d = state->block[i] + state->block[i + 1];
559         emitfs(state, d * (d + 1) / 2 + state->block[i + 1]);
560     }
561
562     return m_flush_block(strm);
563 }
564
565 static int m_encode_zero(struct aec_stream *strm)
566 {
567     struct internal_state *state = strm->state;
568
569     emit(state, 0, state->id_len + 1);
570
571     if (state->zero_ref)
572         emit(state, state->zero_ref_sample, strm->bits_per_sample);
573
574     if (state->zero_blocks == ROS)
575         emitfs(state, 4);
576     else if (state->zero_blocks >= 5)
577         emitfs(state, state->zero_blocks);
578     else
579         emitfs(state, state->zero_blocks - 1);
580
581     state->zero_blocks = 0;
582     return m_flush_block(strm);
583 }
584
585 static int m_select_code_option(struct aec_stream *strm)
586 {
587     /**
588        Decide which code option to use.
589     */
590
591     uint32_t split_len;
592     uint32_t se_len;
593     struct internal_state *state = strm->state;
594
595     if (state->id_len > 1)
596         split_len = assess_splitting_option(strm);
597     else
598         split_len = UINT32_MAX;
599     se_len = assess_se_option(strm);
600
601     if (split_len < state->uncomp_len) {
602         if (split_len < se_len)
603             return m_encode_splitting(strm);
604         else
605             return m_encode_se(strm);
606     } else {
607         if (state->uncomp_len <= se_len)
608             return m_encode_uncomp(strm);
609         else
610             return m_encode_se(strm);
611     }
612 }
613
614 static int m_check_zero_block(struct aec_stream *strm)
615 {
616     /**
617        Check if input block is all zero.
618
619        Aggregate consecutive zero blocks until we find !0 or reach the
620        end of a segment or RSI.
621     */
622
623     size_t i;
624     struct internal_state *state = strm->state;
625     uint32_t *p = state->block;
626
627     for (i = 0; i < strm->block_size; i++)
628         if (p[i] != 0)
629             break;
630
631     if (i < strm->block_size) {
632         if (state->zero_blocks) {
633             /* The current block isn't zero but we have to emit a
634              * previous zero block first. The current block will be
635              * flagged and handled later.
636              */
637             state->block_nonzero = 1;
638             state->mode = m_encode_zero;
639             return M_CONTINUE;
640         }
641         state->mode = m_select_code_option;
642         return M_CONTINUE;
643     } else {
644         state->zero_blocks++;
645         if (state->zero_blocks == 1) {
646             state->zero_ref = state->ref;
647             state->zero_ref_sample = state->ref_sample;
648         }
649         if (state->blocks_avail == 0 || state->blocks_dispensed % 64 == 0) {
650             if (state->zero_blocks > 4)
651                 state->zero_blocks = ROS;
652
653             state->mode = m_encode_zero;
654             return M_CONTINUE;
655         }
656         state->mode = m_get_block;
657         return M_CONTINUE;
658     }
659 }
660
661 static int m_get_rsi_resumable(struct aec_stream *strm)
662 {
663     /**
664        Get RSI while input buffer is short.
665
666        Let user provide more input. Once we got all input pad buffer
667        to full RSI.
668     */
669
670     struct internal_state *state = strm->state;
671
672     do {
673         if (strm->avail_in >= state->bytes_per_sample) {
674             state->data_raw[state->i] = state->get_sample(strm);
675         } else {
676             if (state->flush == AEC_FLUSH) {
677                 if (state->i > 0) {
678                     state->blocks_avail = state->i / strm->block_size - 1;
679                     if (state->i % strm->block_size)
680                         state->blocks_avail++;
681                     do
682                         state->data_raw[state->i] =
683                             state->data_raw[state->i - 1];
684                     while(++state->i < strm->rsi * strm->block_size);
685                 } else {
686                     /* Finish encoding by padding the last byte with
687                      * zero bits. */
688                     emit(state, 0, state->bits);
689                     if (strm->avail_out > 0) {
690                         if (!state->direct_out)
691                             *strm->next_out++ = *state->cds;
692                         strm->avail_out--;
693                         state->flushed = 1;
694                     }
695                     return M_EXIT;
696                 }
697             } else {
698                 return M_EXIT;
699             }
700         }
701     } while (++state->i < strm->rsi * strm->block_size);
702
703     if (strm->flags & AEC_DATA_PREPROCESS)
704         state->preprocess(strm);
705
706     return m_check_zero_block(strm);
707 }
708
709 static int m_get_block(struct aec_stream *strm)
710 {
711     /**
712        Provide the next block of preprocessed input data.
713
714        Pull in a whole Reference Sample Interval (RSI) of data if
715        block buffer is empty.
716     */
717
718     struct internal_state *state = strm->state;
719
720     init_output(strm);
721
722     if (state->block_nonzero) {
723         state->block_nonzero = 0;
724         state->mode = m_select_code_option;
725         return M_CONTINUE;
726     }
727
728     if (state->blocks_avail == 0) {
729         state->blocks_avail = strm->rsi - 1;
730         state->block = state->data_pp;
731         state->blocks_dispensed = 1;
732
733         if (strm->avail_in >= state->rsi_len) {
734             state->get_rsi(strm);
735             if (strm->flags & AEC_DATA_PREPROCESS)
736                 state->preprocess(strm);
737
738             return m_check_zero_block(strm);
739         } else {
740             state->i = 0;
741             state->mode = m_get_rsi_resumable;
742         }
743     } else {
744         if (state->ref) {
745             state->ref = 0;
746             state->uncomp_len = strm->block_size * strm->bits_per_sample;
747         }
748         state->block += strm->block_size;
749         state->blocks_dispensed++;
750         state->blocks_avail--;
751         return m_check_zero_block(strm);
752     }
753     return M_CONTINUE;
754 }
755
756 static void cleanup(struct aec_stream *strm)
757 {
758     struct internal_state *state = strm->state;
759
760     if (strm->flags & AEC_DATA_PREPROCESS && state->data_raw)
761         free(state->data_raw);
762     if (state->data_pp)
763         free(state->data_pp);
764     free(state);
765 }
766
767 /*
768  *
769  * API functions
770  *
771  */
772
773 int aec_encode_init(struct aec_stream *strm)
774 {
775     struct internal_state *state;
776
777     if (strm->bits_per_sample > 32 || strm->bits_per_sample == 0)
778         return AEC_CONF_ERROR;
779
780     if (strm->flags & AEC_NOT_ENFORCE) {
781         /* All even block sizes are allowed. */
782         if (strm->block_size & 1)
783             return AEC_CONF_ERROR;
784     } else {
785         /* Only allow standard conforming block sizes */
786         if (strm->block_size != 8
787             && strm->block_size != 16
788             && strm->block_size != 32
789             && strm->block_size != 64)
790             return AEC_CONF_ERROR;
791     }
792
793     if (strm->rsi > 4096)
794         return AEC_CONF_ERROR;
795
796     state = malloc(sizeof(struct internal_state));
797     if (state == NULL)
798         return AEC_MEM_ERROR;
799
800     memset(state, 0, sizeof(struct internal_state));
801     strm->state = state;
802     state->uncomp_len = strm->block_size * strm->bits_per_sample;
803
804     if (strm->bits_per_sample > 16) {
805         /* 24/32 input bit settings */
806         state->id_len = 5;
807
808         if (strm->bits_per_sample <= 24
809             && strm->flags & AEC_DATA_3BYTE) {
810             state->bytes_per_sample = 3;
811             if (strm->flags & AEC_DATA_MSB) {
812                 state->get_sample = aec_get_msb_24;
813                 state->get_rsi = aec_get_rsi_msb_24;
814             } else {
815                 state->get_sample = aec_get_lsb_24;
816                 state->get_rsi = aec_get_rsi_lsb_24;
817             }
818         } else {
819             state->bytes_per_sample = 4;
820             if (strm->flags & AEC_DATA_MSB) {
821                 state->get_sample = aec_get_msb_32;
822                 state->get_rsi = aec_get_rsi_msb_32;
823             } else {
824                 state->get_sample = aec_get_lsb_32;
825                 state->get_rsi = aec_get_rsi_lsb_32;
826             }
827         }
828     }
829     else if (strm->bits_per_sample > 8) {
830         /* 16 bit settings */
831         state->id_len = 4;
832         state->bytes_per_sample = 2;
833
834         if (strm->flags & AEC_DATA_MSB) {
835             state->get_sample = aec_get_msb_16;
836             state->get_rsi = aec_get_rsi_msb_16;
837         } else {
838             state->get_sample = aec_get_lsb_16;
839             state->get_rsi = aec_get_rsi_lsb_16;
840         }
841     } else {
842         /* 8 bit settings */
843         if (strm->flags & AEC_RESTRICTED) {
844             if (strm->bits_per_sample <= 4) {
845                 if (strm->bits_per_sample <= 2)
846                     state->id_len = 1;
847                 else
848                     state->id_len = 2;
849             } else {
850                 return AEC_CONF_ERROR;
851             }
852         } else {
853             state->id_len = 3;
854         }
855         state->bytes_per_sample = 1;
856
857         state->get_sample = aec_get_8;
858         state->get_rsi = aec_get_rsi_8;
859     }
860     state->rsi_len = strm->rsi * strm->block_size * state->bytes_per_sample;
861
862     if (strm->flags & AEC_DATA_SIGNED) {
863         state->xmax = UINT32_MAX >> (32 - strm->bits_per_sample + 1);
864         state->xmin = ~state->xmax;
865         state->preprocess = preprocess_signed;
866     } else {
867         state->xmin = 0;
868         state->xmax = UINT32_MAX >> (32 - strm->bits_per_sample);
869         state->preprocess = preprocess_unsigned;
870     }
871
872     state->kmax = (1U << state->id_len) - 3;
873
874     state->data_pp = malloc(strm->rsi
875                             * strm->block_size
876                             * sizeof(uint32_t));
877     if (state->data_pp == NULL) {
878         cleanup(strm);
879         return AEC_MEM_ERROR;
880     }
881
882     if (strm->flags & AEC_DATA_PREPROCESS) {
883         state->data_raw = malloc(strm->rsi
884                                  * strm->block_size
885                                  * sizeof(uint32_t));
886         if (state->data_raw == NULL) {
887             cleanup(strm);
888             return AEC_MEM_ERROR;
889         }
890     } else {
891         state->data_raw = state->data_pp;
892     }
893
894     state->block = state->data_pp;
895
896     state->ref = 0;
897     strm->total_in = 0;
898     strm->total_out = 0;
899     state->flushed = 0;
900
901     state->cds = state->cds_buf;
902     *state->cds = 0;
903     state->bits = 8;
904     state->mode = m_get_block;
905
906     return AEC_OK;
907 }
908
909 int aec_encode(struct aec_stream *strm, int flush)
910 {
911     /**
912        Finite-state machine implementation of the adaptive entropy
913        encoder.
914     */
915     int n;
916     struct internal_state *state = strm->state;
917
918     state->flush = flush;
919     strm->total_in += strm->avail_in;
920     strm->total_out += strm->avail_out;
921
922     while (state->mode(strm) == M_CONTINUE);
923
924     if (state->direct_out) {
925         n = (int)(state->cds - strm->next_out);
926         strm->next_out += n;
927         strm->avail_out -= n;
928
929         *state->cds_buf = *state->cds;
930         state->cds = state->cds_buf;
931         state->direct_out = 0;
932     }
933     strm->total_in -= strm->avail_in;
934     strm->total_out -= strm->avail_out;
935     return AEC_OK;
936 }
937
938 int aec_encode_end(struct aec_stream *strm)
939 {
940     struct internal_state *state = strm->state;
941     int status;
942
943     status = AEC_OK;
944     if (state->flush == AEC_FLUSH && state->flushed == 0)
945         status = AEC_STREAM_ERROR;
946     cleanup(strm);
947     return status;
948 }
949
950 int aec_buffer_encode(struct aec_stream *strm)
951 {
952     int status;
953
954     status = aec_encode_init(strm);
955     if (status != AEC_OK)
956         return status;
957     status = aec_encode(strm, AEC_FLUSH);
958     if (status != AEC_OK) {
959         cleanup(strm);
960         return status;
961     }
962     return aec_encode_end(strm);
963 }