ping: convert to RCU lookups, get rid of rwlock
authorEric Dumazet <edumazet@google.com>
Sat, 18 Jun 2022 04:04:15 +0000 (21:04 -0700)
committerDavid S. Miller <davem@davemloft.net>
Sat, 18 Jun 2022 09:54:29 +0000 (10:54 +0100)
Using rwlock in networking code is extremely risky.
writers can starve if enough readers are constantly
grabing the rwlock.

I thought rwlock were at fault and sent this patch:

https://lkml.org/lkml/2022/6/17/272

But Peter and Linus essentially told me rwlock had to be unfair.

We need to get rid of rwlock in networking code.

Fixes: c319b4d76b9e ("net: ipv4: add IPPROTO_ICMP socket kind")
Signed-off-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/ipv4/ping.c

index 1a43ca73f94d68800f7a849a01238f3d089ff8b5..3d6fc6d841b8e72671e807e7b1fc7a6686a78fb6 100644 (file)
@@ -50,7 +50,7 @@
 
 struct ping_table {
        struct hlist_nulls_head hash[PING_HTABLE_SIZE];
-       rwlock_t                lock;
+       spinlock_t              lock;
 };
 
 static struct ping_table ping_table;
@@ -82,7 +82,7 @@ int ping_get_port(struct sock *sk, unsigned short ident)
        struct sock *sk2 = NULL;
 
        isk = inet_sk(sk);
-       write_lock_bh(&ping_table.lock);
+       spin_lock(&ping_table.lock);
        if (ident == 0) {
                u32 i;
                u16 result = ping_port_rover + 1;
@@ -128,14 +128,15 @@ next_port:
        if (sk_unhashed(sk)) {
                pr_debug("was not hashed\n");
                sock_hold(sk);
-               hlist_nulls_add_head(&sk->sk_nulls_node, hlist);
+               sock_set_flag(sk, SOCK_RCU_FREE);
+               hlist_nulls_add_head_rcu(&sk->sk_nulls_node, hlist);
                sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
        }
-       write_unlock_bh(&ping_table.lock);
+       spin_unlock(&ping_table.lock);
        return 0;
 
 fail:
-       write_unlock_bh(&ping_table.lock);
+       spin_unlock(&ping_table.lock);
        return 1;
 }
 EXPORT_SYMBOL_GPL(ping_get_port);
@@ -153,19 +154,19 @@ void ping_unhash(struct sock *sk)
        struct inet_sock *isk = inet_sk(sk);
 
        pr_debug("ping_unhash(isk=%p,isk->num=%u)\n", isk, isk->inet_num);
-       write_lock_bh(&ping_table.lock);
+       spin_lock(&ping_table.lock);
        if (sk_hashed(sk)) {
-               hlist_nulls_del(&sk->sk_nulls_node);
-               sk_nulls_node_init(&sk->sk_nulls_node);
+               hlist_nulls_del_init_rcu(&sk->sk_nulls_node);
                sock_put(sk);
                isk->inet_num = 0;
                isk->inet_sport = 0;
                sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
        }
-       write_unlock_bh(&ping_table.lock);
+       spin_unlock(&ping_table.lock);
 }
 EXPORT_SYMBOL_GPL(ping_unhash);
 
+/* Called under rcu_read_lock() */
 static struct sock *ping_lookup(struct net *net, struct sk_buff *skb, u16 ident)
 {
        struct hlist_nulls_head *hslot = ping_hashslot(&ping_table, net, ident);
@@ -190,8 +191,6 @@ static struct sock *ping_lookup(struct net *net, struct sk_buff *skb, u16 ident)
                return NULL;
        }
 
-       read_lock_bh(&ping_table.lock);
-
        ping_portaddr_for_each_entry(sk, hnode, hslot) {
                isk = inet_sk(sk);
 
@@ -230,13 +229,11 @@ static struct sock *ping_lookup(struct net *net, struct sk_buff *skb, u16 ident)
                    sk->sk_bound_dev_if != sdif)
                        continue;
 
-               sock_hold(sk);
                goto exit;
        }
 
        sk = NULL;
 exit:
-       read_unlock_bh(&ping_table.lock);
 
        return sk;
 }
@@ -588,7 +585,7 @@ void ping_err(struct sk_buff *skb, int offset, u32 info)
        sk->sk_err = err;
        sk_error_report(sk);
 out:
-       sock_put(sk);
+       return;
 }
 EXPORT_SYMBOL_GPL(ping_err);
 
@@ -994,7 +991,6 @@ enum skb_drop_reason ping_rcv(struct sk_buff *skb)
                        reason = __ping_queue_rcv_skb(sk, skb2);
                else
                        reason = SKB_DROP_REASON_NOMEM;
-               sock_put(sk);
        }
 
        if (reason)
@@ -1080,13 +1076,13 @@ static struct sock *ping_get_idx(struct seq_file *seq, loff_t pos)
 }
 
 void *ping_seq_start(struct seq_file *seq, loff_t *pos, sa_family_t family)
-       __acquires(ping_table.lock)
+       __acquires(RCU)
 {
        struct ping_iter_state *state = seq->private;
        state->bucket = 0;
        state->family = family;
 
-       read_lock_bh(&ping_table.lock);
+       rcu_read_lock();
 
        return *pos ? ping_get_idx(seq, *pos-1) : SEQ_START_TOKEN;
 }
@@ -1112,9 +1108,9 @@ void *ping_seq_next(struct seq_file *seq, void *v, loff_t *pos)
 EXPORT_SYMBOL_GPL(ping_seq_next);
 
 void ping_seq_stop(struct seq_file *seq, void *v)
-       __releases(ping_table.lock)
+       __releases(RCU)
 {
-       read_unlock_bh(&ping_table.lock);
+       rcu_read_unlock();
 }
 EXPORT_SYMBOL_GPL(ping_seq_stop);
 
@@ -1198,5 +1194,5 @@ void __init ping_init(void)
 
        for (i = 0; i < PING_HTABLE_SIZE; i++)
                INIT_HLIST_NULLS_HEAD(&ping_table.hash[i], i);
-       rwlock_init(&ping_table.lock);
+       spin_lock_init(&ping_table.lock);
 }