net: sk_msg: Simplify sk_psock initialization
authorLorenz Bauer <lmb@cloudflare.com>
Fri, 21 Aug 2020 10:29:43 +0000 (11:29 +0100)
committerAlexei Starovoitov <ast@kernel.org>
Fri, 21 Aug 2020 22:16:11 +0000 (15:16 -0700)
Initializing psock->sk_proto and other saved callbacks is only
done in sk_psock_update_proto, after sk_psock_init has returned.
The logic for this is difficult to follow, and needlessly complex.

Instead, initialize psock->sk_proto whenever we allocate a new
psock. Additionally, assert the following invariants:

* The SK has no ULP: ULP does it's own finagling of sk->sk_prot
* sk_user_data is unused: we need it to store sk_psock

Protect our access to sk_user_data with sk_callback_lock, which
is what other users like reuseport arrays, etc. do.

The result is that an sk_psock is always fully initialized, and
that psock->sk_proto is always the "original" struct proto.
The latter allows us to use psock->sk_proto when initializing
IPv6 TCP / UDP callbacks for sockmap.

Signed-off-by: Lorenz Bauer <lmb@cloudflare.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Acked-by: John Fastabend <john.fastabend@gmail.com>
Link: https://lore.kernel.org/bpf/20200821102948.21918-2-lmb@cloudflare.com
include/linux/skmsg.h
net/core/skmsg.c
net/core/sock_map.c
net/ipv4/tcp_bpf.c
net/ipv4/udp_bpf.c

index 1e9ed84..3119928 100644 (file)
@@ -340,23 +340,6 @@ static inline void sk_psock_update_proto(struct sock *sk,
                                         struct sk_psock *psock,
                                         struct proto *ops)
 {
-       /* Initialize saved callbacks and original proto only once, since this
-        * function may be called multiple times for a psock, e.g. when
-        * psock->progs.msg_parser is updated.
-        *
-        * Since we've not installed the new proto, psock is not yet in use and
-        * we can initialize it without synchronization.
-        */
-       if (!psock->sk_proto) {
-               struct proto *orig = READ_ONCE(sk->sk_prot);
-
-               psock->saved_unhash = orig->unhash;
-               psock->saved_close = orig->close;
-               psock->saved_write_space = sk->sk_write_space;
-
-               psock->sk_proto = orig;
-       }
-
        /* Pairs with lockless read in sk_clone_lock() */
        WRITE_ONCE(sk->sk_prot, ops);
 }
index 6a32a1f..1c81caf 100644 (file)
@@ -494,14 +494,34 @@ end:
 
 struct sk_psock *sk_psock_init(struct sock *sk, int node)
 {
-       struct sk_psock *psock = kzalloc_node(sizeof(*psock),
-                                             GFP_ATOMIC | __GFP_NOWARN,
-                                             node);
-       if (!psock)
-               return NULL;
+       struct sk_psock *psock;
+       struct proto *prot;
+
+       write_lock_bh(&sk->sk_callback_lock);
+
+       if (inet_csk_has_ulp(sk)) {
+               psock = ERR_PTR(-EINVAL);
+               goto out;
+       }
+
+       if (sk->sk_user_data) {
+               psock = ERR_PTR(-EBUSY);
+               goto out;
+       }
 
+       psock = kzalloc_node(sizeof(*psock), GFP_ATOMIC | __GFP_NOWARN, node);
+       if (!psock) {
+               psock = ERR_PTR(-ENOMEM);
+               goto out;
+       }
+
+       prot = READ_ONCE(sk->sk_prot);
        psock->sk = sk;
-       psock->eval =  __SK_NONE;
+       psock->eval = __SK_NONE;
+       psock->sk_proto = prot;
+       psock->saved_unhash = prot->unhash;
+       psock->saved_close = prot->close;
+       psock->saved_write_space = sk->sk_write_space;
 
        INIT_LIST_HEAD(&psock->link);
        spin_lock_init(&psock->link_lock);
@@ -516,6 +536,8 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)
        rcu_assign_sk_user_data_nocopy(sk, psock);
        sock_hold(sk);
 
+out:
+       write_unlock_bh(&sk->sk_callback_lock);
        return psock;
 }
 EXPORT_SYMBOL_GPL(sk_psock_init);
index 119f52a..abe4bac 100644 (file)
@@ -184,8 +184,6 @@ static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
 {
        struct proto *prot;
 
-       sock_owned_by_me(sk);
-
        switch (sk->sk_type) {
        case SOCK_STREAM:
                prot = tcp_bpf_get_proto(sk, psock);
@@ -272,8 +270,8 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
                }
        } else {
                psock = sk_psock_init(sk, map->numa_node);
-               if (!psock) {
-                       ret = -ENOMEM;
+               if (IS_ERR(psock)) {
+                       ret = PTR_ERR(psock);
                        goto out_progs;
                }
        }
@@ -322,8 +320,8 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
 
        if (!psock) {
                psock = sk_psock_init(sk, map->numa_node);
-               if (!psock)
-                       return -ENOMEM;
+               if (IS_ERR(psock))
+                       return PTR_ERR(psock);
        }
 
        ret = sock_map_init_proto(sk, psock);
@@ -478,8 +476,6 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
                return -EINVAL;
        if (unlikely(idx >= map->max_entries))
                return -E2BIG;
-       if (inet_csk_has_ulp(sk))
-               return -EINVAL;
 
        link = sk_psock_init_link();
        if (!link)
@@ -855,8 +851,6 @@ static int sock_hash_update_common(struct bpf_map *map, void *key,
        WARN_ON_ONCE(!rcu_read_lock_held());
        if (unlikely(flags > BPF_EXIST))
                return -EINVAL;
-       if (inet_csk_has_ulp(sk))
-               return -EINVAL;
 
        link = sk_psock_init_link();
        if (!link)
index 7aa68f4..37f4cb2 100644 (file)
@@ -567,10 +567,9 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
        prot[TCP_BPF_TX].sendpage               = tcp_bpf_sendpage;
 }
 
-static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
+static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops)
 {
-       if (sk->sk_family == AF_INET6 &&
-           unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
+       if (unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
                spin_lock_bh(&tcpv6_prot_lock);
                if (likely(ops != tcpv6_prot_saved)) {
                        tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
@@ -603,13 +602,11 @@ struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
        int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
        int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 
-       if (!psock->sk_proto) {
-               struct proto *ops = READ_ONCE(sk->sk_prot);
-
-               if (tcp_bpf_assert_proto_ops(ops))
+       if (sk->sk_family == AF_INET6) {
+               if (tcp_bpf_assert_proto_ops(psock->sk_proto))
                        return ERR_PTR(-EINVAL);
 
-               tcp_bpf_check_v6_needs_rebuild(sk, ops);
+               tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
        }
 
        return &tcp_bpf_prots[family][config];
index eddd973..7a94791 100644 (file)
@@ -22,10 +22,9 @@ static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
        prot->close  = sock_map_close;
 }
 
-static void udp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
+static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
 {
-       if (sk->sk_family == AF_INET6 &&
-           unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
+       if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
                spin_lock_bh(&udpv6_prot_lock);
                if (likely(ops != udpv6_prot_saved)) {
                        udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops);
@@ -46,8 +45,8 @@ struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
 {
        int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
 
-       if (!psock->sk_proto)
-               udp_bpf_check_v6_needs_rebuild(sk, READ_ONCE(sk->sk_prot));
+       if (sk->sk_family == AF_INET6)
+               udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
 
        return &udp_bpf_prots[family];
 }