tls: rx: async: adjust record geometry immediately
[platform/kernel/linux-rpi.git] / net / tls / tls_sw.c
index 6a98754..09fe2cf 100644 (file)
@@ -184,39 +184,22 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
        struct scatterlist *sgin = aead_req->src;
        struct tls_sw_context_rx *ctx;
        struct tls_context *tls_ctx;
-       struct tls_prot_info *prot;
        struct scatterlist *sg;
-       struct sk_buff *skb;
        unsigned int pages;
+       struct sock *sk;
 
-       skb = (struct sk_buff *)req->data;
-       tls_ctx = tls_get_ctx(skb->sk);
+       sk = (struct sock *)req->data;
+       tls_ctx = tls_get_ctx(sk);
        ctx = tls_sw_ctx_rx(tls_ctx);
-       prot = &tls_ctx->prot_info;
 
        /* Propagate if there was an err */
        if (err) {
                if (err == -EBADMSG)
-                       TLS_INC_STATS(sock_net(skb->sk),
-                                     LINUX_MIB_TLSDECRYPTERROR);
+                       TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
                ctx->async_wait.err = err;
-               tls_err_abort(skb->sk, err);
-       } else {
-               struct strp_msg *rxm = strp_msg(skb);
-
-               /* No TLS 1.3 support with async crypto */
-               WARN_ON(prot->tail_size);
-
-               rxm->offset += prot->prepend_size;
-               rxm->full_len -= prot->overhead_size;
+               tls_err_abort(sk, err);
        }
 
-       /* After using skb->sk to propagate sk through crypto async callback
-        * we need to NULL it again.
-        */
-       skb->sk = NULL;
-
-
        /* Free the destination pages if skb was not decrypted inplace */
        if (sgout != sgin) {
                /* Skip the first S/G entry as it points to AAD */
@@ -236,7 +219,6 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
 }
 
 static int tls_do_decryption(struct sock *sk,
-                            struct sk_buff *skb,
                             struct scatterlist *sgin,
                             struct scatterlist *sgout,
                             char *iv_recv,
@@ -256,16 +238,9 @@ static int tls_do_decryption(struct sock *sk,
                               (u8 *)iv_recv);
 
        if (darg->async) {
-               /* Using skb->sk to push sk through to crypto async callback
-                * handler. This allows propagating errors up to the socket
-                * if needed. It _must_ be cleared in the async handler
-                * before consume_skb is called. We _know_ skb->sk is NULL
-                * because it is a clone from strparser.
-                */
-               skb->sk = sk;
                aead_request_set_callback(aead_req,
                                          CRYPTO_TFM_REQ_MAY_BACKLOG,
-                                         tls_decrypt_done, skb);
+                                         tls_decrypt_done, sk);
                atomic_inc(&ctx->decrypt_pending);
        } else {
                aead_request_set_callback(aead_req,
@@ -1554,7 +1529,7 @@ fallback_to_reg_recv:
        }
 
        /* Prepare and submit AEAD request */
-       err = tls_do_decryption(sk, skb, sgin, sgout, dctx->iv,
+       err = tls_do_decryption(sk, sgin, sgout, dctx->iv,
                                data_len + prot->tail_size, aead_req, darg);
        if (err)
                goto exit_free_pages;
@@ -1617,11 +1592,8 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
                        TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
                return err;
        }
-       if (darg->async) {
-               if (darg->skb == ctx->recv_pkt)
-                       ctx->recv_pkt = NULL;
-               goto decrypt_next;
-       }
+       if (darg->async)
+               goto decrypt_done;
        /* If opportunistic TLS 1.3 ZC failed retry without ZC */
        if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
                     darg->tail != TLS_RECORD_TYPE_DATA)) {
@@ -1632,10 +1604,10 @@ static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest,
                return tls_rx_one_record(sk, dest, darg);
        }
 
+decrypt_done:
        if (darg->skb == ctx->recv_pkt)
                ctx->recv_pkt = NULL;
 
-decrypt_done:
        pad = tls_padding_length(prot, darg->skb, darg);
        if (pad < 0) {
                consume_skb(darg->skb);
@@ -1646,7 +1618,6 @@ decrypt_done:
        rxm->full_len -= pad;
        rxm->offset += prot->prepend_size;
        rxm->full_len -= prot->overhead_size;
-decrypt_next:
        tls_advance_record_sn(sk, prot, &tls_ctx->rx);
 
        return 0;