rhashtable: Allow rhashtable to be used from irq-safe contexts
authorTejun Heo <tj@kernel.org>
Tue, 6 Dec 2022 21:36:32 +0000 (11:36 -1000)
committerDavid S. Miller <davem@davemloft.net>
Fri, 9 Dec 2022 10:42:56 +0000 (10:42 +0000)
rhashtable currently only does bh-safe synchronization making it impossible
to use from irq-safe contexts. Switch it to use irq-safe synchronization to
remove the restriction.

v2: Update the lock functions to return the ulong flags value and unlock
    functions to take the value directly instead of passing around the
    pointer. Suggested by Linus.

Signed-off-by: Tejun Heo <tj@kernel.org>
Reviewed-by: David Vernet <dvernet@meta.com>
Acked-by: Josh Don <joshdon@google.com>
Acked-by: Hao Luo <haoluo@google.com>
Acked-by: Barret Rhoden <brho@google.com>
Cc: Linus Torvalds <torvalds@linux-foundation.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/linux/rhashtable.h
lib/rhashtable.c

index 68dab3e..5b5357c 100644 (file)
@@ -323,29 +323,36 @@ static inline struct rhash_lock_head __rcu **rht_bucket_insert(
  * When we write to a bucket without unlocking, we use rht_assign_locked().
  */
 
-static inline void rht_lock(struct bucket_table *tbl,
-                           struct rhash_lock_head __rcu **bkt)
+static inline unsigned long rht_lock(struct bucket_table *tbl,
+                                    struct rhash_lock_head __rcu **bkt)
 {
-       local_bh_disable();
+       unsigned long flags;
+
+       local_irq_save(flags);
        bit_spin_lock(0, (unsigned long *)bkt);
        lock_map_acquire(&tbl->dep_map);
+       return flags;
 }
 
-static inline void rht_lock_nested(struct bucket_table *tbl,
-                                  struct rhash_lock_head __rcu **bucket,
-                                  unsigned int subclass)
+static inline unsigned long rht_lock_nested(struct bucket_table *tbl,
+                                       struct rhash_lock_head __rcu **bucket,
+                                       unsigned int subclass)
 {
-       local_bh_disable();
+       unsigned long flags;
+
+       local_irq_save(flags);
        bit_spin_lock(0, (unsigned long *)bucket);
        lock_acquire_exclusive(&tbl->dep_map, subclass, 0, NULL, _THIS_IP_);
+       return flags;
 }
 
 static inline void rht_unlock(struct bucket_table *tbl,
-                             struct rhash_lock_head __rcu **bkt)
+                             struct rhash_lock_head __rcu **bkt,
+                             unsigned long flags)
 {
        lock_map_release(&tbl->dep_map);
        bit_spin_unlock(0, (unsigned long *)bkt);
-       local_bh_enable();
+       local_irq_restore(flags);
 }
 
 static inline struct rhash_head *__rht_ptr(
@@ -393,7 +400,8 @@ static inline void rht_assign_locked(struct rhash_lock_head __rcu **bkt,
 
 static inline void rht_assign_unlock(struct bucket_table *tbl,
                                     struct rhash_lock_head __rcu **bkt,
-                                    struct rhash_head *obj)
+                                    struct rhash_head *obj,
+                                    unsigned long flags)
 {
        if (rht_is_a_nulls(obj))
                obj = NULL;
@@ -401,7 +409,7 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
        rcu_assign_pointer(*bkt, (void *)obj);
        preempt_enable();
        __release(bitlock);
-       local_bh_enable();
+       local_irq_restore(flags);
 }
 
 /**
@@ -706,6 +714,7 @@ static inline void *__rhashtable_insert_fast(
        struct rhash_head __rcu **pprev;
        struct bucket_table *tbl;
        struct rhash_head *head;
+       unsigned long flags;
        unsigned int hash;
        int elasticity;
        void *data;
@@ -720,11 +729,11 @@ static inline void *__rhashtable_insert_fast(
        if (!bkt)
                goto out;
        pprev = NULL;
-       rht_lock(tbl, bkt);
+       flags = rht_lock(tbl, bkt);
 
        if (unlikely(rcu_access_pointer(tbl->future_tbl))) {
 slow_path:
-               rht_unlock(tbl, bkt);
+               rht_unlock(tbl, bkt, flags);
                rcu_read_unlock();
                return rhashtable_insert_slow(ht, key, obj);
        }
@@ -756,9 +765,9 @@ slow_path:
                RCU_INIT_POINTER(list->rhead.next, head);
                if (pprev) {
                        rcu_assign_pointer(*pprev, obj);
-                       rht_unlock(tbl, bkt);
+                       rht_unlock(tbl, bkt, flags);
                } else
-                       rht_assign_unlock(tbl, bkt, obj);
+                       rht_assign_unlock(tbl, bkt, obj, flags);
                data = NULL;
                goto out;
        }
@@ -785,7 +794,7 @@ slow_path:
        }
 
        atomic_inc(&ht->nelems);
-       rht_assign_unlock(tbl, bkt, obj);
+       rht_assign_unlock(tbl, bkt, obj, flags);
 
        if (rht_grow_above_75(ht, tbl))
                schedule_work(&ht->run_work);
@@ -797,7 +806,7 @@ out:
        return data;
 
 out_unlock:
-       rht_unlock(tbl, bkt);
+       rht_unlock(tbl, bkt, flags);
        goto out;
 }
 
@@ -991,6 +1000,7 @@ static inline int __rhashtable_remove_fast_one(
        struct rhash_lock_head __rcu **bkt;
        struct rhash_head __rcu **pprev;
        struct rhash_head *he;
+       unsigned long flags;
        unsigned int hash;
        int err = -ENOENT;
 
@@ -999,7 +1009,7 @@ static inline int __rhashtable_remove_fast_one(
        if (!bkt)
                return -ENOENT;
        pprev = NULL;
-       rht_lock(tbl, bkt);
+       flags = rht_lock(tbl, bkt);
 
        rht_for_each_from(he, rht_ptr(bkt, tbl, hash), tbl, hash) {
                struct rhlist_head *list;
@@ -1043,14 +1053,14 @@ static inline int __rhashtable_remove_fast_one(
 
                if (pprev) {
                        rcu_assign_pointer(*pprev, obj);
-                       rht_unlock(tbl, bkt);
+                       rht_unlock(tbl, bkt, flags);
                } else {
-                       rht_assign_unlock(tbl, bkt, obj);
+                       rht_assign_unlock(tbl, bkt, obj, flags);
                }
                goto unlocked;
        }
 
-       rht_unlock(tbl, bkt);
+       rht_unlock(tbl, bkt, flags);
 unlocked:
        if (err > 0) {
                atomic_dec(&ht->nelems);
@@ -1143,6 +1153,7 @@ static inline int __rhashtable_replace_fast(
        struct rhash_lock_head __rcu **bkt;
        struct rhash_head __rcu **pprev;
        struct rhash_head *he;
+       unsigned long flags;
        unsigned int hash;
        int err = -ENOENT;
 
@@ -1158,7 +1169,7 @@ static inline int __rhashtable_replace_fast(
                return -ENOENT;
 
        pprev = NULL;
-       rht_lock(tbl, bkt);
+       flags = rht_lock(tbl, bkt);
 
        rht_for_each_from(he, rht_ptr(bkt, tbl, hash), tbl, hash) {
                if (he != obj_old) {
@@ -1169,15 +1180,15 @@ static inline int __rhashtable_replace_fast(
                rcu_assign_pointer(obj_new->next, obj_old->next);
                if (pprev) {
                        rcu_assign_pointer(*pprev, obj_new);
-                       rht_unlock(tbl, bkt);
+                       rht_unlock(tbl, bkt, flags);
                } else {
-                       rht_assign_unlock(tbl, bkt, obj_new);
+                       rht_assign_unlock(tbl, bkt, obj_new, flags);
                }
                err = 0;
                goto unlocked;
        }
 
-       rht_unlock(tbl, bkt);
+       rht_unlock(tbl, bkt, flags);
 
 unlocked:
        return err;
index e12bbfb..6ae2ba8 100644 (file)
@@ -231,6 +231,7 @@ static int rhashtable_rehash_one(struct rhashtable *ht,
        struct rhash_head *head, *next, *entry;
        struct rhash_head __rcu **pprev = NULL;
        unsigned int new_hash;
+       unsigned long flags;
 
        if (new_tbl->nest)
                goto out;
@@ -253,13 +254,14 @@ static int rhashtable_rehash_one(struct rhashtable *ht,
 
        new_hash = head_hashfn(ht, new_tbl, entry);
 
-       rht_lock_nested(new_tbl, &new_tbl->buckets[new_hash], SINGLE_DEPTH_NESTING);
+       flags = rht_lock_nested(new_tbl, &new_tbl->buckets[new_hash],
+                               SINGLE_DEPTH_NESTING);
 
        head = rht_ptr(new_tbl->buckets + new_hash, new_tbl, new_hash);
 
        RCU_INIT_POINTER(entry->next, head);
 
-       rht_assign_unlock(new_tbl, &new_tbl->buckets[new_hash], entry);
+       rht_assign_unlock(new_tbl, &new_tbl->buckets[new_hash], entry, flags);
 
        if (pprev)
                rcu_assign_pointer(*pprev, next);
@@ -276,18 +278,19 @@ static int rhashtable_rehash_chain(struct rhashtable *ht,
 {
        struct bucket_table *old_tbl = rht_dereference(ht->tbl, ht);
        struct rhash_lock_head __rcu **bkt = rht_bucket_var(old_tbl, old_hash);
+       unsigned long flags;
        int err;
 
        if (!bkt)
                return 0;
-       rht_lock(old_tbl, bkt);
+       flags = rht_lock(old_tbl, bkt);
 
        while (!(err = rhashtable_rehash_one(ht, bkt, old_hash)))
                ;
 
        if (err == -ENOENT)
                err = 0;
-       rht_unlock(old_tbl, bkt);
+       rht_unlock(old_tbl, bkt, flags);
 
        return err;
 }
@@ -590,6 +593,7 @@ static void *rhashtable_try_insert(struct rhashtable *ht, const void *key,
        struct bucket_table *new_tbl;
        struct bucket_table *tbl;
        struct rhash_lock_head __rcu **bkt;
+       unsigned long flags;
        unsigned int hash;
        void *data;
 
@@ -607,7 +611,7 @@ static void *rhashtable_try_insert(struct rhashtable *ht, const void *key,
                        new_tbl = rht_dereference_rcu(tbl->future_tbl, ht);
                        data = ERR_PTR(-EAGAIN);
                } else {
-                       rht_lock(tbl, bkt);
+                       flags = rht_lock(tbl, bkt);
                        data = rhashtable_lookup_one(ht, bkt, tbl,
                                                     hash, key, obj);
                        new_tbl = rhashtable_insert_one(ht, bkt, tbl,
@@ -615,7 +619,7 @@ static void *rhashtable_try_insert(struct rhashtable *ht, const void *key,
                        if (PTR_ERR(new_tbl) != -EEXIST)
                                data = ERR_CAST(new_tbl);
 
-                       rht_unlock(tbl, bkt);
+                       rht_unlock(tbl, bkt, flags);
                }
        } while (!IS_ERR_OR_NULL(new_tbl));