bpf: add helper to check for a valid SYN cookie
[platform/kernel/linux-rpi.git] / net / core / filter.c
index 7559d68..d2511fe 100644 (file)
@@ -73,6 +73,7 @@
 #include <linux/seg6_local.h>
 #include <net/seg6.h>
 #include <net/seg6_local.h>
+#include <net/lwtunnel.h>
 
 /**
  *     sk_filter_trim_cap - run a packet through a socket filter
@@ -1793,6 +1794,18 @@ static const struct bpf_func_proto bpf_skb_pull_data_proto = {
        .arg2_type      = ARG_ANYTHING,
 };
 
+BPF_CALL_1(bpf_sk_fullsock, struct sock *, sk)
+{
+       return sk_fullsock(sk) ? (unsigned long)sk : (unsigned long)NULL;
+}
+
+static const struct bpf_func_proto bpf_sk_fullsock_proto = {
+       .func           = bpf_sk_fullsock,
+       .gpl_only       = false,
+       .ret_type       = RET_PTR_TO_SOCKET_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_SOCK_COMMON,
+};
+
 static inline int sk_skb_try_make_writable(struct sk_buff *skb,
                                           unsigned int write_len)
 {
@@ -2789,8 +2802,7 @@ static int bpf_skb_proto_4_to_6(struct sk_buff *skb)
        u32 off = skb_mac_header_len(skb);
        int ret;
 
-       /* SCTP uses GSO_BY_FRAGS, thus cannot adjust it. */
-       if (skb_is_gso(skb) && unlikely(skb_is_gso_sctp(skb)))
+       if (skb_is_gso(skb) && !skb_is_gso_tcp(skb))
                return -ENOTSUPP;
 
        ret = skb_cow(skb, len_diff);
@@ -2831,8 +2843,7 @@ static int bpf_skb_proto_6_to_4(struct sk_buff *skb)
        u32 off = skb_mac_header_len(skb);
        int ret;
 
-       /* SCTP uses GSO_BY_FRAGS, thus cannot adjust it. */
-       if (skb_is_gso(skb) && unlikely(skb_is_gso_sctp(skb)))
+       if (skb_is_gso(skb) && !skb_is_gso_tcp(skb))
                return -ENOTSUPP;
 
        ret = skb_unclone(skb, GFP_ATOMIC);
@@ -2957,8 +2968,7 @@ static int bpf_skb_net_grow(struct sk_buff *skb, u32 len_diff)
        u32 off = skb_mac_header_len(skb) + bpf_skb_net_base_len(skb);
        int ret;
 
-       /* SCTP uses GSO_BY_FRAGS, thus cannot adjust it. */
-       if (skb_is_gso(skb) && unlikely(skb_is_gso_sctp(skb)))
+       if (skb_is_gso(skb) && !skb_is_gso_tcp(skb))
                return -ENOTSUPP;
 
        ret = skb_cow(skb, len_diff);
@@ -2987,8 +2997,7 @@ static int bpf_skb_net_shrink(struct sk_buff *skb, u32 len_diff)
        u32 off = skb_mac_header_len(skb) + bpf_skb_net_base_len(skb);
        int ret;
 
-       /* SCTP uses GSO_BY_FRAGS, thus cannot adjust it. */
-       if (skb_is_gso(skb) && unlikely(skb_is_gso_sctp(skb)))
+       if (skb_is_gso(skb) && !skb_is_gso_tcp(skb))
                return -ENOTSUPP;
 
        ret = skb_unclone(skb, GFP_ATOMIC);
