Merge git://git.kernel.org/pub/scm/linux/kernel/git/netdev/net
[platform/kernel/linux-starfive.git] / net / tls / tls_sw.c
index e30649f..f1777d6 100644 (file)
@@ -47,6 +47,7 @@
 struct tls_decrypt_arg {
        bool zc;
        bool async;
+       u8 tail;
 };
 
 noinline void tls_err_abort(struct sock *sk, int err)
@@ -133,7 +134,8 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len)
         return __skb_nsg(skb, offset, len, 0);
 }
 
-static int padding_length(struct tls_prot_info *prot, struct sk_buff *skb)
+static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb,
+                             struct tls_decrypt_arg *darg)
 {
        struct strp_msg *rxm = strp_msg(skb);
        struct tls_msg *tlm = tls_msg(skb);
@@ -142,7 +144,7 @@ static int padding_length(struct tls_prot_info *prot, struct sk_buff *skb)
        /* Determine zero-padding length */
        if (prot->version == TLS_1_3_VERSION) {
                int offset = rxm->full_len - TLS_TAG_SIZE - 1;
-               char content_type = 0;
+               char content_type = darg->zc ? darg->tail : 0;
                int err;
 
                while (content_type == 0) {
@@ -1415,18 +1417,18 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
        struct strp_msg *rxm = strp_msg(skb);
        struct tls_msg *tlm = tls_msg(skb);
        int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
+       u8 *aad, *iv, *tail, *mem = NULL;
        struct aead_request *aead_req;
        struct sk_buff *unused;
-       u8 *aad, *iv, *mem = NULL;
        struct scatterlist *sgin = NULL;
        struct scatterlist *sgout = NULL;
-       const int data_len = rxm->full_len - prot->overhead_size +
-                            prot->tail_size;
+       const int data_len = rxm->full_len - prot->overhead_size;
+       int tail_pages = !!prot->tail_size;
        int iv_offset = 0;
 
        if (darg->zc && (out_iov || out_sg)) {
                if (out_iov)
-                       n_sgout = 1 +
+                       n_sgout = 1 + tail_pages +
                                iov_iter_npages_cap(out_iov, INT_MAX, data_len);
                else
                        n_sgout = sg_nents(out_sg);
@@ -1450,9 +1452,10 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
        mem_size = aead_size + (nsg * sizeof(struct scatterlist));
        mem_size = mem_size + prot->aad_size;
        mem_size = mem_size + MAX_IV_SIZE;
+       mem_size = mem_size + prot->tail_size;
 
        /* Allocate a single block of memory which contains
-        * aead_req || sgin[] || sgout[] || aad || iv.
+        * aead_req || sgin[] || sgout[] || aad || iv || tail.
         * This order achieves correct alignment for aead_req, sgin, sgout.
         */
        mem = kmalloc(mem_size, sk->sk_allocation);
@@ -1465,6 +1468,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
        sgout = sgin + n_sgin;
        aad = (u8 *)(sgout + n_sgout);
        iv = aad + prot->aad_size;
+       tail = iv + MAX_IV_SIZE;
 
        /* For CCM based ciphers, first byte of nonce+iv is a constant */
        switch (prot->cipher_type) {
@@ -1518,9 +1522,16 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 
                        err = tls_setup_from_iter(out_iov, data_len,
                                                  &pages, &sgout[1],
-                                                 (n_sgout - 1));
+                                                 (n_sgout - 1 - tail_pages));
                        if (err < 0)
                                goto fallback_to_reg_recv;
+
+                       if (prot->tail_size) {
+                               sg_unmark_end(&sgout[pages]);
+                               sg_set_buf(&sgout[pages + 1], tail,
+                                          prot->tail_size);
+                               sg_mark_end(&sgout[pages + 1]);
+                       }
                } else if (out_sg) {
                        memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
                } else {
@@ -1535,10 +1546,13 @@ fallback_to_reg_recv:
 
        /* Prepare and submit AEAD request */
        err = tls_do_decryption(sk, skb, sgin, sgout, iv,
-                               data_len, aead_req, darg);
+                               data_len + prot->tail_size, aead_req, darg);
        if (darg->async)
                return 0;
 
+       if (prot->tail_size)
+               darg->tail = *tail;
+
        /* Release the pages in case iov was mapped to pages */
        for (; pages > 0; pages--)
                put_page(sg_page(&sgout[pages]));
@@ -1583,9 +1597,16 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
        }
        if (darg->async)
                goto decrypt_next;
+       /* If opportunistic TLS 1.3 ZC failed retry without ZC */
+       if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
+                    darg->tail != TLS_RECORD_TYPE_DATA)) {
+               darg->zc = false;
+               TLS_INC_STATS(sock_net(sk), LINUX_MIN_TLSDECRYPTRETRY);
+               return decrypt_skb_update(sk, skb, dest, darg);
+       }
 
 decrypt_done:
