selftests/bpf: Fix erroneous bitmask operation
[platform/kernel/linux-rpi.git] / tools / testing / selftests / bpf / progs / bpf_iter_setsockopt.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021 Facebook */
3 #include "bpf_iter.h"
4 #include "bpf_tracing_net.h"
5 #include <bpf/bpf_helpers.h>
6 #include <bpf/bpf_endian.h>
7
8 #define bpf_tcp_sk(skc) ({                              \
9         struct sock_common *_skc = skc;                 \
10         sk = NULL;                                      \
11         tp = NULL;                                      \
12         if (_skc) {                                     \
13                 tp = bpf_skc_to_tcp_sock(_skc);         \
14                 sk = (struct sock *)tp;                 \
15         }                                               \
16         tp;                                             \
17 })
18
19 unsigned short reuse_listen_hport = 0;
20 unsigned short listen_hport = 0;
21 char cubic_cc[TCP_CA_NAME_MAX] = "bpf_cubic";
22 char dctcp_cc[TCP_CA_NAME_MAX] = "bpf_dctcp";
23 bool random_retry = false;
24
25 static bool tcp_cc_eq(const char *a, const char *b)
26 {
27         int i;
28
29         for (i = 0; i < TCP_CA_NAME_MAX; i++) {
30                 if (a[i] != b[i])
31                         return false;
32                 if (!a[i])
33                         break;
34         }
35
36         return true;
37 }
38
39 SEC("iter/tcp")
40 int change_tcp_cc(struct bpf_iter__tcp *ctx)
41 {
42         char cur_cc[TCP_CA_NAME_MAX];
43         struct tcp_sock *tp;
44         struct sock *sk;
45
46         if (!bpf_tcp_sk(ctx->sk_common))
47                 return 0;
48
49         if (sk->sk_family != AF_INET6 ||
50             (sk->sk_state != TCP_LISTEN &&
51              sk->sk_state != TCP_ESTABLISHED) ||
52             (sk->sk_num != reuse_listen_hport &&
53              sk->sk_num != listen_hport &&
54              bpf_ntohs(sk->sk_dport) != listen_hport))
55                 return 0;
56
57         if (bpf_getsockopt(tp, SOL_TCP, TCP_CONGESTION,
58                            cur_cc, sizeof(cur_cc)))
59                 return 0;
60
61         if (!tcp_cc_eq(cur_cc, cubic_cc))
62                 return 0;
63
64         if (random_retry && bpf_get_prandom_u32() % 4 == 1)
65                 return 1;
66
67         bpf_setsockopt(tp, SOL_TCP, TCP_CONGESTION, dctcp_cc, sizeof(dctcp_cc));
68         return 0;
69 }
70
71 char _license[] SEC("license") = "GPL";