selftests/bpf: Fix erroneous bitmask operation
[platform/kernel/linux-rpi.git] / tools / testing / selftests / bpf / progs / test_xdp.c
1 /* Copyright (c) 2016,2017 Facebook
2  *
3  * This program is free software; you can redistribute it and/or
4  * modify it under the terms of version 2 of the GNU General Public
5  * License as published by the Free Software Foundation.
6  */
7 #include <stddef.h>
8 #include <string.h>
9 #include <linux/bpf.h>
10 #include <linux/if_ether.h>
11 #include <linux/if_packet.h>
12 #include <linux/ip.h>
13 #include <linux/ipv6.h>
14 #include <linux/in.h>
15 #include <linux/udp.h>
16 #include <linux/tcp.h>
17 #include <linux/pkt_cls.h>
18 #include <sys/socket.h>
19 #include <bpf/bpf_helpers.h>
20 #include <bpf/bpf_endian.h>
21 #include "test_iptunnel_common.h"
22
23 struct {
24         __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
25         __uint(max_entries, 256);
26         __type(key, __u32);
27         __type(value, __u64);
28 } rxcnt SEC(".maps");
29
30 struct {
31         __uint(type, BPF_MAP_TYPE_HASH);
32         __uint(max_entries, MAX_IPTNL_ENTRIES);
33         __type(key, struct vip);
34         __type(value, struct iptnl_info);
35 } vip2tnl SEC(".maps");
36
37 static __always_inline void count_tx(__u32 protocol)
38 {
39         __u64 *rxcnt_count;
40
41         rxcnt_count = bpf_map_lookup_elem(&rxcnt, &protocol);
42         if (rxcnt_count)
43                 *rxcnt_count += 1;
44 }
45
46 static __always_inline int get_dport(void *trans_data, void *data_end,
47                                      __u8 protocol)
48 {
49         struct tcphdr *th;
50         struct udphdr *uh;
51
52         switch (protocol) {
53         case IPPROTO_TCP:
54                 th = (struct tcphdr *)trans_data;
55                 if (th + 1 > data_end)
56                         return -1;
57                 return th->dest;
58         case IPPROTO_UDP:
59                 uh = (struct udphdr *)trans_data;
60                 if (uh + 1 > data_end)
61                         return -1;
62                 return uh->dest;
63         default:
64                 return 0;
65         }
66 }
67
68 static __always_inline void set_ethhdr(struct ethhdr *new_eth,
69                                        const struct ethhdr *old_eth,
70                                        const struct iptnl_info *tnl,
71                                        __be16 h_proto)
72 {
73         memcpy(new_eth->h_source, old_eth->h_dest, sizeof(new_eth->h_source));
74         memcpy(new_eth->h_dest, tnl->dmac, sizeof(new_eth->h_dest));
75         new_eth->h_proto = h_proto;
76 }
77
78 static __always_inline int handle_ipv4(struct xdp_md *xdp)
79 {
80         void *data_end = (void *)(long)xdp->data_end;
81         void *data = (void *)(long)xdp->data;
82         struct iptnl_info *tnl;
83         struct ethhdr *new_eth;
84         struct ethhdr *old_eth;
85         struct iphdr *iph = data + sizeof(struct ethhdr);
86         __u16 *next_iph;
87         __u16 payload_len;
88         struct vip vip = {};
89         int dport;
90         __u32 csum = 0;
91         int i;
92
93         if (iph + 1 > data_end)
94                 return XDP_DROP;
95
96         dport = get_dport(iph + 1, data_end, iph->protocol);
97         if (dport == -1)
98                 return XDP_DROP;
99
100         vip.protocol = iph->protocol;
101         vip.family = AF_INET;
102         vip.daddr.v4 = iph->daddr;
103         vip.dport = dport;
104         payload_len = bpf_ntohs(iph->tot_len);
105
106         tnl = bpf_map_lookup_elem(&vip2tnl, &vip);
107         /* It only does v4-in-v4 */
108         if (!tnl || tnl->family != AF_INET)
109                 return XDP_PASS;
110
111         if (bpf_xdp_adjust_head(xdp, 0 - (int)sizeof(struct iphdr)))
112                 return XDP_DROP;
113
114         data = (void *)(long)xdp->data;
115         data_end = (void *)(long)xdp->data_end;
116
117         new_eth = data;
118         iph = data + sizeof(*new_eth);
119         old_eth = data + sizeof(*iph);
120
121         if (new_eth + 1 > data_end ||
122             old_eth + 1 > data_end ||
123             iph + 1 > data_end)
124                 return XDP_DROP;
125
126         set_ethhdr(new_eth, old_eth, tnl, bpf_htons(ETH_P_IP));
127
128         iph->version = 4;
129         iph->ihl = sizeof(*iph) >> 2;
130         iph->frag_off = 0;
131         iph->protocol = IPPROTO_IPIP;
132         iph->check = 0;
133         iph->tos = 0;
134         iph->tot_len = bpf_htons(payload_len + sizeof(*iph));
135         iph->daddr = tnl->daddr.v4;
136         iph->saddr = tnl->saddr.v4;
137         iph->ttl = 8;
138
139         next_iph = (__u16 *)iph;
140 #pragma clang loop unroll(full)
141         for (i = 0; i < sizeof(*iph) >> 1; i++)
142                 csum += *next_iph++;
143
144         iph->check = ~((csum & 0xffff) + (csum >> 16));
145
146         count_tx(vip.protocol);
147
148         return XDP_TX;
149 }
150
151 static __always_inline int handle_ipv6(struct xdp_md *xdp)
152 {
153         void *data_end = (void *)(long)xdp->data_end;
154         void *data = (void *)(long)xdp->data;
155         struct iptnl_info *tnl;
156         struct ethhdr *new_eth;
157         struct ethhdr *old_eth;
158         struct ipv6hdr *ip6h = data + sizeof(struct ethhdr);
159         __u16 payload_len;
160         struct vip vip = {};
161         int dport;
162
163         if (ip6h + 1 > data_end)
164                 return XDP_DROP;
165
166         dport = get_dport(ip6h + 1, data_end, ip6h->nexthdr);
167         if (dport == -1)
168                 return XDP_DROP;
169
170         vip.protocol = ip6h->nexthdr;
171         vip.family = AF_INET6;
172         memcpy(vip.daddr.v6, ip6h->daddr.s6_addr32, sizeof(vip.daddr));
173         vip.dport = dport;
174         payload_len = ip6h->payload_len;
175
176         tnl = bpf_map_lookup_elem(&vip2tnl, &vip);
177         /* It only does v6-in-v6 */
178         if (!tnl || tnl->family != AF_INET6)
179                 return XDP_PASS;
180
181         if (bpf_xdp_adjust_head(xdp, 0 - (int)sizeof(struct ipv6hdr)))
182                 return XDP_DROP;
183
184         data = (void *)(long)xdp->data;
185         data_end = (void *)(long)xdp->data_end;
186
187         new_eth = data;
188         ip6h = data + sizeof(*new_eth);
189         old_eth = data + sizeof(*ip6h);
190
191         if (new_eth + 1 > data_end || old_eth + 1 > data_end ||
192             ip6h + 1 > data_end)
193                 return XDP_DROP;
194
195         set_ethhdr(new_eth, old_eth, tnl, bpf_htons(ETH_P_IPV6));
196
197         ip6h->version = 6;
198         ip6h->priority = 0;
199         memset(ip6h->flow_lbl, 0, sizeof(ip6h->flow_lbl));
200         ip6h->payload_len = bpf_htons(bpf_ntohs(payload_len) + sizeof(*ip6h));
201         ip6h->nexthdr = IPPROTO_IPV6;
202         ip6h->hop_limit = 8;
203         memcpy(ip6h->saddr.s6_addr32, tnl->saddr.v6, sizeof(tnl->saddr.v6));
204         memcpy(ip6h->daddr.s6_addr32, tnl->daddr.v6, sizeof(tnl->daddr.v6));
205
206         count_tx(vip.protocol);
207
208         return XDP_TX;
209 }
210
211 SEC("xdp")
212 int _xdp_tx_iptunnel(struct xdp_md *xdp)
213 {
214         void *data_end = (void *)(long)xdp->data_end;
215         void *data = (void *)(long)xdp->data;
216         struct ethhdr *eth = data;
217         __u16 h_proto;
218
219         if (eth + 1 > data_end)
220                 return XDP_DROP;
221
222         h_proto = eth->h_proto;
223
224         if (h_proto == bpf_htons(ETH_P_IP))
225                 return handle_ipv4(xdp);
226         else if (h_proto == bpf_htons(ETH_P_IPV6))
227
228                 return handle_ipv6(xdp);
229         else
230                 return XDP_DROP;
231 }
232
233 char _license[] SEC("license") = "GPL";