xsk: Honor SO_BINDTODEVICE on bind
[platform/kernel/linux-starfive.git] / net / mptcp / protocol.c
index b6dc6e2..4ca61e8 100644 (file)
@@ -53,7 +53,7 @@ enum {
 static struct percpu_counter mptcp_sockets_allocated ____cacheline_aligned_in_smp;
 
 static void __mptcp_destroy_sock(struct sock *sk);
-static void __mptcp_check_send_data_fin(struct sock *sk);
+static void mptcp_check_send_data_fin(struct sock *sk);
 
 DEFINE_PER_CPU(struct mptcp_delegated_action, mptcp_delegated_actions);
 static struct net_device mptcp_napi_dev;
@@ -107,12 +107,12 @@ static int __mptcp_socket_create(struct mptcp_sock *msk)
        struct socket *ssock;
        int err;
 
-       err = mptcp_subflow_create_socket(sk, &ssock);
+       err = mptcp_subflow_create_socket(sk, sk->sk_family, &ssock);
        if (err)
                return err;
 
-       msk->first = ssock->sk;
-       msk->subflow = ssock;
+       WRITE_ONCE(msk->first, ssock->sk);
+       WRITE_ONCE(msk->subflow, ssock);
        subflow = mptcp_subflow_ctx(ssock->sk);
        list_add(&subflow->node, &msk->conn_list);
        sock_hold(ssock->sk);
@@ -420,8 +420,7 @@ static bool mptcp_pending_data_fin_ack(struct sock *sk)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
 
-       return !__mptcp_check_fallback(msk) &&
-              ((1 << sk->sk_state) &
+       return ((1 << sk->sk_state) &
                (TCPF_FIN_WAIT1 | TCPF_CLOSING | TCPF_LAST_ACK)) &&
               msk->write_seq == READ_ONCE(msk->snd_una);
 }
@@ -579,9 +578,6 @@ static bool mptcp_check_data_fin(struct sock *sk)
        u64 rcv_data_fin_seq;
        bool ret = false;
 
-       if (__mptcp_check_fallback(msk))
-               return ret;
-
        /* Need to ack a DATA_FIN received from a peer while this side
         * of the connection is in ESTABLISHED, FIN_WAIT1, or FIN_WAIT2.
         * msk->rcv_data_fin was set when parsing the incoming options
@@ -599,7 +595,7 @@ static bool mptcp_check_data_fin(struct sock *sk)
                WRITE_ONCE(msk->ack_seq, msk->ack_seq + 1);
                WRITE_ONCE(msk->rcv_data_fin, 0);
 
-               sk->sk_shutdown |= RCV_SHUTDOWN;
+               WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN);
                smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
 
                switch (sk->sk_state) {
@@ -619,7 +615,8 @@ static bool mptcp_check_data_fin(struct sock *sk)
                }
 
                ret = true;
-               mptcp_send_ack(msk);
+               if (!__mptcp_check_fallback(msk))
+                       mptcp_send_ack(msk);
                mptcp_close_wake_up(sk);
        }
        return ret;
@@ -821,6 +818,13 @@ void mptcp_data_ready(struct sock *sk, struct sock *ssk)
        mptcp_data_unlock(sk);
 }
 
+static void mptcp_subflow_joined(struct mptcp_sock *msk, struct sock *ssk)
+{
+       mptcp_subflow_ctx(ssk)->map_seq = READ_ONCE(msk->ack_seq);
+       WRITE_ONCE(msk->allow_infinite_fallback, false);
+       mptcp_event(MPTCP_EVENT_SUB_ESTABLISHED, msk, ssk, GFP_ATOMIC);
+}
+
 static bool __mptcp_finish_join(struct mptcp_sock *msk, struct sock *ssk)
 {
        struct sock *sk = (struct sock *)msk;
@@ -834,17 +838,17 @@ static bool __mptcp_finish_join(struct mptcp_sock *msk, struct sock *ssk)
        if (sk->sk_socket && !ssk->sk_socket)
                mptcp_sock_graft(ssk, sk->sk_socket);
 
-       mptcp_propagate_sndbuf((struct sock *)msk, ssk);
        mptcp_sockopt_sync_locked(msk, ssk);
+       mptcp_subflow_joined(msk, ssk);
        return true;
 }
 
-static void __mptcp_flush_join_list(struct sock *sk)
+static void __mptcp_flush_join_list(struct sock *sk, struct list_head *join_list)
 {
        struct mptcp_subflow_context *tmp, *subflow;
        struct mptcp_sock *msk = mptcp_sk(sk);
 
-       list_for_each_entry_safe(subflow, tmp, &msk->join_list, node) {
+       list_for_each_entry_safe(subflow, tmp, join_list, node) {
                struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
                bool slow = lock_sock_fast(ssk);
 
@@ -907,7 +911,7 @@ static void mptcp_check_for_eof(struct mptcp_sock *msk)
                /* hopefully temporary hack: propagate shutdown status
                 * to msk, when all subflows agree on it
                 */
-               sk->sk_shutdown |= RCV_SHUTDOWN;
+               WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN);
 
                smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
                sk->sk_data_ready(sk);
