mptcp: fix possible list corruption on passive MPJ
[platform/kernel/linux-starfive.git] / net / mptcp / protocol.c
index b6dc6e2..a5b330b 100644 (file)
@@ -107,12 +107,12 @@ static int __mptcp_socket_create(struct mptcp_sock *msk)
        struct socket *ssock;
        int err;
 
-       err = mptcp_subflow_create_socket(sk, &ssock);
+       err = mptcp_subflow_create_socket(sk, sk->sk_family, &ssock);
        if (err)
                return err;
 
-       msk->first = ssock->sk;
-       msk->subflow = ssock;
+       WRITE_ONCE(msk->first, ssock->sk);
+       WRITE_ONCE(msk->subflow, ssock);
        subflow = mptcp_subflow_ctx(ssock->sk);
        list_add(&subflow->node, &msk->conn_list);
        sock_hold(ssock->sk);
@@ -599,7 +599,7 @@ static bool mptcp_check_data_fin(struct sock *sk)
                WRITE_ONCE(msk->ack_seq, msk->ack_seq + 1);
                WRITE_ONCE(msk->rcv_data_fin, 0);
 
-               sk->sk_shutdown |= RCV_SHUTDOWN;
+               WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN);
                smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
 
                switch (sk->sk_state) {
@@ -821,6 +821,13 @@ void mptcp_data_ready(struct sock *sk, struct sock *ssk)
        mptcp_data_unlock(sk);
 }
 
+static void mptcp_subflow_joined(struct mptcp_sock *msk, struct sock *ssk)
+{
+       mptcp_subflow_ctx(ssk)->map_seq = READ_ONCE(msk->ack_seq);
+       WRITE_ONCE(msk->allow_infinite_fallback, false);
+       mptcp_event(MPTCP_EVENT_SUB_ESTABLISHED, msk, ssk, GFP_ATOMIC);
+}
+
 static bool __mptcp_finish_join(struct mptcp_sock *msk, struct sock *ssk)
 {
        struct sock *sk = (struct sock *)msk;
@@ -834,17 +841,17 @@ static bool __mptcp_finish_join(struct mptcp_sock *msk, struct sock *ssk)
        if (sk->sk_socket && !ssk->sk_socket)
                mptcp_sock_graft(ssk, sk->sk_socket);
 
-       mptcp_propagate_sndbuf((struct sock *)msk, ssk);
        mptcp_sockopt_sync_locked(msk, ssk);
+       mptcp_subflow_joined(msk, ssk);
        return true;
 }
 
-static void __mptcp_flush_join_list(struct sock *sk)
+static void __mptcp_flush_join_list(struct sock *sk, struct list_head *join_list)
 {
        struct mptcp_subflow_context *tmp, *subflow;
        struct mptcp_sock *msk = mptcp_sk(sk);
 
-       list_for_each_entry_safe(subflow, tmp, &msk->join_list, node) {
+       list_for_each_entry_safe(subflow, tmp, join_list, node) {
                struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
                bool slow = lock_sock_fast(ssk);
 
@@ -907,7 +914,7 @@ static void mptcp_check_for_eof(struct mptcp_sock *msk)
                /* hopefully temporary hack: propagate shutdown status
                 * to msk, when all subflows agree on it
                 */
-               sk->sk_shutdown |= RCV_SHUTDOWN;
+               WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN);
 
                smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
                sk->sk_data_ready(sk);
@@ -1673,6 +1680,8 @@ static void mptcp_set_nospace(struct sock *sk)
        set_bit(MPTCP_NOSPACE, &mptcp_sk(sk)->flags);
 }
 
+static int mptcp_disconnect(struct sock *sk, int flags);
+
 static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msghdr *msg,
                                  size_t len, int *copied_syn)
 {
@@ -1682,10 +1691,9 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msgh
 
        lock_sock(ssk);
        msg->msg_flags |= MSG_DONTWAIT;
-       msk->connect_flags = O_NONBLOCK;
-       msk->is_sendmsg = 1;
+       msk->fastopening = 1;
        ret = tcp_sendmsg_fastopen(ssk, msg, copied_syn, len, NULL);
-       msk->is_sendmsg = 0;
+       msk->fastopening = 0;
        msg->msg_flags = saved_flags;
        release_sock(ssk);
 
@@ -1699,6 +1707,14 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msgh
                 */
                if (ret && ret != -EINPROGRESS && ret != -ERESTARTSYS && ret != -EINTR)
                        *copied_syn = 0;
+       } else if (ret && ret != -EINPROGRESS) {
+               /* The disconnect() op called by tcp_sendmsg_fastopen()/
+                * __inet_stream_connect() can fail, due to looking check,
+                * see mptcp_disconnect().
+                * Attempt it again outside the problematic scope.
+                */
+               if (!mptcp_disconnect(sk, 0))
+                       sk->sk_socket->state = SS_UNCONNECTED;
        }
 
        return ret;