@@ -4112,10 +4121,12 @@ BPF_CALL_5(bpf_setsockopt, struct bpf_sock_ops_kern *, bpf_sock,
                /* Only some socketops are supported */
                switch (optname) {
                case SO_RCVBUF:
+                       val = min_t(u32, val, sysctl_rmem_max);
                        sk->sk_userlocks |= SOCK_RCVBUF_LOCK;
                        sk->sk_rcvbuf = max_t(int, val * 2, SOCK_MIN_RCVBUF);
                        break;
                case SO_SNDBUF:
+                       val = min_t(u32, val, sysctl_wmem_max);
                        sk->sk_userlocks |= SOCK_SNDBUF_LOCK;
                        sk->sk_sndbuf = max_t(int, val * 2, SOCK_MIN_SNDBUF);
                        break;
@@ -4801,7 +4812,15 @@ static int bpf_push_seg6_encap(struct sk_buff *skb, u32 type, void *hdr, u32 len
 }
 #endif /* CONFIG_IPV6_SEG6_BPF */
 
-BPF_CALL_4(bpf_lwt_push_encap, struct sk_buff *, skb, u32, type, void *, hdr,
+#if IS_ENABLED(CONFIG_LWTUNNEL_BPF)
+static int bpf_push_ip_encap(struct sk_buff *skb, void *hdr, u32 len,
+                            bool ingress)
+{
+       return bpf_lwt_push_ip_encap(skb, hdr, len, ingress);
+}
+#endif
+
+BPF_CALL_4(bpf_lwt_in_push_encap, struct sk_buff *, skb, u32, type, void *, hdr,
           u32, len)
 {
        switch (type) {
@@ -4810,13 +4829,40 @@ BPF_CALL_4(bpf_lwt_push_encap, struct sk_buff *, skb, u32, type, void *, hdr,
        case BPF_LWT_ENCAP_SEG6_INLINE:
                return bpf_push_seg6_encap(skb, type, hdr, len);
 #endif
+#if IS_ENABLED(CONFIG_LWTUNNEL_BPF)
+       case BPF_LWT_ENCAP_IP:
+               return bpf_push_ip_encap(skb, hdr, len, true /* ingress */);
+#endif
+       default:
+               return -EINVAL;
+       }
+}
+
+BPF_CALL_4(bpf_lwt_xmit_push_encap, struct sk_buff *, skb, u32, type,
+          void *, hdr, u32, len)
+{
+       switch (type) {
+#if IS_ENABLED(CONFIG_LWTUNNEL_BPF)
+       case BPF_LWT_ENCAP_IP:
+               return bpf_push_ip_encap(skb, hdr, len, false /* egress */);
+#endif
        default:
                return -EINVAL;
        }
 }
 
-static const struct bpf_func_proto bpf_lwt_push_encap_proto = {
-       .func           = bpf_lwt_push_encap,
+static const struct bpf_func_proto bpf_lwt_in_push_encap_proto = {
+       .func           = bpf_lwt_in_push_encap,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_ANYTHING,
+       .arg3_type      = ARG_PTR_TO_MEM,
+       .arg4_type      = ARG_CONST_SIZE
+};
+
+static const struct bpf_func_proto bpf_lwt_xmit_push_encap_proto = {
+       .func           = bpf_lwt_xmit_push_encap,
        .gpl_only       = false,
        .ret_type       = RET_INTEGER,
        .arg1_type      = ARG_PTR_TO_CTX,
@@ -5016,6 +5062,54 @@ static const struct bpf_func_proto bpf_lwt_seg6_adjust_srh_proto = {
 };
 #endif /* CONFIG_IPV6_SEG6_BPF */
 
+#define CONVERT_COMMON_TCP_SOCK_FIELDS(md_type, CONVERT)               \
+do {                                                                   \
+       switch (si->off) {                                              \
+       case offsetof(md_type, snd_cwnd):                               \
+               CONVERT(snd_cwnd); break;                               \
+       case offsetof(md_type, srtt_us):                                \
+               CONVERT(srtt_us); break;                                \
+       case offsetof(md_type, snd_ssthresh):                           \
+               CONVERT(snd_ssthresh); break;                           \
+       case offsetof(md_type, rcv_nxt):                                \
+               CONVERT(rcv_nxt); break;                                \
+       case offsetof(md_type, snd_nxt):                                \
+               CONVERT(snd_nxt); break;                                \
+       case offsetof(md_type, snd_una):                                \
+               CONVERT(snd_una); break;                                \
+       case offsetof(md_type, mss_cache):                              \
+               CONVERT(mss_cache); break;                              \
+       case offsetof(md_type, ecn_flags):                              \
+               CONVERT(ecn_flags); break;                              \
+       case offsetof(md_type, rate_delivered):                         \
+               CONVERT(rate_delivered); break;                         \
+       case offsetof(md_type, rate_interval_us):                       \
+               CONVERT(rate_interval_us); break;                       \
+       case offsetof(md_type, packets_out):                            \
+               CONVERT(packets_out); break;                            \
+       case offsetof(md_type, retrans_out):                            \
+               CONVERT(retrans_out); break;                            \
+       case offsetof(md_type, total_retrans):                          \
+               CONVERT(total_retrans); break;                          \
+       case offsetof(md_type, segs_in):                                \
+               CONVERT(segs_in); break;                                \
+       case offsetof(md_type, data_segs_in):                           \
+               CONVERT(data_segs_in); break;                           \
+       case offsetof(md_type, segs_out):                               \
+               CONVERT(segs_out); break;                               \
+       case offsetof(md_type, data_segs_out):                          \
+               CONVERT(data_segs_out); break;                          \
+       case offsetof(md_type, lost_out):                               \
+               CONVERT(lost_out); break;                               \
+       case offsetof(md_type, sacked_out):                             \
+               CONVERT(sacked_out); break;                             \
+       case offsetof(md_type, bytes_received):                         \
+               CONVERT(bytes_received); break;                         \
+       case offsetof(md_type, bytes_acked):                            \
+               CONVERT(bytes_acked); break;                            \
+       }                                                               \
+} while (0)
+
 #ifdef CONFIG_INET
 static struct sock *sk_lookup(struct net *net, struct bpf_sock_tuple *tuple,
                              int dif, int sdif, u8 family, u8 proto)
@@ -5062,15 +5156,15 @@ static struct sock *sk_lookup(struct net *net, struct bpf_sock_tuple *tuple,
        return sk;
 }
 
-/* bpf_sk_lookup performs the core lookup for different types of sockets,
+/* bpf_skc_lookup performs the core lookup for different types of sockets,
  * taking a reference on the socket if it doesn't have the flag SOCK_RCU_FREE.
  * Returns the socket as an 'unsigned long' to simplify the casting in the
  * callers to satisfy BPF_CALL declarations.
  */
-static unsigned long
-__bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
-               struct net *caller_net, u32 ifindex, u8 proto, u64 netns_id,
-               u64 flags)
+static struct sock *
+__bpf_skc_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
+                struct net *caller_net, u32 ifindex, u8 proto, u64 netns_id,
+                u64 flags)
 {
        struct sock *sk = NULL;
        u8 family = AF_UNSPEC;
@@ -5098,15 +5192,27 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
                put_net(net);
        }
 
+out:
+       return sk;
+}
+
+static struct sock *
+__bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
+               struct net *caller_net, u32 ifindex, u8 proto, u64 netns_id,
+               u64 flags)
+{
+       struct sock *sk = __bpf_skc_lookup(skb, tuple, len, caller_net,
+                                          ifindex, proto, netns_id, flags);
+
        if (sk)
                sk = sk_to_full_sk(sk);
-out:
-       return (unsigned long) sk;
+
+       return sk;
 }
 
-static unsigned long
-bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
-             u8 proto, u64 netns_id, u64 flags)
+static struct sock *
+bpf_skc_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
+              u8 proto, u64 netns_id, u64 flags)
 {
        struct net *caller_net;
        int ifindex;
@@ -5119,14 +5225,47 @@ bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
                ifindex = 0;
        }
 
-       return __bpf_sk_lookup(skb, tuple, len, caller_net, ifindex,
-                             proto, netns_id, flags);
+       return __bpf_skc_lookup(skb, tuple, len, caller_net, ifindex, proto,
+                               netns_id, flags);
 }
 
