Merge tag 'mm-hotfixes-stable-2023-10-24-09-40' of git://git.kernel.org/pub/scm/linux...
[platform/kernel/linux-rpi.git] / fs / ntfs3 / lznt.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  *
4  * Copyright (C) 2019-2021 Paragon Software GmbH, All rights reserved.
5  *
6  */
7
8 #include <linux/kernel.h>
9 #include <linux/slab.h>
10 #include <linux/stddef.h>
11 #include <linux/string.h>
12 #include <linux/types.h>
13
14 #include "debug.h"
15 #include "ntfs_fs.h"
16
17 // clang-format off
18 /* Src buffer is zero. */
19 #define LZNT_ERROR_ALL_ZEROS    1
20 #define LZNT_CHUNK_SIZE         0x1000
21 // clang-format on
22
23 struct lznt_hash {
24         const u8 *p1;
25         const u8 *p2;
26 };
27
28 struct lznt {
29         const u8 *unc;
30         const u8 *unc_end;
31         const u8 *best_match;
32         size_t max_len;
33         bool std;
34
35         struct lznt_hash hash[LZNT_CHUNK_SIZE];
36 };
37
38 static inline size_t get_match_len(const u8 *ptr, const u8 *end, const u8 *prev,
39                                    size_t max_len)
40 {
41         size_t len = 0;
42
43         while (ptr + len < end && ptr[len] == prev[len] && ++len < max_len)
44                 ;
45         return len;
46 }
47
48 static size_t longest_match_std(const u8 *src, struct lznt *ctx)
49 {
50         size_t hash_index;
51         size_t len1 = 0, len2 = 0;
52         const u8 **hash;
53
54         hash_index =
55                 ((40543U * ((((src[0] << 4) ^ src[1]) << 4) ^ src[2])) >> 4) &
56                 (LZNT_CHUNK_SIZE - 1);
57
58         hash = &(ctx->hash[hash_index].p1);
59
60         if (hash[0] >= ctx->unc && hash[0] < src && hash[0][0] == src[0] &&
61             hash[0][1] == src[1] && hash[0][2] == src[2]) {
62                 len1 = 3;
63                 if (ctx->max_len > 3)
64                         len1 += get_match_len(src + 3, ctx->unc_end,
65                                               hash[0] + 3, ctx->max_len - 3);
66         }
67
68         if (hash[1] >= ctx->unc && hash[1] < src && hash[1][0] == src[0] &&
69             hash[1][1] == src[1] && hash[1][2] == src[2]) {
70                 len2 = 3;
71                 if (ctx->max_len > 3)
72                         len2 += get_match_len(src + 3, ctx->unc_end,
73                                               hash[1] + 3, ctx->max_len - 3);
74         }
75
76         /* Compare two matches and select the best one. */
77         if (len1 < len2) {
78                 ctx->best_match = hash[1];
79                 len1 = len2;
80         } else {
81                 ctx->best_match = hash[0];
82         }
83
84         hash[1] = hash[0];
85         hash[0] = src;
86         return len1;
87 }
88
89 static size_t longest_match_best(const u8 *src, struct lznt *ctx)
90 {
91         size_t max_len;
92         const u8 *ptr;
93
94         if (ctx->unc >= src || !ctx->max_len)
95                 return 0;
96
97         max_len = 0;
98         for (ptr = ctx->unc; ptr < src; ++ptr) {
99                 size_t len =
100                         get_match_len(src, ctx->unc_end, ptr, ctx->max_len);
101                 if (len >= max_len) {
102                         max_len = len;
103                         ctx->best_match = ptr;
104                 }
105         }
106
107         return max_len >= 3 ? max_len : 0;
108 }
109
110 static const size_t s_max_len[] = {
111         0x1002, 0x802, 0x402, 0x202, 0x102, 0x82, 0x42, 0x22, 0x12,
112 };
113
114 static const size_t s_max_off[] = {
115         0x10, 0x20, 0x40, 0x80, 0x100, 0x200, 0x400, 0x800, 0x1000,
116 };
117
118 static inline u16 make_pair(size_t offset, size_t len, size_t index)
119 {
120         return ((offset - 1) << (12 - index)) |
121                ((len - 3) & (((1 << (12 - index)) - 1)));
122 }
123
124 static inline size_t parse_pair(u16 pair, size_t *offset, size_t index)
125 {
126         *offset = 1 + (pair >> (12 - index));
127         return 3 + (pair & ((1 << (12 - index)) - 1));
128 }
129
130 /*
131  * compress_chunk
132  *
133  * Return:
134  * * 0  - Ok, @cmpr contains @cmpr_chunk_size bytes of compressed data.
135  * * 1  - Input buffer is full zero.
136  * * -2 - The compressed buffer is too small to hold the compressed data.
137  */
138 static inline int compress_chunk(size_t (*match)(const u8 *, struct lznt *),
139                                  const u8 *unc, const u8 *unc_end, u8 *cmpr,
140                                  u8 *cmpr_end, size_t *cmpr_chunk_size,
141                                  struct lznt *ctx)
142 {
143         size_t cnt = 0;
144         size_t idx = 0;
145         const u8 *up = unc;
146         u8 *cp = cmpr + 3;
147         u8 *cp2 = cmpr + 2;
148         u8 not_zero = 0;
149         /* Control byte of 8-bit values: ( 0 - means byte as is, 1 - short pair ). */
150         u8 ohdr = 0;
151         u8 *last;
152         u16 t16;
153
154         if (unc + LZNT_CHUNK_SIZE < unc_end)
155                 unc_end = unc + LZNT_CHUNK_SIZE;
156
157         last = min(cmpr + LZNT_CHUNK_SIZE + sizeof(short), cmpr_end);
158
159         ctx->unc = unc;
160         ctx->unc_end = unc_end;
161         ctx->max_len = s_max_len[0];
162
163         while (up < unc_end) {
164                 size_t max_len;
165
166                 while (unc + s_max_off[idx] < up)
167                         ctx->max_len = s_max_len[++idx];
168
169                 /* Find match. */
170                 max_len = up + 3 <= unc_end ? (*match)(up, ctx) : 0;
171
172                 if (!max_len) {
173                         if (cp >= last)
174                                 goto NotCompressed;
175                         not_zero |= *cp++ = *up++;
176                 } else if (cp + 1 >= last) {
177                         goto NotCompressed;
178                 } else {
179                         t16 = make_pair(up - ctx->best_match, max_len, idx);
180                         *cp++ = t16;
181                         *cp++ = t16 >> 8;
182
183                         ohdr |= 1 << cnt;
184                         up += max_len;
185                 }
186
187                 cnt = (cnt + 1) & 7;
188                 if (!cnt) {
189                         *cp2 = ohdr;
190                         ohdr = 0;
191                         cp2 = cp;
192                         cp += 1;
193                 }
194         }
195
196         if (cp2 < last)
197                 *cp2 = ohdr;
198         else
199                 cp -= 1;
200
201         *cmpr_chunk_size = cp - cmpr;
202
203         t16 = (*cmpr_chunk_size - 3) | 0xB000;
204         cmpr[0] = t16;
205         cmpr[1] = t16 >> 8;
206
207         return not_zero ? 0 : LZNT_ERROR_ALL_ZEROS;
208
209 NotCompressed:
210
211         if ((cmpr + LZNT_CHUNK_SIZE + sizeof(short)) > last)
212                 return -2;
213
214         /*
215          * Copy non cmpr data.
216          * 0x3FFF == ((LZNT_CHUNK_SIZE + 2 - 3) | 0x3000)
217          */
218         cmpr[0] = 0xff;
219         cmpr[1] = 0x3f;
220
221         memcpy(cmpr + sizeof(short), unc, LZNT_CHUNK_SIZE);
222         *cmpr_chunk_size = LZNT_CHUNK_SIZE + sizeof(short);
223
224         return 0;
225 }
226
227 static inline ssize_t decompress_chunk(u8 *unc, u8 *unc_end, const u8 *cmpr,
228                                        const u8 *cmpr_end)
229 {
230         u8 *up = unc;
231         u8 ch = *cmpr++;
232         size_t bit = 0;
233         size_t index = 0;
234         u16 pair;
235         size_t offset, length;
236
237         /* Do decompression until pointers are inside range. */
238         while (up < unc_end && cmpr < cmpr_end) {
239                 /* Correct index */
240                 while (unc + s_max_off[index] < up)
241                         index += 1;
242
243                 /* Check the current flag for zero. */
244                 if (!(ch & (1 << bit))) {
245                         /* Just copy byte. */
246                         *up++ = *cmpr++;
247                         goto next;
248                 }
249
250                 /* Check for boundary. */
251                 if (cmpr + 1 >= cmpr_end)
252                         return -EINVAL;
253
254                 /* Read a short from little endian stream. */
255                 pair = cmpr[1];
256                 pair <<= 8;
257                 pair |= cmpr[0];
258
259                 cmpr += 2;
260
261                 /* Translate packed information into offset and length. */
262                 length = parse_pair(pair, &offset, index);
263
264                 /* Check offset for boundary. */
265                 if (unc + offset > up)
266                         return -EINVAL;
267
268                 /* Truncate the length if necessary. */
269                 if (up + length >= unc_end)
270                         length = unc_end - up;
271
272                 /* Now we copy bytes. This is the heart of LZ algorithm. */
273                 for (; length > 0; length--, up++)
274                         *up = *(up - offset);
275
276 next:
277                 /* Advance flag bit value. */
278                 bit = (bit + 1) & 7;
279
280                 if (!bit) {
281                         if (cmpr >= cmpr_end)
282                                 break;
283
284                         ch = *cmpr++;
285                 }
286         }
287
288         /* Return the size of uncompressed data. */
289         return up - unc;
290 }
291
292 /*
293  * get_lznt_ctx
294  * @level: 0 - Standard compression.
295  *         !0 - Best compression, requires a lot of cpu.
296  */
297 struct lznt *get_lznt_ctx(int level)
298 {
299         struct lznt *r = kzalloc(level ? offsetof(struct lznt, hash) :
300                                          sizeof(struct lznt),
301                                  GFP_NOFS);
302
303         if (r)
304                 r->std = !level;
305         return r;
306 }
307
308 /*
309  * compress_lznt - Compresses @unc into @cmpr
310  *
311  * Return:
312  * * +x - Ok, @cmpr contains 'final_compressed_size' bytes of compressed data.
313  * * 0 - Input buffer is full zero.
314  */
315 size_t compress_lznt(const void *unc, size_t unc_size, void *cmpr,
316                      size_t cmpr_size, struct lznt *ctx)
317 {
318         int err;
319         size_t (*match)(const u8 *src, struct lznt *ctx);
320         u8 *p = cmpr;
321         u8 *end = p + cmpr_size;
322         const u8 *unc_chunk = unc;
323         const u8 *unc_end = unc_chunk + unc_size;
324         bool is_zero = true;
325
326         if (ctx->std) {
327                 match = &longest_match_std;
328                 memset(ctx->hash, 0, sizeof(ctx->hash));
329         } else {
330                 match = &longest_match_best;
331         }
332
333         /* Compression cycle. */
334         for (; unc_chunk < unc_end; unc_chunk += LZNT_CHUNK_SIZE) {
335                 cmpr_size = 0;
336                 err = compress_chunk(match, unc_chunk, unc_end, p, end,
337                                      &cmpr_size, ctx);
338                 if (err < 0)
339                         return unc_size;
340
341                 if (is_zero && err != LZNT_ERROR_ALL_ZEROS)
342                         is_zero = false;
343
344                 p += cmpr_size;
345         }
346
347         if (p <= end - 2)
348                 p[0] = p[1] = 0;
349
350         return is_zero ? 0 : PtrOffset(cmpr, p);
351 }
352
353 /*
354  * decompress_lznt - Decompress @cmpr into @unc.
355  */
356 ssize_t decompress_lznt(const void *cmpr, size_t cmpr_size, void *unc,
357                         size_t unc_size)
358 {
359         const u8 *cmpr_chunk = cmpr;
360         const u8 *cmpr_end = cmpr_chunk + cmpr_size;
361         u8 *unc_chunk = unc;
362         u8 *unc_end = unc_chunk + unc_size;
363         u16 chunk_hdr;
364
365         if (cmpr_size < sizeof(short))
366                 return -EINVAL;
367
368         /* Read chunk header. */
369         chunk_hdr = cmpr_chunk[1];
370         chunk_hdr <<= 8;
371         chunk_hdr |= cmpr_chunk[0];
372
373         /* Loop through decompressing chunks. */
374         for (;;) {
375                 size_t chunk_size_saved;
376                 size_t unc_use;
377                 size_t cmpr_use = 3 + (chunk_hdr & (LZNT_CHUNK_SIZE - 1));
378
379                 /* Check that the chunk actually fits the supplied buffer. */
380                 if (cmpr_chunk + cmpr_use > cmpr_end)
381                         return -EINVAL;
382
383                 /* First make sure the chunk contains compressed data. */
384                 if (chunk_hdr & 0x8000) {
385                         /* Decompress a chunk and return if we get an error. */
386                         ssize_t err =
387                                 decompress_chunk(unc_chunk, unc_end,
388                                                  cmpr_chunk + sizeof(chunk_hdr),
389                                                  cmpr_chunk + cmpr_use);
390                         if (err < 0)
391                                 return err;
392                         unc_use = err;
393                 } else {
394                         /* This chunk does not contain compressed data. */
395                         unc_use = unc_chunk + LZNT_CHUNK_SIZE > unc_end ?
396                                           unc_end - unc_chunk :
397                                           LZNT_CHUNK_SIZE;
398
399                         if (cmpr_chunk + sizeof(chunk_hdr) + unc_use >
400                             cmpr_end) {
401                                 return -EINVAL;
402                         }
403
404                         memcpy(unc_chunk, cmpr_chunk + sizeof(chunk_hdr),
405                                unc_use);
406                 }
407
408                 /* Advance pointers. */
409                 cmpr_chunk += cmpr_use;
410                 unc_chunk += unc_use;
411
412                 /* Check for the end of unc buffer. */
413                 if (unc_chunk >= unc_end)
414                         break;
415
416                 /* Proceed the next chunk. */
417                 if (cmpr_chunk > cmpr_end - 2)
418                         break;
419
420                 chunk_size_saved = LZNT_CHUNK_SIZE;
421
422                 /* Read chunk header. */
423                 chunk_hdr = cmpr_chunk[1];
424                 chunk_hdr <<= 8;
425                 chunk_hdr |= cmpr_chunk[0];
426
427                 if (!chunk_hdr)
428                         break;
429
430                 /* Check the size of unc buffer. */
431                 if (unc_use < chunk_size_saved) {
432                         size_t t1 = chunk_size_saved - unc_use;
433                         u8 *t2 = unc_chunk + t1;
434
435                         /* 'Zero' memory. */
436                         if (t2 >= unc_end)
437                                 break;
438
439                         memset(unc_chunk, 0, t1);
440                         unc_chunk = t2;
441                 }
442         }
443
444         /* Check compression boundary. */
445         if (cmpr_chunk > cmpr_end)
446                 return -EINVAL;
447
448         /*
449          * The unc size is just a difference between current
450          * pointer and original one.
451          */
452         return PtrOffset(unc, unc_chunk);
453 }