bpf: sockmap: Move generic sockmap hooks from BPF TCP
authorLorenz Bauer <lmb@cloudflare.com>
Mon, 9 Mar 2020 11:12:36 +0000 (11:12 +0000)
committerDaniel Borkmann <daniel@iogearbox.net>
Mon, 9 Mar 2020 21:34:58 +0000 (22:34 +0100)
The init, close and unhash handlers from TCP sockmap are generic,
and can be reused by UDP sockmap. Move the helpers into the sockmap code
base and expose them. This requires tcp_bpf_get_proto and tcp_bpf_clone to
be conditional on BPF_STREAM_PARSER.

The moved functions are unmodified, except that sk_psock_unlink is
renamed to sock_map_unlink to better match its behaviour.

Signed-off-by: Lorenz Bauer <lmb@cloudflare.com>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
Reviewed-by: Jakub Sitnicki <jakub@cloudflare.com>
Acked-by: John Fastabend <john.fastabend@gmail.com>
Link: https://lore.kernel.org/bpf/20200309111243.6982-6-lmb@cloudflare.com
include/linux/bpf.h
include/linux/skmsg.h
include/net/tcp.h
net/core/sock_map.c
net/ipv4/tcp_bpf.c

index 40c5392..94a329b 100644 (file)
@@ -1419,6 +1419,8 @@ static inline void bpf_map_offload_map_free(struct bpf_map *map)
 #if defined(CONFIG_BPF_STREAM_PARSER)
 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, u32 which);
 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog);
+void sock_map_unhash(struct sock *sk);
+void sock_map_close(struct sock *sk, long timeout);
 #else
 static inline int sock_map_prog_update(struct bpf_map *map,
                                       struct bpf_prog *prog, u32 which)
@@ -1431,7 +1433,7 @@ static inline int sock_map_get_from_fd(const union bpf_attr *attr,
 {
        return -EINVAL;
 }
-#endif
+#endif /* CONFIG_BPF_STREAM_PARSER */
 
 #if defined(CONFIG_INET) && defined(CONFIG_BPF_SYSCALL)
 void bpf_sk_reuseport_detach(struct sock *sk);
index 2be51b7..8a709f6 100644 (file)
@@ -323,14 +323,6 @@ static inline void sk_psock_free_link(struct sk_psock_link *link)
 }
 
 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock);
-#if defined(CONFIG_BPF_STREAM_PARSER)
-void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link);
-#else
-static inline void sk_psock_unlink(struct sock *sk,
-                                  struct sk_psock_link *link)
-{
-}
-#endif
 
 void __sk_psock_purge_ingress_msg(struct sk_psock *psock);
 
@@ -399,26 +391,6 @@ static inline bool sk_psock_test_state(const struct sk_psock *psock,
        return test_bit(bit, &psock->state);
 }
 
-static inline struct sk_psock *sk_psock_get_checked(struct sock *sk)
-{
-       struct sk_psock *psock;
-
-       rcu_read_lock();
-       psock = sk_psock(sk);
-       if (psock) {
-               if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
-                       psock = ERR_PTR(-EBUSY);
-                       goto out;
-               }
-
-               if (!refcount_inc_not_zero(&psock->refcnt))
-                       psock = ERR_PTR(-EBUSY);
-       }
-out:
-       rcu_read_unlock();
-       return psock;
-}
-
 static inline struct sk_psock *sk_psock_get(struct sock *sk)
 {
        struct sk_psock *psock;
index ad3abea..43fa07a 100644 (file)
@@ -2195,19 +2195,22 @@ void tcp_update_ulp(struct sock *sk, struct proto *p,
 struct sk_msg;
 struct sk_psock;
 
+#ifdef CONFIG_BPF_STREAM_PARSER
+struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
+void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
+#else
+static inline void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
+{
+}
+#endif /* CONFIG_BPF_STREAM_PARSER */
+
 #ifdef CONFIG_NET_SOCK_MSG
-int tcp_bpf_init(struct sock *sk);
 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, u32 bytes,
                          int flags);
 int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
                    int nonblock, int flags, int *addr_len);
 int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
                      struct msghdr *msg, int len, int flags);
