mptcp: don't return sockets in foreign netns
authorFlorian Westphal <fw@strlen.de>
Fri, 24 Sep 2021 00:04:11 +0000 (17:04 -0700)
committerGreg Kroah-Hartman <gregkh@linuxfoundation.org>
Wed, 6 Oct 2021 13:55:52 +0000 (15:55 +0200)
[ Upstream commit ea1300b9df7c8e8b65695a08b8f6aaf4b25fec9c ]

mptcp_token_get_sock() may return a mptcp socket that is in
a different net namespace than the socket that received the token value.

The mptcp syncookie code path had an explicit check for this,
this moves the test into mptcp_token_get_sock() function.

Eventually token.c should be converted to pernet storage, but
such change is not suitable for net tree.

Fixes: 2c5ebd001d4f0 ("mptcp: refactor token container")
Signed-off-by: Florian Westphal <fw@strlen.de>
Signed-off-by: Mat Martineau <mathew.j.martineau@linux.intel.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
Signed-off-by: Sasha Levin <sashal@kernel.org>
net/mptcp/mptcp_diag.c
net/mptcp/protocol.h
net/mptcp/subflow.c
net/mptcp/syncookies.c
net/mptcp/token.c
net/mptcp/token_test.c

index 5f390a97f556de8a0dcfaf985ac14e09bc66d086..f1af3f44875edef913b9ea78bcc93eb25f87e18a 100644 (file)
@@ -36,7 +36,7 @@ static int mptcp_diag_dump_one(struct netlink_callback *cb,
        struct sock *sk;
 
        net = sock_net(in_skb->sk);
-       msk = mptcp_token_get_sock(req->id.idiag_cookie[0]);
+       msk = mptcp_token_get_sock(net, req->id.idiag_cookie[0]);
        if (!msk)
                goto out_nosk;
 
index 13ab89dc19141de0fb8416c87424c9d87adaceed..3e5af8397434a97f9a06633d633af4dbb7be1fc1 100644 (file)
@@ -424,7 +424,7 @@ int mptcp_token_new_connect(struct sock *sk);
 void mptcp_token_accept(struct mptcp_subflow_request_sock *r,
                        struct mptcp_sock *msk);
 bool mptcp_token_exists(u32 token);
-struct mptcp_sock *mptcp_token_get_sock(u32 token);
+struct mptcp_sock *mptcp_token_get_sock(struct net *net, u32 token);
 struct mptcp_sock *mptcp_token_iter_next(const struct net *net, long *s_slot,
                                         long *s_num);
 void mptcp_token_destroy(struct mptcp_sock *msk);
index bba5696fee36d5c9873ee9719bec224dacbaec3b..2e923849092412e3d5bb425cd3fe3acba4549310 100644 (file)
@@ -69,7 +69,7 @@ static struct mptcp_sock *subflow_token_join_request(struct request_sock *req,
        struct mptcp_sock *msk;
        int local_id;
 
-       msk = mptcp_token_get_sock(subflow_req->token);
+       msk = mptcp_token_get_sock(sock_net(req_to_sk(req)), subflow_req->token);
        if (!msk) {
                SUBFLOW_REQ_INC_STATS(req, MPTCP_MIB_JOINNOTOKEN);
                return NULL;
index 37127781aee987055acf380f8696a2326e50971f..7f22526346a7e6667b73599ce27771fc18824cfd 100644 (file)
@@ -108,18 +108,12 @@ bool mptcp_token_join_cookie_init_state(struct mptcp_subflow_request_sock *subfl
 
        e->valid = 0;
 
-       msk = mptcp_token_get_sock(e->token);
+       msk = mptcp_token_get_sock(net, e->token);
        if (!msk) {
                spin_unlock_bh(&join_entry_locks[i]);
                return false;
        }
 
-       /* If this fails, the token got re-used in the mean time by another
-        * mptcp socket in a different netns, i.e. entry is outdated.
-        */
-       if (!net_eq(sock_net((struct sock *)msk), net))
-               goto err_put;
-
        subflow_req->remote_nonce = e->remote_nonce;
        subflow_req->local_nonce = e->local_nonce;
        subflow_req->backup = e->backup;
@@ -128,11 +122,6 @@ bool mptcp_token_join_cookie_init_state(struct mptcp_subflow_request_sock *subfl
        subflow_req->msk = msk;
        spin_unlock_bh(&join_entry_locks[i]);
        return true;
-
-err_put:
-       spin_unlock_bh(&join_entry_locks[i]);
-       sock_put((struct sock *)msk);
-       return false;
 }
 
 void __init mptcp_join_cookie_init(void)
index 0691a4883f3ab13d24992602b9645cd1ab0f3616..f0d656bf27ada007a6f7306d88f0870fc337e7a0 100644 (file)
@@ -232,6 +232,7 @@ found:
 
 /**
  * mptcp_token_get_sock - retrieve mptcp connection sock using its token
+ * @net: restrict to this namespace
  * @token: token of the mptcp connection to retrieve
  *
  * This function returns the mptcp connection structure with the given token.
@@ -239,7 +240,7 @@ found:
  *
  * returns NULL if no connection with the given token value exists.
  */
-struct mptcp_sock *mptcp_token_get_sock(u32 token)
+struct mptcp_sock *mptcp_token_get_sock(struct net *net, u32 token)
 {
        struct hlist_nulls_node *pos;
        struct token_bucket *bucket;
@@ -252,11 +253,15 @@ struct mptcp_sock *mptcp_token_get_sock(u32 token)
 again:
        sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) {
                msk = mptcp_sk(sk);
-               if (READ_ONCE(msk->token) != token)
+               if (READ_ONCE(msk->token) != token ||
+                   !net_eq(sock_net(sk), net))
                        continue;
+
                if (!refcount_inc_not_zero(&sk->sk_refcnt))
                        goto not_found;
-               if (READ_ONCE(msk->token) != token) {
+
+               if (READ_ONCE(msk->token) != token ||
+                   !net_eq(sock_net(sk), net)) {
                        sock_put(sk);
                        goto again;
                }
index e1bd6f0a0676fa1f6ecec42dc9e24f97247b33fe..5d984bec1cd865b78ee19adfe248bac57dc1d24e 100644 (file)
@@ -11,6 +11,7 @@ static struct mptcp_subflow_request_sock *build_req_sock(struct kunit *test)
                            GFP_USER);
        KUNIT_EXPECT_NOT_ERR_OR_NULL(test, req);
        mptcp_token_init_request((struct request_sock *)req);
+       sock_net_set((struct sock *)req, &init_net);
        return req;
 }
 
@@ -22,7 +23,7 @@ static void mptcp_token_test_req_basic(struct kunit *test)
        KUNIT_ASSERT_EQ(test, 0,
                        mptcp_token_new_request((struct request_sock *)req));
        KUNIT_EXPECT_NE(test, 0, (int)req->token);
-       KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(req->token));
+       KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, req->token));
 
        /* cleanup */
        mptcp_token_destroy_request((struct request_sock *)req);
@@ -55,6 +56,7 @@ static struct mptcp_sock *build_msk(struct kunit *test)
        msk = kunit_kzalloc(test, sizeof(struct mptcp_sock), GFP_USER);
        KUNIT_EXPECT_NOT_ERR_OR_NULL(test, msk);
        refcount_set(&((struct sock *)msk)->sk_refcnt, 1);
+       sock_net_set((struct sock *)msk, &init_net);
        return msk;
 }
 