@@ -2266,7 +2282,7 @@ static void mptcp_dispose_initial_subflow(struct mptcp_sock *msk)
 {
        if (msk->subflow) {
                iput(SOCK_INODE(msk->subflow));
-               msk->subflow = NULL;
+               WRITE_ONCE(msk->subflow, NULL);
        }
 }
 
@@ -2327,7 +2343,26 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
                              unsigned int flags)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
-       bool need_push, dispose_it;
+       bool dispose_it, need_push = false;
+
+       /* If the first subflow moved to a close state before accept, e.g. due
+        * to an incoming reset, mptcp either:
+        * - if either the subflow or the msk are dead, destroy the context
+        *   (the subflow socket is deleted by inet_child_forget) and the msk
+        * - otherwise do nothing at the moment and take action at accept and/or
+        *   listener shutdown - user-space must be able to accept() the closed
+        *   socket.
+        */
+       if (msk->in_accept_queue && msk->first == ssk) {
+               if (!sock_flag(sk, SOCK_DEAD) && !sock_flag(ssk, SOCK_DEAD))
+                       return;
+
+               /* ensure later check in mptcp_worker() will dispose the msk */
+               sock_set_flag(sk, SOCK_DEAD);
+               lock_sock_nested(ssk, SINGLE_DEPTH_NESTING);
+               mptcp_subflow_drop_ctx(ssk);
+               goto out_release;
+       }
 
        dispose_it = !msk->subflow || ssk != msk->subflow->sk;
        if (dispose_it)
@@ -2346,7 +2381,10 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
 
        need_push = (flags & MPTCP_CF_PUSH) && __mptcp_retransmit_pending_data(sk);
        if (!dispose_it) {
-               tcp_disconnect(ssk, 0);
+               /* The MPTCP code never wait on the subflow sockets, TCP-level
+                * disconnect should never fail
+                */
+               WARN_ON_ONCE(tcp_disconnect(ssk, 0));
                msk->subflow->state = SS_UNCONNECTED;
                mptcp_subflow_ctx_reset(subflow);
                release_sock(ssk);
@@ -2354,12 +2392,6 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
                goto out;
        }
 
-       /* if we are invoked by the msk cleanup code, the subflow is
-        * already orphaned
-        */
-       if (ssk->sk_socket)
-               sock_orphan(ssk);
-
        subflow->disposable = 1;
 
        /* if ssk hit tcp_done(), tcp_cleanup_ulp() cleared the related ops
@@ -2367,25 +2399,29 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
         * reference owned by msk;
         */
        if (!inet_csk(ssk)->icsk_ulp_ops) {
+               WARN_ON_ONCE(!sock_flag(ssk, SOCK_DEAD));
                kfree_rcu(subflow, rcu);
        } else {
                /* otherwise tcp will dispose of the ssk and subflow ctx */
                if (ssk->sk_state == TCP_LISTEN) {
                        tcp_set_state(ssk, TCP_CLOSE);
-                       mptcp_subflow_queue_clean(ssk);
+                       mptcp_subflow_queue_clean(sk, ssk);
                        inet_csk_listen_stop(ssk);
                }
+
                __tcp_close(ssk, 0);
 
                /* close acquired an extra ref */
                __sock_put(ssk);
        }
