btrfs: fix race between quota disable and quota assign ioctls
[platform/kernel/linux-rpi.git] / net / tls / tls_sw.c
index 4feb95e..c0fea67 100644 (file)
@@ -35,6 +35,7 @@
  * SOFTWARE.
  */
 
+#include <linux/bug.h>
 #include <linux/sched/signal.h>
 #include <linux/module.h>
 #include <linux/splice.h>
 #include <net/strparser.h>
 #include <net/tls.h>
 
+noinline void tls_err_abort(struct sock *sk, int err)
+{
+       WARN_ON_ONCE(err >= 0);
+       /* sk->sk_err should contain a positive error code. */
+       sk->sk_err = -err;
+       sk_error_report(sk);
+}
+
 static int __skb_nsg(struct sk_buff *skb, int offset, int len,
                      unsigned int recursion_level)
 {
@@ -419,7 +428,7 @@ int tls_tx_records(struct sock *sk, int flags)
 
 tx_err:
        if (rc < 0 && rc != -EAGAIN)
-               tls_err_abort(sk, EBADMSG);
+               tls_err_abort(sk, -EBADMSG);
 
        return rc;
 }
@@ -450,7 +459,7 @@ static void tls_encrypt_done(struct crypto_async_request *req, int err)
 
                /* If err is already set on socket, return the same code */
                if (sk->sk_err) {
-                       ctx->async_wait.err = sk->sk_err;
+                       ctx->async_wait.err = -sk->sk_err;
                } else {
                        ctx->async_wait.err = err;
                        tls_err_abort(sk, err);
@@ -506,7 +515,7 @@ static int tls_do_encryption(struct sock *sk,
        memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
               prot->iv_size + prot->salt_size);
 
-       xor_iv_with_seq(prot, rec->iv_data, tls_ctx->tx.rec_seq);
+       xor_iv_with_seq(prot, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq);
 
        sge->offset += prot->prepend_size;
        sge->length -= prot->prepend_size;
@@ -763,7 +772,7 @@ static int tls_push_record(struct sock *sk, int flags,
                               msg_pl->sg.size + prot->tail_size, i);
        if (rc < 0) {
                if (rc != -EINPROGRESS) {
-                       tls_err_abort(sk, EBADMSG);
+                       tls_err_abort(sk, -EBADMSG);
                        if (split) {
                                tls_ctx->pending_open_record_frags = true;
                                tls_merge_open_record(sk, rec, tmp, orig_end);
@@ -792,7 +801,7 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
        struct sk_psock *psock;
        struct sock *sk_redir;
        struct tls_rec *rec;
-       bool enospc, policy;
+       bool enospc, policy, redir_ingress;
        int err = 0, send;
        u32 delta = 0;
 
@@ -837,6 +846,7 @@ more_data:
                }
                break;
        case __SK_REDIRECT:
+               redir_ingress = psock->redir_ingress;
                sk_redir = psock->sk_redir;
                memcpy(&msg_redir, msg, sizeof(*msg));
                if (msg->apply_bytes < send)
@@ -846,7 +856,8 @@ more_data:
                sk_msg_return_zero(sk, msg, send);
                msg->sg.size -= send;
                release_sock(sk);
-               err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
+               err = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress,
+                                           &msg_redir, send, flags);
                lock_sock(sk);
                if (err < 0) {
                        *copied -= sk_msg_free_nocharge(sk, &msg_redir);
@@ -1474,11 +1485,11 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
        if (prot->version == TLS_1_3_VERSION ||
            prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305)
                memcpy(iv + iv_offset, tls_ctx->rx.iv,
-                      crypto_aead_ivsize(ctx->aead_recv));
+                      prot->iv_size + prot->salt_size);
        else
                memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
 
-       xor_iv_with_seq(prot, iv, tls_ctx->rx.rec_seq);
+       xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq);
 
        /* Prepare AAD */
        tls_make_aad(aad, rxm->full_len - prot->overhead_size +
@@ -1827,7 +1838,7 @@ int tls_sw_recvmsg(struct sock *sk,
                err = decrypt_skb_update(sk, skb, &msg->msg_iter,
                                         &chunk, &zc, async_capable);
                if (err < 0 && err != -EINPROGRESS) {
-                       tls_err_abort(sk, EBADMSG);
+                       tls_err_abort(sk, -EBADMSG);
                        goto recv_end;
                }
 
@@ -1984,6 +1995,7 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
        struct sock *sk = sock->sk;
        struct sk_buff *skb;
        ssize_t copied = 0;
+       bool from_queue;
        int err = 0;
        long timeo;
        int chunk;
@@ -1993,25 +2005,28 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
 
        timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK);
 
-       skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, &err);
-       if (!skb)
-               goto splice_read_end;
-
-       if (!ctx->decrypted) {
-               err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
-
-               /* splice does not support reading control messages */
-               if (ctx->control != TLS_RECORD_TYPE_DATA) {
-                       err = -EINVAL;
+       from_queue = !skb_queue_empty(&ctx->rx_list);
+       if (from_queue) {
+               skb = __skb_dequeue(&ctx->rx_list);
+       } else {
+               skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo,
+                                   &err);
+               if (!skb)
                        goto splice_read_end;
-               }
 
+               err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
                if (err < 0) {
-                       tls_err_abort(sk, EBADMSG);
+                       tls_err_abort(sk, -EBADMSG);
                        goto splice_read_end;
                }
-               ctx->decrypted = 1;
        }
+
+       /* splice does not support reading control messages */
+       if (ctx->control != TLS_RECORD_TYPE_DATA) {
+               err = -EINVAL;
+               goto splice_read_end;
+       }
+
        rxm = strp_msg(skb);
 
        chunk = min_t(unsigned int, rxm->full_len, len);
@@ -2019,14 +2034,24 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
        if (copied < 0)
                goto splice_read_end;
 
-       tls_sw_advance_skb(sk, skb, copied);
+       if (!from_queue) {
+               ctx->recv_pkt = NULL;
+               __strp_unpause(&ctx->strp);
+       }
+       if (chunk < rxm->full_len) {
+               __skb_queue_head(&ctx->rx_list, skb);
+               rxm->offset += len;
+               rxm->full_len -= len;
+       } else {
+               consume_skb(skb);
+       }
 
 splice_read_end:
        release_sock(sk);
        return copied ? : err;
 }
 
-bool tls_sw_stream_read(const struct sock *sk)
+bool tls_sw_sock_is_readable(struct sock *sk)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);