bpf: add helper to check for a valid SYN cookie
[platform/kernel/linux-rpi.git] / net / core / filter.c
index f274620..d2511fe 100644 (file)
@@ -1796,8 +1796,6 @@ static const struct bpf_func_proto bpf_skb_pull_data_proto = {
 
 BPF_CALL_1(bpf_sk_fullsock, struct sock *, sk)
 {
-       sk = sk_to_full_sk(sk);
-
        return sk_fullsock(sk) ? (unsigned long)sk : (unsigned long)NULL;
 }
 
@@ -5158,15 +5156,15 @@ static struct sock *sk_lookup(struct net *net, struct bpf_sock_tuple *tuple,
        return sk;
 }
 
-/* bpf_sk_lookup performs the core lookup for different types of sockets,
+/* bpf_skc_lookup performs the core lookup for different types of sockets,
  * taking a reference on the socket if it doesn't have the flag SOCK_RCU_FREE.
  * Returns the socket as an 'unsigned long' to simplify the casting in the
  * callers to satisfy BPF_CALL declarations.
  */
-static unsigned long
-__bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
-               struct net *caller_net, u32 ifindex, u8 proto, u64 netns_id,
-               u64 flags)
+static struct sock *
+__bpf_skc_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
+                struct net *caller_net, u32 ifindex, u8 proto, u64 netns_id,
+                u64 flags)
 {
        struct sock *sk = NULL;
        u8 family = AF_UNSPEC;
@@ -5194,15 +5192,27 @@ __bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
                put_net(net);
        }
 
+out:
+       return sk;
+}
+
+static struct sock *
+__bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
+               struct net *caller_net, u32 ifindex, u8 proto, u64 netns_id,
+               u64 flags)
+{
+       struct sock *sk = __bpf_skc_lookup(skb, tuple, len, caller_net,
+                                          ifindex, proto, netns_id, flags);
+
        if (sk)
                sk = sk_to_full_sk(sk);
-out:
-       return (unsigned long) sk;
+
+       return sk;
 }
 
-static unsigned long
-bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
-             u8 proto, u64 netns_id, u64 flags)
+static struct sock *
+bpf_skc_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
+              u8 proto, u64 netns_id, u64 flags)
 {
        struct net *caller_net;
        int ifindex;
@@ -5215,14 +5225,47 @@ bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
                ifindex = 0;
        }
 
