mptcp: fix possible list corruption on passive MPJ
[platform/kernel/linux-starfive.git] / net / mptcp / protocol.c
index f420600..a5b330b 100644 (file)
@@ -599,7 +599,7 @@ static bool mptcp_check_data_fin(struct sock *sk)
                WRITE_ONCE(msk->ack_seq, msk->ack_seq + 1);
                WRITE_ONCE(msk->rcv_data_fin, 0);
 
-               sk->sk_shutdown |= RCV_SHUTDOWN;
+               WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN);
                smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
 
                switch (sk->sk_state) {
@@ -821,6 +821,13 @@ void mptcp_data_ready(struct sock *sk, struct sock *ssk)
        mptcp_data_unlock(sk);
 }
 
+static void mptcp_subflow_joined(struct mptcp_sock *msk, struct sock *ssk)
+{
+       mptcp_subflow_ctx(ssk)->map_seq = READ_ONCE(msk->ack_seq);
+       WRITE_ONCE(msk->allow_infinite_fallback, false);
+       mptcp_event(MPTCP_EVENT_SUB_ESTABLISHED, msk, ssk, GFP_ATOMIC);
+}
+
 static bool __mptcp_finish_join(struct mptcp_sock *msk, struct sock *ssk)
 {
        struct sock *sk = (struct sock *)msk;
@@ -835,15 +842,16 @@ static bool __mptcp_finish_join(struct mptcp_sock *msk, struct sock *ssk)
                mptcp_sock_graft(ssk, sk->sk_socket);
 
        mptcp_sockopt_sync_locked(msk, ssk);
+       mptcp_subflow_joined(msk, ssk);
        return true;
 }
 
-static void __mptcp_flush_join_list(struct sock *sk)
+static void __mptcp_flush_join_list(struct sock *sk, struct list_head *join_list)
 {
        struct mptcp_subflow_context *tmp, *subflow;
        struct mptcp_sock *msk = mptcp_sk(sk);
 
-       list_for_each_entry_safe(subflow, tmp, &msk->join_list, node) {
+       list_for_each_entry_safe(subflow, tmp, join_list, node) {
                struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
                bool slow = lock_sock_fast(ssk);
 
@@ -906,7 +914,7 @@ static void mptcp_check_for_eof(struct mptcp_sock *msk)
                /* hopefully temporary hack: propagate shutdown status
                 * to msk, when all subflows agree on it
                 */
-               sk->sk_shutdown |= RCV_SHUTDOWN;
+               WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN);
 
                smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
                sk->sk_data_ready(sk);
@@ -1683,7 +1691,6 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msgh
 
        lock_sock(ssk);
        msg->msg_flags |= MSG_DONTWAIT;
-       msk->connect_flags = O_NONBLOCK;
        msk->fastopening = 1;
        ret = tcp_sendmsg_fastopen(ssk, msg, copied_syn, len, NULL);
        msk->fastopening = 0;
@@ -1701,7 +1708,13 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msgh
                if (ret && ret != -EINPROGRESS && ret != -ERESTARTSYS && ret != -EINTR)
                        *copied_syn = 0;
        } else if (ret && ret != -EINPROGRESS) {
-               mptcp_disconnect(sk, 0);
+               /* The disconnect() op called by tcp_sendmsg_fastopen()/
+                * __inet_stream_connect() can fail, due to looking check,
+                * see mptcp_disconnect().
+                * Attempt it again outside the problematic scope.
+                */
+               if (!mptcp_disconnect(sk, 0))
+                       sk->sk_socket->state = SS_UNCONNECTED;
        }
 
        return ret;
@@ -2368,7 +2381,10 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
 
        need_push = (flags & MPTCP_CF_PUSH) && __mptcp_retransmit_pending_data(sk);
        if (!dispose_it) {
-               tcp_disconnect(ssk, 0);
+               /* The MPTCP code never wait on the subflow sockets, TCP-level
+                * disconnect should never fail
+                */
+               WARN_ON_ONCE(tcp_disconnect(ssk, 0));
                msk->subflow->state = SS_UNCONNECTED;
                mptcp_subflow_ctx_reset(subflow);
                release_sock(ssk);
@@ -2512,7 +2528,7 @@ static void mptcp_check_fastclose(struct mptcp_sock *msk)
        }
 
        inet_sk_state_store(sk, TCP_CLOSE);
-       sk->sk_shutdown = SHUTDOWN_MASK;
+       WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK);
        smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
        set_bit(MPTCP_WORK_CLOSE_SUBFLOW, &msk->flags);
 
