sock: optimise UDP sock_wfree() refcounting
authorPavel Begunkov <asml.silence@gmail.com>
Thu, 28 Apr 2022 10:58:18 +0000 (11:58 +0100)
committerDavid S. Miller <davem@davemloft.net>
Sun, 1 May 2022 11:19:01 +0000 (12:19 +0100)
For non SOCK_USE_WRITE_QUEUE sockets, sock_wfree() (atomically) puts
->sk_wmem_alloc twice. It's needed to keep the socket alive while
calling ->sk_write_space() after the first put.

However, some sockets, such as UDP, are freed by RCU
(i.e. SOCK_RCU_FREE) and use already RCU-safe sock_def_write_space().
Carve a fast path for such sockets, put down all refs in one go before
calling sock_def_write_space() but guard the socket from being freed
by an RCU read section.

note: because TCP sockets are marked with SOCK_USE_WRITE_QUEUE it
doesn't add extra checks in its path.

Signed-off-by: Pavel Begunkov <asml.silence@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/core/sock.c

index 6ad22fc2eb810005d60ab327bf2b09a97bf4675a..ab865b04130ba788cc5bd992f429fed245ab2d93 100644 (file)
 static DEFINE_MUTEX(proto_list_mutex);
 static LIST_HEAD(proto_list);
 
+static void sock_def_write_space(struct sock *sk);
+
 /**
  * sk_ns_capable - General socket capability test
  * @sk: Socket to use a capability on or through
@@ -2324,8 +2326,20 @@ void sock_wfree(struct sk_buff *skb)
 {
        struct sock *sk = skb->sk;
        unsigned int len = skb->truesize;
+       bool free;
 
        if (!sock_flag(sk, SOCK_USE_WRITE_QUEUE)) {
+               if (sock_flag(sk, SOCK_RCU_FREE) &&
+                   sk->sk_write_space == sock_def_write_space) {
+                       rcu_read_lock();
+                       free = refcount_sub_and_test(len, &sk->sk_wmem_alloc);
+                       sock_def_write_space(sk);
+                       rcu_read_unlock();
+                       if (unlikely(free))
+                               __sk_free(sk);
+                       return;
+               }
+
                /*
                 * Keep a reference on sk_wmem_alloc, this will be released
                 * after sk_write_space() call