mptcp: add annotations around sk->sk_shutdown accesses
[platform/kernel/linux-starfive.git] / net / mptcp / pm_netlink.c
index 9813ed0..1c69e47 100644 (file)
@@ -987,13 +987,17 @@ out:
        return ret;
 }
 
+static struct lock_class_key mptcp_slock_keys[2];
+static struct lock_class_key mptcp_keys[2];
+
 static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
                                            struct mptcp_pm_addr_entry *entry)
 {
+       bool is_ipv6 = sk->sk_family == AF_INET6;
        int addrlen = sizeof(struct sockaddr_in);
        struct sockaddr_storage addr;
-       struct mptcp_sock *msk;
        struct socket *ssock;
+       struct sock *newsk;
        int backlog = 1024;
        int err;
 
@@ -1002,17 +1006,27 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
        if (err)
                return err;
 
-       msk = mptcp_sk(entry->lsk->sk);
-       if (!msk) {
-               err = -EINVAL;
-               goto out;
-       }
+       newsk = entry->lsk->sk;
+       if (!newsk)
+               return -EINVAL;
 
-       ssock = __mptcp_nmpc_socket(msk);
-       if (!ssock) {
-               err = -EINVAL;
-               goto out;
-       }
+       /* The subflow socket lock is acquired in a nested to the msk one
+        * in several places, even by the TCP stack, and this msk is a kernel
+        * socket: lockdep complains. Instead of propagating the _nested
+        * modifiers in several places, re-init the lock class for the msk
+        * socket to an mptcp specific one.
+        */
+       sock_lock_init_class_and_name(newsk,
+                                     is_ipv6 ? "mlock-AF_INET6" : "mlock-AF_INET",
+                                     &mptcp_slock_keys[is_ipv6],
+                                     is_ipv6 ? "msk_lock-AF_INET6" : "msk_lock-AF_INET",
+                                     &mptcp_keys[is_ipv6]);
+
+       lock_sock(newsk);
+       ssock = __mptcp_nmpc_socket(mptcp_sk(newsk));
+       release_sock(newsk);
+       if (!ssock)
+               return -EINVAL;
 
        mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family);
 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
@@ -1022,20 +1036,16 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
        err = kernel_bind(ssock, (struct sockaddr *)&addr, addrlen);
        if (err) {
                pr_warn("kernel_bind error, err=%d", err);
-               goto out;
+               return err;
        }
 
        err = kernel_listen(ssock, backlog);
        if (err) {
                pr_warn("kernel_listen error, err=%d", err);
-               goto out;
+               return err;
        }
 
        return 0;
-
-out:
-       sock_release(entry->lsk);
-       return err;
 }
 
 int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
@@ -1327,7 +1337,7 @@ static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info)
                return -EINVAL;
        }
 
-       entry = kmalloc(sizeof(*entry), GFP_KERNEL_ACCOUNT);
+       entry = kzalloc(sizeof(*entry), GFP_KERNEL_ACCOUNT);
        if (!entry) {
                GENL_SET_ERR_MSG(info, "can't allocate addr");
                return -ENOMEM;
@@ -1338,22 +1348,21 @@ static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info)
                ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry);
                if (ret) {
                        GENL_SET_ERR_MSG(info, "create listen socket error");
-                       kfree(entry);
-                       return ret;
+                       goto out_free;
                }
        }
        ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
        if (ret < 0) {
                GENL_SET_ERR_MSG(info, "too many addresses or duplicate one");
-               if (entry->lsk)
-                       sock_release(entry->lsk);
-               kfree(entry);
-               return ret;
+               goto out_free;
        }
 
        mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk));
-
        return 0;
+
+out_free:
+       __mptcp_pm_release_addr_entry(entry);
+       return ret;
 }
 
 int mptcp_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, unsigned int id,