+
+out_release:
        release_sock(ssk);
 
        sock_put(ssk);
 
        if (ssk == msk->first)
-               msk->first = NULL;
+               WRITE_ONCE(msk->first, NULL);
 
 out:
        if (ssk == msk->last_snd)
@@ -2414,9 +2450,10 @@ static unsigned int mptcp_sync_mss(struct sock *sk, u32 pmtu)
        return 0;
 }
 
-static void __mptcp_close_subflow(struct mptcp_sock *msk)
+static void __mptcp_close_subflow(struct sock *sk)
 {
        struct mptcp_subflow_context *subflow, *tmp;
+       struct mptcp_sock *msk = mptcp_sk(sk);
 
        might_sleep();
 
@@ -2430,16 +2467,17 @@ static void __mptcp_close_subflow(struct mptcp_sock *msk)
                if (!skb_queue_empty_lockless(&ssk->sk_receive_queue))
                        continue;
 
-               mptcp_close_ssk((struct sock *)msk, ssk, subflow);
+               mptcp_close_ssk(sk, ssk, subflow);
        }
+
 }
 
-static bool mptcp_check_close_timeout(const struct sock *sk)
+static bool mptcp_should_close(const struct sock *sk)
 {
        s32 delta = tcp_jiffies32 - inet_csk(sk)->icsk_mtup.probe_timestamp;
        struct mptcp_subflow_context *subflow;
 
-       if (delta >= TCP_TIMEWAIT_LEN)
+       if (delta >= TCP_TIMEWAIT_LEN || mptcp_sk(sk)->in_accept_queue)
                return true;
 
        /* if all subflows are in closed status don't bother with additional
@@ -2490,7 +2528,7 @@ static void mptcp_check_fastclose(struct mptcp_sock *msk)
        }
 
        inet_sk_state_store(sk, TCP_CLOSE);
-       sk->sk_shutdown = SHUTDOWN_MASK;
+       WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK);
        smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
        set_bit(MPTCP_WORK_CLOSE_SUBFLOW, &msk->flags);
 
@@ -2624,7 +2662,7 @@ static void mptcp_worker(struct work_struct *work)
 
        lock_sock(sk);
        state = sk->sk_state;
-       if (unlikely(state == TCP_CLOSE))
+       if (unlikely((1 << state) & (TCPF_CLOSE | TCPF_LISTEN)))
                goto unlock;
 
        mptcp_check_data_fin_ack(sk);
@@ -2639,12 +2677,15 @@ static void mptcp_worker(struct work_struct *work)
        __mptcp_check_send_data_fin(sk);
        mptcp_check_data_fin(sk);
 
+       if (test_and_clear_bit(MPTCP_WORK_CLOSE_SUBFLOW, &msk->flags))
+               __mptcp_close_subflow(sk);
+
        /* There is no point in keeping around an orphaned sk timedout or
         * closed, but we need the msk around to reply to incoming DATA_FIN,
         * even if it is orphaned and in FIN_WAIT2 state
         */
        if (sock_flag(sk, SOCK_DEAD)) {
-               if (mptcp_check_close_timeout(sk)) {
+               if (mptcp_should_close(sk)) {
                        inet_sk_state_store(sk, TCP_CLOSE);
                        mptcp_do_fastclose(sk);
                }
@@ -2654,9 +2695,6 @@ static void mptcp_worker(struct work_struct *work)
                }
        }
 
-       if (test_and_clear_bit(MPTCP_WORK_CLOSE_SUBFLOW, &msk->flags))
-               __mptcp_close_subflow(msk);
-
        if (test_and_clear_bit(MPTCP_WORK_RTX, &msk->flags))
                __mptcp_retrans(sk);
 
@@ -2684,7 +2722,7 @@ static int __mptcp_init_sock(struct sock *sk)
        WRITE_ONCE(msk->rmem_released, 0);
        msk->timer_ival = TCP_RTO_MIN;
 
-       msk->first = NULL;
+       WRITE_ONCE(msk->first, NULL);
        inet_csk(sk)->icsk_sync_mss = mptcp_sync_mss;
        WRITE_ONCE(msk->csum_enabled, mptcp_is_checksum_enabled(sock_net(sk)));
        WRITE_ONCE(msk->allow_infinite_fallback, true);
@@ -2770,7 +2808,7 @@ void mptcp_subflow_shutdown(struct sock *sk, struct sock *ssk, int how)
                        break;
                fallthrough;
        case TCP_SYN_SENT:
-               tcp_disconnect(ssk, O_NONBLOCK);
+               WARN_ON_ONCE(tcp_disconnect(ssk, O_NONBLOCK));
                break;
        default:
                if (__mptcp_check_fallback(mptcp_sk(sk))) {
@@ -2892,6 +2930,14 @@ static void __mptcp_destroy_sock(struct sock *sk)
        sock_put(sk);
 }
 
+void __mptcp_unaccepted_force_close(struct sock *sk)
+{
+       sock_set_flag(sk, SOCK_DEAD);
+       inet_sk_state_store(sk, TCP_CLOSE);
+       mptcp_do_fastclose(sk);
+       __mptcp_destroy_sock(sk);
+}
+
 static __poll_t mptcp_check_readable(struct mptcp_sock *msk)
 {
        /* Concurrent splices from sk_receive_queue into receive_queue will
@@ -2909,8 +2955,9 @@ bool __mptcp_close(struct sock *sk, long timeout)
        struct mptcp_subflow_context *subflow;
        struct mptcp_sock *msk = mptcp_sk(sk);
        bool do_cancel_work = false;
+       int subflows_alive = 0;
 
-       sk->sk_shutdown = SHUTDOWN_MASK;
+       WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK);
 
        if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) {
                inet_sk_state_store(sk, TCP_CLOSE);
@@ -2934,17 +2981,29 @@ cleanup:
                struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
                bool slow = lock_sock_fast_nested(ssk);
 
+               subflows_alive += ssk->sk_state != TCP_CLOSE;
+
                /* since the close timeout takes precedence on the fail one,
                 * cancel the latter
                 */
                if (ssk == msk->first)
                        subflow->fail_tout = 0;
 
-               sock_orphan(ssk);
+               /* detach from the parent socket, but allow data_ready to
+                * push incoming data into the mptcp stack, to properly ack it
+                */
+               ssk->sk_socket = NULL;
+               ssk->sk_wq = NULL;
                unlock_sock_fast(ssk, slow);
        }
        sock_orphan(sk);
 
+       /* all the subflows are closed, only timeout can change the msk
+        * state, let's not keep resources busy for no reasons
+        */
+       if (subflows_alive == 0)
+               inet_sk_state_store(sk, TCP_CLOSE);
+
        sock_hold(sk);
        pr_debug("msk=%p state=%d", sk, sk->sk_state);
        if (mptcp_sk(sk)->token)
@@ -2974,7 +3033,7 @@ static void mptcp_close(struct sock *sk, long timeout)
        sock_put(sk);
 }
 