-       pad = padding_length(prot, skb);
+       pad = tls_padding_length(prot, skb, darg);
        if (pad < 0)
                return pad;
 
@@ -1717,6 +1738,24 @@ out:
        return copied ? : err;
 }
 
+static void
+tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot,
+                      size_t len_left, size_t decrypted, ssize_t done,
+                      size_t *flushed_at)
+{
+       size_t max_rec;
+
+       if (len_left <= decrypted)
+               return;
+
+       max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE;
+       if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec)
+               return;
+
+       *flushed_at = done;
+       sk_flush_backlog(sk);
+}
+
 int tls_sw_recvmsg(struct sock *sk,
                   struct msghdr *msg,
                   size_t len,
@@ -1729,6 +1768,7 @@ int tls_sw_recvmsg(struct sock *sk,
        struct sk_psock *psock;
        unsigned char control = 0;
        ssize_t decrypted = 0;
+       size_t flushed_at = 0;
        struct strp_msg *rxm;
        struct tls_msg *tlm;
        struct sk_buff *skb;
@@ -1767,7 +1807,7 @@ int tls_sw_recvmsg(struct sock *sk,
        timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
 
        zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
-                    prot->version != TLS_1_3_VERSION;
+               ctx->zc_capable;
        decrypted = 0;
        while (len && (decrypted + copied < target || ctx->recv_pkt)) {
                struct tls_decrypt_arg darg = {};
@@ -1818,6 +1858,10 @@ int tls_sw_recvmsg(struct sock *sk,
                if (err <= 0)
                        goto recv_end;
 
+               /* periodically flush backlog, and feed strparser */
+               tls_read_flush_backlog(sk, prot, len, to_decrypt,
+                                      decrypted + copied, &flushed_at);
+
                ctx->recv_pkt = NULL;
                __strp_unpause(&ctx->strp);
                __skb_queue_tail(&ctx->rx_list, skb);
@@ -2249,6 +2293,14 @@ void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
        strp_check_rcv(&rx_ctx->strp);
 }
 
+void tls_update_rx_zc_capable(struct tls_context *tls_ctx)
+{
+       struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
+
+       rx_ctx->zc_capable = tls_ctx->rx_no_pad ||
+               tls_ctx->prot_info.version != TLS_1_3_VERSION;
+}
+
 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
@@ -2484,12 +2536,10 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
        if (sw_ctx_rx) {
                tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
 
-               if (crypto_info->version == TLS_1_3_VERSION)
-                       sw_ctx_rx->async_capable = 0;
-               else
-                       sw_ctx_rx->async_capable =
-                               !!(tfm->__crt_alg->cra_flags &
-                                  CRYPTO_ALG_ASYNC);
+               tls_update_rx_zc_capable(ctx);
+               sw_ctx_rx->async_capable =
+                       crypto_info->version != TLS_1_3_VERSION &&
+                       !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC);
 
                /* Set up strparser */
                memset(&cb, 0, sizeof(cb));