xsk: Honor SO_BINDTODEVICE on bind
[platform/kernel/linux-starfive.git] / net / mptcp / protocol.c
index 5918699..4ca61e8 100644 (file)
@@ -53,7 +53,7 @@ enum {
 static struct percpu_counter mptcp_sockets_allocated ____cacheline_aligned_in_smp;
 
 static void __mptcp_destroy_sock(struct sock *sk);
-static void __mptcp_check_send_data_fin(struct sock *sk);
+static void mptcp_check_send_data_fin(struct sock *sk);
 
 DEFINE_PER_CPU(struct mptcp_delegated_action, mptcp_delegated_actions);
 static struct net_device mptcp_napi_dev;
@@ -420,8 +420,7 @@ static bool mptcp_pending_data_fin_ack(struct sock *sk)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
 
-       return !__mptcp_check_fallback(msk) &&
-              ((1 << sk->sk_state) &
+       return ((1 << sk->sk_state) &
                (TCPF_FIN_WAIT1 | TCPF_CLOSING | TCPF_LAST_ACK)) &&
               msk->write_seq == READ_ONCE(msk->snd_una);
 }
@@ -579,9 +578,6 @@ static bool mptcp_check_data_fin(struct sock *sk)
        u64 rcv_data_fin_seq;
        bool ret = false;
 
-       if (__mptcp_check_fallback(msk))
-               return ret;
-
        /* Need to ack a DATA_FIN received from a peer while this side
         * of the connection is in ESTABLISHED, FIN_WAIT1, or FIN_WAIT2.
         * msk->rcv_data_fin was set when parsing the incoming options
@@ -619,7 +615,8 @@ static bool mptcp_check_data_fin(struct sock *sk)
                }
 
                ret = true;
-               mptcp_send_ack(msk);
+               if (!__mptcp_check_fallback(msk))
+                       mptcp_send_ack(msk);
                mptcp_close_wake_up(sk);
        }
        return ret;
@@ -846,12 +843,12 @@ static bool __mptcp_finish_join(struct mptcp_sock *msk, struct sock *ssk)
        return true;
 }
 
-static void __mptcp_flush_join_list(struct sock *sk)
+static void __mptcp_flush_join_list(struct sock *sk, struct list_head *join_list)
 {
        struct mptcp_subflow_context *tmp, *subflow;
        struct mptcp_sock *msk = mptcp_sk(sk);
 
-       list_for_each_entry_safe(subflow, tmp, &msk->join_list, node) {
+       list_for_each_entry_safe(subflow, tmp, join_list, node) {
                struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
                bool slow = lock_sock_fast(ssk);
 
@@ -1606,7 +1603,7 @@ out:
        if (!mptcp_timer_pending(sk))
                mptcp_reset_timer(sk);
        if (do_check_data_fin)
-               __mptcp_check_send_data_fin(sk);
+               mptcp_check_send_data_fin(sk);
 }
 
 static void __mptcp_subflow_push_pending(struct sock *sk, struct sock *ssk)
@@ -1708,7 +1705,13 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msgh
                if (ret && ret != -EINPROGRESS && ret != -ERESTARTSYS && ret != -EINTR)
                        *copied_syn = 0;
        } else if (ret && ret != -EINPROGRESS) {
-               mptcp_disconnect(sk, 0);
+               /* The disconnect() op called by tcp_sendmsg_fastopen()/
+                * __inet_stream_connect() can fail, due to looking check,
+                * see mptcp_disconnect().
+                * Attempt it again outside the problematic scope.
+                */
+               if (!mptcp_disconnect(sk, 0))
+                       sk->sk_socket->state = SS_UNCONNECTED;
        }
 
        return ret;
