mptcp: move first subflow allocation at mpc access time
authorPaolo Abeni <pabeni@redhat.com>
Fri, 14 Apr 2023 14:08:03 +0000 (16:08 +0200)
committerDavid S. Miller <davem@davemloft.net>
Mon, 17 Apr 2023 07:18:34 +0000 (08:18 +0100)
In the long run this will simplify the mptcp code and will
allow for more consistent behavior. Move the first subflow
allocation out of the sock->init ops into the __mptcp_nmpc_socket()
helper.

Since the first subflow creation can now happen after the first
setsockopt() we additionally need to invoke mptcp_sockopt_sync()
on it.

Signed-off-by: Paolo Abeni <pabeni@redhat.com>
Reviewed-by: Matthieu Baerts <matthieu.baerts@tessares.net>
Signed-off-by: Matthieu Baerts <matthieu.baerts@tessares.net>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/mptcp/pm_netlink.c
net/mptcp/protocol.c
net/mptcp/protocol.h
net/mptcp/sockopt.c

index 1c42bebca39ed5fa3d5f5e625c3b48d78ef09457..bc343dab5e3fc1cbd7f376a9931a0e2d35b9818f 100644 (file)
@@ -1035,8 +1035,8 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
        lock_sock(newsk);
        ssock = __mptcp_nmpc_socket(mptcp_sk(newsk));
        release_sock(newsk);
-       if (!ssock)
-               return -EINVAL;
+       if (IS_ERR(ssock))
+               return PTR_ERR(ssock);
 
        mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family);
 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
index 22e073b373af80c96dad0458b15eb2cee34e560a..a676ac1bb9f1da649ca4858eb51601c6ecec6476 100644 (file)
@@ -49,18 +49,6 @@ static void __mptcp_check_send_data_fin(struct sock *sk);
 DEFINE_PER_CPU(struct mptcp_delegated_action, mptcp_delegated_actions);
 static struct net_device mptcp_napi_dev;
 
-/* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not
- * completed yet or has failed, return the subflow socket.
- * Otherwise return NULL.
- */
-struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
-{
-       if (!msk->subflow || READ_ONCE(msk->can_ack))
-               return NULL;
-
-       return msk->subflow;
-}
-
 /* Returns end sequence number of the receiver's advertised window */
 static u64 mptcp_wnd_end(const struct mptcp_sock *msk)
 {
@@ -116,6 +104,31 @@ static int __mptcp_socket_create(struct mptcp_sock *msk)
        return 0;
 }
 
