selftests/net: ipsec: fix constant out of range
[platform/kernel/linux-rpi.git] / tools / testing / selftests / net / ipsec.c
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * ipsec.c - Check xfrm on veth inside a net-ns.
4  * Copyright (c) 2018 Dmitry Safonov
5  */
6
7 #define _GNU_SOURCE
8
9 #include <arpa/inet.h>
10 #include <asm/types.h>
11 #include <errno.h>
12 #include <fcntl.h>
13 #include <limits.h>
14 #include <linux/limits.h>
15 #include <linux/netlink.h>
16 #include <linux/random.h>
17 #include <linux/rtnetlink.h>
18 #include <linux/veth.h>
19 #include <linux/xfrm.h>
20 #include <netinet/in.h>
21 #include <net/if.h>
22 #include <sched.h>
23 #include <stdbool.h>
24 #include <stdint.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <sys/mman.h>
29 #include <sys/socket.h>
30 #include <sys/stat.h>
31 #include <sys/syscall.h>
32 #include <sys/types.h>
33 #include <sys/wait.h>
34 #include <time.h>
35 #include <unistd.h>
36
37 #include "../kselftest.h"
38
39 #define printk(fmt, ...)                                                \
40         ksft_print_msg("%d[%u] " fmt "\n", getpid(), __LINE__, ##__VA_ARGS__)
41
42 #define pr_err(fmt, ...)        printk(fmt ": %m", ##__VA_ARGS__)
43
44 #define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))
45
46 #define IPV4_STR_SZ     16      /* xxx.xxx.xxx.xxx is longest + \0 */
47 #define MAX_PAYLOAD     2048
48 #define XFRM_ALGO_KEY_BUF_SIZE  512
49 #define MAX_PROCESSES   (1 << 14) /* /16 mask divided by /30 subnets */
50 #define INADDR_A        ((in_addr_t) 0x0a000000) /* 10.0.0.0 */
51 #define INADDR_B        ((in_addr_t) 0xc0a80000) /* 192.168.0.0 */
52
53 /* /30 mask for one veth connection */
54 #define PREFIX_LEN      30
55 #define child_ip(nr)    (4*nr + 1)
56 #define grchild_ip(nr)  (4*nr + 2)
57
58 #define VETH_FMT        "ktst-%d"
59 #define VETH_LEN        12
60
61 #define XFRM_ALGO_NR_KEYS 29
62
63 static int nsfd_parent  = -1;
64 static int nsfd_childa  = -1;
65 static int nsfd_childb  = -1;
66 static long page_size;
67
68 /*
69  * ksft_cnt is static in kselftest, so isn't shared with children.
70  * We have to send a test result back to parent and count there.
71  * results_fd is a pipe with test feedback from children.
72  */
73 static int results_fd[2];
74
75 const unsigned int ping_delay_nsec      = 50 * 1000 * 1000;
76 const unsigned int ping_timeout         = 300;
77 const unsigned int ping_count           = 100;
78 const unsigned int ping_success         = 80;
79
80 struct xfrm_key_entry {
81         char algo_name[35];
82         int key_len;
83 };
84
85 struct xfrm_key_entry xfrm_key_entries[] = {
86         {"digest_null", 0},
87         {"ecb(cipher_null)", 0},
88         {"cbc(des)", 64},
89         {"hmac(md5)", 128},
90         {"cmac(aes)", 128},
91         {"xcbc(aes)", 128},
92         {"cbc(cast5)", 128},
93         {"cbc(serpent)", 128},
94         {"hmac(sha1)", 160},
95         {"hmac(rmd160)", 160},
96         {"cbc(des3_ede)", 192},
97         {"hmac(sha256)", 256},
98         {"cbc(aes)", 256},
99         {"cbc(camellia)", 256},
100         {"cbc(twofish)", 256},
101         {"rfc3686(ctr(aes))", 288},
102         {"hmac(sha384)", 384},
103         {"cbc(blowfish)", 448},
104         {"hmac(sha512)", 512},
105         {"rfc4106(gcm(aes))-128", 160},
106         {"rfc4543(gcm(aes))-128", 160},
107         {"rfc4309(ccm(aes))-128", 152},
108         {"rfc4106(gcm(aes))-192", 224},
109         {"rfc4543(gcm(aes))-192", 224},
110         {"rfc4309(ccm(aes))-192", 216},
111         {"rfc4106(gcm(aes))-256", 288},
112         {"rfc4543(gcm(aes))-256", 288},
113         {"rfc4309(ccm(aes))-256", 280},
114         {"rfc7539(chacha20,poly1305)-128", 0}
115 };
116
117 static void randomize_buffer(void *buf, size_t buflen)
118 {
119         int *p = (int *)buf;
120         size_t words = buflen / sizeof(int);
121         size_t leftover = buflen % sizeof(int);
122
123         if (!buflen)
124                 return;
125
126         while (words--)
127                 *p++ = rand();
128
129         if (leftover) {
130                 int tmp = rand();
131
132                 memcpy(buf + buflen - leftover, &tmp, leftover);
133         }
134
135         return;
136 }
137
138 static int unshare_open(void)
139 {
140         const char *netns_path = "/proc/self/ns/net";
141         int fd;
142
143         if (unshare(CLONE_NEWNET) != 0) {
144                 pr_err("unshare()");
145                 return -1;
146         }
147
148         fd = open(netns_path, O_RDONLY);
149         if (fd <= 0) {
150                 pr_err("open(%s)", netns_path);
151                 return -1;
152         }
153
154         return fd;
155 }
156
157 static int switch_ns(int fd)
158 {
159         if (setns(fd, CLONE_NEWNET)) {
160                 pr_err("setns()");
161                 return -1;
162         }
163         return 0;
164 }
165
166 /*
167  * Running the test inside a new parent net namespace to bother less
168  * about cleanup on error-path.
169  */
170 static int init_namespaces(void)
171 {
172         nsfd_parent = unshare_open();
173         if (nsfd_parent <= 0)
174                 return -1;
175
176         nsfd_childa = unshare_open();
177         if (nsfd_childa <= 0)
178                 return -1;
179
180         if (switch_ns(nsfd_parent))
181                 return -1;
182
183         nsfd_childb = unshare_open();
184         if (nsfd_childb <= 0)
185                 return -1;
186
187         if (switch_ns(nsfd_parent))
188                 return -1;
189         return 0;
190 }
191
192 static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
193 {
194         if (*sock > 0) {
195                 seq_nr++;
196                 return 0;
197         }
198
199         *sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
200         if (*sock <= 0) {
201                 pr_err("socket(AF_NETLINK)");
202                 return -1;
203         }
204
205         randomize_buffer(seq_nr, sizeof(*seq_nr));
206
207         return 0;
208 }
209
210 static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
211 {
212         return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
213 }
214
215 static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
216                 unsigned short rta_type, const void *payload, size_t size)
217 {
218         /* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
219         struct rtattr *attr = rtattr_hdr(nh);
220         size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);
221
222         if (req_sz < nl_size) {
223                 printk("req buf is too small: %zu < %zu", req_sz, nl_size);
224                 return -1;
225         }
226         nh->nlmsg_len = nl_size;
227
228         attr->rta_len = RTA_LENGTH(size);
229         attr->rta_type = rta_type;
230         memcpy(RTA_DATA(attr), payload, size);
231
232         return 0;
233 }
234
235 static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
236                 unsigned short rta_type, const void *payload, size_t size)
237 {
238         struct rtattr *ret = rtattr_hdr(nh);
239
240         if (rtattr_pack(nh, req_sz, rta_type, payload, size))
241                 return 0;
242
243         return ret;
244 }
245
246 static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
247                 unsigned short rta_type)
248 {
249         return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
250 }
251
252 static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
253 {
254         char *nlmsg_end = (char *)nh + nh->nlmsg_len;
255
256         attr->rta_len = nlmsg_end - (char *)attr;
257 }
258
259 static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
260                 const char *peer, int ns)
261 {
262         struct ifinfomsg pi;
263         struct rtattr *peer_attr;
264
265         memset(&pi, 0, sizeof(pi));
266         pi.ifi_family   = AF_UNSPEC;
267         pi.ifi_change   = 0xFFFFFFFF;
268
269         peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
270         if (!peer_attr)
271                 return -1;
272
273         if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
274                 return -1;
275
276         if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
277                 return -1;
278
279         rtattr_end(nh, peer_attr);
280
281         return 0;
282 }
283
284 static int netlink_check_answer(int sock)
285 {
286         struct nlmsgerror {
287                 struct nlmsghdr hdr;
288                 int error;
289                 struct nlmsghdr orig_msg;
290         } answer;
291
292         if (recv(sock, &answer, sizeof(answer), 0) < 0) {
293                 pr_err("recv()");
294                 return -1;
295         } else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
296                 printk("expected NLMSG_ERROR, got %d", (int)answer.hdr.nlmsg_type);
297                 return -1;
298         } else if (answer.error) {
299                 printk("NLMSG_ERROR: %d: %s",
300                         answer.error, strerror(-answer.error));
301                 return answer.error;
302         }
303
304         return 0;
305 }
306
307 static int veth_add(int sock, uint32_t seq, const char *peera, int ns_a,
308                 const char *peerb, int ns_b)
309 {
310         uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
311         struct {
312                 struct nlmsghdr         nh;
313                 struct ifinfomsg        info;
314                 char                    attrbuf[MAX_PAYLOAD];
315         } req;
316         const char veth_type[] = "veth";
317         struct rtattr *link_info, *info_data;
318
319         memset(&req, 0, sizeof(req));
320         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
321         req.nh.nlmsg_type       = RTM_NEWLINK;
322         req.nh.nlmsg_flags      = flags;
323         req.nh.nlmsg_seq        = seq;
324         req.info.ifi_family     = AF_UNSPEC;
325         req.info.ifi_change     = 0xFFFFFFFF;
326
327         if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, peera, strlen(peera)))
328                 return -1;
329
330         if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
331                 return -1;
332
333         link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
334         if (!link_info)
335                 return -1;
336
337         if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
338                 return -1;
339
340         info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
341         if (!info_data)
342                 return -1;
343
344         if (veth_pack_peerb(&req.nh, sizeof(req), peerb, ns_b))
345                 return -1;
346
347         rtattr_end(&req.nh, info_data);
348         rtattr_end(&req.nh, link_info);
349
350         if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
351                 pr_err("send()");
352                 return -1;
353         }
354         return netlink_check_answer(sock);
355 }
356
357 static int ip4_addr_set(int sock, uint32_t seq, const char *intf,
358                 struct in_addr addr, uint8_t prefix)
359 {
360         uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
361         struct {
362                 struct nlmsghdr         nh;
363                 struct ifaddrmsg        info;
364                 char                    attrbuf[MAX_PAYLOAD];
365         } req;
366
367         memset(&req, 0, sizeof(req));
368         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
369         req.nh.nlmsg_type       = RTM_NEWADDR;
370         req.nh.nlmsg_flags      = flags;
371         req.nh.nlmsg_seq        = seq;
372         req.info.ifa_family     = AF_INET;
373         req.info.ifa_prefixlen  = prefix;
374         req.info.ifa_index      = if_nametoindex(intf);
375
376 #ifdef DEBUG
377         {
378                 char addr_str[IPV4_STR_SZ] = {};
379
380                 strncpy(addr_str, inet_ntoa(addr), IPV4_STR_SZ - 1);
381
382                 printk("ip addr set %s", addr_str);
383         }
384 #endif
385
386         if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, sizeof(addr)))
387                 return -1;
388
389         if (rtattr_pack(&req.nh, sizeof(req), IFA_ADDRESS, &addr, sizeof(addr)))
390                 return -1;
391
392         if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
393                 pr_err("send()");
394                 return -1;
395         }
396         return netlink_check_answer(sock);
397 }
398
399 static int link_set_up(int sock, uint32_t seq, const char *intf)
400 {
401         struct {
402                 struct nlmsghdr         nh;
403                 struct ifinfomsg        info;
404                 char                    attrbuf[MAX_PAYLOAD];
405         } req;
406
407         memset(&req, 0, sizeof(req));
408         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
409         req.nh.nlmsg_type       = RTM_NEWLINK;
410         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
411         req.nh.nlmsg_seq        = seq;
412         req.info.ifi_family     = AF_UNSPEC;
413         req.info.ifi_change     = 0xFFFFFFFF;
414         req.info.ifi_index      = if_nametoindex(intf);
415         req.info.ifi_flags      = IFF_UP;
416         req.info.ifi_change     = IFF_UP;
417
418         if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
419                 pr_err("send()");
420                 return -1;
421         }
422         return netlink_check_answer(sock);
423 }
424
425 static int ip4_route_set(int sock, uint32_t seq, const char *intf,
426                 struct in_addr src, struct in_addr dst)
427 {
428         struct {
429                 struct nlmsghdr nh;
430                 struct rtmsg    rt;
431                 char            attrbuf[MAX_PAYLOAD];
432         } req;
433         unsigned int index = if_nametoindex(intf);
434
435         memset(&req, 0, sizeof(req));
436         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.rt));
437         req.nh.nlmsg_type       = RTM_NEWROUTE;
438         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
439         req.nh.nlmsg_seq        = seq;
440         req.rt.rtm_family       = AF_INET;
441         req.rt.rtm_dst_len      = 32;
442         req.rt.rtm_table        = RT_TABLE_MAIN;
443         req.rt.rtm_protocol     = RTPROT_BOOT;
444         req.rt.rtm_scope        = RT_SCOPE_LINK;
445         req.rt.rtm_type         = RTN_UNICAST;
446
447         if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, sizeof(dst)))
448                 return -1;
449
450         if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, sizeof(src)))
451                 return -1;
452
453         if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
454                 return -1;
455
456         if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
457                 pr_err("send()");
458                 return -1;
459         }
460
461         return netlink_check_answer(sock);
462 }
463
464 static int tunnel_set_route(int route_sock, uint32_t *route_seq, char *veth,
465                 struct in_addr tunsrc, struct in_addr tundst)
466 {
467         if (ip4_addr_set(route_sock, (*route_seq)++, "lo",
468                         tunsrc, PREFIX_LEN)) {
469                 printk("Failed to set ipv4 addr");
470                 return -1;
471         }
472
473         if (ip4_route_set(route_sock, (*route_seq)++, veth, tunsrc, tundst)) {
474                 printk("Failed to set ipv4 route");
475                 return -1;
476         }
477
478         return 0;
479 }
480
481 static int init_child(int nsfd, char *veth, unsigned int src, unsigned int dst)
482 {
483         struct in_addr intsrc = inet_makeaddr(INADDR_B, src);
484         struct in_addr tunsrc = inet_makeaddr(INADDR_A, src);
485         struct in_addr tundst = inet_makeaddr(INADDR_A, dst);
486         int route_sock = -1, ret = -1;
487         uint32_t route_seq;
488
489         if (switch_ns(nsfd))
490                 return -1;
491
492         if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) {
493                 printk("Failed to open netlink route socket in child");
494                 return -1;
495         }
496
497         if (ip4_addr_set(route_sock, route_seq++, veth, intsrc, PREFIX_LEN)) {
498                 printk("Failed to set ipv4 addr");
499                 goto err;
500         }
501
502         if (link_set_up(route_sock, route_seq++, veth)) {
503                 printk("Failed to bring up %s", veth);
504                 goto err;
505         }
506
507         if (tunnel_set_route(route_sock, &route_seq, veth, tunsrc, tundst)) {
508                 printk("Failed to add tunnel route on %s", veth);
509                 goto err;
510         }
511         ret = 0;
512
513 err:
514         close(route_sock);
515         return ret;
516 }
517
518 #define ALGO_LEN        64
519 enum desc_type {
520         CREATE_TUNNEL   = 0,
521         ALLOCATE_SPI,
522         MONITOR_ACQUIRE,
523         EXPIRE_STATE,
524         EXPIRE_POLICY,
525         SPDINFO_ATTRS,
526 };
527 const char *desc_name[] = {
528         "create tunnel",
529         "alloc spi",
530         "monitor acquire",
531         "expire state",
532         "expire policy",
533         "spdinfo attributes",
534         ""
535 };
536 struct xfrm_desc {
537         enum desc_type  type;
538         uint8_t         proto;
539         char            a_algo[ALGO_LEN];
540         char            e_algo[ALGO_LEN];
541         char            c_algo[ALGO_LEN];
542         char            ae_algo[ALGO_LEN];
543         unsigned int    icv_len;
544         /* unsigned key_len; */
545 };
546
547 enum msg_type {
548         MSG_ACK         = 0,
549         MSG_EXIT,
550         MSG_PING,
551         MSG_XFRM_PREPARE,
552         MSG_XFRM_ADD,
553         MSG_XFRM_DEL,
554         MSG_XFRM_CLEANUP,
555 };
556
557 struct test_desc {
558         enum msg_type type;
559         union {
560                 struct {
561                         in_addr_t reply_ip;
562                         unsigned int port;
563                 } ping;
564                 struct xfrm_desc xfrm_desc;
565         } body;
566 };
567
568 struct test_result {
569         struct xfrm_desc desc;
570         unsigned int res;
571 };
572
573 static void write_test_result(unsigned int res, struct xfrm_desc *d)
574 {
575         struct test_result tr = {};
576         ssize_t ret;
577
578         tr.desc = *d;
579         tr.res = res;
580
581         ret = write(results_fd[1], &tr, sizeof(tr));
582         if (ret != sizeof(tr))
583                 pr_err("Failed to write the result in pipe %zd", ret);
584 }
585
586 static void write_msg(int fd, struct test_desc *msg, bool exit_of_fail)
587 {
588         ssize_t bytes = write(fd, msg, sizeof(*msg));
589
590         /* Make sure that write/read is atomic to a pipe */
591         BUILD_BUG_ON(sizeof(struct test_desc) > PIPE_BUF);
592
593         if (bytes < 0) {
594                 pr_err("write()");
595                 if (exit_of_fail)
596                         exit(KSFT_FAIL);
597         }
598         if (bytes != sizeof(*msg)) {
599                 pr_err("sent part of the message %zd/%zu", bytes, sizeof(*msg));
600                 if (exit_of_fail)
601                         exit(KSFT_FAIL);
602         }
603 }
604
605 static void read_msg(int fd, struct test_desc *msg, bool exit_of_fail)
606 {
607         ssize_t bytes = read(fd, msg, sizeof(*msg));
608
609         if (bytes < 0) {
610                 pr_err("read()");
611                 if (exit_of_fail)
612                         exit(KSFT_FAIL);
613         }
614         if (bytes != sizeof(*msg)) {
615                 pr_err("got incomplete message %zd/%zu", bytes, sizeof(*msg));
616                 if (exit_of_fail)
617                         exit(KSFT_FAIL);
618         }
619 }
620
621 static int udp_ping_init(struct in_addr listen_ip, unsigned int u_timeout,
622                 unsigned int *server_port, int sock[2])
623 {
624         struct sockaddr_in server;
625         struct timeval t = { .tv_sec = 0, .tv_usec = u_timeout };
626         socklen_t s_len = sizeof(server);
627
628         sock[0] = socket(AF_INET, SOCK_DGRAM, 0);
629         if (sock[0] < 0) {
630                 pr_err("socket()");
631                 return -1;
632         }
633
634         server.sin_family       = AF_INET;
635         server.sin_port         = 0;
636         memcpy(&server.sin_addr.s_addr, &listen_ip, sizeof(struct in_addr));
637
638         if (bind(sock[0], (struct sockaddr *)&server, s_len)) {
639                 pr_err("bind()");
640                 goto err_close_server;
641         }
642
643         if (getsockname(sock[0], (struct sockaddr *)&server, &s_len)) {
644                 pr_err("getsockname()");
645                 goto err_close_server;
646         }
647
648         *server_port = ntohs(server.sin_port);
649
650         if (setsockopt(sock[0], SOL_SOCKET, SO_RCVTIMEO, (const char *)&t, sizeof t)) {
651                 pr_err("setsockopt()");
652                 goto err_close_server;
653         }
654
655         sock[1] = socket(AF_INET, SOCK_DGRAM, 0);
656         if (sock[1] < 0) {
657                 pr_err("socket()");
658                 goto err_close_server;
659         }
660
661         return 0;
662
663 err_close_server:
664         close(sock[0]);
665         return -1;
666 }
667
668 static int udp_ping_send(int sock[2], in_addr_t dest_ip, unsigned int port,
669                 char *buf, size_t buf_len)
670 {
671         struct sockaddr_in server;
672         const struct sockaddr *dest_addr = (struct sockaddr *)&server;
673         char *sock_buf[buf_len];
674         ssize_t r_bytes, s_bytes;
675
676         server.sin_family       = AF_INET;
677         server.sin_port         = htons(port);
678         server.sin_addr.s_addr  = dest_ip;
679
680         s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
681         if (s_bytes < 0) {
682                 pr_err("sendto()");
683                 return -1;
684         } else if (s_bytes != buf_len) {
685                 printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
686                 return -1;
687         }
688
689         r_bytes = recv(sock[0], sock_buf, buf_len, 0);
690         if (r_bytes < 0) {
691                 if (errno != EAGAIN)
692                         pr_err("recv()");
693                 return -1;
694         } else if (r_bytes == 0) { /* EOF */
695                 printk("EOF on reply to ping");
696                 return -1;
697         } else if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
698                 printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
699                 return -1;
700         }
701
702         return 0;
703 }
704
705 static int udp_ping_reply(int sock[2], in_addr_t dest_ip, unsigned int port,
706                 char *buf, size_t buf_len)
707 {
708         struct sockaddr_in server;
709         const struct sockaddr *dest_addr = (struct sockaddr *)&server;
710         char *sock_buf[buf_len];
711         ssize_t r_bytes, s_bytes;
712
713         server.sin_family       = AF_INET;
714         server.sin_port         = htons(port);
715         server.sin_addr.s_addr  = dest_ip;
716
717         r_bytes = recv(sock[0], sock_buf, buf_len, 0);
718         if (r_bytes < 0) {
719                 if (errno != EAGAIN)
720                         pr_err("recv()");
721                 return -1;
722         }
723         if (r_bytes == 0) { /* EOF */
724                 printk("EOF on reply to ping");
725                 return -1;
726         }
727         if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
728                 printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
729                 return -1;
730         }
731
732         s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
733         if (s_bytes < 0) {
734                 pr_err("sendto()");
735                 return -1;
736         } else if (s_bytes != buf_len) {
737                 printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
738                 return -1;
739         }
740
741         return 0;
742 }
743
744 typedef int (*ping_f)(int sock[2], in_addr_t dest_ip, unsigned int port,
745                 char *buf, size_t buf_len);
746 static int do_ping(int cmd_fd, char *buf, size_t buf_len, struct in_addr from,
747                 bool init_side, int d_port, in_addr_t to, ping_f func)
748 {
749         struct test_desc msg;
750         unsigned int s_port, i, ping_succeeded = 0;
751         int ping_sock[2];
752         char to_str[IPV4_STR_SZ] = {}, from_str[IPV4_STR_SZ] = {};
753
754         if (udp_ping_init(from, ping_timeout, &s_port, ping_sock)) {
755                 printk("Failed to init ping");
756                 return -1;
757         }
758
759         memset(&msg, 0, sizeof(msg));
760         msg.type                = MSG_PING;
761         msg.body.ping.port      = s_port;
762         memcpy(&msg.body.ping.reply_ip, &from, sizeof(from));
763
764         write_msg(cmd_fd, &msg, 0);
765         if (init_side) {
766                 /* The other end sends ip to ping */
767                 read_msg(cmd_fd, &msg, 0);
768                 if (msg.type != MSG_PING)
769                         return -1;
770                 to = msg.body.ping.reply_ip;
771                 d_port = msg.body.ping.port;
772         }
773
774         for (i = 0; i < ping_count ; i++) {
775                 struct timespec sleep_time = {
776                         .tv_sec = 0,
777                         .tv_nsec = ping_delay_nsec,
778                 };
779
780                 ping_succeeded += !func(ping_sock, to, d_port, buf, page_size);
781                 nanosleep(&sleep_time, 0);
782         }
783
784         close(ping_sock[0]);
785         close(ping_sock[1]);
786
787         strncpy(to_str, inet_ntoa(*(struct in_addr *)&to), IPV4_STR_SZ - 1);
788         strncpy(from_str, inet_ntoa(from), IPV4_STR_SZ - 1);
789
790         if (ping_succeeded < ping_success) {
791                 printk("ping (%s) %s->%s failed %u/%u times",
792                         init_side ? "send" : "reply", from_str, to_str,
793                         ping_count - ping_succeeded, ping_count);
794                 return -1;
795         }
796
797 #ifdef DEBUG
798         printk("ping (%s) %s->%s succeeded %u/%u times",
799                 init_side ? "send" : "reply", from_str, to_str,
800                 ping_succeeded, ping_count);
801 #endif
802
803         return 0;
804 }
805
806 static int xfrm_fill_key(char *name, char *buf,
807                 size_t buf_len, unsigned int *key_len)
808 {
809         int i;
810
811         for (i = 0; i < XFRM_ALGO_NR_KEYS; i++) {
812                 if (strncmp(name, xfrm_key_entries[i].algo_name, ALGO_LEN) == 0)
813                         *key_len = xfrm_key_entries[i].key_len;
814         }
815
816         if (*key_len > buf_len) {
817                 printk("Can't pack a key - too big for buffer");
818                 return -1;
819         }
820
821         randomize_buffer(buf, *key_len);
822
823         return 0;
824 }
825
826 static int xfrm_state_pack_algo(struct nlmsghdr *nh, size_t req_sz,
827                 struct xfrm_desc *desc)
828 {
829         struct {
830                 union {
831                         struct xfrm_algo        alg;
832                         struct xfrm_algo_aead   aead;
833                         struct xfrm_algo_auth   auth;
834                 } u;
835                 char buf[XFRM_ALGO_KEY_BUF_SIZE];
836         } alg = {};
837         size_t alen, elen, clen, aelen;
838         unsigned short type;
839
840         alen = strlen(desc->a_algo);
841         elen = strlen(desc->e_algo);
842         clen = strlen(desc->c_algo);
843         aelen = strlen(desc->ae_algo);
844
845         /* Verify desc */
846         switch (desc->proto) {
847         case IPPROTO_AH:
848                 if (!alen || elen || clen || aelen) {
849                         printk("BUG: buggy ah desc");
850                         return -1;
851                 }
852                 strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN - 1);
853                 if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
854                                 sizeof(alg.buf), &alg.u.alg.alg_key_len))
855                         return -1;
856                 type = XFRMA_ALG_AUTH;
857                 break;
858         case IPPROTO_COMP:
859                 if (!clen || elen || alen || aelen) {
860                         printk("BUG: buggy comp desc");
861                         return -1;
862                 }
863                 strncpy(alg.u.alg.alg_name, desc->c_algo, ALGO_LEN - 1);
864                 if (xfrm_fill_key(desc->c_algo, alg.u.alg.alg_key,
865                                 sizeof(alg.buf), &alg.u.alg.alg_key_len))
866                         return -1;
867                 type = XFRMA_ALG_COMP;
868                 break;
869         case IPPROTO_ESP:
870                 if (!((alen && elen) ^ aelen) || clen) {
871                         printk("BUG: buggy esp desc");
872                         return -1;
873                 }
874                 if (aelen) {
875                         alg.u.aead.alg_icv_len = desc->icv_len;
876                         strncpy(alg.u.aead.alg_name, desc->ae_algo, ALGO_LEN - 1);
877                         if (xfrm_fill_key(desc->ae_algo, alg.u.aead.alg_key,
878                                                 sizeof(alg.buf), &alg.u.aead.alg_key_len))
879                                 return -1;
880                         type = XFRMA_ALG_AEAD;
881                 } else {
882
883                         strncpy(alg.u.alg.alg_name, desc->e_algo, ALGO_LEN - 1);
884                         type = XFRMA_ALG_CRYPT;
885                         if (xfrm_fill_key(desc->e_algo, alg.u.alg.alg_key,
886                                                 sizeof(alg.buf), &alg.u.alg.alg_key_len))
887                                 return -1;
888                         if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
889                                 return -1;
890
891                         strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN);
892                         type = XFRMA_ALG_AUTH;
893                         if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
894                                                 sizeof(alg.buf), &alg.u.alg.alg_key_len))
895                                 return -1;
896                 }
897                 break;
898         default:
899                 printk("BUG: unknown proto in desc");
900                 return -1;
901         }
902
903         if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
904                 return -1;
905
906         return 0;
907 }
908
909 static inline uint32_t gen_spi(struct in_addr src)
910 {
911         return htonl(inet_lnaof(src));
912 }
913
914 static int xfrm_state_add(int xfrm_sock, uint32_t seq, uint32_t spi,
915                 struct in_addr src, struct in_addr dst,
916                 struct xfrm_desc *desc)
917 {
918         struct {
919                 struct nlmsghdr         nh;
920                 struct xfrm_usersa_info info;
921                 char                    attrbuf[MAX_PAYLOAD];
922         } req;
923
924         memset(&req, 0, sizeof(req));
925         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
926         req.nh.nlmsg_type       = XFRM_MSG_NEWSA;
927         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
928         req.nh.nlmsg_seq        = seq;
929
930         /* Fill selector. */
931         memcpy(&req.info.sel.daddr, &dst, sizeof(dst));
932         memcpy(&req.info.sel.saddr, &src, sizeof(src));
933         req.info.sel.family             = AF_INET;
934         req.info.sel.prefixlen_d        = PREFIX_LEN;
935         req.info.sel.prefixlen_s        = PREFIX_LEN;
936
937         /* Fill id */
938         memcpy(&req.info.id.daddr, &dst, sizeof(dst));
939         /* Note: zero-spi cannot be deleted */
940         req.info.id.spi = spi;
941         req.info.id.proto       = desc->proto;
942
943         memcpy(&req.info.saddr, &src, sizeof(src));
944
945         /* Fill lifteme_cfg */
946         req.info.lft.soft_byte_limit    = XFRM_INF;
947         req.info.lft.hard_byte_limit    = XFRM_INF;
948         req.info.lft.soft_packet_limit  = XFRM_INF;
949         req.info.lft.hard_packet_limit  = XFRM_INF;
950
951         req.info.family         = AF_INET;
952         req.info.mode           = XFRM_MODE_TUNNEL;
953
954         if (xfrm_state_pack_algo(&req.nh, sizeof(req), desc))
955                 return -1;
956
957         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
958                 pr_err("send()");
959                 return -1;
960         }
961
962         return netlink_check_answer(xfrm_sock);
963 }
964
965 static bool xfrm_usersa_found(struct xfrm_usersa_info *info, uint32_t spi,
966                 struct in_addr src, struct in_addr dst,
967                 struct xfrm_desc *desc)
968 {
969         if (memcmp(&info->sel.daddr, &dst, sizeof(dst)))
970                 return false;
971
972         if (memcmp(&info->sel.saddr, &src, sizeof(src)))
973                 return false;
974
975         if (info->sel.family != AF_INET                                 ||
976                         info->sel.prefixlen_d != PREFIX_LEN             ||
977                         info->sel.prefixlen_s != PREFIX_LEN)
978                 return false;
979
980         if (info->id.spi != spi || info->id.proto != desc->proto)
981                 return false;
982
983         if (memcmp(&info->id.daddr, &dst, sizeof(dst)))
984                 return false;
985
986         if (memcmp(&info->saddr, &src, sizeof(src)))
987                 return false;
988
989         if (info->lft.soft_byte_limit != XFRM_INF                       ||
990                         info->lft.hard_byte_limit != XFRM_INF           ||
991                         info->lft.soft_packet_limit != XFRM_INF         ||
992                         info->lft.hard_packet_limit != XFRM_INF)
993                 return false;
994
995         if (info->family != AF_INET || info->mode != XFRM_MODE_TUNNEL)
996                 return false;
997
998         /* XXX: check xfrm algo, see xfrm_state_pack_algo(). */
999
1000         return true;
1001 }
1002
1003 static int xfrm_state_check(int xfrm_sock, uint32_t seq, uint32_t spi,
1004                 struct in_addr src, struct in_addr dst,
1005                 struct xfrm_desc *desc)
1006 {
1007         struct {
1008                 struct nlmsghdr         nh;
1009                 char                    attrbuf[MAX_PAYLOAD];
1010         } req;
1011         struct {
1012                 struct nlmsghdr         nh;
1013                 union {
1014                         struct xfrm_usersa_info info;
1015                         int error;
1016                 };
1017                 char                    attrbuf[MAX_PAYLOAD];
1018         } answer;
1019         struct xfrm_address_filter filter = {};
1020         bool found = false;
1021
1022
1023         memset(&req, 0, sizeof(req));
1024         req.nh.nlmsg_len        = NLMSG_LENGTH(0);
1025         req.nh.nlmsg_type       = XFRM_MSG_GETSA;
1026         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_DUMP;
1027         req.nh.nlmsg_seq        = seq;
1028
1029         /*
1030          * Add dump filter by source address as there may be other tunnels
1031          * in this netns (if tests run in parallel).
1032          */
1033         filter.family = AF_INET;
1034         filter.splen = 0x1f;    /* 0xffffffff mask see addr_match() */
1035         memcpy(&filter.saddr, &src, sizeof(src));
1036         if (rtattr_pack(&req.nh, sizeof(req), XFRMA_ADDRESS_FILTER,
1037                                 &filter, sizeof(filter)))
1038                 return -1;
1039
1040         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1041                 pr_err("send()");
1042                 return -1;
1043         }
1044
1045         while (1) {
1046                 if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1047                         pr_err("recv()");
1048                         return -1;
1049                 }
1050                 if (answer.nh.nlmsg_type == NLMSG_ERROR) {
1051                         printk("NLMSG_ERROR: %d: %s",
1052                                 answer.error, strerror(-answer.error));
1053                         return -1;
1054                 } else if (answer.nh.nlmsg_type == NLMSG_DONE) {
1055                         if (found)
1056                                 return 0;
1057                         printk("didn't find allocated xfrm state in dump");
1058                         return -1;
1059                 } else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1060                         if (xfrm_usersa_found(&answer.info, spi, src, dst, desc))
1061                                 found = true;
1062                 }
1063         }
1064 }
1065
1066 static int xfrm_set(int xfrm_sock, uint32_t *seq,
1067                 struct in_addr src, struct in_addr dst,
1068                 struct in_addr tunsrc, struct in_addr tundst,
1069                 struct xfrm_desc *desc)
1070 {
1071         int err;
1072
1073         err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1074         if (err) {
1075                 printk("Failed to add xfrm state");
1076                 return -1;
1077         }
1078
1079         err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1080         if (err) {
1081                 printk("Failed to add xfrm state");
1082                 return -1;
1083         }
1084
1085         /* Check dumps for XFRM_MSG_GETSA */
1086         err = xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
1087         err |= xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
1088         if (err) {
1089                 printk("Failed to check xfrm state");
1090                 return -1;
1091         }
1092
1093         return 0;
1094 }
1095
1096 static int xfrm_policy_add(int xfrm_sock, uint32_t seq, uint32_t spi,
1097                 struct in_addr src, struct in_addr dst, uint8_t dir,
1098                 struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1099 {
1100         struct {
1101                 struct nlmsghdr                 nh;
1102                 struct xfrm_userpolicy_info     info;
1103                 char                            attrbuf[MAX_PAYLOAD];
1104         } req;
1105         struct xfrm_user_tmpl tmpl;
1106
1107         memset(&req, 0, sizeof(req));
1108         memset(&tmpl, 0, sizeof(tmpl));
1109         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.info));
1110         req.nh.nlmsg_type       = XFRM_MSG_NEWPOLICY;
1111         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1112         req.nh.nlmsg_seq        = seq;
1113
1114         /* Fill selector. */
1115         memcpy(&req.info.sel.daddr, &dst, sizeof(tundst));
1116         memcpy(&req.info.sel.saddr, &src, sizeof(tunsrc));
1117         req.info.sel.family             = AF_INET;
1118         req.info.sel.prefixlen_d        = PREFIX_LEN;
1119         req.info.sel.prefixlen_s        = PREFIX_LEN;
1120
1121         /* Fill lifteme_cfg */
1122         req.info.lft.soft_byte_limit    = XFRM_INF;
1123         req.info.lft.hard_byte_limit    = XFRM_INF;
1124         req.info.lft.soft_packet_limit  = XFRM_INF;
1125         req.info.lft.hard_packet_limit  = XFRM_INF;
1126
1127         req.info.dir = dir;
1128
1129         /* Fill tmpl */
1130         memcpy(&tmpl.id.daddr, &dst, sizeof(dst));
1131         /* Note: zero-spi cannot be deleted */
1132         tmpl.id.spi = spi;
1133         tmpl.id.proto   = proto;
1134         tmpl.family     = AF_INET;
1135         memcpy(&tmpl.saddr, &src, sizeof(src));
1136         tmpl.mode       = XFRM_MODE_TUNNEL;
1137         tmpl.aalgos = (~(uint32_t)0);
1138         tmpl.ealgos = (~(uint32_t)0);
1139         tmpl.calgos = (~(uint32_t)0);
1140
1141         if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &tmpl, sizeof(tmpl)))
1142                 return -1;
1143
1144         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1145                 pr_err("send()");
1146                 return -1;
1147         }
1148
1149         return netlink_check_answer(xfrm_sock);
1150 }
1151
1152 static int xfrm_prepare(int xfrm_sock, uint32_t *seq,
1153                 struct in_addr src, struct in_addr dst,
1154                 struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1155 {
1156         if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1157                                 XFRM_POLICY_OUT, tunsrc, tundst, proto)) {
1158                 printk("Failed to add xfrm policy");
1159                 return -1;
1160         }
1161
1162         if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src,
1163                                 XFRM_POLICY_IN, tunsrc, tundst, proto)) {
1164                 printk("Failed to add xfrm policy");
1165                 return -1;
1166         }
1167
1168         return 0;
1169 }
1170
1171 static int xfrm_policy_del(int xfrm_sock, uint32_t seq,
1172                 struct in_addr src, struct in_addr dst, uint8_t dir,
1173                 struct in_addr tunsrc, struct in_addr tundst)
1174 {
1175         struct {
1176                 struct nlmsghdr                 nh;
1177                 struct xfrm_userpolicy_id       id;
1178                 char                            attrbuf[MAX_PAYLOAD];
1179         } req;
1180
1181         memset(&req, 0, sizeof(req));
1182         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.id));
1183         req.nh.nlmsg_type       = XFRM_MSG_DELPOLICY;
1184         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1185         req.nh.nlmsg_seq        = seq;
1186
1187         /* Fill id */
1188         memcpy(&req.id.sel.daddr, &dst, sizeof(tundst));
1189         memcpy(&req.id.sel.saddr, &src, sizeof(tunsrc));
1190         req.id.sel.family               = AF_INET;
1191         req.id.sel.prefixlen_d          = PREFIX_LEN;
1192         req.id.sel.prefixlen_s          = PREFIX_LEN;
1193         req.id.dir = dir;
1194
1195         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1196                 pr_err("send()");
1197                 return -1;
1198         }
1199
1200         return netlink_check_answer(xfrm_sock);
1201 }
1202
1203 static int xfrm_cleanup(int xfrm_sock, uint32_t *seq,
1204                 struct in_addr src, struct in_addr dst,
1205                 struct in_addr tunsrc, struct in_addr tundst)
1206 {
1207         if (xfrm_policy_del(xfrm_sock, (*seq)++, src, dst,
1208                                 XFRM_POLICY_OUT, tunsrc, tundst)) {
1209                 printk("Failed to add xfrm policy");
1210                 return -1;
1211         }
1212
1213         if (xfrm_policy_del(xfrm_sock, (*seq)++, dst, src,
1214                                 XFRM_POLICY_IN, tunsrc, tundst)) {
1215                 printk("Failed to add xfrm policy");
1216                 return -1;
1217         }
1218
1219         return 0;
1220 }
1221
1222 static int xfrm_state_del(int xfrm_sock, uint32_t seq, uint32_t spi,
1223                 struct in_addr src, struct in_addr dst, uint8_t proto)
1224 {
1225         struct {
1226                 struct nlmsghdr         nh;
1227                 struct xfrm_usersa_id   id;
1228                 char                    attrbuf[MAX_PAYLOAD];
1229         } req;
1230         xfrm_address_t saddr = {};
1231
1232         memset(&req, 0, sizeof(req));
1233         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.id));
1234         req.nh.nlmsg_type       = XFRM_MSG_DELSA;
1235         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1236         req.nh.nlmsg_seq        = seq;
1237
1238         memcpy(&req.id.daddr, &dst, sizeof(dst));
1239         req.id.family           = AF_INET;
1240         req.id.proto            = proto;
1241         /* Note: zero-spi cannot be deleted */
1242         req.id.spi = spi;
1243
1244         memcpy(&saddr, &src, sizeof(src));
1245         if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SRCADDR, &saddr, sizeof(saddr)))
1246                 return -1;
1247
1248         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1249                 pr_err("send()");
1250                 return -1;
1251         }
1252
1253         return netlink_check_answer(xfrm_sock);
1254 }
1255
1256 static int xfrm_delete(int xfrm_sock, uint32_t *seq,
1257                 struct in_addr src, struct in_addr dst,
1258                 struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
1259 {
1260         if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), src, dst, proto)) {
1261                 printk("Failed to remove xfrm state");
1262                 return -1;
1263         }
1264
1265         if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), dst, src, proto)) {
1266                 printk("Failed to remove xfrm state");
1267                 return -1;
1268         }
1269
1270         return 0;
1271 }
1272
1273 static int xfrm_state_allocspi(int xfrm_sock, uint32_t *seq,
1274                 uint32_t spi, uint8_t proto)
1275 {
1276         struct {
1277                 struct nlmsghdr                 nh;
1278                 struct xfrm_userspi_info        spi;
1279         } req;
1280         struct {
1281                 struct nlmsghdr                 nh;
1282                 union {
1283                         struct xfrm_usersa_info info;
1284                         int error;
1285                 };
1286         } answer;
1287
1288         memset(&req, 0, sizeof(req));
1289         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.spi));
1290         req.nh.nlmsg_type       = XFRM_MSG_ALLOCSPI;
1291         req.nh.nlmsg_flags      = NLM_F_REQUEST;
1292         req.nh.nlmsg_seq        = (*seq)++;
1293
1294         req.spi.info.family     = AF_INET;
1295         req.spi.min             = spi;
1296         req.spi.max             = spi;
1297         req.spi.info.id.proto   = proto;
1298
1299         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1300                 pr_err("send()");
1301                 return KSFT_FAIL;
1302         }
1303
1304         if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
1305                 pr_err("recv()");
1306                 return KSFT_FAIL;
1307         } else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
1308                 uint32_t new_spi = htonl(answer.info.id.spi);
1309
1310                 if (new_spi != spi) {
1311                         printk("allocated spi is different from requested: %#x != %#x",
1312                                         new_spi, spi);
1313                         return KSFT_FAIL;
1314                 }
1315                 return KSFT_PASS;
1316         } else if (answer.nh.nlmsg_type != NLMSG_ERROR) {
1317                 printk("expected NLMSG_ERROR, got %d", (int)answer.nh.nlmsg_type);
1318                 return KSFT_FAIL;
1319         }
1320
1321         printk("NLMSG_ERROR: %d: %s", answer.error, strerror(-answer.error));
1322         return (answer.error) ? KSFT_FAIL : KSFT_PASS;
1323 }
1324
1325 static int netlink_sock_bind(int *sock, uint32_t *seq, int proto, uint32_t groups)
1326 {
1327         struct sockaddr_nl snl = {};
1328         socklen_t addr_len;
1329         int ret = -1;
1330
1331         snl.nl_family = AF_NETLINK;
1332         snl.nl_groups = groups;
1333
1334         if (netlink_sock(sock, seq, proto)) {
1335                 printk("Failed to open xfrm netlink socket");
1336                 return -1;
1337         }
1338
1339         if (bind(*sock, (struct sockaddr *)&snl, sizeof(snl)) < 0) {
1340                 pr_err("bind()");
1341                 goto out_close;
1342         }
1343
1344         addr_len = sizeof(snl);
1345         if (getsockname(*sock, (struct sockaddr *)&snl, &addr_len) < 0) {
1346                 pr_err("getsockname()");
1347                 goto out_close;
1348         }
1349         if (addr_len != sizeof(snl)) {
1350                 printk("Wrong address length %d", addr_len);
1351                 goto out_close;
1352         }
1353         if (snl.nl_family != AF_NETLINK) {
1354                 printk("Wrong address family %d", snl.nl_family);
1355                 goto out_close;
1356         }
1357         return 0;
1358
1359 out_close:
1360         close(*sock);
1361         return ret;
1362 }
1363
1364 static int xfrm_monitor_acquire(int xfrm_sock, uint32_t *seq, unsigned int nr)
1365 {
1366         struct {
1367                 struct nlmsghdr nh;
1368                 union {
1369                         struct xfrm_user_acquire acq;
1370                         int error;
1371                 };
1372                 char attrbuf[MAX_PAYLOAD];
1373         } req;
1374         struct xfrm_user_tmpl xfrm_tmpl = {};
1375         int xfrm_listen = -1, ret = KSFT_FAIL;
1376         uint32_t seq_listen;
1377
1378         if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_ACQUIRE))
1379                 return KSFT_FAIL;
1380
1381         memset(&req, 0, sizeof(req));
1382         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.acq));
1383         req.nh.nlmsg_type       = XFRM_MSG_ACQUIRE;
1384         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1385         req.nh.nlmsg_seq        = (*seq)++;
1386
1387         req.acq.policy.sel.family       = AF_INET;
1388         req.acq.aalgos  = 0xfeed;
1389         req.acq.ealgos  = 0xbaad;
1390         req.acq.calgos  = 0xbabe;
1391
1392         xfrm_tmpl.family = AF_INET;
1393         xfrm_tmpl.id.proto = IPPROTO_ESP;
1394         if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &xfrm_tmpl, sizeof(xfrm_tmpl)))
1395                 goto out_close;
1396
1397         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1398                 pr_err("send()");
1399                 goto out_close;
1400         }
1401
1402         if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1403                 pr_err("recv()");
1404                 goto out_close;
1405         } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1406                 printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1407                 goto out_close;
1408         }
1409
1410         if (req.error) {
1411                 printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1412                 ret = req.error;
1413                 goto out_close;
1414         }
1415
1416         if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1417                 pr_err("recv()");
1418                 goto out_close;
1419         }
1420
1421         if (req.acq.aalgos != 0xfeed || req.acq.ealgos != 0xbaad
1422                         || req.acq.calgos != 0xbabe) {
1423                 printk("xfrm_user_acquire has changed  %x %x %x",
1424                                 req.acq.aalgos, req.acq.ealgos, req.acq.calgos);
1425                 goto out_close;
1426         }
1427
1428         ret = KSFT_PASS;
1429 out_close:
1430         close(xfrm_listen);
1431         return ret;
1432 }
1433
1434 static int xfrm_expire_state(int xfrm_sock, uint32_t *seq,
1435                 unsigned int nr, struct xfrm_desc *desc)
1436 {
1437         struct {
1438                 struct nlmsghdr nh;
1439                 union {
1440                         struct xfrm_user_expire expire;
1441                         int error;
1442                 };
1443         } req;
1444         struct in_addr src, dst;
1445         int xfrm_listen = -1, ret = KSFT_FAIL;
1446         uint32_t seq_listen;
1447
1448         src = inet_makeaddr(INADDR_B, child_ip(nr));
1449         dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1450
1451         if (xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc)) {
1452                 printk("Failed to add xfrm state");
1453                 return KSFT_FAIL;
1454         }
1455
1456         if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1457                 return KSFT_FAIL;
1458
1459         memset(&req, 0, sizeof(req));
1460         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.expire));
1461         req.nh.nlmsg_type       = XFRM_MSG_EXPIRE;
1462         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1463         req.nh.nlmsg_seq        = (*seq)++;
1464
1465         memcpy(&req.expire.state.id.daddr, &dst, sizeof(dst));
1466         req.expire.state.id.spi         = gen_spi(src);
1467         req.expire.state.id.proto       = desc->proto;
1468         req.expire.state.family         = AF_INET;
1469         req.expire.hard                 = 0xff;
1470
1471         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1472                 pr_err("send()");
1473                 goto out_close;
1474         }
1475
1476         if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1477                 pr_err("recv()");
1478                 goto out_close;
1479         } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1480                 printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1481                 goto out_close;
1482         }
1483
1484         if (req.error) {
1485                 printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1486                 ret = req.error;
1487                 goto out_close;
1488         }
1489
1490         if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1491                 pr_err("recv()");
1492                 goto out_close;
1493         }
1494
1495         if (req.expire.hard != 0x1) {
1496                 printk("expire.hard is not set: %x", req.expire.hard);
1497                 goto out_close;
1498         }
1499
1500         ret = KSFT_PASS;
1501 out_close:
1502         close(xfrm_listen);
1503         return ret;
1504 }
1505
1506 static int xfrm_expire_policy(int xfrm_sock, uint32_t *seq,
1507                 unsigned int nr, struct xfrm_desc *desc)
1508 {
1509         struct {
1510                 struct nlmsghdr nh;
1511                 union {
1512                         struct xfrm_user_polexpire expire;
1513                         int error;
1514                 };
1515         } req;
1516         struct in_addr src, dst, tunsrc, tundst;
1517         int xfrm_listen = -1, ret = KSFT_FAIL;
1518         uint32_t seq_listen;
1519
1520         src = inet_makeaddr(INADDR_B, child_ip(nr));
1521         dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1522         tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1523         tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1524
1525         if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
1526                                 XFRM_POLICY_OUT, tunsrc, tundst, desc->proto)) {
1527                 printk("Failed to add xfrm policy");
1528                 return KSFT_FAIL;
1529         }
1530
1531         if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
1532                 return KSFT_FAIL;
1533
1534         memset(&req, 0, sizeof(req));
1535         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.expire));
1536         req.nh.nlmsg_type       = XFRM_MSG_POLEXPIRE;
1537         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1538         req.nh.nlmsg_seq        = (*seq)++;
1539
1540         /* Fill selector. */
1541         memcpy(&req.expire.pol.sel.daddr, &dst, sizeof(tundst));
1542         memcpy(&req.expire.pol.sel.saddr, &src, sizeof(tunsrc));
1543         req.expire.pol.sel.family       = AF_INET;
1544         req.expire.pol.sel.prefixlen_d  = PREFIX_LEN;
1545         req.expire.pol.sel.prefixlen_s  = PREFIX_LEN;
1546         req.expire.pol.dir              = XFRM_POLICY_OUT;
1547         req.expire.hard                 = 0xff;
1548
1549         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1550                 pr_err("send()");
1551                 goto out_close;
1552         }
1553
1554         if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1555                 pr_err("recv()");
1556                 goto out_close;
1557         } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1558                 printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1559                 goto out_close;
1560         }
1561
1562         if (req.error) {
1563                 printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1564                 ret = req.error;
1565                 goto out_close;
1566         }
1567
1568         if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
1569                 pr_err("recv()");
1570                 goto out_close;
1571         }
1572
1573         if (req.expire.hard != 0x1) {
1574                 printk("expire.hard is not set: %x", req.expire.hard);
1575                 goto out_close;
1576         }
1577
1578         ret = KSFT_PASS;
1579 out_close:
1580         close(xfrm_listen);
1581         return ret;
1582 }
1583
1584 static int xfrm_spdinfo_set_thresh(int xfrm_sock, uint32_t *seq,
1585                 unsigned thresh4_l, unsigned thresh4_r,
1586                 unsigned thresh6_l, unsigned thresh6_r,
1587                 bool add_bad_attr)
1588
1589 {
1590         struct {
1591                 struct nlmsghdr         nh;
1592                 union {
1593                         uint32_t        unused;
1594                         int             error;
1595                 };
1596                 char                    attrbuf[MAX_PAYLOAD];
1597         } req;
1598         struct xfrmu_spdhthresh thresh;
1599
1600         memset(&req, 0, sizeof(req));
1601         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.unused));
1602         req.nh.nlmsg_type       = XFRM_MSG_NEWSPDINFO;
1603         req.nh.nlmsg_flags      = NLM_F_REQUEST | NLM_F_ACK;
1604         req.nh.nlmsg_seq        = (*seq)++;
1605
1606         thresh.lbits = thresh4_l;
1607         thresh.rbits = thresh4_r;
1608         if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV4_HTHRESH, &thresh, sizeof(thresh)))
1609                 return -1;
1610
1611         thresh.lbits = thresh6_l;
1612         thresh.rbits = thresh6_r;
1613         if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV6_HTHRESH, &thresh, sizeof(thresh)))
1614                 return -1;
1615
1616         if (add_bad_attr) {
1617                 BUILD_BUG_ON(XFRMA_IF_ID <= XFRMA_SPD_MAX + 1);
1618                 if (rtattr_pack(&req.nh, sizeof(req), XFRMA_IF_ID, NULL, 0)) {
1619                         pr_err("adding attribute failed: no space");
1620                         return -1;
1621                 }
1622         }
1623
1624         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1625                 pr_err("send()");
1626                 return -1;
1627         }
1628
1629         if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1630                 pr_err("recv()");
1631                 return -1;
1632         } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1633                 printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1634                 return -1;
1635         }
1636
1637         if (req.error) {
1638                 printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1639                 return -1;
1640         }
1641
1642         return 0;
1643 }
1644
1645 static int xfrm_spdinfo_attrs(int xfrm_sock, uint32_t *seq)
1646 {
1647         struct {
1648                 struct nlmsghdr                 nh;
1649                 union {
1650                         uint32_t        unused;
1651                         int             error;
1652                 };
1653                 char                    attrbuf[MAX_PAYLOAD];
1654         } req;
1655
1656         if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 31, 120, 16, false)) {
1657                 pr_err("Can't set SPD HTHRESH");
1658                 return KSFT_FAIL;
1659         }
1660
1661         memset(&req, 0, sizeof(req));
1662
1663         req.nh.nlmsg_len        = NLMSG_LENGTH(sizeof(req.unused));
1664         req.nh.nlmsg_type       = XFRM_MSG_GETSPDINFO;
1665         req.nh.nlmsg_flags      = NLM_F_REQUEST;
1666         req.nh.nlmsg_seq        = (*seq)++;
1667         if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
1668                 pr_err("send()");
1669                 return KSFT_FAIL;
1670         }
1671
1672         if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
1673                 pr_err("recv()");
1674                 return KSFT_FAIL;
1675         } else if (req.nh.nlmsg_type == XFRM_MSG_NEWSPDINFO) {
1676                 size_t len = NLMSG_PAYLOAD(&req.nh, sizeof(req.unused));
1677                 struct rtattr *attr = (void *)req.attrbuf;
1678                 int got_thresh = 0;
1679
1680                 for (; RTA_OK(attr, len); attr = RTA_NEXT(attr, len)) {
1681                         if (attr->rta_type == XFRMA_SPD_IPV4_HTHRESH) {
1682                                 struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1683
1684                                 got_thresh++;
1685                                 if (t->lbits != 32 || t->rbits != 31) {
1686                                         pr_err("thresh differ: %u, %u",
1687                                                         t->lbits, t->rbits);
1688                                         return KSFT_FAIL;
1689                                 }
1690                         }
1691                         if (attr->rta_type == XFRMA_SPD_IPV6_HTHRESH) {
1692                                 struct xfrmu_spdhthresh *t = RTA_DATA(attr);
1693
1694                                 got_thresh++;
1695                                 if (t->lbits != 120 || t->rbits != 16) {
1696                                         pr_err("thresh differ: %u, %u",
1697                                                         t->lbits, t->rbits);
1698                                         return KSFT_FAIL;
1699                                 }
1700                         }
1701                 }
1702                 if (got_thresh != 2) {
1703                         pr_err("only %d thresh returned by XFRM_MSG_GETSPDINFO", got_thresh);
1704                         return KSFT_FAIL;
1705                 }
1706         } else if (req.nh.nlmsg_type != NLMSG_ERROR) {
1707                 printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
1708                 return KSFT_FAIL;
1709         } else {
1710                 printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
1711                 return -1;
1712         }
1713
1714         /* Restore the default */
1715         if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, false)) {
1716                 pr_err("Can't restore SPD HTHRESH");
1717                 return KSFT_FAIL;
1718         }
1719
1720         /*
1721          * At this moment xfrm uses nlmsg_parse_deprecated(), which
1722          * implies NL_VALIDATE_LIBERAL - ignoring attributes with
1723          * (type > maxtype). nla_parse_depricated_strict() would enforce
1724          * it. Or even stricter nla_parse().
1725          * Right now it's not expected to fail, but to be ignored.
1726          */
1727         if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, true))
1728                 return KSFT_PASS;
1729
1730         return KSFT_PASS;
1731 }
1732
1733 static int child_serv(int xfrm_sock, uint32_t *seq,
1734                 unsigned int nr, int cmd_fd, void *buf, struct xfrm_desc *desc)
1735 {
1736         struct in_addr src, dst, tunsrc, tundst;
1737         struct test_desc msg;
1738         int ret = KSFT_FAIL;
1739
1740         src = inet_makeaddr(INADDR_B, child_ip(nr));
1741         dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
1742         tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
1743         tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));
1744
1745         /* UDP pinging without xfrm */
1746         if (do_ping(cmd_fd, buf, page_size, src, true, 0, 0, udp_ping_send)) {
1747                 printk("ping failed before setting xfrm");
1748                 return KSFT_FAIL;
1749         }
1750
1751         memset(&msg, 0, sizeof(msg));
1752         msg.type = MSG_XFRM_PREPARE;
1753         memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1754         write_msg(cmd_fd, &msg, 1);
1755
1756         if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1757                 printk("failed to prepare xfrm");
1758                 goto cleanup;
1759         }
1760
1761         memset(&msg, 0, sizeof(msg));
1762         msg.type = MSG_XFRM_ADD;
1763         memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1764         write_msg(cmd_fd, &msg, 1);
1765         if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1766                 printk("failed to set xfrm");
1767                 goto delete;
1768         }
1769
1770         /* UDP pinging with xfrm tunnel */
1771         if (do_ping(cmd_fd, buf, page_size, tunsrc,
1772                                 true, 0, 0, udp_ping_send)) {
1773                 printk("ping failed for xfrm");
1774                 goto delete;
1775         }
1776
1777         ret = KSFT_PASS;
1778 delete:
1779         /* xfrm delete */
1780         memset(&msg, 0, sizeof(msg));
1781         msg.type = MSG_XFRM_DEL;
1782         memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1783         write_msg(cmd_fd, &msg, 1);
1784
1785         if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
1786                 printk("failed ping to remove xfrm");
1787                 ret = KSFT_FAIL;
1788         }
1789
1790 cleanup:
1791         memset(&msg, 0, sizeof(msg));
1792         msg.type = MSG_XFRM_CLEANUP;
1793         memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
1794         write_msg(cmd_fd, &msg, 1);
1795         if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1796                 printk("failed ping to cleanup xfrm");
1797                 ret = KSFT_FAIL;
1798         }
1799         return ret;
1800 }
1801
1802 static int child_f(unsigned int nr, int test_desc_fd, int cmd_fd, void *buf)
1803 {
1804         struct xfrm_desc desc;
1805         struct test_desc msg;
1806         int xfrm_sock = -1;
1807         uint32_t seq;
1808
1809         if (switch_ns(nsfd_childa))
1810                 exit(KSFT_FAIL);
1811
1812         if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1813                 printk("Failed to open xfrm netlink socket");
1814                 exit(KSFT_FAIL);
1815         }
1816
1817         /* Check that seq sock is ready, just for sure. */
1818         memset(&msg, 0, sizeof(msg));
1819         msg.type = MSG_ACK;
1820         write_msg(cmd_fd, &msg, 1);
1821         read_msg(cmd_fd, &msg, 1);
1822         if (msg.type != MSG_ACK) {
1823                 printk("Ack failed");
1824                 exit(KSFT_FAIL);
1825         }
1826
1827         for (;;) {
1828                 ssize_t received = read(test_desc_fd, &desc, sizeof(desc));
1829                 int ret;
1830
1831                 if (received == 0) /* EOF */
1832                         break;
1833
1834                 if (received != sizeof(desc)) {
1835                         pr_err("read() returned %zd", received);
1836                         exit(KSFT_FAIL);
1837                 }
1838
1839                 switch (desc.type) {
1840                 case CREATE_TUNNEL:
1841                         ret = child_serv(xfrm_sock, &seq, nr,
1842                                          cmd_fd, buf, &desc);
1843                         break;
1844                 case ALLOCATE_SPI:
1845                         ret = xfrm_state_allocspi(xfrm_sock, &seq,
1846                                                   -1, desc.proto);
1847                         break;
1848                 case MONITOR_ACQUIRE:
1849                         ret = xfrm_monitor_acquire(xfrm_sock, &seq, nr);
1850                         break;
1851                 case EXPIRE_STATE:
1852                         ret = xfrm_expire_state(xfrm_sock, &seq, nr, &desc);
1853                         break;
1854                 case EXPIRE_POLICY:
1855                         ret = xfrm_expire_policy(xfrm_sock, &seq, nr, &desc);
1856                         break;
1857                 case SPDINFO_ATTRS:
1858                         ret = xfrm_spdinfo_attrs(xfrm_sock, &seq);
1859                         break;
1860                 default:
1861                         printk("Unknown desc type %d", desc.type);
1862                         exit(KSFT_FAIL);
1863                 }
1864                 write_test_result(ret, &desc);
1865         }
1866
1867         close(xfrm_sock);
1868
1869         msg.type = MSG_EXIT;
1870         write_msg(cmd_fd, &msg, 1);
1871         exit(KSFT_PASS);
1872 }
1873
1874 static void grand_child_serv(unsigned int nr, int cmd_fd, void *buf,
1875                 struct test_desc *msg, int xfrm_sock, uint32_t *seq)
1876 {
1877         struct in_addr src, dst, tunsrc, tundst;
1878         bool tun_reply;
1879         struct xfrm_desc *desc = &msg->body.xfrm_desc;
1880
1881         src = inet_makeaddr(INADDR_B, grchild_ip(nr));
1882         dst = inet_makeaddr(INADDR_B, child_ip(nr));
1883         tunsrc = inet_makeaddr(INADDR_A, grchild_ip(nr));
1884         tundst = inet_makeaddr(INADDR_A, child_ip(nr));
1885
1886         switch (msg->type) {
1887         case MSG_EXIT:
1888                 exit(KSFT_PASS);
1889         case MSG_ACK:
1890                 write_msg(cmd_fd, msg, 1);
1891                 break;
1892         case MSG_PING:
1893                 tun_reply = memcmp(&dst, &msg->body.ping.reply_ip, sizeof(in_addr_t));
1894                 /* UDP pinging without xfrm */
1895                 if (do_ping(cmd_fd, buf, page_size, tun_reply ? tunsrc : src,
1896                                 false, msg->body.ping.port,
1897                                 msg->body.ping.reply_ip, udp_ping_reply)) {
1898                         printk("ping failed before setting xfrm");
1899                 }
1900                 break;
1901         case MSG_XFRM_PREPARE:
1902                 if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst,
1903                                         desc->proto)) {
1904                         xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1905                         printk("failed to prepare xfrm");
1906                 }
1907                 break;
1908         case MSG_XFRM_ADD:
1909                 if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
1910                         xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1911                         printk("failed to set xfrm");
1912                 }
1913                 break;
1914         case MSG_XFRM_DEL:
1915                 if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst,
1916                                         desc->proto)) {
1917                         xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
1918                         printk("failed to remove xfrm");
1919                 }
1920                 break;
1921         case MSG_XFRM_CLEANUP:
1922                 if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
1923                         printk("failed to cleanup xfrm");
1924                 }
1925                 break;
1926         default:
1927                 printk("got unknown msg type %d", msg->type);
1928         }
1929 }
1930
1931 static int grand_child_f(unsigned int nr, int cmd_fd, void *buf)
1932 {
1933         struct test_desc msg;
1934         int xfrm_sock = -1;
1935         uint32_t seq;
1936
1937         if (switch_ns(nsfd_childb))
1938                 exit(KSFT_FAIL);
1939
1940         if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
1941                 printk("Failed to open xfrm netlink socket");
1942                 exit(KSFT_FAIL);
1943         }
1944
1945         do {
1946                 read_msg(cmd_fd, &msg, 1);
1947                 grand_child_serv(nr, cmd_fd, buf, &msg, xfrm_sock, &seq);
1948         } while (1);
1949
1950         close(xfrm_sock);
1951         exit(KSFT_FAIL);
1952 }
1953
1954 static int start_child(unsigned int nr, char *veth, int test_desc_fd[2])
1955 {
1956         int cmd_sock[2];
1957         void *data_map;
1958         pid_t child;
1959
1960         if (init_child(nsfd_childa, veth, child_ip(nr), grchild_ip(nr)))
1961                 return -1;
1962
1963         if (init_child(nsfd_childb, veth, grchild_ip(nr), child_ip(nr)))
1964                 return -1;
1965
1966         child = fork();
1967         if (child < 0) {
1968                 pr_err("fork()");
1969                 return -1;
1970         } else if (child) {
1971                 /* in parent - selftest */
1972                 return switch_ns(nsfd_parent);
1973         }
1974
1975         if (close(test_desc_fd[1])) {
1976                 pr_err("close()");
1977                 return -1;
1978         }
1979
1980         /* child */
1981         data_map = mmap(0, page_size, PROT_READ | PROT_WRITE,
1982                         MAP_SHARED | MAP_ANONYMOUS, -1, 0);
1983         if (data_map == MAP_FAILED) {
1984                 pr_err("mmap()");
1985                 return -1;
1986         }
1987
1988         randomize_buffer(data_map, page_size);
1989
1990         if (socketpair(PF_LOCAL, SOCK_SEQPACKET, 0, cmd_sock)) {
1991                 pr_err("socketpair()");
1992                 return -1;
1993         }
1994
1995         child = fork();
1996         if (child < 0) {
1997                 pr_err("fork()");
1998                 return -1;
1999         } else if (child) {
2000                 if (close(cmd_sock[0])) {
2001                         pr_err("close()");
2002                         return -1;
2003                 }
2004                 return child_f(nr, test_desc_fd[0], cmd_sock[1], data_map);
2005         }
2006         if (close(cmd_sock[1])) {
2007                 pr_err("close()");
2008                 return -1;
2009         }
2010         return grand_child_f(nr, cmd_sock[0], data_map);
2011 }
2012
2013 static void exit_usage(char **argv)
2014 {
2015         printk("Usage: %s [nr_process]", argv[0]);
2016         exit(KSFT_FAIL);
2017 }
2018
2019 static int __write_desc(int test_desc_fd, struct xfrm_desc *desc)
2020 {
2021         ssize_t ret;
2022
2023         ret = write(test_desc_fd, desc, sizeof(*desc));
2024
2025         if (ret == sizeof(*desc))
2026                 return 0;
2027
2028         pr_err("Writing test's desc failed %ld", ret);
2029
2030         return -1;
2031 }
2032
2033 static int write_desc(int proto, int test_desc_fd,
2034                 char *a, char *e, char *c, char *ae)
2035 {
2036         struct xfrm_desc desc = {};
2037
2038         desc.type = CREATE_TUNNEL;
2039         desc.proto = proto;
2040
2041         if (a)
2042                 strncpy(desc.a_algo, a, ALGO_LEN - 1);
2043         if (e)
2044                 strncpy(desc.e_algo, e, ALGO_LEN - 1);
2045         if (c)
2046                 strncpy(desc.c_algo, c, ALGO_LEN - 1);
2047         if (ae)
2048                 strncpy(desc.ae_algo, ae, ALGO_LEN - 1);
2049
2050         return __write_desc(test_desc_fd, &desc);
2051 }
2052
2053 int proto_list[] = { IPPROTO_AH, IPPROTO_COMP, IPPROTO_ESP };
2054 char *ah_list[] = {
2055         "digest_null", "hmac(md5)", "hmac(sha1)", "hmac(sha256)",
2056         "hmac(sha384)", "hmac(sha512)", "hmac(rmd160)",
2057         "xcbc(aes)", "cmac(aes)"
2058 };
2059 char *comp_list[] = {
2060         "deflate",
2061 #if 0
2062         /* No compression backend realization */
2063         "lzs", "lzjh"
2064 #endif
2065 };
2066 char *e_list[] = {
2067         "ecb(cipher_null)", "cbc(des)", "cbc(des3_ede)", "cbc(cast5)",
2068         "cbc(blowfish)", "cbc(aes)", "cbc(serpent)", "cbc(camellia)",
2069         "cbc(twofish)", "rfc3686(ctr(aes))"
2070 };
2071 char *ae_list[] = {
2072 #if 0
2073         /* not implemented */
2074         "rfc4106(gcm(aes))", "rfc4309(ccm(aes))", "rfc4543(gcm(aes))",
2075         "rfc7539esp(chacha20,poly1305)"
2076 #endif
2077 };
2078
2079 const unsigned int proto_plan = ARRAY_SIZE(ah_list) + ARRAY_SIZE(comp_list) \
2080                                 + (ARRAY_SIZE(ah_list) * ARRAY_SIZE(e_list)) \
2081                                 + ARRAY_SIZE(ae_list);
2082
2083 static int write_proto_plan(int fd, int proto)
2084 {
2085         unsigned int i;
2086
2087         switch (proto) {
2088         case IPPROTO_AH:
2089                 for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2090                         if (write_desc(proto, fd, ah_list[i], 0, 0, 0))
2091                                 return -1;
2092                 }
2093                 break;
2094         case IPPROTO_COMP:
2095                 for (i = 0; i < ARRAY_SIZE(comp_list); i++) {
2096                         if (write_desc(proto, fd, 0, 0, comp_list[i], 0))
2097                                 return -1;
2098                 }
2099                 break;
2100         case IPPROTO_ESP:
2101                 for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
2102                         int j;
2103
2104                         for (j = 0; j < ARRAY_SIZE(e_list); j++) {
2105                                 if (write_desc(proto, fd, ah_list[i],
2106                                                         e_list[j], 0, 0))
2107                                         return -1;
2108                         }
2109                 }
2110                 for (i = 0; i < ARRAY_SIZE(ae_list); i++) {
2111                         if (write_desc(proto, fd, 0, 0, 0, ae_list[i]))
2112                                 return -1;
2113                 }
2114                 break;
2115         default:
2116                 printk("BUG: Specified unknown proto %d", proto);
2117                 return -1;
2118         }
2119
2120         return 0;
2121 }
2122
2123 /*
2124  * Some structures in xfrm uapi header differ in size between
2125  * 64-bit and 32-bit ABI:
2126  *
2127  *             32-bit UABI               |            64-bit UABI
2128  *  -------------------------------------|-------------------------------------
2129  *   sizeof(xfrm_usersa_info)     = 220  |  sizeof(xfrm_usersa_info)     = 224
2130  *   sizeof(xfrm_userpolicy_info) = 164  |  sizeof(xfrm_userpolicy_info) = 168
2131  *   sizeof(xfrm_userspi_info)    = 228  |  sizeof(xfrm_userspi_info)    = 232
2132  *   sizeof(xfrm_user_acquire)    = 276  |  sizeof(xfrm_user_acquire)    = 280
2133  *   sizeof(xfrm_user_expire)     = 224  |  sizeof(xfrm_user_expire)     = 232
2134  *   sizeof(xfrm_user_polexpire)  = 168  |  sizeof(xfrm_user_polexpire)  = 176
2135  *
2136  * Check the affected by the UABI difference structures.
2137  * Also, check translation for xfrm_set_spdinfo: it has it's own attributes
2138  * which needs to be correctly copied, but not translated.
2139  */
2140 const unsigned int compat_plan = 5;
2141 static int write_compat_struct_tests(int test_desc_fd)
2142 {
2143         struct xfrm_desc desc = {};
2144
2145         desc.type = ALLOCATE_SPI;
2146         desc.proto = IPPROTO_AH;
2147         strncpy(desc.a_algo, ah_list[0], ALGO_LEN - 1);
2148
2149         if (__write_desc(test_desc_fd, &desc))
2150                 return -1;
2151
2152         desc.type = MONITOR_ACQUIRE;
2153         if (__write_desc(test_desc_fd, &desc))
2154                 return -1;
2155
2156         desc.type = EXPIRE_STATE;
2157         if (__write_desc(test_desc_fd, &desc))
2158                 return -1;
2159
2160         desc.type = EXPIRE_POLICY;
2161         if (__write_desc(test_desc_fd, &desc))
2162                 return -1;
2163
2164         desc.type = SPDINFO_ATTRS;
2165         if (__write_desc(test_desc_fd, &desc))
2166                 return -1;
2167
2168         return 0;
2169 }
2170
2171 static int write_test_plan(int test_desc_fd)
2172 {
2173         unsigned int i;
2174         pid_t child;
2175
2176         child = fork();
2177         if (child < 0) {
2178                 pr_err("fork()");
2179                 return -1;
2180         }
2181         if (child) {
2182                 if (close(test_desc_fd))
2183                         printk("close(): %m");
2184                 return 0;
2185         }
2186
2187         if (write_compat_struct_tests(test_desc_fd))
2188                 exit(KSFT_FAIL);
2189
2190         for (i = 0; i < ARRAY_SIZE(proto_list); i++) {
2191                 if (write_proto_plan(test_desc_fd, proto_list[i]))
2192                         exit(KSFT_FAIL);
2193         }
2194
2195         exit(KSFT_PASS);
2196 }
2197
2198 static int children_cleanup(void)
2199 {
2200         unsigned ret = KSFT_PASS;
2201
2202         while (1) {
2203                 int status;
2204                 pid_t p = wait(&status);
2205
2206                 if ((p < 0) && errno == ECHILD)
2207                         break;
2208
2209                 if (p < 0) {
2210                         pr_err("wait()");
2211                         return KSFT_FAIL;
2212                 }
2213
2214                 if (!WIFEXITED(status)) {
2215                         ret = KSFT_FAIL;
2216                         continue;
2217                 }
2218
2219                 if (WEXITSTATUS(status) == KSFT_FAIL)
2220                         ret = KSFT_FAIL;
2221         }
2222
2223         return ret;
2224 }
2225
2226 typedef void (*print_res)(const char *, ...);
2227
2228 static int check_results(void)
2229 {
2230         struct test_result tr = {};
2231         struct xfrm_desc *d = &tr.desc;
2232         int ret = KSFT_PASS;
2233
2234         while (1) {
2235                 ssize_t received = read(results_fd[0], &tr, sizeof(tr));
2236                 print_res result;
2237
2238                 if (received == 0) /* EOF */
2239                         break;
2240
2241                 if (received != sizeof(tr)) {
2242                         pr_err("read() returned %zd", received);
2243                         return KSFT_FAIL;
2244                 }
2245
2246                 switch (tr.res) {
2247                 case KSFT_PASS:
2248                         result = ksft_test_result_pass;
2249                         break;
2250                 case KSFT_FAIL:
2251                 default:
2252                         result = ksft_test_result_fail;
2253                         ret = KSFT_FAIL;
2254                 }
2255
2256                 result(" %s: [%u, '%s', '%s', '%s', '%s', %u]\n",
2257                        desc_name[d->type], (unsigned int)d->proto, d->a_algo,
2258                        d->e_algo, d->c_algo, d->ae_algo, d->icv_len);
2259         }
2260
2261         return ret;
2262 }
2263
2264 int main(int argc, char **argv)
2265 {
2266         long nr_process = 1;
2267         int route_sock = -1, ret = KSFT_SKIP;
2268         int test_desc_fd[2];
2269         uint32_t route_seq;
2270         unsigned int i;
2271
2272         if (argc > 2)
2273                 exit_usage(argv);
2274
2275         if (argc > 1) {
2276                 char *endptr;
2277
2278                 errno = 0;
2279                 nr_process = strtol(argv[1], &endptr, 10);
2280                 if ((errno == ERANGE && (nr_process == LONG_MAX || nr_process == LONG_MIN))
2281                                 || (errno != 0 && nr_process == 0)
2282                                 || (endptr == argv[1]) || (*endptr != '\0')) {
2283                         printk("Failed to parse [nr_process]");
2284                         exit_usage(argv);
2285                 }
2286
2287                 if (nr_process > MAX_PROCESSES || nr_process < 1) {
2288                         printk("nr_process should be between [1; %u]",
2289                                         MAX_PROCESSES);
2290                         exit_usage(argv);
2291                 }
2292         }
2293
2294         srand(time(NULL));
2295         page_size = sysconf(_SC_PAGESIZE);
2296         if (page_size < 1)
2297                 ksft_exit_skip("sysconf(): %m\n");
2298
2299         if (pipe2(test_desc_fd, O_DIRECT) < 0)
2300                 ksft_exit_skip("pipe(): %m\n");
2301
2302         if (pipe2(results_fd, O_DIRECT) < 0)
2303                 ksft_exit_skip("pipe(): %m\n");
2304
2305         if (init_namespaces())
2306                 ksft_exit_skip("Failed to create namespaces\n");
2307
2308         if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
2309                 ksft_exit_skip("Failed to open netlink route socket\n");
2310
2311         for (i = 0; i < nr_process; i++) {
2312                 char veth[VETH_LEN];
2313
2314                 snprintf(veth, VETH_LEN, VETH_FMT, i);
2315
2316                 if (veth_add(route_sock, route_seq++, veth, nsfd_childa, veth, nsfd_childb)) {
2317                         close(route_sock);
2318                         ksft_exit_fail_msg("Failed to create veth device");
2319                 }
2320
2321                 if (start_child(i, veth, test_desc_fd)) {
2322                         close(route_sock);
2323                         ksft_exit_fail_msg("Child %u failed to start", i);
2324                 }
2325         }
2326
2327         if (close(route_sock) || close(test_desc_fd[0]) || close(results_fd[1]))
2328                 ksft_exit_fail_msg("close(): %m");
2329
2330         ksft_set_plan(proto_plan + compat_plan);
2331
2332         if (write_test_plan(test_desc_fd[1]))
2333                 ksft_exit_fail_msg("Failed to write test plan to pipe");
2334
2335         ret = check_results();
2336
2337         if (children_cleanup() == KSFT_FAIL)
2338                 exit(KSFT_FAIL);
2339
2340         exit(ret);
2341 }