selftests/bpf: Fix erroneous bitmask operation
[platform/kernel/linux-rpi.git] / tools / testing / selftests / bpf / progs / setget_sockopt.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) Meta Platforms, Inc. and affiliates. */
3
4 #include "vmlinux.h"
5 #include "bpf_tracing_net.h"
6 #include <bpf/bpf_core_read.h>
7 #include <bpf/bpf_helpers.h>
8 #include <bpf/bpf_tracing.h>
9
10 #ifndef ARRAY_SIZE
11 #define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0]))
12 #endif
13
14 extern unsigned long CONFIG_HZ __kconfig;
15
16 const volatile char veth[IFNAMSIZ];
17 const volatile int veth_ifindex;
18
19 int nr_listen;
20 int nr_passive;
21 int nr_active;
22 int nr_connect;
23 int nr_binddev;
24 int nr_socket_post_create;
25 int nr_fin_wait1;
26
27 struct sockopt_test {
28         int opt;
29         int new;
30         int restore;
31         int expected;
32         int tcp_expected;
33         unsigned int flip:1;
34 };
35
36 static const char not_exist_cc[] = "not_exist";
37 static const char cubic_cc[] = "cubic";
38 static const char reno_cc[] = "reno";
39
40 static const struct sockopt_test sol_socket_tests[] = {
41         { .opt = SO_REUSEADDR, .flip = 1, },
42         { .opt = SO_SNDBUF, .new = 8123, .expected = 8123 * 2, },
43         { .opt = SO_RCVBUF, .new = 8123, .expected = 8123 * 2, },
44         { .opt = SO_KEEPALIVE, .flip = 1, },
45         { .opt = SO_PRIORITY, .new = 0xeb9f, .expected = 0xeb9f, },
46         { .opt = SO_REUSEPORT, .flip = 1, },
47         { .opt = SO_RCVLOWAT, .new = 8123, .expected = 8123, },
48         { .opt = SO_MARK, .new = 0xeb9f, .expected = 0xeb9f, },
49         { .opt = SO_MAX_PACING_RATE, .new = 0xeb9f, .expected = 0xeb9f, },
50         { .opt = SO_TXREHASH, .flip = 1, },
51         { .opt = 0, },
52 };
53
54 static const struct sockopt_test sol_tcp_tests[] = {
55         { .opt = TCP_NODELAY, .flip = 1, },
56         { .opt = TCP_KEEPIDLE, .new = 123, .expected = 123, .restore = 321, },
57         { .opt = TCP_KEEPINTVL, .new = 123, .expected = 123, .restore = 321, },
58         { .opt = TCP_KEEPCNT, .new = 123, .expected = 123, .restore = 124, },
59         { .opt = TCP_SYNCNT, .new = 123, .expected = 123, .restore = 124, },
60         { .opt = TCP_WINDOW_CLAMP, .new = 8123, .expected = 8123, .restore = 8124, },
61         { .opt = TCP_CONGESTION, },
62         { .opt = TCP_THIN_LINEAR_TIMEOUTS, .flip = 1, },
63         { .opt = TCP_USER_TIMEOUT, .new = 123400, .expected = 123400, },
64         { .opt = TCP_NOTSENT_LOWAT, .new = 1314, .expected = 1314, },
65         { .opt = 0, },
66 };
67
68 static const struct sockopt_test sol_ip_tests[] = {
69         { .opt = IP_TOS, .new = 0xe1, .expected = 0xe1, .tcp_expected = 0xe0, },
70         { .opt = 0, },
71 };
72
73 static const struct sockopt_test sol_ipv6_tests[] = {
74         { .opt = IPV6_TCLASS, .new = 0xe1, .expected = 0xe1, .tcp_expected = 0xe0, },
75         { .opt = IPV6_AUTOFLOWLABEL, .flip = 1, },
76         { .opt = 0, },
77 };
78
79 struct loop_ctx {
80         void *ctx;
81         struct sock *sk;
82 };
83
84 static int bpf_test_sockopt_flip(void *ctx, struct sock *sk,
85                                  const struct sockopt_test *t,
86                                  int level)
87 {
88         int old, tmp, new, opt = t->opt;
89
90         opt = t->opt;
91
92         if (bpf_getsockopt(ctx, level, opt, &old, sizeof(old)))
93                 return 1;
94         /* kernel initialized txrehash to 255 */
95         if (level == SOL_SOCKET && opt == SO_TXREHASH && old != 0 && old != 1)
96                 old = 1;
97
98         new = !old;
99         if (bpf_setsockopt(ctx, level, opt, &new, sizeof(new)))
100                 return 1;
101         if (bpf_getsockopt(ctx, level, opt, &tmp, sizeof(tmp)) ||
102             tmp != new)
103                 return 1;
104
105         if (bpf_setsockopt(ctx, level, opt, &old, sizeof(old)))
106                 return 1;
107
108         return 0;
109 }
110
111 static int bpf_test_sockopt_int(void *ctx, struct sock *sk,
112                                 const struct sockopt_test *t,
113                                 int level)
114 {
115         int old, tmp, new, expected, opt;
116
117         opt = t->opt;
118         new = t->new;
119         if (sk->sk_type == SOCK_STREAM && t->tcp_expected)
120                 expected = t->tcp_expected;
121         else
122                 expected = t->expected;
123
124         if (bpf_getsockopt(ctx, level, opt, &old, sizeof(old)) ||
125             old == new)
126                 return 1;
127
128         if (bpf_setsockopt(ctx, level, opt, &new, sizeof(new)))
129                 return 1;
130         if (bpf_getsockopt(ctx, level, opt, &tmp, sizeof(tmp)) ||
131             tmp != expected)
132                 return 1;
133
134         if (t->restore)
135                 old = t->restore;
136         if (bpf_setsockopt(ctx, level, opt, &old, sizeof(old)))
137                 return 1;
138
139         return 0;
140 }
141
142 static int bpf_test_socket_sockopt(__u32 i, struct loop_ctx *lc)
143 {
144         const struct sockopt_test *t;
145
146         if (i >= ARRAY_SIZE(sol_socket_tests))
147                 return 1;
148
149         t = &sol_socket_tests[i];
150         if (!t->opt)
151                 return 1;
152
153         if (t->flip)
154                 return bpf_test_sockopt_flip(lc->ctx, lc->sk, t, SOL_SOCKET);
155
156         return bpf_test_sockopt_int(lc->ctx, lc->sk, t, SOL_SOCKET);
157 }
158
159 static int bpf_test_ip_sockopt(__u32 i, struct loop_ctx *lc)
160 {
161         const struct sockopt_test *t;
162
163         if (i >= ARRAY_SIZE(sol_ip_tests))
164                 return 1;
165
166         t = &sol_ip_tests[i];
167         if (!t->opt)
168                 return 1;
169
170         if (t->flip)
171                 return bpf_test_sockopt_flip(lc->ctx, lc->sk, t, IPPROTO_IP);
172
173         return bpf_test_sockopt_int(lc->ctx, lc->sk, t, IPPROTO_IP);
174 }
175
176 static int bpf_test_ipv6_sockopt(__u32 i, struct loop_ctx *lc)
177 {
178         const struct sockopt_test *t;
179
180         if (i >= ARRAY_SIZE(sol_ipv6_tests))
181                 return 1;
182
183         t = &sol_ipv6_tests[i];
184         if (!t->opt)
185                 return 1;
186
187         if (t->flip)
188                 return bpf_test_sockopt_flip(lc->ctx, lc->sk, t, IPPROTO_IPV6);
189
190         return bpf_test_sockopt_int(lc->ctx, lc->sk, t, IPPROTO_IPV6);
191 }
192
193 static int bpf_test_tcp_sockopt(__u32 i, struct loop_ctx *lc)
194 {
195         const struct sockopt_test *t;
196         struct sock *sk;
197         void *ctx;
198
199         if (i >= ARRAY_SIZE(sol_tcp_tests))
200                 return 1;
201
202         t = &sol_tcp_tests[i];
203         if (!t->opt)
204                 return 1;
205
206         ctx = lc->ctx;
207         sk = lc->sk;
208
209         if (t->opt == TCP_CONGESTION) {
210                 char old_cc[16], tmp_cc[16];
211                 const char *new_cc;
212                 int new_cc_len;
213
214                 if (!bpf_setsockopt(ctx, IPPROTO_TCP, TCP_CONGESTION,
215                                     (void *)not_exist_cc, sizeof(not_exist_cc)))
216                         return 1;
217                 if (bpf_getsockopt(ctx, IPPROTO_TCP, TCP_CONGESTION, old_cc, sizeof(old_cc)))
218                         return 1;
219                 if (!bpf_strncmp(old_cc, sizeof(old_cc), cubic_cc)) {
220                         new_cc = reno_cc;
221                         new_cc_len = sizeof(reno_cc);
222                 } else {
223                         new_cc = cubic_cc;
224                         new_cc_len = sizeof(cubic_cc);
225                 }
226                 if (bpf_setsockopt(ctx, IPPROTO_TCP, TCP_CONGESTION, (void *)new_cc,
227                                    new_cc_len))
228                         return 1;
229                 if (bpf_getsockopt(ctx, IPPROTO_TCP, TCP_CONGESTION, tmp_cc, sizeof(tmp_cc)))
230                         return 1;
231                 if (bpf_strncmp(tmp_cc, sizeof(tmp_cc), new_cc))
232                         return 1;
233                 if (bpf_setsockopt(ctx, IPPROTO_TCP, TCP_CONGESTION, old_cc, sizeof(old_cc)))
234                         return 1;
235                 return 0;
236         }
237
238         if (t->flip)
239                 return bpf_test_sockopt_flip(ctx, sk, t, IPPROTO_TCP);
240
241         return bpf_test_sockopt_int(ctx, sk, t, IPPROTO_TCP);
242 }
243
244 static int bpf_test_sockopt(void *ctx, struct sock *sk)
245 {
246         struct loop_ctx lc = { .ctx = ctx, .sk = sk, };
247         __u16 family, proto;
248         int n;
249
250         family = sk->sk_family;
251         proto = sk->sk_protocol;
252
253         n = bpf_loop(ARRAY_SIZE(sol_socket_tests), bpf_test_socket_sockopt, &lc, 0);
254         if (n != ARRAY_SIZE(sol_socket_tests))
255                 return -1;
256
257         if (proto == IPPROTO_TCP) {
258                 n = bpf_loop(ARRAY_SIZE(sol_tcp_tests), bpf_test_tcp_sockopt, &lc, 0);
259                 if (n != ARRAY_SIZE(sol_tcp_tests))
260                         return -1;
261         }
262
263         if (family == AF_INET) {
264                 n = bpf_loop(ARRAY_SIZE(sol_ip_tests), bpf_test_ip_sockopt, &lc, 0);
265                 if (n != ARRAY_SIZE(sol_ip_tests))
266                         return -1;
267         } else {
268                 n = bpf_loop(ARRAY_SIZE(sol_ipv6_tests), bpf_test_ipv6_sockopt, &lc, 0);
269                 if (n != ARRAY_SIZE(sol_ipv6_tests))
270                         return -1;
271         }
272
273         return 0;
274 }
275
276 static int binddev_test(void *ctx)
277 {
278         const char empty_ifname[] = "";
279         int ifindex, zero = 0;
280
281         if (bpf_setsockopt(ctx, SOL_SOCKET, SO_BINDTODEVICE,
282                            (void *)veth, sizeof(veth)))
283                 return -1;
284         if (bpf_getsockopt(ctx, SOL_SOCKET, SO_BINDTOIFINDEX,
285                            &ifindex, sizeof(int)) ||
286             ifindex != veth_ifindex)
287                 return -1;
288
289         if (bpf_setsockopt(ctx, SOL_SOCKET, SO_BINDTODEVICE,
290                            (void *)empty_ifname, sizeof(empty_ifname)))
291                 return -1;
292         if (bpf_getsockopt(ctx, SOL_SOCKET, SO_BINDTOIFINDEX,
293                            &ifindex, sizeof(int)) ||
294             ifindex != 0)
295                 return -1;
296
297         if (bpf_setsockopt(ctx, SOL_SOCKET, SO_BINDTOIFINDEX,
298                            (void *)&veth_ifindex, sizeof(int)))
299                 return -1;
300         if (bpf_getsockopt(ctx, SOL_SOCKET, SO_BINDTOIFINDEX,
301                            &ifindex, sizeof(int)) ||
302             ifindex != veth_ifindex)
303                 return -1;
304
305         if (bpf_setsockopt(ctx, SOL_SOCKET, SO_BINDTOIFINDEX,
306                            &zero, sizeof(int)))
307                 return -1;
308         if (bpf_getsockopt(ctx, SOL_SOCKET, SO_BINDTOIFINDEX,
309                            &ifindex, sizeof(int)) ||
310             ifindex != 0)
311                 return -1;
312
313         return 0;
314 }
315
316 static int test_tcp_maxseg(void *ctx, struct sock *sk)
317 {
318         int val = 1314, tmp;
319
320         if (sk->sk_state != TCP_ESTABLISHED)
321                 return bpf_setsockopt(ctx, IPPROTO_TCP, TCP_MAXSEG,
322                                       &val, sizeof(val));
323
324         if (bpf_getsockopt(ctx, IPPROTO_TCP, TCP_MAXSEG, &tmp, sizeof(tmp)) ||
325             tmp > val)
326                 return -1;
327
328         return 0;
329 }
330
331 static int test_tcp_saved_syn(void *ctx, struct sock *sk)
332 {
333         __u8 saved_syn[20];
334         int one = 1;
335
336         if (sk->sk_state == TCP_LISTEN)
337                 return bpf_setsockopt(ctx, IPPROTO_TCP, TCP_SAVE_SYN,
338                                       &one, sizeof(one));
339
340         return bpf_getsockopt(ctx, IPPROTO_TCP, TCP_SAVED_SYN,
341                               saved_syn, sizeof(saved_syn));
342 }
343
344 SEC("lsm_cgroup/socket_post_create")
345 int BPF_PROG(socket_post_create, struct socket *sock, int family,
346              int type, int protocol, int kern)
347 {
348         struct sock *sk = sock->sk;
349
350         if (!sk)
351                 return 1;
352
353         nr_socket_post_create += !bpf_test_sockopt(sk, sk);
354         nr_binddev += !binddev_test(sk);
355
356         return 1;
357 }
358
359 SEC("sockops")
360 int skops_sockopt(struct bpf_sock_ops *skops)
361 {
362         struct bpf_sock *bpf_sk = skops->sk;
363         struct sock *sk;
364
365         if (!bpf_sk)
366                 return 1;
367
368         sk = (struct sock *)bpf_skc_to_tcp_sock(bpf_sk);
369         if (!sk)
370                 return 1;
371
372         switch (skops->op) {
373         case BPF_SOCK_OPS_TCP_LISTEN_CB:
374                 nr_listen += !(bpf_test_sockopt(skops, sk) ||
375                                test_tcp_maxseg(skops, sk) ||
376                                test_tcp_saved_syn(skops, sk));
377                 break;
378         case BPF_SOCK_OPS_TCP_CONNECT_CB:
379                 nr_connect += !(bpf_test_sockopt(skops, sk) ||
380                                 test_tcp_maxseg(skops, sk));
381                 break;
382         case BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB:
383                 nr_active += !(bpf_test_sockopt(skops, sk) ||
384                                test_tcp_maxseg(skops, sk));
385                 break;
386         case BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB:
387                 nr_passive += !(bpf_test_sockopt(skops, sk) ||
388                                 test_tcp_maxseg(skops, sk) ||
389                                 test_tcp_saved_syn(skops, sk));
390                 bpf_sock_ops_cb_flags_set(skops,
391                                           skops->bpf_sock_ops_cb_flags |
392                                           BPF_SOCK_OPS_STATE_CB_FLAG);
393                 break;
394         case BPF_SOCK_OPS_STATE_CB:
395                 if (skops->args[1] == BPF_TCP_CLOSE_WAIT)
396                         nr_fin_wait1 += !bpf_test_sockopt(skops, sk);
397                 break;
398         }
399
400         return 1;
401 }
402
403 char _license[] SEC("license") = "GPL";