@@ -1599,7 +1603,7 @@ out:
        if (!mptcp_timer_pending(sk))
                mptcp_reset_timer(sk);
        if (do_check_data_fin)
-               __mptcp_check_send_data_fin(sk);
+               mptcp_check_send_data_fin(sk);
 }
 
 static void __mptcp_subflow_push_pending(struct sock *sk, struct sock *ssk)
@@ -1673,6 +1677,8 @@ static void mptcp_set_nospace(struct sock *sk)
        set_bit(MPTCP_NOSPACE, &mptcp_sk(sk)->flags);
 }
 
+static int mptcp_disconnect(struct sock *sk, int flags);
+
 static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msghdr *msg,
                                  size_t len, int *copied_syn)
 {
@@ -1682,10 +1688,9 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msgh
 
        lock_sock(ssk);
        msg->msg_flags |= MSG_DONTWAIT;
-       msk->connect_flags = O_NONBLOCK;
-       msk->is_sendmsg = 1;
+       msk->fastopening = 1;
        ret = tcp_sendmsg_fastopen(ssk, msg, copied_syn, len, NULL);
-       msk->is_sendmsg = 0;
+       msk->fastopening = 0;
        msg->msg_flags = saved_flags;
        release_sock(ssk);
 
@@ -1699,6 +1704,14 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msgh
                 */
                if (ret && ret != -EINPROGRESS && ret != -ERESTARTSYS && ret != -EINTR)
                        *copied_syn = 0;
+       } else if (ret && ret != -EINPROGRESS) {
+               /* The disconnect() op called by tcp_sendmsg_fastopen()/
+                * __inet_stream_connect() can fail, due to looking check,
+                * see mptcp_disconnect().
+                * Attempt it again outside the problematic scope.
+                */
+               if (!mptcp_disconnect(sk, 0))
+                       sk->sk_socket->state = SS_UNCONNECTED;
        }
 
        return ret;
@@ -2266,7 +2279,7 @@ static void mptcp_dispose_initial_subflow(struct mptcp_sock *msk)
 {
        if (msk->subflow) {
                iput(SOCK_INODE(msk->subflow));
-               msk->subflow = NULL;
+               WRITE_ONCE(msk->subflow, NULL);
        }
 }
 
@@ -2327,7 +2340,26 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
                              unsigned int flags)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
-       bool need_push, dispose_it;
+       bool dispose_it, need_push = false;
+
+       /* If the first subflow moved to a close state before accept, e.g. due
+        * to an incoming reset, mptcp either:
+        * - if either the subflow or the msk are dead, destroy the context
+        *   (the subflow socket is deleted by inet_child_forget) and the msk
+        * - otherwise do nothing at the moment and take action at accept and/or
+        *   listener shutdown - user-space must be able to accept() the closed
+        *   socket.
+        */
+       if (msk->in_accept_queue && msk->first == ssk) {
+               if (!sock_flag(sk, SOCK_DEAD) && !sock_flag(ssk, SOCK_DEAD))
+                       return;
+
+               /* ensure later check in mptcp_worker() will dispose the msk */
+               sock_set_flag(sk, SOCK_DEAD);
+               lock_sock_nested(ssk, SINGLE_DEPTH_NESTING);
+               mptcp_subflow_drop_ctx(ssk);
+               goto out_release;
+       }
 
        dispose_it = !msk->subflow || ssk != msk->subflow->sk;
        if (dispose_it)
@@ -2346,7 +2378,10 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
 
        need_push = (flags & MPTCP_CF_PUSH) && __mptcp_retransmit_pending_data(sk);
        if (!dispose_it) {
-               tcp_disconnect(ssk, 0);
+               /* The MPTCP code never wait on the subflow sockets, TCP-level
+                * disconnect should never fail
+                */
+               WARN_ON_ONCE(tcp_disconnect(ssk, 0));
                msk->subflow->state = SS_UNCONNECTED;
                mptcp_subflow_ctx_reset(subflow);
                release_sock(ssk);
@@ -2354,12 +2389,6 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
                goto out;
        }
 
