bpf: Change bpf_setsockopt(SOL_SOCKET) to reuse sk_setsockopt()
authorMartin KaFai Lau <kafai@fb.com>
Wed, 17 Aug 2022 06:18:04 +0000 (23:18 -0700)
committerAlexei Starovoitov <ast@kernel.org>
Fri, 19 Aug 2022 00:06:13 +0000 (17:06 -0700)
After the prep work in the previous patches,
this patch removes most of the dup code from bpf_setsockopt(SOL_SOCKET)
and reuses them from sk_setsockopt().

The sock ptr test is added to the SO_RCVLOWAT because
the sk->sk_socket could be NULL in some of the bpf hooks.

The existing optname white-list is refactored into a new
function sol_socket_setsockopt().

Reviewed-by: Stanislav Fomichev <sdf@google.com>
Signed-off-by: Martin KaFai Lau <kafai@fb.com>
Link: https://lore.kernel.org/r/20220817061804.4178920-1-kafai@fb.com
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
include/net/sock.h
net/core/filter.c
net/core/sock.c

index 13089d8..ee44b42 100644 (file)
@@ -1828,6 +1828,8 @@ void sock_pfree(struct sk_buff *skb);
 #define sock_edemux sock_efree
 #endif
 
+int sk_setsockopt(struct sock *sk, int level, int optname,
+                 sockptr_t optval, unsigned int optlen);
 int sock_setsockopt(struct socket *sock, int level, int op,
                    sockptr_t optval, unsigned int optlen);
 
index a663d7b..6f5bcc8 100644 (file)
@@ -5013,109 +5013,43 @@ static const struct bpf_func_proto bpf_get_socket_uid_proto = {
        .arg1_type      = ARG_PTR_TO_CTX,
 };
 