-void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
-#else
-static inline void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
-{
-}
 #endif /* CONFIG_NET_SOCK_MSG */
 
 /* Call BPF_SOCK_OPS program that returns an int. If the return value
index fafcbd2..cb240d8 100644 (file)
@@ -141,6 +141,51 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
        }
 }
 
+static int sock_map_init_proto(struct sock *sk)
+{
+       struct sk_psock *psock;
+       struct proto *prot;
+
+       sock_owned_by_me(sk);
+
+       rcu_read_lock();
+       psock = sk_psock(sk);
+       if (unlikely(!psock)) {
+               rcu_read_unlock();
+               return -EINVAL;
+       }
+
+       prot = tcp_bpf_get_proto(sk, psock);
+       if (IS_ERR(prot)) {
+               rcu_read_unlock();
+               return PTR_ERR(prot);
+       }
+
+       sk_psock_update_proto(sk, psock, prot);
+       rcu_read_unlock();
+       return 0;
+}
+
+static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
+{
+       struct sk_psock *psock;
+
+       rcu_read_lock();
+       psock = sk_psock(sk);
+       if (psock) {
+               if (sk->sk_prot->recvmsg != tcp_bpf_recvmsg) {
+                       psock = ERR_PTR(-EBUSY);
+                       goto out;
+               }
+
+               if (!refcount_inc_not_zero(&psock->refcnt))
+                       psock = ERR_PTR(-EBUSY);
+       }
+out:
+       rcu_read_unlock();
+       return psock;
+}
+
 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
                         struct sock *sk)
 {
@@ -172,7 +217,7 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
                }
        }
 
-       psock = sk_psock_get_checked(sk);
+       psock = sock_map_psock_get_checked(sk);
        if (IS_ERR(psock)) {
                ret = PTR_ERR(psock);
                goto out_progs;
@@ -196,7 +241,7 @@ static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
        if (msg_parser)
                psock_set_prog(&psock->progs.msg_parser, msg_parser);
 
-       ret = tcp_bpf_init(sk);
+       ret = sock_map_init_proto(sk);
        if (ret < 0)
                goto out_drop;
 
@@ -231,7 +276,7 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
        struct sk_psock *psock;
        int ret;
 
-       psock = sk_psock_get_checked(sk);
+       psock = sock_map_psock_get_checked(sk);
        if (IS_ERR(psock))
                return PTR_ERR(psock);
 
@@ -241,7 +286,7 @@ static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
                        return -ENOMEM;
        }
 
-       ret = tcp_bpf_init(sk);
+       ret = sock_map_init_proto(sk);
        if (ret < 0)
                sk_psock_put(sk, psock);
        return ret;
@@ -1120,7 +1165,7 @@ int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
        return 0;
 }
 
-void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
+static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link)
 {
        switch (link->map->map_type) {
        case BPF_MAP_TYPE_SOCKMAP:
@@ -1133,3 +1178,54 @@ void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
                break;
        }
 }
+
+static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock)
+{
+       struct sk_psock_link *link;
+
+       while ((link = sk_psock_link_pop(psock))) {
+               sock_map_unlink(sk, link);
+               sk_psock_free_link(link);
+       }
+}
+
+void sock_map_unhash(struct sock *sk)
+{
+       void (*saved_unhash)(struct sock *sk);
+       struct sk_psock *psock;
+
+       rcu_read_lock();
+       psock = sk_psock(sk);
+       if (unlikely(!psock)) {
+               rcu_read_unlock();
+               if (sk->sk_prot->unhash)
+                       sk->sk_prot->unhash(sk);
+               return;
+       }
+
+       saved_unhash = psock->saved_unhash;
+       sock_map_remove_links(sk, psock);
+       rcu_read_unlock();
+       saved_unhash(sk);
+}
+
+void sock_map_close(struct sock *sk, long timeout)
+{
+       void (*saved_close)(struct sock *sk, long timeout);
+       struct sk_psock *psock;
+
+       lock_sock(sk);
+       rcu_read_lock();
+       psock = sk_psock(sk);
+       if (unlikely(!psock)) {
+               rcu_read_unlock();
+               release_sock(sk);
+               return sk->sk_prot->close(sk, timeout);
+       }
+
+       saved_close = psock->saved_close;
+       sock_map_remove_links(sk, psock);
+       rcu_read_unlock();
+       release_sock(sk);
+       saved_close(sk, timeout);
+}
index ed8a8f3..fe7b4fb 100644 (file)
@@ -528,57 +528,7 @@ out_err:
        return copied ? copied : err;
 }
 
-static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock)
-{
-       struct sk_psock_link *link;
-
-       while ((link = sk_psock_link_pop(psock))) {
-               sk_psock_unlink(sk, link);
-               sk_psock_free_link(link);
-       }
-}
-
-static void tcp_bpf_unhash(struct sock *sk)
-{
-       void (*saved_unhash)(struct sock *sk);
-       struct sk_psock *psock;
-
-       rcu_read_lock();
-       psock = sk_psock(sk);
-       if (unlikely(!psock)) {
-               rcu_read_unlock();
-               if (sk->sk_prot->unhash)
-                       sk->sk_prot->unhash(sk);
-               return;
-       }
-
-       saved_unhash = psock->saved_unhash;
-       tcp_bpf_remove(sk, psock);
-       rcu_read_unlock();
-       saved_unhash(sk);
-}
-
-static void tcp_bpf_close(struct sock *sk, long timeout)
-{
-       void (*saved_close)(struct sock *sk, long timeout);
-       struct sk_psock *psock;
-
-       lock_sock(sk);
-       rcu_read_lock();
-       psock = sk_psock(sk);
-       if (unlikely(!psock)) {
-               rcu_read_unlock();
-               release_sock(sk);
-               return sk->sk_prot->close(sk, timeout);
-       }
-
-       saved_close = psock->saved_close;
-       tcp_bpf_remove(sk, psock);
-       rcu_read_unlock();
-       release_sock(sk);
-       saved_close(sk, timeout);
-}
-
+#ifdef CONFIG_BPF_STREAM_PARSER
 enum {
        TCP_BPF_IPV4,
        TCP_BPF_IPV6,
@@ -599,8 +549,8 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS],
                                   struct proto *base)
 {
        prot[TCP_BPF_BASE]                      = *base;
-       prot[TCP_BPF_BASE].unhash               = tcp_bpf_unhash;
-       prot[TCP_BPF_BASE].close                = tcp_bpf_close;
+       prot[TCP_BPF_BASE].unhash               = sock_map_unhash;
+       prot[TCP_BPF_BASE].close                = sock_map_close;
        prot[TCP_BPF_BASE].recvmsg              = tcp_bpf_recvmsg;
        prot[TCP_BPF_BASE].stream_memory_read   = tcp_bpf_stream_read;
 
@@ -640,7 +590,7 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
               ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 }
 
-static struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
+struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
 {
        int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
        int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
@@ -657,31 +607,6 @@ static struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
        return &tcp_bpf_prots[family][config];
 }
 
-int tcp_bpf_init(struct sock *sk)
-{
-       struct sk_psock *psock;
-       struct proto *prot;
-
-       sock_owned_by_me(sk);
-
-       rcu_read_lock();
-       psock = sk_psock(sk);
-       if (unlikely(!psock)) {
-               rcu_read_unlock();
-               return -EINVAL;
-       }
-
-       prot = tcp_bpf_get_proto(sk, psock);
-       if (IS_ERR(prot)) {
-               rcu_read_unlock();
-               return PTR_ERR(prot);
-       }
-
-       sk_psock_update_proto(sk, psock, prot);
-       rcu_read_unlock();
-       return 0;
-}
-
 /* If a child got cloned from a listening socket that had tcp_bpf
  * protocol callbacks installed, we need to restore the callbacks to
  * the default ones because the child does not inherit the psock state
@@ -695,3 +620,4 @@ void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
        if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE])
                newsk->sk_prot = sk->sk_prot_creator;
 }
+#endif /* CONFIG_BPF_STREAM_PARSER */