-void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk)
+static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk)
 {
 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
        const struct ipv6_pinfo *ssk6 = inet6_sk(ssk);
@@ -3001,6 +3060,19 @@ static int mptcp_disconnect(struct sock *sk, int flags)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
 
+       /* Deny disconnect if other threads are blocked in sk_wait_event()
+        * or inet_wait_for_connect().
+        */
+       if (sk->sk_wait_pending)
+               return -EBUSY;
+
+       /* We are on the fastopen error path. We can't call straight into the
+        * subflows cleanup code due to lock nesting (we are already under
+        * msk->firstsocket lock).
+        */
+       if (msk->fastopening)
+               return -EBUSY;
+
        inet_sk_state_store(sk, TCP_CLOSE);
 
        mptcp_stop_timer(sk);
@@ -3028,7 +3100,7 @@ static int mptcp_disconnect(struct sock *sk, int flags)
        mptcp_pm_data_reset(msk);
        mptcp_ca_reset(sk);
 
-       sk->sk_shutdown = 0;
+       WRITE_ONCE(sk->sk_shutdown, 0);
        sk_error_report(sk);
        return 0;
 }
@@ -3042,9 +3114,10 @@ static struct ipv6_pinfo *mptcp_inet6_sk(const struct sock *sk)
 }
 #endif
 
