crypto: arm64/aes-xctr - Add accelerated implementation of XCTR
authorNathan Huckleberry <nhuck@google.com>
Fri, 20 May 2022 18:14:57 +0000 (18:14 +0000)
committerHerbert Xu <herbert@gondor.apana.org.au>
Fri, 10 Jun 2022 08:40:17 +0000 (16:40 +0800)
Add hardware accelerated version of XCTR for ARM64 CPUs with ARMv8
Crypto Extension support.  This XCTR implementation is based on the CTR
implementation in aes-modes.S.

More information on XCTR can be found in
the HCTR2 paper: "Length-preserving encryption with HCTR2":
https://eprint.iacr.org/2021/1441.pdf

Signed-off-by: Nathan Huckleberry <nhuck@google.com>
Reviewed-by: Ard Biesheuvel <ardb@kernel.org>
Reviewed-by: Eric Biggers <ebiggers@google.com>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
arch/arm64/crypto/Kconfig
arch/arm64/crypto/aes-glue.c
arch/arm64/crypto/aes-modes.S

index ac85682..74d5bed 100644 (file)
@@ -96,13 +96,13 @@ config CRYPTO_AES_ARM64_CE_CCM
        select CRYPTO_LIB_AES
 
 config CRYPTO_AES_ARM64_CE_BLK
-       tristate "AES in ECB/CBC/CTR/XTS modes using ARMv8 Crypto Extensions"
+       tristate "AES in ECB/CBC/CTR/XTS/XCTR modes using ARMv8 Crypto Extensions"
        depends on KERNEL_MODE_NEON
        select CRYPTO_SKCIPHER
        select CRYPTO_AES_ARM64_CE
 
 config CRYPTO_AES_ARM64_NEON_BLK
-       tristate "AES in ECB/CBC/CTR/XTS modes using NEON instructions"
+       tristate "AES in ECB/CBC/CTR/XTS/XCTR modes using NEON instructions"
        depends on KERNEL_MODE_NEON
        select CRYPTO_SKCIPHER
        select CRYPTO_LIB_AES
index 561dd23..b688328 100644 (file)
 #define aes_essiv_cbc_encrypt  ce_aes_essiv_cbc_encrypt
 #define aes_essiv_cbc_decrypt  ce_aes_essiv_cbc_decrypt
 #define aes_ctr_encrypt                ce_aes_ctr_encrypt
+#define aes_xctr_encrypt       ce_aes_xctr_encrypt
 #define aes_xts_encrypt                ce_aes_xts_encrypt
 #define aes_xts_decrypt                ce_aes_xts_decrypt
 #define aes_mac_update         ce_aes_mac_update
-MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
+MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 Crypto Extensions");
 #else
 #define MODE                   "neon"
 #define PRIO                   200
@@ -50,16 +51,18 @@ MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
 #define aes_essiv_cbc_encrypt  neon_aes_essiv_cbc_encrypt
 #define aes_essiv_cbc_decrypt  neon_aes_essiv_cbc_decrypt
 #define aes_ctr_encrypt                neon_aes_ctr_encrypt
+#define aes_xctr_encrypt       neon_aes_xctr_encrypt
 #define aes_xts_encrypt                neon_aes_xts_encrypt
 #define aes_xts_decrypt                neon_aes_xts_decrypt
 #define aes_mac_update         neon_aes_mac_update
-MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
+MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 NEON");
 #endif
 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
 MODULE_ALIAS_CRYPTO("ecb(aes)");
 MODULE_ALIAS_CRYPTO("cbc(aes)");
 MODULE_ALIAS_CRYPTO("ctr(aes)");
 MODULE_ALIAS_CRYPTO("xts(aes)");
+MODULE_ALIAS_CRYPTO("xctr(aes)");
 #endif
 MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
 MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
@@ -89,6 +92,9 @@ asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
                                int rounds, int bytes, u8 ctr[]);
 
+asmlinkage void aes_xctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
+                                int rounds, int bytes, u8 ctr[], int byte_ctr);
+
 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
                                int rounds, int bytes, u32 const rk2[], u8 iv[],
                                int first);
@@ -442,6 +448,44 @@ static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
        return err ?: cbc_decrypt_walk(req, &walk);
 }
 