+static struct sock *
+bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
+             u8 proto, u64 netns_id, u64 flags)
+{
+       struct sock *sk = bpf_skc_lookup(skb, tuple, len, proto, netns_id,
+                                        flags);
+
+       if (sk)
+               sk = sk_to_full_sk(sk);
+
+       return sk;
+}
+
+BPF_CALL_5(bpf_skc_lookup_tcp, struct sk_buff *, skb,
+          struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
+{
+       return (unsigned long)bpf_skc_lookup(skb, tuple, len, IPPROTO_TCP,
+                                            netns_id, flags);
+}
+
+static const struct bpf_func_proto bpf_skc_lookup_tcp_proto = {
+       .func           = bpf_skc_lookup_tcp,
+       .gpl_only       = false,
+       .pkt_access     = true,
+       .ret_type       = RET_PTR_TO_SOCK_COMMON_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_PTR_TO_MEM,
+       .arg3_type      = ARG_CONST_SIZE,
+       .arg4_type      = ARG_ANYTHING,
+       .arg5_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_5(bpf_sk_lookup_tcp, struct sk_buff *, skb,
           struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
 {
-       return bpf_sk_lookup(skb, tuple, len, IPPROTO_TCP, netns_id, flags);
+       return (unsigned long)bpf_sk_lookup(skb, tuple, len, IPPROTO_TCP,
+                                           netns_id, flags);
 }
 
 static const struct bpf_func_proto bpf_sk_lookup_tcp_proto = {
@@ -5144,7 +5283,8 @@ static const struct bpf_func_proto bpf_sk_lookup_tcp_proto = {
 BPF_CALL_5(bpf_sk_lookup_udp, struct sk_buff *, skb,
           struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
 {
-       return bpf_sk_lookup(skb, tuple, len, IPPROTO_UDP, netns_id, flags);
+       return (unsigned long)bpf_sk_lookup(skb, tuple, len, IPPROTO_UDP,
+                                           netns_id, flags);
 }
 
 static const struct bpf_func_proto bpf_sk_lookup_udp_proto = {
@@ -5170,7 +5310,7 @@ static const struct bpf_func_proto bpf_sk_release_proto = {
        .func           = bpf_sk_release,
        .gpl_only       = false,
        .ret_type       = RET_INTEGER,
-       .arg1_type      = ARG_PTR_TO_SOCKET,
+       .arg1_type      = ARG_PTR_TO_SOCK_COMMON,
 };
 
 BPF_CALL_5(bpf_xdp_sk_lookup_udp, struct xdp_buff *, ctx,
@@ -5179,8 +5319,9 @@ BPF_CALL_5(bpf_xdp_sk_lookup_udp, struct xdp_buff *, ctx,
        struct net *caller_net = dev_net(ctx->rxq->dev);
        int ifindex = ctx->rxq->dev->ifindex;
 
-       return __bpf_sk_lookup(NULL, tuple, len, caller_net, ifindex,
-                             IPPROTO_UDP, netns_id, flags);
+       return (unsigned long)__bpf_sk_lookup(NULL, tuple, len, caller_net,
+                                             ifindex, IPPROTO_UDP, netns_id,
+                                             flags);
 }
 
 static const struct bpf_func_proto bpf_xdp_sk_lookup_udp_proto = {
@@ -5195,14 +5336,38 @@ static const struct bpf_func_proto bpf_xdp_sk_lookup_udp_proto = {
        .arg5_type      = ARG_ANYTHING,
 };
 
+BPF_CALL_5(bpf_xdp_skc_lookup_tcp, struct xdp_buff *, ctx,
+          struct bpf_sock_tuple *, tuple, u32, len, u32, netns_id, u64, flags)
+{
+       struct net *caller_net = dev_net(ctx->rxq->dev);
+       int ifindex = ctx->rxq->dev->ifindex;
+
+       return (unsigned long)__bpf_skc_lookup(NULL, tuple, len, caller_net,
+                                              ifindex, IPPROTO_TCP, netns_id,
+                                              flags);
+}
+
+static const struct bpf_func_proto bpf_xdp_skc_lookup_tcp_proto = {
+       .func           = bpf_xdp_skc_lookup_tcp,
+       .gpl_only       = false,
+       .pkt_access     = true,
+       .ret_type       = RET_PTR_TO_SOCK_COMMON_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_PTR_TO_MEM,
+       .arg3_type      = ARG_CONST_SIZE,
+       .arg4_type      = ARG_ANYTHING,
+       .arg5_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_5(bpf_xdp_sk_lookup_tcp, struct xdp_buff *, ctx,
           struct bpf_sock_tuple *, tuple, u32, len, u32, netns_id, u64, flags)
 {
        struct net *caller_net = dev_net(ctx->rxq->dev);
        int ifindex = ctx->rxq->dev->ifindex;
 
-       return __bpf_sk_lookup(NULL, tuple, len, caller_net, ifindex,
-                             IPPROTO_TCP, netns_id, flags);
+       return (unsigned long)__bpf_sk_lookup(NULL, tuple, len, caller_net,
+                                             ifindex, IPPROTO_TCP, netns_id,
+                                             flags);
 }
 
 static const struct bpf_func_proto bpf_xdp_sk_lookup_tcp_proto = {
@@ -5217,11 +5382,31 @@ static const struct bpf_func_proto bpf_xdp_sk_lookup_tcp_proto = {
        .arg5_type      = ARG_ANYTHING,
 };
 
+BPF_CALL_5(bpf_sock_addr_skc_lookup_tcp, struct bpf_sock_addr_kern *, ctx,
+          struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
+{
+       return (unsigned long)__bpf_skc_lookup(NULL, tuple, len,
+                                              sock_net(ctx->sk), 0,
+                                              IPPROTO_TCP, netns_id, flags);
+}
+
+static const struct bpf_func_proto bpf_sock_addr_skc_lookup_tcp_proto = {
+       .func           = bpf_sock_addr_skc_lookup_tcp,
+       .gpl_only       = false,
+       .ret_type       = RET_PTR_TO_SOCK_COMMON_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_PTR_TO_MEM,
+       .arg3_type      = ARG_CONST_SIZE,
+       .arg4_type      = ARG_ANYTHING,
+       .arg5_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_5(bpf_sock_addr_sk_lookup_tcp, struct bpf_sock_addr_kern *, ctx,
           struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
 {
-       return __bpf_sk_lookup(NULL, tuple, len, sock_net(ctx->sk), 0,
-                              IPPROTO_TCP, netns_id, flags);
+       return (unsigned long)__bpf_sk_lookup(NULL, tuple, len,
+                                             sock_net(ctx->sk), 0, IPPROTO_TCP,
+                                             netns_id, flags);
 }
 
 static const struct bpf_func_proto bpf_sock_addr_sk_lookup_tcp_proto = {
@@ -5238,8 +5423,9 @@ static const struct bpf_func_proto bpf_sock_addr_sk_lookup_tcp_proto = {
 BPF_CALL_5(bpf_sock_addr_sk_lookup_udp, struct bpf_sock_addr_kern *, ctx,
           struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
 {
-       return __bpf_sk_lookup(NULL, tuple, len, sock_net(ctx->sk), 0,
-                              IPPROTO_UDP, netns_id, flags);
+       return (unsigned long)__bpf_sk_lookup(NULL, tuple, len,
+                                             sock_net(ctx->sk), 0, IPPROTO_UDP,
+                                             netns_id, flags);
 }
 
 static const struct bpf_func_proto bpf_sock_addr_sk_lookup_udp_proto = {
@@ -5253,6 +5439,188 @@ static const struct bpf_func_proto bpf_sock_addr_sk_lookup_udp_proto = {
        .arg5_type      = ARG_ANYTHING,
 };
 
+bool bpf_tcp_sock_is_valid_access(int off, int size, enum bpf_access_type type,
+                                 struct bpf_insn_access_aux *info)
+{
+       if (off < 0 || off >= offsetofend(struct bpf_tcp_sock, bytes_acked))
+               return false;
+
+       if (off % size != 0)
+               return false;
+
+       switch (off) {
+       case offsetof(struct bpf_tcp_sock, bytes_received):
+       case offsetof(struct bpf_tcp_sock, bytes_acked):
+               return size == sizeof(__u64);
+       default:
+               return size == sizeof(__u32);
+       }
+}
+
+u32 bpf_tcp_sock_convert_ctx_access(enum bpf_access_type type,
+                                   const struct bpf_insn *si,
+                                   struct bpf_insn *insn_buf,
+                                   struct bpf_prog *prog, u32 *target_size)
+{
+       struct bpf_insn *insn = insn_buf;
+
+#define BPF_TCP_SOCK_GET_COMMON(FIELD)                                 \
+       do {                                                            \
+               BUILD_BUG_ON(FIELD_SIZEOF(struct tcp_sock, FIELD) >     \
+                            FIELD_SIZEOF(struct bpf_tcp_sock, FIELD)); \
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct tcp_sock, FIELD),\
+                                     si->dst_reg, si->src_reg,         \
+                                     offsetof(struct tcp_sock, FIELD)); \
+       } while (0)
+
+       CONVERT_COMMON_TCP_SOCK_FIELDS(struct bpf_tcp_sock,
+                                      BPF_TCP_SOCK_GET_COMMON);
+
+       if (insn > insn_buf)
+               return insn - insn_buf;
+
+       switch (si->off) {
+       case offsetof(struct bpf_tcp_sock, rtt_min):
+               BUILD_BUG_ON(FIELD_SIZEOF(struct tcp_sock, rtt_min) !=
+                            sizeof(struct minmax));
+               BUILD_BUG_ON(sizeof(struct minmax) <
+                            sizeof(struct minmax_sample));
+
+               *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg,
+                                     offsetof(struct tcp_sock, rtt_min) +
+                                     offsetof(struct minmax_sample, v));
+               break;
+       }
+
+       return insn - insn_buf;
+}
+
+BPF_CALL_1(bpf_tcp_sock, struct sock *, sk)
+{
+       if (sk_fullsock(sk) && sk->sk_protocol == IPPROTO_TCP)
+               return (unsigned long)sk;
+
+       return (unsigned long)NULL;
+}
+
+static const struct bpf_func_proto bpf_tcp_sock_proto = {
+       .func           = bpf_tcp_sock,
+       .gpl_only       = false,
+       .ret_type       = RET_PTR_TO_TCP_SOCK_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_SOCK_COMMON,
+};
+
+BPF_CALL_1(bpf_get_listener_sock, struct sock *, sk)
+{
+       sk = sk_to_full_sk(sk);
+
+       if (sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE))
+               return (unsigned long)sk;
+
+       return (unsigned long)NULL;
+}
+
+static const struct bpf_func_proto bpf_get_listener_sock_proto = {
+       .func           = bpf_get_listener_sock,
+       .gpl_only       = false,
+       .ret_type       = RET_PTR_TO_SOCKET_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_SOCK_COMMON,
+};
+
+BPF_CALL_1(bpf_skb_ecn_set_ce, struct sk_buff *, skb)
+{
+       unsigned int iphdr_len;
+
+       if (skb->protocol == cpu_to_be16(ETH_P_IP))
+               iphdr_len = sizeof(struct iphdr);
+       else if (skb->protocol == cpu_to_be16(ETH_P_IPV6))
+               iphdr_len = sizeof(struct ipv6hdr);
+       else
+               return 0;
+
+       if (skb_headlen(skb) < iphdr_len)
+               return 0;
+
+       if (skb_cloned(skb) && !skb_clone_writable(skb, iphdr_len))
+               return 0;
+
+       return INET_ECN_set_ce(skb);
+}
+
+static const struct bpf_func_proto bpf_skb_ecn_set_ce_proto = {
+       .func           = bpf_skb_ecn_set_ce,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+};
+
+BPF_CALL_5(bpf_tcp_check_syncookie, struct sock *, sk, void *, iph, u32, iph_len,
+          struct tcphdr *, th, u32, th_len)
+{
+#ifdef CONFIG_SYN_COOKIES
+       u32 cookie;
+       int ret;
+
+       if (unlikely(th_len < sizeof(*th)))
+               return -EINVAL;
+
+       /* sk_listener() allows TCP_NEW_SYN_RECV, which makes no sense here. */
+       if (sk->sk_protocol != IPPROTO_TCP || sk->sk_state != TCP_LISTEN)
+               return -EINVAL;
+
+       if (!sock_net(sk)->ipv4.sysctl_tcp_syncookies)
+               return -EINVAL;
+
+       if (!th->ack || th->rst || th->syn)
+               return -ENOENT;
+
+       if (tcp_synq_no_recent_overflow(sk))
+               return -ENOENT;
+
+       cookie = ntohl(th->ack_seq) - 1;
+
+       switch (sk->sk_family) {
+       case AF_INET:
+               if (unlikely(iph_len < sizeof(struct iphdr)))
+                       return -EINVAL;
+
+               ret = __cookie_v4_check((struct iphdr *)iph, th, cookie);
+               break;
+
+#if IS_BUILTIN(CONFIG_IPV6)
+       case AF_INET6:
+               if (unlikely(iph_len < sizeof(struct ipv6hdr)))
+                       return -EINVAL;
+
+               ret = __cookie_v6_check((struct ipv6hdr *)iph, th, cookie);
+               break;
+#endif /* CONFIG_IPV6 */
+
+       default:
+               return -EPROTONOSUPPORT;
+       }
+
+       if (ret > 0)
+               return 0;
+
+       return -ENOENT;
+#else
+       return -ENOTSUPP;
+#endif
+}
+
+static const struct bpf_func_proto bpf_tcp_check_syncookie_proto = {
+       .func           = bpf_tcp_check_syncookie,
+       .gpl_only       = true,
+       .pkt_access     = true,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_SOCK_COMMON,
+       .arg2_type      = ARG_PTR_TO_MEM,
+       .arg3_type      = ARG_CONST_SIZE,
+       .arg4_type      = ARG_PTR_TO_MEM,
+       .arg5_type      = ARG_CONST_SIZE,
+};
+
 #endif /* CONFIG_INET */
 
 bool bpf_helper_changes_pkt_data(void *func)
@@ -5282,7 +5650,8 @@ bool bpf_helper_changes_pkt_data(void *func)
            func == bpf_lwt_seg6_adjust_srh ||
            func == bpf_lwt_seg6_action ||
 #endif
-           func == bpf_lwt_push_encap)
+           func == bpf_lwt_in_push_encap ||
+           func == bpf_lwt_xmit_push_encap)
                return true;
 
        return false;
@@ -5314,10 +5683,20 @@ bpf_base_func_proto(enum bpf_func_id func_id)
                return &bpf_tail_call_proto;
        case BPF_FUNC_ktime_get_ns:
                return &bpf_ktime_get_ns_proto;
+       default:
+               break;
+       }
+
+       if (!capable(CAP_SYS_ADMIN))
+               return NULL;
+
+       switch (func_id) {
+       case BPF_FUNC_spin_lock:
+               return &bpf_spin_lock_proto;
+       case BPF_FUNC_spin_unlock:
+               return &bpf_spin_unlock_proto;
        case BPF_FUNC_trace_printk:
-               if (capable(CAP_SYS_ADMIN))
-                       return bpf_get_trace_printk_proto();
-               /* else, fall through */
+               return bpf_get_trace_printk_proto();
        default:
                return NULL;
        }
@@ -5367,6 +5746,8 @@ sock_addr_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_sock_addr_sk_lookup_udp_proto;
        case BPF_FUNC_sk_release:
                return &bpf_sk_release_proto;
+       case BPF_FUNC_skc_lookup_tcp:
+               return &bpf_sock_addr_skc_lookup_tcp_proto;
 #endif /* CONFIG_INET */
        default:
                return bpf_base_func_proto(func_id);
@@ -5396,6 +5777,16 @@ cg_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
        switch (func_id) {
        case BPF_FUNC_get_local_storage:
                return &bpf_get_local_storage_proto;
+       case BPF_FUNC_sk_fullsock:
+               return &bpf_sk_fullsock_proto;
+#ifdef CONFIG_INET
+       case BPF_FUNC_tcp_sock:
+               return &bpf_tcp_sock_proto;
+       case BPF_FUNC_get_listener_sock:
+               return &bpf_get_listener_sock_proto;
+       case BPF_FUNC_skb_ecn_set_ce:
+               return &bpf_skb_ecn_set_ce_proto;
+#endif
        default:
                return sk_filter_func_proto(func_id, prog);
        }
@@ -5467,6 +5858,8 @@ tc_cls_act_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_get_socket_uid_proto;
        case BPF_FUNC_fib_lookup:
                return &bpf_skb_fib_lookup_proto;
+       case BPF_FUNC_sk_fullsock:
+               return &bpf_sk_fullsock_proto;
 #ifdef CONFIG_XFRM
        case BPF_FUNC_skb_get_xfrm_state:
                return &bpf_skb_get_xfrm_state_proto;
@@ -5484,6 +5877,14 @@ tc_cls_act_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_sk_lookup_udp_proto;
        case BPF_FUNC_sk_release:
                return &bpf_sk_release_proto;
+       case BPF_FUNC_tcp_sock:
+               return &bpf_tcp_sock_proto;
+       case BPF_FUNC_get_listener_sock:
+               return &bpf_get_listener_sock_proto;
+       case BPF_FUNC_skc_lookup_tcp:
+               return &bpf_skc_lookup_tcp_proto;
+       case BPF_FUNC_tcp_check_syncookie:
+               return &bpf_tcp_check_syncookie_proto;
 #endif
        default:
                return bpf_base_func_proto(func_id);
@@ -5519,6 +5920,10 @@ xdp_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_xdp_sk_lookup_tcp_proto;
        case BPF_FUNC_sk_release:
                return &bpf_sk_release_proto;
+       case BPF_FUNC_skc_lookup_tcp:
+               return &bpf_xdp_skc_lookup_tcp_proto;
+       case BPF_FUNC_tcp_check_syncookie:
+               return &bpf_tcp_check_syncookie_proto;
 #endif
        default:
                return bpf_base_func_proto(func_id);
@@ -5611,6 +6016,8 @@ sk_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_sk_lookup_udp_proto;
        case BPF_FUNC_sk_release:
                return &bpf_sk_release_proto;
+       case BPF_FUNC_skc_lookup_tcp:
+               return &bpf_skc_lookup_tcp_proto;
 #endif
        default:
                return bpf_base_func_proto(func_id);
@@ -5660,7 +6067,7 @@ lwt_in_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
 {
        switch (func_id) {
        case BPF_FUNC_lwt_push_encap:
-               return &bpf_lwt_push_encap_proto;
+               return &bpf_lwt_in_push_encap_proto;
        default:
                return lwt_out_func_proto(func_id, prog);
        }
@@ -5696,6 +6103,8 @@ lwt_xmit_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_l4_csum_replace_proto;
        case BPF_FUNC_set_hash_invalid:
                return &bpf_set_hash_invalid_proto;
+       case BPF_FUNC_lwt_push_encap:
+               return &bpf_lwt_xmit_push_encap_proto;
        default:
                return lwt_out_func_proto(func_id, prog);
        }
@@ -5754,6 +6163,11 @@ static bool bpf_skb_is_valid_access(int off, int size, enum bpf_access_type type
                if (size != sizeof(__u64))
                        return false;
                break;
+       case offsetof(struct __sk_buff, sk):
+               if (type == BPF_WRITE || size != sizeof(__u64))
+                       return false;
+               info->reg_type = PTR_TO_SOCK_COMMON_OR_NULL;
+               break;
        default:
                /* Only narrow read access allowed for now. */
                if (type == BPF_WRITE) {
@@ -5925,31 +6339,44 @@ full_access:
        return true;
 }
 
-static bool __sock_filter_check_size(int off, int size,
+bool bpf_sock_common_is_valid_access(int off, int size,
+                                    enum bpf_access_type type,
                                     struct bpf_insn_access_aux *info)
 {
-       const int size_default = sizeof(__u32);
-
        switch (off) {
-       case bpf_ctx_range(struct bpf_sock, src_ip4):
-       case bpf_ctx_range_till(struct bpf_sock, src_ip6[0], src_ip6[3]):
-               bpf_ctx_record_field_size(info, size_default);
-               return bpf_ctx_narrow_access_ok(off, size, size_default);
+       case bpf_ctx_range_till(struct bpf_sock, type, priority):
+               return false;
+       default:
+               return bpf_sock_is_valid_access(off, size, type, info);
        }
-
-       return size == size_default;
 }
 
 bool bpf_sock_is_valid_access(int off, int size, enum bpf_access_type type,
                              struct bpf_insn_access_aux *info)
 {
+       const int size_default = sizeof(__u32);
+
        if (off < 0 || off >= sizeof(struct bpf_sock))
                return false;
        if (off % size != 0)
                return false;
-       if (!__sock_filter_check_size(off, size, info))
-               return false;
-       return true;
+
+       switch (off) {
+       case offsetof(struct bpf_sock, state):
+       case offsetof(struct bpf_sock, family):
+       case offsetof(struct bpf_sock, type):
+       case offsetof(struct bpf_sock, protocol):
+       case offsetof(struct bpf_sock, dst_port):
+       case offsetof(struct bpf_sock, src_port):
+       case bpf_ctx_range(struct bpf_sock, src_ip4):
+       case bpf_ctx_range_till(struct bpf_sock, src_ip6[0], src_ip6[3]):
+       case bpf_ctx_range(struct bpf_sock, dst_ip4):
+       case bpf_ctx_range_till(struct bpf_sock, dst_ip6[0], dst_ip6[3]):
+               bpf_ctx_record_field_size(info, size_default);
+               return bpf_ctx_narrow_access_ok(off, size, size_default);
+       }
+
+       return size == size_default;
 }
 
 static bool sock_filter_is_valid_access(int off, int size,
@@ -6065,6 +6492,7 @@ static bool tc_cls_act_is_valid_access(int off, int size,
                case bpf_ctx_range(struct __sk_buff, tc_classid):
                case bpf_ctx_range_till(struct __sk_buff, cb[0], cb[4]):
                case bpf_ctx_range(struct __sk_buff, tstamp):
+               case bpf_ctx_range(struct __sk_buff, queue_mapping):
                        break;
                default:
                        return false;
@@ -6469,9 +6897,18 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type,
                break;
 
        case offsetof(struct __sk_buff, queue_mapping):
-               *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg,
-                                     bpf_target_off(struct sk_buff, queue_mapping, 2,
-                                                    target_size));
+               if (type == BPF_WRITE) {
+                       *insn++ = BPF_JMP_IMM(BPF_JGE, si->src_reg, NO_QUEUE_MAPPING, 1);
+                       *insn++ = BPF_STX_MEM(BPF_H, si->dst_reg, si->src_reg,
+                                             bpf_target_off(struct sk_buff,
+                                                            queue_mapping,
+                                                            2, target_size));
+               } else {
+                       *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg,
+                                             bpf_target_off(struct sk_buff,
+                                                            queue_mapping,
+                                                            2, target_size));
+               }
                break;
 
        case offsetof(struct __sk_buff, vlan_present):
@@ -6708,6 +7145,27 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type,
                                                             target_size));
                break;
 
+       case offsetof(struct __sk_buff, gso_segs):
+               /* si->dst_reg = skb_shinfo(SKB); */
+#ifdef NET_SKBUFF_DATA_USES_OFFSET
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, head),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct sk_buff, head));
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, end),
+                                     BPF_REG_AX, si->src_reg,
+                                     offsetof(struct sk_buff, end));
+               *insn++ = BPF_ALU64_REG(BPF_ADD, si->dst_reg, BPF_REG_AX);
+#else
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, end),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct sk_buff, end));
+#endif
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct skb_shared_info, gso_segs),
+                                     si->dst_reg, si->dst_reg,
+                                     bpf_target_off(struct skb_shared_info,
+                                                    gso_segs, 2,
+                                                    target_size));
+               break;
        case offsetof(struct __sk_buff, wire_len):
                BUILD_BUG_ON(FIELD_SIZEOF(struct qdisc_skb_cb, pkt_len) != 4);
 
@@ -6717,6 +7175,13 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type,
                off += offsetof(struct qdisc_skb_cb, pkt_len);
                *target_size = 4;
                *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, off);
