Fix CVE-2017-6891 in minitasn1 code
[platform/upstream/gnutls.git] / lib / gnutls_compress.c
1 /*
2  * Copyright (C) 2000-2012 Free Software Foundation, Inc.
3  *
4  * Author: Nikos Mavrogiannopoulos
5  *
6  * This file is part of GnuTLS.
7  *
8  * The GnuTLS is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public License
10  * as published by the Free Software Foundation; either version 2.1 of
11  * the License, or (at your option) any later version.
12  *
13  * This library is distributed in the hope that it will be useful, but
14  * WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public License
19  * along with this program.  If not, see <http://www.gnu.org/licenses/>
20  *
21  */
22
23 /* This file contains the functions which convert the TLS plaintext
24  * packet to TLS compressed packet.
25  */
26
27 #include "gnutls_int.h"
28 #include "gnutls_compress.h"
29 #include "gnutls_errors.h"
30 #include "gnutls_constate.h"
31 #include <algorithms.h>
32 #include <gnutls/gnutls.h>
33
34 /* Compression Section */
35 #define GNUTLS_COMPRESSION_ENTRY(name, id, wb, ml, cl)  \
36   { #name, name, id, wb, ml, cl}
37
38
39 #define MAX_COMP_METHODS 5
40 const int _gnutls_comp_algorithms_size = MAX_COMP_METHODS;
41
42 gnutls_compression_entry _gnutls_compression_algorithms[MAX_COMP_METHODS] = {
43         GNUTLS_COMPRESSION_ENTRY(GNUTLS_COMP_NULL, 0x00, 0, 0, 0),
44 #ifdef HAVE_LIBZ
45         /* draft-ietf-tls-compression-02 */
46         GNUTLS_COMPRESSION_ENTRY(GNUTLS_COMP_DEFLATE, 0x01, 15, 8, 3),
47 #endif
48         {0, 0, 0, 0, 0, 0}
49 };
50
51 static const gnutls_compression_method_t supported_compressions[] = {
52 #ifdef HAVE_LIBZ
53         GNUTLS_COMP_DEFLATE,
54 #endif
55         GNUTLS_COMP_NULL,
56         0
57 };
58
59 #define GNUTLS_COMPRESSION_LOOP(b)         \
60   const gnutls_compression_entry *p;                                    \
61   for(p = _gnutls_compression_algorithms; p->name != NULL; p++) { b ; }
62 #define GNUTLS_COMPRESSION_ALG_LOOP(a)                                  \
63   GNUTLS_COMPRESSION_LOOP( if(p->id == algorithm) { a; break; } )
64 #define GNUTLS_COMPRESSION_ALG_LOOP_NUM(a)                              \
65   GNUTLS_COMPRESSION_LOOP( if(p->num == num) { a; break; } )
66
67 /* Compression Functions */
68
69 /**
70  * gnutls_compression_get_name:
71  * @algorithm: is a Compression algorithm
72  *
73  * Convert a #gnutls_compression_method_t value to a string.
74  *
75  * Returns: a pointer to a string that contains the name of the
76  *   specified compression algorithm, or %NULL.
77  **/
78 const char *gnutls_compression_get_name(gnutls_compression_method_t
79                                         algorithm)
80 {
81         const char *ret = NULL;
82
83         /* avoid prefix */
84         GNUTLS_COMPRESSION_ALG_LOOP(ret =
85                                     p->name + sizeof("GNUTLS_COMP_") - 1);
86
87         return ret;
88 }
89
90 /**
91  * gnutls_compression_get_id:
92  * @name: is a compression method name
93  *
94  * The names are compared in a case insensitive way.
95  *
96  * Returns: an id of the specified in a string compression method, or
97  *   %GNUTLS_COMP_UNKNOWN on error.
98  **/
99 gnutls_compression_method_t gnutls_compression_get_id(const char *name)
100 {
101         gnutls_compression_method_t ret = GNUTLS_COMP_UNKNOWN;
102
103         GNUTLS_COMPRESSION_LOOP(if
104                                 (strcasecmp
105                                  (p->name + sizeof("GNUTLS_COMP_") - 1,
106                                   name) == 0) ret = p->id);
107
108         return ret;
109 }
110
111 /**
112  * gnutls_compression_list:
113  *
114  * Get a list of compression methods.  
115  *
116  * Returns: a zero-terminated list of #gnutls_compression_method_t
117  *   integers indicating the available compression methods.
118  **/
119 const gnutls_compression_method_t *gnutls_compression_list(void)
120 {
121         return supported_compressions;
122 }
123
124 /* return the tls number of the specified algorithm */
125 int _gnutls_compression_get_num(gnutls_compression_method_t algorithm)
126 {
127         int ret = -1;
128
129         /* avoid prefix */
130         GNUTLS_COMPRESSION_ALG_LOOP(ret = p->num);
131
132         return ret;
133 }
134
135 #ifdef HAVE_LIBZ
136
137 static int get_wbits(gnutls_compression_method_t algorithm)
138 {
139         int ret = -1;
140         /* avoid prefix */
141         GNUTLS_COMPRESSION_ALG_LOOP(ret = p->window_bits);
142         return ret;
143 }
144
145 static int get_mem_level(gnutls_compression_method_t algorithm)
146 {
147         int ret = -1;
148         /* avoid prefix */
149         GNUTLS_COMPRESSION_ALG_LOOP(ret = p->mem_level);
150         return ret;
151 }
152
153 static int get_comp_level(gnutls_compression_method_t algorithm)
154 {
155         int ret = -1;
156         /* avoid prefix */
157         GNUTLS_COMPRESSION_ALG_LOOP(ret = p->comp_level);
158         return ret;
159 }
160
161 #endif
162
163 /* returns the gnutls internal ID of the TLS compression
164  * method num
165  */
166 gnutls_compression_method_t _gnutls_compression_get_id(int num)
167 {
168         gnutls_compression_method_t ret = -1;
169
170         /* avoid prefix */
171         GNUTLS_COMPRESSION_ALG_LOOP_NUM(ret = p->id);
172
173         return ret;
174 }
175
176 int _gnutls_compression_is_ok(gnutls_compression_method_t algorithm)
177 {
178         ssize_t ret = -1;
179         GNUTLS_COMPRESSION_ALG_LOOP(ret = p->id);
180         if (ret >= 0)
181                 ret = 0;
182         else
183                 ret = 1;
184         return ret;
185 }
186
187
188
189 /* For compression  */
190
191 #define MIN_PRIVATE_COMP_ALGO 0xEF
192
193 /* returns the TLS numbers of the compression methods we support
194  */
195 #define SUPPORTED_COMPRESSION_METHODS session->internals.priorities.compression.algorithms
196 int
197 _gnutls_supported_compression_methods(gnutls_session_t session,
198                                       uint8_t * comp, size_t comp_size)
199 {
200         unsigned int i, j;
201         int tmp;
202
203         if (comp_size < SUPPORTED_COMPRESSION_METHODS)
204                 return gnutls_assert_val(GNUTLS_E_INTERNAL_ERROR);
205
206         for (i = j = 0; i < SUPPORTED_COMPRESSION_METHODS; i++) {
207                 if (IS_DTLS(session) && session->internals.priorities.compression.priority[i] != GNUTLS_COMP_NULL) {
208                         gnutls_assert();
209                         continue;
210                 }
211
212                 tmp =
213                     _gnutls_compression_get_num(session->
214                                                 internals.priorities.
215                                                 compression.priority[i]);
216
217                 /* remove private compression algorithms, if requested.
218                  */
219                 if (tmp == -1 || (tmp >= MIN_PRIVATE_COMP_ALGO &&
220                                   session->internals.enable_private == 0))
221                 {
222                         gnutls_assert();
223                         continue;
224                 }
225
226                 comp[j] = (uint8_t) tmp;
227                 j++;
228         }
229
230         if (j == 0) {
231                 gnutls_assert();
232                 return GNUTLS_E_NO_COMPRESSION_ALGORITHMS;
233         }
234         return j;
235 }
236
237
238 /* The flag d is the direction (compress, decompress). Non zero is
239  * decompress.
240  */
241 int _gnutls_comp_init(comp_hd_st * handle,
242                       gnutls_compression_method_t method, int d)
243 {
244         handle->algo = method;
245         handle->handle = NULL;
246
247         switch (method) {
248         case GNUTLS_COMP_DEFLATE:
249 #ifdef HAVE_LIBZ
250                 {
251                         int window_bits, mem_level;
252                         int comp_level;
253                         z_stream *zhandle;
254                         int err;
255
256                         window_bits = get_wbits(method);
257                         mem_level = get_mem_level(method);
258                         comp_level = get_comp_level(method);
259
260                         handle->handle = gnutls_malloc(sizeof(z_stream));
261                         if (handle->handle == NULL)
262                                 return
263                                     gnutls_assert_val
264                                     (GNUTLS_E_MEMORY_ERROR);
265
266                         zhandle = handle->handle;
267
268                         zhandle->zalloc = (alloc_func) 0;
269                         zhandle->zfree = (free_func) 0;
270                         zhandle->opaque = (voidpf) 0;
271
272                         if (d)
273                                 err = inflateInit2(zhandle, window_bits);
274                         else {
275                                 err = deflateInit2(zhandle,
276                                                    comp_level, Z_DEFLATED,
277                                                    window_bits, mem_level,
278                                                    Z_DEFAULT_STRATEGY);
279                         }
280                         if (err != Z_OK) {
281                                 gnutls_assert();
282                                 gnutls_free(handle->handle);
283                                 return GNUTLS_E_COMPRESSION_FAILED;
284                         }
285                 }
286                 break;
287 #endif
288         case GNUTLS_COMP_NULL:
289         case GNUTLS_COMP_UNKNOWN:
290                 break;
291         default:
292                 return GNUTLS_E_UNKNOWN_COMPRESSION_ALGORITHM;
293         }
294
295         return 0;
296 }
297
298 /* The flag d is the direction (compress, decompress). Non zero is
299  * decompress.
300  */
301 void _gnutls_comp_deinit(comp_hd_st * handle, int d)
302 {
303         if (handle != NULL) {
304                 switch (handle->algo) {
305 #ifdef HAVE_LIBZ
306                 case GNUTLS_COMP_DEFLATE:
307                         {
308                                 if (d)
309                                         inflateEnd(handle->handle);
310                                 else
311                                         deflateEnd(handle->handle);
312                                 break;
313                         }
314 #endif
315                 default:
316                         break;
317                 }
318                 gnutls_free(handle->handle);
319                 handle->handle = NULL;
320         }
321 }
322
323 /* These functions are memory consuming 
324  */
325
326 int
327 _gnutls_compress(comp_hd_st * handle, const uint8_t * plain,
328                  size_t plain_size, uint8_t * compressed,
329                  size_t max_comp_size, unsigned int stateless)
330 {
331         int compressed_size = GNUTLS_E_COMPRESSION_FAILED;
332
333         /* NULL compression is not handled here
334          */
335         if (handle == NULL) {
336                 gnutls_assert();
337                 return GNUTLS_E_INTERNAL_ERROR;
338         }
339
340         switch (handle->algo) {
341 #ifdef HAVE_LIBZ
342         case GNUTLS_COMP_DEFLATE:
343                 {
344                         z_stream *zhandle;
345                         int err;
346                         int type;
347
348                         if (stateless) {
349                                 type = Z_FULL_FLUSH;
350                         } else
351                                 type = Z_SYNC_FLUSH;
352
353                         zhandle = handle->handle;
354
355                         zhandle->next_in = (Bytef *) plain;
356                         zhandle->avail_in = plain_size;
357                         zhandle->next_out = (Bytef *) compressed;
358                         zhandle->avail_out = max_comp_size;
359
360                         err = deflate(zhandle, type);
361                         if (err != Z_OK || zhandle->avail_in != 0)
362                                 return
363                                     gnutls_assert_val
364                                     (GNUTLS_E_COMPRESSION_FAILED);
365
366
367                         compressed_size =
368                             max_comp_size - zhandle->avail_out;
369                         break;
370                 }
371 #endif
372         default:
373                 gnutls_assert();
374                 return GNUTLS_E_INTERNAL_ERROR;
375         }                       /* switch */
376
377 #ifdef COMPRESSION_DEBUG
378         _gnutls_debug_log("Compression ratio: %f\n",
379                           (float) ((float) compressed_size /
380                                    (float) plain_size));
381 #endif
382
383         return compressed_size;
384 }
385
386
387
388 int
389 _gnutls_decompress(comp_hd_st * handle, uint8_t * compressed,
390                    size_t compressed_size, uint8_t * plain,
391                    size_t max_plain_size)
392 {
393         int plain_size = GNUTLS_E_DECOMPRESSION_FAILED;
394
395         if (compressed_size > max_plain_size + EXTRA_COMP_SIZE) {
396                 gnutls_assert();
397                 return GNUTLS_E_DECOMPRESSION_FAILED;
398         }
399
400         /* NULL compression is not handled here
401          */
402
403         if (handle == NULL) {
404                 gnutls_assert();
405                 return GNUTLS_E_INTERNAL_ERROR;
406         }
407
408         switch (handle->algo) {
409 #ifdef HAVE_LIBZ
410         case GNUTLS_COMP_DEFLATE:
411                 {
412                         z_stream *zhandle;
413                         int err;
414
415                         zhandle = handle->handle;
416
417                         zhandle->next_in = (Bytef *) compressed;
418                         zhandle->avail_in = compressed_size;
419
420                         zhandle->next_out = (Bytef *) plain;
421                         zhandle->avail_out = max_plain_size;
422                         err = inflate(zhandle, Z_SYNC_FLUSH);
423
424                         if (err != Z_OK)
425                                 return
426                                     gnutls_assert_val
427                                     (GNUTLS_E_DECOMPRESSION_FAILED);
428
429                         plain_size = max_plain_size - zhandle->avail_out;
430                         break;
431                 }
432 #endif
433         default:
434                 gnutls_assert();
435                 return GNUTLS_E_INTERNAL_ERROR;
436         }                       /* switch */
437
438         return plain_size;
439 }