af_unix: Put a socket into a per-netns hash table.
authorKuniyuki Iwashima <kuniyu@amazon.com>
Tue, 21 Jun 2022 17:19:12 +0000 (10:19 -0700)
committerDavid S. Miller <davem@davemloft.net>
Wed, 22 Jun 2022 11:59:43 +0000 (12:59 +0100)
This commit replaces the global hash table with a per-netns one and removes
the global one.

We now link a socket in each netns's hash table so we can save some netns
comparisons when iterating through a hash bucket.

Signed-off-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/net/af_unix.h
net/unix/af_unix.c
net/unix/diag.c

index acb56e4..b1748c9 100644 (file)
@@ -22,7 +22,6 @@ struct sock *unix_peer_get(struct sock *sk);
 
 extern unsigned int unix_tot_inflight;
 extern spinlock_t unix_table_locks[UNIX_HASH_SIZE];
-extern struct hlist_head unix_socket_table[UNIX_HASH_SIZE];
 
 struct unix_address {
        refcount_t      refcnt;
index 79f8fc5..9d0b072 100644 (file)
 
 spinlock_t unix_table_locks[UNIX_HASH_SIZE];
 EXPORT_SYMBOL_GPL(unix_table_locks);
-struct hlist_head unix_socket_table[UNIX_HASH_SIZE];
-EXPORT_SYMBOL_GPL(unix_socket_table);
 static atomic_long_t unix_nr_socks;
 
 /* SMP locking strategy:
@@ -308,20 +306,20 @@ static void __unix_remove_socket(struct sock *sk)
        sk_del_node_init(sk);
 }
 
-static void __unix_insert_socket(struct sock *sk)
+static void __unix_insert_socket(struct net *net, struct sock *sk)
 {
        DEBUG_NET_WARN_ON_ONCE(!sk_unhashed(sk));
-       sk_add_node(sk, &unix_socket_table[sk->sk_hash]);
+       sk_add_node(sk, &net->unx.table.buckets[sk->sk_hash]);
 }
 
-static void __unix_set_addr_hash(struct sock *sk, struct unix_address *addr,
-                                unsigned int hash)
+static void __unix_set_addr_hash(struct net *net, struct sock *sk,
+                                struct unix_address *addr, unsigned int hash)
 {
        __unix_remove_socket(sk);
        smp_store_release(&unix_sk(sk)->addr, addr);
 
        sk->sk_hash = hash;
-       __unix_insert_socket(sk);
+       __unix_insert_socket(net, sk);
 }
 
 static void unix_remove_socket(struct net *net, struct sock *sk)
@@ -337,7 +335,7 @@ static void unix_insert_unbound_socket(struct net *net, struct sock *sk)
 {
        spin_lock(&unix_table_locks[sk->sk_hash]);
        spin_lock(&net->unx.table.locks[sk->sk_hash]);
-       __unix_insert_socket(sk);
+       __unix_insert_socket(net, sk);
        spin_unlock(&net->unx.table.locks[sk->sk_hash]);
        spin_unlock(&unix_table_locks[sk->sk_hash]);
 }
@@ -348,12 +346,9 @@ static struct sock *__unix_find_socket_byname(struct net *net,
 {
        struct sock *s;
 
-       sk_for_each(s, &unix_socket_table[hash]) {
+       sk_for_each(s, &net->unx.table.buckets[hash]) {
                struct unix_sock *u = unix_sk(s);
 
-               if (!net_eq(sock_net(s), net))
-                       continue;
-
                if (u->addr->len == len &&
                    !memcmp(u->addr->name, sunname, len))
                        return s;
@@ -384,7 +379,7 @@ static struct sock *unix_find_socket_byinode(struct net *net, struct inode *i)
 
        spin_lock(&unix_table_locks[hash]);
        spin_lock(&net->unx.table.locks[hash]);
-       sk_for_each(s, &unix_socket_table[hash]) {
+       sk_for_each(s, &net->unx.table.buckets[hash]) {
                struct dentry *dentry = unix_sk(s)->path.dentry;
 
                if (dentry && d_backing_inode(dentry) == i) {
@@ -1140,7 +1135,7 @@ retry:
                goto retry;
        }
 
-       __unix_set_addr_hash(sk, addr, new_hash);
+       __unix_set_addr_hash(net, sk, addr, new_hash);
        unix_table_double_unlock(net, old_hash, new_hash);
        err = 0;
 
@@ -1199,7 +1194,7 @@ static int unix_bind_bsd(struct sock *sk, struct sockaddr_un *sunaddr,
        unix_table_double_lock(net, old_hash, new_hash);
        u->path.mnt = mntget(parent.mnt);
        u->path.dentry = dget(dentry);
-       __unix_set_addr_hash(sk, addr, new_hash);
+       __unix_set_addr_hash(net, sk, addr, new_hash);
        unix_table_double_unlock(net, old_hash, new_hash);
        mutex_unlock(&u->bindlock);
        done_path_create(&parent, dentry);
@@ -1246,7 +1241,7 @@ static int unix_bind_abstract(struct sock *sk, struct sockaddr_un *sunaddr,
        if (__unix_find_socket_byname(net, addr->name, addr->len, new_hash))
                goto out_spin;
 
-       __unix_set_addr_hash(sk, addr, new_hash);
+       __unix_set_addr_hash(net, sk, addr, new_hash);
        unix_table_double_unlock(net, old_hash, new_hash);
        mutex_unlock(&u->bindlock);
        return 0;
@@ -3239,12 +3234,11 @@ static struct sock *unix_from_bucket(struct seq_file *seq, loff_t *pos)
 {
        unsigned long offset = get_offset(*pos);
        unsigned long bucket = get_bucket(*pos);
-       struct sock *sk;
        unsigned long count = 0;
+       struct sock *sk;
 
-       for (sk = sk_head(&unix_socket_table[bucket]); sk; sk = sk_next(sk)) {
-               if (sock_net(sk) != seq_file_net(seq))
-                       continue;
+       for (sk = sk_head(&seq_file_net(seq)->unx.table.buckets[bucket]);
+            sk; sk = sk_next(sk)) {
                if (++count == offset)
                        break;
        }
@@ -3279,13 +3273,13 @@ static struct sock *unix_get_next(struct seq_file *seq, struct sock *sk,
                                  loff_t *pos)
 {
        unsigned long bucket = get_bucket(*pos);
-       struct net *net = seq_file_net(seq);
 
-       for (sk = sk_next(sk); sk; sk = sk_next(sk))
-               if (sock_net(sk) == net)
-                       return sk;
+       sk = sk_next(sk);
+       if (sk)
+               return sk;
+
 
-       spin_unlock(&net->unx.table.locks[bucket]);
+       spin_unlock(&seq_file_net(seq)->unx.table.locks[bucket]);
        spin_unlock(&unix_table_locks[bucket]);
 
        *pos = set_bucket_offset(++bucket, 1);
@@ -3406,7 +3400,6 @@ static int bpf_iter_unix_hold_batch(struct seq_file *seq, struct sock *start_sk)
 
 {
        struct bpf_unix_iter_state *iter = seq->private;
-       struct net *net = seq_file_net(seq);
        unsigned int expected = 1;
        struct sock *sk;
 
@@ -3414,9 +3407,6 @@ static int bpf_iter_unix_hold_batch(struct seq_file *seq, struct sock *start_sk)
        iter->batch[iter->end_sk++] = start_sk;
 
        for (sk = sk_next(start_sk); sk; sk = sk_next(sk)) {
-               if (sock_net(sk) != net)
-                       continue;
-
                if (iter->end_sk < iter->max_sk) {
                        sock_hold(sk);
                        iter->batch[iter->end_sk++] = sk;
@@ -3425,7 +3415,7 @@ static int bpf_iter_unix_hold_batch(struct seq_file *seq, struct sock *start_sk)
                expected++;
        }
 
-       spin_unlock(&net->unx.table.locks[start_sk->sk_hash]);
+       spin_unlock(&seq_file_net(seq)->unx.table.locks[start_sk->sk_hash]);
        spin_unlock(&unix_table_locks[start_sk->sk_hash]);
 
        return expected;
index 7fc3774..4d0f0ca 100644 (file)
@@ -210,9 +210,7 @@ static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
                num = 0;
                spin_lock(&unix_table_locks[slot]);
                spin_lock(&net->unx.table.locks[slot]);
-               sk_for_each(sk, &unix_socket_table[slot]) {
-                       if (!net_eq(sock_net(sk), net))
-                               continue;
+               sk_for_each(sk, &net->unx.table.buckets[slot]) {
                        if (num < s_num)
                                goto next;
                        if (!(req->udiag_states & (1 << sk->sk_state)))
@@ -246,13 +244,14 @@ static struct sock *unix_lookup_by_ino(struct net *net, unsigned int ino)
        for (i = 0; i < UNIX_HASH_SIZE; i++) {
                spin_lock(&unix_table_locks[i]);
                spin_lock(&net->unx.table.locks[i]);
-               sk_for_each(sk, &unix_socket_table[i])
+               sk_for_each(sk, &net->unx.table.buckets[i]) {
                        if (ino == sock_i_ino(sk)) {
                                sock_hold(sk);
                                spin_unlock(&net->unx.table.locks[i]);
                                spin_unlock(&unix_table_locks[i]);
                                return sk;
                        }
+               }
                spin_unlock(&net->unx.table.locks[i]);
                spin_unlock(&unix_table_locks[i]);
        }
@@ -277,8 +276,6 @@ static int unix_diag_get_exact(struct sk_buff *in_skb,
        err = -ENOENT;
        if (sk == NULL)
                goto out_nosk;
-       if (!net_eq(sock_net(sk), net))
-               goto out;
 
        err = sock_diag_check_cookie(sk, req->udiag_cookie);
        if (err)