+/* If the MPC handshake is not started, returns the first subflow,
+ * eventually allocating it.
+ */
+struct socket *__mptcp_nmpc_socket(struct mptcp_sock *msk)
+{
+       struct sock *sk = (struct sock *)msk;
+       int ret;
+
+       if (!((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN)))
+               return ERR_PTR(-EINVAL);
+
+       if (!msk->subflow) {
+               if (msk->first)
+                       return ERR_PTR(-EINVAL);
+
+               ret = __mptcp_socket_create(msk);
+               if (ret)
+                       return ERR_PTR(ret);
+
+               mptcp_sockopt_sync(msk, msk->first);
+       }
+
+       return msk->subflow;
+}
+
 static void mptcp_drop(struct sock *sk, struct sk_buff *skb)
 {
        sk_drops_add(sk, skb);
@@ -1667,6 +1680,7 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg,
 {
        unsigned int saved_flags = msg->msg_flags;
        struct mptcp_sock *msk = mptcp_sk(sk);
+       struct socket *ssock;
        struct sock *ssk;
        int ret;
 
@@ -1676,8 +1690,11 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg,
         * Since the defer_connect flag is cleared after the first succsful
         * fastopen attempt, no need to check for additional subflow status.
         */
-       if (msg->msg_flags & MSG_FASTOPEN && !__mptcp_nmpc_socket(msk))
-               return -EINVAL;
+       if (msg->msg_flags & MSG_FASTOPEN) {
+               ssock = __mptcp_nmpc_socket(msk);
+               if (IS_ERR(ssock))
+                       return PTR_ERR(ssock);
+       }
        if (!msk->first)
                return -EINVAL;
 
@@ -2740,10 +2757,6 @@ static int mptcp_init_sock(struct sock *sk)
        if (unlikely(!net->mib.mptcp_statistics) && !mptcp_mib_alloc(net))
                return -ENOMEM;
 
-       ret = __mptcp_socket_create(mptcp_sk(sk));
-       if (ret)
-               return ret;
-
        set_bit(SOCK_CUSTOM_SOCKOPT, &sk->sk_socket->flags);
 
        /* fetch the ca name; do it outside __mptcp_init_sock(), so that clone will
@@ -3563,8 +3576,8 @@ static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
        int err = -EINVAL;
 
        ssock = __mptcp_nmpc_socket(msk);
-       if (!ssock)
-               return -EINVAL;
+       if (IS_ERR(ssock))
+               return PTR_ERR(ssock);
 
        mptcp_token_destroy(msk);
        inet_sk_state_store(sk, TCP_SYN_SENT);
@@ -3652,8 +3665,8 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
 
        lock_sock(sock->sk);
        ssock = __mptcp_nmpc_socket(msk);
-       if (!ssock) {
-               err = -EINVAL;
+       if (IS_ERR(ssock)) {
+               err = PTR_ERR(ssock);
                goto unlock;
        }
 
@@ -3689,8 +3702,8 @@ static int mptcp_listen(struct socket *sock, int backlog)
 
        lock_sock(sk);
        ssock = __mptcp_nmpc_socket(msk);
-       if (!ssock) {
-               err = -EINVAL;
+       if (IS_ERR(ssock)) {
+               err = PTR_ERR(ssock);
                goto unlock;
        }
 
index a9eb0e428a6be0cde1b31b2d38a3b4126a7ff623..21eda9cd0c5281ee798c02c6b09e5c98174cd8e8 100644 (file)
@@ -627,7 +627,7 @@ void mptcp_close_ssk(struct sock *sk, struct sock *ssk,
 void __mptcp_subflow_send_ack(struct sock *ssk);
 void mptcp_subflow_reset(struct sock *ssk);
 void mptcp_sock_graft(struct sock *sk, struct socket *parent);
-struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk);
+struct socket *__mptcp_nmpc_socket(struct mptcp_sock *msk);
 bool __mptcp_close(struct sock *sk, long timeout);
 void mptcp_cancel_work(struct sock *sk);
 void mptcp_set_owner_r(struct sk_buff *skb, struct sock *sk);
index b655cebda0f38b89d4d6988ea2059b1e2ab2d994..d4258869ac48fe97bdf99dc48fbb7d145e353726 100644 (file)
@@ -301,9 +301,9 @@ static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname,
        case SO_BINDTOIFINDEX:
                lock_sock(sk);
                ssock = __mptcp_nmpc_socket(msk);
-               if (!ssock) {
+               if (IS_ERR(ssock)) {
                        release_sock(sk);
-                       return -EINVAL;
+                       return PTR_ERR(ssock);
                }
 
                ret = sock_setsockopt(ssock, SOL_SOCKET, optname, optval, optlen);
@@ -396,9 +396,9 @@ static int mptcp_setsockopt_v6(struct mptcp_sock *msk, int optname,
        case IPV6_FREEBIND:
                lock_sock(sk);
                ssock = __mptcp_nmpc_socket(msk);
-               if (!ssock) {
+               if (IS_ERR(ssock)) {
                        release_sock(sk);
-                       return -EINVAL;
+                       return PTR_ERR(ssock);
                }
 
                ret = tcp_setsockopt(ssock->sk, SOL_IPV6, optname, optval, optlen);
@@ -693,9 +693,9 @@ static int mptcp_setsockopt_sol_ip_set_transparent(struct mptcp_sock *msk, int o
        lock_sock(sk);
 
        ssock = __mptcp_nmpc_socket(msk);
-       if (!ssock) {
+       if (IS_ERR(ssock)) {
                release_sock(sk);
-               return -EINVAL;
+               return PTR_ERR(ssock);
        }
 
        issk = inet_sk(ssock->sk);
@@ -762,13 +762,15 @@ static int mptcp_setsockopt_first_sf_only(struct mptcp_sock *msk, int level, int
 {
        struct sock *sk = (struct sock *)msk;
        struct socket *sock;
-       int ret = -EINVAL;
+       int ret;
 
        /* Limit to first subflow, before the connection establishment */
        lock_sock(sk);
        sock = __mptcp_nmpc_socket(msk);
-       if (!sock)
+       if (IS_ERR(sock)) {
+               ret = PTR_ERR(sock);
                goto unlock;
+       }
 
        ret = tcp_setsockopt(sock->sk, level, optname, optval, optlen);
 
@@ -861,7 +863,7 @@ static int mptcp_getsockopt_first_sf_only(struct mptcp_sock *msk, int level, int
 {
        struct sock *sk = (struct sock *)msk;
        struct socket *ssock;
-       int ret = -EINVAL;
+       int ret;
        struct sock *ssk;
 
        lock_sock(sk);
@@ -872,8 +874,10 @@ static int mptcp_getsockopt_first_sf_only(struct mptcp_sock *msk, int level, int
        }
 
        ssock = __mptcp_nmpc_socket(msk);
-       if (!ssock)
+       if (IS_ERR(ssock)) {
+               ret = PTR_ERR(ssock);
                goto out;
+       }
 
        ret = tcp_getsockopt(ssock->sk, level, optname, optval, optlen);