Merge branch 'Add SO_REUSEPORT support for TC bpf_sk_assign'
authorMartin KaFai Lau <martin.lau@kernel.org>
Mon, 24 Jul 2023 20:28:41 +0000 (13:28 -0700)
committerMartin KaFai Lau <martin.lau@kernel.org>
Tue, 25 Jul 2023 21:07:08 +0000 (14:07 -0700)
Lorenz Bauer says:

====================
We want to replace iptables TPROXY with a BPF program at TC ingress.
To make this work in all cases we need to assign a SO_REUSEPORT socket
to an skb, which is currently prohibited. This series adds support for
such sockets to bpf_sk_assing.

I did some refactoring to cut down on the amount of duplicate code. The
key to this is to use INDIRECT_CALL in the reuseport helpers. To show
that this approach is not just beneficial to TC sk_assign I removed
duplicate code for bpf_sk_lookup as well.

Joint work with Daniel Borkmann.

Signed-off-by: Lorenz Bauer <lmb@isovalent.com>
---
Changes in v6:
- Reject unhashed UDP sockets in bpf_sk_assign to avoid ref leak
- Link to v5: https://lore.kernel.org/r/20230613-so-reuseport-v5-0-f6686a0dbce0@isovalent.com

Changes in v5:
- Drop reuse_sk == sk check in inet[6]_steal_stock (Kuniyuki)
- Link to v4: https://lore.kernel.org/r/20230613-so-reuseport-v4-0-4ece76708bba@isovalent.com

Changes in v4:
- WARN_ON_ONCE if reuseport socket is refcounted (Kuniyuki)
- Use inet[6]_ehashfn_t to shorten function declarations (Kuniyuki)
- Shuffle documentation patch around (Kuniyuki)
- Update commit message to explain why IPv6 needs EXPORT_SYMBOL
- Link to v3: https://lore.kernel.org/r/20230613-so-reuseport-v3-0-907b4cbb7b99@isovalent.com

Changes in v3:
- Fix warning re udp_ehashfn and udp6_ehashfn (Simon)
- Return higher scoring connected UDP reuseport sockets (Kuniyuki)
- Fix ipv6 module builds
- Link to v2: https://lore.kernel.org/r/20230613-so-reuseport-v2-0-b7c69a342613@isovalent.com

Changes in v2:
- Correct commit abbrev length (Kuniyuki)
- Reduce duplication (Kuniyuki)
- Add checks on sk_state (Martin)
- Split exporting inet[6]_lookup_reuseport into separate patch (Eric)

---
Daniel Borkmann (1):
      selftests/bpf: Test that SO_REUSEPORT can be used with sk_assign helper
====================

Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>
13 files changed:
include/net/inet6_hashtables.h
include/net/inet_hashtables.h
include/net/sock.h
include/uapi/linux/bpf.h
net/core/filter.c
net/ipv4/inet_hashtables.c
net/ipv4/udp.c
net/ipv6/inet6_hashtables.c
net/ipv6/udp.c
tools/include/uapi/linux/bpf.h
tools/testing/selftests/bpf/network_helpers.c
tools/testing/selftests/bpf/prog_tests/assign_reuse.c [new file with mode: 0644]
tools/testing/selftests/bpf/progs/test_assign_reuse.c [new file with mode: 0644]

index 56f1286..284b5ce 100644 (file)
@@ -48,6 +48,22 @@ struct sock *__inet6_lookup_established(struct net *net,
                                        const u16 hnum, const int dif,
                                        const int sdif);
 
+typedef u32 (inet6_ehashfn_t)(const struct net *net,
+                              const struct in6_addr *laddr, const u16 lport,
+                              const struct in6_addr *faddr, const __be16 fport);
+
+inet6_ehashfn_t inet6_ehashfn;
+
+INDIRECT_CALLABLE_DECLARE(inet6_ehashfn_t udp6_ehashfn);
+
+struct sock *inet6_lookup_reuseport(struct net *net, struct sock *sk,
+                                   struct sk_buff *skb, int doff,
+                                   const struct in6_addr *saddr,
+                                   __be16 sport,
+                                   const struct in6_addr *daddr,
+                                   unsigned short hnum,
+                                   inet6_ehashfn_t *ehashfn);
+
 struct sock *inet6_lookup_listener(struct net *net,
                                   struct inet_hashinfo *hashinfo,
                                   struct sk_buff *skb, int doff,
@@ -57,6 +73,15 @@ struct sock *inet6_lookup_listener(struct net *net,
                                   const unsigned short hnum,
                                   const int dif, const int sdif);
 
+struct sock *inet6_lookup_run_sk_lookup(struct net *net,
+                                       int protocol,
+                                       struct sk_buff *skb, int doff,
+                                       const struct in6_addr *saddr,
+                                       const __be16 sport,
+                                       const struct in6_addr *daddr,
+                                       const u16 hnum, const int dif,
+                                       inet6_ehashfn_t *ehashfn);
+
 static inline struct sock *__inet6_lookup(struct net *net,
                                          struct inet_hashinfo *hashinfo,
                                          struct sk_buff *skb, int doff,
@@ -78,6 +103,46 @@ static inline struct sock *__inet6_lookup(struct net *net,
                                     daddr, hnum, dif, sdif);
 }
 
+static inline
+struct sock *inet6_steal_sock(struct net *net, struct sk_buff *skb, int doff,
+                             const struct in6_addr *saddr, const __be16 sport,
+                             const struct in6_addr *daddr, const __be16 dport,
+                             bool *refcounted, inet6_ehashfn_t *ehashfn)
+{
+       struct sock *sk, *reuse_sk;
+       bool prefetched;
+
+       sk = skb_steal_sock(skb, refcounted, &prefetched);
+       if (!sk)
+               return NULL;
+
+       if (!prefetched)
+               return sk;
+
+       if (sk->sk_protocol == IPPROTO_TCP) {
+               if (sk->sk_state != TCP_LISTEN)
+                       return sk;
+       } else if (sk->sk_protocol == IPPROTO_UDP) {
+               if (sk->sk_state != TCP_CLOSE)
+                       return sk;
+       } else {
+               return sk;
+       }
+
+       reuse_sk = inet6_lookup_reuseport(net, sk, skb, doff,
+                                         saddr, sport, daddr, ntohs(dport),
+                                         ehashfn);
+       if (!reuse_sk)
+               return sk;
+
+       /* We've chosen a new reuseport sock which is never refcounted. This
+        * implies that sk also isn't refcounted.
+        */
+       WARN_ON_ONCE(*refcounted);
+
+       return reuse_sk;
+}
+
 static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
                                              struct sk_buff *skb, int doff,
                                              const __be16 sport,
@@ -85,14 +150,20 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
                                              int iif, int sdif,
                                              bool *refcounted)
 {
-       struct sock *sk = skb_steal_sock(skb, refcounted);
-
+       struct net *net = dev_net(skb_dst(skb)->dev);
+       const struct ipv6hdr *ip6h = ipv6_hdr(skb);
+       struct sock *sk;
+
+       sk = inet6_steal_sock(net, skb, doff, &ip6h->saddr, sport, &ip6h->daddr, dport,
+                             refcounted, inet6_ehashfn);
+       if (IS_ERR(sk))
+               return NULL;
        if (sk)
                return sk;
 
-       return __inet6_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
-                             doff, &ipv6_hdr(skb)->saddr, sport,
-                             &ipv6_hdr(skb)->daddr, ntohs(dport),
+       return __inet6_lookup(net, hashinfo, skb,
+                             doff, &ip6h->saddr, sport,
+                             &ip6h->daddr, ntohs(dport),
                              iif, sdif, refcounted);
 }
 
