Merge tag 'pull-18-rc1-work.namei' of git://git.kernel.org/pub/scm/linux/kernel/git...
[platform/kernel/linux-rpi.git] / arch / arm64 / crypto / aes-glue.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
4  *
5  * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7
8 #include <asm/neon.h>
9 #include <asm/hwcap.h>
10 #include <asm/simd.h>
11 #include <crypto/aes.h>
12 #include <crypto/ctr.h>
13 #include <crypto/sha2.h>
14 #include <crypto/internal/hash.h>
15 #include <crypto/internal/simd.h>
16 #include <crypto/internal/skcipher.h>
17 #include <crypto/scatterwalk.h>
18 #include <linux/module.h>
19 #include <linux/cpufeature.h>
20 #include <crypto/xts.h>
21
22 #include "aes-ce-setkey.h"
23
24 #ifdef USE_V8_CRYPTO_EXTENSIONS
25 #define MODE                    "ce"
26 #define PRIO                    300
27 #define aes_expandkey           ce_aes_expandkey
28 #define aes_ecb_encrypt         ce_aes_ecb_encrypt
29 #define aes_ecb_decrypt         ce_aes_ecb_decrypt
30 #define aes_cbc_encrypt         ce_aes_cbc_encrypt
31 #define aes_cbc_decrypt         ce_aes_cbc_decrypt
32 #define aes_cbc_cts_encrypt     ce_aes_cbc_cts_encrypt
33 #define aes_cbc_cts_decrypt     ce_aes_cbc_cts_decrypt
34 #define aes_essiv_cbc_encrypt   ce_aes_essiv_cbc_encrypt
35 #define aes_essiv_cbc_decrypt   ce_aes_essiv_cbc_decrypt
36 #define aes_ctr_encrypt         ce_aes_ctr_encrypt
37 #define aes_xts_encrypt         ce_aes_xts_encrypt
38 #define aes_xts_decrypt         ce_aes_xts_decrypt
39 #define aes_mac_update          ce_aes_mac_update
40 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
41 #else
42 #define MODE                    "neon"
43 #define PRIO                    200
44 #define aes_ecb_encrypt         neon_aes_ecb_encrypt
45 #define aes_ecb_decrypt         neon_aes_ecb_decrypt
46 #define aes_cbc_encrypt         neon_aes_cbc_encrypt
47 #define aes_cbc_decrypt         neon_aes_cbc_decrypt
48 #define aes_cbc_cts_encrypt     neon_aes_cbc_cts_encrypt
49 #define aes_cbc_cts_decrypt     neon_aes_cbc_cts_decrypt
50 #define aes_essiv_cbc_encrypt   neon_aes_essiv_cbc_encrypt
51 #define aes_essiv_cbc_decrypt   neon_aes_essiv_cbc_decrypt
52 #define aes_ctr_encrypt         neon_aes_ctr_encrypt
53 #define aes_xts_encrypt         neon_aes_xts_encrypt
54 #define aes_xts_decrypt         neon_aes_xts_decrypt
55 #define aes_mac_update          neon_aes_mac_update
56 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
57 #endif
58 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
59 MODULE_ALIAS_CRYPTO("ecb(aes)");
60 MODULE_ALIAS_CRYPTO("cbc(aes)");
61 MODULE_ALIAS_CRYPTO("ctr(aes)");
62 MODULE_ALIAS_CRYPTO("xts(aes)");
63 #endif
64 MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
65 MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
66 MODULE_ALIAS_CRYPTO("cmac(aes)");
67 MODULE_ALIAS_CRYPTO("xcbc(aes)");
68 MODULE_ALIAS_CRYPTO("cbcmac(aes)");
69
70 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
71 MODULE_LICENSE("GPL v2");
72
73 /* defined in aes-modes.S */
74 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
75                                 int rounds, int blocks);
76 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
77                                 int rounds, int blocks);
78
79 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
80                                 int rounds, int blocks, u8 iv[]);
81 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
82                                 int rounds, int blocks, u8 iv[]);
83
84 asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
85                                 int rounds, int bytes, u8 const iv[]);
86 asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
87                                 int rounds, int bytes, u8 const iv[]);
88
89 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
90                                 int rounds, int bytes, u8 ctr[]);
91
92 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
93                                 int rounds, int bytes, u32 const rk2[], u8 iv[],
94                                 int first);
95 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
96                                 int rounds, int bytes, u32 const rk2[], u8 iv[],
97                                 int first);
98
99 asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
100                                       int rounds, int blocks, u8 iv[],
101                                       u32 const rk2[]);
102 asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
103                                       int rounds, int blocks, u8 iv[],
104                                       u32 const rk2[]);
105
106 asmlinkage int aes_mac_update(u8 const in[], u32 const rk[], int rounds,
107                               int blocks, u8 dg[], int enc_before,
108                               int enc_after);
109
110 struct crypto_aes_xts_ctx {
111         struct crypto_aes_ctx key1;
112         struct crypto_aes_ctx __aligned(8) key2;
113 };
114
115 struct crypto_aes_essiv_cbc_ctx {
116         struct crypto_aes_ctx key1;
117         struct crypto_aes_ctx __aligned(8) key2;
118         struct crypto_shash *hash;
119 };
120
121 struct mac_tfm_ctx {
122         struct crypto_aes_ctx key;
123         u8 __aligned(8) consts[];
124 };
125
126 struct mac_desc_ctx {
127         unsigned int len;
128         u8 dg[AES_BLOCK_SIZE];
129 };
130
131 static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
132                                unsigned int key_len)
133 {
134         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
135
136         return aes_expandkey(ctx, in_key, key_len);
137 }
138
139 static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
140                                       const u8 *in_key, unsigned int key_len)
141 {
142         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
143         int ret;
144
145         ret = xts_verify_key(tfm, in_key, key_len);
146         if (ret)
147                 return ret;
148
149         ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
150         if (!ret)
151                 ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
152                                     key_len / 2);
153         return ret;
154 }
155
156 static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
157                                             const u8 *in_key,
158                                             unsigned int key_len)
159 {
160         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
161         u8 digest[SHA256_DIGEST_SIZE];
162         int ret;
163
164         ret = aes_expandkey(&ctx->key1, in_key, key_len);
165         if (ret)
166                 return ret;
167
168         crypto_shash_tfm_digest(ctx->hash, in_key, key_len, digest);
169
170         return aes_expandkey(&ctx->key2, digest, sizeof(digest));
171 }
172
173 static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
174 {
175         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
176         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
177         int err, rounds = 6 + ctx->key_length / 4;
178         struct skcipher_walk walk;
179         unsigned int blocks;
180
181         err = skcipher_walk_virt(&walk, req, false);
182
183         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
184                 kernel_neon_begin();
185                 aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
186                                 ctx->key_enc, rounds, blocks);
187                 kernel_neon_end();
188                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
189         }
190         return err;
191 }
192
193 static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
194 {
195         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
196         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
197         int err, rounds = 6 + ctx->key_length / 4;
198         struct skcipher_walk walk;
199         unsigned int blocks;
200
201         err = skcipher_walk_virt(&walk, req, false);
202
203         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
204                 kernel_neon_begin();
205                 aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
206                                 ctx->key_dec, rounds, blocks);
207                 kernel_neon_end();
208                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
209         }
210         return err;
211 }
212
213 static int cbc_encrypt_walk(struct skcipher_request *req,
214                             struct skcipher_walk *walk)
215 {
216         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
217         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
218         int err = 0, rounds = 6 + ctx->key_length / 4;
219         unsigned int blocks;
220
221         while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
222                 kernel_neon_begin();
223                 aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
224                                 ctx->key_enc, rounds, blocks, walk->iv);
225                 kernel_neon_end();
226                 err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
227         }
228         return err;
229 }
230
231 static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
232 {
233         struct skcipher_walk walk;
234         int err;
235
236         err = skcipher_walk_virt(&walk, req, false);
237         if (err)
238                 return err;
239         return cbc_encrypt_walk(req, &walk);
240 }
241
242 static int cbc_decrypt_walk(struct skcipher_request *req,
243                             struct skcipher_walk *walk)
244 {
245         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
246         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
247         int err = 0, rounds = 6 + ctx->key_length / 4;
248         unsigned int blocks;
249
250         while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
251                 kernel_neon_begin();
252                 aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
253                                 ctx->key_dec, rounds, blocks, walk->iv);
254                 kernel_neon_end();
255                 err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
256         }
257         return err;
258 }
259
260 static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
261 {
262         struct skcipher_walk walk;
263         int err;
264
265         err = skcipher_walk_virt(&walk, req, false);
266         if (err)
267                 return err;
268         return cbc_decrypt_walk(req, &walk);
269 }
270
271 static int cts_cbc_encrypt(struct skcipher_request *req)
272 {
273         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
274         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
275         int err, rounds = 6 + ctx->key_length / 4;
276         int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
277         struct scatterlist *src = req->src, *dst = req->dst;
278         struct scatterlist sg_src[2], sg_dst[2];
279         struct skcipher_request subreq;
280         struct skcipher_walk walk;
281
282         skcipher_request_set_tfm(&subreq, tfm);
283         skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
284                                       NULL, NULL);
285
286         if (req->cryptlen <= AES_BLOCK_SIZE) {
287                 if (req->cryptlen < AES_BLOCK_SIZE)
288                         return -EINVAL;
289                 cbc_blocks = 1;
290         }
291
292         if (cbc_blocks > 0) {
293                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
294                                            cbc_blocks * AES_BLOCK_SIZE,
295                                            req->iv);
296
297                 err = skcipher_walk_virt(&walk, &subreq, false) ?:
298                       cbc_encrypt_walk(&subreq, &walk);
299                 if (err)
300                         return err;
301
302                 if (req->cryptlen == AES_BLOCK_SIZE)
303                         return 0;
304
305                 dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
306                 if (req->dst != req->src)
307                         dst = scatterwalk_ffwd(sg_dst, req->dst,
308                                                subreq.cryptlen);
309         }
310
311         /* handle ciphertext stealing */
312         skcipher_request_set_crypt(&subreq, src, dst,
313                                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
314                                    req->iv);
315
316         err = skcipher_walk_virt(&walk, &subreq, false);
317         if (err)
318                 return err;
319
320         kernel_neon_begin();
321         aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
322                             ctx->key_enc, rounds, walk.nbytes, walk.iv);
323         kernel_neon_end();
324
325         return skcipher_walk_done(&walk, 0);
326 }
327
328 static int cts_cbc_decrypt(struct skcipher_request *req)
329 {
330         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
331         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
332         int err, rounds = 6 + ctx->key_length / 4;
333         int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
334         struct scatterlist *src = req->src, *dst = req->dst;
335         struct scatterlist sg_src[2], sg_dst[2];
336         struct skcipher_request subreq;
337         struct skcipher_walk walk;
338
339         skcipher_request_set_tfm(&subreq, tfm);
340         skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
341                                       NULL, NULL);
342
343         if (req->cryptlen <= AES_BLOCK_SIZE) {
344                 if (req->cryptlen < AES_BLOCK_SIZE)
345                         return -EINVAL;
346                 cbc_blocks = 1;
347         }
348
349         if (cbc_blocks > 0) {
350                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
351                                            cbc_blocks * AES_BLOCK_SIZE,
352                                            req->iv);
353
354                 err = skcipher_walk_virt(&walk, &subreq, false) ?:
355                       cbc_decrypt_walk(&subreq, &walk);
356                 if (err)
357                         return err;
358
359                 if (req->cryptlen == AES_BLOCK_SIZE)
360                         return 0;
361
362                 dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
363                 if (req->dst != req->src)
364                         dst = scatterwalk_ffwd(sg_dst, req->dst,
365                                                subreq.cryptlen);
366         }
367
368         /* handle ciphertext stealing */
369         skcipher_request_set_crypt(&subreq, src, dst,
370                                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
371                                    req->iv);
372
373         err = skcipher_walk_virt(&walk, &subreq, false);
374         if (err)
375                 return err;
376
377         kernel_neon_begin();
378         aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
379                             ctx->key_dec, rounds, walk.nbytes, walk.iv);
380         kernel_neon_end();
381
382         return skcipher_walk_done(&walk, 0);
383 }
384
385 static int __maybe_unused essiv_cbc_init_tfm(struct crypto_skcipher *tfm)
386 {
387         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
388
389         ctx->hash = crypto_alloc_shash("sha256", 0, 0);
390
391         return PTR_ERR_OR_ZERO(ctx->hash);
392 }
393
394 static void __maybe_unused essiv_cbc_exit_tfm(struct crypto_skcipher *tfm)
395 {
396         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
397
398         crypto_free_shash(ctx->hash);
399 }
400
401 static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
402 {
403         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
404         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
405         int err, rounds = 6 + ctx->key1.key_length / 4;
406         struct skcipher_walk walk;
407         unsigned int blocks;
408
409         err = skcipher_walk_virt(&walk, req, false);
410
411         blocks = walk.nbytes / AES_BLOCK_SIZE;
412         if (blocks) {
413                 kernel_neon_begin();
414                 aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
415                                       ctx->key1.key_enc, rounds, blocks,
416                                       req->iv, ctx->key2.key_enc);
417                 kernel_neon_end();
418                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
419         }
420         return err ?: cbc_encrypt_walk(req, &walk);
421 }
422
423 static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
424 {
425         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
426         struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
427         int err, rounds = 6 + ctx->key1.key_length / 4;
428         struct skcipher_walk walk;
429         unsigned int blocks;
430
431         err = skcipher_walk_virt(&walk, req, false);
432
433         blocks = walk.nbytes / AES_BLOCK_SIZE;
434         if (blocks) {
435                 kernel_neon_begin();
436                 aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
437                                       ctx->key1.key_dec, rounds, blocks,
438                                       req->iv, ctx->key2.key_enc);
439                 kernel_neon_end();
440                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
441         }
442         return err ?: cbc_decrypt_walk(req, &walk);
443 }
444
445 static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
446 {
447         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
448         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
449         int err, rounds = 6 + ctx->key_length / 4;
450         struct skcipher_walk walk;
451
452         err = skcipher_walk_virt(&walk, req, false);
453
454         while (walk.nbytes > 0) {
455                 const u8 *src = walk.src.virt.addr;
456                 unsigned int nbytes = walk.nbytes;
457                 u8 *dst = walk.dst.virt.addr;
458                 u8 buf[AES_BLOCK_SIZE];
459
460                 if (unlikely(nbytes < AES_BLOCK_SIZE))
461                         src = dst = memcpy(buf + sizeof(buf) - nbytes,
462                                            src, nbytes);
463                 else if (nbytes < walk.total)
464                         nbytes &= ~(AES_BLOCK_SIZE - 1);
465
466                 kernel_neon_begin();
467                 aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
468                                 walk.iv);
469                 kernel_neon_end();
470
471                 if (unlikely(nbytes < AES_BLOCK_SIZE))
472                         memcpy(walk.dst.virt.addr,
473                                buf + sizeof(buf) - nbytes, nbytes);
474
475                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
476         }
477
478         return err;
479 }
480
481 static int __maybe_unused xts_encrypt(struct skcipher_request *req)
482 {
483         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
484         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
485         int err, first, rounds = 6 + ctx->key1.key_length / 4;
486         int tail = req->cryptlen % AES_BLOCK_SIZE;
487         struct scatterlist sg_src[2], sg_dst[2];
488         struct skcipher_request subreq;
489         struct scatterlist *src, *dst;
490         struct skcipher_walk walk;
491
492         if (req->cryptlen < AES_BLOCK_SIZE)
493                 return -EINVAL;
494
495         err = skcipher_walk_virt(&walk, req, false);
496
497         if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
498                 int xts_blocks = DIV_ROUND_UP(req->cryptlen,
499                                               AES_BLOCK_SIZE) - 2;
500
501                 skcipher_walk_abort(&walk);
502
503                 skcipher_request_set_tfm(&subreq, tfm);
504                 skcipher_request_set_callback(&subreq,
505                                               skcipher_request_flags(req),
506                                               NULL, NULL);
507                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
508                                            xts_blocks * AES_BLOCK_SIZE,
509                                            req->iv);
510                 req = &subreq;
511                 err = skcipher_walk_virt(&walk, req, false);
512         } else {
513                 tail = 0;
514         }
515
516         for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
517                 int nbytes = walk.nbytes;
518
519                 if (walk.nbytes < walk.total)
520                         nbytes &= ~(AES_BLOCK_SIZE - 1);
521
522                 kernel_neon_begin();
523                 aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
524                                 ctx->key1.key_enc, rounds, nbytes,
525                                 ctx->key2.key_enc, walk.iv, first);
526                 kernel_neon_end();
527                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
528         }
529
530         if (err || likely(!tail))
531                 return err;
532
533         dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
534         if (req->dst != req->src)
535                 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
536
537         skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
538                                    req->iv);
539
540         err = skcipher_walk_virt(&walk, &subreq, false);
541         if (err)
542                 return err;
543
544         kernel_neon_begin();
545         aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
546                         ctx->key1.key_enc, rounds, walk.nbytes,
547                         ctx->key2.key_enc, walk.iv, first);
548         kernel_neon_end();
549
550         return skcipher_walk_done(&walk, 0);
551 }
552
553 static int __maybe_unused xts_decrypt(struct skcipher_request *req)
554 {
555         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
556         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
557         int err, first, rounds = 6 + ctx->key1.key_length / 4;
558         int tail = req->cryptlen % AES_BLOCK_SIZE;
559         struct scatterlist sg_src[2], sg_dst[2];
560         struct skcipher_request subreq;
561         struct scatterlist *src, *dst;
562         struct skcipher_walk walk;
563
564         if (req->cryptlen < AES_BLOCK_SIZE)
565                 return -EINVAL;
566
567         err = skcipher_walk_virt(&walk, req, false);
568
569         if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
570                 int xts_blocks = DIV_ROUND_UP(req->cryptlen,
571                                               AES_BLOCK_SIZE) - 2;
572
573                 skcipher_walk_abort(&walk);
574
575                 skcipher_request_set_tfm(&subreq, tfm);
576                 skcipher_request_set_callback(&subreq,
577                                               skcipher_request_flags(req),
578                                               NULL, NULL);
579                 skcipher_request_set_crypt(&subreq, req->src, req->dst,
580                                            xts_blocks * AES_BLOCK_SIZE,
581                                            req->iv);
582                 req = &subreq;
583                 err = skcipher_walk_virt(&walk, req, false);
584         } else {
585                 tail = 0;
586         }
587
588         for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
589                 int nbytes = walk.nbytes;
590
591                 if (walk.nbytes < walk.total)
592                         nbytes &= ~(AES_BLOCK_SIZE - 1);
593
594                 kernel_neon_begin();
595                 aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
596                                 ctx->key1.key_dec, rounds, nbytes,
597                                 ctx->key2.key_enc, walk.iv, first);
598                 kernel_neon_end();
599                 err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
600         }
601
602         if (err || likely(!tail))
603                 return err;
604
605         dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
606         if (req->dst != req->src)
607                 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
608
609         skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
610                                    req->iv);
611
612         err = skcipher_walk_virt(&walk, &subreq, false);
613         if (err)
614                 return err;
615
616
617         kernel_neon_begin();
618         aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
619                         ctx->key1.key_dec, rounds, walk.nbytes,
620                         ctx->key2.key_enc, walk.iv, first);
621         kernel_neon_end();
622
623         return skcipher_walk_done(&walk, 0);
624 }
625
626 static struct skcipher_alg aes_algs[] = { {
627 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
628         .base = {
629                 .cra_name               = "ecb(aes)",
630                 .cra_driver_name        = "ecb-aes-" MODE,
631                 .cra_priority           = PRIO,
632                 .cra_blocksize          = AES_BLOCK_SIZE,
633                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
634                 .cra_module             = THIS_MODULE,
635         },
636         .min_keysize    = AES_MIN_KEY_SIZE,
637         .max_keysize    = AES_MAX_KEY_SIZE,
638         .setkey         = skcipher_aes_setkey,
639         .encrypt        = ecb_encrypt,
640         .decrypt        = ecb_decrypt,
641 }, {
642         .base = {
643                 .cra_name               = "cbc(aes)",
644                 .cra_driver_name        = "cbc-aes-" MODE,
645                 .cra_priority           = PRIO,
646                 .cra_blocksize          = AES_BLOCK_SIZE,
647                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
648                 .cra_module             = THIS_MODULE,
649         },
650         .min_keysize    = AES_MIN_KEY_SIZE,
651         .max_keysize    = AES_MAX_KEY_SIZE,
652         .ivsize         = AES_BLOCK_SIZE,
653         .setkey         = skcipher_aes_setkey,
654         .encrypt        = cbc_encrypt,
655         .decrypt        = cbc_decrypt,
656 }, {
657         .base = {
658                 .cra_name               = "ctr(aes)",
659                 .cra_driver_name        = "ctr-aes-" MODE,
660                 .cra_priority           = PRIO,
661                 .cra_blocksize          = 1,
662                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
663                 .cra_module             = THIS_MODULE,
664         },
665         .min_keysize    = AES_MIN_KEY_SIZE,
666         .max_keysize    = AES_MAX_KEY_SIZE,
667         .ivsize         = AES_BLOCK_SIZE,
668         .chunksize      = AES_BLOCK_SIZE,
669         .setkey         = skcipher_aes_setkey,
670         .encrypt        = ctr_encrypt,
671         .decrypt        = ctr_encrypt,
672 }, {
673         .base = {
674                 .cra_name               = "xts(aes)",
675                 .cra_driver_name        = "xts-aes-" MODE,
676                 .cra_priority           = PRIO,
677                 .cra_blocksize          = AES_BLOCK_SIZE,
678                 .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
679                 .cra_module             = THIS_MODULE,
680         },
681         .min_keysize    = 2 * AES_MIN_KEY_SIZE,
682         .max_keysize    = 2 * AES_MAX_KEY_SIZE,
683         .ivsize         = AES_BLOCK_SIZE,
684         .walksize       = 2 * AES_BLOCK_SIZE,
685         .setkey         = xts_set_key,
686         .encrypt        = xts_encrypt,
687         .decrypt        = xts_decrypt,
688 }, {
689 #endif
690         .base = {
691                 .cra_name               = "cts(cbc(aes))",
692                 .cra_driver_name        = "cts-cbc-aes-" MODE,
693                 .cra_priority           = PRIO,
694                 .cra_blocksize          = AES_BLOCK_SIZE,
695                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
696                 .cra_module             = THIS_MODULE,
697         },
698         .min_keysize    = AES_MIN_KEY_SIZE,
699         .max_keysize    = AES_MAX_KEY_SIZE,
700         .ivsize         = AES_BLOCK_SIZE,
701         .walksize       = 2 * AES_BLOCK_SIZE,
702         .setkey         = skcipher_aes_setkey,
703         .encrypt        = cts_cbc_encrypt,
704         .decrypt        = cts_cbc_decrypt,
705 }, {
706         .base = {
707                 .cra_name               = "essiv(cbc(aes),sha256)",
708                 .cra_driver_name        = "essiv-cbc-aes-sha256-" MODE,
709                 .cra_priority           = PRIO + 1,
710                 .cra_blocksize          = AES_BLOCK_SIZE,
711                 .cra_ctxsize            = sizeof(struct crypto_aes_essiv_cbc_ctx),
712                 .cra_module             = THIS_MODULE,
713         },
714         .min_keysize    = AES_MIN_KEY_SIZE,
715         .max_keysize    = AES_MAX_KEY_SIZE,
716         .ivsize         = AES_BLOCK_SIZE,
717         .setkey         = essiv_cbc_set_key,
718         .encrypt        = essiv_cbc_encrypt,
719         .decrypt        = essiv_cbc_decrypt,
720         .init           = essiv_cbc_init_tfm,
721         .exit           = essiv_cbc_exit_tfm,
722 } };
723
724 static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
725                          unsigned int key_len)
726 {
727         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
728
729         return aes_expandkey(&ctx->key, in_key, key_len);
730 }
731
732 static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
733 {
734         u64 a = be64_to_cpu(x->a);
735         u64 b = be64_to_cpu(x->b);
736
737         y->a = cpu_to_be64((a << 1) | (b >> 63));
738         y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
739 }
740
741 static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
742                        unsigned int key_len)
743 {
744         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
745         be128 *consts = (be128 *)ctx->consts;
746         int rounds = 6 + key_len / 4;
747         int err;
748
749         err = cbcmac_setkey(tfm, in_key, key_len);
750         if (err)
751                 return err;
752
753         /* encrypt the zero vector */
754         kernel_neon_begin();
755         aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
756                         rounds, 1);
757         kernel_neon_end();
758
759         cmac_gf128_mul_by_x(consts, consts);
760         cmac_gf128_mul_by_x(consts + 1, consts);
761
762         return 0;
763 }
764
765 static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
766                        unsigned int key_len)
767 {
768         static u8 const ks[3][AES_BLOCK_SIZE] = {
769                 { [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
770                 { [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
771                 { [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
772         };
773
774         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
775         int rounds = 6 + key_len / 4;
776         u8 key[AES_BLOCK_SIZE];
777         int err;
778
779         err = cbcmac_setkey(tfm, in_key, key_len);
780         if (err)
781                 return err;
782
783         kernel_neon_begin();
784         aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
785         aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
786         kernel_neon_end();
787
788         return cbcmac_setkey(tfm, key, sizeof(key));
789 }
790
791 static int mac_init(struct shash_desc *desc)
792 {
793         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
794
795         memset(ctx->dg, 0, AES_BLOCK_SIZE);
796         ctx->len = 0;
797
798         return 0;
799 }
800
801 static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
802                           u8 dg[], int enc_before, int enc_after)
803 {
804         int rounds = 6 + ctx->key_length / 4;
805
806         if (crypto_simd_usable()) {
807                 int rem;
808
809                 do {
810                         kernel_neon_begin();
811                         rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
812                                              dg, enc_before, enc_after);
813                         kernel_neon_end();
814                         in += (blocks - rem) * AES_BLOCK_SIZE;
815                         blocks = rem;
816                         enc_before = 0;
817                 } while (blocks);
818         } else {
819                 if (enc_before)
820                         aes_encrypt(ctx, dg, dg);
821
822                 while (blocks--) {
823                         crypto_xor(dg, in, AES_BLOCK_SIZE);
824                         in += AES_BLOCK_SIZE;
825
826                         if (blocks || enc_after)
827                                 aes_encrypt(ctx, dg, dg);
828                 }
829         }
830 }
831
832 static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
833 {
834         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
835         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
836
837         while (len > 0) {
838                 unsigned int l;
839
840                 if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
841                     (ctx->len + len) > AES_BLOCK_SIZE) {
842
843                         int blocks = len / AES_BLOCK_SIZE;
844
845                         len %= AES_BLOCK_SIZE;
846
847                         mac_do_update(&tctx->key, p, blocks, ctx->dg,
848                                       (ctx->len != 0), (len != 0));
849
850                         p += blocks * AES_BLOCK_SIZE;
851
852                         if (!len) {
853                                 ctx->len = AES_BLOCK_SIZE;
854                                 break;
855                         }
856                         ctx->len = 0;
857                 }
858
859                 l = min(len, AES_BLOCK_SIZE - ctx->len);
860
861                 if (l <= AES_BLOCK_SIZE) {
862                         crypto_xor(ctx->dg + ctx->len, p, l);
863                         ctx->len += l;
864                         len -= l;
865                         p += l;
866                 }
867         }
868
869         return 0;
870 }
871
872 static int cbcmac_final(struct shash_desc *desc, u8 *out)
873 {
874         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
875         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
876
877         mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
878
879         memcpy(out, ctx->dg, AES_BLOCK_SIZE);
880
881         return 0;
882 }
883
884 static int cmac_final(struct shash_desc *desc, u8 *out)
885 {
886         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
887         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
888         u8 *consts = tctx->consts;
889
890         if (ctx->len != AES_BLOCK_SIZE) {
891                 ctx->dg[ctx->len] ^= 0x80;
892                 consts += AES_BLOCK_SIZE;
893         }
894
895         mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
896
897         memcpy(out, ctx->dg, AES_BLOCK_SIZE);
898
899         return 0;
900 }
901
902 static struct shash_alg mac_algs[] = { {
903         .base.cra_name          = "cmac(aes)",
904         .base.cra_driver_name   = "cmac-aes-" MODE,
905         .base.cra_priority      = PRIO,
906         .base.cra_blocksize     = AES_BLOCK_SIZE,
907         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
908                                   2 * AES_BLOCK_SIZE,
909         .base.cra_module        = THIS_MODULE,
910
911         .digestsize             = AES_BLOCK_SIZE,
912         .init                   = mac_init,
913         .update                 = mac_update,
914         .final                  = cmac_final,
915         .setkey                 = cmac_setkey,
916         .descsize               = sizeof(struct mac_desc_ctx),
917 }, {
918         .base.cra_name          = "xcbc(aes)",
919         .base.cra_driver_name   = "xcbc-aes-" MODE,
920         .base.cra_priority      = PRIO,
921         .base.cra_blocksize     = AES_BLOCK_SIZE,
922         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
923                                   2 * AES_BLOCK_SIZE,
924         .base.cra_module        = THIS_MODULE,
925
926         .digestsize             = AES_BLOCK_SIZE,
927         .init                   = mac_init,
928         .update                 = mac_update,
929         .final                  = cmac_final,
930         .setkey                 = xcbc_setkey,
931         .descsize               = sizeof(struct mac_desc_ctx),
932 }, {
933         .base.cra_name          = "cbcmac(aes)",
934         .base.cra_driver_name   = "cbcmac-aes-" MODE,
935         .base.cra_priority      = PRIO,
936         .base.cra_blocksize     = 1,
937         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx),
938         .base.cra_module        = THIS_MODULE,
939
940         .digestsize             = AES_BLOCK_SIZE,
941         .init                   = mac_init,
942         .update                 = mac_update,
943         .final                  = cbcmac_final,
944         .setkey                 = cbcmac_setkey,
945         .descsize               = sizeof(struct mac_desc_ctx),
946 } };
947
948 static void aes_exit(void)
949 {
950         crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
951         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
952 }
953
954 static int __init aes_init(void)
955 {
956         int err;
957
958         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
959         if (err)
960                 return err;
961
962         err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
963         if (err)
964                 goto unregister_ciphers;
965
966         return 0;
967
968 unregister_ciphers:
969         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
970         return err;
971 }
972
973 #ifdef USE_V8_CRYPTO_EXTENSIONS
974 module_cpu_feature_match(AES, aes_init);
975 #else
976 module_init(aes_init);
977 EXPORT_SYMBOL(neon_aes_ecb_encrypt);
978 EXPORT_SYMBOL(neon_aes_cbc_encrypt);
979 EXPORT_SYMBOL(neon_aes_ctr_encrypt);
980 EXPORT_SYMBOL(neon_aes_xts_encrypt);
981 EXPORT_SYMBOL(neon_aes_xts_decrypt);
982 #endif
983 module_exit(aes_exit);