CR_1737:crypto:starfive: Fixed AEAD tag generation and verification
authorjiajie.ho <jiajie.ho@starfivetech.com>
Thu, 22 Sep 2022 04:31:49 +0000 (12:31 +0800)
committerjiajie.ho <jiajie.ho@starfivetech.com>
Thu, 22 Sep 2022 04:31:49 +0000 (12:31 +0800)
Added support in AES GCM and CCM mode for various tag length and
tag verification for non-AES blocksize aligned text.

Signed-off-by: jiajie.ho <jiajie.ho@starfivetech.com>
drivers/crypto/starfive/jh7110/jh7110-aes.c
drivers/crypto/starfive/jh7110/jh7110-pka.c
drivers/crypto/starfive/jh7110/jh7110-sha.c
drivers/crypto/starfive/jh7110/jh7110-str.h

index 1d8fa7b..163e0dd 100755 (executable)
@@ -55,8 +55,6 @@
 /* Misc */
 #define AES_BLOCK_32                           (AES_BLOCK_SIZE / sizeof(u32))
 #define GCM_CTR_INIT                           1
-#define _walked_in                             (cryp->in_walk.offset - cryp->in_sg->offset)
-#define _walked_out                            (cryp->out_walk.offset - cryp->out_sg->offset)
 #define CRYP_AUTOSUSPEND_DELAY                 50
 
 static inline int jh7110_aes_wait_busy(struct jh7110_sec_ctx *ctx)
@@ -832,7 +830,7 @@ static int jh7110_cryp_write_out_cpu(struct jh7110_sec_ctx *ctx)
        struct jh7110_sec_dev *sdev = ctx->sdev;
        struct jh7110_sec_request_ctx *rctx = ctx->rctx;
        unsigned int  *buffer, *out;
-       int total_len, mlen, loop, count;
+       int total_len, loop, count;
 
        total_len = rctx->bufcnt;
        buffer = (unsigned int *)sdev->aes_data;
@@ -911,8 +909,6 @@ static int jh7110_cryp_write_data(struct jh7110_sec_ctx *ctx)
                        rctx->bufcnt = rctx->bufcnt - rctx->authsize;
 
                if (rctx->bufcnt) {
-                       memcpy((void *)rctx->msg_end, (void *)sdev->aes_data + rctx->bufcnt - JH7110_AES_IV_LEN,
-                              JH7110_AES_IV_LEN);
                        if (sdev->use_dma)
                                ret = jh7110_cryp_write_out_dma(ctx);
                        else
@@ -948,6 +944,7 @@ static int jh7110_cryp_gcm_write_aad(struct jh7110_sec_ctx *ctx)
                buffer++;
                jh7110_sec_write(sdev, JH7110_AES_NONCE3, *buffer);
                buffer++;
+               udelay(2);
        }
 
        if (jh7110_aes_wait_gcmdone(ctx))
@@ -1023,8 +1020,6 @@ static int jh7110_cryp_xcm_write_data(struct jh7110_sec_ctx *ctx)
        struct jh7110_sec_dev *sdev = ctx->sdev;
        struct jh7110_sec_request_ctx *rctx = ctx->rctx;
        size_t data_len, total, count, data_buf_len, offset, auths;
-       unsigned int *out;
-       int loop;
        int ret;
        bool fragmented = false;
 