index 99bd823..1177eff 100644 (file)
@@ -379,6 +379,27 @@ struct sock *__inet_lookup_established(struct net *net,
                                       const __be32 daddr, const u16 hnum,
                                       const int dif, const int sdif);
 
+typedef u32 (inet_ehashfn_t)(const struct net *net,
+                             const __be32 laddr, const __u16 lport,
+                             const __be32 faddr, const __be16 fport);
+
+inet_ehashfn_t inet_ehashfn;
+
+INDIRECT_CALLABLE_DECLARE(inet_ehashfn_t udp_ehashfn);
+
+struct sock *inet_lookup_reuseport(struct net *net, struct sock *sk,
+                                  struct sk_buff *skb, int doff,
+                                  __be32 saddr, __be16 sport,
+                                  __be32 daddr, unsigned short hnum,
+                                  inet_ehashfn_t *ehashfn);
+
+struct sock *inet_lookup_run_sk_lookup(struct net *net,
+                                      int protocol,
+                                      struct sk_buff *skb, int doff,
+                                      __be32 saddr, __be16 sport,
+                                      __be32 daddr, u16 hnum, const int dif,
+                                      inet_ehashfn_t *ehashfn);
+
 static inline struct sock *
        inet_lookup_established(struct net *net, struct inet_hashinfo *hashinfo,
                                const __be32 saddr, const __be16 sport,
@@ -428,6 +449,46 @@ static inline struct sock *inet_lookup(struct net *net,
        return sk;
 }
 
+static inline
+struct sock *inet_steal_sock(struct net *net, struct sk_buff *skb, int doff,
+                            const __be32 saddr, const __be16 sport,
+                            const __be32 daddr, const __be16 dport,
+                            bool *refcounted, inet_ehashfn_t *ehashfn)
+{
+       struct sock *sk, *reuse_sk;
+       bool prefetched;
+
+       sk = skb_steal_sock(skb, refcounted, &prefetched);
+       if (!sk)
+               return NULL;
+
+       if (!prefetched)
+               return sk;
+
+       if (sk->sk_protocol == IPPROTO_TCP) {
+               if (sk->sk_state != TCP_LISTEN)
+                       return sk;
+       } else if (sk->sk_protocol == IPPROTO_UDP) {
+               if (sk->sk_state != TCP_CLOSE)
+                       return sk;
+       } else {
+               return sk;
+       }
+
+       reuse_sk = inet_lookup_reuseport(net, sk, skb, doff,
+                                        saddr, sport, daddr, ntohs(dport),
+                                        ehashfn);
+       if (!reuse_sk)
+               return sk;
+
+       /* We've chosen a new reuseport sock which is never refcounted. This
+        * implies that sk also isn't refcounted.
+        */
+       WARN_ON_ONCE(*refcounted);
+
+       return reuse_sk;
+}
+
 static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
                                             struct sk_buff *skb,
                                             int doff,
@@ -436,22 +497,23 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
                                             const int sdif,
                                             bool *refcounted)
 {
-       struct sock *sk = skb_steal_sock(skb, refcounted);
+       struct net *net = dev_net(skb_dst(skb)->dev);
        const struct iphdr *iph = ip_hdr(skb);
+       struct sock *sk;
 
+       sk = inet_steal_sock(net, skb, doff, iph->saddr, sport, iph->daddr, dport,
+                            refcounted, inet_ehashfn);
+       if (IS_ERR(sk))
+               return NULL;
        if (sk)
                return sk;
 
-       return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
+       return __inet_lookup(net, hashinfo, skb,
                             doff, iph->saddr, sport,
                             iph->daddr, dport, inet_iif(skb), sdif,
                             refcounted);
 }
 