-struct sock *mptcp_sk_clone(const struct sock *sk,
-                           const struct mptcp_options_received *mp_opt,
-                           struct request_sock *req)
+struct sock *mptcp_sk_clone_init(const struct sock *sk,
+                                const struct mptcp_options_received *mp_opt,
+                                struct sock *ssk,
+                                struct request_sock *req)
 {
        struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
        struct sock *nsk = sk_clone_lock(sk, GFP_ATOMIC);
@@ -3059,12 +3132,14 @@ struct sock *mptcp_sk_clone(const struct sock *sk,
                inet_sk(nsk)->pinet6 = mptcp_inet6_sk(nsk);
 #endif
 
+       nsk->sk_wait_pending = 0;
        __mptcp_init_sock(nsk);
 
        msk = mptcp_sk(nsk);
        msk->local_key = subflow_req->local_key;
        msk->token = subflow_req->token;
-       msk->subflow = NULL;
+       WRITE_ONCE(msk->subflow, NULL);
+       msk->in_accept_queue = 1;
        WRITE_ONCE(msk->fully_established, false);
        if (mp_opt->suboptions & OPTION_MPTCP_CSUMREQD)
                WRITE_ONCE(msk->csum_enabled, true);
@@ -3085,14 +3160,33 @@ struct sock *mptcp_sk_clone(const struct sock *sk,
        }
 
        sock_reset_flag(nsk, SOCK_RCU_FREE);
-       /* will be fully established after successful MPC subflow creation */
-       inet_sk_state_store(nsk, TCP_SYN_RECV);
-
        security_inet_csk_clone(nsk, req);
+
+       /* this can't race with mptcp_close(), as the msk is
+        * not yet exposted to user-space
+        */
+       inet_sk_state_store(nsk, TCP_ESTABLISHED);
+
+       /* The msk maintain a ref to each subflow in the connections list */
+       WRITE_ONCE(msk->first, ssk);
+       list_add(&mptcp_subflow_ctx(ssk)->node, &msk->conn_list);
+       sock_hold(ssk);
+
+       /* new mpc subflow takes ownership of the newly
+        * created mptcp socket
+        */
+       mptcp_token_accept(subflow_req, msk);
+
+       /* set msk addresses early to ensure mptcp_pm_get_local_id()
+        * uses the correct data
+        */
+       mptcp_copy_inaddrs(nsk, ssk);
+       mptcp_propagate_sndbuf(nsk, ssk);
+
+       mptcp_rcv_space_init(msk, ssk);
        bh_unlock_sock(nsk);
 
-       /* keep a single reference */
-       __sock_put(nsk);
+       /* note: the newly allocated socket refcount is 2 now */
        return nsk;
 }
 
@@ -3121,7 +3215,7 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
        struct socket *listener;
        struct sock *newsk;
 
-       listener = __mptcp_nmpc_socket(msk);
+       listener = READ_ONCE(msk->subflow);
        if (WARN_ON_ONCE(!listener)) {
                *err = -EINVAL;
                return NULL;
@@ -3148,8 +3242,6 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
                        goto out;
                }
 
-               /* acquire the 2nd reference for the owning socket */
-               sock_hold(new_mptcp_sock);
                newsk = new_mptcp_sock;
                MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPCAPABLEPASSIVEACK);
        } else {
@@ -3243,9 +3335,14 @@ static void mptcp_release_cb(struct sock *sk)
        for (;;) {
                unsigned long flags = (msk->cb_flags & MPTCP_FLAGS_PROCESS_CTX_NEED) |
                                      msk->push_pending;
+               struct list_head join_list;
+
                if (!flags)
                        break;
 
+               INIT_LIST_HEAD(&join_list);
+               list_splice_init(&msk->join_list, &join_list);
+
                /* the following actions acquire the subflow socket lock
                 *
                 * 1) can't be invoked in atomic scope
@@ -3256,8 +3353,9 @@ static void mptcp_release_cb(struct sock *sk)
                msk->push_pending = 0;
                msk->cb_flags &= ~flags;
                spin_unlock_bh(&sk->sk_lock.slock);
+
                if (flags & BIT(MPTCP_FLUSH_JOIN_LIST))
-                       __mptcp_flush_join_list(sk);
+                       __mptcp_flush_join_list(sk, &join_list);
                if (flags & BIT(MPTCP_PUSH_PENDING))
                        __mptcp_push_pending(sk, 0);
                if (flags & BIT(MPTCP_RETRANSMIT))
@@ -3349,7 +3447,7 @@ static int mptcp_get_port(struct sock *sk, unsigned short snum)
        struct mptcp_sock *msk = mptcp_sk(sk);
        struct socket *ssock;
 
-       ssock = __mptcp_nmpc_socket(msk);
+       ssock = msk->subflow;
        pr_debug("msk=%p, subflow=%p", msk, ssock);
        if (WARN_ON_ONCE(!ssock))
                return -EINVAL;
@@ -3416,14 +3514,16 @@ bool mptcp_finish_join(struct sock *ssk)
                return false;
        }
 
-       if (!list_empty(&subflow->node))
-               goto out;
+       /* active subflow, already present inside the conn_list */
+       if (!list_empty(&subflow->node)) {
+               mptcp_subflow_joined(msk, ssk);
+               return true;
+       }
 
        if (!mptcp_pm_allow_new_subflow(msk))
                goto err_prohibited;
 
-       /* active connections are already on conn_list.
-        * If we can't acquire msk socket lock here, let the release callback
+       /* If we can't acquire msk socket lock here, let the release callback
         * handle it
         */
        mptcp_data_lock(parent);
@@ -3446,11 +3546,6 @@ err_prohibited:
                return false;
        }
 
-       subflow->map_seq = READ_ONCE(msk->ack_seq);
-       WRITE_ONCE(msk->allow_infinite_fallback, false);
-
-out:
-       mptcp_event(MPTCP_EVENT_SUB_ESTABLISHED, msk, ssk, GFP_ATOMIC);
        return true;
 }
 
