btrfs: fix race between quota disable and quota assign ioctls
[platform/kernel/linux-rpi.git] / net / tls / tls_sw.c
index 1b08b87..c0fea67 100644 (file)
@@ -515,7 +515,7 @@ static int tls_do_encryption(struct sock *sk,
        memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
               prot->iv_size + prot->salt_size);
 
-       xor_iv_with_seq(prot, rec->iv_data, tls_ctx->tx.rec_seq);
+       xor_iv_with_seq(prot, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq);
 
        sge->offset += prot->prepend_size;
        sge->length -= prot->prepend_size;
@@ -801,7 +801,7 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
        struct sk_psock *psock;
        struct sock *sk_redir;
        struct tls_rec *rec;
-       bool enospc, policy;
+       bool enospc, policy, redir_ingress;
        int err = 0, send;
        u32 delta = 0;
 
@@ -846,6 +846,7 @@ more_data:
                }
                break;
        case __SK_REDIRECT:
+               redir_ingress = psock->redir_ingress;
                sk_redir = psock->sk_redir;
                memcpy(&msg_redir, msg, sizeof(*msg));
                if (msg->apply_bytes < send)
@@ -855,7 +856,8 @@ more_data:
                sk_msg_return_zero(sk, msg, send);
                msg->sg.size -= send;
                release_sock(sk);
-               err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
+               err = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress,
+                                           &msg_redir, send, flags);
                lock_sock(sk);
                if (err < 0) {
                        *copied -= sk_msg_free_nocharge(sk, &msg_redir);
@@ -1483,11 +1485,11 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
        if (prot->version == TLS_1_3_VERSION ||
            prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305)
                memcpy(iv + iv_offset, tls_ctx->rx.iv,
-                      crypto_aead_ivsize(ctx->aead_recv));
+                      prot->iv_size + prot->salt_size);
        else
                memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
 
-       xor_iv_with_seq(prot, iv, tls_ctx->rx.rec_seq);
+       xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq);
 
        /* Prepare AAD */
        tls_make_aad(aad, rxm->full_len - prot->overhead_size +
@@ -1993,6 +1995,7 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
        struct sock *sk = sock->sk;
        struct sk_buff *skb;
        ssize_t copied = 0;
+       bool from_queue;
        int err = 0;
        long timeo;
        int chunk;
@@ -2002,25 +2005,28 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
 
        timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK);
 
-       skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, &err);
-       if (!skb)
-               goto splice_read_end;
-
-       if (!ctx->decrypted) {
-               err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
-
-               /* splice does not support reading control messages */
-               if (ctx->control != TLS_RECORD_TYPE_DATA) {
-                       err = -EINVAL;
+       from_queue = !skb_queue_empty(&ctx->rx_list);
+       if (from_queue) {
+               skb = __skb_dequeue(&ctx->rx_list);
+       } else {
+               skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo,
+                                   &err);
+               if (!skb)
                        goto splice_read_end;
-               }
 
+               err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
                if (err < 0) {
                        tls_err_abort(sk, -EBADMSG);
                        goto splice_read_end;
                }
-               ctx->decrypted = 1;
        }
+
+       /* splice does not support reading control messages */
+       if (ctx->control != TLS_RECORD_TYPE_DATA) {
+               err = -EINVAL;
+               goto splice_read_end;
+       }
+
        rxm = strp_msg(skb);
 
        chunk = min_t(unsigned int, rxm->full_len, len);
@@ -2028,7 +2034,17 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
        if (copied < 0)
                goto splice_read_end;
 
-       tls_sw_advance_skb(sk, skb, copied);
+       if (!from_queue) {
+               ctx->recv_pkt = NULL;
+               __strp_unpause(&ctx->strp);
+       }
+       if (chunk < rxm->full_len) {
+               __skb_queue_head(&ctx->rx_list, skb);
+               rxm->offset += len;
+               rxm->full_len -= len;
+       } else {
+               consume_skb(skb);
+       }
 
 splice_read_end:
        release_sock(sk);