-       /* if we are invoked by the msk cleanup code, the subflow is
-        * already orphaned
-        */
-       if (ssk->sk_socket)
-               sock_orphan(ssk);
-
        subflow->disposable = 1;
 
        /* if ssk hit tcp_done(), tcp_cleanup_ulp() cleared the related ops
@@ -2367,25 +2396,23 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
         * reference owned by msk;
         */
        if (!inet_csk(ssk)->icsk_ulp_ops) {
+               WARN_ON_ONCE(!sock_flag(ssk, SOCK_DEAD));
                kfree_rcu(subflow, rcu);
        } else {
                /* otherwise tcp will dispose of the ssk and subflow ctx */
-               if (ssk->sk_state == TCP_LISTEN) {
-                       tcp_set_state(ssk, TCP_CLOSE);
-                       mptcp_subflow_queue_clean(ssk);
-                       inet_csk_listen_stop(ssk);
-               }
                __tcp_close(ssk, 0);
 
                /* close acquired an extra ref */
                __sock_put(ssk);
        }
+
+out_release:
        release_sock(ssk);
 
        sock_put(ssk);
 
        if (ssk == msk->first)
-               msk->first = NULL;
+               WRITE_ONCE(msk->first, NULL);
 
 out:
        if (ssk == msk->last_snd)
@@ -2414,9 +2441,10 @@ static unsigned int mptcp_sync_mss(struct sock *sk, u32 pmtu)
        return 0;
 }
 
-static void __mptcp_close_subflow(struct mptcp_sock *msk)
+static void __mptcp_close_subflow(struct sock *sk)
 {
        struct mptcp_subflow_context *subflow, *tmp;
+       struct mptcp_sock *msk = mptcp_sk(sk);
 
        might_sleep();
 
@@ -2430,16 +2458,17 @@ static void __mptcp_close_subflow(struct mptcp_sock *msk)
                if (!skb_queue_empty_lockless(&ssk->sk_receive_queue))
                        continue;
 
-               mptcp_close_ssk((struct sock *)msk, ssk, subflow);
+               mptcp_close_ssk(sk, ssk, subflow);
        }
+
 }
 