+               break;
+
+       case offsetof(struct __sk_buff, sk):
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct sk_buff, sk));
+               break;
        }
 
        return insn - insn_buf;
@@ -6765,24 +7230,32 @@ u32 bpf_sock_convert_ctx_access(enum bpf_access_type type,
                break;
 
        case offsetof(struct bpf_sock, family):
-               BUILD_BUG_ON(FIELD_SIZEOF(struct sock, sk_family) != 2);
-
-               *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg,
-                                     offsetof(struct sock, sk_family));
+               *insn++ = BPF_LDX_MEM(
+                       BPF_FIELD_SIZEOF(struct sock_common, skc_family),
+                       si->dst_reg, si->src_reg,
+                       bpf_target_off(struct sock_common,
+                                      skc_family,
+                                      FIELD_SIZEOF(struct sock_common,
+                                                   skc_family),
+                                      target_size));
                break;
 
        case offsetof(struct bpf_sock, type):
+               BUILD_BUG_ON(HWEIGHT32(SK_FL_TYPE_MASK) != BITS_PER_BYTE * 2);
                *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg,
                                      offsetof(struct sock, __sk_flags_offset));
                *insn++ = BPF_ALU32_IMM(BPF_AND, si->dst_reg, SK_FL_TYPE_MASK);
                *insn++ = BPF_ALU32_IMM(BPF_RSH, si->dst_reg, SK_FL_TYPE_SHIFT);