@@ -3567,10 +3662,10 @@ static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
        /* if reaching here via the fastopen/sendmsg path, the caller already
         * acquired the subflow socket lock, too.
         */
-       if (msk->is_sendmsg)
-               err = __inet_stream_connect(ssock, uaddr, addr_len, msk->connect_flags, 1);
+       if (msk->fastopening)
+               err = __inet_stream_connect(ssock, uaddr, addr_len, O_NONBLOCK, 1);
        else
-               err = inet_stream_connect(ssock, uaddr, addr_len, msk->connect_flags);
+               err = inet_stream_connect(ssock, uaddr, addr_len, O_NONBLOCK);
        inet_sk(sk)->defer_connect = inet_sk(ssock->sk)->defer_connect;
 
        /* on successful connect, the msk state will be moved to established by
@@ -3583,12 +3678,10 @@ static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 
        mptcp_copy_inaddrs(sk, ssock->sk);
 
-       /* unblocking connect, mptcp-level inet_stream_connect will error out
-        * without changing the socket state, update it here.
+       /* silence EINPROGRESS and let the caller inet_stream_connect
+        * handle the connection in progress
         */
-       if (err == -EINPROGRESS)
-               sk->sk_socket->state = ssock->state;
-       return err;
+       return 0;
 }
 
 static struct proto mptcp_prot = {
@@ -3647,18 +3740,6 @@ unlock:
        return err;
 }
 
