ARM: 9286/1: crypto: Implement fused AES-CTR/GHASH version of GCM
authorArd Biesheuvel <ardb@kernel.org>
Mon, 16 Jan 2023 11:01:48 +0000 (12:01 +0100)
committerRussell King (Oracle) <rmk+kernel@armlinux.org.uk>
Wed, 18 Jan 2023 15:04:51 +0000 (15:04 +0000)
On 32-bit ARM, AES in GCM mode takes full advantage of the ARMv8 Crypto
Extensions when available, resulting in a performance of 6-7 cycles per
byte for typical IPsec frames on cores such as Cortex-A53, using the
generic GCM template encapsulating the accelerated AES-CTR and GHASH
implementations.

At such high rates, any time spent copying data or doing other poorly
optimized work in the generic layer hurts disproportionately, and we can
get a significant performance improvement by combining the optimized
AES-CTR and GHASH implementations into a single GCM driver.

On Cortex-A53, this results in a performance improvement of around 75%,
and AES-256-GCM-128 with RFC4106 encapsulation runs in 4 cycles per
byte.

Note that this code takes advantage of the fact that kernel mode NEON is
now supported in softirq context as well, and therefore does not provide
a non-NEON fallback path at all. (AEADs are only callable in process or
softirq context)

Acked-by: Herbert Xu <herbert@gondor.apana.org.au>
Signed-off-by: Ard Biesheuvel <ardb@kernel.org>
Signed-off-by: Russell King (Oracle) <rmk+kernel@armlinux.org.uk>
arch/arm/crypto/Kconfig
arch/arm/crypto/ghash-ce-core.S
arch/arm/crypto/ghash-ce-glue.c

index 7b2b7d0..847b7a0 100644 (file)
@@ -16,8 +16,10 @@ config CRYPTO_CURVE25519_NEON
 config CRYPTO_GHASH_ARM_CE
        tristate "Hash functions: GHASH (PMULL/NEON/ARMv8 Crypto Extensions)"
        depends on KERNEL_MODE_NEON
+       select CRYPTO_AEAD
        select CRYPTO_HASH
        select CRYPTO_CRYPTD
+       select CRYPTO_LIB_AES
        select CRYPTO_LIB_GF128MUL
        help
          GCM GHASH function (NIST SP800-38D)
index 9f51e3f..858c0d6 100644 (file)
@@ -2,7 +2,8 @@
 /*
  * Accelerated GHASH implementation with NEON/ARMv8 vmull.p8/64 instructions.
  *
- * Copyright (C) 2015 - 2017 Linaro Ltd. <ard.biesheuvel@linaro.org>
+ * Copyright (C) 2015 - 2017 Linaro Ltd.
+ * Copyright (C) 2023 Google LLC. <ardb@google.com>
  */
 
 #include <linux/linkage.h>
@@ -44,7 +45,7 @@
        t2q             .req    q7
        t3q             .req    q8
        t4q             .req    q9
-       T2              .req    q9
+       XH2             .req    q9
 
        s1l             .req    d20
        s1h             .req    d21
@@ -80,7 +81,7 @@
 
        XL2             .req    q5
        XM2             .req    q6
-       XH2             .req    q7
+       T2              .req    q7
        T3              .req    q8
 
        XL2_L           .req    d10
        vshr.u64        XL, XL, #1
        .endm
 
-       .macro          ghash_update, pn
+       .macro          ghash_update, pn, enc, aggregate=1, head=1
        vld1.64         {XL}, [r1]
 
+       .if             \head
        /* do the head block first, if supplied */
        ldr             ip, [sp]
        teq             ip, #0
        vld1.64         {T1}, [ip]
        teq             r0, #0
        b               3f
+       .endif
 
 0:     .ifc            \pn, p64
+       .if             \aggregate
        tst             r0, #3                  // skip until #blocks is a
        bne             2f                      // round multiple of 4
 
        vld1.8          {XL2-XM2}, [r2]!
-1:     vld1.8          {T3-T2}, [r2]!
+1:     vld1.8          {T2-T3}, [r2]!
+
+       .ifnb           \enc
+       \enc\()_4x      XL2, XM2, T2, T3
+
+       add             ip, r3, #16
+       vld1.64         {HH}, [ip, :128]!
+       vld1.64         {HH3-HH4}, [ip, :128]
+
+       veor            SHASH2_p64, SHASH_L, SHASH_H
+       veor            SHASH2_H, HH_L, HH_H
+       veor            HH34_L, HH3_L, HH3_H
+       veor            HH34_H, HH4_L, HH4_H
+
+       vmov.i8         MASK, #0xe1
+       vshl.u64        MASK, MASK, #57
+       .endif
+
        vrev64.8        XL2, XL2
        vrev64.8        XM2, XM2
 
        veor            XL2_H, XL2_H, XL_L
        veor            XL, XL, T1
 
