tcp: allow again tcp_disconnect() when threads are waiting
authorPaolo Abeni <pabeni@redhat.com>
Wed, 11 Oct 2023 07:20:55 +0000 (09:20 +0200)
committerJakub Kicinski <kuba@kernel.org>
Fri, 13 Oct 2023 23:49:32 +0000 (16:49 -0700)
As reported by Tom, .NET and applications build on top of it rely
on connect(AF_UNSPEC) to async cancel pending I/O operations on TCP
socket.

The blamed commit below caused a regression, as such cancellation
can now fail.

As suggested by Eric, this change addresses the problem explicitly
causing blocking I/O operation to terminate immediately (with an error)
when a concurrent disconnect() is executed.

Instead of tracking the number of threads blocked on a given socket,
track the number of disconnect() issued on such socket. If such counter
changes after a blocking operation releasing and re-acquiring the socket
lock, error out the current operation.

Fixes: 4faeee0cf8a5 ("tcp: deny tcp_disconnect() when threads are waiting")
Reported-by: Tom Deseyn <tdeseyn@redhat.com>
Closes: https://bugzilla.redhat.com/show_bug.cgi?id=1886305
Suggested-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Link: https://lore.kernel.org/r/f3b95e47e3dbed840960548aebaa8d954372db41.1697008693.git.pabeni@redhat.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c
include/net/sock.h
net/core/stream.c
net/ipv4/af_inet.c
net/ipv4/inet_connection_sock.c
net/ipv4/tcp.c
net/ipv4/tcp_bpf.c
net/mptcp/protocol.c
net/tls/tls_main.c
net/tls/tls_sw.c