-static bool mptcp_check_close_timeout(const struct sock *sk)
+static bool mptcp_should_close(const struct sock *sk)
 {
        s32 delta = tcp_jiffies32 - inet_csk(sk)->icsk_mtup.probe_timestamp;
        struct mptcp_subflow_context *subflow;
 
-       if (delta >= TCP_TIMEWAIT_LEN)
+       if (delta >= TCP_TIMEWAIT_LEN || mptcp_sk(sk)->in_accept_queue)
                return true;
 
        /* if all subflows are in closed status don't bother with additional
@@ -2490,7 +2519,7 @@ static void mptcp_check_fastclose(struct mptcp_sock *msk)
        }
 
        inet_sk_state_store(sk, TCP_CLOSE);
-       sk->sk_shutdown = SHUTDOWN_MASK;
+       WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK);
        smp_mb__before_atomic(); /* SHUTDOWN must be visible first */
        set_bit(MPTCP_WORK_CLOSE_SUBFLOW, &msk->flags);
 
@@ -2624,11 +2653,9 @@ static void mptcp_worker(struct work_struct *work)
 
        lock_sock(sk);
        state = sk->sk_state;
-       if (unlikely(state == TCP_CLOSE))
+       if (unlikely((1 << state) & (TCPF_CLOSE | TCPF_LISTEN)))
                goto unlock;
 
-       mptcp_check_data_fin_ack(sk);
-
        mptcp_check_fastclose(msk);
 
        mptcp_pm_nl_work(msk);
@@ -2636,15 +2663,19 @@ static void mptcp_worker(struct work_struct *work)
        if (test_and_clear_bit(MPTCP_WORK_EOF, &msk->flags))
                mptcp_check_for_eof(msk);
 
-       __mptcp_check_send_data_fin(sk);
+       mptcp_check_send_data_fin(sk);
+       mptcp_check_data_fin_ack(sk);
        mptcp_check_data_fin(sk);
 
+       if (test_and_clear_bit(MPTCP_WORK_CLOSE_SUBFLOW, &msk->flags))
+               __mptcp_close_subflow(sk);
+
        /* There is no point in keeping around an orphaned sk timedout or
         * closed, but we need the msk around to reply to incoming DATA_FIN,
         * even if it is orphaned and in FIN_WAIT2 state
         */
        if (sock_flag(sk, SOCK_DEAD)) {
-               if (mptcp_check_close_timeout(sk)) {
+               if (mptcp_should_close(sk)) {
                        inet_sk_state_store(sk, TCP_CLOSE);
                        mptcp_do_fastclose(sk);
                }
@@ -2654,9 +2685,6 @@ static void mptcp_worker(struct work_struct *work)
                }
        }
 
-       if (test_and_clear_bit(MPTCP_WORK_CLOSE_SUBFLOW, &msk->flags))
-               __mptcp_close_subflow(msk);
-
        if (test_and_clear_bit(MPTCP_WORK_RTX, &msk->flags))
                __mptcp_retrans(sk);
 
@@ -2684,7 +2712,7 @@ static int __mptcp_init_sock(struct sock *sk)
        WRITE_ONCE(msk->rmem_released, 0);
        msk->timer_ival = TCP_RTO_MIN;
 
-       msk->first = NULL;
+       WRITE_ONCE(msk->first, NULL);
        inet_csk(sk)->icsk_sync_mss = mptcp_sync_mss;
        WRITE_ONCE(msk->csum_enabled, mptcp_is_checksum_enabled(sock_net(sk)));
        WRITE_ONCE(msk->allow_infinite_fallback, true);
@@ -2770,13 +2798,19 @@ void mptcp_subflow_shutdown(struct sock *sk, struct sock *ssk, int how)
                        break;
                fallthrough;
        case TCP_SYN_SENT:
-               tcp_disconnect(ssk, O_NONBLOCK);
+               WARN_ON_ONCE(tcp_disconnect(ssk, O_NONBLOCK));
                break;
        default:
                if (__mptcp_check_fallback(mptcp_sk(sk))) {
                        pr_debug("Fallback");
                        ssk->sk_shutdown |= how;
                        tcp_shutdown(ssk, how);
+
+                       /* simulate the data_fin ack reception to let the state
+                        * machine move forward
+                        */
+                       WRITE_ONCE(mptcp_sk(sk)->snd_una, mptcp_sk(sk)->snd_nxt);
+                       mptcp_schedule_work(sk);
                } else {
                        pr_debug("Sending DATA_FIN on subflow %p", ssk);
                        tcp_send_ack(ssk);
@@ -2816,7 +2850,7 @@ static int mptcp_close_state(struct sock *sk)
        return next & TCP_ACTION_FIN;
 }
 
-static void __mptcp_check_send_data_fin(struct sock *sk)
+static void mptcp_check_send_data_fin(struct sock *sk)
 {
        struct mptcp_subflow_context *subflow;
        struct mptcp_sock *msk = mptcp_sk(sk);
@@ -2834,19 +2868,6 @@ static void __mptcp_check_send_data_fin(struct sock *sk)
 
        WRITE_ONCE(msk->snd_nxt, msk->write_seq);
 
-       /* fallback socket will not get data_fin/ack, can move to the next
-        * state now
-        */
-       if (__mptcp_check_fallback(msk)) {
-               WRITE_ONCE(msk->snd_una, msk->write_seq);
-               if ((1 << sk->sk_state) & (TCPF_CLOSING | TCPF_LAST_ACK)) {
-                       inet_sk_state_store(sk, TCP_CLOSE);
-                       mptcp_close_wake_up(sk);
-               } else if (sk->sk_state == TCP_FIN_WAIT1) {
-                       inet_sk_state_store(sk, TCP_FIN_WAIT2);
-               }
-       }
-
        mptcp_for_each_subflow(msk, subflow) {
                struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow);
 
@@ -2866,7 +2887,7 @@ static void __mptcp_wr_shutdown(struct sock *sk)
        WRITE_ONCE(msk->write_seq, msk->write_seq + 1);
        WRITE_ONCE(msk->snd_data_fin_enable, 1);
 
-       __mptcp_check_send_data_fin(sk);
+       mptcp_check_send_data_fin(sk);
 }
 
 static void __mptcp_destroy_sock(struct sock *sk)
@@ -2892,6 +2913,14 @@ static void __mptcp_destroy_sock(struct sock *sk)
        sock_put(sk);
 }
 