@@ -2375,7 +2378,10 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
 
        need_push = (flags & MPTCP_CF_PUSH) && __mptcp_retransmit_pending_data(sk);
        if (!dispose_it) {
-               tcp_disconnect(ssk, 0);
+               /* The MPTCP code never wait on the subflow sockets, TCP-level
+                * disconnect should never fail
+                */
+               WARN_ON_ONCE(tcp_disconnect(ssk, 0));
                msk->subflow->state = SS_UNCONNECTED;
                mptcp_subflow_ctx_reset(subflow);
                release_sock(ssk);
@@ -2394,12 +2400,6 @@ static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
                kfree_rcu(subflow, rcu);
        } else {
                /* otherwise tcp will dispose of the ssk and subflow ctx */
-               if (ssk->sk_state == TCP_LISTEN) {
-                       tcp_set_state(ssk, TCP_CLOSE);
-                       mptcp_subflow_queue_clean(sk, ssk);
-                       inet_csk_listen_stop(ssk);
-               }
-
                __tcp_close(ssk, 0);
 
                /* close acquired an extra ref */
@@ -2656,8 +2656,6 @@ static void mptcp_worker(struct work_struct *work)
        if (unlikely((1 << state) & (TCPF_CLOSE | TCPF_LISTEN)))
                goto unlock;
 
-       mptcp_check_data_fin_ack(sk);
-
        mptcp_check_fastclose(msk);
 
        mptcp_pm_nl_work(msk);
@@ -2665,7 +2663,8 @@ static void mptcp_worker(struct work_struct *work)
        if (test_and_clear_bit(MPTCP_WORK_EOF, &msk->flags))
                mptcp_check_for_eof(msk);
 
-       __mptcp_check_send_data_fin(sk);
+       mptcp_check_send_data_fin(sk);
+       mptcp_check_data_fin_ack(sk);
        mptcp_check_data_fin(sk);
 
        if (test_and_clear_bit(MPTCP_WORK_CLOSE_SUBFLOW, &msk->flags))
@@ -2799,13 +2798,19 @@ void mptcp_subflow_shutdown(struct sock *sk, struct sock *ssk, int how)
                        break;
                fallthrough;
        case TCP_SYN_SENT:
-               tcp_disconnect(ssk, O_NONBLOCK);
+               WARN_ON_ONCE(tcp_disconnect(ssk, O_NONBLOCK));
                break;
        default:
                if (__mptcp_check_fallback(mptcp_sk(sk))) {
                        pr_debug("Fallback");
                        ssk->sk_shutdown |= how;
                        tcp_shutdown(ssk, how);
+
+                       /* simulate the data_fin ack reception to let the state
+                        * machine move forward
+                        */
+                       WRITE_ONCE(mptcp_sk(sk)->snd_una, mptcp_sk(sk)->snd_nxt);
+                       mptcp_schedule_work(sk);
                } else {
                        pr_debug("Sending DATA_FIN on subflow %p", ssk);
                        tcp_send_ack(ssk);
@@ -2845,7 +2850,7 @@ static int mptcp_close_state(struct sock *sk)
        return next & TCP_ACTION_FIN;
 }
 
-static void __mptcp_check_send_data_fin(struct sock *sk)
+static void mptcp_check_send_data_fin(struct sock *sk)
 {
        struct mptcp_subflow_context *subflow;
        struct mptcp_sock *msk = mptcp_sk(sk);
@@ -2863,19 +2868,6 @@ static void __mptcp_check_send_data_fin(struct sock *sk)
 
        WRITE_ONCE(msk->snd_nxt, msk->write_seq);
 
-       /* fallback socket will not get data_fin/ack, can move to the next
-        * state now
-        */
-       if (__mptcp_check_fallback(msk)) {
-               WRITE_ONCE(msk->snd_una, msk->write_seq);
-               if ((1 << sk->sk_state) & (TCPF_CLOSING | TCPF_LAST_ACK)) {
-                       inet_sk_state_store(sk, TCP_CLOSE);
-                       mptcp_close_wake_up(sk);
-               } else if (sk->sk_state == TCP_FIN_WAIT1) {
-                       inet_sk_state_store(sk, TCP_FIN_WAIT2);
-               }
-       }
-
        mptcp_for_each_subflow(msk, subflow) {
                struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow);
 
@@ -2895,7 +2887,7 @@ static void __mptcp_wr_shutdown(struct sock *sk)
        WRITE_ONCE(msk->write_seq, msk->write_seq + 1);
        WRITE_ONCE(msk->snd_data_fin_enable, 1);
 
-       __mptcp_check_send_data_fin(sk);
+       mptcp_check_send_data_fin(sk);
 }
 
 static void __mptcp_destroy_sock(struct sock *sk)
@@ -2941,6 +2933,24 @@ static __poll_t mptcp_check_readable(struct mptcp_sock *msk)
        return EPOLLIN | EPOLLRDNORM;
 }
 