+static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
+{
+       struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+       struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+       int err, rounds = 6 + ctx->key_length / 4;
+       struct skcipher_walk walk;
+       unsigned int byte_ctr = 0;
+
+       err = skcipher_walk_virt(&walk, req, false);
+
+       while (walk.nbytes > 0) {
+               const u8 *src = walk.src.virt.addr;
+               unsigned int nbytes = walk.nbytes;
+               u8 *dst = walk.dst.virt.addr;
+               u8 buf[AES_BLOCK_SIZE];
+
+               if (unlikely(nbytes < AES_BLOCK_SIZE))
+                       src = dst = memcpy(buf + sizeof(buf) - nbytes,
+                                          src, nbytes);
+               else if (nbytes < walk.total)
+                       nbytes &= ~(AES_BLOCK_SIZE - 1);
+
+               kernel_neon_begin();
+               aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
+                                                walk.iv, byte_ctr);
+               kernel_neon_end();
+
+               if (unlikely(nbytes < AES_BLOCK_SIZE))
+                       memcpy(walk.dst.virt.addr,
+                              buf + sizeof(buf) - nbytes, nbytes);
+               byte_ctr += nbytes;
+
+               err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
+       }
+
+       return err;
+}
+
 static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
 {
        struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
@@ -671,6 +715,22 @@ static struct skcipher_alg aes_algs[] = { {
        .decrypt        = ctr_encrypt,
 }, {
        .base = {
+               .cra_name               = "xctr(aes)",
+               .cra_driver_name        = "xctr-aes-" MODE,
+               .cra_priority           = PRIO,
+               .cra_blocksize          = 1,
+               .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
+               .cra_module             = THIS_MODULE,
+       },
+       .min_keysize    = AES_MIN_KEY_SIZE,
+       .max_keysize    = AES_MAX_KEY_SIZE,
+       .ivsize         = AES_BLOCK_SIZE,
+       .chunksize      = AES_BLOCK_SIZE,
+       .setkey         = skcipher_aes_setkey,
+       .encrypt        = xctr_encrypt,
+       .decrypt        = xctr_encrypt,
+}, {
+       .base = {
                .cra_name               = "xts(aes)",
                .cra_driver_name        = "xts-aes-" MODE,
                .cra_priority           = PRIO,
index dc35eb0..6c36a3b 100644 (file)
@@ -318,79 +318,102 @@ AES_FUNC_END(aes_cbc_cts_decrypt)
        .byte           0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
        .previous
 
-
        /*
-        * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
-        *                 int bytes, u8 ctr[])
+        * This macro generates the code for CTR and XCTR mode.
         */
-
-AES_FUNC_START(aes_ctr_encrypt)
+.macro ctr_encrypt xctr
        stp             x29, x30, [sp, #-16]!
        mov             x29, sp
 
        enc_prepare     w3, x2, x12
        ld1             {vctr.16b}, [x5]
 
-       umov            x12, vctr.d[1]          /* keep swabbed ctr in reg */
-       rev             x12, x12
+       .if \xctr
+               umov            x12, vctr.d[0]
+               lsr             w11, w6, #4
+       .else
+               umov            x12, vctr.d[1] /* keep swabbed ctr in reg */
+               rev             x12, x12
+       .endif
 
-.LctrloopNx:
+.LctrloopNx\xctr:
        add             w7, w4, #15
        sub             w4, w4, #MAX_STRIDE << 4
        lsr             w7, w7, #4
        mov             w8, #MAX_STRIDE
        cmp             w7, w8
        csel            w7, w7, w8, lt
-       adds            x12, x12, x7
 
+       .if \xctr
+               add             x11, x11, x7
+       .else
+               adds            x12, x12, x7
+       .endif
        mov             v0.16b, vctr.16b
        mov             v1.16b, vctr.16b
        mov             v2.16b, vctr.16b
        mov             v3.16b, vctr.16b
 ST5(   mov             v4.16b, vctr.16b                )
-       bcs             0f
-
-       .subsection     1
-       /* apply carry to outgoing counter */
-0:     umov            x8, vctr.d[0]
-       rev             x8, x8
-       add             x8, x8, #1
-       rev             x8, x8
-       ins             vctr.d[0], x8
-
-       /* apply carry to N counter blocks for N := x12 */
-       cbz             x12, 2f
-       adr             x16, 1f
-       sub             x16, x16, x12, lsl #3
-       br              x16
-       bti             c
-       mov             v0.d[0], vctr.d[0]
-       bti             c
-       mov             v1.d[0], vctr.d[0]
-       bti             c
-       mov             v2.d[0], vctr.d[0]
-       bti             c
-       mov             v3.d[0], vctr.d[0]
-ST5(   bti             c                               )
-ST5(   mov             v4.d[0], vctr.d[0]              )
-1:     b               2f
-       .previous
-
-2:     rev             x7, x12
-       ins             vctr.d[1], x7
-       sub             x7, x12, #MAX_STRIDE - 1
-       sub             x8, x12, #MAX_STRIDE - 2
-       sub             x9, x12, #MAX_STRIDE - 3
-       rev             x7, x7
-       rev             x8, x8
-       mov             v1.d[1], x7
-       rev             x9, x9
-ST5(   sub             x10, x12, #MAX_STRIDE - 4       )
-       mov             v2.d[1], x8
-ST5(   rev             x10, x10                        )
-       mov             v3.d[1], x9
-ST5(   mov             v4.d[1], x10                    )
-       tbnz            w4, #31, .Lctrtail
+       .if \xctr
+               sub             x6, x11, #MAX_STRIDE - 1
+               sub             x7, x11, #MAX_STRIDE - 2
+               sub             x8, x11, #MAX_STRIDE - 3
+               sub             x9, x11, #MAX_STRIDE - 4
+ST5(           sub             x10, x11, #MAX_STRIDE - 5       )
+               eor             x6, x6, x12
+               eor             x7, x7, x12
+               eor             x8, x8, x12
+               eor             x9, x9, x12
+ST5(           eor             x10, x10, x12                   )
+               mov             v0.d[0], x6
+               mov             v1.d[0], x7
+               mov             v2.d[0], x8
+               mov             v3.d[0], x9
+ST5(           mov             v4.d[0], x10                    )
+       .else
+               bcs             0f
+               .subsection     1
+               /* apply carry to outgoing counter */
+0:             umov            x8, vctr.d[0]
+               rev             x8, x8
+               add             x8, x8, #1
+               rev             x8, x8
+               ins             vctr.d[0], x8
+
+               /* apply carry to N counter blocks for N := x12 */
+               cbz             x12, 2f
+               adr             x16, 1f
+               sub             x16, x16, x12, lsl #3
+               br              x16
+               bti             c
+               mov             v0.d[0], vctr.d[0]
+               bti             c
+               mov             v1.d[0], vctr.d[0]
+               bti             c
+               mov             v2.d[0], vctr.d[0]
+               bti             c
+               mov             v3.d[0], vctr.d[0]
+ST5(           bti             c                               )
+ST5(           mov             v4.d[0], vctr.d[0]              )
+1:             b               2f
+               .previous
+
+2:             rev             x7, x12
+               ins             vctr.d[1], x7
+               sub             x7, x12, #MAX_STRIDE - 1
+               sub             x8, x12, #MAX_STRIDE - 2
+               sub             x9, x12, #MAX_STRIDE - 3
+               rev             x7, x7
+               rev             x8, x8
+               mov             v1.d[1], x7
+               rev             x9, x9
+ST5(           sub             x10, x12, #MAX_STRIDE - 4       )
+               mov             v2.d[1], x8
+ST5(           rev             x10, x10                        )
+               mov             v3.d[1], x9
+ST5(           mov             v4.d[1], x10                    )
+       .endif
+       tbnz            w4, #31, .Lctrtail\xctr
        ld1             {v5.16b-v7.16b}, [x1], #48
 ST4(   bl              aes_encrypt_block4x             )
 ST5(   bl              aes_encrypt_block5x             )
@@ -403,16 +426,17 @@ ST5(      ld1             {v5.16b-v6.16b}, [x1], #32      )
 ST5(   eor             v4.16b, v6.16b, v4.16b          )
        st1             {v0.16b-v3.16b}, [x0], #64
 ST5(   st1             {v4.16b}, [x0], #16             )
-       cbz             w4, .Lctrout
-       b               .LctrloopNx
+       cbz             w4, .Lctrout\xctr
+       b               .LctrloopNx\xctr
 
-.Lctrout:
-       st1             {vctr.16b}, [x5]        /* return next CTR value */
+.Lctrout\xctr:
+       .if !\xctr
+               st1             {vctr.16b}, [x5] /* return next CTR value */
+       .endif
        ldp             x29, x30, [sp], #16
        ret
 
-.Lctrtail:
-       /* XOR up to MAX_STRIDE * 16 - 1 bytes of in/output with v0 ... v3/v4 */
+.Lctrtail\xctr:
        mov             x16, #16
        ands            x6, x4, #0xf
        csel            x13, x6, x16, ne
@@ -427,7 +451,7 @@ ST5(        csel            x14, x16, xzr, gt               )
 
        adr_l           x12, .Lcts_permute_table
        add             x12, x12, x13
-       ble             .Lctrtail1x
+       ble             .Lctrtail1x\xctr
 
 ST5(   ld1             {v5.16b}, [x1], x14             )
        ld1             {v6.16b}, [x1], x15
@@ -459,9 +483,9 @@ ST5(        st1             {v5.16b}, [x0], x14             )
        add             x13, x13, x0
        st1             {v9.16b}, [x13]         // overlapping stores
        st1             {v8.16b}, [x0]
-       b               .Lctrout
+       b               .Lctrout\xctr
 
-.Lctrtail1x:
+.Lctrtail1x\xctr:
        sub             x7, x6, #16
        csel            x6, x6, x7, eq
        add             x1, x1, x6
@@ -476,9 +500,27 @@ ST5(       mov             v3.16b, v4.16b                  )
        eor             v5.16b, v5.16b, v3.16b
        bif             v5.16b, v6.16b, v11.16b
        st1             {v5.16b}, [x0]
-       b               .Lctrout
+       b               .Lctrout\xctr
+.endm
+
+       /*
+        * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
+        *                 int bytes, u8 ctr[])
+        */
+
+AES_FUNC_START(aes_ctr_encrypt)
+       ctr_encrypt 0
 AES_FUNC_END(aes_ctr_encrypt)
 
+       /*
+        * aes_xctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
+        *                 int bytes, u8 const iv[], int byte_ctr)
+        */
+
+AES_FUNC_START(aes_xctr_encrypt)
+       ctr_encrypt 1
+AES_FUNC_END(aes_xctr_encrypt)
+
 
        /*
         * aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,