@@ -2792,7 +2808,7 @@ void mptcp_subflow_shutdown(struct sock *sk, struct sock *ssk, int how)
                        break;
                fallthrough;
        case TCP_SYN_SENT:
-               tcp_disconnect(ssk, O_NONBLOCK);
+               WARN_ON_ONCE(tcp_disconnect(ssk, O_NONBLOCK));
                break;
        default:
                if (__mptcp_check_fallback(mptcp_sk(sk))) {
@@ -2941,7 +2957,7 @@ bool __mptcp_close(struct sock *sk, long timeout)
        bool do_cancel_work = false;
        int subflows_alive = 0;
 
-       sk->sk_shutdown = SHUTDOWN_MASK;
+       WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK);
 
        if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) {
                inet_sk_state_store(sk, TCP_CLOSE);
@@ -3044,13 +3060,18 @@ static int mptcp_disconnect(struct sock *sk, int flags)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
 
+       /* Deny disconnect if other threads are blocked in sk_wait_event()
+        * or inet_wait_for_connect().
+        */
+       if (sk->sk_wait_pending)
+               return -EBUSY;
+
        /* We are on the fastopen error path. We can't call straight into the
         * subflows cleanup code due to lock nesting (we are already under
-        * msk->firstsocket lock). Do nothing and leave the cleanup to the
-        * caller.
+        * msk->firstsocket lock).
         */
        if (msk->fastopening)
-               return 0;
+               return -EBUSY;
 
        inet_sk_state_store(sk, TCP_CLOSE);
 
@@ -3079,7 +3100,7 @@ static int mptcp_disconnect(struct sock *sk, int flags)
        mptcp_pm_data_reset(msk);
        mptcp_ca_reset(sk);
 
-       sk->sk_shutdown = 0;
+       WRITE_ONCE(sk->sk_shutdown, 0);
        sk_error_report(sk);
        return 0;
 }
@@ -3111,6 +3132,7 @@ struct sock *mptcp_sk_clone_init(const struct sock *sk,
                inet_sk(nsk)->pinet6 = mptcp_inet6_sk(nsk);
 #endif
 
+       nsk->sk_wait_pending = 0;
        __mptcp_init_sock(nsk);
 
        msk = mptcp_sk(nsk);
@@ -3313,9 +3335,14 @@ static void mptcp_release_cb(struct sock *sk)
        for (;;) {
                unsigned long flags = (msk->cb_flags & MPTCP_FLAGS_PROCESS_CTX_NEED) |
                                      msk->push_pending;
+               struct list_head join_list;
+
                if (!flags)
                        break;
 
+               INIT_LIST_HEAD(&join_list);
+               list_splice_init(&msk->join_list, &join_list);
+
                /* the following actions acquire the subflow socket lock
                 *
                 * 1) can't be invoked in atomic scope
@@ -3326,8 +3353,9 @@ static void mptcp_release_cb(struct sock *sk)
                msk->push_pending = 0;
                msk->cb_flags &= ~flags;
                spin_unlock_bh(&sk->sk_lock.slock);
+
                if (flags & BIT(MPTCP_FLUSH_JOIN_LIST))
-                       __mptcp_flush_join_list(sk);
+                       __mptcp_flush_join_list(sk, &join_list);
                if (flags & BIT(MPTCP_PUSH_PENDING))
                        __mptcp_push_pending(sk, 0);
                if (flags & BIT(MPTCP_RETRANSMIT))
@@ -3486,14 +3514,16 @@ bool mptcp_finish_join(struct sock *ssk)
                return false;
        }
 
-       if (!list_empty(&subflow->node))
-               goto out;
+       /* active subflow, already present inside the conn_list */
+       if (!list_empty(&subflow->node)) {
+               mptcp_subflow_joined(msk, ssk);
+               return true;
+       }
 
        if (!mptcp_pm_allow_new_subflow(msk))
                goto err_prohibited;
 
-       /* active connections are already on conn_list.
-        * If we can't acquire msk socket lock here, let the release callback
+       /* If we can't acquire msk socket lock here, let the release callback
         * handle it
         */
        mptcp_data_lock(parent);
@@ -3516,11 +3546,6 @@ err_prohibited:
                return false;
        }
 
-       subflow->map_seq = READ_ONCE(msk->ack_seq);
-       WRITE_ONCE(msk->allow_infinite_fallback, false);
-
-out:
-       mptcp_event(MPTCP_EVENT_SUB_ESTABLISHED, msk, ssk, GFP_ATOMIC);
        return true;
 }
 
@@ -3638,9 +3663,9 @@ static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
         * acquired the subflow socket lock, too.
         */
        if (msk->fastopening)
