Fix another zero block bug with check
[platform/upstream/libaec.git] / src / encode.c
1 /**
2  * @file encode.c
3  * @author Mathis Rosenhauer, Deutsches Klimarechenzentrum
4  * @section DESCRIPTION
5  *
6  * Adaptive Entropy Encoder
7  * Based on CCSDS documents 121.0-B-2 and 120.0-G-2
8  *
9  */
10
11 #include <config.h>
12
13 #if HAVE_STDINT_H
14 # include <stdint.h>
15 #endif
16
17 #include <stdio.h>
18 #include <stdlib.h>
19 #include <unistd.h>
20 #include <string.h>
21
22 #include "libaec.h"
23 #include "encode.h"
24 #include "encode_accessors.h"
25
26 /* Marker for Remainder Of Segment condition in zero block encoding */
27 #define ROS -1
28
29 static int m_get_block(struct aec_stream *strm);
30 static int m_get_block_cautious(struct aec_stream *strm);
31 static int m_check_zero_block(struct aec_stream *strm);
32 static int m_select_code_option(struct aec_stream *strm);
33 static int m_flush_block(struct aec_stream *strm);
34 static int m_flush_block_cautious(struct aec_stream *strm);
35 static int m_encode_splitting(struct aec_stream *strm);
36 static int m_encode_uncomp(struct aec_stream *strm);
37 static int m_encode_se(struct aec_stream *strm);
38 static int m_encode_zero(struct aec_stream *strm);
39
40 static inline void emit(struct internal_state *state,
41                         uint32_t data, int bits)
42 {
43     /**
44        Emit sequence of bits.
45      */
46
47     if (bits <= state->bit_p) {
48         state->bit_p -= bits;
49         *state->cds_p += data << state->bit_p;
50     } else {
51         bits -= state->bit_p;
52         *state->cds_p++ += (uint64_t)data >> bits;
53
54         while (bits & ~7) {
55             bits -= 8;
56             *state->cds_p++ = data >> bits;
57         }
58
59         state->bit_p = 8 - bits;
60         *state->cds_p = data << state->bit_p;
61     }
62 }
63
64 static inline void emitfs(struct internal_state *state, int fs)
65 {
66     /**
67        Emits a fundamental sequence.
68
69        fs zero bits followed by one 1 bit.
70      */
71
72     for(;;) {
73         if (fs < state->bit_p) {
74             state->bit_p -= fs + 1;
75             *state->cds_p += 1 << state->bit_p;
76             break;
77         } else {
78             fs -= state->bit_p;
79             *++state->cds_p = 0;
80             state->bit_p = 8;
81         }
82     }
83 }
84
85 #define EMITBLOCK(ref)                                          \
86     static inline void emitblock_##ref(struct aec_stream *strm, \
87                                        int k)                   \
88     {                                                           \
89         int b;                                                  \
90         uint64_t a;                                             \
91         struct internal_state *state = strm->state;             \
92         uint32_t *in = state->block_p + ref;                    \
93         uint32_t *in_end = state->block_p + strm->block_size;   \
94         uint64_t mask = (1ULL << k) - 1;                        \
95         uint8_t *o = state->cds_p;                              \
96         int p = state->bit_p;                                   \
97                                                                 \
98         a = *o;                                                 \
99                                                                 \
100         while(in < in_end) {                                    \
101             a <<= 56;                                           \
102             p = (p % 8) + 56;                                   \
103                                                                 \
104             while (p > k && in < in_end) {                      \
105                 p -= k;                                         \
106                 a += ((uint64_t)(*in++) & mask) << p;           \
107             }                                                   \
108                                                                 \
109             for (b = 56; b > (p & ~7); b -= 8)                  \
110                 *o++ = a >> b;                                  \
111             a >>= b;                                            \
112         }                                                       \
113                                                                 \
114         *o = a;                                                 \
115         state->cds_p = o;                                       \
116         state->bit_p = p % 8;                                   \
117     }
118
119 EMITBLOCK(0);
120 EMITBLOCK(1);
121
122 static void preprocess_unsigned(struct aec_stream *strm)
123 {
124     int64_t d;
125     struct internal_state *state = strm->state;
126     uint32_t *x = state->block_buf;
127     int64_t x1 = *x++;
128     uint32_t xmax = state->xmax;
129     uint32_t rsi = strm->rsi * strm->block_size - 1;
130
131     while (rsi--) {
132         if (*x >= x1) {
133             d = *x - x1;
134             if (d <= x1) {
135                 x1 = *x;
136                 *x = 2 * d;
137             } else {
138                 x1 = *x;
139             }
140         } else {
141             d = x1 - *x;
142             if (d <= xmax - x1) {
143                 x1 = *x;
144                 *x = 2 * d - 1;
145             } else {
146                 x1 = *x;
147                 *x = xmax - *x;
148             }
149         }
150         x++;
151     }
152 }
153
154 static void preprocess_signed(struct aec_stream *strm)
155 {
156     int64_t d;
157     int64_t x;
158     struct internal_state *state = strm->state;
159     uint32_t *buf = state->block_buf;
160     uint32_t m = 1ULL << (strm->bit_per_sample - 1);
161     int64_t x1 = (((int64_t)*buf++) ^ m) - m;
162     int64_t xmax = state->xmax;
163     int64_t xmin = state->xmin;
164     uint32_t rsi = strm->rsi * strm->block_size - 1;
165
166     while (rsi--) {
167         x = (((int64_t)*buf) ^ m) - m;
168         if (x < x1) {
169             d = x1 - x;
170             if (d <= xmax - x1)
171                 *buf = 2 * d - 1;
172             else
173                 *buf = xmax - x;
174         } else {
175             d = x - x1;
176             if (d <= x1 - xmin)
177                 *buf = 2 * d;
178             else
179                 *buf = x - xmin;
180         }
181         x1 = x;
182         buf++;
183     }
184 }
185
186 /*
187  *
188  * FSM functions
189  *
190  */
191
192 static int m_get_block(struct aec_stream *strm)
193 {
194     struct internal_state *state = strm->state;
195
196     if (strm->avail_out > state->cds_len) {
197         if (!state->direct_out) {
198             state->direct_out = 1;
199             *strm->next_out = *state->cds_p;
200             state->cds_p = strm->next_out;
201         }
202     } else {
203         if (state->zero_blocks == 0 || state->direct_out) {
204             /* copy leftover from last block */
205             *state->cds_buf = *state->cds_p;
206             state->cds_p = state->cds_buf;
207         }
208         state->direct_out = 0;
209     }
210
211     if (state->blocks_avail == 0) {
212         state->ref = 1;
213         state->block_p = state->block_buf;
214
215         if (strm->avail_in >= state->block_len * strm->rsi) {
216             state->get_rsi(strm);
217             state->blocks_avail = strm->rsi - 1;
218
219             if (strm->flags & AEC_DATA_PREPROCESS)
220                 state->preprocess(strm);
221
222             return m_check_zero_block(strm);
223         } else {
224             state->i = 0;
225             state->mode = m_get_block_cautious;
226         }
227     } else {
228         state->ref = 0;
229         state->block_p += strm->block_size;
230         state->blocks_avail--;
231         return m_check_zero_block(strm);
232     }
233     return M_CONTINUE;
234 }
235
236 static int m_get_block_cautious(struct aec_stream *strm)
237 {
238     int j;
239     struct internal_state *state = strm->state;
240
241     do {
242         if (strm->avail_in > 0) {
243             state->block_buf[state->i] = state->get_sample(strm);
244         } else {
245             if (state->flush == AEC_FLUSH) {
246                 if (state->i > 0) {
247                     for (j = state->i; j < strm->rsi * strm->block_size; j++)
248                         state->block_buf[j] = state->block_buf[state->i - 1];
249                     state->i = strm->rsi * strm->block_size;
250                 } else {
251                     if (state->zero_blocks) {
252                         state->mode = m_encode_zero;
253                         return M_CONTINUE;
254                     }
255
256                     emit(state, 0, state->bit_p);
257                     if (state->direct_out == 0)
258                         *strm->next_out++ = *state->cds_p;
259                     strm->avail_out--;
260                     strm->total_out++;
261
262                     return M_EXIT;
263                 }
264             } else {
265                 return M_EXIT;
266             }
267         }
268     } while (++state->i < strm->rsi * strm->block_size);
269
270     state->blocks_avail = strm->rsi - 1;
271     if (strm->flags & AEC_DATA_PREPROCESS)
272         state->preprocess(strm);
273
274     return m_check_zero_block(strm);
275 }
276
277 static int m_check_zero_block(struct aec_stream *strm)
278 {
279     struct internal_state *state = strm->state;
280     uint32_t *p = state->block_p + state->ref;
281     uint32_t *end = state->block_p + strm->block_size;
282
283     while(*p == 0 && p < end)
284         p++;
285
286     if (p < end) {
287         if (state->zero_blocks) {
288             /* The current block isn't zero but we have to emit a
289              * previous zero block first. The current block will be
290              * handled later.
291              */
292             state->block_p -= strm->block_size;
293             state->blocks_avail++;
294             state->mode = m_encode_zero;
295             return M_CONTINUE;
296         }
297         state->mode = m_select_code_option;
298         return M_CONTINUE;
299     } else {
300         state->zero_blocks++;
301         if (state->zero_blocks == 1) {
302             state->zero_ref = state->ref;
303             state->zero_ref_sample = state->block_p[0];
304         }
305         if (state->blocks_avail == 0
306             || (strm->rsi - state->blocks_avail) % 64 == 0) {
307             if (state->zero_blocks > 4)
308                 state->zero_blocks = ROS;
309             state->mode = m_encode_zero;
310             return M_CONTINUE;
311         }
312         state->mode = m_get_block;
313         return M_CONTINUE;
314     }
315 }
316
317 static uint64_t block_fs(struct aec_stream *strm, int k)
318 {
319     int j;
320     uint64_t fs;
321     struct internal_state *state = strm->state;
322
323     fs = (uint64_t)(state->block_p[1] >> k)
324         + (uint64_t)(state->block_p[2] >> k)
325         + (uint64_t)(state->block_p[3] >> k)
326         + (uint64_t)(state->block_p[4] >> k)
327         + (uint64_t)(state->block_p[5] >> k)
328         + (uint64_t)(state->block_p[6] >> k)
329         + (uint64_t)(state->block_p[7] >> k);
330
331     if (strm->block_size > 8)
332         for (j = 8; j < strm->block_size; j += 8)
333             fs +=
334                 (uint64_t)(state->block_p[j + 0] >> k)
335                 + (uint64_t)(state->block_p[j + 1] >> k)
336                 + (uint64_t)(state->block_p[j + 2] >> k)
337                 + (uint64_t)(state->block_p[j + 3] >> k)
338                 + (uint64_t)(state->block_p[j + 4] >> k)
339                 + (uint64_t)(state->block_p[j + 5] >> k)
340                 + (uint64_t)(state->block_p[j + 6] >> k)
341                 + (uint64_t)(state->block_p[j + 7] >> k);
342
343     if (!state->ref)
344         fs += (uint64_t)(state->block_p[0] >> k);
345
346     return fs;
347 }
348
349 static int count_splitting_option(struct aec_stream *strm)
350 {
351     /**
352        Find the best point for splitting samples in a block.
353
354        In Rice coding each sample in a block of samples is split at
355        the same position into k LSB and bit_per_sample - k MSB. The
356        LSB part is left binary and the MSB part is coded as a
357        fundamental sequence a.k.a. unary (see CCSDS 121.0-B-2). The
358        function of the length of the Coded Data Set (CDS) depending on
359        k has exactly one minimum (see A. Kiely, IPN Progress Report
360        42-159).
361
362        To find that minimum with only a few costly evaluations of the
363        CDS length, we start with the k of the previous CDS. K is
364        increased and the CDS length evaluated. If the CDS length gets
365        smaller, then we are moving towards the minimum. If the length
366        increases, then the minimum will be found with smaller k.
367
368        For increasing k we know that we will gain block_size bits in
369        length through the larger binary part. If the FS lenth is less
370        than the block size then a reduced FS part can't compensate the
371        larger binary part. So we know that the CDS for k+1 will be
372        larger than for k without actually computing the length. An
373        analogue check can be done for decreasing k.
374      */
375
376     int k, k_min;
377     int this_bs; /* Block size of current block */
378     int no_turn; /* 1 if we shouldn't reverse */
379     int dir; /* Direction, 1 means increasing k, 0 decreasing k */
380     uint64_t len; /* CDS length for current k */
381     uint64_t len_min; /* CDS length minimum so far */
382     uint64_t fs_len; /* Length of FS part (not including 1s) */
383
384     struct internal_state *state = strm->state;
385
386     this_bs = strm->block_size - state->ref;
387     len_min = UINT64_MAX;
388     k = k_min = state->k;
389     no_turn = (k == 0) ? 1 : 0;
390     dir = 1;
391
392     for (;;) {
393         fs_len = block_fs(strm, k);
394         len = fs_len + this_bs * (k + 1);
395
396         if (len < len_min) {
397             if (len_min < UINT64_MAX)
398                 no_turn = 1;
399
400             len_min = len;
401             k_min = k;
402
403             if (dir) {
404                 if (fs_len < this_bs || k >= state->kmax) {
405                     if (no_turn)
406                         break;
407                     k = state->k - 1;
408                     dir = 0;
409                     no_turn = 1;
410                 } else {
411                     k++;
412                 }
413             } else {
414                 if (fs_len >= this_bs || k == 0)
415                     break;
416                 k--;
417             }
418         } else {
419             if (no_turn)
420                 break;
421             k = state->k - 1;
422             dir = 0;
423             no_turn = 1;
424         }
425     }
426     state->k = k_min;
427
428     return len_min;
429 }
430
431 static int count_se_option(uint64_t limit, struct aec_stream *strm)
432 {
433     int i;
434     uint64_t d, len;
435     struct internal_state *state = strm->state;
436
437     len = 1;
438
439     for (i = 0; i < strm->block_size; i+= 2) {
440         d = (uint64_t)state->block_p[i]
441             + (uint64_t)state->block_p[i + 1];
442         /* we have to worry about overflow here */
443         if (d > limit) {
444             len = UINT64_MAX;
445             break;
446         } else {
447             len += d * (d + 1) / 2
448                 + (uint64_t)state->block_p[i + 1];
449         }
450     }
451     return len;
452 }
453
454 static int m_select_code_option(struct aec_stream *strm)
455 {
456     uint64_t uncomp_len, split_len, se_len;
457     struct internal_state *state = strm->state;
458
459     uncomp_len = (strm->block_size - state->ref)
460         * strm->bit_per_sample;
461     split_len = count_splitting_option(strm);
462     se_len = count_se_option(split_len, strm);
463
464     if (split_len < uncomp_len) {
465         if (split_len < se_len)
466             return m_encode_splitting(strm);
467         else
468             return m_encode_se(strm);
469     } else {
470         if (uncomp_len <= se_len)
471             return m_encode_uncomp(strm);
472         else
473             return m_encode_se(strm);
474     }
475 }
476
477 static int m_encode_splitting(struct aec_stream *strm)
478 {
479     int i;
480     struct internal_state *state = strm->state;
481     int k = state->k;
482
483     emit(state, k + 1, state->id_len);
484
485     if (state->ref)
486     {
487         emit(state, state->block_p[0], strm->bit_per_sample);
488         for (i = 1; i < strm->block_size; i++)
489             emitfs(state, state->block_p[i] >> k);
490         if (k)
491             emitblock_1(strm, k);
492     }
493     else
494     {
495         for (i = 0; i < strm->block_size; i++)
496             emitfs(state, state->block_p[i] >> k);
497         if (k)
498             emitblock_0(strm, k);
499     }
500
501     return m_flush_block(strm);
502 }
503
504 static int m_encode_uncomp(struct aec_stream *strm)
505 {
506     struct internal_state *state = strm->state;
507
508     emit(state, (1U << state->id_len) - 1, state->id_len);
509     emitblock_0(strm, strm->bit_per_sample);
510
511     return m_flush_block(strm);
512 }
513
514 static int m_encode_se(struct aec_stream *strm)
515 {
516     int i;
517     uint32_t d;
518     struct internal_state *state = strm->state;
519
520     emit(state, 1, state->id_len + 1);
521     if (state->ref)
522         emit(state, state->block_p[0], strm->bit_per_sample);
523
524     for (i = 0; i < strm->block_size; i+= 2) {
525         d = state->block_p[i] + state->block_p[i + 1];
526         emitfs(state, d * (d + 1) / 2 + state->block_p[i + 1]);
527     }
528
529     return m_flush_block(strm);
530 }
531
532 static int m_encode_zero(struct aec_stream *strm)
533 {
534     struct internal_state *state = strm->state;
535
536     emit(state, 0, state->id_len + 1);
537
538     if (state->zero_ref)
539         emit(state, state->zero_ref_sample, strm->bit_per_sample);
540
541     if (state->zero_blocks == ROS)
542         emitfs(state, 4);
543     else if (state->zero_blocks >= 5)
544         emitfs(state, state->zero_blocks);
545     else
546         emitfs(state, state->zero_blocks - 1);
547
548     state->zero_blocks = 0;
549     return m_flush_block(strm);
550 }
551
552 static int m_flush_block(struct aec_stream *strm)
553 {
554     /**
555        Flush block in direct_out mode by updating counters.
556
557        Fall back to slow flushing if in buffered mode.
558     */
559     int n;
560     struct internal_state *state = strm->state;
561
562     if (state->direct_out) {
563         n = state->cds_p - strm->next_out;
564         strm->next_out += n;
565         strm->avail_out -= n;
566         strm->total_out += n;
567         state->mode = m_get_block;
568         return M_CONTINUE;
569     }
570
571     state->i = 0;
572     state->mode = m_flush_block_cautious;
573     return M_CONTINUE;
574 }
575
576 static int m_flush_block_cautious(struct aec_stream *strm)
577 {
578     /**
579        Slow and restartable flushing
580     */
581     struct internal_state *state = strm->state;
582
583     while(state->cds_buf + state->i < state->cds_p) {
584         if (strm->avail_out == 0)
585             return M_EXIT;
586
587         *strm->next_out++ = state->cds_buf[state->i];
588         strm->avail_out--;
589         strm->total_out++;
590         state->i++;
591     }
592     state->mode = m_get_block;
593     return M_CONTINUE;
594 }
595
596 /*
597  *
598  * API functions
599  *
600  */
601
602 int aec_encode_init(struct aec_stream *strm)
603 {
604     struct internal_state *state;
605
606     if (strm->bit_per_sample > 32 || strm->bit_per_sample == 0)
607         return AEC_CONF_ERROR;
608
609     if (strm->block_size != 8
610         && strm->block_size != 16
611         && strm->block_size != 32
612         && strm->block_size != 64)
613         return AEC_CONF_ERROR;
614
615     if (strm->rsi > 4096)
616         return AEC_CONF_ERROR;
617
618     state = (struct internal_state *)malloc(sizeof(struct internal_state));
619     if (state == NULL)
620         return AEC_MEM_ERROR;
621
622     memset(state, 0, sizeof(struct internal_state));
623     strm->state = state;
624
625     if (strm->bit_per_sample > 16) {
626         /* 24/32 input bit settings */
627         state->id_len = 5;
628
629         if (strm->bit_per_sample <= 24
630             && strm->flags & AEC_DATA_3BYTE) {
631             state->block_len = 3 * strm->block_size;
632             if (strm->flags & AEC_DATA_MSB) {
633                 state->get_sample = get_msb_24;
634                 state->get_rsi = get_rsi_msb_24;
635             } else {
636                 state->get_sample = get_lsb_24;
637                 state->get_rsi = get_rsi_lsb_24;
638             }
639         } else {
640             state->block_len = 4 * strm->block_size;
641             if (strm->flags & AEC_DATA_MSB) {
642                 state->get_sample = get_msb_32;
643                 state->get_rsi = get_rsi_msb_32;
644             } else {
645                 state->get_sample = get_lsb_32;
646                 state->get_rsi = get_rsi_lsb_32;
647             }
648         }
649     }
650     else if (strm->bit_per_sample > 8) {
651         /* 16 bit settings */
652         state->id_len = 4;
653         state->block_len = 2 * strm->block_size;
654
655         if (strm->flags & AEC_DATA_MSB) {
656             state->get_sample = get_msb_16;
657             state->get_rsi = get_rsi_msb_16;
658         } else {
659             state->get_sample = get_lsb_16;
660             state->get_rsi = get_rsi_lsb_16;
661         }
662     } else {
663         /* 8 bit settings */
664         state->id_len = 3;
665         state->block_len = strm->block_size;
666
667         state->get_sample = get_8;
668         state->get_rsi = get_rsi_8;
669     }
670
671     if (strm->flags & AEC_DATA_SIGNED) {
672         state->xmin = -(1ULL << (strm->bit_per_sample - 1));
673         state->xmax = (1ULL << (strm->bit_per_sample - 1)) - 1;
674         state->preprocess = preprocess_signed;
675     } else {
676         state->xmin = 0;
677         state->xmax = (1ULL << strm->bit_per_sample) - 1;
678         state->preprocess = preprocess_unsigned;
679     }
680
681     state->kmax = (1U << state->id_len) - 3;
682
683     state->block_buf = (uint32_t *)malloc(strm->rsi
684                                          * strm->block_size
685                                          * sizeof(uint32_t));
686     if (state->block_buf == NULL)
687         return AEC_MEM_ERROR;
688
689     state->block_p = state->block_buf;
690
691     /* Largest possible CDS according to specs */
692     state->cds_len = (5 + 64 * 32) / 8 + 3;
693     state->cds_buf = (uint8_t *)malloc(state->cds_len);
694     if (state->cds_buf == NULL)
695         return AEC_MEM_ERROR;
696
697     strm->total_in = 0;
698     strm->total_out = 0;
699
700     state->cds_p = state->cds_buf;
701     *state->cds_p = 0;
702     state->bit_p = 8;
703     state->mode = m_get_block;
704
705     return AEC_OK;
706 }
707
708 int aec_encode(struct aec_stream *strm, int flush)
709 {
710     /**
711        Finite-state machine implementation of the adaptive entropy
712        encoder.
713     */
714     int n;
715     struct internal_state *state;
716     state = strm->state;
717     state->flush = flush;
718
719     while (state->mode(strm) == M_CONTINUE);
720
721     if (state->direct_out) {
722         n = state->cds_p - strm->next_out;
723         strm->next_out += n;
724         strm->avail_out -= n;
725         strm->total_out += n;
726
727         *state->cds_buf = *state->cds_p;
728         state->cds_p = state->cds_buf;
729         state->direct_out = 0;
730     }
731     return AEC_OK;
732 }
733
734 int aec_encode_end(struct aec_stream *strm)
735 {
736     struct internal_state *state = strm->state;
737
738     free(state->block_buf);
739     free(state->cds_buf);
740     free(state);
741     return AEC_OK;
742 }
743
744 int aec_buffer_encode(struct aec_stream *strm)
745 {
746     int status;
747
748     status = aec_encode_init(strm);
749     if (status != AEC_OK)
750         return status;
751     status = aec_encode(strm, AEC_FLUSH);
752     if (strm->avail_in > 0 || strm->avail_out == 0)
753         status = AEC_DATA_ERROR;
754
755     aec_encode_end(strm);
756     return status;
757 }