netfilter: nft_set_rbtree: .deactivate fails if element has expired
[platform/kernel/linux-starfive.git] / net / netfilter / nft_set_rbtree.c
index 487572d..e34662f 100644 (file)
@@ -233,10 +233,9 @@ static void nft_rbtree_gc_remove(struct net *net, struct nft_set *set,
        rb_erase(&rbe->node, &priv->root);
 }
 
-static int nft_rbtree_gc_elem(const struct nft_set *__set,
-                             struct nft_rbtree *priv,
-                             struct nft_rbtree_elem *rbe,
-                             u8 genmask)
+static const struct nft_rbtree_elem *
+nft_rbtree_gc_elem(const struct nft_set *__set, struct nft_rbtree *priv,
+                  struct nft_rbtree_elem *rbe, u8 genmask)
 {
        struct nft_set *set = (struct nft_set *)__set;
        struct rb_node *prev = rb_prev(&rbe->node);
@@ -246,7 +245,7 @@ static int nft_rbtree_gc_elem(const struct nft_set *__set,
 
        gc = nft_trans_gc_alloc(set, 0, GFP_ATOMIC);
        if (!gc)
-               return -ENOMEM;
+               return ERR_PTR(-ENOMEM);
 
        /* search for end interval coming before this element.
         * end intervals don't carry a timeout extension, they
@@ -261,6 +260,7 @@ static int nft_rbtree_gc_elem(const struct nft_set *__set,
                prev = rb_prev(prev);
        }
 
+       rbe_prev = NULL;
        if (prev) {
                rbe_prev = rb_entry(prev, struct nft_rbtree_elem, node);
                nft_rbtree_gc_remove(net, set, priv, rbe_prev);
@@ -272,7 +272,7 @@ static int nft_rbtree_gc_elem(const struct nft_set *__set,
                 */
                gc = nft_trans_gc_queue_sync(gc, GFP_ATOMIC);
                if (WARN_ON_ONCE(!gc))
-                       return -ENOMEM;
+                       return ERR_PTR(-ENOMEM);
 
                nft_trans_gc_elem_add(gc, rbe_prev);
        }
@@ -280,13 +280,13 @@ static int nft_rbtree_gc_elem(const struct nft_set *__set,
        nft_rbtree_gc_remove(net, set, priv, rbe);
        gc = nft_trans_gc_queue_sync(gc, GFP_ATOMIC);
        if (WARN_ON_ONCE(!gc))
-               return -ENOMEM;
+               return ERR_PTR(-ENOMEM);
 
        nft_trans_gc_elem_add(gc, rbe);
 
        nft_trans_gc_queue_sync_done(gc);
 
-       return 0;
+       return rbe_prev;
 }
 
 static bool nft_rbtree_update_first(const struct nft_set *set,
@@ -314,7 +314,7 @@ static int __nft_rbtree_insert(const struct net *net, const struct nft_set *set,
        struct nft_rbtree *priv = nft_set_priv(set);
        u8 cur_genmask = nft_genmask_cur(net);
        u8 genmask = nft_genmask_next(net);
-       int d, err;
+       int d;
 
        /* Descend the tree to search for an existing element greater than the
         * key value to insert that is greater than the new element. This is the
@@ -363,9 +363,14 @@ static int __nft_rbtree_insert(const struct net *net, const struct nft_set *set,
                 */
                if (nft_set_elem_expired(&rbe->ext) &&
                    nft_set_elem_active(&rbe->ext, cur_genmask)) {
-                       err = nft_rbtree_gc_elem(set, priv, rbe, genmask);
-                       if (err < 0)
-                               return err;
+                       const struct nft_rbtree_elem *removed_end;
+
+                       removed_end = nft_rbtree_gc_elem(set, priv, rbe, genmask);
+                       if (IS_ERR(removed_end))
+                               return PTR_ERR(removed_end);
+
+                       if (removed_end == rbe_le || removed_end == rbe_ge)
+                               return -EAGAIN;
 
                        continue;
                }
@@ -486,11 +491,18 @@ static int nft_rbtree_insert(const struct net *net, const struct nft_set *set,
        struct nft_rbtree_elem *rbe = elem->priv;
        int err;
 
-       write_lock_bh(&priv->lock);
-       write_seqcount_begin(&priv->count);
-       err = __nft_rbtree_insert(net, set, rbe, ext);
-       write_seqcount_end(&priv->count);
-       write_unlock_bh(&priv->lock);
+       do {
+               if (fatal_signal_pending(current))
+                       return -EINTR;
+
+               cond_resched();
+
+               write_lock_bh(&priv->lock);
+               write_seqcount_begin(&priv->count);
+               err = __nft_rbtree_insert(net, set, rbe, ext);
+               write_seqcount_end(&priv->count);
+               write_unlock_bh(&priv->lock);
+       } while (err == -EAGAIN);
 
        return err;
 }
@@ -556,6 +568,8 @@ static void *nft_rbtree_deactivate(const struct net *net,
                                   nft_rbtree_interval_end(this)) {
                                parent = parent->rb_right;
                                continue;
+                       } else if (nft_set_elem_expired(&rbe->ext)) {
+                               break;
                        } else if (!nft_set_elem_active(&rbe->ext, genmask)) {
                                parent = parent->rb_left;
                                continue;