private.h: rename to contain dir
[platform/upstream/libwebsockets.git] / lib / tls / mbedtls / lws-genaes.c
1 /*
2  * libwebsockets - generic AES api hiding the backend
3  *
4  * Copyright (C) 2017 - 2018 Andy Green <andy@warmcat.com>
5  *
6  *  This library is free software; you can redistribute it and/or
7  *  modify it under the terms of the GNU Lesser General Public
8  *  License as published by the Free Software Foundation:
9  *  version 2.1 of the License.
10  *
11  *  This library is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14  *  Lesser General Public License for more details.
15  *
16  *  You should have received a copy of the GNU Lesser General Public
17  *  License along with this library; if not, write to the Free Software
18  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
19  *  MA  02110-1301  USA
20  *
21  *  lws_genaes provides an abstraction api for AES in lws that works the
22  *  same whether you are using openssl or mbedtls hash functions underneath.
23  */
24 #include "private-lib-core.h"
25 #include "../../jose/private-lib-jose.h"
26
27 static int operation_map[] = { MBEDTLS_AES_ENCRYPT, MBEDTLS_AES_DECRYPT };
28
29 LWS_VISIBLE int
30 lws_genaes_create(struct lws_genaes_ctx *ctx, enum enum_aes_operation op,
31                   enum enum_aes_modes mode, struct lws_gencrypto_keyelem *el,
32                   enum enum_aes_padding padding, void *engine)
33 {
34         int n = 0;
35
36         ctx->mode = mode;
37         ctx->k = el;
38         ctx->op = operation_map[op];
39         ctx->underway = 0;
40
41         switch (ctx->mode) {
42         case LWS_GAESM_XTS:
43 #if defined(MBEDTLS_CIPHER_MODE_XTS)
44                 mbedtls_aes_xts_init(&ctx->u.ctx_xts);
45                 break;
46 #else
47                 return -1;
48 #endif
49         case LWS_GAESM_GCM:
50                 mbedtls_gcm_init(&ctx->u.ctx_gcm);
51                 n = mbedtls_gcm_setkey(&ctx->u.ctx_gcm, MBEDTLS_CIPHER_ID_AES,
52                                        ctx->k->buf, ctx->k->len * 8);
53                 if (n) {
54                         lwsl_notice("%s: mbedtls_gcm_setkey: -0x%x\n",
55                                     __func__, -n);
56                         return n;
57                 }
58                 return n;
59         default:
60                 mbedtls_aes_init(&ctx->u.ctx);
61                 break;
62         }
63
64         switch (op) {
65         case LWS_GAESO_ENC:
66                 if (ctx->mode == LWS_GAESM_XTS)
67 #if defined(MBEDTLS_CIPHER_MODE_XTS)
68                         n = mbedtls_aes_xts_setkey_enc(&ctx->u.ctx_xts,
69                                                        ctx->k->buf,
70                                                        ctx->k->len * 8);
71 #else
72                         return -1;
73 #endif
74                 else
75                         n = mbedtls_aes_setkey_enc(&ctx->u.ctx, ctx->k->buf,
76                                                    ctx->k->len * 8);
77                 break;
78         case LWS_GAESO_DEC:
79                 switch (ctx->mode) {
80                 case LWS_GAESM_XTS:
81 #if defined(MBEDTLS_CIPHER_MODE_XTS)
82                         n = mbedtls_aes_xts_setkey_dec(&ctx->u.ctx_xts,
83                                                        ctx->k->buf,
84                                                        ctx->k->len * 8);
85                         break;
86 #else
87                         return -1;
88 #endif
89
90                 case LWS_GAESM_CFB128:
91                 case LWS_GAESM_CFB8:
92                 case LWS_GAESM_CTR:
93                 case LWS_GAESM_OFB:
94                         n = mbedtls_aes_setkey_enc(&ctx->u.ctx, ctx->k->buf,
95                                                    ctx->k->len * 8);
96                         break;
97                 default:
98                         n = mbedtls_aes_setkey_dec(&ctx->u.ctx, ctx->k->buf,
99                                                    ctx->k->len * 8);
100                         break;
101                 }
102                 break;
103         }
104
105         if (n)
106                 lwsl_notice("%s: setting key: -0x%x\n", __func__, -n);
107
108         return n;
109 }
110
111 LWS_VISIBLE int
112 lws_genaes_destroy(struct lws_genaes_ctx *ctx, unsigned char *tag, size_t tlen)
113 {
114         int n = 0;
115
116         if (ctx->mode == LWS_GAESM_GCM) {
117                 n = mbedtls_gcm_finish(&ctx->u.ctx_gcm, tag, tlen);
118                 if (n)
119                         lwsl_notice("%s: mbedtls_gcm_finish: -0x%x\n",
120                                     __func__, -n);
121                 if (tag && ctx->op == MBEDTLS_AES_DECRYPT && !n) {
122                         if (lws_timingsafe_bcmp(ctx->tag, tag, ctx->taglen)) {
123                                 lwsl_err("%s: lws_genaes_crypt tag "
124                                          "mismatch (bad first)\n",
125                                                 __func__);
126                                 lwsl_hexdump_notice(tag, tlen);
127                                 lwsl_hexdump_notice(ctx->tag, ctx->taglen);
128                                 n = -1;
129                         }
130                 }
131                 mbedtls_gcm_free(&ctx->u.ctx_gcm);
132                 return n;
133         }
134         if (ctx->mode == LWS_GAESM_XTS)
135 #if defined(MBEDTLS_CIPHER_MODE_XTS)
136                 mbedtls_aes_xts_free(&ctx->u.ctx_xts);
137 #else
138                 return -1;
139 #endif
140         else
141                 mbedtls_aes_free(&ctx->u.ctx);
142
143         return 0;
144 }
145
146 static int
147 lws_genaes_rfc3394_wrap(int wrap, int cek_bits, const uint8_t *kek,
148                         int kek_bits, const uint8_t *in, uint8_t *out)
149 {
150         int n, m, ret = -1, c64 = cek_bits / 64;
151         mbedtls_aes_context ctx;
152         uint8_t a[8], b[16];
153
154         /*
155          * notice the KEK key used to perform the wrapping or unwrapping is
156          * always the size of the AES key used, eg, A128KW == 128 bits.  The
157          * key being wrapped or unwrapped may be larger and is set by the
158          * 'bits' parameter.
159          *
160          * If it's larger than the KEK key size bits, we iterate over it
161          */
162
163         mbedtls_aes_init(&ctx);
164
165         if (wrap) {
166                 /*
167                  * The inputs to the key wrapping process are the KEK and the
168                  * plaintext to be wrapped.  The plaintext consists of n 64-bit
169                  * blocks, containing the key data being wrapped.
170                  *
171                  * Inputs:      Plaintext, n 64-bit values {P1, P2, ..., Pn},
172                  *              and Key, K (the KEK).
173                  * Outputs:     Ciphertext, (n+1) 64-bit values
174                  *              {C0, C1, ..., Cn}.
175                  *
176                  * The default initial value (IV) is defined to be the
177                  * hexadecimal constant:
178                  *
179                  * A[0] = IV = A6A6A6A6A6A6A6A6
180                  */
181                 memset(out, 0xa6, 8);
182                 memcpy(out + 8, in, 8 * c64);
183                 n = mbedtls_aes_setkey_enc(&ctx, kek, kek_bits);
184         } else {
185                 /*
186                  * 2.2.2 Key Unwrap
187                  *
188                  * The inputs to the unwrap process are the KEK and (n+1)
189                  * 64-bit blocks of ciphertext consisting of previously
190                  * wrapped key.  It returns n blocks of plaintext consisting
191                  * of the n 64-bit blocks of the decrypted key data.
192                  *
193                  * Inputs:  Ciphertext, (n+1) 64-bit values {C0, C1, ..., Cn},
194                  * and Key, K (the KEK).
195                  *
196                  * Outputs: Plaintext, n 64-bit values {P1, P2, ..., Pn}.
197                  */
198                 memcpy(a, in, 8);
199                 memcpy(out, in + 8, 8 * c64);
200                 n = mbedtls_aes_setkey_dec(&ctx, kek, kek_bits);
201         }
202
203         if (n < 0) {
204                 lwsl_err("%s: setkey failed\n", __func__);
205                 goto bail;
206         }
207
208         if (wrap) {
209                 for (n = 0; n <= 5; n++) {
210                         uint8_t *r = out + 8;
211                         for (m = 1; m <= c64; m++) {
212                                 memcpy(b, out, 8);
213                                 memcpy(b + 8, r, 8);
214                                 if (mbedtls_internal_aes_encrypt(&ctx, b, b))
215                                         goto bail;
216
217                                 memcpy(out, b, 8);
218                                 out[7] ^= c64 * n + m;
219                                 memcpy(r, b + 8, 8);
220                                 r += 8;
221                         }
222                 }
223                 ret = 0;
224         } else {
225                 /*
226                  *
227                  */
228                 for (n = 5; n >= 0; n--) {
229                         uint8_t *r = out + (c64 - 1) * 8;
230                         for (m = c64; m >= 1; m--) {
231                                 memcpy(b, a, 8);
232                                 b[7] ^= c64 * n + m;
233                                 memcpy(b + 8, r, 8);
234                                 if (mbedtls_internal_aes_decrypt(&ctx, b, b))
235                                         goto bail;
236
237                                 memcpy(a, b, 8);
238                                 memcpy(r, b + 8, 8);
239                                 r -= 8;
240                         }
241                 }
242
243                 ret = 0;
244                 for (n = 0; n < 8; n++)
245                         if (a[n] != 0xa6)
246                                 ret = -1;
247         }
248
249 bail:
250         if (ret)
251                 lwsl_notice("%s: failed\n", __func__);
252         mbedtls_aes_free(&ctx);
253
254         return ret;
255 }
256
257 LWS_VISIBLE int
258 lws_genaes_crypt(struct lws_genaes_ctx *ctx, const uint8_t *in, size_t len,
259                  uint8_t *out, uint8_t *iv_or_nonce_ctr_or_data_unit_16,
260                  uint8_t *stream_block_16, size_t *nc_or_iv_off, int taglen)
261 {
262         uint8_t iv[LWS_JWE_AES_IV_BYTES], sb[16];
263         int n = 0;
264
265         switch (ctx->mode) {
266         case LWS_GAESM_KW:
267                 /* a key of length ctx->k->len is wrapped by a 128-bit KEK */
268                 n = lws_genaes_rfc3394_wrap(ctx->op == MBEDTLS_AES_ENCRYPT,
269                                 ctx->op == MBEDTLS_AES_ENCRYPT ? len * 8 :
270                                                 (len - 8) * 8, ctx->k->buf,
271                                                 ctx->k->len * 8,
272                                 in, out);
273                 break;
274         case LWS_GAESM_CBC:
275                 memcpy(iv, iv_or_nonce_ctr_or_data_unit_16, 16);
276                 n = mbedtls_aes_crypt_cbc(&ctx->u.ctx, ctx->op, len, iv,
277                                           in, out);
278                 break;
279
280         case LWS_GAESM_CFB128:
281                 memcpy(iv, iv_or_nonce_ctr_or_data_unit_16, 16);
282                 n = mbedtls_aes_crypt_cfb128(&ctx->u.ctx, ctx->op, len,
283                                              nc_or_iv_off, iv, in, out);
284                 break;
285
286         case LWS_GAESM_CFB8:
287                 memcpy(iv, iv_or_nonce_ctr_or_data_unit_16, 16);
288                 n = mbedtls_aes_crypt_cfb8(&ctx->u.ctx, ctx->op, len, iv,
289                                            in, out);
290                 break;
291
292         case LWS_GAESM_CTR:
293                 memcpy(iv, iv_or_nonce_ctr_or_data_unit_16, 16);
294                 memcpy(sb, stream_block_16, 16);
295                 n = mbedtls_aes_crypt_ctr(&ctx->u.ctx, len, nc_or_iv_off,
296                                           iv, sb, in, out);
297                 memcpy(iv_or_nonce_ctr_or_data_unit_16, iv, 16);
298                 memcpy(stream_block_16, sb, 16);
299                 break;
300
301         case LWS_GAESM_ECB:
302                 n = mbedtls_aes_crypt_ecb(&ctx->u.ctx, ctx->op, in, out);
303                 break;
304
305         case LWS_GAESM_OFB:
306 #if defined(MBEDTLS_CIPHER_MODE_OFB)
307                 memcpy(iv, iv_or_nonce_ctr_or_data_unit_16, 16);
308                 n = mbedtls_aes_crypt_ofb(&ctx->u.ctx, len, nc_or_iv_off, iv,
309                                           in, out);
310                 break;
311 #else
312                 return -1;
313 #endif
314
315         case LWS_GAESM_XTS:
316 #if defined(MBEDTLS_CIPHER_MODE_XTS)
317                 memcpy(iv, iv_or_nonce_ctr_or_data_unit_16, 16);
318                 n = mbedtls_aes_crypt_xts(&ctx->u.ctx_xts, ctx->op, len, iv,
319                                           in, out);
320                 break;
321 #else
322                 return -1;
323 #endif
324         case LWS_GAESM_GCM:
325                 if (!ctx->underway) {
326                         ctx->underway = 1;
327
328                         memcpy(ctx->tag, stream_block_16, taglen);
329                         ctx->taglen = taglen;
330
331                         /*
332                          * iv:                   iv_or_nonce_ctr_or_data_unit_16
333                          * iv_len:               *nc_or_iv_off
334                          * stream_block_16:      pointer to tag
335                          * additional data:      in
336                          * additional data len:  len
337                          */
338
339                         n = mbedtls_gcm_starts(&ctx->u.ctx_gcm, ctx->op,
340                                                iv_or_nonce_ctr_or_data_unit_16,
341                                                *nc_or_iv_off, in, len);
342                         if (n) {
343                                 lwsl_notice("%s: mbedtls_gcm_starts: -0x%x\n",
344                                             __func__, -n);
345
346                                 return -1;
347                         }
348                         break;
349                 }
350
351                 n = mbedtls_gcm_update(&ctx->u.ctx_gcm, len, in, out);
352                 if (n) {
353                         lwsl_notice("%s: mbedtls_gcm_update: -0x%x\n",
354                                     __func__, -n);
355
356                         return -1;
357                 }
358                 break;
359         }
360
361         if (n) {
362                 lwsl_notice("%s: failed: -0x%x, len %d\n", __func__, -n, (int)len);
363
364                 return -1;
365         }
366
367         return 0;
368 }