+void __mptcp_unaccepted_force_close(struct sock *sk)
+{
+       sock_set_flag(sk, SOCK_DEAD);
+       inet_sk_state_store(sk, TCP_CLOSE);
+       mptcp_do_fastclose(sk);
+       __mptcp_destroy_sock(sk);
+}
+
 static __poll_t mptcp_check_readable(struct mptcp_sock *msk)
 {
        /* Concurrent splices from sk_receive_queue into receive_queue will
@@ -2904,15 +2933,35 @@ static __poll_t mptcp_check_readable(struct mptcp_sock *msk)
        return EPOLLIN | EPOLLRDNORM;
 }
 
+static void mptcp_check_listen_stop(struct sock *sk)
+{
+       struct sock *ssk;
+
+       if (inet_sk_state_load(sk) != TCP_LISTEN)
+               return;
+
+       ssk = mptcp_sk(sk)->first;
+       if (WARN_ON_ONCE(!ssk || inet_sk_state_load(ssk) != TCP_LISTEN))
+               return;
+
+       lock_sock_nested(ssk, SINGLE_DEPTH_NESTING);
+       mptcp_subflow_queue_clean(sk, ssk);
+       inet_csk_listen_stop(ssk);
+       tcp_set_state(ssk, TCP_CLOSE);
+       release_sock(ssk);
+}
+
 bool __mptcp_close(struct sock *sk, long timeout)
 {
        struct mptcp_subflow_context *subflow;
        struct mptcp_sock *msk = mptcp_sk(sk);
        bool do_cancel_work = false;
+       int subflows_alive = 0;
 
-       sk->sk_shutdown = SHUTDOWN_MASK;
+       WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK);
 
        if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) {
+               mptcp_check_listen_stop(sk);
                inet_sk_state_store(sk, TCP_CLOSE);
                goto cleanup;
        }
@@ -2934,17 +2983,29 @@ cleanup:
                struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
                bool slow = lock_sock_fast_nested(ssk);
 
+               subflows_alive += ssk->sk_state != TCP_CLOSE;
+
                /* since the close timeout takes precedence on the fail one,
                 * cancel the latter
                 */
                if (ssk == msk->first)
                        subflow->fail_tout = 0;
 
-               sock_orphan(ssk);
+               /* detach from the parent socket, but allow data_ready to
+                * push incoming data into the mptcp stack, to properly ack it
+                */
+               ssk->sk_socket = NULL;
+               ssk->sk_wq = NULL;
                unlock_sock_fast(ssk, slow);
        }
        sock_orphan(sk);
 
+       /* all the subflows are closed, only timeout can change the msk
+        * state, let's not keep resources busy for no reasons
+        */
+       if (subflows_alive == 0)
+               inet_sk_state_store(sk, TCP_CLOSE);
+
        sock_hold(sk);
        pr_debug("msk=%p state=%d", sk, sk->sk_state);
        if (mptcp_sk(sk)->token)
@@ -2974,7 +3035,7 @@ static void mptcp_close(struct sock *sk, long timeout)
        sock_put(sk);
 }
 
-void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk)
+static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk)
 {
 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
        const struct ipv6_pinfo *ssk6 = inet6_sk(ssk);
@@ -3001,6 +3062,20 @@ static int mptcp_disconnect(struct sock *sk, int flags)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
 
+       /* Deny disconnect if other threads are blocked in sk_wait_event()
+        * or inet_wait_for_connect().
+        */
+       if (sk->sk_wait_pending)
+               return -EBUSY;
+
+       /* We are on the fastopen error path. We can't call straight into the
+        * subflows cleanup code due to lock nesting (we are already under
+        * msk->firstsocket lock).
+        */
+       if (msk->fastopening)
+               return -EBUSY;
+
+       mptcp_check_listen_stop(sk);
        inet_sk_state_store(sk, TCP_CLOSE);
 
        mptcp_stop_timer(sk);
@@ -3028,7 +3103,7 @@ static int mptcp_disconnect(struct sock *sk, int flags)
        mptcp_pm_data_reset(msk);
        mptcp_ca_reset(sk);
 
-       sk->sk_shutdown = 0;
+       WRITE_ONCE(sk->sk_shutdown, 0);
        sk_error_report(sk);
        return 0;
 }
@@ -3042,9 +3117,10 @@ static struct ipv6_pinfo *mptcp_inet6_sk(const struct sock *sk)
 }
 #endif
 
-struct sock *mptcp_sk_clone(const struct sock *sk,
-                           const struct mptcp_options_received *mp_opt,
-                           struct request_sock *req)
+struct sock *mptcp_sk_clone_init(const struct sock *sk,
+                                const struct mptcp_options_received *mp_opt,
+                                struct sock *ssk,
+                                struct request_sock *req)
 {
        struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req);
        struct sock *nsk = sk_clone_lock(sk, GFP_ATOMIC);