@@ -1105,11 +1100,6 @@ static int jh7110_cryp_xcm_write_data(struct jh7110_sec_ctx *ctx)
                total += data_len;
 
                if (rctx->bufcnt) {
-                       memcpy((void *)rctx->msg_end, (void *)sdev->aes_data + rctx->bufcnt - JH7110_AES_IV_LEN,
-                              JH7110_AES_IV_LEN);
-                       out = (unsigned int *)sdev->aes_data;
-                       for (loop = 0; loop < rctx->bufcnt / 4; loop++)
-                               dev_dbg(sdev->dev, "aes_data[%d] = %x\n", loop, out[loop]);
                        if (sdev->use_dma)
                                ret = jh7110_cryp_write_out_dma(ctx);
                        else
@@ -1205,6 +1195,8 @@ static int jh7110_cryp_prepare_aead_req(struct crypto_engine *engine,
 static int jh7110_cryp_aes_aead_init(struct crypto_aead *tfm)
 {
        struct jh7110_sec_ctx *ctx = crypto_aead_ctx(tfm);
+       struct crypto_tfm *aead = crypto_aead_tfm(tfm);
+       struct crypto_alg *alg = aead->__crt_alg;
 
        ctx->sdev = jh7110_sec_find_dev(ctx);
 
@@ -1213,6 +1205,18 @@ static int jh7110_cryp_aes_aead_init(struct crypto_aead *tfm)
 
        crypto_aead_set_reqsize(tfm, sizeof(struct jh7110_sec_request_ctx));
 
+       if (alg->cra_flags & CRYPTO_ALG_NEED_FALLBACK) {
+               ctx->fallback.aead =
+                       crypto_alloc_aead(alg->cra_name, 0,
+                                       CRYPTO_ALG_ASYNC |
+                                       CRYPTO_ALG_NEED_FALLBACK);
+               if (IS_ERR(ctx->fallback.aead)) {
+                       pr_err("%s() failed to allocate fallback for %s\n",
+                                       __func__, alg->cra_name);
+                       return PTR_ERR(ctx->fallback.aead);
+               }
+       }
+
        ctx->enginectx.op.do_one_request = jh7110_cryp_aead_one_req;
        ctx->enginectx.op.prepare_request = jh7110_cryp_prepare_aead_req;
        ctx->enginectx.op.unprepare_request = NULL;
@@ -1224,6 +1228,11 @@ static void jh7110_cryp_aes_aead_exit(struct crypto_aead *tfm)
 {
        struct jh7110_sec_ctx *ctx = crypto_aead_ctx(tfm);
 
+       if (ctx->fallback.aead) {
+               crypto_free_aead(ctx->fallback.aead);
+               ctx->fallback.aead = NULL;
+       }
+
        ctx->enginectx.op.do_one_request = NULL;
        ctx->enginectx.op.prepare_request = NULL;
        ctx->enginectx.op.unprepare_request = NULL;
@@ -1245,6 +1254,22 @@ static int jh7110_cryp_crypt(struct skcipher_request *req, unsigned long flags)
        return crypto_transfer_skcipher_request_to_engine(sdev->engine, req);
 }
 
+static int aead_do_fallback(struct aead_request *req)
+{
+       struct aead_request *subreq = aead_request_ctx(req);
+       struct crypto_aead *aead = crypto_aead_reqtfm(req);
+       struct jh7110_sec_ctx *ctx = crypto_aead_ctx(aead);
+
+       aead_request_set_tfm(subreq, ctx->fallback.aead);
+       aead_request_set_callback(subreq, req->base.flags,
+                       req->base.complete, req->base.data);
+       aead_request_set_crypt(subreq, req->src,
+                       req->dst, req->cryptlen, req->iv);
+       aead_request_set_ad(subreq, req->assoclen);
+
+       return crypto_aead_decrypt(subreq);
+}
+
 static int jh7110_cryp_aead_crypt(struct aead_request *req, unsigned long flags)
 {
        struct jh7110_sec_ctx *ctx = crypto_aead_ctx(crypto_aead_reqtfm(req));
@@ -1257,6 +1282,12 @@ static int jh7110_cryp_aead_crypt(struct aead_request *req, unsigned long flags)
        rctx->flags = flags;
        rctx->req_type = JH7110_AEAD_REQ;
 
+       /* HW engine could not perform tag verification on
+        * non-blocksize aligned ciphertext, use fallback algo instead
+        */
+       if (ctx->fallback.aead && is_decrypt(rctx))
+               return aead_do_fallback(req);
+
        return crypto_transfer_aead_request_to_engine(sdev->engine, req);
 }
 
@@ -1290,6 +1321,7 @@ static int jh7110_cryp_aes_aead_setkey(struct crypto_aead *tfm, const u8 *key,
                                      unsigned int keylen)
 {
        struct jh7110_sec_ctx *ctx = crypto_aead_ctx(tfm);
+       int ret = 0;
 
        if (keylen != AES_KEYSIZE_128 && keylen != AES_KEYSIZE_192 &&
            keylen != AES_KEYSIZE_256) {
@@ -1298,13 +1330,11 @@ static int jh7110_cryp_aes_aead_setkey(struct crypto_aead *tfm, const u8 *key,
 
        memcpy(ctx->key, key, keylen);
        ctx->keylen = keylen;
-       {
-               int loop;
 
-               for (loop = 0; loop < keylen; loop++)
-                       pr_debug("key[%d] = %x ctx->key[%d] = %x\n", loop, key[loop], loop, ctx->key[loop]);
-       }
-       return 0;
+       if (ctx->fallback.aead)
+               ret = crypto_aead_setkey(ctx->fallback.aead, key, keylen);
+
+       return ret;
 }
 
 static int jh7110_cryp_aes_gcm_setauthsize(struct crypto_aead *tfm,
@@ -1316,6 +1346,9 @@ static int jh7110_cryp_aes_gcm_setauthsize(struct crypto_aead *tfm,
 static int jh7110_cryp_aes_ccm_setauthsize(struct crypto_aead *tfm,
                                          unsigned int authsize)
 {
+       struct jh7110_sec_ctx *ctx = crypto_aead_ctx(tfm);
+       int ret = 0;
+
        switch (authsize) {
        case 4:
        case 6:
@@ -1329,7 +1362,12 @@ static int jh7110_cryp_aes_ccm_setauthsize(struct crypto_aead *tfm,
                return -EINVAL;
        }
 
-       return 0;
+       tfm->authsize = authsize;
+
+       if (ctx->fallback.aead)
+               ctx->fallback.aead->authsize = authsize;
+
+       return ret;
 }
 
 static int jh7110_cryp_aes_ecb_encrypt(struct skcipher_request *req)
@@ -1594,7 +1632,7 @@ static struct skcipher_alg crypto_algs[] = {
        .base.cra_name                  = "cbc(aes)",
        .base.cra_driver_name           = "jh7110-cbc-aes",
        .base.cra_priority              = 200,
-       .base.cra_flags                 =  CRYPTO_ALG_ASYNC,
+       .base.cra_flags                 = CRYPTO_ALG_ASYNC,
        .base.cra_blocksize             = AES_BLOCK_SIZE,
        .base.cra_ctxsize               = sizeof(struct jh7110_sec_ctx),
        .base.cra_alignmask             = 0xf,
@@ -1696,7 +1734,7 @@ static struct aead_alg aead_algs[] = {
                .cra_name               = "ccm(aes)",
                .cra_driver_name        = "jh7110-ccm-aes",
                .cra_priority           = 200,
-               .cra_flags              = CRYPTO_ALG_ASYNC,
+               .cra_flags              = CRYPTO_ALG_ASYNC | CRYPTO_ALG_NEED_FALLBACK,
                .cra_blocksize          = 1,
                .cra_ctxsize            = sizeof(struct jh7110_sec_ctx),
                .cra_alignmask          = 0xf,
index 1831c30..f01313a 100644 (file)
@@ -387,7 +387,7 @@ static int jh7110_rsa_enc(struct akcipher_request *req)
        int ret = 0;
 
        if (key->key_sz > JH7110_RSA_MAX_KEYSZ) {
-               akcipher_request_set_tfm(req, ctx->soft_tfm);
+               akcipher_request_set_tfm(req, ctx->fallback.akcipher);
                ret = crypto_akcipher_encrypt(req);
                akcipher_request_set_tfm(req, tfm);
                return ret;
@@ -427,7 +427,7 @@ static int jh7110_rsa_dec(struct akcipher_request *req)
        int ret = 0;
 
        if (key->key_sz > JH7110_RSA_MAX_KEYSZ) {
-               akcipher_request_set_tfm(req, ctx->soft_tfm);
+               akcipher_request_set_tfm(req, ctx->fallback.akcipher);
                ret = crypto_akcipher_decrypt(req);
                akcipher_request_set_tfm(req, tfm);
                return ret;
@@ -627,7 +627,7 @@ static int jh7110_rsa_set_pub_key(struct crypto_akcipher *tfm, const void *key,
        struct jh7110_sec_ctx *ctx = akcipher_tfm_ctx(tfm);
        int ret;
 
-       ret = crypto_akcipher_set_pub_key(ctx->soft_tfm, key, keylen);
+       ret = crypto_akcipher_set_pub_key(ctx->fallback.akcipher, key, keylen);
        if (ret)
                return ret;
 
@@ -640,7 +640,7 @@ static int jh7110_rsa_set_priv_key(struct crypto_akcipher *tfm, const void *key,
        struct jh7110_sec_ctx *ctx = akcipher_tfm_ctx(tfm);
        int ret;
 
-       ret = crypto_akcipher_set_priv_key(ctx->soft_tfm, key, keylen);
+       ret = crypto_akcipher_set_priv_key(ctx->fallback.akcipher, key, keylen);
        if (ret)
                return ret;
 
@@ -653,7 +653,7 @@ static unsigned int jh7110_rsa_max_size(struct crypto_akcipher *tfm)
 
        /* For key sizes > 2Kb, use software tfm */
        if (ctx->rsa_key.key_sz > JH7110_RSA_MAX_KEYSZ)
-               return crypto_akcipher_maxsize(ctx->soft_tfm);
+               return crypto_akcipher_maxsize(ctx->fallback.akcipher);
 
        return ctx->rsa_key.key_sz;
 }
@@ -663,15 +663,15 @@ static int jh7110_rsa_init_tfm(struct crypto_akcipher *tfm)
 {
        struct jh7110_sec_ctx *ctx = akcipher_tfm_ctx(tfm);
 
-       ctx->soft_tfm = crypto_alloc_akcipher("rsa-generic", 0, 0);
-       if (IS_ERR(ctx->soft_tfm)) {
+       ctx->fallback.akcipher = crypto_alloc_akcipher("rsa-generic", 0, 0);
+       if (IS_ERR(ctx->fallback.akcipher)) {
                pr_err("Can not alloc_akcipher!\n");
-               return PTR_ERR(ctx->soft_tfm);
+               return PTR_ERR(ctx->fallback.akcipher);
        }
 
        ctx->sdev = jh7110_sec_find_dev(ctx);
        if (!ctx->sdev) {
-               crypto_free_akcipher(ctx->soft_tfm);
+               crypto_free_akcipher(ctx->fallback.akcipher);
                return -ENODEV;
        }
 
@@ -686,7 +686,7 @@ static void jh7110_rsa_exit_tfm(struct crypto_akcipher *tfm)
        struct jh7110_sec_ctx *ctx = akcipher_tfm_ctx(tfm);
        struct jh7110_rsa_key *key = (struct jh7110_rsa_key *)&ctx->rsa_key;
 
-       crypto_free_akcipher(ctx->soft_tfm);
+       crypto_free_akcipher(ctx->fallback.akcipher);
        jh7110_rsa_free_key(key);
 }
 
index 0be29a6..d995efc 100644 (file)
@@ -573,10 +573,10 @@ static int jh7110_hash_final(struct ahash_request *req)
 
        if (ctx->fallback_available && (rctx->bufcnt < JH7110_HASH_THRES)) {
                if (ctx->sha_mode & JH7110_SHA_HMAC_FLAGS)
-                       crypto_shash_setkey(ctx->fallback, ctx->key,
+                       crypto_shash_setkey(ctx->fallback.shash, ctx->key,
                                        ctx->keylen);
 
-               return crypto_shash_tfm_digest(ctx->fallback, ctx->buffer,
+               return crypto_shash_tfm_digest(ctx->fallback.shash, ctx->buffer,
                                rctx->bufcnt, req->result);
        }
 
@@ -643,10 +643,10 @@ static int jh7110_hash_cra_init_algs(struct crypto_tfm *tfm,
        if (!ctx->sdev)
                return -ENODEV;
 
-       ctx->fallback = crypto_alloc_shash(alg_name, 0,
+       ctx->fallback.shash = crypto_alloc_shash(alg_name, 0,
                        CRYPTO_ALG_NEED_FALLBACK);
-       
-       if (IS_ERR(ctx->fallback)) {
+
+       if (IS_ERR(ctx->fallback.shash)) {
                pr_err("fallback unavailable for '%s'\n", alg_name);
                ctx->fallback_available = false;
        }
@@ -673,9 +673,9 @@ static void jh7110_hash_cra_exit(struct crypto_tfm *tfm)
 {
        struct jh7110_sec_ctx *ctx = crypto_tfm_ctx(tfm);
 
-       crypto_free_shash(ctx->fallback);
+       crypto_free_shash(ctx->fallback.shash);
 
-       ctx->fallback = NULL;
+       ctx->fallback.shash = NULL;
        ctx->enginectx.op.do_one_request = NULL;
        ctx->enginectx.op.prepare_request = NULL;
        ctx->enginectx.op.unprepare_request = NULL;
index 27ffac2..06b210e 100644 (file)
@@ -60,8 +60,11 @@ struct jh7110_sec_ctx {
        struct jh7110_rsa_key                   rsa_key;
        size_t                                  sha_len_total;
        u8                                      *buffer;
-       struct crypto_akcipher                  *soft_tfm;
-       struct crypto_shash                     *fallback;
+       union {
+               struct crypto_akcipher                  *akcipher;
+               struct crypto_aead                      *aead;
+               struct crypto_shash                     *shash;
+       } fallback;
        bool                                    fallback_available;
 };
 
@@ -132,7 +135,6 @@ struct jh7110_sec_request_ctx {
                struct skcipher_request         *sreq;
                struct aead_request             *areq;
        } req;
-
 #define JH7110_AHASH_REQ                       0
 #define JH7110_ABLK_REQ                                1
 #define JH7110_AEAD_REQ                                2
@@ -176,8 +178,6 @@ struct jh7110_sec_request_ctx {
        size_t                                  assoclen;
        size_t                                  ctr_over_count;
 
-       u32                                     msg_end[4];
-       u32                                     dec_end[4];
        u32                                     last_ctr[4];
        u32                                     aes_iv[4];
        u32                                     tag_out[4];