index 5fc64e4..d567e42 100644 (file)
@@ -911,7 +911,7 @@ static int csk_wait_memory(struct chtls_dev *cdev,
                           struct sock *sk, long *timeo_p)
 {
        DEFINE_WAIT_FUNC(wait, woken_wake_function);
-       int err = 0;
+       int ret, err = 0;
        long current_timeo;
        long vm_wait = 0;
        bool noblock;
@@ -942,10 +942,13 @@ static int csk_wait_memory(struct chtls_dev *cdev,
 
                set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
                sk->sk_write_pending++;
-               sk_wait_event(sk, &current_timeo, sk->sk_err ||
-                             (sk->sk_shutdown & SEND_SHUTDOWN) ||
-                             (csk_mem_free(cdev, sk) && !vm_wait), &wait);
+               ret = sk_wait_event(sk, &current_timeo, sk->sk_err ||
+                                   (sk->sk_shutdown & SEND_SHUTDOWN) ||
+                                   (csk_mem_free(cdev, sk) && !vm_wait),
+                                   &wait);
                sk->sk_write_pending--;
+               if (ret < 0)
+                       goto do_error;
 
                if (vm_wait) {
                        vm_wait -= current_timeo;
@@ -1348,6 +1351,7 @@ static int chtls_pt_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
        int copied = 0;
        int target;
        long timeo;
+       int ret;
 
        buffers_freed = 0;
 
@@ -1423,7 +1427,11 @@ static int chtls_pt_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
                if (copied >= target)
                        break;
                chtls_cleanup_rbuf(sk, copied);
-               sk_wait_data(sk, &timeo, NULL);
+               ret = sk_wait_data(sk, &timeo, NULL);
+               if (ret < 0) {
+                       copied = copied ? : ret;
+                       goto unlock;
+               }
                continue;
 found_ok_skb:
                if (!skb->len) {
@@ -1518,6 +1526,8 @@ skip_copy:
 
        if (buffers_freed)
                chtls_cleanup_rbuf(sk, copied);
+
+unlock:
        release_sock(sk);
        return copied;
 }
@@ -1534,6 +1544,7 @@ static int peekmsg(struct sock *sk, struct msghdr *msg,
        int copied = 0;
        size_t avail;          /* amount of available data in current skb */
        long timeo;
+       int ret;
 
        lock_sock(sk);
        timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
@@ -1585,7 +1596,12 @@ static int peekmsg(struct sock *sk, struct msghdr *msg,
                        release_sock(sk);
                        lock_sock(sk);
                } else {
-                       sk_wait_data(sk, &timeo, NULL);
+                       ret = sk_wait_data(sk, &timeo, NULL);
+                       if (ret < 0) {
+                               /* here 'copied' is 0 due to previous checks */
+                               copied = ret;
+                               break;
+                       }
                }
 
                if (unlikely(peek_seq != tp->copied_seq)) {
@@ -1656,6 +1672,7 @@ int chtls_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
        int copied = 0;
        long timeo;
        int target;             /* Read at least this many bytes */
+       int ret;
 
        buffers_freed = 0;
 
@@ -1747,7 +1764,11 @@ int chtls_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
                if (copied >= target)
                        break;
                chtls_cleanup_rbuf(sk, copied);
-               sk_wait_data(sk, &timeo, NULL);
+               ret = sk_wait_data(sk, &timeo, NULL);
+               if (ret < 0) {
+                       copied = copied ? : ret;
+                       goto unlock;
+               }
                continue;
 
 found_ok_skb:
@@ -1816,6 +1837,7 @@ skip_copy:
        if (buffers_freed)
                chtls_cleanup_rbuf(sk, copied);
 
+unlock:
        release_sock(sk);
        return copied;
 }
index b770261..92f7ea6 100644 (file)
@@ -336,7 +336,7 @@ struct sk_filter;
   *    @sk_cgrp_data: cgroup data for this cgroup
   *    @sk_memcg: this socket's memory cgroup association
   *    @sk_write_pending: a write to stream socket waits to start
-  *    @sk_wait_pending: number of threads blocked on this socket
+  *    @sk_disconnects: number of disconnect operations performed on this sock
   *    @sk_state_change: callback to indicate change in the state of the sock
   *    @sk_data_ready: callback to indicate there is data to be processed
   *    @sk_write_space: callback to indicate there is bf sending space available
@@ -429,7 +429,7 @@ struct sock {
        unsigned int            sk_napi_id;
 #endif
        int                     sk_rcvbuf;
-       int                     sk_wait_pending;
+       int                     sk_disconnects;
 
        struct sk_filter __rcu  *sk_filter;
        union {
@@ -1189,8 +1189,7 @@ static inline void sock_rps_reset_rxhash(struct sock *sk)
 }
 
 #define sk_wait_event(__sk, __timeo, __condition, __wait)              \
-       ({      int __rc;                                               \
-               __sk->sk_wait_pending++;                                \
+       ({      int __rc, __dis = __sk->sk_disconnects;                 \
                release_sock(__sk);                                     \
                __rc = __condition;                                     \
                if (!__rc) {                                            \
@@ -1200,8 +1199,7 @@ static inline void sock_rps_reset_rxhash(struct sock *sk)
                }                                                       \
                sched_annotate_sleep();                                 \
                lock_sock(__sk);                                        \
-               __sk->sk_wait_pending--;                                \
-               __rc = __condition;                                     \
+               __rc = __dis == __sk->sk_disconnects ? __condition : -EPIPE; \
                __rc;                                                   \
        })
 
index f5c4e47..96fbcb9 100644 (file)
@@ -117,7 +117,7 @@ EXPORT_SYMBOL(sk_stream_wait_close);
  */
 int sk_stream_wait_memory(struct sock *sk, long *timeo_p)
 {
-       int err = 0;
+       int ret, err = 0;
        long vm_wait = 0;
        long current_timeo = *timeo_p;
        DEFINE_WAIT_FUNC(wait, woken_wake_function);
@@ -142,11 +142,13 @@ int sk_stream_wait_memory(struct sock *sk, long *timeo_p)
 
                set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
                sk->sk_write_pending++;
-               sk_wait_event(sk, &current_timeo, READ_ONCE(sk->sk_err) ||
-                                                 (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) ||
-                                                 (sk_stream_memory_free(sk) &&
-                                                 !vm_wait), &wait);
+               ret = sk_wait_event(sk, &current_timeo, READ_ONCE(sk->sk_err) ||
+                                   (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) ||
+                                   (sk_stream_memory_free(sk) && !vm_wait),
+                                   &wait);
                sk->sk_write_pending--;
+               if (ret < 0)
+                       goto do_error;
 
                if (vm_wait) {
                        vm_wait -= current_timeo;
index 3d2e30e..2713c9b 100644 (file)
@@ -597,7 +597,6 @@ static long inet_wait_for_connect(struct sock *sk, long timeo, int writebias)
 
        add_wait_queue(sk_sleep(sk), &wait);
        sk->sk_write_pending += writebias;
-       sk->sk_wait_pending++;
 
        /* Basic assumption: if someone sets sk->sk_err, he _must_
         * change state of the socket from TCP_SYN_*.
@@ -613,7 +612,6 @@ static long inet_wait_for_connect(struct sock *sk, long timeo, int writebias)
        }
        remove_wait_queue(sk_sleep(sk), &wait);
        sk->sk_write_pending -= writebias;
-       sk->sk_wait_pending--;
        return timeo;
 }
 
@@ -642,6 +640,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
                        return -EINVAL;
 
                if (uaddr->sa_family == AF_UNSPEC) {
+                       sk->sk_disconnects++;
                        err = sk->sk_prot->disconnect(sk, flags);
                        sock->state = err ? SS_DISCONNECTING : SS_UNCONNECTED;
                        goto out;
@@ -696,6 +695,7 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
                int writebias = (sk->sk_protocol == IPPROTO_TCP) &&
                                tcp_sk(sk)->fastopen_req &&
                                tcp_sk(sk)->fastopen_req->data ? 1 : 0;
+               int dis = sk->sk_disconnects;
 
                /* Error code is set above */
                if (!timeo || !inet_wait_for_connect(sk, timeo, writebias))
@@ -704,6 +704,11 @@ int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
                err = sock_intr_errno(timeo);
                if (signal_pending(current))
                        goto out;
+
+               if (dis != sk->sk_disconnects) {
+                       err = -EPIPE;
+                       goto out;
+               }
        }
 
        /* Connection was closed by RST, timeout, ICMP error
@@ -725,6 +730,7 @@ out:
 sock_error:
        err = sock_error(sk) ? : -ECONNABORTED;
        sock->state = SS_UNCONNECTED;
+       sk->sk_disconnects++;
        if (sk->sk_prot->disconnect(sk, flags))
                sock->state = SS_DISCONNECTING;
        goto out;
index aeebe88..394a498 100644 (file)
@@ -1145,7 +1145,6 @@ struct sock *inet_csk_clone_lock(const struct sock *sk,
        if (newsk) {
                struct inet_connection_sock *newicsk = inet_csk(newsk);
 
-               newsk->sk_wait_pending = 0;
                inet_sk_set_state(newsk, TCP_SYN_RECV);
                newicsk->icsk_bind_hash = NULL;
                newicsk->icsk_bind2_hash = NULL;
index 3f66cde..d3456cf 100644 (file)
@@ -831,7 +831,9 @@ ssize_t tcp_splice_read(struct socket *sock, loff_t *ppos,
                         */
                        if (!skb_queue_empty(&sk->sk_receive_queue))
                                break;
-                       sk_wait_data(sk, &timeo, NULL);
+                       ret = sk_wait_data(sk, &timeo, NULL);
+                       if (ret < 0)
+                               break;
                        if (signal_pending(current)) {
                                ret = sock_intr_errno(timeo);
                                break;
@@ -2442,7 +2444,11 @@ static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len,
                        __sk_flush_backlog(sk);
                } else {
                        tcp_cleanup_rbuf(sk, copied);
-                       sk_wait_data(sk, &timeo, last);
+                       err = sk_wait_data(sk, &timeo, last);
+                       if (err < 0) {
+                               err = copied ? : err;
+                               goto out;
+                       }
                }
 
                if ((flags & MSG_PEEK) &&
@@ -2966,12 +2972,6 @@ int tcp_disconnect(struct sock *sk, int flags)
        int old_state = sk->sk_state;
        u32 seq;
 
-       /* Deny disconnect if other threads are blocked in sk_wait_event()
-        * or inet_wait_for_connect().
-        */
-       if (sk->sk_wait_pending)
-               return -EBUSY;
-
        if (old_state != TCP_CLOSE)
                tcp_set_state(sk, TCP_CLOSE);
 
index 3272682..ba2e921 100644 (file)
@@ -307,6 +307,8 @@ msg_bytes_ready:
                }
 
                data = tcp_msg_wait_data(sk, psock, timeo);
+               if (data < 0)
+                       return data;
                if (data && !sk_psock_queue_empty(psock))
                        goto msg_bytes_ready;
                copied = -EAGAIN;
@@ -351,6 +353,8 @@ msg_bytes_ready:
 
                timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
                data = tcp_msg_wait_data(sk, psock, timeo);
+               if (data < 0)
+                       return data;
                if (data) {
                        if (!sk_psock_queue_empty(psock))
                                goto msg_bytes_ready;
index c3b83cb..d190237 100644 (file)
@@ -3098,12 +3098,6 @@ static int mptcp_disconnect(struct sock *sk, int flags)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
 
-       /* Deny disconnect if other threads are blocked in sk_wait_event()
-        * or inet_wait_for_connect().
-        */
-       if (sk->sk_wait_pending)
-               return -EBUSY;
-
        /* We are on the fastopen error path. We can't call straight into the
         * subflows cleanup code due to lock nesting (we are already under
         * msk->firstsocket lock).
@@ -3173,7 +3167,6 @@ struct sock *mptcp_sk_clone_init(const struct sock *sk,
                inet_sk(nsk)->pinet6 = mptcp_inet6_sk(nsk);
 #endif
 
-       nsk->sk_wait_pending = 0;
        __mptcp_init_sock(nsk);
 
        msk = mptcp_sk(nsk);
index 02f583f..002483e 100644 (file)
@@ -139,8 +139,8 @@ void update_sk_prot(struct sock *sk, struct tls_context *ctx)
 
 int wait_on_pending_writer(struct sock *sk, long *timeo)
 {
-       int rc = 0;
        DEFINE_WAIT_FUNC(wait, woken_wake_function);
+       int ret, rc = 0;
 
        add_wait_queue(sk_sleep(sk), &wait);
        while (1) {
@@ -154,9 +154,13 @@ int wait_on_pending_writer(struct sock *sk, long *timeo)
                        break;
                }
 
-               if (sk_wait_event(sk, timeo,
-                                 !READ_ONCE(sk->sk_write_pending), &wait))
+               ret = sk_wait_event(sk, timeo,
+                                   !READ_ONCE(sk->sk_write_pending), &wait);
+               if (ret) {
+                       if (ret < 0)
+                               rc = ret;
                        break;
+               }
        }
        remove_wait_queue(sk_sleep(sk), &wait);
        return rc;
index d1fc295..e9d1e83 100644 (file)
@@ -1291,6 +1291,7 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        DEFINE_WAIT_FUNC(wait, woken_wake_function);
+       int ret = 0;
        long timeo;
 
        timeo = sock_rcvtimeo(sk, nonblock);
@@ -1302,6 +1303,9 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
                if (sk->sk_err)
                        return sock_error(sk);
 
+               if (ret < 0)
+                       return ret;
+
                if (!skb_queue_empty(&sk->sk_receive_queue)) {
                        tls_strp_check_rcv(&ctx->strp);
                        if (tls_strp_msg_ready(ctx))
@@ -1320,10 +1324,10 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
                released = true;
                add_wait_queue(sk_sleep(sk), &wait);
                sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-               sk_wait_event(sk, &timeo,
-                             tls_strp_msg_ready(ctx) ||
-                             !sk_psock_queue_empty(psock),
-                             &wait);
+               ret = sk_wait_event(sk, &timeo,
+                                   tls_strp_msg_ready(ctx) ||
+                                   !sk_psock_queue_empty(psock),
+                                   &wait);
                sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
                remove_wait_queue(sk_sleep(sk), &wait);
 
@@ -1852,6 +1856,7 @@ static int tls_rx_reader_acquire(struct sock *sk, struct tls_sw_context_rx *ctx,
                                 bool nonblock)
 {
        long timeo;
+       int ret;
 
        timeo = sock_rcvtimeo(sk, nonblock);
 
@@ -1861,14 +1866,16 @@ static int tls_rx_reader_acquire(struct sock *sk, struct tls_sw_context_rx *ctx,
                ctx->reader_contended = 1;
 
                add_wait_queue(&ctx->wq, &wait);
-               sk_wait_event(sk, &timeo,
-                             !READ_ONCE(ctx->reader_present), &wait);
+               ret = sk_wait_event(sk, &timeo,
+                                   !READ_ONCE(ctx->reader_present), &wait);
                remove_wait_queue(&ctx->wq, &wait);
 
                if (timeo <= 0)
                        return -EAGAIN;
                if (signal_pending(current))
                        return sock_intr_errno(timeo);
+               if (ret < 0)
+                       return ret;
        }
 
        WRITE_ONCE(ctx->reader_present, 1);