MAINTAINERS: update the LSM maintainer info
[platform/kernel/linux-starfive.git] / tools / testing / selftests / net / mptcp / mptcp_inq.c
1 // SPDX-License-Identifier: GPL-2.0
2
3 #define _GNU_SOURCE
4
5 #include <assert.h>
6 #include <errno.h>
7 #include <fcntl.h>
8 #include <limits.h>
9 #include <string.h>
10 #include <stdarg.h>
11 #include <stdbool.h>
12 #include <stdint.h>
13 #include <inttypes.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <strings.h>
17 #include <unistd.h>
18 #include <time.h>
19
20 #include <sys/ioctl.h>
21 #include <sys/socket.h>
22 #include <sys/types.h>
23 #include <sys/wait.h>
24
25 #include <netdb.h>
26 #include <netinet/in.h>
27
28 #include <linux/tcp.h>
29 #include <linux/sockios.h>
30
31 #ifndef IPPROTO_MPTCP
32 #define IPPROTO_MPTCP 262
33 #endif
34 #ifndef SOL_MPTCP
35 #define SOL_MPTCP 284
36 #endif
37
38 static int pf = AF_INET;
39 static int proto_tx = IPPROTO_MPTCP;
40 static int proto_rx = IPPROTO_MPTCP;
41
42 static void die_perror(const char *msg)
43 {
44         perror(msg);
45         exit(1);
46 }
47
48 static void die_usage(int r)
49 {
50         fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
51         exit(r);
52 }
53
54 static void xerror(const char *fmt, ...)
55 {
56         va_list ap;
57
58         va_start(ap, fmt);
59         vfprintf(stderr, fmt, ap);
60         va_end(ap);
61         fputc('\n', stderr);
62         exit(1);
63 }
64
65 static const char *getxinfo_strerr(int err)
66 {
67         if (err == EAI_SYSTEM)
68                 return strerror(errno);
69
70         return gai_strerror(err);
71 }
72
73 static void xgetaddrinfo(const char *node, const char *service,
74                          const struct addrinfo *hints,
75                          struct addrinfo **res)
76 {
77         int err = getaddrinfo(node, service, hints, res);
78
79         if (err) {
80                 const char *errstr = getxinfo_strerr(err);
81
82                 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
83                         node ? node : "", service ? service : "", errstr);
84                 exit(1);
85         }
86 }
87
88 static int sock_listen_mptcp(const char * const listenaddr,
89                              const char * const port)
90 {
91         int sock;
92         struct addrinfo hints = {
93                 .ai_protocol = IPPROTO_TCP,
94                 .ai_socktype = SOCK_STREAM,
95                 .ai_flags = AI_PASSIVE | AI_NUMERICHOST
96         };
97
98         hints.ai_family = pf;
99
100         struct addrinfo *a, *addr;
101         int one = 1;
102
103         xgetaddrinfo(listenaddr, port, &hints, &addr);
104         hints.ai_family = pf;
105
106         for (a = addr; a; a = a->ai_next) {
107                 sock = socket(a->ai_family, a->ai_socktype, proto_rx);
108                 if (sock < 0)
109                         continue;
110
111                 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
112                                      sizeof(one)))
113                         perror("setsockopt");
114
115                 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
116                         break; /* success */
117
118                 perror("bind");
119                 close(sock);
120                 sock = -1;
121         }
122
123         freeaddrinfo(addr);
124
125         if (sock < 0)
126                 xerror("could not create listen socket");
127
128         if (listen(sock, 20))
129                 die_perror("listen");
130
131         return sock;
132 }
133
134 static int sock_connect_mptcp(const char * const remoteaddr,
135                               const char * const port, int proto)
136 {
137         struct addrinfo hints = {
138                 .ai_protocol = IPPROTO_TCP,
139                 .ai_socktype = SOCK_STREAM,
140         };
141         struct addrinfo *a, *addr;
142         int sock = -1;
143
144         hints.ai_family = pf;
145
146         xgetaddrinfo(remoteaddr, port, &hints, &addr);
147         for (a = addr; a; a = a->ai_next) {
148                 sock = socket(a->ai_family, a->ai_socktype, proto);
149                 if (sock < 0)
150                         continue;
151
152                 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
153                         break; /* success */
154
155                 die_perror("connect");
156         }
157
158         if (sock < 0)
159                 xerror("could not create connect socket");
160
161         freeaddrinfo(addr);
162         return sock;
163 }
164
165 static int protostr_to_num(const char *s)
166 {
167         if (strcasecmp(s, "tcp") == 0)
168                 return IPPROTO_TCP;
169         if (strcasecmp(s, "mptcp") == 0)
170                 return IPPROTO_MPTCP;
171
172         die_usage(1);
173         return 0;
174 }
175
176 static void parse_opts(int argc, char **argv)
177 {
178         int c;
179
180         while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
181                 switch (c) {
182                 case 'h':
183                         die_usage(0);
184                         break;
185                 case '6':
186                         pf = AF_INET6;
187                         break;
188                 case 't':
189                         proto_tx = protostr_to_num(optarg);
190                         break;
191                 case 'r':
192                         proto_rx = protostr_to_num(optarg);
193                         break;
194                 default:
195                         die_usage(1);
196                         break;
197                 }
198         }
199 }
200
201 /* wait up to timeout milliseconds */
202 static void wait_for_ack(int fd, int timeout, size_t total)
203 {
204         int i;
205
206         for (i = 0; i < timeout; i++) {
207                 int nsd, ret, queued = -1;
208                 struct timespec req;
209
210                 ret = ioctl(fd, TIOCOUTQ, &queued);
211                 if (ret < 0)
212                         die_perror("TIOCOUTQ");
213
214                 ret = ioctl(fd, SIOCOUTQNSD, &nsd);
215                 if (ret < 0)
216                         die_perror("SIOCOUTQNSD");
217
218                 if ((size_t)queued > total)
219                         xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
220                 assert(nsd <= queued);
221
222                 if (queued == 0)
223                         return;
224
225                 /* wait for peer to ack rx of all data */
226                 req.tv_sec = 0;
227                 req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
228                 nanosleep(&req, NULL);
229         }
230
231         xerror("still tx data queued after %u ms\n", timeout);
232 }
233
234 static void connect_one_server(int fd, int unixfd)
235 {
236         size_t len, i, total, sent;
237         char buf[4096], buf2[4096];
238         ssize_t ret;
239
240         len = rand() % (sizeof(buf) - 1);
241
242         if (len < 128)
243                 len = 128;
244
245         for (i = 0; i < len ; i++) {
246                 buf[i] = rand() % 26;
247                 buf[i] += 'A';
248         }
249
250         buf[i] = '\n';
251
252         /* un-block server */
253         ret = read(unixfd, buf2, 4);
254         assert(ret == 4);
255
256         assert(strncmp(buf2, "xmit", 4) == 0);
257
258         ret = write(unixfd, &len, sizeof(len));
259         assert(ret == (ssize_t)sizeof(len));
260
261         ret = write(fd, buf, len);
262         if (ret < 0)
263                 die_perror("write");
264
265         if (ret != (ssize_t)len)
266                 xerror("short write");
267
268         ret = read(unixfd, buf2, 4);
269         assert(strncmp(buf2, "huge", 4) == 0);
270
271         total = rand() % (16 * 1024 * 1024);
272         total += (1 * 1024 * 1024);
273         sent = total;
274
275         ret = write(unixfd, &total, sizeof(total));
276         assert(ret == (ssize_t)sizeof(total));
277
278         wait_for_ack(fd, 5000, len);
279
280         while (total > 0) {
281                 if (total > sizeof(buf))
282                         len = sizeof(buf);
283                 else
284                         len = total;
285
286                 ret = write(fd, buf, len);
287                 if (ret < 0)
288                         die_perror("write");
289                 total -= ret;
290
291                 /* we don't have to care about buf content, only
292                  * number of total bytes sent
293                  */
294         }
295
296         ret = read(unixfd, buf2, 4);
297         assert(ret == 4);
298         assert(strncmp(buf2, "shut", 4) == 0);
299
300         wait_for_ack(fd, 5000, sent);
301
302         ret = write(fd, buf, 1);
303         assert(ret == 1);
304         close(fd);
305         ret = write(unixfd, "closed", 6);
306         assert(ret == 6);
307
308         close(unixfd);
309 }
310
311 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
312 {
313         struct cmsghdr *cmsg;
314
315         for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
316                 if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
317                         memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
318                         return;
319                 }
320         }
321
322         xerror("could not find TCP_CM_INQ cmsg type");
323 }
324
325 static void process_one_client(int fd, int unixfd)
326 {
327         unsigned int tcp_inq;
328         size_t expect_len;
329         char msg_buf[4096];
330         char buf[4096];
331         char tmp[16];
332         struct iovec iov = {
333                 .iov_base = buf,
334                 .iov_len = 1,
335         };
336         struct msghdr msg = {
337                 .msg_iov = &iov,
338                 .msg_iovlen = 1,
339                 .msg_control = msg_buf,
340                 .msg_controllen = sizeof(msg_buf),
341         };
342         ssize_t ret, tot;
343
344         ret = write(unixfd, "xmit", 4);
345         assert(ret == 4);
346
347         ret = read(unixfd, &expect_len, sizeof(expect_len));
348         assert(ret == (ssize_t)sizeof(expect_len));
349
350         if (expect_len > sizeof(buf))
351                 xerror("expect len %zu exceeds buffer size", expect_len);
352
353         for (;;) {
354                 struct timespec req;
355                 unsigned int queued;
356
357                 ret = ioctl(fd, FIONREAD, &queued);
358                 if (ret < 0)
359                         die_perror("FIONREAD");
360                 if (queued > expect_len)
361                         xerror("FIONREAD returned %u, but only %zu expected\n",
362                                queued, expect_len);
363                 if (queued == expect_len)
364                         break;
365
366                 req.tv_sec = 0;
367                 req.tv_nsec = 1000 * 1000ul;
368                 nanosleep(&req, NULL);
369         }
370
371         /* read one byte, expect cmsg to return expected - 1 */
372         ret = recvmsg(fd, &msg, 0);
373         if (ret < 0)
374                 die_perror("recvmsg");
375
376         if (msg.msg_controllen == 0)
377                 xerror("msg_controllen is 0");
378
379         get_tcp_inq(&msg, &tcp_inq);
380
381         assert((size_t)tcp_inq == (expect_len - 1));
382
383         iov.iov_len = sizeof(buf);
384         ret = recvmsg(fd, &msg, 0);
385         if (ret < 0)
386                 die_perror("recvmsg");
387
388         /* should have gotten exact remainder of all pending data */
389         assert(ret == (ssize_t)tcp_inq);
390
391         /* should be 0, all drained */
392         get_tcp_inq(&msg, &tcp_inq);
393         assert(tcp_inq == 0);
394
395         /* request a large swath of data. */
396         ret = write(unixfd, "huge", 4);
397         assert(ret == 4);
398
399         ret = read(unixfd, &expect_len, sizeof(expect_len));
400         assert(ret == (ssize_t)sizeof(expect_len));
401
402         /* peer should send us a few mb of data */
403         if (expect_len <= sizeof(buf))
404                 xerror("expect len %zu too small\n", expect_len);
405
406         tot = 0;
407         do {
408                 iov.iov_len = sizeof(buf);
409                 ret = recvmsg(fd, &msg, 0);
410                 if (ret < 0)
411                         die_perror("recvmsg");
412
413                 tot += ret;
414
415                 get_tcp_inq(&msg, &tcp_inq);
416
417                 if (tcp_inq > expect_len - tot)
418                         xerror("inq %d, remaining %d total_len %d\n",
419                                tcp_inq, expect_len - tot, (int)expect_len);
420
421                 assert(tcp_inq <= expect_len - tot);
422         } while ((size_t)tot < expect_len);
423
424         ret = write(unixfd, "shut", 4);
425         assert(ret == 4);
426
427         /* wait for hangup. Should have received one more byte of data. */
428         ret = read(unixfd, tmp, sizeof(tmp));
429         assert(ret == 6);
430         assert(strncmp(tmp, "closed", 6) == 0);
431
432         sleep(1);
433
434         iov.iov_len = 1;
435         ret = recvmsg(fd, &msg, 0);
436         if (ret < 0)
437                 die_perror("recvmsg");
438         assert(ret == 1);
439
440         get_tcp_inq(&msg, &tcp_inq);
441
442         /* tcp_inq should be 1 due to received fin. */
443         assert(tcp_inq == 1);
444
445         iov.iov_len = 1;
446         ret = recvmsg(fd, &msg, 0);
447         if (ret < 0)
448                 die_perror("recvmsg");
449
450         /* expect EOF */
451         assert(ret == 0);
452         get_tcp_inq(&msg, &tcp_inq);
453         assert(tcp_inq == 1);
454
455         close(fd);
456 }
457
458 static int xaccept(int s)
459 {
460         int fd = accept(s, NULL, 0);
461
462         if (fd < 0)
463                 die_perror("accept");
464
465         return fd;
466 }
467
468 static int server(int unixfd)
469 {
470         int fd = -1, r, on = 1;
471
472         switch (pf) {
473         case AF_INET:
474                 fd = sock_listen_mptcp("127.0.0.1", "15432");
475                 break;
476         case AF_INET6:
477                 fd = sock_listen_mptcp("::1", "15432");
478                 break;
479         default:
480                 xerror("Unknown pf %d\n", pf);
481                 break;
482         }
483
484         r = write(unixfd, "conn", 4);
485         assert(r == 4);
486
487         alarm(15);
488         r = xaccept(fd);
489
490         if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
491                 die_perror("setsockopt");
492
493         process_one_client(r, unixfd);
494
495         return 0;
496 }
497
498 static int client(int unixfd)
499 {
500         int fd = -1;
501
502         alarm(15);
503
504         switch (pf) {
505         case AF_INET:
506                 fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
507                 break;
508         case AF_INET6:
509                 fd = sock_connect_mptcp("::1", "15432", proto_tx);
510                 break;
511         default:
512                 xerror("Unknown pf %d\n", pf);
513         }
514
515         connect_one_server(fd, unixfd);
516
517         return 0;
518 }
519
520 static void init_rng(void)
521 {
522         int fd = open("/dev/urandom", O_RDONLY);
523         unsigned int foo;
524
525         if (fd > 0) {
526                 int ret = read(fd, &foo, sizeof(foo));
527
528                 if (ret < 0)
529                         srand(fd + foo);
530                 close(fd);
531         }
532
533         srand(foo);
534 }
535
536 static pid_t xfork(void)
537 {
538         pid_t p = fork();
539
540         if (p < 0)
541                 die_perror("fork");
542         else if (p == 0)
543                 init_rng();
544
545         return p;
546 }
547
548 static int rcheck(int wstatus, const char *what)
549 {
550         if (WIFEXITED(wstatus)) {
551                 if (WEXITSTATUS(wstatus) == 0)
552                         return 0;
553                 fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
554                 return WEXITSTATUS(wstatus);
555         } else if (WIFSIGNALED(wstatus)) {
556                 xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
557         } else if (WIFSTOPPED(wstatus)) {
558                 xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
559         }
560
561         return 111;
562 }
563
564 int main(int argc, char *argv[])
565 {
566         int e1, e2, wstatus;
567         pid_t s, c, ret;
568         int unixfds[2];
569
570         parse_opts(argc, argv);
571
572         e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
573         if (e1 < 0)
574                 die_perror("pipe");
575
576         s = xfork();
577         if (s == 0)
578                 return server(unixfds[1]);
579
580         close(unixfds[1]);
581
582         /* wait until server bound a socket */
583         e1 = read(unixfds[0], &e1, 4);
584         assert(e1 == 4);
585
586         c = xfork();
587         if (c == 0)
588                 return client(unixfds[0]);
589
590         close(unixfds[0]);
591
592         ret = waitpid(s, &wstatus, 0);
593         if (ret == -1)
594                 die_perror("waitpid");
595         e1 = rcheck(wstatus, "server");
596         ret = waitpid(c, &wstatus, 0);
597         if (ret == -1)
598                 die_perror("waitpid");
599         e2 = rcheck(wstatus, "client");
600
601         return e1 ? e1 : e2;
602 }