mptcp: remove addr and subflow in PM netlink
[platform/kernel/linux-rpi.git] / net / mptcp / pm_netlink.c
index f6f96bc..97f9280 100644 (file)
@@ -177,6 +177,50 @@ static void check_work_pending(struct mptcp_sock *msk)
                WRITE_ONCE(msk->pm.work_pending, false);
 }
 
+static bool lookup_anno_list_by_saddr(struct mptcp_sock *msk,
+                                     struct mptcp_addr_info *addr)
+{
+       struct mptcp_pm_addr_entry *entry;
+
+       list_for_each_entry(entry, &msk->pm.anno_list, list) {
+               if (addresses_equal(&entry->addr, addr, false))
+                       return true;
+       }
+
+       return false;
+}
+
+static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk,
+                                    struct mptcp_pm_addr_entry *entry)
+{
+       struct mptcp_pm_addr_entry *clone = NULL;
+
+       if (lookup_anno_list_by_saddr(msk, &entry->addr))
+               return false;
+
+       clone = kmemdup(entry, sizeof(*entry), GFP_ATOMIC);
+       if (!clone)
+               return false;
+
+       list_add(&clone->list, &msk->pm.anno_list);
+
+       return true;
+}
+
+void mptcp_pm_free_anno_list(struct mptcp_sock *msk)
+{
+       struct mptcp_pm_addr_entry *entry, *tmp;
+
+       pr_debug("msk=%p", msk);
+
+       spin_lock_bh(&msk->pm.lock);
+       list_for_each_entry_safe(entry, tmp, &msk->pm.anno_list, list) {
+               list_del(&entry->list);
+               kfree(entry);
+       }
+       spin_unlock_bh(&msk->pm.lock);
+}
+
 static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
 {
        struct mptcp_addr_info remote = { 0 };
@@ -197,8 +241,10 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
                                              msk->pm.add_addr_signaled);
 
                if (local) {
-                       msk->pm.add_addr_signaled++;
-                       mptcp_pm_announce_addr(msk, &local->addr, false);
+                       if (mptcp_pm_alloc_anno_list(msk, local)) {
+                               msk->pm.add_addr_signaled++;
+                               mptcp_pm_announce_addr(msk, &local->addr, false);
+                       }
                } else {
                        /* pick failed, avoid fourther attempts later */
                        msk->pm.local_addr_used = msk->pm.add_addr_signal_max;
@@ -567,6 +613,68 @@ __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id)
        return NULL;
 }
 
+static bool remove_anno_list_by_saddr(struct mptcp_sock *msk,
+                                     struct mptcp_addr_info *addr)
+{
+       struct mptcp_pm_addr_entry *entry, *tmp;
+
+       list_for_each_entry_safe(entry, tmp, &msk->pm.anno_list, list) {
+               if (addresses_equal(&entry->addr, addr, false)) {
+                       list_del(&entry->list);
+                       kfree(entry);
+                       return true;
+               }
+       }
+
+       return false;
+}
+
+static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk,
+                                     struct mptcp_addr_info *addr,
+                                     bool force)
+{
+       bool ret;
+
+       spin_lock_bh(&msk->pm.lock);
+       ret = remove_anno_list_by_saddr(msk, addr);
+       if (ret || force)
+               mptcp_pm_remove_addr(msk, addr->id);
+       spin_unlock_bh(&msk->pm.lock);
+       return ret;
+}
+
+static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net,
+                                                  struct mptcp_addr_info *addr)
+{
+       struct mptcp_sock *msk;
+       long s_slot = 0, s_num = 0;
+
+       pr_debug("remove_id=%d", addr->id);
+
+       while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
+               struct sock *sk = (struct sock *)msk;
+               bool remove_subflow;
+
+               if (list_empty(&msk->conn_list)) {
+                       mptcp_pm_remove_anno_addr(msk, addr, false);
+                       goto next;
+               }
+
+               lock_sock(sk);
+               remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr);
+               mptcp_pm_remove_anno_addr(msk, addr, remove_subflow);
+               if (remove_subflow)
+                       mptcp_pm_remove_subflow(msk, addr->id);
+               release_sock(sk);
+
+next:
+               sock_put(sk);
+               cond_resched();
+       }
+
+       return 0;
+}
+
 static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
 {
        struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
@@ -582,8 +690,8 @@ static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
        entry = __lookup_addr_by_id(pernet, addr.addr.id);
        if (!entry) {
                GENL_SET_ERR_MSG(info, "address not found");
-               ret = -EINVAL;
-               goto out;
+               spin_unlock_bh(&pernet->lock);
+               return -EINVAL;
        }
        if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)
                pernet->add_addr_signal_max--;
@@ -592,9 +700,11 @@ static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
 
        pernet->addrs--;
        list_del_rcu(&entry->list);
-       kfree_rcu(entry, rcu);
-out:
        spin_unlock_bh(&pernet->lock);
+
+       mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr);
+       kfree_rcu(entry, rcu);
+
        return ret;
 }