+static int sol_socket_setsockopt(struct sock *sk, int optname,
+                                char *optval, int optlen)
+{
+       switch (optname) {
+       case SO_SNDBUF:
+       case SO_RCVBUF:
+       case SO_KEEPALIVE:
+       case SO_PRIORITY:
+       case SO_REUSEPORT:
+       case SO_RCVLOWAT:
+       case SO_MARK:
+       case SO_MAX_PACING_RATE:
+       case SO_BINDTOIFINDEX:
+       case SO_TXREHASH:
+               if (optlen != sizeof(int))
+                       return -EINVAL;
+               break;
+       case SO_BINDTODEVICE:
+               break;
+       default:
+               return -EINVAL;
+       }
+
+       return sk_setsockopt(sk, SOL_SOCKET, optname,
+                            KERNEL_SOCKPTR(optval), optlen);
+}
+
 static int __bpf_setsockopt(struct sock *sk, int level, int optname,
                            char *optval, int optlen)
 {
-       char devname[IFNAMSIZ];
-       int val, valbool;
-       struct net *net;
-       int ifindex;
-       int ret = 0;
+       int val, ret = 0;
 
        if (!sk_fullsock(sk))
                return -EINVAL;
 
        if (level == SOL_SOCKET) {
-               if (optlen != sizeof(int) && optname != SO_BINDTODEVICE)
-                       return -EINVAL;
-               val = *((int *)optval);
-               valbool = val ? 1 : 0;
-
-               /* Only some socketops are supported */
-               switch (optname) {
-               case SO_RCVBUF:
-                       val = min_t(u32, val, sysctl_rmem_max);
-                       val = min_t(int, val, INT_MAX / 2);
-                       sk->sk_userlocks |= SOCK_RCVBUF_LOCK;
-                       WRITE_ONCE(sk->sk_rcvbuf,
-                                  max_t(int, val * 2, SOCK_MIN_RCVBUF));
-                       break;
-               case SO_SNDBUF:
-                       val = min_t(u32, val, sysctl_wmem_max);
-                       val = min_t(int, val, INT_MAX / 2);
-                       sk->sk_userlocks |= SOCK_SNDBUF_LOCK;
-                       WRITE_ONCE(sk->sk_sndbuf,
-                                  max_t(int, val * 2, SOCK_MIN_SNDBUF));
-                       break;
-               case SO_MAX_PACING_RATE: /* 32bit version */
-                       if (val != ~0U)
-                               cmpxchg(&sk->sk_pacing_status,
-                                       SK_PACING_NONE,
-                                       SK_PACING_NEEDED);
-                       sk->sk_max_pacing_rate = (val == ~0U) ?
-                                                ~0UL : (unsigned int)val;
-                       sk->sk_pacing_rate = min(sk->sk_pacing_rate,
-                                                sk->sk_max_pacing_rate);
-                       break;
-               case SO_PRIORITY:
-                       sk->sk_priority = val;
-                       break;
-               case SO_RCVLOWAT:
-                       if (val < 0)
-                               val = INT_MAX;
-                       if (sk->sk_socket && sk->sk_socket->ops->set_rcvlowat)
-                               ret = sk->sk_socket->ops->set_rcvlowat(sk, val);
-                       else
-                               WRITE_ONCE(sk->sk_rcvlowat, val ? : 1);
-                       break;
-               case SO_MARK:
-                       if (sk->sk_mark != val) {
-                               sk->sk_mark = val;
-                               sk_dst_reset(sk);
-                       }
-                       break;
-               case SO_BINDTODEVICE:
-                       optlen = min_t(long, optlen, IFNAMSIZ - 1);
-                       strncpy(devname, optval, optlen);
-                       devname[optlen] = 0;
-
-                       ifindex = 0;
-                       if (devname[0] != '\0') {
-                               struct net_device *dev;
-
-                               ret = -ENODEV;
-
-                               net = sock_net(sk);
-                               dev = dev_get_by_name(net, devname);
-                               if (!dev)
-                                       break;
-                               ifindex = dev->ifindex;
-                               dev_put(dev);
-                       }
-                       fallthrough;
-               case SO_BINDTOIFINDEX:
-                       if (optname == SO_BINDTOIFINDEX)
-                               ifindex = val;
-                       ret = sock_bindtoindex(sk, ifindex, false);
-                       break;
-               case SO_KEEPALIVE:
-                       if (sk->sk_prot->keepalive)
-                               sk->sk_prot->keepalive(sk, valbool);
-                       sock_valbool_flag(sk, SOCK_KEEPOPEN, valbool);
-                       break;
-               case SO_REUSEPORT:
-                       sk->sk_reuseport = valbool;
-                       break;
-               case SO_TXREHASH:
-                       if (val < -1 || val > 1) {
-                               ret = -EINVAL;
-                               break;
-                       }
-                       sk->sk_txrehash = (u8)val;
-                       break;
-               default:
-                       ret = -EINVAL;
-               }
+               return sol_socket_setsockopt(sk, optname, optval, optlen);
        } else if (IS_ENABLED(CONFIG_INET) && level == SOL_IP) {
                if (optlen != sizeof(int) || sk->sk_family != AF_INET)
                        return -EINVAL;
index 7ea46e4..2a6f847 100644 (file)
@@ -1077,8 +1077,8 @@ EXPORT_SYMBOL(sockopt_capable);
  *     at the socket level. Everything here is generic.
  */
 
-static int sk_setsockopt(struct sock *sk, int level, int optname,
-                        sockptr_t optval, unsigned int optlen)
+int sk_setsockopt(struct sock *sk, int level, int optname,
+                 sockptr_t optval, unsigned int optlen)
 {
        struct so_timestamping timestamping;
        struct socket *sock = sk->sk_socket;
@@ -1264,7 +1264,7 @@ set_sndbuf:
        case SO_RCVLOWAT:
                if (val < 0)
                        val = INT_MAX;
-               if (sock->ops->set_rcvlowat)
+               if (sock && sock->ops->set_rcvlowat)
                        ret = sock->ops->set_rcvlowat(sk, val);
                else
                        WRITE_ONCE(sk->sk_rcvlowat, val ? : 1);