-       return __bpf_sk_lookup(skb, tuple, len, caller_net, ifindex,
-                             proto, netns_id, flags);
+       return __bpf_skc_lookup(skb, tuple, len, caller_net, ifindex, proto,
+                               netns_id, flags);
+}
+
+static struct sock *
+bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len,
+             u8 proto, u64 netns_id, u64 flags)
+{
+       struct sock *sk = bpf_skc_lookup(skb, tuple, len, proto, netns_id,
+                                        flags);
+
+       if (sk)
+               sk = sk_to_full_sk(sk);
+
+       return sk;
+}
+
+BPF_CALL_5(bpf_skc_lookup_tcp, struct sk_buff *, skb,
+          struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
+{
+       return (unsigned long)bpf_skc_lookup(skb, tuple, len, IPPROTO_TCP,
+                                            netns_id, flags);
 }
 
+static const struct bpf_func_proto bpf_skc_lookup_tcp_proto = {
+       .func           = bpf_skc_lookup_tcp,
+       .gpl_only       = false,
+       .pkt_access     = true,
+       .ret_type       = RET_PTR_TO_SOCK_COMMON_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_PTR_TO_MEM,
+       .arg3_type      = ARG_CONST_SIZE,
+       .arg4_type      = ARG_ANYTHING,
+       .arg5_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_5(bpf_sk_lookup_tcp, struct sk_buff *, skb,
           struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
 {
-       return bpf_sk_lookup(skb, tuple, len, IPPROTO_TCP, netns_id, flags);
+       return (unsigned long)bpf_sk_lookup(skb, tuple, len, IPPROTO_TCP,
+                                           netns_id, flags);
 }
 
 static const struct bpf_func_proto bpf_sk_lookup_tcp_proto = {
@@ -5240,7 +5283,8 @@ static const struct bpf_func_proto bpf_sk_lookup_tcp_proto = {
 BPF_CALL_5(bpf_sk_lookup_udp, struct sk_buff *, skb,
           struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
 {
-       return bpf_sk_lookup(skb, tuple, len, IPPROTO_UDP, netns_id, flags);
+       return (unsigned long)bpf_sk_lookup(skb, tuple, len, IPPROTO_UDP,
+                                           netns_id, flags);
 }
 
 static const struct bpf_func_proto bpf_sk_lookup_udp_proto = {
@@ -5266,7 +5310,7 @@ static const struct bpf_func_proto bpf_sk_release_proto = {
        .func           = bpf_sk_release,
        .gpl_only       = false,
        .ret_type       = RET_INTEGER,
-       .arg1_type      = ARG_PTR_TO_SOCKET,
+       .arg1_type      = ARG_PTR_TO_SOCK_COMMON,
 };
 
 BPF_CALL_5(bpf_xdp_sk_lookup_udp, struct xdp_buff *, ctx,
@@ -5275,8 +5319,9 @@ BPF_CALL_5(bpf_xdp_sk_lookup_udp, struct xdp_buff *, ctx,
        struct net *caller_net = dev_net(ctx->rxq->dev);
        int ifindex = ctx->rxq->dev->ifindex;
 
-       return __bpf_sk_lookup(NULL, tuple, len, caller_net, ifindex,
-                             IPPROTO_UDP, netns_id, flags);
+       return (unsigned long)__bpf_sk_lookup(NULL, tuple, len, caller_net,
+                                             ifindex, IPPROTO_UDP, netns_id,
+                                             flags);
 }
 
 static const struct bpf_func_proto bpf_xdp_sk_lookup_udp_proto = {
@@ -5291,14 +5336,38 @@ static const struct bpf_func_proto bpf_xdp_sk_lookup_udp_proto = {
        .arg5_type      = ARG_ANYTHING,
 };
 
+BPF_CALL_5(bpf_xdp_skc_lookup_tcp, struct xdp_buff *, ctx,
+          struct bpf_sock_tuple *, tuple, u32, len, u32, netns_id, u64, flags)
+{
+       struct net *caller_net = dev_net(ctx->rxq->dev);
+       int ifindex = ctx->rxq->dev->ifindex;
+
+       return (unsigned long)__bpf_skc_lookup(NULL, tuple, len, caller_net,
+                                              ifindex, IPPROTO_TCP, netns_id,
+                                              flags);
+}
+
+static const struct bpf_func_proto bpf_xdp_skc_lookup_tcp_proto = {
+       .func           = bpf_xdp_skc_lookup_tcp,
+       .gpl_only       = false,
+       .pkt_access     = true,
+       .ret_type       = RET_PTR_TO_SOCK_COMMON_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_PTR_TO_MEM,
+       .arg3_type      = ARG_CONST_SIZE,
+       .arg4_type      = ARG_ANYTHING,
+       .arg5_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_5(bpf_xdp_sk_lookup_tcp, struct xdp_buff *, ctx,
           struct bpf_sock_tuple *, tuple, u32, len, u32, netns_id, u64, flags)
 {
        struct net *caller_net = dev_net(ctx->rxq->dev);
        int ifindex = ctx->rxq->dev->ifindex;
 
-       return __bpf_sk_lookup(NULL, tuple, len, caller_net, ifindex,
-                             IPPROTO_TCP, netns_id, flags);
+       return (unsigned long)__bpf_sk_lookup(NULL, tuple, len, caller_net,
+                                             ifindex, IPPROTO_TCP, netns_id,
+                                             flags);
 }
 
 static const struct bpf_func_proto bpf_xdp_sk_lookup_tcp_proto = {
@@ -5313,11 +5382,31 @@ static const struct bpf_func_proto bpf_xdp_sk_lookup_tcp_proto = {
        .arg5_type      = ARG_ANYTHING,
 };
 
+BPF_CALL_5(bpf_sock_addr_skc_lookup_tcp, struct bpf_sock_addr_kern *, ctx,
+          struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
+{
+       return (unsigned long)__bpf_skc_lookup(NULL, tuple, len,
+                                              sock_net(ctx->sk), 0,
+                                              IPPROTO_TCP, netns_id, flags);
+}
+
+static const struct bpf_func_proto bpf_sock_addr_skc_lookup_tcp_proto = {
+       .func           = bpf_sock_addr_skc_lookup_tcp,
+       .gpl_only       = false,
+       .ret_type       = RET_PTR_TO_SOCK_COMMON_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_PTR_TO_MEM,
+       .arg3_type      = ARG_CONST_SIZE,
+       .arg4_type      = ARG_ANYTHING,
+       .arg5_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_5(bpf_sock_addr_sk_lookup_tcp, struct bpf_sock_addr_kern *, ctx,
           struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
 {
-       return __bpf_sk_lookup(NULL, tuple, len, sock_net(ctx->sk), 0,
-                              IPPROTO_TCP, netns_id, flags);
+       return (unsigned long)__bpf_sk_lookup(NULL, tuple, len,
+                                             sock_net(ctx->sk), 0, IPPROTO_TCP,
+                                             netns_id, flags);
 }
 
 static const struct bpf_func_proto bpf_sock_addr_sk_lookup_tcp_proto = {
@@ -5334,8 +5423,9 @@ static const struct bpf_func_proto bpf_sock_addr_sk_lookup_tcp_proto = {
 BPF_CALL_5(bpf_sock_addr_sk_lookup_udp, struct bpf_sock_addr_kern *, ctx,
           struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags)
 {
-       return __bpf_sk_lookup(NULL, tuple, len, sock_net(ctx->sk), 0,
-                              IPPROTO_UDP, netns_id, flags);
+       return (unsigned long)__bpf_sk_lookup(NULL, tuple, len,
+                                             sock_net(ctx->sk), 0, IPPROTO_UDP,
+                                             netns_id, flags);
 }
 
 static const struct bpf_func_proto bpf_sock_addr_sk_lookup_udp_proto = {
@@ -5407,8 +5497,6 @@ u32 bpf_tcp_sock_convert_ctx_access(enum bpf_access_type type,
 
 BPF_CALL_1(bpf_tcp_sock, struct sock *, sk)
 {
-       sk = sk_to_full_sk(sk);
-
        if (sk_fullsock(sk) && sk->sk_protocol == IPPROTO_TCP)
                return (unsigned long)sk;
 
@@ -5422,6 +5510,23 @@ static const struct bpf_func_proto bpf_tcp_sock_proto = {
        .arg1_type      = ARG_PTR_TO_SOCK_COMMON,
 };
 
+BPF_CALL_1(bpf_get_listener_sock, struct sock *, sk)
+{
+       sk = sk_to_full_sk(sk);
+
+       if (sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE))
+               return (unsigned long)sk;
+
+       return (unsigned long)NULL;
+}
+
+static const struct bpf_func_proto bpf_get_listener_sock_proto = {
+       .func           = bpf_get_listener_sock,
+       .gpl_only       = false,
+       .ret_type       = RET_PTR_TO_SOCKET_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_SOCK_COMMON,
+};
+
 BPF_CALL_1(bpf_skb_ecn_set_ce, struct sk_buff *, skb)
 {
        unsigned int iphdr_len;
@@ -5448,6 +5553,74 @@ static const struct bpf_func_proto bpf_skb_ecn_set_ce_proto = {
        .ret_type       = RET_INTEGER,
        .arg1_type      = ARG_PTR_TO_CTX,
 };
+
+BPF_CALL_5(bpf_tcp_check_syncookie, struct sock *, sk, void *, iph, u32, iph_len,
+          struct tcphdr *, th, u32, th_len)
+{
+#ifdef CONFIG_SYN_COOKIES
+       u32 cookie;
+       int ret;
+
+       if (unlikely(th_len < sizeof(*th)))
+               return -EINVAL;
+
+       /* sk_listener() allows TCP_NEW_SYN_RECV, which makes no sense here. */
+       if (sk->sk_protocol != IPPROTO_TCP || sk->sk_state != TCP_LISTEN)
+               return -EINVAL;
+
+       if (!sock_net(sk)->ipv4.sysctl_tcp_syncookies)
+               return -EINVAL;
+
+       if (!th->ack || th->rst || th->syn)
+               return -ENOENT;
+
+       if (tcp_synq_no_recent_overflow(sk))
+               return -ENOENT;
+
+       cookie = ntohl(th->ack_seq) - 1;
+
+       switch (sk->sk_family) {
+       case AF_INET:
+               if (unlikely(iph_len < sizeof(struct iphdr)))
+                       return -EINVAL;
+
+               ret = __cookie_v4_check((struct iphdr *)iph, th, cookie);
+               break;
+
+#if IS_BUILTIN(CONFIG_IPV6)
+       case AF_INET6:
+               if (unlikely(iph_len < sizeof(struct ipv6hdr)))
+                       return -EINVAL;
+
+               ret = __cookie_v6_check((struct ipv6hdr *)iph, th, cookie);
+               break;
+#endif /* CONFIG_IPV6 */
+
+       default:
+               return -EPROTONOSUPPORT;
+       }
+
+       if (ret > 0)
+               return 0;
+
+       return -ENOENT;
+#else
+       return -ENOTSUPP;
+#endif
+}
+
+static const struct bpf_func_proto bpf_tcp_check_syncookie_proto = {
+       .func           = bpf_tcp_check_syncookie,
+       .gpl_only       = true,
+       .pkt_access     = true,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_SOCK_COMMON,
+       .arg2_type      = ARG_PTR_TO_MEM,
+       .arg3_type      = ARG_CONST_SIZE,
+       .arg4_type      = ARG_PTR_TO_MEM,
+       .arg5_type      = ARG_CONST_SIZE,
+};
+
 #endif /* CONFIG_INET */
 
 bool bpf_helper_changes_pkt_data(void *func)
@@ -5573,6 +5746,8 @@ sock_addr_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_sock_addr_sk_lookup_udp_proto;
        case BPF_FUNC_sk_release:
                return &bpf_sk_release_proto;
+       case BPF_FUNC_skc_lookup_tcp:
+               return &bpf_sock_addr_skc_lookup_tcp_proto;
 #endif /* CONFIG_INET */
        default:
                return bpf_base_func_proto(func_id);
@@ -5607,6 +5782,8 @@ cg_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
 #ifdef CONFIG_INET
        case BPF_FUNC_tcp_sock:
                return &bpf_tcp_sock_proto;
+       case BPF_FUNC_get_listener_sock:
+               return &bpf_get_listener_sock_proto;
        case BPF_FUNC_skb_ecn_set_ce:
                return &bpf_skb_ecn_set_ce_proto;
 #endif
@@ -5702,6 +5879,12 @@ tc_cls_act_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_sk_release_proto;
        case BPF_FUNC_tcp_sock:
                return &bpf_tcp_sock_proto;
+       case BPF_FUNC_get_listener_sock:
+               return &bpf_get_listener_sock_proto;
+       case BPF_FUNC_skc_lookup_tcp:
+               return &bpf_skc_lookup_tcp_proto;
+       case BPF_FUNC_tcp_check_syncookie:
+               return &bpf_tcp_check_syncookie_proto;
 #endif
        default:
                return bpf_base_func_proto(func_id);
@@ -5737,6 +5920,10 @@ xdp_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_xdp_sk_lookup_tcp_proto;
        case BPF_FUNC_sk_release:
                return &bpf_sk_release_proto;
+       case BPF_FUNC_skc_lookup_tcp:
+               return &bpf_xdp_skc_lookup_tcp_proto;
+       case BPF_FUNC_tcp_check_syncookie:
+               return &bpf_tcp_check_syncookie_proto;
 #endif
        default:
                return bpf_base_func_proto(func_id);
@@ -5829,6 +6016,8 @@ sk_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_sk_lookup_udp_proto;
        case BPF_FUNC_sk_release:
                return &bpf_sk_release_proto;
+       case BPF_FUNC_skc_lookup_tcp:
+               return &bpf_skc_lookup_tcp_proto;
 #endif
        default:
                return bpf_base_func_proto(func_id);