selftests/bpf: Fix erroneous bitmask operation
[platform/kernel/linux-rpi.git] / tools / testing / selftests / bpf / progs / xdping_kern.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. */
3
4 #define KBUILD_MODNAME "foo"
5 #include <stddef.h>
6 #include <string.h>
7 #include <linux/bpf.h>
8 #include <linux/icmp.h>
9 #include <linux/in.h>
10 #include <linux/if_ether.h>
11 #include <linux/if_packet.h>
12 #include <linux/if_vlan.h>
13 #include <linux/ip.h>
14
15 #include <bpf/bpf_helpers.h>
16 #include <bpf/bpf_endian.h>
17
18 #include "xdping.h"
19
20 struct {
21         __uint(type, BPF_MAP_TYPE_HASH);
22         __uint(max_entries, 256);
23         __type(key, __u32);
24         __type(value, struct pinginfo);
25 } ping_map SEC(".maps");
26
27 static __always_inline void swap_src_dst_mac(void *data)
28 {
29         unsigned short *p = data;
30         unsigned short dst[3];
31
32         dst[0] = p[0];
33         dst[1] = p[1];
34         dst[2] = p[2];
35         p[0] = p[3];
36         p[1] = p[4];
37         p[2] = p[5];
38         p[3] = dst[0];
39         p[4] = dst[1];
40         p[5] = dst[2];
41 }
42
43 static __always_inline __u16 csum_fold_helper(__wsum sum)
44 {
45         sum = (sum & 0xffff) + (sum >> 16);
46         return ~((sum & 0xffff) + (sum >> 16));
47 }
48
49 static __always_inline __u16 ipv4_csum(void *data_start, int data_size)
50 {
51         __wsum sum;
52
53         sum = bpf_csum_diff(0, 0, data_start, data_size, 0);
54         return csum_fold_helper(sum);
55 }
56
57 #define ICMP_ECHO_LEN           64
58
59 static __always_inline int icmp_check(struct xdp_md *ctx, int type)
60 {
61         void *data_end = (void *)(long)ctx->data_end;
62         void *data = (void *)(long)ctx->data;
63         struct ethhdr *eth = data;
64         struct icmphdr *icmph;
65         struct iphdr *iph;
66
67         if (data + sizeof(*eth) + sizeof(*iph) + ICMP_ECHO_LEN > data_end)
68                 return XDP_PASS;
69
70         if (eth->h_proto != bpf_htons(ETH_P_IP))
71                 return XDP_PASS;
72
73         iph = data + sizeof(*eth);
74
75         if (iph->protocol != IPPROTO_ICMP)
76                 return XDP_PASS;
77
78         if (bpf_ntohs(iph->tot_len) - sizeof(*iph) != ICMP_ECHO_LEN)
79                 return XDP_PASS;
80
81         icmph = data + sizeof(*eth) + sizeof(*iph);
82
83         if (icmph->type != type)
84                 return XDP_PASS;
85
86         return XDP_TX;
87 }
88
89 SEC("xdp")
90 int xdping_client(struct xdp_md *ctx)
91 {
92         void *data = (void *)(long)ctx->data;
93         struct pinginfo *pinginfo = NULL;
94         struct ethhdr *eth = data;
95         struct icmphdr *icmph;
96         struct iphdr *iph;
97         __u64 recvtime;
98         __be32 raddr;
99         __be16 seq;
100         int ret;
101         __u8 i;
102
103         ret = icmp_check(ctx, ICMP_ECHOREPLY);
104
105         if (ret != XDP_TX)
106                 return ret;
107
108         iph = data + sizeof(*eth);
109         icmph = data + sizeof(*eth) + sizeof(*iph);
110         raddr = iph->saddr;
111
112         /* Record time reply received. */
113         recvtime = bpf_ktime_get_ns();
114         pinginfo = bpf_map_lookup_elem(&ping_map, &raddr);
115         if (!pinginfo || pinginfo->seq != icmph->un.echo.sequence)
116                 return XDP_PASS;
117
118         if (pinginfo->start) {
119 #pragma clang loop unroll(full)
120                 for (i = 0; i < XDPING_MAX_COUNT; i++) {
121                         if (pinginfo->times[i] == 0)
122                                 break;
123                 }
124                 /* verifier is fussy here... */
125                 if (i < XDPING_MAX_COUNT) {
126                         pinginfo->times[i] = recvtime -
127                                              pinginfo->start;
128                         pinginfo->start = 0;
129                         i++;
130                 }
131                 /* No more space for values? */
132                 if (i == pinginfo->count || i == XDPING_MAX_COUNT)
133                         return XDP_PASS;
134         }
135
136         /* Now convert reply back into echo request. */
137         swap_src_dst_mac(data);
138         iph->saddr = iph->daddr;
139         iph->daddr = raddr;
140         icmph->type = ICMP_ECHO;
141         seq = bpf_htons(bpf_ntohs(icmph->un.echo.sequence) + 1);
142         icmph->un.echo.sequence = seq;
143         icmph->checksum = 0;
144         icmph->checksum = ipv4_csum(icmph, ICMP_ECHO_LEN);
145
146         pinginfo->seq = seq;
147         pinginfo->start = bpf_ktime_get_ns();
148
149         return XDP_TX;
150 }
151
152 SEC("xdp")
153 int xdping_server(struct xdp_md *ctx)
154 {
155         void *data = (void *)(long)ctx->data;
156         struct ethhdr *eth = data;
157         struct icmphdr *icmph;
158         struct iphdr *iph;
159         __be32 raddr;
160         int ret;
161
162         ret = icmp_check(ctx, ICMP_ECHO);
163
164         if (ret != XDP_TX)
165                 return ret;
166
167         iph = data + sizeof(*eth);
168         icmph = data + sizeof(*eth) + sizeof(*iph);
169         raddr = iph->saddr;
170
171         /* Now convert request into echo reply. */
172         swap_src_dst_mac(data);
173         iph->saddr = iph->daddr;
174         iph->daddr = raddr;
175         icmph->type = ICMP_ECHOREPLY;
176         icmph->checksum = 0;
177         icmph->checksum = ipv4_csum(icmph, ICMP_ECHO_LEN);
178
179         return XDP_TX;
180 }
181
182 char _license[] SEC("license") = "GPL";