@@ -74,11 +76,11 @@ static void mptcp_token_test_msk_basic(struct kunit *test)
                        mptcp_token_new_connect((struct sock *)icsk));
        KUNIT_EXPECT_NE(test, 0, (int)ctx->token);
        KUNIT_EXPECT_EQ(test, ctx->token, msk->token);
-       KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(ctx->token));
+       KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, ctx->token));
        KUNIT_EXPECT_EQ(test, 2, (int)refcount_read(&sk->sk_refcnt));
 
        mptcp_token_destroy(msk);
-       KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(ctx->token));
+       KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, ctx->token));
 }
 
 static void mptcp_token_test_accept(struct kunit *test)
@@ -90,11 +92,11 @@ static void mptcp_token_test_accept(struct kunit *test)
                        mptcp_token_new_request((struct request_sock *)req));
        msk->token = req->token;
        mptcp_token_accept(req, msk);
-       KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(msk->token));
+       KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, msk->token));
 
        /* this is now a no-op */
        mptcp_token_destroy_request((struct request_sock *)req);
-       KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(msk->token));
+       KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, msk->token));
 
        /* cleanup */
        mptcp_token_destroy(msk);
@@ -116,7 +118,7 @@ static void mptcp_token_test_destroyed(struct kunit *test)
 
        /* simulate race on removal */
        refcount_set(&sk->sk_refcnt, 0);
-       KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(msk->token));
+       KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, msk->token));
 
        /* cleanup */
        mptcp_token_destroy(msk);