-       vrev64.8        T3, T3
-       vrev64.8        T1, T2
+       vrev64.8        T1, T3
+       vrev64.8        T3, T2
 
        vmull.p64       XH, HH4_H, XL_H                 // a1 * b1
        veor            XL2_H, XL2_H, XL_H
 
        b               1b
        .endif
+       .endif
+
+2:     vld1.8          {T1}, [r2]!
+
+       .ifnb           \enc
+       \enc\()_1x      T1
+       veor            SHASH2_p64, SHASH_L, SHASH_H
+       vmov.i8         MASK, #0xe1
+       vshl.u64        MASK, MASK, #57
+       .endif
 
-2:     vld1.64         {T1}, [r2]!
        subs            r0, r0, #1
 
 3:     /* multiply XL by SHASH in GF(2^128) */
-#ifndef CONFIG_CPU_BIG_ENDIAN
        vrev64.8        T1, T1
-#endif
+
        vext.8          IN1, T1, T1, #8
        veor            T1_L, T1_L, XL_H
        veor            XL, XL, IN1
        veor            XL, XL, T1
 
        bne             0b
-
-       vst1.64         {XL}, [r1]
-       bx              lr
        .endm
 
        /*
@@ -316,6 +342,9 @@ ENTRY(pmull_ghash_update_p64)
        vshl.u64        MASK, MASK, #57
 
        ghash_update    p64
+       vst1.64         {XL}, [r1]
+
+       bx              lr
 ENDPROC(pmull_ghash_update_p64)
 
 ENTRY(pmull_ghash_update_p8)
@@ -336,4 +365,331 @@ ENTRY(pmull_ghash_update_p8)
        vmov.i64        k48, #0xffffffffffff
 
        ghash_update    p8
+       vst1.64         {XL}, [r1]
+
+       bx              lr
 ENDPROC(pmull_ghash_update_p8)
+
+       e0              .req    q9
+       e1              .req    q10
+       e2              .req    q11
+       e3              .req    q12
+       e0l             .req    d18
+       e0h             .req    d19
+       e2l             .req    d22
+       e2h             .req    d23
+       e3l             .req    d24
+       e3h             .req    d25
+       ctr             .req    q13
+       ctr0            .req    d26
+       ctr1            .req    d27
+
+       ek0             .req    q14
+       ek1             .req    q15
+
+       .macro          round, rk:req, regs:vararg
+       .irp            r, \regs
+       aese.8          \r, \rk
+       aesmc.8         \r, \r
+       .endr
+       .endm
+
+       .macro          aes_encrypt, rkp, rounds, regs:vararg
+       vld1.8          {ek0-ek1}, [\rkp, :128]!
+       cmp             \rounds, #12
+       blt             .L\@                    // AES-128
+
+       round           ek0, \regs
+       vld1.8          {ek0}, [\rkp, :128]!
+       round           ek1, \regs
+       vld1.8          {ek1}, [\rkp, :128]!
+
+       beq             .L\@                    // AES-192
+
+       round           ek0, \regs
+       vld1.8          {ek0}, [\rkp, :128]!
+       round           ek1, \regs
+       vld1.8          {ek1}, [\rkp, :128]!
+
+.L\@:  .rept           4
+       round           ek0, \regs
+       vld1.8          {ek0}, [\rkp, :128]!
+       round           ek1, \regs
+       vld1.8          {ek1}, [\rkp, :128]!
+       .endr
+
+       round           ek0, \regs
+       vld1.8          {ek0}, [\rkp, :128]
+
+       .irp            r, \regs
+       aese.8          \r, ek1
+       .endr
+       .irp            r, \regs
+       veor            \r, \r, ek0
+       .endr
+       .endm
+
+pmull_aes_encrypt:
+       add             ip, r5, #4
+       vld1.8          {ctr0}, [r5]            // load 12 byte IV
+       vld1.8          {ctr1}, [ip]
+       rev             r8, r7
+       vext.8          ctr1, ctr1, ctr1, #4
+       add             r7, r7, #1
+       vmov.32         ctr1[1], r8
+       vmov            e0, ctr
+
+       add             ip, r3, #64
+       aes_encrypt     ip, r6, e0
+       bx              lr
+ENDPROC(pmull_aes_encrypt)
+
+pmull_aes_encrypt_4x:
+       add             ip, r5, #4
+       vld1.8          {ctr0}, [r5]
+       vld1.8          {ctr1}, [ip]
+       rev             r8, r7
+       vext.8          ctr1, ctr1, ctr1, #4
+       add             r7, r7, #1
+       vmov.32         ctr1[1], r8
+       rev             ip, r7
+       vmov            e0, ctr
+       add             r7, r7, #1
+       vmov.32         ctr1[1], ip
+       rev             r8, r7
+       vmov            e1, ctr
+       add             r7, r7, #1
+       vmov.32         ctr1[1], r8
+       rev             ip, r7
+       vmov            e2, ctr
+       add             r7, r7, #1
+       vmov.32         ctr1[1], ip
+       vmov            e3, ctr
+
+       add             ip, r3, #64
+       aes_encrypt     ip, r6, e0, e1, e2, e3
+       bx              lr
+ENDPROC(pmull_aes_encrypt_4x)
+
+pmull_aes_encrypt_final:
+       add             ip, r5, #4
+       vld1.8          {ctr0}, [r5]
+       vld1.8          {ctr1}, [ip]
+       rev             r8, r7
+       vext.8          ctr1, ctr1, ctr1, #4
+       mov             r7, #1 << 24            // BE #1 for the tag
+       vmov.32         ctr1[1], r8
+       vmov            e0, ctr
+       vmov.32         ctr1[1], r7
+       vmov            e1, ctr
+
+       add             ip, r3, #64
+       aes_encrypt     ip, r6, e0, e1
+       bx              lr
+ENDPROC(pmull_aes_encrypt_final)
+
+       .macro          enc_1x, in0
+       bl              pmull_aes_encrypt
+       veor            \in0, \in0, e0
+       vst1.8          {\in0}, [r4]!
+       .endm
+
+       .macro          dec_1x, in0
+       bl              pmull_aes_encrypt
+       veor            e0, e0, \in0
+       vst1.8          {e0}, [r4]!
+       .endm
+
+       .macro          enc_4x, in0, in1, in2, in3
+       bl              pmull_aes_encrypt_4x
+
+       veor            \in0, \in0, e0
+       veor            \in1, \in1, e1
+       veor            \in2, \in2, e2
+       veor            \in3, \in3, e3
+
+       vst1.8          {\in0-\in1}, [r4]!
+       vst1.8          {\in2-\in3}, [r4]!
+       .endm
+
+       .macro          dec_4x, in0, in1, in2, in3
+       bl              pmull_aes_encrypt_4x
+
+       veor            e0, e0, \in0
+       veor            e1, e1, \in1
+       veor            e2, e2, \in2
+       veor            e3, e3, \in3
+
+       vst1.8          {e0-e1}, [r4]!
+       vst1.8          {e2-e3}, [r4]!
+       .endm
+
+       /*
+        * void pmull_gcm_encrypt(int blocks, u64 dg[], const char *src,
+        *                        struct gcm_key const *k, char *dst,
+        *                        char *iv, int rounds, u32 counter)
+        */
+ENTRY(pmull_gcm_encrypt)
+       push            {r4-r8, lr}
+       ldrd            r4, r5, [sp, #24]
+       ldrd            r6, r7, [sp, #32]
+
+       vld1.64         {SHASH}, [r3]
+
+       ghash_update    p64, enc, head=0
+       vst1.64         {XL}, [r1]
+
+       pop             {r4-r8, pc}
+ENDPROC(pmull_gcm_encrypt)
+
+       /*
+        * void pmull_gcm_decrypt(int blocks, u64 dg[], const char *src,
+        *                        struct gcm_key const *k, char *dst,
+        *                        char *iv, int rounds, u32 counter)
+        */
+ENTRY(pmull_gcm_decrypt)
+       push            {r4-r8, lr}
+       ldrd            r4, r5, [sp, #24]
+       ldrd            r6, r7, [sp, #32]
+
+       vld1.64         {SHASH}, [r3]
+
+       ghash_update    p64, dec, head=0
+       vst1.64         {XL}, [r1]
+
+       pop             {r4-r8, pc}
+ENDPROC(pmull_gcm_decrypt)
+
+       /*
+        * void pmull_gcm_enc_final(int bytes, u64 dg[], char *tag,
+        *                          struct gcm_key const *k, char *head,
+        *                          char *iv, int rounds, u32 counter)
+        */
+ENTRY(pmull_gcm_enc_final)
+       push            {r4-r8, lr}
+       ldrd            r4, r5, [sp, #24]
+       ldrd            r6, r7, [sp, #32]
+
+       bl              pmull_aes_encrypt_final
+
+       cmp             r0, #0
+       beq             .Lenc_final
+
+       mov_l           ip, .Lpermute
+       sub             r4, r4, #16
+       add             r8, ip, r0
+       add             ip, ip, #32
+       add             r4, r4, r0
+       sub             ip, ip, r0
+
+       vld1.8          {e3}, [r8]              // permute vector for key stream
+       vld1.8          {e2}, [ip]              // permute vector for ghash input
+
+       vtbl.8          e3l, {e0}, e3l
+       vtbl.8          e3h, {e0}, e3h
+
+       vld1.8          {e0}, [r4]              // encrypt tail block
+       veor            e0, e0, e3
+       vst1.8          {e0}, [r4]
+
+       vtbl.8          T1_L, {e0}, e2l
+       vtbl.8          T1_H, {e0}, e2h
+
+       vld1.64         {XL}, [r1]
+.Lenc_final:
+       vld1.64         {SHASH}, [r3, :128]
+       vmov.i8         MASK, #0xe1
+       veor            SHASH2_p64, SHASH_L, SHASH_H
+       vshl.u64        MASK, MASK, #57
+       mov             r0, #1
+       bne             3f                      // process head block first
+       ghash_update    p64, aggregate=0, head=0
+
+       vrev64.8        XL, XL
+       vext.8          XL, XL, XL, #8
+       veor            XL, XL, e1
+
+       sub             r2, r2, #16             // rewind src pointer
+       vst1.8          {XL}, [r2]              // store tag
+
+       pop             {r4-r8, pc}
+ENDPROC(pmull_gcm_enc_final)
+
+       /*
+        * int pmull_gcm_dec_final(int bytes, u64 dg[], char *tag,
+        *                         struct gcm_key const *k, char *head,
+        *                         char *iv, int rounds, u32 counter,
+        *                         const char *otag, int authsize)
+        */
+ENTRY(pmull_gcm_dec_final)
+       push            {r4-r8, lr}
+       ldrd            r4, r5, [sp, #24]
+       ldrd            r6, r7, [sp, #32]
+
+       bl              pmull_aes_encrypt_final
+
+       cmp             r0, #0
+       beq             .Ldec_final
+
+       mov_l           ip, .Lpermute
+       sub             r4, r4, #16
+       add             r8, ip, r0
+       add             ip, ip, #32
+       add             r4, r4, r0
+       sub             ip, ip, r0
+
+       vld1.8          {e3}, [r8]              // permute vector for key stream
+       vld1.8          {e2}, [ip]              // permute vector for ghash input
+
+       vtbl.8          e3l, {e0}, e3l
+       vtbl.8          e3h, {e0}, e3h
+
+       vld1.8          {e0}, [r4]
+
+       vtbl.8          T1_L, {e0}, e2l
+       vtbl.8          T1_H, {e0}, e2h
+
+       veor            e0, e0, e3
+       vst1.8          {e0}, [r4]
+
+       vld1.64         {XL}, [r1]
+.Ldec_final:
+       vld1.64         {SHASH}, [r3]
+       vmov.i8         MASK, #0xe1
+       veor            SHASH2_p64, SHASH_L, SHASH_H
+       vshl.u64        MASK, MASK, #57
+       mov             r0, #1
+       bne             3f                      // process head block first
+       ghash_update    p64, aggregate=0, head=0
+
+       vrev64.8        XL, XL
+       vext.8          XL, XL, XL, #8
+       veor            XL, XL, e1
+
+       mov_l           ip, .Lpermute
+       ldrd            r2, r3, [sp, #40]       // otag and authsize
+       vld1.8          {T1}, [r2]
+       add             ip, ip, r3
+       vceq.i8         T1, T1, XL              // compare tags
+       vmvn            T1, T1                  // 0 for eq, -1 for ne
+
+       vld1.8          {e0}, [ip]
+       vtbl.8          XL_L, {T1}, e0l         // keep authsize bytes only
+       vtbl.8          XL_H, {T1}, e0h
+
+       vpmin.s8        XL_L, XL_L, XL_H        // take the minimum s8 across the vector
+       vpmin.s8        XL_L, XL_L, XL_L
+       vmov.32         r0, XL_L[0]             // fail if != 0x0
+
+       pop             {r4-r8, pc}
+ENDPROC(pmull_gcm_dec_final)
+
+       .section        ".rodata", "a", %progbits
+       .align          5
+.Lpermute:
+       .byte           0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+       .byte           0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+       .byte           0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07
+       .byte           0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f
+       .byte           0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+       .byte           0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
index f13401f..3ddf05b 100644 (file)
@@ -2,36 +2,53 @@
 /*
  * Accelerated GHASH implementation with ARMv8 vmull.p64 instructions.
  *
- * Copyright (C) 2015 - 2018 Linaro Ltd. <ard.biesheuvel@linaro.org>
+ * Copyright (C) 2015 - 2018 Linaro Ltd.
+ * Copyright (C) 2023 Google LLC.
  */
 
 #include <asm/hwcap.h>
 #include <asm/neon.h>
 #include <asm/simd.h>
 #include <asm/unaligned.h>
+#include <crypto/aes.h>
+#include <crypto/gcm.h>
 #include <crypto/b128ops.h>
 #include <crypto/cryptd.h>
+#include <crypto/internal/aead.h>
 #include <crypto/internal/hash.h>
 #include <crypto/internal/simd.h>
+#include <crypto/internal/skcipher.h>
 #include <crypto/gf128mul.h>
+#include <crypto/scatterwalk.h>
 #include <linux/cpufeature.h>
 #include <linux/crypto.h>
 #include <linux/jump_label.h>
 #include <linux/module.h>
 
 MODULE_DESCRIPTION("GHASH hash function using ARMv8 Crypto Extensions");
-MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
-MODULE_LICENSE("GPL v2");
+MODULE_AUTHOR("Ard Biesheuvel <ardb@kernel.org>");
+MODULE_LICENSE("GPL");
 MODULE_ALIAS_CRYPTO("ghash");
+MODULE_ALIAS_CRYPTO("gcm(aes)");
+MODULE_ALIAS_CRYPTO("rfc4106(gcm(aes))");
 
 #define GHASH_BLOCK_SIZE       16
 #define GHASH_DIGEST_SIZE      16
 
+#define RFC4106_NONCE_SIZE     4
+
 struct ghash_key {
        be128   k;
        u64     h[][2];
 };
 
+struct gcm_key {
+       u64     h[4][2];
+       u32     rk[AES_MAX_KEYLENGTH_U32];
+       int     rounds;
+       u8      nonce[];        // for RFC4106 nonce
+};
+
 struct ghash_desc_ctx {
        u64 digest[GHASH_DIGEST_SIZE/sizeof(u64)];
        u8 buf[GHASH_BLOCK_SIZE];
@@ -344,6 +361,393 @@ static struct ahash_alg ghash_async_alg = {
        },
 };
 
+
+void pmull_gcm_encrypt(int blocks, u64 dg[], const char *src,
+                      struct gcm_key const *k, char *dst,
+                      const char *iv, int rounds, u32 counter);
+
+void pmull_gcm_enc_final(int blocks, u64 dg[], char *tag,
+                        struct gcm_key const *k, char *head,
+                        const char *iv, int rounds, u32 counter);
+
+void pmull_gcm_decrypt(int bytes, u64 dg[], const char *src,
+                      struct gcm_key const *k, char *dst,
+                      const char *iv, int rounds, u32 counter);
+
+int pmull_gcm_dec_final(int bytes, u64 dg[], char *tag,
+                       struct gcm_key const *k, char *head,
+                       const char *iv, int rounds, u32 counter,
+                       const char *otag, int authsize);
+
+static int gcm_aes_setkey(struct crypto_aead *tfm, const u8 *inkey,
+                         unsigned int keylen)
+{
+       struct gcm_key *ctx = crypto_aead_ctx(tfm);
+       struct crypto_aes_ctx aes_ctx;
+       be128 h, k;
+       int ret;
+
+       ret = aes_expandkey(&aes_ctx, inkey, keylen);
+       if (ret)
+               return -EINVAL;
+
+       aes_encrypt(&aes_ctx, (u8 *)&k, (u8[AES_BLOCK_SIZE]){});
+
+       memcpy(ctx->rk, aes_ctx.key_enc, sizeof(ctx->rk));
+       ctx->rounds = 6 + keylen / 4;
+
+       memzero_explicit(&aes_ctx, sizeof(aes_ctx));
+
+       ghash_reflect(ctx->h[0], &k);
+
+       h = k;
+       gf128mul_lle(&h, &k);
+       ghash_reflect(ctx->h[1], &h);
+
+       gf128mul_lle(&h, &k);
+       ghash_reflect(ctx->h[2], &h);
+
+       gf128mul_lle(&h, &k);
+       ghash_reflect(ctx->h[3], &h);
+
+       return 0;
+}
+
+static int gcm_aes_setauthsize(struct crypto_aead *tfm, unsigned int authsize)
+{
+       return crypto_gcm_check_authsize(authsize);
+}
+
+static void gcm_update_mac(u64 dg[], const u8 *src, int count, u8 buf[],
+                          int *buf_count, struct gcm_key *ctx)
+{
+       if (*buf_count > 0) {
+               int buf_added = min(count, GHASH_BLOCK_SIZE - *buf_count);
+
+               memcpy(&buf[*buf_count], src, buf_added);
+
+               *buf_count += buf_added;
+               src += buf_added;
+               count -= buf_added;
+       }
+
+       if (count >= GHASH_BLOCK_SIZE || *buf_count == GHASH_BLOCK_SIZE) {
+               int blocks = count / GHASH_BLOCK_SIZE;
+
+               pmull_ghash_update_p64(blocks, dg, src, ctx->h,
+                                      *buf_count ? buf : NULL);
+
+               src += blocks * GHASH_BLOCK_SIZE;
+               count %= GHASH_BLOCK_SIZE;
+               *buf_count = 0;
+       }
+
+       if (count > 0) {
+               memcpy(buf, src, count);
+               *buf_count = count;
+       }
+}
+
+static void gcm_calculate_auth_mac(struct aead_request *req, u64 dg[], u32 len)
+{
+       struct crypto_aead *aead = crypto_aead_reqtfm(req);
+       struct gcm_key *ctx = crypto_aead_ctx(aead);
+       u8 buf[GHASH_BLOCK_SIZE];
+       struct scatter_walk walk;
+       int buf_count = 0;
+
+       scatterwalk_start(&walk, req->src);
+
+       do {
+               u32 n = scatterwalk_clamp(&walk, len);
+               u8 *p;
+
+               if (!n) {
+                       scatterwalk_start(&walk, sg_next(walk.sg));
+                       n = scatterwalk_clamp(&walk, len);
+               }
+
+               p = scatterwalk_map(&walk);
+               gcm_update_mac(dg, p, n, buf, &buf_count, ctx);
+               scatterwalk_unmap(p);
+
+               if (unlikely(len / SZ_4K > (len - n) / SZ_4K)) {
+                       kernel_neon_end();
+                       kernel_neon_begin();
+               }
+
+               len -= n;
+               scatterwalk_advance(&walk, n);
+               scatterwalk_done(&walk, 0, len);
+       } while (len);
+
+       if (buf_count) {
+               memset(&buf[buf_count], 0, GHASH_BLOCK_SIZE - buf_count);
+               pmull_ghash_update_p64(1, dg, buf, ctx->h, NULL);
+       }
+}
+
+static int gcm_encrypt(struct aead_request *req, const u8 *iv, u32 assoclen)
+{
+       struct crypto_aead *aead = crypto_aead_reqtfm(req);
+       struct gcm_key *ctx = crypto_aead_ctx(aead);
+       struct skcipher_walk walk;
+       u8 buf[AES_BLOCK_SIZE];
+       u32 counter = 2;
+       u64 dg[2] = {};
+       be128 lengths;
+       const u8 *src;
+       u8 *tag, *dst;
+       int tail, err;
+
+       if (WARN_ON_ONCE(!may_use_simd()))
+               return -EBUSY;
+
+       err = skcipher_walk_aead_encrypt(&walk, req, false);
+
+       kernel_neon_begin();
+
+       if (assoclen)
+               gcm_calculate_auth_mac(req, dg, assoclen);
+
+       src = walk.src.virt.addr;
+       dst = walk.dst.virt.addr;
+
+       while (walk.nbytes >= AES_BLOCK_SIZE) {
+               int nblocks = walk.nbytes / AES_BLOCK_SIZE;
+
+               pmull_gcm_encrypt(nblocks, dg, src, ctx, dst, iv,
+                                 ctx->rounds, counter);
+               counter += nblocks;
+
+               if (walk.nbytes == walk.total) {
+                       src += nblocks * AES_BLOCK_SIZE;
+                       dst += nblocks * AES_BLOCK_SIZE;
+                       break;
+               }
+
+               kernel_neon_end();
+
+               err = skcipher_walk_done(&walk,
+                                        walk.nbytes % AES_BLOCK_SIZE);
+               if (err)
+                       return err;
+
+               src = walk.src.virt.addr;
+               dst = walk.dst.virt.addr;
+
+               kernel_neon_begin();
+       }
+
+
+       lengths.a = cpu_to_be64(assoclen * 8);
+       lengths.b = cpu_to_be64(req->cryptlen * 8);
+
+       tag = (u8 *)&lengths;
+       tail = walk.nbytes % AES_BLOCK_SIZE;
+
+       /*
+        * Bounce via a buffer unless we are encrypting in place and src/dst
+        * are not pointing to the start of the walk buffer. In that case, we
+        * can do a NEON load/xor/store sequence in place as long as we move
+        * the plain/ciphertext and keystream to the start of the register. If
+        * not, do a memcpy() to the end of the buffer so we can reuse the same
+        * logic.
+        */
+       if (unlikely(tail && (tail == walk.nbytes || src != dst)))
+               src = memcpy(buf + sizeof(buf) - tail, src, tail);
+
+       pmull_gcm_enc_final(tail, dg, tag, ctx, (u8 *)src, iv,
+                           ctx->rounds, counter);
+       kernel_neon_end();
+
+       if (unlikely(tail && src != dst))
+               memcpy(dst, src, tail);
+
+       if (walk.nbytes) {
+               err = skcipher_walk_done(&walk, 0);
+               if (err)
+                       return err;
+       }
+
+       /* copy authtag to end of dst */
+       scatterwalk_map_and_copy(tag, req->dst, req->assoclen + req->cryptlen,
+                                crypto_aead_authsize(aead), 1);
+
+       return 0;
+}
+
+static int gcm_decrypt(struct aead_request *req, const u8 *iv, u32 assoclen)
+{
+       struct crypto_aead *aead = crypto_aead_reqtfm(req);
+       struct gcm_key *ctx = crypto_aead_ctx(aead);
+       int authsize = crypto_aead_authsize(aead);
+       struct skcipher_walk walk;
+       u8 otag[AES_BLOCK_SIZE];
+       u8 buf[AES_BLOCK_SIZE];
+       u32 counter = 2;
+       u64 dg[2] = {};
+       be128 lengths;
+       const u8 *src;
+       u8 *tag, *dst;
+       int tail, err, ret;
+
+       if (WARN_ON_ONCE(!may_use_simd()))
+               return -EBUSY;
+
+       scatterwalk_map_and_copy(otag, req->src,
+                                req->assoclen + req->cryptlen - authsize,
+                                authsize, 0);
+
+       err = skcipher_walk_aead_decrypt(&walk, req, false);
+
+       kernel_neon_begin();
+
+       if (assoclen)
+               gcm_calculate_auth_mac(req, dg, assoclen);
+
+       src = walk.src.virt.addr;
+       dst = walk.dst.virt.addr;
+
+       while (walk.nbytes >= AES_BLOCK_SIZE) {
+               int nblocks = walk.nbytes / AES_BLOCK_SIZE;
+
+               pmull_gcm_decrypt(nblocks, dg, src, ctx, dst, iv,
+                                 ctx->rounds, counter);
+               counter += nblocks;
+
+               if (walk.nbytes == walk.total) {
+                       src += nblocks * AES_BLOCK_SIZE;
+                       dst += nblocks * AES_BLOCK_SIZE;
+                       break;
+               }
+
+               kernel_neon_end();
+
+               err = skcipher_walk_done(&walk,
+                                        walk.nbytes % AES_BLOCK_SIZE);
+               if (err)
+                       return err;
+
+               src = walk.src.virt.addr;
+               dst = walk.dst.virt.addr;
+
+               kernel_neon_begin();
+       }
+
+       lengths.a = cpu_to_be64(assoclen * 8);
+       lengths.b = cpu_to_be64((req->cryptlen - authsize) * 8);
+
+       tag = (u8 *)&lengths;
+       tail = walk.nbytes % AES_BLOCK_SIZE;
+
+       if (unlikely(tail && (tail == walk.nbytes || src != dst)))
+               src = memcpy(buf + sizeof(buf) - tail, src, tail);
+
+       ret = pmull_gcm_dec_final(tail, dg, tag, ctx, (u8 *)src, iv,
+                                 ctx->rounds, counter, otag, authsize);
+       kernel_neon_end();
+
+       if (unlikely(tail && src != dst))
+               memcpy(dst, src, tail);
+
+       if (walk.nbytes) {
+               err = skcipher_walk_done(&walk, 0);
+               if (err)
+                       return err;
+       }
+
+       return ret ? -EBADMSG : 0;
+}
+
+static int gcm_aes_encrypt(struct aead_request *req)
+{
+       return gcm_encrypt(req, req->iv, req->assoclen);
+}
+
+static int gcm_aes_decrypt(struct aead_request *req)
+{
+       return gcm_decrypt(req, req->iv, req->assoclen);
+}
+
+static int rfc4106_setkey(struct crypto_aead *tfm, const u8 *inkey,
+                         unsigned int keylen)
+{
+       struct gcm_key *ctx = crypto_aead_ctx(tfm);
+       int err;
+
+       keylen -= RFC4106_NONCE_SIZE;
+       err = gcm_aes_setkey(tfm, inkey, keylen);
+       if (err)
+               return err;
+
+       memcpy(ctx->nonce, inkey + keylen, RFC4106_NONCE_SIZE);
+       return 0;
+}
+
+static int rfc4106_setauthsize(struct crypto_aead *tfm, unsigned int authsize)
+{
+       return crypto_rfc4106_check_authsize(authsize);
+}
+
+static int rfc4106_encrypt(struct aead_request *req)
+{
+       struct crypto_aead *aead = crypto_aead_reqtfm(req);
+       struct gcm_key *ctx = crypto_aead_ctx(aead);
+       u8 iv[GCM_AES_IV_SIZE];
+
+       memcpy(iv, ctx->nonce, RFC4106_NONCE_SIZE);
+       memcpy(iv + RFC4106_NONCE_SIZE, req->iv, GCM_RFC4106_IV_SIZE);
+
+       return crypto_ipsec_check_assoclen(req->assoclen) ?:
+              gcm_encrypt(req, iv, req->assoclen - GCM_RFC4106_IV_SIZE);
+}
+
+static int rfc4106_decrypt(struct aead_request *req)
+{
+       struct crypto_aead *aead = crypto_aead_reqtfm(req);
+       struct gcm_key *ctx = crypto_aead_ctx(aead);
+       u8 iv[GCM_AES_IV_SIZE];
+
+       memcpy(iv, ctx->nonce, RFC4106_NONCE_SIZE);
+       memcpy(iv + RFC4106_NONCE_SIZE, req->iv, GCM_RFC4106_IV_SIZE);
+
+       return crypto_ipsec_check_assoclen(req->assoclen) ?:
+              gcm_decrypt(req, iv, req->assoclen - GCM_RFC4106_IV_SIZE);
+}
+
+static struct aead_alg gcm_aes_algs[] = {{
+       .ivsize                 = GCM_AES_IV_SIZE,
+       .chunksize              = AES_BLOCK_SIZE,
+       .maxauthsize            = AES_BLOCK_SIZE,
+       .setkey                 = gcm_aes_setkey,
+       .setauthsize            = gcm_aes_setauthsize,
+       .encrypt                = gcm_aes_encrypt,
+       .decrypt                = gcm_aes_decrypt,
+
+       .base.cra_name          = "gcm(aes)",
+       .base.cra_driver_name   = "gcm-aes-ce",
+       .base.cra_priority      = 400,
+       .base.cra_blocksize     = 1,
+       .base.cra_ctxsize       = sizeof(struct gcm_key),
+       .base.cra_module        = THIS_MODULE,
+}, {
+       .ivsize                 = GCM_RFC4106_IV_SIZE,
+       .chunksize              = AES_BLOCK_SIZE,
+       .maxauthsize            = AES_BLOCK_SIZE,
+       .setkey                 = rfc4106_setkey,
+       .setauthsize            = rfc4106_setauthsize,
+       .encrypt                = rfc4106_encrypt,
+       .decrypt                = rfc4106_decrypt,
+
+       .base.cra_name          = "rfc4106(gcm(aes))",
+       .base.cra_driver_name   = "rfc4106-gcm-aes-ce",
+       .base.cra_priority      = 400,
+       .base.cra_blocksize     = 1,
+       .base.cra_ctxsize       = sizeof(struct gcm_key) + RFC4106_NONCE_SIZE,
+       .base.cra_module        = THIS_MODULE,
+}};
+
 static int __init ghash_ce_mod_init(void)
 {
        int err;
@@ -352,13 +756,17 @@ static int __init ghash_ce_mod_init(void)
                return -ENODEV;
 
        if (elf_hwcap2 & HWCAP2_PMULL) {
+               err = crypto_register_aeads(gcm_aes_algs,
+                                           ARRAY_SIZE(gcm_aes_algs));
+               if (err)
+                       return err;
                ghash_alg.base.cra_ctxsize += 3 * sizeof(u64[2]);
                static_branch_enable(&use_p64);
        }
 
        err = crypto_register_shash(&ghash_alg);
        if (err)
-               return err;
+               goto err_aead;
        err = crypto_register_ahash(&ghash_async_alg);
        if (err)
                goto err_shash;
@@ -367,6 +775,10 @@ static int __init ghash_ce_mod_init(void)
 
 err_shash:
        crypto_unregister_shash(&ghash_alg);
+err_aead:
+       if (elf_hwcap2 & HWCAP2_PMULL)
+               crypto_unregister_aeads(gcm_aes_algs,
+                                       ARRAY_SIZE(gcm_aes_algs));
        return err;
 }
 
@@ -374,6 +786,9 @@ static void __exit ghash_ce_mod_exit(void)
 {
        crypto_unregister_ahash(&ghash_async_alg);
        crypto_unregister_shash(&ghash_alg);
+       if (elf_hwcap2 & HWCAP2_PMULL)
+               crypto_unregister_aeads(gcm_aes_algs,
+                                       ARRAY_SIZE(gcm_aes_algs));
 }
 
 module_init(ghash_ce_mod_init);