-u32 inet6_ehashfn(const struct net *net,
-                 const struct in6_addr *laddr, const u16 lport,
-                 const struct in6_addr *faddr, const __be16 fport);
-
 static inline void sk_daddr_set(struct sock *sk, __be32 addr)
 {
        sk->sk_daddr = addr; /* alias of inet_daddr */
index 7ae44bf..74cbfb1 100644 (file)
@@ -2815,20 +2815,23 @@ sk_is_refcounted(struct sock *sk)
  * skb_steal_sock - steal a socket from an sk_buff
  * @skb: sk_buff to steal the socket from
  * @refcounted: is set to true if the socket is reference-counted
+ * @prefetched: is set to true if the socket was assigned from bpf
  */
 static inline struct sock *
-skb_steal_sock(struct sk_buff *skb, bool *refcounted)
+skb_steal_sock(struct sk_buff *skb, bool *refcounted, bool *prefetched)
 {
        if (skb->sk) {
                struct sock *sk = skb->sk;
 
                *refcounted = true;
-               if (skb_sk_is_prefetched(skb))
+               *prefetched = skb_sk_is_prefetched(skb);
+               if (*prefetched)
                        *refcounted = sk_is_refcounted(sk);
                skb->destructor = NULL;
                skb->sk = NULL;
                return sk;
        }
+       *prefetched = false;
        *refcounted = false;
        return NULL;
 }
index 739c159..7fc98f4 100644 (file)
@@ -4198,9 +4198,6 @@ union bpf_attr {
  *             **-EOPNOTSUPP** if the operation is not supported, for example
  *             a call from outside of TC ingress.
  *
- *             **-ESOCKTNOSUPPORT** if the socket type is not supported
- *             (reuseport).
- *
  * long bpf_sk_assign(struct bpf_sk_lookup *ctx, struct bpf_sock *sk, u64 flags)
  *     Description
  *             Helper is overloaded depending on BPF program type. This
index 797e8f0..7c37f46 100644 (file)
@@ -7351,8 +7351,8 @@ BPF_CALL_3(bpf_sk_assign, struct sk_buff *, skb, struct sock *, sk, u64, flags)
                return -EOPNOTSUPP;
        if (unlikely(dev_net(skb->dev) != sock_net(sk)))
                return -ENETUNREACH;
-       if (unlikely(sk_fullsock(sk) && sk->sk_reuseport))
-               return -ESOCKTNOSUPPORT;
+       if (sk_unhashed(sk))
+               return -EOPNOTSUPP;
        if (sk_is_refcounted(sk) &&
            unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
                return -ENOENT;
index 0819d60..6a872b8 100644 (file)
@@ -28,9 +28,9 @@
 #include <net/tcp.h>
 #include <net/sock_reuseport.h>
 
-static u32 inet_ehashfn(const struct net *net, const __be32 laddr,
-                       const __u16 lport, const __be32 faddr,
-                       const __be16 fport)
+u32 inet_ehashfn(const struct net *net, const __be32 laddr,
+                const __u16 lport, const __be32 faddr,
+                const __be16 fport)
 {
        static u32 inet_ehash_secret __read_mostly;
 
@@ -39,6 +39,7 @@ static u32 inet_ehashfn(const struct net *net, const __be32 laddr,
        return __inet_ehashfn(laddr, lport, faddr, fport,
                              inet_ehash_secret + net_hash_mix(net));
 }
+EXPORT_SYMBOL_GPL(inet_ehashfn);
 
 /* This function handles inet_sock, but also timewait and request sockets
  * for IPv4/IPv6.
@@ -332,20 +333,40 @@ static inline int compute_score(struct sock *sk, struct net *net,
        return score;
 }
 
-static inline struct sock *lookup_reuseport(struct net *net, struct sock *sk,
-                                           struct sk_buff *skb, int doff,
-                                           __be32 saddr, __be16 sport,
-                                           __be32 daddr, unsigned short hnum)
+INDIRECT_CALLABLE_DECLARE(inet_ehashfn_t udp_ehashfn);
+
+/**
+ * inet_lookup_reuseport() - execute reuseport logic on AF_INET socket if necessary.
+ * @net: network namespace.
+ * @sk: AF_INET socket, must be in TCP_LISTEN state for TCP or TCP_CLOSE for UDP.
+ * @skb: context for a potential SK_REUSEPORT program.
+ * @doff: header offset.
+ * @saddr: source address.
+ * @sport: source port.
+ * @daddr: destination address.
+ * @hnum: destination port in host byte order.
+ * @ehashfn: hash function used to generate the fallback hash.
+ *
+ * Return: NULL if sk doesn't have SO_REUSEPORT set, otherwise a pointer to
+ *         the selected sock or an error.
+ */
+struct sock *inet_lookup_reuseport(struct net *net, struct sock *sk,
+                                  struct sk_buff *skb, int doff,
+                                  __be32 saddr, __be16 sport,
+                                  __be32 daddr, unsigned short hnum,
+                                  inet_ehashfn_t *ehashfn)
 {
        struct sock *reuse_sk = NULL;
        u32 phash;
 
        if (sk->sk_reuseport) {
-               phash = inet_ehashfn(net, daddr, hnum, saddr, sport);
+               phash = INDIRECT_CALL_2(ehashfn, udp_ehashfn, inet_ehashfn,
+                                       net, daddr, hnum, saddr, sport);
                reuse_sk = reuseport_select_sock(sk, phash, skb, doff);
        }
        return reuse_sk;
 }
+EXPORT_SYMBOL_GPL(inet_lookup_reuseport);
 
 /*
  * Here are some nice properties to exploit here. The BSD API
@@ -369,8 +390,8 @@ static struct sock *inet_lhash2_lookup(struct net *net,
        sk_nulls_for_each_rcu(sk, node, &ilb2->nulls_head) {
                score = compute_score(sk, net, hnum, daddr, dif, sdif);
                if (score > hiscore) {
-                       result = lookup_reuseport(net, sk, skb, doff,
-                                                 saddr, sport, daddr, hnum);
+                       result = inet_lookup_reuseport(net, sk, skb, doff,
+                                                      saddr, sport, daddr, hnum, inet_ehashfn);
                        if (result)
                                return result;
 
@@ -382,24 +403,23 @@ static struct sock *inet_lhash2_lookup(struct net *net,
        return result;
 }
 
-static inline struct sock *inet_lookup_run_bpf(struct net *net,
-                                              struct inet_hashinfo *hashinfo,
-                                              struct sk_buff *skb, int doff,
-                                              __be32 saddr, __be16 sport,
-                                              __be32 daddr, u16 hnum, const int dif)
+struct sock *inet_lookup_run_sk_lookup(struct net *net,
+                                      int protocol,
+                                      struct sk_buff *skb, int doff,
+                                      __be32 saddr, __be16 sport,
+                                      __be32 daddr, u16 hnum, const int dif,
+                                      inet_ehashfn_t *ehashfn)
 {
        struct sock *sk, *reuse_sk;
        bool no_reuseport;
 
-       if (hashinfo != net->ipv4.tcp_death_row.hashinfo)
-               return NULL; /* only TCP is supported */
-
-       no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_TCP, saddr, sport,
+       no_reuseport = bpf_sk_lookup_run_v4(net, protocol, saddr, sport,
                                            daddr, hnum, dif, &sk);
        if (no_reuseport || IS_ERR_OR_NULL(sk))
                return sk;
 
-       reuse_sk = lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum);
+       reuse_sk = inet_lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum,
+                                        ehashfn);
        if (reuse_sk)
                sk = reuse_sk;
        return sk;
@@ -417,9 +437,11 @@ struct sock *__inet_lookup_listener(struct net *net,
        unsigned int hash2;
 
        /* Lookup redirect from BPF */
-       if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
-               result = inet_lookup_run_bpf(net, hashinfo, skb, doff,
-                                            saddr, sport, daddr, hnum, dif);
+       if (static_branch_unlikely(&bpf_sk_lookup_enabled) &&
+           hashinfo == net->ipv4.tcp_death_row.hashinfo) {
+               result = inet_lookup_run_sk_lookup(net, IPPROTO_TCP, skb, doff,
+                                                  saddr, sport, daddr, hnum, dif,
+                                                  inet_ehashfn);
                if (result)
                        goto done;
        }
index 8c3ebd9..d89c4a3 100644 (file)
@@ -406,9 +406,9 @@ static int compute_score(struct sock *sk, struct net *net,
        return score;
 }
 
-static u32 udp_ehashfn(const struct net *net, const __be32 laddr,
-                      const __u16 lport, const __be32 faddr,
-                      const __be16 fport)
+INDIRECT_CALLABLE_SCOPE
+u32 udp_ehashfn(const struct net *net, const __be32 laddr, const __u16 lport,
+               const __be32 faddr, const __be16 fport)
 {
        static u32 udp_ehash_secret __read_mostly;
 
@@ -418,22 +418,6 @@ static u32 udp_ehashfn(const struct net *net, const __be32 laddr,
                              udp_ehash_secret + net_hash_mix(net));
 }
 
-static struct sock *lookup_reuseport(struct net *net, struct sock *sk,
-                                    struct sk_buff *skb,
-                                    __be32 saddr, __be16 sport,
-                                    __be32 daddr, unsigned short hnum)
-{
-       struct sock *reuse_sk = NULL;
-       u32 hash;
-
-       if (sk->sk_reuseport && sk->sk_state != TCP_ESTABLISHED) {
-               hash = udp_ehashfn(net, daddr, hnum, saddr, sport);
-               reuse_sk = reuseport_select_sock(sk, hash, skb,
-                                                sizeof(struct udphdr));
-       }
-       return reuse_sk;
-}
-
 /* called with rcu_read_lock() */
 static struct sock *udp4_lib_lookup2(struct net *net,
                                     __be32 saddr, __be16 sport,
@@ -451,42 +435,36 @@ static struct sock *udp4_lib_lookup2(struct net *net,
                score = compute_score(sk, net, saddr, sport,
                                      daddr, hnum, dif, sdif);
                if (score > badness) {
-                       result = lookup_reuseport(net, sk, skb,
-                                                 saddr, sport, daddr, hnum);
+                       badness = score;
+
+                       if (sk->sk_state == TCP_ESTABLISHED) {
+                               result = sk;
+                               continue;
+                       }
+
+                       result = inet_lookup_reuseport(net, sk, skb, sizeof(struct udphdr),
+                                                      saddr, sport, daddr, hnum, udp_ehashfn);
+                       if (!result) {
+                               result = sk;
+                               continue;
+                       }
+
                        /* Fall back to scoring if group has connections */
-                       if (result && !reuseport_has_conns(sk))
+                       if (!reuseport_has_conns(sk))
                                return result;
 
-                       result = result ? : sk;
-                       badness = score;
+                       /* Reuseport logic returned an error, keep original score. */
+                       if (IS_ERR(result))
+                               continue;
+
+                       badness = compute_score(result, net, saddr, sport,
+                                               daddr, hnum, dif, sdif);
+
                }
        }
        return result;
 }
 
-static struct sock *udp4_lookup_run_bpf(struct net *net,
-                                       struct udp_table *udptable,
-                                       struct sk_buff *skb,
-                                       __be32 saddr, __be16 sport,
-                                       __be32 daddr, u16 hnum, const int dif)
-{
-       struct sock *sk, *reuse_sk;
-       bool no_reuseport;
-
-       if (udptable != net->ipv4.udp_table)
-               return NULL; /* only UDP is supported */
-
-       no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_UDP, saddr, sport,
-                                           daddr, hnum, dif, &sk);
-       if (no_reuseport || IS_ERR_OR_NULL(sk))
-               return sk;
-
-       reuse_sk = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum);
-       if (reuse_sk)
-               sk = reuse_sk;
-       return sk;
-}
-
 /* UDP is nearly always wildcards out the wazoo, it makes no sense to try
  * harder than this. -DaveM
  */
@@ -511,9 +489,11 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
                goto done;
 
        /* Lookup redirect from BPF */
-       if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
-               sk = udp4_lookup_run_bpf(net, udptable, skb,
-                                        saddr, sport, daddr, hnum, dif);
+       if (static_branch_unlikely(&bpf_sk_lookup_enabled) &&
+           udptable == net->ipv4.udp_table) {
+               sk = inet_lookup_run_sk_lookup(net, IPPROTO_UDP, skb, sizeof(struct udphdr),
+                                              saddr, sport, daddr, hnum, dif,
+                                              udp_ehashfn);
                if (sk) {
                        result = sk;
                        goto done;
@@ -2408,7 +2388,11 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
        if (udp4_csum_init(skb, uh, proto))
                goto csum_error;
 
-       sk = skb_steal_sock(skb, &refcounted);
+       sk = inet_steal_sock(net, skb, sizeof(struct udphdr), saddr, uh->source, daddr, uh->dest,
+                            &refcounted, udp_ehashfn);
+       if (IS_ERR(sk))
+               goto no_sk;
+
        if (sk) {
                struct dst_entry *dst = skb_dst(skb);
                int ret;
@@ -2429,7 +2413,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
        sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable);
        if (sk)
                return udp_unicast_rcv_skb(sk, skb, uh);
-
+no_sk:
        if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb))
                goto drop;
        nf_reset_ct(skb);
index b64b490..7c9700c 100644 (file)
@@ -39,6 +39,7 @@ u32 inet6_ehashfn(const struct net *net,
        return __inet6_ehashfn(lhash, lport, fhash, fport,
                               inet6_ehash_secret + net_hash_mix(net));
 }
+EXPORT_SYMBOL_GPL(inet6_ehashfn);
 
 /*
  * Sockets in TCP_CLOSE state are _always_ taken out of the hash, so
@@ -111,22 +112,42 @@ static inline int compute_score(struct sock *sk, struct net *net,
        return score;
 }
 
-static inline struct sock *lookup_reuseport(struct net *net, struct sock *sk,
-                                           struct sk_buff *skb, int doff,
-                                           const struct in6_addr *saddr,
-                                           __be16 sport,
-                                           const struct in6_addr *daddr,
-                                           unsigned short hnum)
+INDIRECT_CALLABLE_DECLARE(inet6_ehashfn_t udp6_ehashfn);
+
+/**
+ * inet6_lookup_reuseport() - execute reuseport logic on AF_INET6 socket if necessary.
+ * @net: network namespace.
+ * @sk: AF_INET6 socket, must be in TCP_LISTEN state for TCP or TCP_CLOSE for UDP.
+ * @skb: context for a potential SK_REUSEPORT program.
+ * @doff: header offset.
+ * @saddr: source address.
+ * @sport: source port.
+ * @daddr: destination address.
+ * @hnum: destination port in host byte order.
+ * @ehashfn: hash function used to generate the fallback hash.
+ *
+ * Return: NULL if sk doesn't have SO_REUSEPORT set, otherwise a pointer to
+ *         the selected sock or an error.
+ */
+struct sock *inet6_lookup_reuseport(struct net *net, struct sock *sk,
+                                   struct sk_buff *skb, int doff,
+                                   const struct in6_addr *saddr,
+                                   __be16 sport,
+                                   const struct in6_addr *daddr,
+                                   unsigned short hnum,
+                                   inet6_ehashfn_t *ehashfn)
 {
        struct sock *reuse_sk = NULL;
        u32 phash;
 
        if (sk->sk_reuseport) {
-               phash = inet6_ehashfn(net, daddr, hnum, saddr, sport);
+               phash = INDIRECT_CALL_INET(ehashfn, udp6_ehashfn, inet6_ehashfn,
+                                          net, daddr, hnum, saddr, sport);
                reuse_sk = reuseport_select_sock(sk, phash, skb, doff);
        }
        return reuse_sk;
 }
+EXPORT_SYMBOL_GPL(inet6_lookup_reuseport);
 
 /* called with rcu_read_lock() */
 static struct sock *inet6_lhash2_lookup(struct net *net,
@@ -143,8 +164,8 @@ static struct sock *inet6_lhash2_lookup(struct net *net,
        sk_nulls_for_each_rcu(sk, node, &ilb2->nulls_head) {
                score = compute_score(sk, net, hnum, daddr, dif, sdif);
                if (score > hiscore) {
-                       result = lookup_reuseport(net, sk, skb, doff,
-                                                 saddr, sport, daddr, hnum);
+                       result = inet6_lookup_reuseport(net, sk, skb, doff,
+                                                       saddr, sport, daddr, hnum, inet6_ehashfn);
                        if (result)
                                return result;
 
@@ -156,30 +177,30 @@ static struct sock *inet6_lhash2_lookup(struct net *net,
        return result;
 }
 
-static inline struct sock *inet6_lookup_run_bpf(struct net *net,
-                                               struct inet_hashinfo *hashinfo,
-                                               struct sk_buff *skb, int doff,
-                                               const struct in6_addr *saddr,
-                                               const __be16 sport,
-                                               const struct in6_addr *daddr,
-                                               const u16 hnum, const int dif)
+struct sock *inet6_lookup_run_sk_lookup(struct net *net,
+                                       int protocol,
+                                       struct sk_buff *skb, int doff,
+                                       const struct in6_addr *saddr,
+                                       const __be16 sport,
+                                       const struct in6_addr *daddr,
+                                       const u16 hnum, const int dif,
+                                       inet6_ehashfn_t *ehashfn)
 {
        struct sock *sk, *reuse_sk;
        bool no_reuseport;
 
-       if (hashinfo != net->ipv4.tcp_death_row.hashinfo)
-               return NULL; /* only TCP is supported */
-
-       no_reuseport = bpf_sk_lookup_run_v6(net, IPPROTO_TCP, saddr, sport,
+       no_reuseport = bpf_sk_lookup_run_v6(net, protocol, saddr, sport,
                                            daddr, hnum, dif, &sk);
        if (no_reuseport || IS_ERR_OR_NULL(sk))
                return sk;
 
-       reuse_sk = lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum);
+       reuse_sk = inet6_lookup_reuseport(net, sk, skb, doff,
+                                         saddr, sport, daddr, hnum, ehashfn);
        if (reuse_sk)
                sk = reuse_sk;
        return sk;
 }
+EXPORT_SYMBOL_GPL(inet6_lookup_run_sk_lookup);
 
 struct sock *inet6_lookup_listener(struct net *net,
                struct inet_hashinfo *hashinfo,
@@ -193,9 +214,11 @@ struct sock *inet6_lookup_listener(struct net *net,
        unsigned int hash2;
 
        /* Lookup redirect from BPF */
-       if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
-               result = inet6_lookup_run_bpf(net, hashinfo, skb, doff,
-                                             saddr, sport, daddr, hnum, dif);
+       if (static_branch_unlikely(&bpf_sk_lookup_enabled) &&
+           hashinfo == net->ipv4.tcp_death_row.hashinfo) {
+               result = inet6_lookup_run_sk_lookup(net, IPPROTO_TCP, skb, doff,
+                                                   saddr, sport, daddr, hnum, dif,
+                                                   inet6_ehashfn);
                if (result)
                        goto done;
        }
index 95c75d8..00996f0 100644 (file)
@@ -71,11 +71,12 @@ int udpv6_init_sock(struct sock *sk)
        return 0;
 }
 
-static u32 udp6_ehashfn(const struct net *net,
-                       const struct in6_addr *laddr,
-                       const u16 lport,
-                       const struct in6_addr *faddr,
-                       const __be16 fport)
+INDIRECT_CALLABLE_SCOPE
+u32 udp6_ehashfn(const struct net *net,
+                const struct in6_addr *laddr,
+                const u16 lport,
+                const struct in6_addr *faddr,
+                const __be16 fport)
 {
        static u32 udp6_ehash_secret __read_mostly;
        static u32 udp_ipv6_hash_secret __read_mostly;
@@ -160,24 +161,6 @@ static int compute_score(struct sock *sk, struct net *net,
        return score;
 }
 
-static struct sock *lookup_reuseport(struct net *net, struct sock *sk,
-                                    struct sk_buff *skb,
-                                    const struct in6_addr *saddr,
-                                    __be16 sport,
-                                    const struct in6_addr *daddr,
-                                    unsigned int hnum)
-{
-       struct sock *reuse_sk = NULL;
-       u32 hash;
-
-       if (sk->sk_reuseport && sk->sk_state != TCP_ESTABLISHED) {
-               hash = udp6_ehashfn(net, daddr, hnum, saddr, sport);
-               reuse_sk = reuseport_select_sock(sk, hash, skb,
-                                                sizeof(struct udphdr));
-       }
-       return reuse_sk;
-}
-
 /* called with rcu_read_lock() */
 static struct sock *udp6_lib_lookup2(struct net *net,
                const struct in6_addr *saddr, __be16 sport,
@@ -194,44 +177,35 @@ static struct sock *udp6_lib_lookup2(struct net *net,
                score = compute_score(sk, net, saddr, sport,
                                      daddr, hnum, dif, sdif);
                if (score > badness) {
-                       result = lookup_reuseport(net, sk, skb,
-                                                 saddr, sport, daddr, hnum);
+                       badness = score;
+
+                       if (sk->sk_state == TCP_ESTABLISHED) {
+                               result = sk;
+                               continue;
+                       }
+
+                       result = inet6_lookup_reuseport(net, sk, skb, sizeof(struct udphdr),
+                                                       saddr, sport, daddr, hnum, udp6_ehashfn);
+                       if (!result) {
+                               result = sk;
+                               continue;
+                       }
+
                        /* Fall back to scoring if group has connections */
-                       if (result && !reuseport_has_conns(sk))
+                       if (!reuseport_has_conns(sk))
                                return result;
 
-                       result = result ? : sk;
-                       badness = score;
+                       /* Reuseport logic returned an error, keep original score. */
+                       if (IS_ERR(result))
+                               continue;
+
+                       badness = compute_score(sk, net, saddr, sport,
+                                               daddr, hnum, dif, sdif);
                }
        }
        return result;
 }
 
-static inline struct sock *udp6_lookup_run_bpf(struct net *net,
-                                              struct udp_table *udptable,
-                                              struct sk_buff *skb,
-                                              const struct in6_addr *saddr,
-                                              __be16 sport,
-                                              const struct in6_addr *daddr,
-                                              u16 hnum, const int dif)
-{
-       struct sock *sk, *reuse_sk;
-       bool no_reuseport;
-
-       if (udptable != net->ipv4.udp_table)
-               return NULL; /* only UDP is supported */
-
-       no_reuseport = bpf_sk_lookup_run_v6(net, IPPROTO_UDP, saddr, sport,
-                                           daddr, hnum, dif, &sk);
-       if (no_reuseport || IS_ERR_OR_NULL(sk))
-               return sk;
-
-       reuse_sk = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum);
-       if (reuse_sk)
-               sk = reuse_sk;
-       return sk;
-}
-
 /* rcu_read_lock() must be held */
 struct sock *__udp6_lib_lookup(struct net *net,
                               const struct in6_addr *saddr, __be16 sport,
@@ -256,9 +230,11 @@ struct sock *__udp6_lib_lookup(struct net *net,
                goto done;
 
        /* Lookup redirect from BPF */
-       if (static_branch_unlikely(&bpf_sk_lookup_enabled)) {
-               sk = udp6_lookup_run_bpf(net, udptable, skb,
-                                        saddr, sport, daddr, hnum, dif);
+       if (static_branch_unlikely(&bpf_sk_lookup_enabled) &&
+           udptable == net->ipv4.udp_table) {
+               sk = inet6_lookup_run_sk_lookup(net, IPPROTO_UDP, skb, sizeof(struct udphdr),
+                                               saddr, sport, daddr, hnum, dif,
+                                               udp6_ehashfn);
                if (sk) {
                        result = sk;
                        goto done;
@@ -988,7 +964,11 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
                goto csum_error;
 
        /* Check if the socket is already available, e.g. due to early demux */
-       sk = skb_steal_sock(skb, &refcounted);
+       sk = inet6_steal_sock(net, skb, sizeof(struct udphdr), saddr, uh->source, daddr, uh->dest,
+                             &refcounted, udp6_ehashfn);
+       if (IS_ERR(sk))
+               goto no_sk;
+
        if (sk) {
                struct dst_entry *dst = skb_dst(skb);
                int ret;
@@ -1022,7 +1002,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
                        goto report_csum_error;
                return udp6_unicast_rcv_skb(sk, skb, uh);
        }
-
+no_sk:
        reason = SKB_DROP_REASON_NO_SOCKET;
 
        if (!uh->check)
index 739c159..7fc98f4 100644 (file)
@@ -4198,9 +4198,6 @@ union bpf_attr {
  *             **-EOPNOTSUPP** if the operation is not supported, for example
  *             a call from outside of TC ingress.
  *
- *             **-ESOCKTNOSUPPORT** if the socket type is not supported
- *             (reuseport).
- *
  * long bpf_sk_assign(struct bpf_sk_lookup *ctx, struct bpf_sock *sk, u64 flags)
  *     Description
  *             Helper is overloaded depending on BPF program type. This
index a105c0c..8a33bce 100644 (file)
@@ -423,6 +423,9 @@ fail:
 
 void close_netns(struct nstoken *token)
 {
+       if (!token)
+               return;
+
        ASSERT_OK(setns(token->orig_netns_fd, CLONE_NEWNET), "setns");
        close(token->orig_netns_fd);
        free(token);
diff --git a/tools/testing/selftests/bpf/prog_tests/assign_reuse.c b/tools/testing/selftests/bpf/prog_tests/assign_reuse.c
new file mode 100644 (file)
index 0000000..989ee4d
--- /dev/null
@@ -0,0 +1,199 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Copyright (c) 2023 Isovalent */
+#include <uapi/linux/if_link.h>
+#include <test_progs.h>
+
+#include <netinet/tcp.h>
+#include <netinet/udp.h>
+
+#include "network_helpers.h"
+#include "test_assign_reuse.skel.h"
+
+#define NS_TEST "assign_reuse"
+#define LOOPBACK 1
+#define PORT 4443
+
+static int attach_reuseport(int sock_fd, int prog_fd)
+{
+       return setsockopt(sock_fd, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF,
+                         &prog_fd, sizeof(prog_fd));
+}
+
+static __u64 cookie(int fd)
+{
+       __u64 cookie = 0;
+       socklen_t cookie_len = sizeof(cookie);
+       int ret;
+
+       ret = getsockopt(fd, SOL_SOCKET, SO_COOKIE, &cookie, &cookie_len);
+       ASSERT_OK(ret, "cookie");
+       ASSERT_GT(cookie, 0, "cookie_invalid");
+
+       return cookie;
+}
+
+static int echo_test_udp(int fd_sv)
+{
+       struct sockaddr_storage addr = {};
+       socklen_t len = sizeof(addr);
+       char buff[1] = {};
+       int fd_cl = -1, ret;
+
+       fd_cl = connect_to_fd(fd_sv, 100);
+       ASSERT_GT(fd_cl, 0, "create_client");
+       ASSERT_EQ(getsockname(fd_cl, (void *)&addr, &len), 0, "getsockname");
+
+       ASSERT_EQ(send(fd_cl, buff, sizeof(buff), 0), 1, "send_client");
+
+       ret = recv(fd_sv, buff, sizeof(buff), 0);
+       if (ret < 0) {
+               close(fd_cl);
+               return errno;
+       }
+
+       ASSERT_EQ(ret, 1, "recv_server");
+       ASSERT_EQ(sendto(fd_sv, buff, sizeof(buff), 0, (void *)&addr, len), 1, "send_server");
+       ASSERT_EQ(recv(fd_cl, buff, sizeof(buff), 0), 1, "recv_client");
+       close(fd_cl);
+       return 0;
+}
+
+static int echo_test_tcp(int fd_sv)
+{
+       char buff[1] = {};
+       int fd_cl = -1, fd_sv_cl = -1;
+
+       fd_cl = connect_to_fd(fd_sv, 100);
+       if (fd_cl < 0)
+               return errno;
+
+       fd_sv_cl = accept(fd_sv, NULL, NULL);
+       ASSERT_GE(fd_sv_cl, 0, "accept_fd");
+
+       ASSERT_EQ(send(fd_cl, buff, sizeof(buff), 0), 1, "send_client");
+       ASSERT_EQ(recv(fd_sv_cl, buff, sizeof(buff), 0), 1, "recv_server");
+       ASSERT_EQ(send(fd_sv_cl, buff, sizeof(buff), 0), 1, "send_server");
+       ASSERT_EQ(recv(fd_cl, buff, sizeof(buff), 0), 1, "recv_client");
+       close(fd_sv_cl);
+       close(fd_cl);
+       return 0;
+}
+
+void run_assign_reuse(int family, int sotype, const char *ip, __u16 port)
+{
+       DECLARE_LIBBPF_OPTS(bpf_tc_hook, tc_hook,
+               .ifindex = LOOPBACK,
+               .attach_point = BPF_TC_INGRESS,
+       );
+       DECLARE_LIBBPF_OPTS(bpf_tc_opts, tc_opts,
+               .handle = 1,
+               .priority = 1,
+       );
+       bool hook_created = false, tc_attached = false;
+       int ret, fd_tc, fd_accept, fd_drop, fd_map;
+       int *fd_sv = NULL;
+       __u64 fd_val;
+       struct test_assign_reuse *skel;
+       const int zero = 0;
+
+       skel = test_assign_reuse__open();
+       if (!ASSERT_OK_PTR(skel, "skel_open"))
+               goto cleanup;
+
+       skel->rodata->dest_port = port;
+
+       ret = test_assign_reuse__load(skel);
+       if (!ASSERT_OK(ret, "skel_load"))
+               goto cleanup;
+
+       ASSERT_EQ(skel->bss->sk_cookie_seen, 0, "cookie_init");
+
+       fd_tc = bpf_program__fd(skel->progs.tc_main);
+       fd_accept = bpf_program__fd(skel->progs.reuse_accept);
+       fd_drop = bpf_program__fd(skel->progs.reuse_drop);
+       fd_map = bpf_map__fd(skel->maps.sk_map);
+
+       fd_sv = start_reuseport_server(family, sotype, ip, port, 100, 1);
+       if (!ASSERT_NEQ(fd_sv, NULL, "start_reuseport_server"))
+               goto cleanup;
+
+       ret = attach_reuseport(*fd_sv, fd_drop);
+       if (!ASSERT_OK(ret, "attach_reuseport"))
+               goto cleanup;
+
+       fd_val = *fd_sv;
+       ret = bpf_map_update_elem(fd_map, &zero, &fd_val, BPF_NOEXIST);
+       if (!ASSERT_OK(ret, "bpf_sk_map"))
+               goto cleanup;
+
+       ret = bpf_tc_hook_create(&tc_hook);
+       if (ret == 0)
+               hook_created = true;
+       ret = ret == -EEXIST ? 0 : ret;
+       if (!ASSERT_OK(ret, "bpf_tc_hook_create"))
+               goto cleanup;
+
+       tc_opts.prog_fd = fd_tc;
+       ret = bpf_tc_attach(&tc_hook, &tc_opts);
+       if (!ASSERT_OK(ret, "bpf_tc_attach"))
+               goto cleanup;
+       tc_attached = true;
+
+       if (sotype == SOCK_STREAM)
+               ASSERT_EQ(echo_test_tcp(*fd_sv), ECONNREFUSED, "drop_tcp");
+       else
+               ASSERT_EQ(echo_test_udp(*fd_sv), EAGAIN, "drop_udp");
+       ASSERT_EQ(skel->bss->reuseport_executed, 1, "program executed once");
+
+       skel->bss->sk_cookie_seen = 0;
+       skel->bss->reuseport_executed = 0;
+       ASSERT_OK(attach_reuseport(*fd_sv, fd_accept), "attach_reuseport(accept)");
+
+       if (sotype == SOCK_STREAM)
+               ASSERT_EQ(echo_test_tcp(*fd_sv), 0, "echo_tcp");
+       else
+               ASSERT_EQ(echo_test_udp(*fd_sv), 0, "echo_udp");
+
+       ASSERT_EQ(skel->bss->sk_cookie_seen, cookie(*fd_sv),
+                 "cookie_mismatch");
+       ASSERT_EQ(skel->bss->reuseport_executed, 1, "program executed once");
+cleanup:
+       if (tc_attached) {
+               tc_opts.flags = tc_opts.prog_fd = tc_opts.prog_id = 0;
+               ret = bpf_tc_detach(&tc_hook, &tc_opts);
+               ASSERT_OK(ret, "bpf_tc_detach");
+       }
+       if (hook_created) {
+               tc_hook.attach_point = BPF_TC_INGRESS | BPF_TC_EGRESS;
+               bpf_tc_hook_destroy(&tc_hook);
+       }
+       test_assign_reuse__destroy(skel);
+       free_fds(fd_sv, 1);
+}
+
+void test_assign_reuse(void)
+{
+       struct nstoken *tok = NULL;
+
+       SYS(out, "ip netns add %s", NS_TEST);
+       SYS(cleanup, "ip -net %s link set dev lo up", NS_TEST);
+
+       tok = open_netns(NS_TEST);
+       if (!ASSERT_OK_PTR(tok, "netns token"))
+               return;
+
+       if (test__start_subtest("tcpv4"))
+               run_assign_reuse(AF_INET, SOCK_STREAM, "127.0.0.1", PORT);
+       if (test__start_subtest("tcpv6"))
+               run_assign_reuse(AF_INET6, SOCK_STREAM, "::1", PORT);
+       if (test__start_subtest("udpv4"))
+               run_assign_reuse(AF_INET, SOCK_DGRAM, "127.0.0.1", PORT);
+       if (test__start_subtest("udpv6"))
+               run_assign_reuse(AF_INET6, SOCK_DGRAM, "::1", PORT);
+
+cleanup:
+       close_netns(tok);
+       SYS_NOFAIL("ip netns delete %s", NS_TEST);
+out:
+       return;
+}
diff --git a/tools/testing/selftests/bpf/progs/test_assign_reuse.c b/tools/testing/selftests/bpf/progs/test_assign_reuse.c
new file mode 100644 (file)
index 0000000..4f2e232
--- /dev/null
@@ -0,0 +1,142 @@
+// SPDX-License-Identifier: GPL-2.0
+/* Copyright (c) 2023 Isovalent */
+#include <stdbool.h>
+#include <linux/bpf.h>
+#include <linux/if_ether.h>
+#include <linux/in.h>
+#include <linux/ip.h>
+#include <linux/ipv6.h>
+#include <linux/tcp.h>
+#include <linux/udp.h>
+#include <bpf/bpf_endian.h>
+#include <bpf/bpf_helpers.h>
+#include <linux/pkt_cls.h>
+
+char LICENSE[] SEC("license") = "GPL";
+
+__u64 sk_cookie_seen;
+__u64 reuseport_executed;
+union {
+       struct tcphdr tcp;
+       struct udphdr udp;
+} headers;
+
+const volatile __u16 dest_port;
+
+struct {
+       __uint(type, BPF_MAP_TYPE_SOCKMAP);
+       __uint(max_entries, 1);
+       __type(key, __u32);
+       __type(value, __u64);
+} sk_map SEC(".maps");
+
+SEC("sk_reuseport")
+int reuse_accept(struct sk_reuseport_md *ctx)
+{
+       reuseport_executed++;
+
+       if (ctx->ip_protocol == IPPROTO_TCP) {
+               if (ctx->data + sizeof(headers.tcp) > ctx->data_end)
+                       return SK_DROP;
+
+               if (__builtin_memcmp(&headers.tcp, ctx->data, sizeof(headers.tcp)) != 0)
+                       return SK_DROP;
+       } else if (ctx->ip_protocol == IPPROTO_UDP) {
+               if (ctx->data + sizeof(headers.udp) > ctx->data_end)
+                       return SK_DROP;
+
+               if (__builtin_memcmp(&headers.udp, ctx->data, sizeof(headers.udp)) != 0)
+                       return SK_DROP;
+       } else {
+               return SK_DROP;
+       }
+
+       sk_cookie_seen = bpf_get_socket_cookie(ctx->sk);
+       return SK_PASS;
+}
+
+SEC("sk_reuseport")
+int reuse_drop(struct sk_reuseport_md *ctx)
+{
+       reuseport_executed++;
+       sk_cookie_seen = 0;
+       return SK_DROP;
+}
+
+static int
+assign_sk(struct __sk_buff *skb)
+{
+       int zero = 0, ret = 0;
+       struct bpf_sock *sk;
+
+       sk = bpf_map_lookup_elem(&sk_map, &zero);
+       if (!sk)
+               return TC_ACT_SHOT;
+       ret = bpf_sk_assign(skb, sk, 0);
+       bpf_sk_release(sk);
+       return ret ? TC_ACT_SHOT : TC_ACT_OK;
+}
+
+static bool
+maybe_assign_tcp(struct __sk_buff *skb, struct tcphdr *th)
+{
+       if (th + 1 > (void *)(long)(skb->data_end))
+               return TC_ACT_SHOT;
+
+       if (!th->syn || th->ack || th->dest != bpf_htons(dest_port))
+               return TC_ACT_OK;
+
+       __builtin_memcpy(&headers.tcp, th, sizeof(headers.tcp));
+       return assign_sk(skb);
+}
+
+static bool
+maybe_assign_udp(struct __sk_buff *skb, struct udphdr *uh)
+{
+       if (uh + 1 > (void *)(long)(skb->data_end))
+               return TC_ACT_SHOT;
+
+       if (uh->dest != bpf_htons(dest_port))
+               return TC_ACT_OK;
+
+       __builtin_memcpy(&headers.udp, uh, sizeof(headers.udp));
+       return assign_sk(skb);
+}
+
+SEC("tc")
+int tc_main(struct __sk_buff *skb)
+{
+       void *data_end = (void *)(long)skb->data_end;
+       void *data = (void *)(long)skb->data;
+       struct ethhdr *eth;
+
+       eth = (struct ethhdr *)(data);
+       if (eth + 1 > data_end)
+               return TC_ACT_SHOT;
+
+       if (eth->h_proto == bpf_htons(ETH_P_IP)) {
+               struct iphdr *iph = (struct iphdr *)(data + sizeof(*eth));
+
+               if (iph + 1 > data_end)
+                       return TC_ACT_SHOT;
+
+               if (iph->protocol == IPPROTO_TCP)
+                       return maybe_assign_tcp(skb, (struct tcphdr *)(iph + 1));
+               else if (iph->protocol == IPPROTO_UDP)
+                       return maybe_assign_udp(skb, (struct udphdr *)(iph + 1));
+               else
+                       return TC_ACT_SHOT;
+       } else {
+               struct ipv6hdr *ip6h = (struct ipv6hdr *)(data + sizeof(*eth));
+
+               if (ip6h + 1 > data_end)
+                       return TC_ACT_SHOT;
+
+               if (ip6h->nexthdr == IPPROTO_TCP)
+                       return maybe_assign_tcp(skb, (struct tcphdr *)(ip6h + 1));
+               else if (ip6h->nexthdr == IPPROTO_UDP)
+                       return maybe_assign_udp(skb, (struct udphdr *)(ip6h + 1));
+               else
+                       return TC_ACT_SHOT;
+       }
+}