Merge tag 'gfs2-v5.19-rc4-fixes' of git://git.kernel.org/pub/scm/linux/kernel/git...
[platform/kernel/linux-starfive.git] / tools / testing / selftests / net / mptcp / mptcp_sockopt.c
1 // SPDX-License-Identifier: GPL-2.0
2
3 #define _GNU_SOURCE
4
5 #include <assert.h>
6 #include <errno.h>
7 #include <fcntl.h>
8 #include <limits.h>
9 #include <string.h>
10 #include <stdarg.h>
11 #include <stdbool.h>
12 #include <stdint.h>
13 #include <inttypes.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <strings.h>
17 #include <time.h>
18 #include <unistd.h>
19
20 #include <sys/socket.h>
21 #include <sys/types.h>
22 #include <sys/wait.h>
23
24 #include <netdb.h>
25 #include <netinet/in.h>
26
27 #include <linux/tcp.h>
28
29 static int pf = AF_INET;
30
31 #ifndef IPPROTO_MPTCP
32 #define IPPROTO_MPTCP 262
33 #endif
34 #ifndef SOL_MPTCP
35 #define SOL_MPTCP 284
36 #endif
37
38 #ifndef MPTCP_INFO
39 struct mptcp_info {
40         __u8    mptcpi_subflows;
41         __u8    mptcpi_add_addr_signal;
42         __u8    mptcpi_add_addr_accepted;
43         __u8    mptcpi_subflows_max;
44         __u8    mptcpi_add_addr_signal_max;
45         __u8    mptcpi_add_addr_accepted_max;
46         __u32   mptcpi_flags;
47         __u32   mptcpi_token;
48         __u64   mptcpi_write_seq;
49         __u64   mptcpi_snd_una;
50         __u64   mptcpi_rcv_nxt;
51         __u8    mptcpi_local_addr_used;
52         __u8    mptcpi_local_addr_max;
53         __u8    mptcpi_csum_enabled;
54 };
55
56 struct mptcp_subflow_data {
57         __u32           size_subflow_data;              /* size of this structure in userspace */
58         __u32           num_subflows;                   /* must be 0, set by kernel */
59         __u32           size_kernel;                    /* must be 0, set by kernel */
60         __u32           size_user;                      /* size of one element in data[] */
61 } __attribute__((aligned(8)));
62
63 struct mptcp_subflow_addrs {
64         union {
65                 __kernel_sa_family_t sa_family;
66                 struct sockaddr sa_local;
67                 struct sockaddr_in sin_local;
68                 struct sockaddr_in6 sin6_local;
69                 struct __kernel_sockaddr_storage ss_local;
70         };
71         union {
72                 struct sockaddr sa_remote;
73                 struct sockaddr_in sin_remote;
74                 struct sockaddr_in6 sin6_remote;
75                 struct __kernel_sockaddr_storage ss_remote;
76         };
77 };
78
79 #define MPTCP_INFO              1
80 #define MPTCP_TCPINFO           2
81 #define MPTCP_SUBFLOW_ADDRS     3
82 #endif
83
84 struct so_state {
85         struct mptcp_info mi;
86         uint64_t mptcpi_rcv_delta;
87         uint64_t tcpi_rcv_delta;
88 };
89
90 static void die_perror(const char *msg)
91 {
92         perror(msg);
93         exit(1);
94 }
95
96 static void die_usage(int r)
97 {
98         fprintf(stderr, "Usage: mptcp_sockopt [-6]\n");
99         exit(r);
100 }
101
102 static void xerror(const char *fmt, ...)
103 {
104         va_list ap;
105
106         va_start(ap, fmt);
107         vfprintf(stderr, fmt, ap);
108         va_end(ap);
109         fputc('\n', stderr);
110         exit(1);
111 }
112
113 static const char *getxinfo_strerr(int err)
114 {
115         if (err == EAI_SYSTEM)
116                 return strerror(errno);
117
118         return gai_strerror(err);
119 }
120
121 static void xgetaddrinfo(const char *node, const char *service,
122                          const struct addrinfo *hints,
123                          struct addrinfo **res)
124 {
125         int err = getaddrinfo(node, service, hints, res);
126
127         if (err) {
128                 const char *errstr = getxinfo_strerr(err);
129
130                 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
131                         node ? node : "", service ? service : "", errstr);
132                 exit(1);
133         }
134 }
135
136 static int sock_listen_mptcp(const char * const listenaddr,
137                              const char * const port)
138 {
139         int sock = -1;
140         struct addrinfo hints = {
141                 .ai_protocol = IPPROTO_TCP,
142                 .ai_socktype = SOCK_STREAM,
143                 .ai_flags = AI_PASSIVE | AI_NUMERICHOST
144         };
145
146         hints.ai_family = pf;
147
148         struct addrinfo *a, *addr;
149         int one = 1;
150
151         xgetaddrinfo(listenaddr, port, &hints, &addr);
152         hints.ai_family = pf;
153
154         for (a = addr; a; a = a->ai_next) {
155                 sock = socket(a->ai_family, a->ai_socktype, IPPROTO_MPTCP);
156                 if (sock < 0)
157                         continue;
158
159                 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
160                                      sizeof(one)))
161                         perror("setsockopt");
162
163                 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
164                         break; /* success */
165
166                 perror("bind");
167                 close(sock);
168                 sock = -1;
169         }
170
171         freeaddrinfo(addr);
172
173         if (sock < 0)
174                 xerror("could not create listen socket");
175
176         if (listen(sock, 20))
177                 die_perror("listen");
178
179         return sock;
180 }
181
182 static int sock_connect_mptcp(const char * const remoteaddr,
183                               const char * const port, int proto)
184 {
185         struct addrinfo hints = {
186                 .ai_protocol = IPPROTO_TCP,
187                 .ai_socktype = SOCK_STREAM,
188         };
189         struct addrinfo *a, *addr;
190         int sock = -1;
191
192         hints.ai_family = pf;
193
194         xgetaddrinfo(remoteaddr, port, &hints, &addr);
195         for (a = addr; a; a = a->ai_next) {
196                 sock = socket(a->ai_family, a->ai_socktype, proto);
197                 if (sock < 0)
198                         continue;
199
200                 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
201                         break; /* success */
202
203                 die_perror("connect");
204         }
205
206         if (sock < 0)
207                 xerror("could not create connect socket");
208
209         freeaddrinfo(addr);
210         return sock;
211 }
212
213 static void parse_opts(int argc, char **argv)
214 {
215         int c;
216
217         while ((c = getopt(argc, argv, "h6")) != -1) {
218                 switch (c) {
219                 case 'h':
220                         die_usage(0);
221                         break;
222                 case '6':
223                         pf = AF_INET6;
224                         break;
225                 default:
226                         die_usage(1);
227                         break;
228                 }
229         }
230 }
231
232 static void do_getsockopt_bogus_sf_data(int fd, int optname)
233 {
234         struct mptcp_subflow_data good_data;
235         struct bogus_data {
236                 struct mptcp_subflow_data d;
237                 char buf[2];
238         } bd;
239         socklen_t olen, _olen;
240         int ret;
241
242         memset(&bd, 0, sizeof(bd));
243         memset(&good_data, 0, sizeof(good_data));
244
245         olen = sizeof(good_data);
246         good_data.size_subflow_data = olen;
247
248         ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
249         assert(ret < 0); /* 0 size_subflow_data */
250         assert(olen == sizeof(good_data));
251
252         bd.d = good_data;
253
254         ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
255         assert(ret == 0);
256         assert(olen == sizeof(good_data));
257         assert(bd.d.num_subflows == 1);
258         assert(bd.d.size_kernel > 0);
259         assert(bd.d.size_user == 0);
260
261         bd.d = good_data;
262         _olen = rand() % olen;
263         olen = _olen;
264         ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
265         assert(ret < 0);        /* bogus olen */
266         assert(olen == _olen);  /* must be unchanged */
267
268         bd.d = good_data;
269         olen = sizeof(good_data);
270         bd.d.size_kernel = 1;
271         ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
272         assert(ret < 0); /* size_kernel not 0 */
273
274         bd.d = good_data;
275         olen = sizeof(good_data);
276         bd.d.num_subflows = 1;
277         ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
278         assert(ret < 0); /* num_subflows not 0 */
279
280         /* forward compat check: larger struct mptcp_subflow_data on 'old' kernel */
281         bd.d = good_data;
282         olen = sizeof(bd);
283         bd.d.size_subflow_data = sizeof(bd);
284
285         ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
286         assert(ret == 0);
287
288         /* olen must be truncated to real data size filled by kernel: */
289         assert(olen == sizeof(good_data));
290
291         assert(bd.d.size_subflow_data == sizeof(bd));
292
293         bd.d = good_data;
294         bd.d.size_subflow_data += 1;
295         bd.d.size_user = 1;
296         olen = bd.d.size_subflow_data + 1;
297         _olen = olen;
298
299         ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &_olen);
300         assert(ret == 0);
301
302         /* no truncation, kernel should have filled 1 byte of optname payload in buf[1]: */
303         assert(olen == _olen);
304
305         assert(bd.d.size_subflow_data == sizeof(good_data) + 1);
306         assert(bd.buf[0] == 0);
307 }
308
309 static void do_getsockopt_mptcp_info(struct so_state *s, int fd, size_t w)
310 {
311         struct mptcp_info i;
312         socklen_t olen;
313         int ret;
314
315         olen = sizeof(i);
316         ret = getsockopt(fd, SOL_MPTCP, MPTCP_INFO, &i, &olen);
317
318         if (ret < 0)
319                 die_perror("getsockopt MPTCP_INFO");
320
321         assert(olen == sizeof(i));
322
323         if (s->mi.mptcpi_write_seq == 0)
324                 s->mi = i;
325
326         assert(s->mi.mptcpi_write_seq + w == i.mptcpi_write_seq);
327
328         s->mptcpi_rcv_delta = i.mptcpi_rcv_nxt - s->mi.mptcpi_rcv_nxt;
329 }
330
331 static void do_getsockopt_tcp_info(struct so_state *s, int fd, size_t r, size_t w)
332 {
333         struct my_tcp_info {
334                 struct mptcp_subflow_data d;
335                 struct tcp_info ti[2];
336         } ti;
337         int ret, tries = 5;
338         socklen_t olen;
339
340         do {
341                 memset(&ti, 0, sizeof(ti));
342
343                 ti.d.size_subflow_data = sizeof(struct mptcp_subflow_data);
344                 ti.d.size_user = sizeof(struct tcp_info);
345                 olen = sizeof(ti);
346
347                 ret = getsockopt(fd, SOL_MPTCP, MPTCP_TCPINFO, &ti, &olen);
348                 if (ret < 0)
349                         xerror("getsockopt MPTCP_TCPINFO (tries %d, %m)");
350
351                 assert(olen <= sizeof(ti));
352                 assert(ti.d.size_user == ti.d.size_kernel);
353                 assert(ti.d.size_user == sizeof(struct tcp_info));
354                 assert(ti.d.num_subflows == 1);
355
356                 assert(olen > (socklen_t)sizeof(struct mptcp_subflow_data));
357                 olen -= sizeof(struct mptcp_subflow_data);
358                 assert(olen == sizeof(struct tcp_info));
359
360                 if (ti.ti[0].tcpi_bytes_sent == w &&
361                     ti.ti[0].tcpi_bytes_received == r)
362                         goto done;
363
364                 if (r == 0 && ti.ti[0].tcpi_bytes_sent == w &&
365                     ti.ti[0].tcpi_bytes_received) {
366                         s->tcpi_rcv_delta = ti.ti[0].tcpi_bytes_received;
367                         goto done;
368                 }
369
370                 /* wait and repeat, might be that tx is still ongoing */
371                 sleep(1);
372         } while (tries-- > 0);
373
374         xerror("tcpi_bytes_sent %" PRIu64 ", want %zu. tcpi_bytes_received %" PRIu64 ", want %zu",
375                 ti.ti[0].tcpi_bytes_sent, w, ti.ti[0].tcpi_bytes_received, r);
376
377 done:
378         do_getsockopt_bogus_sf_data(fd, MPTCP_TCPINFO);
379 }
380
381 static void do_getsockopt_subflow_addrs(int fd)
382 {
383         struct sockaddr_storage remote, local;
384         socklen_t olen, rlen, llen;
385         int ret;
386         struct my_addrs {
387                 struct mptcp_subflow_data d;
388                 struct mptcp_subflow_addrs addr[2];
389         } addrs;
390
391         memset(&addrs, 0, sizeof(addrs));
392         memset(&local, 0, sizeof(local));
393         memset(&remote, 0, sizeof(remote));
394
395         addrs.d.size_subflow_data = sizeof(struct mptcp_subflow_data);
396         addrs.d.size_user = sizeof(struct mptcp_subflow_addrs);
397         olen = sizeof(addrs);
398
399         ret = getsockopt(fd, SOL_MPTCP, MPTCP_SUBFLOW_ADDRS, &addrs, &olen);
400         if (ret < 0)
401                 die_perror("getsockopt MPTCP_SUBFLOW_ADDRS");
402
403         assert(olen <= sizeof(addrs));
404         assert(addrs.d.size_user == addrs.d.size_kernel);
405         assert(addrs.d.size_user == sizeof(struct mptcp_subflow_addrs));
406         assert(addrs.d.num_subflows == 1);
407
408         assert(olen > (socklen_t)sizeof(struct mptcp_subflow_data));
409         olen -= sizeof(struct mptcp_subflow_data);
410         assert(olen == sizeof(struct mptcp_subflow_addrs));
411
412         llen = sizeof(local);
413         ret = getsockname(fd, (struct sockaddr *)&local, &llen);
414         if (ret < 0)
415                 die_perror("getsockname");
416         rlen = sizeof(remote);
417         ret = getpeername(fd, (struct sockaddr *)&remote, &rlen);
418         if (ret < 0)
419                 die_perror("getpeername");
420
421         assert(rlen > 0);
422         assert(rlen == llen);
423
424         assert(remote.ss_family == local.ss_family);
425
426         assert(memcmp(&local, &addrs.addr[0].ss_local, sizeof(local)) == 0);
427         assert(memcmp(&remote, &addrs.addr[0].ss_remote, sizeof(remote)) == 0);
428
429         memset(&addrs, 0, sizeof(addrs));
430
431         addrs.d.size_subflow_data = sizeof(struct mptcp_subflow_data);
432         addrs.d.size_user = sizeof(sa_family_t);
433         olen = sizeof(addrs.d) + sizeof(sa_family_t);
434
435         ret = getsockopt(fd, SOL_MPTCP, MPTCP_SUBFLOW_ADDRS, &addrs, &olen);
436         assert(ret == 0);
437         assert(olen == sizeof(addrs.d) + sizeof(sa_family_t));
438
439         assert(addrs.addr[0].sa_family == pf);
440         assert(addrs.addr[0].sa_family == local.ss_family);
441
442         assert(memcmp(&local, &addrs.addr[0].ss_local, sizeof(local)) != 0);
443         assert(memcmp(&remote, &addrs.addr[0].ss_remote, sizeof(remote)) != 0);
444
445         do_getsockopt_bogus_sf_data(fd, MPTCP_SUBFLOW_ADDRS);
446 }
447
448 static void do_getsockopts(struct so_state *s, int fd, size_t r, size_t w)
449 {
450         do_getsockopt_mptcp_info(s, fd, w);
451
452         do_getsockopt_tcp_info(s, fd, r, w);
453
454         do_getsockopt_subflow_addrs(fd);
455 }
456
457 static void connect_one_server(int fd, int pipefd)
458 {
459         char buf[4096], buf2[4096];
460         size_t len, i, total;
461         struct so_state s;
462         bool eof = false;
463         ssize_t ret;
464
465         memset(&s, 0, sizeof(s));
466
467         len = rand() % (sizeof(buf) - 1);
468
469         if (len < 128)
470                 len = 128;
471
472         for (i = 0; i < len ; i++) {
473                 buf[i] = rand() % 26;
474                 buf[i] += 'A';
475         }
476
477         buf[i] = '\n';
478
479         do_getsockopts(&s, fd, 0, 0);
480
481         /* un-block server */
482         ret = read(pipefd, buf2, 4);
483         assert(ret == 4);
484         close(pipefd);
485
486         assert(strncmp(buf2, "xmit", 4) == 0);
487
488         ret = write(fd, buf, len);
489         if (ret < 0)
490                 die_perror("write");
491
492         if (ret != (ssize_t)len)
493                 xerror("short write");
494
495         total = 0;
496         do {
497                 ret = read(fd, buf2 + total, sizeof(buf2) - total);
498                 if (ret < 0)
499                         die_perror("read");
500                 if (ret == 0) {
501                         eof = true;
502                         break;
503                 }
504
505                 total += ret;
506         } while (total < len);
507
508         if (total != len)
509                 xerror("total %lu, len %lu eof %d\n", total, len, eof);
510
511         if (memcmp(buf, buf2, len))
512                 xerror("data corruption");
513
514         if (s.tcpi_rcv_delta)
515                 assert(s.tcpi_rcv_delta <= total);
516
517         do_getsockopts(&s, fd, ret, ret);
518
519         if (eof)
520                 total += 1; /* sequence advances due to FIN */
521
522         assert(s.mptcpi_rcv_delta == (uint64_t)total);
523         close(fd);
524 }
525
526 static void process_one_client(int fd, int pipefd)
527 {
528         ssize_t ret, ret2, ret3;
529         struct so_state s;
530         char buf[4096];
531
532         memset(&s, 0, sizeof(s));
533         do_getsockopts(&s, fd, 0, 0);
534
535         ret = write(pipefd, "xmit", 4);
536         assert(ret == 4);
537
538         ret = read(fd, buf, sizeof(buf));
539         if (ret < 0)
540                 die_perror("read");
541
542         assert(s.mptcpi_rcv_delta <= (uint64_t)ret);
543
544         if (s.tcpi_rcv_delta)
545                 assert(s.tcpi_rcv_delta == (uint64_t)ret);
546
547         ret2 = write(fd, buf, ret);
548         if (ret2 < 0)
549                 die_perror("write");
550
551         /* wait for hangup */
552         ret3 = read(fd, buf, 1);
553         if (ret3 != 0)
554                 xerror("expected EOF, got %lu", ret3);
555
556         do_getsockopts(&s, fd, ret, ret2);
557         if (s.mptcpi_rcv_delta != (uint64_t)ret + 1)
558                 xerror("mptcpi_rcv_delta %" PRIu64 ", expect %" PRIu64, s.mptcpi_rcv_delta, ret + 1, s.mptcpi_rcv_delta - ret);
559         close(fd);
560 }
561
562 static int xaccept(int s)
563 {
564         int fd = accept(s, NULL, 0);
565
566         if (fd < 0)
567                 die_perror("accept");
568
569         return fd;
570 }
571
572 static int server(int pipefd)
573 {
574         int fd = -1, r;
575
576         switch (pf) {
577         case AF_INET:
578                 fd = sock_listen_mptcp("127.0.0.1", "15432");
579                 break;
580         case AF_INET6:
581                 fd = sock_listen_mptcp("::1", "15432");
582                 break;
583         default:
584                 xerror("Unknown pf %d\n", pf);
585                 break;
586         }
587
588         r = write(pipefd, "conn", 4);
589         assert(r == 4);
590
591         alarm(15);
592         r = xaccept(fd);
593
594         process_one_client(r, pipefd);
595
596         return 0;
597 }
598
599 static void test_ip_tos_sockopt(int fd)
600 {
601         uint8_t tos_in, tos_out;
602         socklen_t s;
603         int r;
604
605         tos_in = rand() & 0xfc;
606         r = setsockopt(fd, SOL_IP, IP_TOS, &tos_in, sizeof(tos_out));
607         if (r != 0)
608                 die_perror("setsockopt IP_TOS");
609
610         tos_out = 0;
611         s = sizeof(tos_out);
612         r = getsockopt(fd, SOL_IP, IP_TOS, &tos_out, &s);
613         if (r != 0)
614                 die_perror("getsockopt IP_TOS");
615
616         if (tos_in != tos_out)
617                 xerror("tos %x != %x socklen_t %d\n", tos_in, tos_out, s);
618
619         if (s != 1)
620                 xerror("tos should be 1 byte");
621
622         s = 0;
623         r = getsockopt(fd, SOL_IP, IP_TOS, &tos_out, &s);
624         if (r != 0)
625                 die_perror("getsockopt IP_TOS 0");
626         if (s != 0)
627                 xerror("expect socklen_t == 0");
628
629         s = -1;
630         r = getsockopt(fd, SOL_IP, IP_TOS, &tos_out, &s);
631         if (r != -1 && errno != EINVAL)
632                 die_perror("getsockopt IP_TOS did not indicate -EINVAL");
633         if (s != -1)
634                 xerror("expect socklen_t == -1");
635 }
636
637 static int client(int pipefd)
638 {
639         int fd = -1;
640
641         alarm(15);
642
643         switch (pf) {
644         case AF_INET:
645                 fd = sock_connect_mptcp("127.0.0.1", "15432", IPPROTO_MPTCP);
646                 break;
647         case AF_INET6:
648                 fd = sock_connect_mptcp("::1", "15432", IPPROTO_MPTCP);
649                 break;
650         default:
651                 xerror("Unknown pf %d\n", pf);
652         }
653
654         test_ip_tos_sockopt(fd);
655
656         connect_one_server(fd, pipefd);
657
658         return 0;
659 }
660
661 static pid_t xfork(void)
662 {
663         pid_t p = fork();
664
665         if (p < 0)
666                 die_perror("fork");
667
668         return p;
669 }
670
671 static int rcheck(int wstatus, const char *what)
672 {
673         if (WIFEXITED(wstatus)) {
674                 if (WEXITSTATUS(wstatus) == 0)
675                         return 0;
676                 fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
677                 return WEXITSTATUS(wstatus);
678         } else if (WIFSIGNALED(wstatus)) {
679                 xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
680         } else if (WIFSTOPPED(wstatus)) {
681                 xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
682         }
683
684         return 111;
685 }
686
687 static void init_rng(void)
688 {
689         int fd = open("/dev/urandom", O_RDONLY);
690
691         if (fd >= 0) {
692                 unsigned int foo;
693                 ssize_t ret;
694
695                 /* can't fail */
696                 ret = read(fd, &foo, sizeof(foo));
697                 assert(ret == sizeof(foo));
698
699                 close(fd);
700                 srand(foo);
701         } else {
702                 srand(time(NULL));
703         }
704 }
705
706 int main(int argc, char *argv[])
707 {
708         int e1, e2, wstatus;
709         pid_t s, c, ret;
710         int pipefds[2];
711
712         parse_opts(argc, argv);
713
714         init_rng();
715
716         e1 = pipe(pipefds);
717         if (e1 < 0)
718                 die_perror("pipe");
719
720         s = xfork();
721         if (s == 0)
722                 return server(pipefds[1]);
723
724         close(pipefds[1]);
725
726         /* wait until server bound a socket */
727         e1 = read(pipefds[0], &e1, 4);
728         assert(e1 == 4);
729
730         c = xfork();
731         if (c == 0)
732                 return client(pipefds[0]);
733
734         close(pipefds[0]);
735
736         ret = waitpid(s, &wstatus, 0);
737         if (ret == -1)
738                 die_perror("waitpid");
739         e1 = rcheck(wstatus, "server");
740         ret = waitpid(c, &wstatus, 0);
741         if (ret == -1)
742                 die_perror("waitpid");
743         e2 = rcheck(wstatus, "client");
744
745         return e1 ? e1 : e2;
746 }