selftests/bpf: Fix erroneous bitmask operation
[platform/kernel/linux-rpi.git] / tools / testing / selftests / bpf / progs / test_l4lb_noinline.c
1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2017 Facebook
3 #include <stddef.h>
4 #include <stdbool.h>
5 #include <string.h>
6 #include <linux/pkt_cls.h>
7 #include <linux/bpf.h>
8 #include <linux/in.h>
9 #include <linux/if_ether.h>
10 #include <linux/ip.h>
11 #include <linux/ipv6.h>
12 #include <linux/icmp.h>
13 #include <linux/icmpv6.h>
14 #include <linux/tcp.h>
15 #include <linux/udp.h>
16 #include <bpf/bpf_helpers.h>
17 #include "test_iptunnel_common.h"
18 #include <bpf/bpf_endian.h>
19
20 static __always_inline __u32 rol32(__u32 word, unsigned int shift)
21 {
22         return (word << shift) | (word >> ((-shift) & 31));
23 }
24
25 /* copy paste of jhash from kernel sources to make sure llvm
26  * can compile it into valid sequence of bpf instructions
27  */
28 #define __jhash_mix(a, b, c)                    \
29 {                                               \
30         a -= c;  a ^= rol32(c, 4);  c += b;     \
31         b -= a;  b ^= rol32(a, 6);  a += c;     \
32         c -= b;  c ^= rol32(b, 8);  b += a;     \
33         a -= c;  a ^= rol32(c, 16); c += b;     \
34         b -= a;  b ^= rol32(a, 19); a += c;     \
35         c -= b;  c ^= rol32(b, 4);  b += a;     \
36 }
37
38 #define __jhash_final(a, b, c)                  \
39 {                                               \
40         c ^= b; c -= rol32(b, 14);              \
41         a ^= c; a -= rol32(c, 11);              \
42         b ^= a; b -= rol32(a, 25);              \
43         c ^= b; c -= rol32(b, 16);              \
44         a ^= c; a -= rol32(c, 4);               \
45         b ^= a; b -= rol32(a, 14);              \
46         c ^= b; c -= rol32(b, 24);              \
47 }
48
49 #define JHASH_INITVAL           0xdeadbeef
50
51 typedef unsigned int u32;
52
53 static __noinline u32 jhash(const void *key, u32 length, u32 initval)
54 {
55         u32 a, b, c;
56         const unsigned char *k = key;
57
58         a = b = c = JHASH_INITVAL + length + initval;
59
60         while (length > 12) {
61                 a += *(u32 *)(k);
62                 b += *(u32 *)(k + 4);
63                 c += *(u32 *)(k + 8);
64                 __jhash_mix(a, b, c);
65                 length -= 12;
66                 k += 12;
67         }
68         switch (length) {
69         case 12: c += (u32)k[11]<<24;
70         case 11: c += (u32)k[10]<<16;
71         case 10: c += (u32)k[9]<<8;
72         case 9:  c += k[8];
73         case 8:  b += (u32)k[7]<<24;
74         case 7:  b += (u32)k[6]<<16;
75         case 6:  b += (u32)k[5]<<8;
76         case 5:  b += k[4];
77         case 4:  a += (u32)k[3]<<24;
78         case 3:  a += (u32)k[2]<<16;
79         case 2:  a += (u32)k[1]<<8;
80         case 1:  a += k[0];
81                  __jhash_final(a, b, c);
82         case 0: /* Nothing left to add */
83                 break;
84         }
85
86         return c;
87 }
88
89 static __noinline u32 __jhash_nwords(u32 a, u32 b, u32 c, u32 initval)
90 {
91         a += initval;
92         b += initval;
93         c += initval;
94         __jhash_final(a, b, c);
95         return c;
96 }
97
98 static __noinline u32 jhash_2words(u32 a, u32 b, u32 initval)
99 {
100         return __jhash_nwords(a, b, 0, initval + JHASH_INITVAL + (2 << 2));
101 }
102
103 #define PCKT_FRAGMENTED 65343
104 #define IPV4_HDR_LEN_NO_OPT 20
105 #define IPV4_PLUS_ICMP_HDR 28
106 #define IPV6_PLUS_ICMP_HDR 48
107 #define RING_SIZE 2
108 #define MAX_VIPS 12
109 #define MAX_REALS 5
110 #define CTL_MAP_SIZE 16
111 #define CH_RINGS_SIZE (MAX_VIPS * RING_SIZE)
112 #define F_IPV6 (1 << 0)
113 #define F_HASH_NO_SRC_PORT (1 << 0)
114 #define F_ICMP (1 << 0)
115 #define F_SYN_SET (1 << 1)
116
117 struct packet_description {
118         union {
119                 __be32 src;
120                 __be32 srcv6[4];
121         };
122         union {
123                 __be32 dst;
124                 __be32 dstv6[4];
125         };
126         union {
127                 __u32 ports;
128                 __u16 port16[2];
129         };
130         __u8 proto;
131         __u8 flags;
132 };
133
134 struct ctl_value {
135         union {
136                 __u64 value;
137                 __u32 ifindex;
138                 __u8 mac[6];
139         };
140 };
141
142 struct vip_meta {
143         __u32 flags;
144         __u32 vip_num;
145 };
146
147 struct real_definition {
148         union {
149                 __be32 dst;
150                 __be32 dstv6[4];
151         };
152         __u8 flags;
153 };
154
155 struct vip_stats {
156         __u64 bytes;
157         __u64 pkts;
158 };
159
160 struct eth_hdr {
161         unsigned char eth_dest[ETH_ALEN];
162         unsigned char eth_source[ETH_ALEN];
163         unsigned short eth_proto;
164 };
165
166 struct {
167         __uint(type, BPF_MAP_TYPE_HASH);
168         __uint(max_entries, MAX_VIPS);
169         __type(key, struct vip);
170         __type(value, struct vip_meta);
171 } vip_map SEC(".maps");
172
173 struct {
174         __uint(type, BPF_MAP_TYPE_ARRAY);
175         __uint(max_entries, CH_RINGS_SIZE);
176         __type(key, __u32);
177         __type(value, __u32);
178 } ch_rings SEC(".maps");
179
180 struct {
181         __uint(type, BPF_MAP_TYPE_ARRAY);
182         __uint(max_entries, MAX_REALS);
183         __type(key, __u32);
184         __type(value, struct real_definition);
185 } reals SEC(".maps");
186
187 struct {
188         __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
189         __uint(max_entries, MAX_VIPS);
190         __type(key, __u32);
191         __type(value, struct vip_stats);
192 } stats SEC(".maps");
193
194 struct {
195         __uint(type, BPF_MAP_TYPE_ARRAY);
196         __uint(max_entries, CTL_MAP_SIZE);
197         __type(key, __u32);
198         __type(value, struct ctl_value);
199 } ctl_array SEC(".maps");
200
201 static __noinline __u32 get_packet_hash(struct packet_description *pckt, bool ipv6)
202 {
203         if (ipv6)
204                 return jhash_2words(jhash(pckt->srcv6, 16, MAX_VIPS),
205                                     pckt->ports, CH_RINGS_SIZE);
206         else
207                 return jhash_2words(pckt->src, pckt->ports, CH_RINGS_SIZE);
208 }
209
210 static __noinline bool get_packet_dst(struct real_definition **real,
211                                       struct packet_description *pckt,
212                                       struct vip_meta *vip_info,
213                                       bool is_ipv6)
214 {
215         __u32 hash = get_packet_hash(pckt, is_ipv6);
216         __u32 key = RING_SIZE * vip_info->vip_num + hash % RING_SIZE;
217         __u32 *real_pos;
218
219         if (hash != 0x358459b7 /* jhash of ipv4 packet */  &&
220             hash != 0x2f4bc6bb /* jhash of ipv6 packet */)
221                 return false;
222
223         real_pos = bpf_map_lookup_elem(&ch_rings, &key);
224         if (!real_pos)
225                 return false;
226         key = *real_pos;
227         *real = bpf_map_lookup_elem(&reals, &key);
228         if (!(*real))
229                 return false;
230         return true;
231 }
232
233 static __noinline int parse_icmpv6(void *data, void *data_end, __u64 off,
234                                    struct packet_description *pckt)
235 {
236         struct icmp6hdr *icmp_hdr;
237         struct ipv6hdr *ip6h;
238
239         icmp_hdr = data + off;
240         if (icmp_hdr + 1 > data_end)
241                 return TC_ACT_SHOT;
242         if (icmp_hdr->icmp6_type != ICMPV6_PKT_TOOBIG)
243                 return TC_ACT_OK;
244         off += sizeof(struct icmp6hdr);
245         ip6h = data + off;
246         if (ip6h + 1 > data_end)
247                 return TC_ACT_SHOT;
248         pckt->proto = ip6h->nexthdr;
249         pckt->flags |= F_ICMP;
250         memcpy(pckt->srcv6, ip6h->daddr.s6_addr32, 16);
251         memcpy(pckt->dstv6, ip6h->saddr.s6_addr32, 16);
252         return TC_ACT_UNSPEC;
253 }
254
255 static __noinline int parse_icmp(void *data, void *data_end, __u64 off,
256                                  struct packet_description *pckt)
257 {
258         struct icmphdr *icmp_hdr;
259         struct iphdr *iph;
260
261         icmp_hdr = data + off;
262         if (icmp_hdr + 1 > data_end)
263                 return TC_ACT_SHOT;
264         if (icmp_hdr->type != ICMP_DEST_UNREACH ||
265             icmp_hdr->code != ICMP_FRAG_NEEDED)
266                 return TC_ACT_OK;
267         off += sizeof(struct icmphdr);
268         iph = data + off;
269         if (iph + 1 > data_end)
270                 return TC_ACT_SHOT;
271         if (iph->ihl != 5)
272                 return TC_ACT_SHOT;
273         pckt->proto = iph->protocol;
274         pckt->flags |= F_ICMP;
275         pckt->src = iph->daddr;
276         pckt->dst = iph->saddr;
277         return TC_ACT_UNSPEC;
278 }
279
280 static __noinline bool parse_udp(void *data, __u64 off, void *data_end,
281                                  struct packet_description *pckt)
282 {
283         struct udphdr *udp;
284         udp = data + off;
285
286         if (udp + 1 > data_end)
287                 return false;
288
289         if (!(pckt->flags & F_ICMP)) {
290                 pckt->port16[0] = udp->source;
291                 pckt->port16[1] = udp->dest;
292         } else {
293                 pckt->port16[0] = udp->dest;
294                 pckt->port16[1] = udp->source;
295         }
296         return true;
297 }
298
299 static __noinline bool parse_tcp(void *data, __u64 off, void *data_end,
300                                  struct packet_description *pckt)
301 {
302         struct tcphdr *tcp;
303
304         tcp = data + off;
305         if (tcp + 1 > data_end)
306                 return false;
307
308         if (tcp->syn)
309                 pckt->flags |= F_SYN_SET;
310
311         if (!(pckt->flags & F_ICMP)) {
312                 pckt->port16[0] = tcp->source;
313                 pckt->port16[1] = tcp->dest;
314         } else {
315                 pckt->port16[0] = tcp->dest;
316                 pckt->port16[1] = tcp->source;
317         }
318         return true;
319 }
320
321 static __noinline int process_packet(void *data, __u64 off, void *data_end,
322                                      bool is_ipv6, struct __sk_buff *skb)
323 {
324         void *pkt_start = (void *)(long)skb->data;
325         struct packet_description pckt = {};
326         struct eth_hdr *eth = pkt_start;
327         struct bpf_tunnel_key tkey = {};
328         struct vip_stats *data_stats;
329         struct real_definition *dst;
330         struct vip_meta *vip_info;
331         struct ctl_value *cval;
332         __u32 v4_intf_pos = 1;
333         __u32 v6_intf_pos = 2;
334         struct ipv6hdr *ip6h;
335         struct vip vip = {};
336         struct iphdr *iph;
337         int tun_flag = 0;
338         __u16 pkt_bytes;
339         __u64 iph_len;
340         __u32 ifindex;
341         __u8 protocol;
342         __u32 vip_num;
343         int action;
344
345         tkey.tunnel_ttl = 64;
346         if (is_ipv6) {
347                 ip6h = data + off;
348                 if (ip6h + 1 > data_end)
349                         return TC_ACT_SHOT;
350
351                 iph_len = sizeof(struct ipv6hdr);
352                 protocol = ip6h->nexthdr;
353                 pckt.proto = protocol;
354                 pkt_bytes = bpf_ntohs(ip6h->payload_len);
355                 off += iph_len;
356                 if (protocol == IPPROTO_FRAGMENT) {
357                         return TC_ACT_SHOT;
358                 } else if (protocol == IPPROTO_ICMPV6) {
359                         action = parse_icmpv6(data, data_end, off, &pckt);
360                         if (action >= 0)
361                                 return action;
362                         off += IPV6_PLUS_ICMP_HDR;
363                 } else {
364                         memcpy(pckt.srcv6, ip6h->saddr.s6_addr32, 16);
365                         memcpy(pckt.dstv6, ip6h->daddr.s6_addr32, 16);
366                 }
367         } else {
368                 iph = data + off;
369                 if (iph + 1 > data_end)
370                         return TC_ACT_SHOT;
371                 if (iph->ihl != 5)
372                         return TC_ACT_SHOT;
373
374                 protocol = iph->protocol;
375                 pckt.proto = protocol;
376                 pkt_bytes = bpf_ntohs(iph->tot_len);
377                 off += IPV4_HDR_LEN_NO_OPT;
378
379                 if (iph->frag_off & PCKT_FRAGMENTED)
380                         return TC_ACT_SHOT;
381                 if (protocol == IPPROTO_ICMP) {
382                         action = parse_icmp(data, data_end, off, &pckt);
383                         if (action >= 0)
384                                 return action;
385                         off += IPV4_PLUS_ICMP_HDR;
386                 } else {
387                         pckt.src = iph->saddr;
388                         pckt.dst = iph->daddr;
389                 }
390         }
391         protocol = pckt.proto;
392
393         if (protocol == IPPROTO_TCP) {
394                 if (!parse_tcp(data, off, data_end, &pckt))
395                         return TC_ACT_SHOT;
396         } else if (protocol == IPPROTO_UDP) {
397                 if (!parse_udp(data, off, data_end, &pckt))
398                         return TC_ACT_SHOT;
399         } else {
400                 return TC_ACT_SHOT;
401         }
402
403         if (is_ipv6)
404                 memcpy(vip.daddr.v6, pckt.dstv6, 16);
405         else
406                 vip.daddr.v4 = pckt.dst;
407
408         vip.dport = pckt.port16[1];
409         vip.protocol = pckt.proto;
410         vip_info = bpf_map_lookup_elem(&vip_map, &vip);
411         if (!vip_info) {
412                 vip.dport = 0;
413                 vip_info = bpf_map_lookup_elem(&vip_map, &vip);
414                 if (!vip_info)
415                         return TC_ACT_SHOT;
416                 pckt.port16[1] = 0;
417         }
418
419         if (vip_info->flags & F_HASH_NO_SRC_PORT)
420                 pckt.port16[0] = 0;
421
422         if (!get_packet_dst(&dst, &pckt, vip_info, is_ipv6))
423                 return TC_ACT_SHOT;
424
425         if (dst->flags & F_IPV6) {
426                 cval = bpf_map_lookup_elem(&ctl_array, &v6_intf_pos);
427                 if (!cval)
428                         return TC_ACT_SHOT;
429                 ifindex = cval->ifindex;
430                 memcpy(tkey.remote_ipv6, dst->dstv6, 16);
431                 tun_flag = BPF_F_TUNINFO_IPV6;
432         } else {
433                 cval = bpf_map_lookup_elem(&ctl_array, &v4_intf_pos);
434                 if (!cval)
435                         return TC_ACT_SHOT;
436                 ifindex = cval->ifindex;
437                 tkey.remote_ipv4 = dst->dst;
438         }
439         vip_num = vip_info->vip_num;
440         data_stats = bpf_map_lookup_elem(&stats, &vip_num);
441         if (!data_stats)
442                 return TC_ACT_SHOT;
443         data_stats->pkts++;
444         data_stats->bytes += pkt_bytes;
445         bpf_skb_set_tunnel_key(skb, &tkey, sizeof(tkey), tun_flag);
446         *(u32 *)eth->eth_dest = tkey.remote_ipv4;
447         return bpf_redirect(ifindex, 0);
448 }
449
450 SEC("tc")
451 int balancer_ingress(struct __sk_buff *ctx)
452 {
453         void *data_end = (void *)(long)ctx->data_end;
454         void *data = (void *)(long)ctx->data;
455         struct eth_hdr *eth = data;
456         __u32 eth_proto;
457         __u32 nh_off;
458
459         nh_off = sizeof(struct eth_hdr);
460         if (data + nh_off > data_end)
461                 return TC_ACT_SHOT;
462         eth_proto = eth->eth_proto;
463         if (eth_proto == bpf_htons(ETH_P_IP))
464                 return process_packet(data, nh_off, data_end, false, ctx);
465         else if (eth_proto == bpf_htons(ETH_P_IPV6))
466                 return process_packet(data, nh_off, data_end, true, ctx);
467         else
468                 return TC_ACT_SHOT;
469 }
470 char _license[] SEC("license") = "GPL";