selftests: mptcp: sockopt: relax expected returned size
[platform/kernel/linux-starfive.git] / tools / testing / selftests / net / mptcp / mptcp_sockopt.c
1 // SPDX-License-Identifier: GPL-2.0
2
3 #define _GNU_SOURCE
4
5 #include <assert.h>
6 #include <errno.h>
7 #include <fcntl.h>
8 #include <limits.h>
9 #include <string.h>
10 #include <stdarg.h>
11 #include <stdbool.h>
12 #include <stdint.h>
13 #include <inttypes.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <strings.h>
17 #include <time.h>
18 #include <unistd.h>
19
20 #include <sys/socket.h>
21 #include <sys/types.h>
22 #include <sys/wait.h>
23
24 #include <netdb.h>
25 #include <netinet/in.h>
26
27 #include <linux/tcp.h>
28
29 static int pf = AF_INET;
30
31 #ifndef IPPROTO_MPTCP
32 #define IPPROTO_MPTCP 262
33 #endif
34 #ifndef SOL_MPTCP
35 #define SOL_MPTCP 284
36 #endif
37
38 #ifndef MPTCP_INFO
39 struct mptcp_info {
40         __u8    mptcpi_subflows;
41         __u8    mptcpi_add_addr_signal;
42         __u8    mptcpi_add_addr_accepted;
43         __u8    mptcpi_subflows_max;
44         __u8    mptcpi_add_addr_signal_max;
45         __u8    mptcpi_add_addr_accepted_max;
46         __u32   mptcpi_flags;
47         __u32   mptcpi_token;
48         __u64   mptcpi_write_seq;
49         __u64   mptcpi_snd_una;
50         __u64   mptcpi_rcv_nxt;
51         __u8    mptcpi_local_addr_used;
52         __u8    mptcpi_local_addr_max;
53         __u8    mptcpi_csum_enabled;
54 };
55
56 struct mptcp_subflow_data {
57         __u32           size_subflow_data;              /* size of this structure in userspace */
58         __u32           num_subflows;                   /* must be 0, set by kernel */
59         __u32           size_kernel;                    /* must be 0, set by kernel */
60         __u32           size_user;                      /* size of one element in data[] */
61 } __attribute__((aligned(8)));
62
63 struct mptcp_subflow_addrs {
64         union {
65                 __kernel_sa_family_t sa_family;
66                 struct sockaddr sa_local;
67                 struct sockaddr_in sin_local;
68                 struct sockaddr_in6 sin6_local;
69                 struct __kernel_sockaddr_storage ss_local;
70         };
71         union {
72                 struct sockaddr sa_remote;
73                 struct sockaddr_in sin_remote;
74                 struct sockaddr_in6 sin6_remote;
75                 struct __kernel_sockaddr_storage ss_remote;
76         };
77 };
78
79 #define MPTCP_INFO              1
80 #define MPTCP_TCPINFO           2
81 #define MPTCP_SUBFLOW_ADDRS     3
82 #endif
83
84 struct so_state {
85         struct mptcp_info mi;
86         uint64_t mptcpi_rcv_delta;
87         uint64_t tcpi_rcv_delta;
88 };
89
90 #ifndef MIN
91 #define MIN(a, b) ((a) < (b) ? (a) : (b))
92 #endif
93
94 static void die_perror(const char *msg)
95 {
96         perror(msg);
97         exit(1);
98 }
99
100 static void die_usage(int r)
101 {
102         fprintf(stderr, "Usage: mptcp_sockopt [-6]\n");
103         exit(r);
104 }
105
106 static void xerror(const char *fmt, ...)
107 {
108         va_list ap;
109
110         va_start(ap, fmt);
111         vfprintf(stderr, fmt, ap);
112         va_end(ap);
113         fputc('\n', stderr);
114         exit(1);
115 }
116
117 static const char *getxinfo_strerr(int err)
118 {
119         if (err == EAI_SYSTEM)
120                 return strerror(errno);
121
122         return gai_strerror(err);
123 }
124
125 static void xgetaddrinfo(const char *node, const char *service,
126                          const struct addrinfo *hints,
127                          struct addrinfo **res)
128 {
129         int err = getaddrinfo(node, service, hints, res);
130
131         if (err) {
132                 const char *errstr = getxinfo_strerr(err);
133
134                 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
135                         node ? node : "", service ? service : "", errstr);
136                 exit(1);
137         }
138 }
139
140 static int sock_listen_mptcp(const char * const listenaddr,
141                              const char * const port)
142 {
143         int sock = -1;
144         struct addrinfo hints = {
145                 .ai_protocol = IPPROTO_TCP,
146                 .ai_socktype = SOCK_STREAM,
147                 .ai_flags = AI_PASSIVE | AI_NUMERICHOST
148         };
149
150         hints.ai_family = pf;
151
152         struct addrinfo *a, *addr;
153         int one = 1;
154
155         xgetaddrinfo(listenaddr, port, &hints, &addr);
156         hints.ai_family = pf;
157
158         for (a = addr; a; a = a->ai_next) {
159                 sock = socket(a->ai_family, a->ai_socktype, IPPROTO_MPTCP);
160                 if (sock < 0)
161                         continue;
162
163                 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
164                                      sizeof(one)))
165                         perror("setsockopt");
166
167                 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
168                         break; /* success */
169
170                 perror("bind");
171                 close(sock);
172                 sock = -1;
173         }
174
175         freeaddrinfo(addr);
176
177         if (sock < 0)
178                 xerror("could not create listen socket");
179
180         if (listen(sock, 20))
181                 die_perror("listen");
182
183         return sock;
184 }
185
186 static int sock_connect_mptcp(const char * const remoteaddr,
187                               const char * const port, int proto)
188 {
189         struct addrinfo hints = {
190                 .ai_protocol = IPPROTO_TCP,
191                 .ai_socktype = SOCK_STREAM,
192         };
193         struct addrinfo *a, *addr;
194         int sock = -1;
195
196         hints.ai_family = pf;
197
198         xgetaddrinfo(remoteaddr, port, &hints, &addr);
199         for (a = addr; a; a = a->ai_next) {
200                 sock = socket(a->ai_family, a->ai_socktype, proto);
201                 if (sock < 0)
202                         continue;
203
204                 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
205                         break; /* success */
206
207                 die_perror("connect");
208         }
209
210         if (sock < 0)
211                 xerror("could not create connect socket");
212
213         freeaddrinfo(addr);
214         return sock;
215 }
216
217 static void parse_opts(int argc, char **argv)
218 {
219         int c;
220
221         while ((c = getopt(argc, argv, "h6")) != -1) {
222                 switch (c) {
223                 case 'h':
224                         die_usage(0);
225                         break;
226                 case '6':
227                         pf = AF_INET6;
228                         break;
229                 default:
230                         die_usage(1);
231                         break;
232                 }
233         }
234 }
235
236 static void do_getsockopt_bogus_sf_data(int fd, int optname)
237 {
238         struct mptcp_subflow_data good_data;
239         struct bogus_data {
240                 struct mptcp_subflow_data d;
241                 char buf[2];
242         } bd;
243         socklen_t olen, _olen;
244         int ret;
245
246         memset(&bd, 0, sizeof(bd));
247         memset(&good_data, 0, sizeof(good_data));
248
249         olen = sizeof(good_data);
250         good_data.size_subflow_data = olen;
251
252         ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
253         assert(ret < 0); /* 0 size_subflow_data */
254         assert(olen == sizeof(good_data));
255
256         bd.d = good_data;
257
258         ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
259         assert(ret == 0);
260         assert(olen == sizeof(good_data));
261         assert(bd.d.num_subflows == 1);
262         assert(bd.d.size_kernel > 0);
263         assert(bd.d.size_user == 0);
264
265         bd.d = good_data;
266         _olen = rand() % olen;
267         olen = _olen;
268         ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
269         assert(ret < 0);        /* bogus olen */
270         assert(olen == _olen);  /* must be unchanged */
271
272         bd.d = good_data;
273         olen = sizeof(good_data);
274         bd.d.size_kernel = 1;
275         ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
276         assert(ret < 0); /* size_kernel not 0 */
277
278         bd.d = good_data;
279         olen = sizeof(good_data);
280         bd.d.num_subflows = 1;
281         ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
282         assert(ret < 0); /* num_subflows not 0 */
283
284         /* forward compat check: larger struct mptcp_subflow_data on 'old' kernel */
285         bd.d = good_data;
286         olen = sizeof(bd);
287         bd.d.size_subflow_data = sizeof(bd);
288
289         ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
290         assert(ret == 0);
291
292         /* olen must be truncated to real data size filled by kernel: */
293         assert(olen == sizeof(good_data));
294
295         assert(bd.d.size_subflow_data == sizeof(bd));
296
297         bd.d = good_data;
298         bd.d.size_subflow_data += 1;
299         bd.d.size_user = 1;
300         olen = bd.d.size_subflow_data + 1;
301         _olen = olen;
302
303         ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &_olen);
304         assert(ret == 0);
305
306         /* no truncation, kernel should have filled 1 byte of optname payload in buf[1]: */
307         assert(olen == _olen);
308
309         assert(bd.d.size_subflow_data == sizeof(good_data) + 1);
310         assert(bd.buf[0] == 0);
311 }
312
313 static void do_getsockopt_mptcp_info(struct so_state *s, int fd, size_t w)
314 {
315         struct mptcp_info i;
316         socklen_t olen;
317         int ret;
318
319         olen = sizeof(i);
320         ret = getsockopt(fd, SOL_MPTCP, MPTCP_INFO, &i, &olen);
321
322         if (ret < 0)
323                 die_perror("getsockopt MPTCP_INFO");
324
325         assert(olen == sizeof(i));
326
327         if (s->mi.mptcpi_write_seq == 0)
328                 s->mi = i;
329
330         assert(s->mi.mptcpi_write_seq + w == i.mptcpi_write_seq);
331
332         s->mptcpi_rcv_delta = i.mptcpi_rcv_nxt - s->mi.mptcpi_rcv_nxt;
333 }
334
335 static void do_getsockopt_tcp_info(struct so_state *s, int fd, size_t r, size_t w)
336 {
337         struct my_tcp_info {
338                 struct mptcp_subflow_data d;
339                 struct tcp_info ti[2];
340         } ti;
341         int ret, tries = 5;
342         socklen_t olen;
343
344         do {
345                 memset(&ti, 0, sizeof(ti));
346
347                 ti.d.size_subflow_data = sizeof(struct mptcp_subflow_data);
348                 ti.d.size_user = sizeof(struct tcp_info);
349                 olen = sizeof(ti);
350
351                 ret = getsockopt(fd, SOL_MPTCP, MPTCP_TCPINFO, &ti, &olen);
352                 if (ret < 0)
353                         xerror("getsockopt MPTCP_TCPINFO (tries %d, %m)");
354
355                 assert(olen <= sizeof(ti));
356                 assert(ti.d.size_kernel > 0);
357                 assert(ti.d.size_user ==
358                        MIN(ti.d.size_kernel, sizeof(struct tcp_info)));
359                 assert(ti.d.num_subflows == 1);
360
361                 assert(olen > (socklen_t)sizeof(struct mptcp_subflow_data));
362                 olen -= sizeof(struct mptcp_subflow_data);
363                 assert(olen == ti.d.size_user);
364
365                 if (ti.ti[0].tcpi_bytes_sent == w &&
366                     ti.ti[0].tcpi_bytes_received == r)
367                         goto done;
368
369                 if (r == 0 && ti.ti[0].tcpi_bytes_sent == w &&
370                     ti.ti[0].tcpi_bytes_received) {
371                         s->tcpi_rcv_delta = ti.ti[0].tcpi_bytes_received;
372                         goto done;
373                 }
374
375                 /* wait and repeat, might be that tx is still ongoing */
376                 sleep(1);
377         } while (tries-- > 0);
378
379         xerror("tcpi_bytes_sent %" PRIu64 ", want %zu. tcpi_bytes_received %" PRIu64 ", want %zu",
380                 ti.ti[0].tcpi_bytes_sent, w, ti.ti[0].tcpi_bytes_received, r);
381
382 done:
383         do_getsockopt_bogus_sf_data(fd, MPTCP_TCPINFO);
384 }
385
386 static void do_getsockopt_subflow_addrs(int fd)
387 {
388         struct sockaddr_storage remote, local;
389         socklen_t olen, rlen, llen;
390         int ret;
391         struct my_addrs {
392                 struct mptcp_subflow_data d;
393                 struct mptcp_subflow_addrs addr[2];
394         } addrs;
395
396         memset(&addrs, 0, sizeof(addrs));
397         memset(&local, 0, sizeof(local));
398         memset(&remote, 0, sizeof(remote));
399
400         addrs.d.size_subflow_data = sizeof(struct mptcp_subflow_data);
401         addrs.d.size_user = sizeof(struct mptcp_subflow_addrs);
402         olen = sizeof(addrs);
403
404         ret = getsockopt(fd, SOL_MPTCP, MPTCP_SUBFLOW_ADDRS, &addrs, &olen);
405         if (ret < 0)
406                 die_perror("getsockopt MPTCP_SUBFLOW_ADDRS");
407
408         assert(olen <= sizeof(addrs));
409         assert(addrs.d.size_kernel > 0);
410         assert(addrs.d.size_user ==
411                MIN(addrs.d.size_kernel, sizeof(struct mptcp_subflow_addrs)));
412         assert(addrs.d.num_subflows == 1);
413
414         assert(olen > (socklen_t)sizeof(struct mptcp_subflow_data));
415         olen -= sizeof(struct mptcp_subflow_data);
416         assert(olen == addrs.d.size_user);
417
418         llen = sizeof(local);
419         ret = getsockname(fd, (struct sockaddr *)&local, &llen);
420         if (ret < 0)
421                 die_perror("getsockname");
422         rlen = sizeof(remote);
423         ret = getpeername(fd, (struct sockaddr *)&remote, &rlen);
424         if (ret < 0)
425                 die_perror("getpeername");
426
427         assert(rlen > 0);
428         assert(rlen == llen);
429
430         assert(remote.ss_family == local.ss_family);
431
432         assert(memcmp(&local, &addrs.addr[0].ss_local, sizeof(local)) == 0);
433         assert(memcmp(&remote, &addrs.addr[0].ss_remote, sizeof(remote)) == 0);
434
435         memset(&addrs, 0, sizeof(addrs));
436
437         addrs.d.size_subflow_data = sizeof(struct mptcp_subflow_data);
438         addrs.d.size_user = sizeof(sa_family_t);
439         olen = sizeof(addrs.d) + sizeof(sa_family_t);
440
441         ret = getsockopt(fd, SOL_MPTCP, MPTCP_SUBFLOW_ADDRS, &addrs, &olen);
442         assert(ret == 0);
443         assert(olen == sizeof(addrs.d) + sizeof(sa_family_t));
444
445         assert(addrs.addr[0].sa_family == pf);
446         assert(addrs.addr[0].sa_family == local.ss_family);
447
448         assert(memcmp(&local, &addrs.addr[0].ss_local, sizeof(local)) != 0);
449         assert(memcmp(&remote, &addrs.addr[0].ss_remote, sizeof(remote)) != 0);
450
451         do_getsockopt_bogus_sf_data(fd, MPTCP_SUBFLOW_ADDRS);
452 }
453
454 static void do_getsockopts(struct so_state *s, int fd, size_t r, size_t w)
455 {
456         do_getsockopt_mptcp_info(s, fd, w);
457
458         do_getsockopt_tcp_info(s, fd, r, w);
459
460         do_getsockopt_subflow_addrs(fd);
461 }
462
463 static void connect_one_server(int fd, int pipefd)
464 {
465         char buf[4096], buf2[4096];
466         size_t len, i, total;
467         struct so_state s;
468         bool eof = false;
469         ssize_t ret;
470
471         memset(&s, 0, sizeof(s));
472
473         len = rand() % (sizeof(buf) - 1);
474
475         if (len < 128)
476                 len = 128;
477
478         for (i = 0; i < len ; i++) {
479                 buf[i] = rand() % 26;
480                 buf[i] += 'A';
481         }
482
483         buf[i] = '\n';
484
485         do_getsockopts(&s, fd, 0, 0);
486
487         /* un-block server */
488         ret = read(pipefd, buf2, 4);
489         assert(ret == 4);
490         close(pipefd);
491
492         assert(strncmp(buf2, "xmit", 4) == 0);
493
494         ret = write(fd, buf, len);
495         if (ret < 0)
496                 die_perror("write");
497
498         if (ret != (ssize_t)len)
499                 xerror("short write");
500
501         total = 0;
502         do {
503                 ret = read(fd, buf2 + total, sizeof(buf2) - total);
504                 if (ret < 0)
505                         die_perror("read");
506                 if (ret == 0) {
507                         eof = true;
508                         break;
509                 }
510
511                 total += ret;
512         } while (total < len);
513
514         if (total != len)
515                 xerror("total %lu, len %lu eof %d\n", total, len, eof);
516
517         if (memcmp(buf, buf2, len))
518                 xerror("data corruption");
519
520         if (s.tcpi_rcv_delta)
521                 assert(s.tcpi_rcv_delta <= total);
522
523         do_getsockopts(&s, fd, ret, ret);
524
525         if (eof)
526                 total += 1; /* sequence advances due to FIN */
527
528         assert(s.mptcpi_rcv_delta == (uint64_t)total);
529         close(fd);
530 }
531
532 static void process_one_client(int fd, int pipefd)
533 {
534         ssize_t ret, ret2, ret3;
535         struct so_state s;
536         char buf[4096];
537
538         memset(&s, 0, sizeof(s));
539         do_getsockopts(&s, fd, 0, 0);
540
541         ret = write(pipefd, "xmit", 4);
542         assert(ret == 4);
543
544         ret = read(fd, buf, sizeof(buf));
545         if (ret < 0)
546                 die_perror("read");
547
548         assert(s.mptcpi_rcv_delta <= (uint64_t)ret);
549
550         if (s.tcpi_rcv_delta)
551                 assert(s.tcpi_rcv_delta == (uint64_t)ret);
552
553         ret2 = write(fd, buf, ret);
554         if (ret2 < 0)
555                 die_perror("write");
556
557         /* wait for hangup */
558         ret3 = read(fd, buf, 1);
559         if (ret3 != 0)
560                 xerror("expected EOF, got %lu", ret3);
561
562         do_getsockopts(&s, fd, ret, ret2);
563         if (s.mptcpi_rcv_delta != (uint64_t)ret + 1)
564                 xerror("mptcpi_rcv_delta %" PRIu64 ", expect %" PRIu64, s.mptcpi_rcv_delta, ret + 1, s.mptcpi_rcv_delta - ret);
565         close(fd);
566 }
567
568 static int xaccept(int s)
569 {
570         int fd = accept(s, NULL, 0);
571
572         if (fd < 0)
573                 die_perror("accept");
574
575         return fd;
576 }
577
578 static int server(int pipefd)
579 {
580         int fd = -1, r;
581
582         switch (pf) {
583         case AF_INET:
584                 fd = sock_listen_mptcp("127.0.0.1", "15432");
585                 break;
586         case AF_INET6:
587                 fd = sock_listen_mptcp("::1", "15432");
588                 break;
589         default:
590                 xerror("Unknown pf %d\n", pf);
591                 break;
592         }
593
594         r = write(pipefd, "conn", 4);
595         assert(r == 4);
596
597         alarm(15);
598         r = xaccept(fd);
599
600         process_one_client(r, pipefd);
601
602         return 0;
603 }
604
605 static void test_ip_tos_sockopt(int fd)
606 {
607         uint8_t tos_in, tos_out;
608         socklen_t s;
609         int r;
610
611         tos_in = rand() & 0xfc;
612         r = setsockopt(fd, SOL_IP, IP_TOS, &tos_in, sizeof(tos_out));
613         if (r != 0)
614                 die_perror("setsockopt IP_TOS");
615
616         tos_out = 0;
617         s = sizeof(tos_out);
618         r = getsockopt(fd, SOL_IP, IP_TOS, &tos_out, &s);
619         if (r != 0)
620                 die_perror("getsockopt IP_TOS");
621
622         if (tos_in != tos_out)
623                 xerror("tos %x != %x socklen_t %d\n", tos_in, tos_out, s);
624
625         if (s != 1)
626                 xerror("tos should be 1 byte");
627
628         s = 0;
629         r = getsockopt(fd, SOL_IP, IP_TOS, &tos_out, &s);
630         if (r != 0)
631                 die_perror("getsockopt IP_TOS 0");
632         if (s != 0)
633                 xerror("expect socklen_t == 0");
634
635         s = -1;
636         r = getsockopt(fd, SOL_IP, IP_TOS, &tos_out, &s);
637         if (r != -1 && errno != EINVAL)
638                 die_perror("getsockopt IP_TOS did not indicate -EINVAL");
639         if (s != -1)
640                 xerror("expect socklen_t == -1");
641 }
642
643 static int client(int pipefd)
644 {
645         int fd = -1;
646
647         alarm(15);
648
649         switch (pf) {
650         case AF_INET:
651                 fd = sock_connect_mptcp("127.0.0.1", "15432", IPPROTO_MPTCP);
652                 break;
653         case AF_INET6:
654                 fd = sock_connect_mptcp("::1", "15432", IPPROTO_MPTCP);
655                 break;
656         default:
657                 xerror("Unknown pf %d\n", pf);
658         }
659
660         test_ip_tos_sockopt(fd);
661
662         connect_one_server(fd, pipefd);
663
664         return 0;
665 }
666
667 static pid_t xfork(void)
668 {
669         pid_t p = fork();
670
671         if (p < 0)
672                 die_perror("fork");
673
674         return p;
675 }
676
677 static int rcheck(int wstatus, const char *what)
678 {
679         if (WIFEXITED(wstatus)) {
680                 if (WEXITSTATUS(wstatus) == 0)
681                         return 0;
682                 fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
683                 return WEXITSTATUS(wstatus);
684         } else if (WIFSIGNALED(wstatus)) {
685                 xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
686         } else if (WIFSTOPPED(wstatus)) {
687                 xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
688         }
689
690         return 111;
691 }
692
693 static void init_rng(void)
694 {
695         int fd = open("/dev/urandom", O_RDONLY);
696
697         if (fd >= 0) {
698                 unsigned int foo;
699                 ssize_t ret;
700
701                 /* can't fail */
702                 ret = read(fd, &foo, sizeof(foo));
703                 assert(ret == sizeof(foo));
704
705                 close(fd);
706                 srand(foo);
707         } else {
708                 srand(time(NULL));
709         }
710 }
711
712 int main(int argc, char *argv[])
713 {
714         int e1, e2, wstatus;
715         pid_t s, c, ret;
716         int pipefds[2];
717
718         parse_opts(argc, argv);
719
720         init_rng();
721
722         e1 = pipe(pipefds);
723         if (e1 < 0)
724                 die_perror("pipe");
725
726         s = xfork();
727         if (s == 0)
728                 return server(pipefds[1]);
729
730         close(pipefds[1]);
731
732         /* wait until server bound a socket */
733         e1 = read(pipefds[0], &e1, 4);
734         assert(e1 == 4);
735
736         c = xfork();
737         if (c == 0)
738                 return client(pipefds[0]);
739
740         close(pipefds[0]);
741
742         ret = waitpid(s, &wstatus, 0);
743         if (ret == -1)
744                 die_perror("waitpid");
745         e1 = rcheck(wstatus, "server");
746         ret = waitpid(c, &wstatus, 0);
747         if (ret == -1)
748                 die_perror("waitpid");
749         e2 = rcheck(wstatus, "client");
750
751         return e1 ? e1 : e2;
752 }