+static void mptcp_check_listen_stop(struct sock *sk)
+{
+       struct sock *ssk;
+
+       if (inet_sk_state_load(sk) != TCP_LISTEN)
+               return;
+
+       ssk = mptcp_sk(sk)->first;
+       if (WARN_ON_ONCE(!ssk || inet_sk_state_load(ssk) != TCP_LISTEN))
+               return;
+
+       lock_sock_nested(ssk, SINGLE_DEPTH_NESTING);
+       mptcp_subflow_queue_clean(sk, ssk);
+       inet_csk_listen_stop(ssk);
+       tcp_set_state(ssk, TCP_CLOSE);
+       release_sock(ssk);
+}
+
 bool __mptcp_close(struct sock *sk, long timeout)
 {
        struct mptcp_subflow_context *subflow;
@@ -2951,6 +2961,7 @@ bool __mptcp_close(struct sock *sk, long timeout)
        WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK);
 
        if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) {
+               mptcp_check_listen_stop(sk);
                inet_sk_state_store(sk, TCP_CLOSE);
                goto cleanup;
        }
@@ -3051,14 +3062,20 @@ static int mptcp_disconnect(struct sock *sk, int flags)
 {
        struct mptcp_sock *msk = mptcp_sk(sk);
 
+       /* Deny disconnect if other threads are blocked in sk_wait_event()
+        * or inet_wait_for_connect().
+        */
+       if (sk->sk_wait_pending)
+               return -EBUSY;
+
        /* We are on the fastopen error path. We can't call straight into the
         * subflows cleanup code due to lock nesting (we are already under
-        * msk->firstsocket lock). Do nothing and leave the cleanup to the
-        * caller.
+        * msk->firstsocket lock).
         */
        if (msk->fastopening)
-               return 0;
+               return -EBUSY;
 
+       mptcp_check_listen_stop(sk);
        inet_sk_state_store(sk, TCP_CLOSE);
 
        mptcp_stop_timer(sk);
@@ -3118,6 +3135,7 @@ struct sock *mptcp_sk_clone_init(const struct sock *sk,
                inet_sk(nsk)->pinet6 = mptcp_inet6_sk(nsk);
 #endif
 
+       nsk->sk_wait_pending = 0;
        __mptcp_init_sock(nsk);
 
        msk = mptcp_sk(nsk);
@@ -3320,9 +3338,14 @@ static void mptcp_release_cb(struct sock *sk)
        for (;;) {
                unsigned long flags = (msk->cb_flags & MPTCP_FLAGS_PROCESS_CTX_NEED) |
                                      msk->push_pending;
+               struct list_head join_list;
+
                if (!flags)
                        break;
 
+               INIT_LIST_HEAD(&join_list);
+               list_splice_init(&msk->join_list, &join_list);
+
                /* the following actions acquire the subflow socket lock
                 *
                 * 1) can't be invoked in atomic scope
@@ -3333,8 +3356,9 @@ static void mptcp_release_cb(struct sock *sk)
                msk->push_pending = 0;
                msk->cb_flags &= ~flags;
                spin_unlock_bh(&sk->sk_lock.slock);
+
                if (flags & BIT(MPTCP_FLUSH_JOIN_LIST))
-                       __mptcp_flush_join_list(sk);
+                       __mptcp_flush_join_list(sk, &join_list);
                if (flags & BIT(MPTCP_PUSH_PENDING))
                        __mptcp_push_pending(sk, 0);
                if (flags & BIT(MPTCP_RETRANSMIT))