netfilter: nf_conntrack: use safer way to lock all buckets
[platform/kernel/linux-rpi.git] / net / netfilter / nf_conntrack_core.c
index 3cb3cb8..58882de 100644 (file)
@@ -66,6 +66,21 @@ EXPORT_SYMBOL_GPL(nf_conntrack_locks);
 __cacheline_aligned_in_smp DEFINE_SPINLOCK(nf_conntrack_expect_lock);
 EXPORT_SYMBOL_GPL(nf_conntrack_expect_lock);
 
+static __read_mostly spinlock_t nf_conntrack_locks_all_lock;
+static __read_mostly bool nf_conntrack_locks_all;
+
+void nf_conntrack_lock(spinlock_t *lock) __acquires(lock)
+{
+       spin_lock(lock);
+       while (unlikely(nf_conntrack_locks_all)) {
+               spin_unlock(lock);
+               spin_lock(&nf_conntrack_locks_all_lock);
+               spin_unlock(&nf_conntrack_locks_all_lock);
+               spin_lock(lock);
+       }
+}
+EXPORT_SYMBOL_GPL(nf_conntrack_lock);
+
 static void nf_conntrack_double_unlock(unsigned int h1, unsigned int h2)
 {
        h1 %= CONNTRACK_LOCKS;
@@ -82,12 +97,12 @@ static bool nf_conntrack_double_lock(struct net *net, unsigned int h1,
        h1 %= CONNTRACK_LOCKS;
        h2 %= CONNTRACK_LOCKS;
        if (h1 <= h2) {
-               spin_lock(&nf_conntrack_locks[h1]);
+               nf_conntrack_lock(&nf_conntrack_locks[h1]);
                if (h1 != h2)
                        spin_lock_nested(&nf_conntrack_locks[h2],
                                         SINGLE_DEPTH_NESTING);
        } else {
-               spin_lock(&nf_conntrack_locks[h2]);
+               nf_conntrack_lock(&nf_conntrack_locks[h2]);
                spin_lock_nested(&nf_conntrack_locks[h1],
                                 SINGLE_DEPTH_NESTING);
        }
@@ -102,16 +117,19 @@ static void nf_conntrack_all_lock(void)
 {
        int i;
 
-       for (i = 0; i < CONNTRACK_LOCKS; i++)
-               spin_lock_nested(&nf_conntrack_locks[i], i);
+       spin_lock(&nf_conntrack_locks_all_lock);
+       nf_conntrack_locks_all = true;
+
+       for (i = 0; i < CONNTRACK_LOCKS; i++) {
+               spin_lock(&nf_conntrack_locks[i]);
+               spin_unlock(&nf_conntrack_locks[i]);
+       }
 }
 
 static void nf_conntrack_all_unlock(void)
 {
-       int i;
-
-       for (i = 0; i < CONNTRACK_LOCKS; i++)
-               spin_unlock(&nf_conntrack_locks[i]);
+       nf_conntrack_locks_all = false;
+       spin_unlock(&nf_conntrack_locks_all_lock);
 }
 
 unsigned int nf_conntrack_htable_size __read_mostly;
@@ -757,7 +775,7 @@ restart:
        hash = hash_bucket(_hash, net);
        for (; i < net->ct.htable_size; i++) {
                lockp = &nf_conntrack_locks[hash % CONNTRACK_LOCKS];
-               spin_lock(lockp);
+               nf_conntrack_lock(lockp);
                if (read_seqcount_retry(&net->ct.generation, sequence)) {
                        spin_unlock(lockp);
                        goto restart;
@@ -1382,7 +1400,7 @@ get_next_corpse(struct net *net, int (*iter)(struct nf_conn *i, void *data),
        for (; *bucket < net->ct.htable_size; (*bucket)++) {
                lockp = &nf_conntrack_locks[*bucket % CONNTRACK_LOCKS];
                local_bh_disable();
-               spin_lock(lockp);
+               nf_conntrack_lock(lockp);
                if (*bucket < net->ct.htable_size) {
                        hlist_nulls_for_each_entry(h, n, &net->ct.hash[*bucket], hnnode) {
                                if (NF_CT_DIRECTION(h) != IP_CT_DIR_ORIGINAL)