-static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
-                               int addr_len, int flags)
-{
-       int ret;
-
-       lock_sock(sock->sk);
-       mptcp_sk(sock->sk)->connect_flags = flags;
-       ret = __inet_stream_connect(sock, uaddr, addr_len, flags, 0);
-       release_sock(sock->sk);
-       return ret;
-}
-
 static int mptcp_listen(struct socket *sock, int backlog)
 {
        struct mptcp_sock *msk = mptcp_sk(sock->sk);
@@ -3697,7 +3778,10 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
 
        pr_debug("msk=%p", msk);
 
-       ssock = __mptcp_nmpc_socket(msk);
+       /* Buggy applications can call accept on socket states other then LISTEN
+        * but no need to allocate the first subflow just to error out.
+        */
+       ssock = READ_ONCE(msk->subflow);
        if (!ssock)
                return -EINVAL;
 
@@ -3707,23 +3791,9 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
                struct mptcp_subflow_context *subflow;
                struct sock *newsk = newsock->sk;
 
-               lock_sock(newsk);
-
-               /* PM/worker can now acquire the first subflow socket
-                * lock without racing with listener queue cleanup,
-                * we can notify it, if needed.
-                *
-                * Even if remote has reset the initial subflow by now
-                * the refcnt is still at least one.
-                */
-               subflow = mptcp_subflow_ctx(msk->first);
-               list_add(&subflow->node, &msk->conn_list);
-               sock_hold(msk->first);
-               if (mptcp_is_fully_established(newsk))
-                       mptcp_pm_fully_established(msk, msk->first, GFP_KERNEL);
+               msk->in_accept_queue = 0;
 
-               mptcp_rcv_space_init(msk, msk->first);
-               mptcp_propagate_sndbuf(newsk, msk->first);
+               lock_sock(newsk);
 
                /* set ssk->sk_socket of accept()ed flows to mptcp socket.
                 * This is needed so NOSPACE flag can be set from tcp stack.
@@ -3734,6 +3804,18 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
                        if (!ssk->sk_socket)
                                mptcp_sock_graft(ssk, newsock);
                }
+
+               /* Do late cleanup for the first subflow as necessary. Also
+                * deal with bad peers not doing a complete shutdown.
+                */
+               if (msk->first &&
+                   unlikely(inet_sk_state_load(msk->first) == TCP_CLOSE)) {
+                       __mptcp_close_ssk(newsk, msk->first,
+                                         mptcp_subflow_ctx(msk->first), 0);
+                       if (unlikely(list_empty(&msk->conn_list)))
+                               inet_sk_state_store(newsk, TCP_CLOSE);
+               }
+
                release_sock(newsk);
        }
 