@@ -3059,12 +3135,14 @@ struct sock *mptcp_sk_clone(const struct sock *sk,
                inet_sk(nsk)->pinet6 = mptcp_inet6_sk(nsk);
 #endif
 
+       nsk->sk_wait_pending = 0;
        __mptcp_init_sock(nsk);
 
        msk = mptcp_sk(nsk);
        msk->local_key = subflow_req->local_key;
        msk->token = subflow_req->token;
-       msk->subflow = NULL;
+       WRITE_ONCE(msk->subflow, NULL);
+       msk->in_accept_queue = 1;
        WRITE_ONCE(msk->fully_established, false);
        if (mp_opt->suboptions & OPTION_MPTCP_CSUMREQD)
                WRITE_ONCE(msk->csum_enabled, true);
@@ -3085,14 +3163,33 @@ struct sock *mptcp_sk_clone(const struct sock *sk,
        }
 
        sock_reset_flag(nsk, SOCK_RCU_FREE);
-       /* will be fully established after successful MPC subflow creation */
-       inet_sk_state_store(nsk, TCP_SYN_RECV);
-
        security_inet_csk_clone(nsk, req);
+
+       /* this can't race with mptcp_close(), as the msk is
+        * not yet exposted to user-space
+        */
+       inet_sk_state_store(nsk, TCP_ESTABLISHED);
+
+       /* The msk maintain a ref to each subflow in the connections list */
+       WRITE_ONCE(msk->first, ssk);
+       list_add(&mptcp_subflow_ctx(ssk)->node, &msk->conn_list);
+       sock_hold(ssk);
+
+       /* new mpc subflow takes ownership of the newly
+        * created mptcp socket
+        */
+       mptcp_token_accept(subflow_req, msk);
+
+       /* set msk addresses early to ensure mptcp_pm_get_local_id()
+        * uses the correct data
+        */
+       mptcp_copy_inaddrs(nsk, ssk);
+       mptcp_propagate_sndbuf(nsk, ssk);
+
+       mptcp_rcv_space_init(msk, ssk);
        bh_unlock_sock(nsk);
 
-       /* keep a single reference */
-       __sock_put(nsk);
+       /* note: the newly allocated socket refcount is 2 now */
        return nsk;
 }
 
@@ -3121,7 +3218,7 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
        struct socket *listener;
        struct sock *newsk;
 
-       listener = __mptcp_nmpc_socket(msk);
+       listener = READ_ONCE(msk->subflow);
        if (WARN_ON_ONCE(!listener)) {
                *err = -EINVAL;
                return NULL;
@@ -3148,8 +3245,6 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
                        goto out;
                }
 
-               /* acquire the 2nd reference for the owning socket */
-               sock_hold(new_mptcp_sock);
                newsk = new_mptcp_sock;
                MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPCAPABLEPASSIVEACK);
        } else {
@@ -3243,9 +3338,14 @@ static void mptcp_release_cb(struct sock *sk)
        for (;;) {
                unsigned long flags = (msk->cb_flags & MPTCP_FLAGS_PROCESS_CTX_NEED) |
                                      msk->push_pending;
+               struct list_head join_list;
+
                if (!flags)
                        break;
 
+               INIT_LIST_HEAD(&join_list);
+               list_splice_init(&msk->join_list, &join_list);
+
                /* the following actions acquire the subflow socket lock
                 *
                 * 1) can't be invoked in atomic scope
@@ -3256,8 +3356,9 @@ static void mptcp_release_cb(struct sock *sk)
                msk->push_pending = 0;
                msk->cb_flags &= ~flags;
                spin_unlock_bh(&sk->sk_lock.slock);
+
                if (flags & BIT(MPTCP_FLUSH_JOIN_LIST))
-                       __mptcp_flush_join_list(sk);
+                       __mptcp_flush_join_list(sk, &join_list);
                if (flags & BIT(MPTCP_PUSH_PENDING))
                        __mptcp_push_pending(sk, 0);
                if (flags & BIT(MPTCP_RETRANSMIT))
@@ -3349,7 +3450,7 @@ static int mptcp_get_port(struct sock *sk, unsigned short snum)
        struct mptcp_sock *msk = mptcp_sk(sk);
        struct socket *ssock;
 
-       ssock = __mptcp_nmpc_socket(msk);
+       ssock = msk->subflow;
        pr_debug("msk=%p, subflow=%p", msk, ssock);
        if (WARN_ON_ONCE(!ssock))
                return -EINVAL;
@@ -3416,14 +3517,16 @@ bool mptcp_finish_join(struct sock *ssk)
                return false;
        }
 
-       if (!list_empty(&subflow->node))
-               goto out;
+       /* active subflow, already present inside the conn_list */
+       if (!list_empty(&subflow->node)) {
+               mptcp_subflow_joined(msk, ssk);
+               return true;
+       }
 
        if (!mptcp_pm_allow_new_subflow(msk))
                goto err_prohibited;
 
-       /* active connections are already on conn_list.
-        * If we can't acquire msk socket lock here, let the release callback
+       /* If we can't acquire msk socket lock here, let the release callback
         * handle it
         */
        mptcp_data_lock(parent);
@@ -3446,11 +3549,6 @@ err_prohibited:
                return false;
        }
 
-       subflow->map_seq = READ_ONCE(msk->ack_seq);
-       WRITE_ONCE(msk->allow_infinite_fallback, false);
-
-out:
-       mptcp_event(MPTCP_EVENT_SUB_ESTABLISHED, msk, ssk, GFP_ATOMIC);
        return true;
 }
 