+               *target_size = 2;
                break;
 
        case offsetof(struct bpf_sock, protocol):
+               BUILD_BUG_ON(HWEIGHT32(SK_FL_PROTO_MASK) != BITS_PER_BYTE);
                *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg,
                                      offsetof(struct sock, __sk_flags_offset));
                *insn++ = BPF_ALU32_IMM(BPF_AND, si->dst_reg, SK_FL_PROTO_MASK);
                *insn++ = BPF_ALU32_IMM(BPF_RSH, si->dst_reg, SK_FL_PROTO_SHIFT);
+               *target_size = 1;
                break;
 
        case offsetof(struct bpf_sock, src_ip4):
@@ -6794,6 +7267,15 @@ u32 bpf_sock_convert_ctx_access(enum bpf_access_type type,
                                       target_size));
                break;
 
+       case offsetof(struct bpf_sock, dst_ip4):
+               *insn++ = BPF_LDX_MEM(
+                       BPF_SIZE(si->code), si->dst_reg, si->src_reg,
+                       bpf_target_off(struct sock_common, skc_daddr,
+                                      FIELD_SIZEOF(struct sock_common,
+                                                   skc_daddr),
+                                      target_size));
+               break;
+
        case bpf_ctx_range_till(struct bpf_sock, src_ip6[0], src_ip6[3]):
 #if IS_ENABLED(CONFIG_IPV6)
                off = si->off;