-               err = __inet_stream_connect(ssock, uaddr, addr_len, msk->connect_flags, 1);
+               err = __inet_stream_connect(ssock, uaddr, addr_len, O_NONBLOCK, 1);
        else
-               err = inet_stream_connect(ssock, uaddr, addr_len, msk->connect_flags);
+               err = inet_stream_connect(ssock, uaddr, addr_len, O_NONBLOCK);
        inet_sk(sk)->defer_connect = inet_sk(ssock->sk)->defer_connect;
 
        /* on successful connect, the msk state will be moved to established by
@@ -3653,12 +3678,10 @@ static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 
        mptcp_copy_inaddrs(sk, ssock->sk);
 
-       /* unblocking connect, mptcp-level inet_stream_connect will error out
-        * without changing the socket state, update it here.
+       /* silence EINPROGRESS and let the caller inet_stream_connect
+        * handle the connection in progress
         */
-       if (err == -EINPROGRESS)
-               sk->sk_socket->state = ssock->state;
-       return err;
+       return 0;
 }
 
 static struct proto mptcp_prot = {
@@ -3717,18 +3740,6 @@ unlock:
        return err;
 }
 
-static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
-                               int addr_len, int flags)
-{
-       int ret;
-
-       lock_sock(sock->sk);
-       mptcp_sk(sock->sk)->connect_flags = flags;
-       ret = __inet_stream_connect(sock, uaddr, addr_len, flags, 0);
-       release_sock(sock->sk);
-       return ret;
-}
-
 static int mptcp_listen(struct socket *sock, int backlog)
 {
        struct mptcp_sock *msk = mptcp_sk(sock->sk);
@@ -3815,9 +3826,6 @@ static __poll_t mptcp_check_writeable(struct mptcp_sock *msk)
 {
        struct sock *sk = (struct sock *)msk;
 
-       if (unlikely(sk->sk_shutdown & SEND_SHUTDOWN))
-               return EPOLLOUT | EPOLLWRNORM;
-
        if (sk_stream_is_writeable(sk))
                return EPOLLOUT | EPOLLWRNORM;
 
@@ -3835,6 +3843,7 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock,
        struct sock *sk = sock->sk;
        struct mptcp_sock *msk;
        __poll_t mask = 0;
+       u8 shutdown;
        int state;
 
        msk = mptcp_sk(sk);
@@ -3851,17 +3860,22 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock,
                return inet_csk_listen_poll(ssock->sk);
        }
 
+       shutdown = READ_ONCE(sk->sk_shutdown);
+       if (shutdown == SHUTDOWN_MASK || state == TCP_CLOSE)
+               mask |= EPOLLHUP;
+       if (shutdown & RCV_SHUTDOWN)
+               mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
+
        if (state != TCP_SYN_SENT && state != TCP_SYN_RECV) {
                mask |= mptcp_check_readable(msk);
-               mask |= mptcp_check_writeable(msk);
+               if (shutdown & SEND_SHUTDOWN)
+                       mask |= EPOLLOUT | EPOLLWRNORM;
+               else
+                       mask |= mptcp_check_writeable(msk);
        } else if (state == TCP_SYN_SENT && inet_sk(sk)->defer_connect) {
                /* cf tcp_poll() note about TFO */
                mask |= EPOLLOUT | EPOLLWRNORM;
        }
-       if (sk->sk_shutdown == SHUTDOWN_MASK || state == TCP_CLOSE)
-               mask |= EPOLLHUP;
-       if (sk->sk_shutdown & RCV_SHUTDOWN)
-               mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
 
        /* This barrier is coupled with smp_wmb() in __mptcp_error_report() */
        smp_rmb();
@@ -3876,7 +3890,7 @@ static const struct proto_ops mptcp_stream_ops = {
        .owner             = THIS_MODULE,
        .release           = inet_release,
        .bind              = mptcp_bind,
-       .connect           = mptcp_stream_connect,
+       .connect           = inet_stream_connect,
        .socketpair        = sock_no_socketpair,
        .accept            = mptcp_stream_accept,
        .getname           = inet_getname,
@@ -3971,7 +3985,7 @@ static const struct proto_ops mptcp_v6_stream_ops = {
        .owner             = THIS_MODULE,
        .release           = inet6_release,
        .bind              = mptcp_bind,
-       .connect           = mptcp_stream_connect,
+       .connect           = inet_stream_connect,
        .socketpair        = sock_no_socketpair,
        .accept            = mptcp_stream_accept,
        .getname           = inet6_getname,