@@ -3567,10 +3665,10 @@ static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
        /* if reaching here via the fastopen/sendmsg path, the caller already
         * acquired the subflow socket lock, too.
         */
-       if (msk->is_sendmsg)
-               err = __inet_stream_connect(ssock, uaddr, addr_len, msk->connect_flags, 1);
+       if (msk->fastopening)
+               err = __inet_stream_connect(ssock, uaddr, addr_len, O_NONBLOCK, 1);
        else
-               err = inet_stream_connect(ssock, uaddr, addr_len, msk->connect_flags);
+               err = inet_stream_connect(ssock, uaddr, addr_len, O_NONBLOCK);
        inet_sk(sk)->defer_connect = inet_sk(ssock->sk)->defer_connect;
 
        /* on successful connect, the msk state will be moved to established by
@@ -3583,12 +3681,10 @@ static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 
        mptcp_copy_inaddrs(sk, ssock->sk);
 
-       /* unblocking connect, mptcp-level inet_stream_connect will error out
-        * without changing the socket state, update it here.
+       /* silence EINPROGRESS and let the caller inet_stream_connect
+        * handle the connection in progress
         */
-       if (err == -EINPROGRESS)
-               sk->sk_socket->state = ssock->state;
-       return err;
+       return 0;
 }
 
 static struct proto mptcp_prot = {
@@ -3647,18 +3743,6 @@ unlock:
        return err;
 }
 
-static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
-                               int addr_len, int flags)
-{
-       int ret;
-
-       lock_sock(sock->sk);
-       mptcp_sk(sock->sk)->connect_flags = flags;
-       ret = __inet_stream_connect(sock, uaddr, addr_len, flags, 0);
-       release_sock(sock->sk);
-       return ret;
-}
-
 static int mptcp_listen(struct socket *sock, int backlog)
 {
        struct mptcp_sock *msk = mptcp_sk(sock->sk);
@@ -3697,7 +3781,10 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
 
        pr_debug("msk=%p", msk);
 
-       ssock = __mptcp_nmpc_socket(msk);
+       /* Buggy applications can call accept on socket states other then LISTEN
+        * but no need to allocate the first subflow just to error out.
+        */
+       ssock = READ_ONCE(msk->subflow);
        if (!ssock)
                return -EINVAL;
 
@@ -3707,23 +3794,9 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
                struct mptcp_subflow_context *subflow;
                struct sock *newsk = newsock->sk;
 
-               lock_sock(newsk);
-
-               /* PM/worker can now acquire the first subflow socket
-                * lock without racing with listener queue cleanup,
-                * we can notify it, if needed.
-                *
-                * Even if remote has reset the initial subflow by now
-                * the refcnt is still at least one.
-                */
-               subflow = mptcp_subflow_ctx(msk->first);
-               list_add(&subflow->node, &msk->conn_list);
-               sock_hold(msk->first);
-               if (mptcp_is_fully_established(newsk))
-                       mptcp_pm_fully_established(msk, msk->first, GFP_KERNEL);
+               msk->in_accept_queue = 0;
 
-               mptcp_rcv_space_init(msk, msk->first);
-               mptcp_propagate_sndbuf(newsk, msk->first);
+               lock_sock(newsk);
 
                /* set ssk->sk_socket of accept()ed flows to mptcp socket.
                 * This is needed so NOSPACE flag can be set from tcp stack.
@@ -3734,6 +3807,18 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
                        if (!ssk->sk_socket)
                                mptcp_sock_graft(ssk, newsock);
                }
+
+               /* Do late cleanup for the first subflow as necessary. Also
+                * deal with bad peers not doing a complete shutdown.
+                */
+               if (msk->first &&
+                   unlikely(inet_sk_state_load(msk->first) == TCP_CLOSE)) {
+                       __mptcp_close_ssk(newsk, msk->first,
+                                         mptcp_subflow_ctx(msk->first), 0);
+                       if (unlikely(list_empty(&msk->conn_list)))
+                               inet_sk_state_store(newsk, TCP_CLOSE);
+               }
+
                release_sock(newsk);
        }
 