@@ -6812,6 +7294,23 @@ u32 bpf_sock_convert_ctx_access(enum bpf_access_type type,
 #endif
                break;
 
+       case bpf_ctx_range_till(struct bpf_sock, dst_ip6[0], dst_ip6[3]):
+#if IS_ENABLED(CONFIG_IPV6)
+               off = si->off;
+               off -= offsetof(struct bpf_sock, dst_ip6[0]);
+               *insn++ = BPF_LDX_MEM(
+                       BPF_SIZE(si->code), si->dst_reg, si->src_reg,
+                       bpf_target_off(struct sock_common,
+                                      skc_v6_daddr.s6_addr32[0],
+                                      FIELD_SIZEOF(struct sock_common,
+                                                   skc_v6_daddr.s6_addr32[0]),
+                                      target_size) + off);
+#else
+               *insn++ = BPF_MOV32_IMM(si->dst_reg, 0);
+               *target_size = 4;
+#endif
+               break;
+
        case offsetof(struct bpf_sock, src_port):
                *insn++ = BPF_LDX_MEM(
                        BPF_FIELD_SIZEOF(struct sock_common, skc_num),
@@ -6821,6 +7320,26 @@ u32 bpf_sock_convert_ctx_access(enum bpf_access_type type,
                                                    skc_num),
                                       target_size));
                break;
