Merge git://git.kernel.org/pub/scm/linux/kernel/git/netdev/net
[platform/kernel/linux-starfive.git] / net / mptcp / pm_netlink.c
index 3de83e2..5bdb559 100644 (file)
@@ -717,9 +717,10 @@ void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk)
        }
 }
 
-static int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
-                                       struct mptcp_addr_info *addr,
-                                       u8 bkup)
+int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
+                                struct mptcp_addr_info *addr,
+                                struct mptcp_addr_info *rem,
+                                u8 bkup)
 {
        struct mptcp_subflow_context *subflow;
 
@@ -727,24 +728,29 @@ static int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
 
        mptcp_for_each_subflow(msk, subflow) {
                struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
-               struct sock *sk = (struct sock *)msk;
-               struct mptcp_addr_info local;
+               struct mptcp_addr_info local, remote;
+               bool slow;
 
                local_address((struct sock_common *)ssk, &local);
                if (!mptcp_addresses_equal(&local, addr, addr->port))
                        continue;
 
+               if (rem && rem->family != AF_UNSPEC) {
+                       remote_address((struct sock_common *)ssk, &remote);
+                       if (!mptcp_addresses_equal(&remote, rem, rem->port))
+                               continue;
+               }
+
+               slow = lock_sock_fast(ssk);
                if (subflow->backup != bkup)
                        msk->last_snd = NULL;
                subflow->backup = bkup;
                subflow->send_mp_prio = 1;
                subflow->request_bkup = bkup;
-               __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPPRIOTX);
 
-               spin_unlock_bh(&msk->pm.lock);
                pr_debug("send ack for mp_prio");
-               mptcp_subflow_send_ack(ssk);
-               spin_lock_bh(&msk->pm.lock);
+               __mptcp_subflow_send_ack(ssk);
+               unlock_sock_fast(ssk, slow);
 
                return 0;
        }
@@ -801,7 +807,8 @@ static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,
                        removed = true;
                        __MPTCP_INC_STATS(sock_net(sk), rm_type);
                }
-               __set_bit(rm_list->ids[i], msk->pm.id_avail_bitmap);
+               if (rm_type == MPTCP_MIB_RMSUBFLOW)
+                       __set_bit(rm_list->ids[i], msk->pm.id_avail_bitmap);
                if (!removed)
                        continue;
 
@@ -1816,8 +1823,10 @@ static void mptcp_pm_nl_fullmesh(struct mptcp_sock *msk,
 
        list.ids[list.nr++] = addr->id;
 
+       spin_lock_bh(&msk->pm.lock);
        mptcp_pm_nl_rm_subflow_received(msk, &list);
        mptcp_pm_create_subflow_or_signal_addr(msk);
+       spin_unlock_bh(&msk->pm.lock);
 }
 
 static int mptcp_nl_set_flags(struct net *net,
@@ -1835,12 +1844,10 @@ static int mptcp_nl_set_flags(struct net *net,
                        goto next;
 
                lock_sock(sk);
-               spin_lock_bh(&msk->pm.lock);
                if (changed & MPTCP_PM_ADDR_FLAG_BACKUP)
-                       ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, bkup);
+                       ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, NULL, bkup);
                if (changed & MPTCP_PM_ADDR_FLAG_FULLMESH)
                        mptcp_pm_nl_fullmesh(msk, addr);
-               spin_unlock_bh(&msk->pm.lock);
                release_sock(sk);
 
 next:
@@ -1854,6 +1861,9 @@ next:
 static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
 {
        struct mptcp_pm_addr_entry addr = { .addr = { .family = AF_UNSPEC }, }, *entry;
+       struct mptcp_pm_addr_entry remote = { .addr = { .family = AF_UNSPEC }, };
+       struct nlattr *attr_rem = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE];
+       struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
        struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
        struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
        u8 changed, mask = MPTCP_PM_ADDR_FLAG_BACKUP |
@@ -1866,6 +1876,12 @@ static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
        if (ret < 0)
                return ret;
 
+       if (attr_rem) {
+               ret = mptcp_pm_parse_entry(attr_rem, info, false, &remote);
+               if (ret < 0)
+                       return ret;
+       }
+
        if (addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP)
                bkup = 1;
        if (addr.addr.family == AF_UNSPEC) {
@@ -1874,6 +1890,10 @@ static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
                        return -EOPNOTSUPP;
        }
 
+       if (token)
+               return mptcp_userspace_pm_set_flags(sock_net(skb->sk),
+                                                   token, &addr, &remote, bkup);
+
        spin_lock_bh(&pernet->lock);
        entry = __lookup_addr(pernet, &addr.addr, lookup_by_id);
        if (!entry) {