udp: Implement udp_bpf_recvmsg() for sockmap
authorCong Wang <cong.wang@bytedance.com>
Wed, 31 Mar 2021 02:32:34 +0000 (19:32 -0700)
committerAlexei Starovoitov <ast@kernel.org>
Thu, 1 Apr 2021 17:56:14 +0000 (10:56 -0700)
We have to implement udp_bpf_recvmsg() to replace the ->recvmsg()
to retrieve skmsg from ingress_msg.

Signed-off-by: Cong Wang <cong.wang@bytedance.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Acked-by: John Fastabend <john.fastabend@gmail.com>
Link: https://lore.kernel.org/bpf/20210331023237.41094-14-xiyou.wangcong@gmail.com
net/ipv4/udp_bpf.c

index 6001f93..7d5c4eb 100644 (file)
@@ -4,6 +4,68 @@
 #include <linux/skmsg.h>
 #include <net/sock.h>
 #include <net/udp.h>
+#include <net/inet_common.h>
+
+#include "udp_impl.h"
+
+static struct proto *udpv6_prot_saved __read_mostly;
+
+static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
+                         int noblock, int flags, int *addr_len)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+       if (sk->sk_family == AF_INET6)
+               return udpv6_prot_saved->recvmsg(sk, msg, len, noblock, flags,
+                                                addr_len);
+#endif
+       return udp_prot.recvmsg(sk, msg, len, noblock, flags, addr_len);
+}
+
+static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
+                          int nonblock, int flags, int *addr_len)
+{
+       struct sk_psock *psock;
+       int copied, ret;
+
+       if (unlikely(flags & MSG_ERRQUEUE))
+               return inet_recv_error(sk, msg, len, addr_len);
+
+       psock = sk_psock_get(sk);
+       if (unlikely(!psock))
+               return sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
+
+       lock_sock(sk);
+       if (sk_psock_queue_empty(psock)) {
+               ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
+               goto out;
+       }
+
+msg_bytes_ready:
+       copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
+       if (!copied) {
+               int data, err = 0;
+               long timeo;
+
+               timeo = sock_rcvtimeo(sk, nonblock);
+               data = sk_msg_wait_data(sk, psock, flags, timeo, &err);
+               if (data) {
+                       if (!sk_psock_queue_empty(psock))
+                               goto msg_bytes_ready;
+                       ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
+                       goto out;
+               }
+               if (err) {
+                       ret = err;
+                       goto out;
+               }
+               copied = -EAGAIN;
+       }
+       ret = copied;
+out:
+       release_sock(sk);
+       sk_psock_put(sk, psock);
+       return ret;
+}
 
 enum {
        UDP_BPF_IPV4,
@@ -11,7 +73,6 @@ enum {
        UDP_BPF_NUM_PROTS,
 };
 
-static struct proto *udpv6_prot_saved __read_mostly;
 static DEFINE_SPINLOCK(udpv6_prot_lock);
 static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS];
 
@@ -20,6 +81,7 @@ static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
        *prot        = *base;
        prot->unhash = sock_map_unhash;
        prot->close  = sock_map_close;
+       prot->recvmsg = udp_bpf_recvmsg;
 }
 
 static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)