@@ -3744,9 +3829,6 @@ static __poll_t mptcp_check_writeable(struct mptcp_sock *msk)
 {
        struct sock *sk = (struct sock *)msk;
 
-       if (unlikely(sk->sk_shutdown & SEND_SHUTDOWN))
-               return EPOLLOUT | EPOLLWRNORM;
-
        if (sk_stream_is_writeable(sk))
                return EPOLLOUT | EPOLLWRNORM;
 
@@ -3764,6 +3846,7 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock,
        struct sock *sk = sock->sk;
        struct mptcp_sock *msk;
        __poll_t mask = 0;
+       u8 shutdown;
        int state;
 
        msk = mptcp_sk(sk);
@@ -3772,23 +3855,30 @@ static __poll_t mptcp_poll(struct file *file, struct socket *sock,
        state = inet_sk_state_load(sk);
        pr_debug("msk=%p state=%d flags=%lx", msk, state, msk->flags);
        if (state == TCP_LISTEN) {
-               if (WARN_ON_ONCE(!msk->subflow || !msk->subflow->sk))
+               struct socket *ssock = READ_ONCE(msk->subflow);
+
+               if (WARN_ON_ONCE(!ssock || !ssock->sk))
                        return 0;
 
-               return inet_csk_listen_poll(msk->subflow->sk);
+               return inet_csk_listen_poll(ssock->sk);
        }
 
+       shutdown = READ_ONCE(sk->sk_shutdown);
+       if (shutdown == SHUTDOWN_MASK || state == TCP_CLOSE)
+               mask |= EPOLLHUP;
+       if (shutdown & RCV_SHUTDOWN)
+               mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
+
        if (state != TCP_SYN_SENT && state != TCP_SYN_RECV) {
                mask |= mptcp_check_readable(msk);
-               mask |= mptcp_check_writeable(msk);
+               if (shutdown & SEND_SHUTDOWN)
+                       mask |= EPOLLOUT | EPOLLWRNORM;
+               else
+                       mask |= mptcp_check_writeable(msk);
        } else if (state == TCP_SYN_SENT && inet_sk(sk)->defer_connect) {
                /* cf tcp_poll() note about TFO */
                mask |= EPOLLOUT | EPOLLWRNORM;
        }
-       if (sk->sk_shutdown == SHUTDOWN_MASK || state == TCP_CLOSE)
-               mask |= EPOLLHUP;
-       if (sk->sk_shutdown & RCV_SHUTDOWN)
-               mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
 
        /* This barrier is coupled with smp_wmb() in __mptcp_error_report() */
        smp_rmb();
@@ -3803,7 +3893,7 @@ static const struct proto_ops mptcp_stream_ops = {
        .owner             = THIS_MODULE,
        .release           = inet_release,
        .bind              = mptcp_bind,
-       .connect           = mptcp_stream_connect,
+       .connect           = inet_stream_connect,
        .socketpair        = sock_no_socketpair,
        .accept            = mptcp_stream_accept,
        .getname           = inet_getname,
@@ -3898,7 +3988,7 @@ static const struct proto_ops mptcp_v6_stream_ops = {
        .owner             = THIS_MODULE,
        .release           = inet6_release,
        .bind              = mptcp_bind,
-       .connect           = mptcp_stream_connect,
+       .connect           = inet_stream_connect,
        .socketpair        = sock_no_socketpair,
        .accept            = mptcp_stream_accept,
        .getname           = inet6_getname,
@@ -3920,12 +4010,6 @@ static const struct proto_ops mptcp_v6_stream_ops = {
 
 static struct proto mptcp_v6_prot;
 
-static void mptcp_v6_destroy(struct sock *sk)
-{
-       mptcp_destroy(sk);
-       inet6_destroy_sock(sk);
-}
-
 static struct inet_protosw mptcp_v6_protosw = {
        .type           = SOCK_STREAM,
        .protocol       = IPPROTO_MPTCP,
@@ -3941,7 +4025,6 @@ int __init mptcp_proto_v6_init(void)
        mptcp_v6_prot = mptcp_prot;
        strcpy(mptcp_v6_prot.name, "MPTCPv6");
        mptcp_v6_prot.slab = NULL;
-       mptcp_v6_prot.destroy = mptcp_v6_destroy;
        mptcp_v6_prot.obj_size = sizeof(struct mptcp6_sock);
 
        err = proto_register(&mptcp_v6_prot, 1);