+
+       case offsetof(struct bpf_sock, dst_port):
+               *insn++ = BPF_LDX_MEM(
+                       BPF_FIELD_SIZEOF(struct sock_common, skc_dport),
+                       si->dst_reg, si->src_reg,
+                       bpf_target_off(struct sock_common, skc_dport,
+                                      FIELD_SIZEOF(struct sock_common,
+                                                   skc_dport),
+                                      target_size));
+               break;
+
+       case offsetof(struct bpf_sock, state):
+               *insn++ = BPF_LDX_MEM(
+                       BPF_FIELD_SIZEOF(struct sock_common, skc_state),
+                       si->dst_reg, si->src_reg,
+                       bpf_target_off(struct sock_common, skc_state,
+                                      FIELD_SIZEOF(struct sock_common,
+                                                   skc_state),
+                                      target_size));
+               break;
        }
 
        return insn - insn_buf;
@@ -7068,6 +7587,85 @@ static u32 sock_ops_convert_ctx_access(enum bpf_access_type type,
        struct bpf_insn *insn = insn_buf;
        int off;
 
+/* Helper macro for adding read access to tcp_sock or sock fields. */
+#define SOCK_OPS_GET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ)                        \
+       do {                                                                  \
+               BUILD_BUG_ON(FIELD_SIZEOF(OBJ, OBJ_FIELD) >                   \
+                            FIELD_SIZEOF(struct bpf_sock_ops, BPF_FIELD));   \
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(                       \
+                                               struct bpf_sock_ops_kern,     \
+                                               is_fullsock),                 \
+                                     si->dst_reg, si->src_reg,               \
+                                     offsetof(struct bpf_sock_ops_kern,      \
+                                              is_fullsock));                 \
+               *insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 2);            \
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(                       \
+                                               struct bpf_sock_ops_kern, sk),\
+                                     si->dst_reg, si->src_reg,               \
+                                     offsetof(struct bpf_sock_ops_kern, sk));\
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(OBJ,                   \
+                                                      OBJ_FIELD),            \
+                                     si->dst_reg, si->dst_reg,               \
+                                     offsetof(OBJ, OBJ_FIELD));              \
+       } while (0)
+
+#define SOCK_OPS_GET_TCP_SOCK_FIELD(FIELD) \
+               SOCK_OPS_GET_FIELD(FIELD, FIELD, struct tcp_sock)
+
+/* Helper macro for adding write access to tcp_sock or sock fields.
+ * The macro is called with two registers, dst_reg which contains a pointer
+ * to ctx (context) and src_reg which contains the value that should be
+ * stored. However, we need an additional register since we cannot overwrite
+ * dst_reg because it may be used later in the program.
+ * Instead we "borrow" one of the other register. We first save its value
+ * into a new (temp) field in bpf_sock_ops_kern, use it, and then restore
+ * it at the end of the macro.
+ */
+#define SOCK_OPS_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ)                        \
+       do {                                                                  \
+               int reg = BPF_REG_9;                                          \
+               BUILD_BUG_ON(FIELD_SIZEOF(OBJ, OBJ_FIELD) >                   \
+                            FIELD_SIZEOF(struct bpf_sock_ops, BPF_FIELD));   \
+               if (si->dst_reg == reg || si->src_reg == reg)                 \
+                       reg--;                                                \
+               if (si->dst_reg == reg || si->src_reg == reg)                 \
+                       reg--;                                                \
+               *insn++ = BPF_STX_MEM(BPF_DW, si->dst_reg, reg,               \
+                                     offsetof(struct bpf_sock_ops_kern,      \
+                                              temp));                        \
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(                       \
+                                               struct bpf_sock_ops_kern,     \
+                                               is_fullsock),                 \
+                                     reg, si->dst_reg,                       \
+                                     offsetof(struct bpf_sock_ops_kern,      \
+                                              is_fullsock));                 \
+               *insn++ = BPF_JMP_IMM(BPF_JEQ, reg, 0, 2);                    \
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(                       \
+                                               struct bpf_sock_ops_kern, sk),\
+                                     reg, si->dst_reg,                       \
+                                     offsetof(struct bpf_sock_ops_kern, sk));\
+               *insn++ = BPF_STX_MEM(BPF_FIELD_SIZEOF(OBJ, OBJ_FIELD),       \
+                                     reg, si->src_reg,                       \
+                                     offsetof(OBJ, OBJ_FIELD));              \
+               *insn++ = BPF_LDX_MEM(BPF_DW, reg, si->dst_reg,               \
+                                     offsetof(struct bpf_sock_ops_kern,      \
+                                              temp));                        \
+       } while (0)
+
+#define SOCK_OPS_GET_OR_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ, TYPE)           \
+       do {                                                                  \
+               if (TYPE == BPF_WRITE)                                        \
+                       SOCK_OPS_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ);        \
+               else                                                          \
+                       SOCK_OPS_GET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ);        \
+       } while (0)
+
+       CONVERT_COMMON_TCP_SOCK_FIELDS(struct bpf_sock_ops,
+                                      SOCK_OPS_GET_TCP_SOCK_FIELD);
+
+       if (insn > insn_buf)
+               return insn - insn_buf;
+
        switch (si->off) {
        case offsetof(struct bpf_sock_ops, op) ...
             offsetof(struct bpf_sock_ops, replylong[3]):
@@ -7225,175 +7823,15 @@ static u32 sock_ops_convert_ctx_access(enum bpf_access_type type,
                                      FIELD_SIZEOF(struct minmax_sample, t));
                break;
 
-/* Helper macro for adding read access to tcp_sock or sock fields. */
-#define SOCK_OPS_GET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ)                        \
-       do {                                                                  \
-               BUILD_BUG_ON(FIELD_SIZEOF(OBJ, OBJ_FIELD) >                   \
-                            FIELD_SIZEOF(struct bpf_sock_ops, BPF_FIELD));   \
-               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(                       \
-                                               struct bpf_sock_ops_kern,     \
-                                               is_fullsock),                 \
-                                     si->dst_reg, si->src_reg,               \
-                                     offsetof(struct bpf_sock_ops_kern,      \
-                                              is_fullsock));                 \
-               *insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 2);            \
-               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(                       \
-                                               struct bpf_sock_ops_kern, sk),\
-                                     si->dst_reg, si->src_reg,               \
-                                     offsetof(struct bpf_sock_ops_kern, sk));\
-               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(OBJ,                   \
-                                                      OBJ_FIELD),            \
-                                     si->dst_reg, si->dst_reg,               \
-                                     offsetof(OBJ, OBJ_FIELD));              \
-       } while (0)
-
-/* Helper macro for adding write access to tcp_sock or sock fields.
- * The macro is called with two registers, dst_reg which contains a pointer
- * to ctx (context) and src_reg which contains the value that should be
- * stored. However, we need an additional register since we cannot overwrite
- * dst_reg because it may be used later in the program.
- * Instead we "borrow" one of the other register. We first save its value
- * into a new (temp) field in bpf_sock_ops_kern, use it, and then restore
- * it at the end of the macro.
- */
-#define SOCK_OPS_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ)                        \
-       do {                                                                  \
-               int reg = BPF_REG_9;                                          \
-               BUILD_BUG_ON(FIELD_SIZEOF(OBJ, OBJ_FIELD) >                   \
-                            FIELD_SIZEOF(struct bpf_sock_ops, BPF_FIELD));   \
-               if (si->dst_reg == reg || si->src_reg == reg)                 \
-                       reg--;                                                \
-               if (si->dst_reg == reg || si->src_reg == reg)                 \
-                       reg--;                                                \
-               *insn++ = BPF_STX_MEM(BPF_DW, si->dst_reg, reg,               \
-                                     offsetof(struct bpf_sock_ops_kern,      \
-                                              temp));                        \
-               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(                       \
-                                               struct bpf_sock_ops_kern,     \
-                                               is_fullsock),                 \
-                                     reg, si->dst_reg,                       \
-                                     offsetof(struct bpf_sock_ops_kern,      \
-                                              is_fullsock));                 \
-               *insn++ = BPF_JMP_IMM(BPF_JEQ, reg, 0, 2);                    \
-               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(                       \
-                                               struct bpf_sock_ops_kern, sk),\
-                                     reg, si->dst_reg,                       \
-                                     offsetof(struct bpf_sock_ops_kern, sk));\
-               *insn++ = BPF_STX_MEM(BPF_FIELD_SIZEOF(OBJ, OBJ_FIELD),       \
-                                     reg, si->src_reg,                       \
-                                     offsetof(OBJ, OBJ_FIELD));              \
-               *insn++ = BPF_LDX_MEM(BPF_DW, reg, si->dst_reg,               \
-                                     offsetof(struct bpf_sock_ops_kern,      \
-                                              temp));                        \
-       } while (0)
-
-#define SOCK_OPS_GET_OR_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ, TYPE)           \
-       do {                                                                  \
-               if (TYPE == BPF_WRITE)                                        \
-                       SOCK_OPS_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ);        \
-               else                                                          \
-                       SOCK_OPS_GET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ);        \
-       } while (0)
-
-       case offsetof(struct bpf_sock_ops, snd_cwnd):
-               SOCK_OPS_GET_FIELD(snd_cwnd, snd_cwnd, struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, srtt_us):
-               SOCK_OPS_GET_FIELD(srtt_us, srtt_us, struct tcp_sock);
-               break;
-
        case offsetof(struct bpf_sock_ops, bpf_sock_ops_cb_flags):
                SOCK_OPS_GET_FIELD(bpf_sock_ops_cb_flags, bpf_sock_ops_cb_flags,
                                   struct tcp_sock);
                break;
 