@@ -3744,9 +3826,6 @@ static __poll_t mptcp_check_writeable(struct mptcp_sock *msk)
 {
        struct sock *sk = (struct sock *)msk;
 
-       if (unlikely(sk->sk_shutdown & SEND_SHUTDOWN))
-               return EPOLLOUT | EPOLLWRNORM;
-
        if (sk_stream_is_writeable(sk))
                return EPOLLOUT | EPOLLWRNORM;
 
@@ -3764,6 +3843,7 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock,
        struct sock *sk = sock->sk;
        struct mptcp_sock *msk;
        __poll_t mask = 0;
+       u8 shutdown;
        int state;
 
        msk = mptcp_sk(sk);
@@ -3772,23 +3852,30 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock,
        state = inet_sk_state_load(sk);
        pr_debug("msk=%p state=%d flags=%lx", msk, state, msk->flags);
        if (state == TCP_LISTEN) {
-               if (WARN_ON_ONCE(!msk->subflow || !msk->subflow->sk))
+               struct socket *ssock = READ_ONCE(msk->subflow);
+
+               if (WARN_ON_ONCE(!ssock || !ssock->sk))
                        return 0;
 
-               return inet_csk_listen_poll(msk->subflow->sk);
+               return inet_csk_listen_poll(ssock->sk);
        }
 
+       shutdown = READ_ONCE(sk->sk_shutdown);
+       if (shutdown == SHUTDOWN_MASK || state == TCP_CLOSE)
+               mask |= EPOLLHUP;
+       if (shutdown & RCV_SHUTDOWN)
+               mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
+
        if (state != TCP_SYN_SENT && state != TCP_SYN_RECV) {
                mask |= mptcp_check_readable(msk);
-               mask |= mptcp_check_writeable(msk);
+               if (shutdown & SEND_SHUTDOWN)
+                       mask |= EPOLLOUT | EPOLLWRNORM;
+               else
+                       mask |= mptcp_check_writeable(msk);
        } else if (state == TCP_SYN_SENT && inet_sk(sk)->defer_connect) {
                /* cf tcp_poll() note about TFO */
                mask |= EPOLLOUT | EPOLLWRNORM;
        }
-       if (sk->sk_shutdown == SHUTDOWN_MASK || state == TCP_CLOSE)
-               mask |= EPOLLHUP;
-       if (sk->sk_shutdown & RCV_SHUTDOWN)
-               mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
 
        /* This barrier is coupled with smp_wmb() in __mptcp_error_report() */
        smp_rmb();
@@ -3803,7 +3890,7 @@ static const struct proto_ops mptcp_stream_ops = {
        .owner             = THIS_MODULE,
        .release           = inet_release,
        .bind              = mptcp_bind,
-       .connect           = mptcp_stream_connect,
+       .connect           = inet_stream_connect,
        .socketpair        = sock_no_socketpair,
        .accept            = mptcp_stream_accept,
        .getname           = inet_getname,
@@ -3898,7 +3985,7 @@ static const struct proto_ops mptcp_v6_stream_ops = {
        .owner             = THIS_MODULE,
        .release           = inet6_release,
        .bind              = mptcp_bind,
-       .connect           = mptcp_stream_connect,
+       .connect           = inet_stream_connect,
        .socketpair        = sock_no_socketpair,
        .accept            = mptcp_stream_accept,
        .getname           = inet6_getname,
@@ -3920,12 +4007,6 @@ static const struct proto_ops mptcp_v6_stream_ops = {
 
 static struct proto mptcp_v6_prot;
 
-static void mptcp_v6_destroy(struct sock *sk)
-{
-       mptcp_destroy(sk);
-       inet6_destroy_sock(sk);
-}
-
 static struct inet_protosw mptcp_v6_protosw = {
        .type           = SOCK_STREAM,
        .protocol       = IPPROTO_MPTCP,
@@ -3941,7 +4022,6 @@ int __init mptcp_proto_v6_init(void)
        mptcp_v6_prot = mptcp_prot;
        strcpy(mptcp_v6_prot.name, "MPTCPv6");
        mptcp_v6_prot.slab = NULL;
-       mptcp_v6_prot.destroy = mptcp_v6_destroy;
        mptcp_v6_prot.obj_size = sizeof(struct mptcp6_sock);
 
        err = proto_register(&mptcp_v6_prot, 1);