bpf: Change bpf_getsockopt(SOL_SOCKET) to reuse sk_getsockopt()
authorMartin KaFai Lau <martin.lau@kernel.org>
Fri, 2 Sep 2022 00:29:12 +0000 (17:29 -0700)
committerAlexei Starovoitov <ast@kernel.org>
Sat, 3 Sep 2022 03:34:31 +0000 (20:34 -0700)
This patch changes bpf_getsockopt(SOL_SOCKET) to reuse
sk_getsockopt().  It removes all duplicated code from
bpf_getsockopt(SOL_SOCKET).

Before this patch, there were some optnames available to
bpf_setsockopt(SOL_SOCKET) but missing in bpf_getsockopt(SOL_SOCKET).
It surprises users from time to time.  For example, SO_REUSEADDR,
SO_KEEPALIVE, SO_RCVLOWAT, and SO_MAX_PACING_RATE.  This patch
automatically closes this gap without duplicating more code.
The only exception is SO_BINDTODEVICE because it needs to acquire a
blocking lock.  Thus, SO_BINDTODEVICE is not supported.

Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>
Link: https://lore.kernel.org/r/20220902002912.2894040-1-kafai@fb.com
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
include/net/sock.h
net/core/filter.c
net/core/sock.c

index ee44b42..ea79655 100644 (file)
@@ -1833,6 +1833,8 @@ int sk_setsockopt(struct sock *sk, int level, int optname,
 int sock_setsockopt(struct socket *sock, int level, int op,
                    sockptr_t optval, unsigned int optlen);
 
+int sk_getsockopt(struct sock *sk, int level, int optname,
+                 sockptr_t optval, sockptr_t optlen);
 int sock_getsockopt(struct socket *sock, int level, int op,
                    char __user *optval, int __user *optlen);
 int sock_gettstamp(struct socket *sock, void __user *userstamp,
index 5579581..9b26653 100644 (file)
@@ -5017,8 +5017,9 @@ static const struct bpf_func_proto bpf_get_socket_uid_proto = {
        .arg1_type      = ARG_PTR_TO_CTX,
 };
 
-static int sol_socket_setsockopt(struct sock *sk, int optname,
-                                char *optval, int optlen)
+static int sol_socket_sockopt(struct sock *sk, int optname,
+                             char *optval, int *optlen,
+                             bool getopt)
 {
        switch (optname) {
        case SO_REUSEADDR:
@@ -5032,7 +5033,7 @@ static int sol_socket_setsockopt(struct sock *sk, int optname,
        case SO_MAX_PACING_RATE:
        case SO_BINDTOIFINDEX:
        case SO_TXREHASH:
-               if (optlen != sizeof(int))
+               if (*optlen != sizeof(int))
                        return -EINVAL;
                break;
        case SO_BINDTODEVICE:
@@ -5041,8 +5042,16 @@ static int sol_socket_setsockopt(struct sock *sk, int optname,
                return -EINVAL;
        }
 
+       if (getopt) {
+               if (optname == SO_BINDTODEVICE)
+                       return -EINVAL;
+               return sk_getsockopt(sk, SOL_SOCKET, optname,
+                                    KERNEL_SOCKPTR(optval),
+                                    KERNEL_SOCKPTR(optlen));
+       }
+
        return sk_setsockopt(sk, SOL_SOCKET, optname,
-                            KERNEL_SOCKPTR(optval), optlen);
+                            KERNEL_SOCKPTR(optval), *optlen);
 }
 
 static int bpf_sol_tcp_setsockopt(struct sock *sk, int optname,
@@ -5168,7 +5177,7 @@ static int __bpf_setsockopt(struct sock *sk, int level, int optname,
                return -EINVAL;
 
        if (level == SOL_SOCKET)
-               return sol_socket_setsockopt(sk, optname, optval, optlen);
+               return sol_socket_sockopt(sk, optname, optval, &optlen, false);
        else if (IS_ENABLED(CONFIG_INET) && level == SOL_IP)
                return sol_ip_setsockopt(sk, optname, optval, optlen);
        else if (IS_ENABLED(CONFIG_IPV6) && level == SOL_IPV6)
@@ -5190,38 +5199,13 @@ static int _bpf_setsockopt(struct sock *sk, int level, int optname,
 static int __bpf_getsockopt(struct sock *sk, int level, int optname,
                            char *optval, int optlen)
 {
+       int err = 0, saved_optlen = optlen;
+
        if (!sk_fullsock(sk))
                goto err_clear;
 
        if (level == SOL_SOCKET) {
-               if (optlen != sizeof(int))
-                       goto err_clear;
-
-               switch (optname) {
-               case SO_RCVBUF:
-                       *((int *)optval) = sk->sk_rcvbuf;
-                       break;
-               case SO_SNDBUF:
-                       *((int *)optval) = sk->sk_sndbuf;
-                       break;
-               case SO_MARK:
-                       *((int *)optval) = sk->sk_mark;
-                       break;
-               case SO_PRIORITY:
-                       *((int *)optval) = sk->sk_priority;
-                       break;
-               case SO_BINDTOIFINDEX:
-                       *((int *)optval) = sk->sk_bound_dev_if;
-                       break;
-               case SO_REUSEPORT:
-                       *((int *)optval) = sk->sk_reuseport;
-                       break;
-               case SO_TXREHASH:
-                       *((int *)optval) = sk->sk_txrehash;
-                       break;
-               default:
-                       goto err_clear;
-               }
+               err = sol_socket_sockopt(sk, optname, optval, &optlen, true);
        } else if (IS_ENABLED(CONFIG_INET) &&
                   level == SOL_TCP && sk->sk_prot->getsockopt == tcp_getsockopt) {
                struct inet_connection_sock *icsk;
@@ -5278,7 +5262,12 @@ static int __bpf_getsockopt(struct sock *sk, int level, int optname,
        } else {
                goto err_clear;
        }
-       return 0;
+
+       if (err)
+               optlen = 0;
+       if (optlen < saved_optlen)
+               memset(optval + optlen, 0, saved_optlen - optlen);
+       return err;
 err_clear:
        memset(optval, 0, optlen);
        return -EINVAL;
index 7fa30fd..68e4662 100644 (file)
@@ -1583,8 +1583,8 @@ static int groups_to_user(sockptr_t dst, const struct group_info *src)
        return 0;
 }
 
-static int sk_getsockopt(struct sock *sk, int level, int optname,
-                        sockptr_t optval, sockptr_t optlen)
+int sk_getsockopt(struct sock *sk, int level, int optname,
+                 sockptr_t optval, sockptr_t optlen)
 {
        struct socket *sock = sk->sk_socket;