-       case offsetof(struct bpf_sock_ops, snd_ssthresh):
-               SOCK_OPS_GET_FIELD(snd_ssthresh, snd_ssthresh, struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, rcv_nxt):
-               SOCK_OPS_GET_FIELD(rcv_nxt, rcv_nxt, struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, snd_nxt):
-               SOCK_OPS_GET_FIELD(snd_nxt, snd_nxt, struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, snd_una):
-               SOCK_OPS_GET_FIELD(snd_una, snd_una, struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, mss_cache):
-               SOCK_OPS_GET_FIELD(mss_cache, mss_cache, struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, ecn_flags):
-               SOCK_OPS_GET_FIELD(ecn_flags, ecn_flags, struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, rate_delivered):
-               SOCK_OPS_GET_FIELD(rate_delivered, rate_delivered,
-                                  struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, rate_interval_us):
-               SOCK_OPS_GET_FIELD(rate_interval_us, rate_interval_us,
-                                  struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, packets_out):
-               SOCK_OPS_GET_FIELD(packets_out, packets_out, struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, retrans_out):
-               SOCK_OPS_GET_FIELD(retrans_out, retrans_out, struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, total_retrans):
-               SOCK_OPS_GET_FIELD(total_retrans, total_retrans,
-                                  struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, segs_in):
-               SOCK_OPS_GET_FIELD(segs_in, segs_in, struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, data_segs_in):
-               SOCK_OPS_GET_FIELD(data_segs_in, data_segs_in, struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, segs_out):
-               SOCK_OPS_GET_FIELD(segs_out, segs_out, struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, data_segs_out):
-               SOCK_OPS_GET_FIELD(data_segs_out, data_segs_out,
-                                  struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, lost_out):
-               SOCK_OPS_GET_FIELD(lost_out, lost_out, struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, sacked_out):
-               SOCK_OPS_GET_FIELD(sacked_out, sacked_out, struct tcp_sock);
-               break;
-
        case offsetof(struct bpf_sock_ops, sk_txhash):
                SOCK_OPS_GET_OR_SET_FIELD(sk_txhash, sk_txhash,
                                          struct sock, type);
                break;
-
-       case offsetof(struct bpf_sock_ops, bytes_received):
-               SOCK_OPS_GET_FIELD(bytes_received, bytes_received,
-                                  struct tcp_sock);
-               break;
-
-       case offsetof(struct bpf_sock_ops, bytes_acked):
-               SOCK_OPS_GET_FIELD(bytes_acked, bytes_acked, struct tcp_sock);
-               break;
-
        }
        return insn - insn_buf;
 }
@@ -7698,6 +8136,7 @@ const struct bpf_verifier_ops flow_dissector_verifier_ops = {
 };
 
 const struct bpf_prog_ops flow_dissector_prog_ops = {
+       .test_run               = bpf_prog_test_run_flow_dissector,
 };
 
 int sk_detach_filter(struct sock *sk)