bpf: net: Change do_tcp_getsockopt() to take the sockptr_t argument
authorMartin KaFai Lau <martin.lau@kernel.org>
Fri, 2 Sep 2022 00:28:15 +0000 (17:28 -0700)
committerAlexei Starovoitov <ast@kernel.org>
Sat, 3 Sep 2022 03:34:30 +0000 (20:34 -0700)
Similar to the earlier patch that changes sk_getsockopt() to
take the sockptr_t argument .  This patch also changes
do_tcp_getsockopt() to take the sockptr_t argument such that
a latter patch can make bpf_getsockopt(SOL_TCP) to reuse
do_tcp_getsockopt().

Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>
Link: https://lore.kernel.org/r/20220902002815.2889332-1-kafai@fb.com
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
net/ipv4/tcp.c

index f0d79ea..108c430 100644 (file)
@@ -4044,14 +4044,14 @@ struct sk_buff *tcp_get_timestamping_opt_stats(const struct sock *sk,
 }
 
 static int do_tcp_getsockopt(struct sock *sk, int level,
-               int optname, char __user *optval, int __user *optlen)
+                            int optname, sockptr_t optval, sockptr_t optlen)
 {
        struct inet_connection_sock *icsk = inet_csk(sk);
        struct tcp_sock *tp = tcp_sk(sk);
        struct net *net = sock_net(sk);
        int val, len;
 
-       if (get_user(len, optlen))
+       if (copy_from_sockptr(&len, optlen, sizeof(int)))
                return -EFAULT;
 
        len = min_t(unsigned int, len, sizeof(int));
@@ -4101,15 +4101,15 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
        case TCP_INFO: {
                struct tcp_info info;
 
-               if (get_user(len, optlen))
+               if (copy_from_sockptr(&len, optlen, sizeof(int)))
                        return -EFAULT;
 
                tcp_get_info(sk, &info);
 
                len = min_t(unsigned int, len, sizeof(info));
-               if (put_user(len, optlen))
+               if (copy_to_sockptr(optlen, &len, sizeof(int)))
                        return -EFAULT;
-               if (copy_to_user(optval, &info, len))
+               if (copy_to_sockptr(optval, &info, len))
                        return -EFAULT;
                return 0;
        }
@@ -4119,7 +4119,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
                size_t sz = 0;
                int attr;
 
-               if (get_user(len, optlen))
+               if (copy_from_sockptr(&len, optlen, sizeof(int)))
                        return -EFAULT;
 
                ca_ops = icsk->icsk_ca_ops;
@@ -4127,9 +4127,9 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
                        sz = ca_ops->get_info(sk, ~0U, &attr, &info);
 
                len = min_t(unsigned int, len, sz);
-               if (put_user(len, optlen))
+               if (copy_to_sockptr(optlen, &len, sizeof(int)))
                        return -EFAULT;
-               if (copy_to_user(optval, &info, len))
+               if (copy_to_sockptr(optval, &info, len))
                        return -EFAULT;
                return 0;
        }
@@ -4138,27 +4138,28 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
                break;
 
        case TCP_CONGESTION:
-               if (get_user(len, optlen))
+               if (copy_from_sockptr(&len, optlen, sizeof(int)))
                        return -EFAULT;
                len = min_t(unsigned int, len, TCP_CA_NAME_MAX);
-               if (put_user(len, optlen))
+               if (copy_to_sockptr(optlen, &len, sizeof(int)))
                        return -EFAULT;
-               if (copy_to_user(optval, icsk->icsk_ca_ops->name, len))
+               if (copy_to_sockptr(optval, icsk->icsk_ca_ops->name, len))
                        return -EFAULT;
                return 0;
 
        case TCP_ULP:
-               if (get_user(len, optlen))
+               if (copy_from_sockptr(&len, optlen, sizeof(int)))
                        return -EFAULT;
                len = min_t(unsigned int, len, TCP_ULP_NAME_MAX);
                if (!icsk->icsk_ulp_ops) {
-                       if (put_user(0, optlen))
+                       len = 0;
+                       if (copy_to_sockptr(optlen, &len, sizeof(int)))
                                return -EFAULT;
                        return 0;
                }
-               if (put_user(len, optlen))
+               if (copy_to_sockptr(optlen, &len, sizeof(int)))
                        return -EFAULT;
-               if (copy_to_user(optval, icsk->icsk_ulp_ops->name, len))
+               if (copy_to_sockptr(optval, icsk->icsk_ulp_ops->name, len))
                        return -EFAULT;
                return 0;
 
@@ -4166,15 +4167,15 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
                u64 key[TCP_FASTOPEN_KEY_BUF_LENGTH / sizeof(u64)];
                unsigned int key_len;
 
-               if (get_user(len, optlen))
+               if (copy_from_sockptr(&len, optlen, sizeof(int)))
                        return -EFAULT;
 
                key_len = tcp_fastopen_get_cipher(net, icsk, key) *
                                TCP_FASTOPEN_KEY_LENGTH;
                len = min_t(unsigned int, len, key_len);
