selftests/bpf: Fix erroneous bitmask operation
[platform/kernel/linux-rpi.git] / tools / testing / selftests / bpf / progs / test_l4lb.c
1 /* Copyright (c) 2017 Facebook
2  *
3  * This program is free software; you can redistribute it and/or
4  * modify it under the terms of version 2 of the GNU General Public
5  * License as published by the Free Software Foundation.
6  */
7 #include <stddef.h>
8 #include <stdbool.h>
9 #include <string.h>
10 #include <linux/pkt_cls.h>
11 #include <linux/bpf.h>
12 #include <linux/in.h>
13 #include <linux/if_ether.h>
14 #include <linux/ip.h>
15 #include <linux/ipv6.h>
16 #include <linux/icmp.h>
17 #include <linux/icmpv6.h>
18 #include <linux/tcp.h>
19 #include <linux/udp.h>
20 #include <bpf/bpf_helpers.h>
21 #include "test_iptunnel_common.h"
22 #include <bpf/bpf_endian.h>
23
24 static inline __u32 rol32(__u32 word, unsigned int shift)
25 {
26         return (word << shift) | (word >> ((-shift) & 31));
27 }
28
29 /* copy paste of jhash from kernel sources to make sure llvm
30  * can compile it into valid sequence of bpf instructions
31  */
32 #define __jhash_mix(a, b, c)                    \
33 {                                               \
34         a -= c;  a ^= rol32(c, 4);  c += b;     \
35         b -= a;  b ^= rol32(a, 6);  a += c;     \
36         c -= b;  c ^= rol32(b, 8);  b += a;     \
37         a -= c;  a ^= rol32(c, 16); c += b;     \
38         b -= a;  b ^= rol32(a, 19); a += c;     \
39         c -= b;  c ^= rol32(b, 4);  b += a;     \
40 }
41
42 #define __jhash_final(a, b, c)                  \
43 {                                               \
44         c ^= b; c -= rol32(b, 14);              \
45         a ^= c; a -= rol32(c, 11);              \
46         b ^= a; b -= rol32(a, 25);              \
47         c ^= b; c -= rol32(b, 16);              \
48         a ^= c; a -= rol32(c, 4);               \
49         b ^= a; b -= rol32(a, 14);              \
50         c ^= b; c -= rol32(b, 24);              \
51 }
52
53 #define JHASH_INITVAL           0xdeadbeef
54
55 typedef unsigned int u32;
56
57 static inline u32 jhash(const void *key, u32 length, u32 initval)
58 {
59         u32 a, b, c;
60         const unsigned char *k = key;
61
62         a = b = c = JHASH_INITVAL + length + initval;
63
64         while (length > 12) {
65                 a += *(u32 *)(k);
66                 b += *(u32 *)(k + 4);
67                 c += *(u32 *)(k + 8);
68                 __jhash_mix(a, b, c);
69                 length -= 12;
70                 k += 12;
71         }
72         switch (length) {
73         case 12: c += (u32)k[11]<<24;
74         case 11: c += (u32)k[10]<<16;
75         case 10: c += (u32)k[9]<<8;
76         case 9:  c += k[8];
77         case 8:  b += (u32)k[7]<<24;
78         case 7:  b += (u32)k[6]<<16;
79         case 6:  b += (u32)k[5]<<8;
80         case 5:  b += k[4];
81         case 4:  a += (u32)k[3]<<24;
82         case 3:  a += (u32)k[2]<<16;
83         case 2:  a += (u32)k[1]<<8;
84         case 1:  a += k[0];
85                  __jhash_final(a, b, c);
86         case 0: /* Nothing left to add */
87                 break;
88         }
89
90         return c;
91 }
92
93 static inline u32 __jhash_nwords(u32 a, u32 b, u32 c, u32 initval)
94 {
95         a += initval;
96         b += initval;
97         c += initval;
98         __jhash_final(a, b, c);
99         return c;
100 }
101
102 static inline u32 jhash_2words(u32 a, u32 b, u32 initval)
103 {
104         return __jhash_nwords(a, b, 0, initval + JHASH_INITVAL + (2 << 2));
105 }
106
107 #define PCKT_FRAGMENTED 65343
108 #define IPV4_HDR_LEN_NO_OPT 20
109 #define IPV4_PLUS_ICMP_HDR 28
110 #define IPV6_PLUS_ICMP_HDR 48
111 #define RING_SIZE 2
112 #define MAX_VIPS 12
113 #define MAX_REALS 5
114 #define CTL_MAP_SIZE 16
115 #define CH_RINGS_SIZE (MAX_VIPS * RING_SIZE)
116 #define F_IPV6 (1 << 0)
117 #define F_HASH_NO_SRC_PORT (1 << 0)
118 #define F_ICMP (1 << 0)
119 #define F_SYN_SET (1 << 1)
120
121 struct packet_description {
122         union {
123                 __be32 src;
124                 __be32 srcv6[4];
125         };
126         union {
127                 __be32 dst;
128                 __be32 dstv6[4];
129         };
130         union {
131                 __u32 ports;
132                 __u16 port16[2];
133         };
134         __u8 proto;
135         __u8 flags;
136 };
137
138 struct ctl_value {
139         union {
140                 __u64 value;
141                 __u32 ifindex;
142                 __u8 mac[6];
143         };
144 };
145
146 struct vip_meta {
147         __u32 flags;
148         __u32 vip_num;
149 };
150
151 struct real_definition {
152         union {
153                 __be32 dst;
154                 __be32 dstv6[4];
155         };
156         __u8 flags;
157 };
158
159 struct vip_stats {
160         __u64 bytes;
161         __u64 pkts;
162 };
163
164 struct eth_hdr {
165         unsigned char eth_dest[ETH_ALEN];
166         unsigned char eth_source[ETH_ALEN];
167         unsigned short eth_proto;
168 };
169
170 struct {
171         __uint(type, BPF_MAP_TYPE_HASH);
172         __uint(max_entries, MAX_VIPS);
173         __type(key, struct vip);
174         __type(value, struct vip_meta);
175 } vip_map SEC(".maps");
176
177 struct {
178         __uint(type, BPF_MAP_TYPE_ARRAY);
179         __uint(max_entries, CH_RINGS_SIZE);
180         __type(key, __u32);
181         __type(value, __u32);
182 } ch_rings SEC(".maps");
183
184 struct {
185         __uint(type, BPF_MAP_TYPE_ARRAY);
186         __uint(max_entries, MAX_REALS);
187         __type(key, __u32);
188         __type(value, struct real_definition);
189 } reals SEC(".maps");
190
191 struct {
192         __uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
193         __uint(max_entries, MAX_VIPS);
194         __type(key, __u32);
195         __type(value, struct vip_stats);
196 } stats SEC(".maps");
197
198 struct {
199         __uint(type, BPF_MAP_TYPE_ARRAY);
200         __uint(max_entries, CTL_MAP_SIZE);
201         __type(key, __u32);
202         __type(value, struct ctl_value);
203 } ctl_array SEC(".maps");
204
205 static __always_inline __u32 get_packet_hash(struct packet_description *pckt,
206                                              bool ipv6)
207 {
208         if (ipv6)
209                 return jhash_2words(jhash(pckt->srcv6, 16, MAX_VIPS),
210                                     pckt->ports, CH_RINGS_SIZE);
211         else
212                 return jhash_2words(pckt->src, pckt->ports, CH_RINGS_SIZE);
213 }
214
215 static __always_inline bool get_packet_dst(struct real_definition **real,
216                                            struct packet_description *pckt,
217                                            struct vip_meta *vip_info,
218                                            bool is_ipv6)
219 {
220         __u32 hash = get_packet_hash(pckt, is_ipv6) % RING_SIZE;
221         __u32 key = RING_SIZE * vip_info->vip_num + hash;
222         __u32 *real_pos;
223
224         real_pos = bpf_map_lookup_elem(&ch_rings, &key);
225         if (!real_pos)
226                 return false;
227         key = *real_pos;
228         *real = bpf_map_lookup_elem(&reals, &key);
229         if (!(*real))
230                 return false;
231         return true;
232 }
233
234 static __always_inline int parse_icmpv6(void *data, void *data_end, __u64 off,
235                                         struct packet_description *pckt)
236 {
237         struct icmp6hdr *icmp_hdr;
238         struct ipv6hdr *ip6h;
239
240         icmp_hdr = data + off;
241         if (icmp_hdr + 1 > data_end)
242                 return TC_ACT_SHOT;
243         if (icmp_hdr->icmp6_type != ICMPV6_PKT_TOOBIG)
244                 return TC_ACT_OK;
245         off += sizeof(struct icmp6hdr);
246         ip6h = data + off;
247         if (ip6h + 1 > data_end)
248                 return TC_ACT_SHOT;
249         pckt->proto = ip6h->nexthdr;
250         pckt->flags |= F_ICMP;
251         memcpy(pckt->srcv6, ip6h->daddr.s6_addr32, 16);
252         memcpy(pckt->dstv6, ip6h->saddr.s6_addr32, 16);
253         return TC_ACT_UNSPEC;
254 }
255
256 static __always_inline int parse_icmp(void *data, void *data_end, __u64 off,
257                                       struct packet_description *pckt)
258 {
259         struct icmphdr *icmp_hdr;
260         struct iphdr *iph;
261
262         icmp_hdr = data + off;
263         if (icmp_hdr + 1 > data_end)
264                 return TC_ACT_SHOT;
265         if (icmp_hdr->type != ICMP_DEST_UNREACH ||
266             icmp_hdr->code != ICMP_FRAG_NEEDED)
267                 return TC_ACT_OK;
268         off += sizeof(struct icmphdr);
269         iph = data + off;
270         if (iph + 1 > data_end)
271                 return TC_ACT_SHOT;
272         if (iph->ihl != 5)
273                 return TC_ACT_SHOT;
274         pckt->proto = iph->protocol;
275         pckt->flags |= F_ICMP;
276         pckt->src = iph->daddr;
277         pckt->dst = iph->saddr;
278         return TC_ACT_UNSPEC;
279 }
280
281 static __always_inline bool parse_udp(void *data, __u64 off, void *data_end,
282                                       struct packet_description *pckt)
283 {
284         struct udphdr *udp;
285         udp = data + off;
286
287         if (udp + 1 > data_end)
288                 return false;
289
290         if (!(pckt->flags & F_ICMP)) {
291                 pckt->port16[0] = udp->source;
292                 pckt->port16[1] = udp->dest;
293         } else {
294                 pckt->port16[0] = udp->dest;
295                 pckt->port16[1] = udp->source;
296         }
297         return true;
298 }
299
300 static __always_inline bool parse_tcp(void *data, __u64 off, void *data_end,
301                                       struct packet_description *pckt)
302 {
303         struct tcphdr *tcp;
304
305         tcp = data + off;
306         if (tcp + 1 > data_end)
307                 return false;
308
309         if (tcp->syn)
310                 pckt->flags |= F_SYN_SET;
311
312         if (!(pckt->flags & F_ICMP)) {
313                 pckt->port16[0] = tcp->source;
314                 pckt->port16[1] = tcp->dest;
315         } else {
316                 pckt->port16[0] = tcp->dest;
317                 pckt->port16[1] = tcp->source;
318         }
319         return true;
320 }
321
322 static __always_inline int process_packet(void *data, __u64 off, void *data_end,
323                                           bool is_ipv6, struct __sk_buff *skb)
324 {
325         void *pkt_start = (void *)(long)skb->data;
326         struct packet_description pckt = {};
327         struct eth_hdr *eth = pkt_start;
328         struct bpf_tunnel_key tkey = {};
329         struct vip_stats *data_stats;
330         struct real_definition *dst;
331         struct vip_meta *vip_info;
332         struct ctl_value *cval;
333         __u32 v4_intf_pos = 1;
334         __u32 v6_intf_pos = 2;
335         struct ipv6hdr *ip6h;
336         struct vip vip = {};
337         struct iphdr *iph;
338         int tun_flag = 0;
339         __u16 pkt_bytes;
340         __u64 iph_len;
341         __u32 ifindex;
342         __u8 protocol;
343         __u32 vip_num;
344         int action;
345
346         tkey.tunnel_ttl = 64;
347         if (is_ipv6) {
348                 ip6h = data + off;
349                 if (ip6h + 1 > data_end)
350                         return TC_ACT_SHOT;
351
352                 iph_len = sizeof(struct ipv6hdr);
353                 protocol = ip6h->nexthdr;
354                 pckt.proto = protocol;
355                 pkt_bytes = bpf_ntohs(ip6h->payload_len);
356                 off += iph_len;
357                 if (protocol == IPPROTO_FRAGMENT) {
358                         return TC_ACT_SHOT;
359                 } else if (protocol == IPPROTO_ICMPV6) {
360                         action = parse_icmpv6(data, data_end, off, &pckt);
361                         if (action >= 0)
362                                 return action;
363                         off += IPV6_PLUS_ICMP_HDR;
364                 } else {
365                         memcpy(pckt.srcv6, ip6h->saddr.s6_addr32, 16);
366                         memcpy(pckt.dstv6, ip6h->daddr.s6_addr32, 16);
367                 }
368         } else {
369                 iph = data + off;
370                 if (iph + 1 > data_end)
371                         return TC_ACT_SHOT;
372                 if (iph->ihl != 5)
373                         return TC_ACT_SHOT;
374
375                 protocol = iph->protocol;
376                 pckt.proto = protocol;
377                 pkt_bytes = bpf_ntohs(iph->tot_len);
378                 off += IPV4_HDR_LEN_NO_OPT;
379
380                 if (iph->frag_off & PCKT_FRAGMENTED)
381                         return TC_ACT_SHOT;
382                 if (protocol == IPPROTO_ICMP) {
383                         action = parse_icmp(data, data_end, off, &pckt);
384                         if (action >= 0)
385                                 return action;
386                         off += IPV4_PLUS_ICMP_HDR;
387                 } else {
388                         pckt.src = iph->saddr;
389                         pckt.dst = iph->daddr;
390                 }
391         }
392         protocol = pckt.proto;
393
394         if (protocol == IPPROTO_TCP) {
395                 if (!parse_tcp(data, off, data_end, &pckt))
396                         return TC_ACT_SHOT;
397         } else if (protocol == IPPROTO_UDP) {
398                 if (!parse_udp(data, off, data_end, &pckt))
399                         return TC_ACT_SHOT;
400         } else {
401                 return TC_ACT_SHOT;
402         }
403
404         if (is_ipv6)
405                 memcpy(vip.daddr.v6, pckt.dstv6, 16);
406         else
407                 vip.daddr.v4 = pckt.dst;
408
409         vip.dport = pckt.port16[1];
410         vip.protocol = pckt.proto;
411         vip_info = bpf_map_lookup_elem(&vip_map, &vip);
412         if (!vip_info) {
413                 vip.dport = 0;
414                 vip_info = bpf_map_lookup_elem(&vip_map, &vip);
415                 if (!vip_info)
416                         return TC_ACT_SHOT;
417                 pckt.port16[1] = 0;
418         }
419
420         if (vip_info->flags & F_HASH_NO_SRC_PORT)
421                 pckt.port16[0] = 0;
422
423         if (!get_packet_dst(&dst, &pckt, vip_info, is_ipv6))
424                 return TC_ACT_SHOT;
425
426         if (dst->flags & F_IPV6) {
427                 cval = bpf_map_lookup_elem(&ctl_array, &v6_intf_pos);
428                 if (!cval)
429                         return TC_ACT_SHOT;
430                 ifindex = cval->ifindex;
431                 memcpy(tkey.remote_ipv6, dst->dstv6, 16);
432                 tun_flag = BPF_F_TUNINFO_IPV6;
433         } else {
434                 cval = bpf_map_lookup_elem(&ctl_array, &v4_intf_pos);
435                 if (!cval)
436                         return TC_ACT_SHOT;
437                 ifindex = cval->ifindex;
438                 tkey.remote_ipv4 = dst->dst;
439         }
440         vip_num = vip_info->vip_num;
441         data_stats = bpf_map_lookup_elem(&stats, &vip_num);
442         if (!data_stats)
443                 return TC_ACT_SHOT;
444         data_stats->pkts++;
445         data_stats->bytes += pkt_bytes;
446         bpf_skb_set_tunnel_key(skb, &tkey, sizeof(tkey), tun_flag);
447         *(u32 *)eth->eth_dest = tkey.remote_ipv4;
448         return bpf_redirect(ifindex, 0);
449 }
450
451 SEC("tc")
452 int balancer_ingress(struct __sk_buff *ctx)
453 {
454         void *data_end = (void *)(long)ctx->data_end;
455         void *data = (void *)(long)ctx->data;
456         struct eth_hdr *eth = data;
457         __u32 eth_proto;
458         __u32 nh_off;
459
460         nh_off = sizeof(struct eth_hdr);
461         if (data + nh_off > data_end)
462                 return TC_ACT_SHOT;
463         eth_proto = eth->eth_proto;
464         if (eth_proto == bpf_htons(ETH_P_IP))
465                 return process_packet(data, nh_off, data_end, false, ctx);
466         else if (eth_proto == bpf_htons(ETH_P_IPV6))
467                 return process_packet(data, nh_off, data_end, true, ctx);
468         else
469                 return TC_ACT_SHOT;
470 }
471 char _license[] SEC("license") = "GPL";