-               if (put_user(len, optlen))
+               if (copy_to_sockptr(optlen, &len, sizeof(int)))
                        return -EFAULT;
-               if (copy_to_user(optval, key, len))
+               if (copy_to_sockptr(optval, key, len))
                        return -EFAULT;
                return 0;
        }
@@ -4200,7 +4201,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
        case TCP_REPAIR_WINDOW: {
                struct tcp_repair_window opt;
 
-               if (get_user(len, optlen))
+               if (copy_from_sockptr(&len, optlen, sizeof(int)))
                        return -EFAULT;
 
                if (len != sizeof(opt))
@@ -4215,7 +4216,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
                opt.rcv_wnd     = tp->rcv_wnd;
                opt.rcv_wup     = tp->rcv_wup;
 
-               if (copy_to_user(optval, &opt, len))
+               if (copy_to_sockptr(optval, &opt, len))
                        return -EFAULT;
                return 0;
        }
@@ -4261,14 +4262,14 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
                val = tp->save_syn;
                break;
        case TCP_SAVED_SYN: {
-               if (get_user(len, optlen))
+               if (copy_from_sockptr(&len, optlen, sizeof(int)))
                        return -EFAULT;
 
                lock_sock(sk);
                if (tp->saved_syn) {
                        if (len < tcp_saved_syn_len(tp->saved_syn)) {
-                               if (put_user(tcp_saved_syn_len(tp->saved_syn),
-                                            optlen)) {
+                               len = tcp_saved_syn_len(tp->saved_syn);
+                               if (copy_to_sockptr(optlen, &len, sizeof(int))) {
                                        release_sock(sk);
                                        return -EFAULT;
                                }
@@ -4276,11 +4277,11 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
                                return -EINVAL;
                        }
                        len = tcp_saved_syn_len(tp->saved_syn);
-                       if (put_user(len, optlen)) {
+                       if (copy_to_sockptr(optlen, &len, sizeof(int))) {
                                release_sock(sk);
                                return -EFAULT;
                        }
-                       if (copy_to_user(optval, tp->saved_syn->data, len)) {
+                       if (copy_to_sockptr(optval, tp->saved_syn->data, len)) {
                                release_sock(sk);
                                return -EFAULT;
                        }
@@ -4289,7 +4290,7 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
                } else {
                        release_sock(sk);
                        len = 0;
-                       if (put_user(len, optlen))
+                       if (copy_to_sockptr(optlen, &len, sizeof(int)))
                                return -EFAULT;
                }
                return 0;
@@ -4300,21 +4301,21 @@ static int do_tcp_getsockopt(struct sock *sk, int level,
                struct tcp_zerocopy_receive zc = {};
                int err;
 
-               if (get_user(len, optlen))
+               if (copy_from_sockptr(&len, optlen, sizeof(int)))
                        return -EFAULT;
                if (len < 0 ||
                    len < offsetofend(struct tcp_zerocopy_receive, length))
                        return -EINVAL;
                if (unlikely(len > sizeof(zc))) {
-                       err = check_zeroed_user(optval + sizeof(zc),
-                                               len - sizeof(zc));
+                       err = check_zeroed_sockptr(optval, sizeof(zc),
+                                                  len - sizeof(zc));
                        if (err < 1)
                                return err == 0 ? -EINVAL : err;
                        len = sizeof(zc);
-                       if (put_user(len, optlen))
+                       if (copy_to_sockptr(optlen, &len, sizeof(int)))
                                return -EFAULT;
                }
-               if (copy_from_user(&zc, optval, len))
+               if (copy_from_sockptr(&zc, optval, len))
                        return -EFAULT;
                if (zc.reserved)
                        return -EINVAL;
@@ -4354,7 +4355,7 @@ zerocopy_rcv_sk_err:
 zerocopy_rcv_inq:
                zc.inq = tcp_inq_hint(sk);
 zerocopy_rcv_out:
-               if (!err && copy_to_user(optval, &zc, len))
+               if (!err && copy_to_sockptr(optval, &zc, len))
                        err = -EFAULT;
                return err;
        }
@@ -4363,9 +4364,9 @@ zerocopy_rcv_out:
                return -ENOPROTOOPT;
        }
 
-       if (put_user(len, optlen))
+       if (copy_to_sockptr(optlen, &len, sizeof(int)))
                return -EFAULT;
-       if (copy_to_user(optval, &val, len))
+       if (copy_to_sockptr(optval, &val, len))
                return -EFAULT;
        return 0;
 }
@@ -4390,7 +4391,8 @@ int tcp_getsockopt(struct sock *sk, int level, int optname, char __user *optval,
        if (level != SOL_TCP)
                return icsk->icsk_af_ops->getsockopt(sk, level, optname,
                                                     optval, optlen);
-       return do_tcp_getsockopt(sk, level, optname, optval, optlen);
+       return do_tcp_getsockopt(sk, level, optname, USER_SOCKPTR(optval),
+                                USER_SOCKPTR(optlen));
 }
 EXPORT_SYMBOL(tcp_getsockopt);