685cbe8933ded472d7aa71b212a0287ff5cdf1f8
[platform/kernel/linux-starfive.git] / tools / testing / selftests / net / nettest.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* nettest - used for functional tests of networking APIs
3  *
4  * Copyright (c) 2013-2019 David Ahern <dsahern@gmail.com>. All rights reserved.
5  */
6
7 #define _GNU_SOURCE
8 #include <features.h>
9 #include <sys/types.h>
10 #include <sys/ioctl.h>
11 #include <sys/socket.h>
12 #include <sys/wait.h>
13 #include <linux/tcp.h>
14 #include <arpa/inet.h>
15 #include <net/if.h>
16 #include <netinet/in.h>
17 #include <netdb.h>
18 #include <fcntl.h>
19 #include <libgen.h>
20 #include <limits.h>
21 #include <sched.h>
22 #include <stdarg.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <unistd.h>
27 #include <time.h>
28 #include <errno.h>
29
30 #ifndef IPV6_UNICAST_IF
31 #define IPV6_UNICAST_IF         76
32 #endif
33 #ifndef IPV6_MULTICAST_IF
34 #define IPV6_MULTICAST_IF       17
35 #endif
36
37 #define DEFAULT_PORT 12345
38
39 #define NS_PREFIX "/run/netns/"
40
41 #ifndef MAX
42 #define MAX(a, b)  ((a) > (b) ? (a) : (b))
43 #endif
44 #ifndef MIN
45 #define MIN(a, b)  ((a) < (b) ? (a) : (b))
46 #endif
47
48 struct sock_args {
49         /* local address */
50         const char *local_addr_str;
51         union {
52                 struct in_addr  in;
53                 struct in6_addr in6;
54         } local_addr;
55
56         /* remote address */
57         const char *remote_addr_str;
58         union {
59                 struct in_addr  in;
60                 struct in6_addr in6;
61         } remote_addr;
62         int scope_id;  /* remote scope; v6 send only */
63
64         struct in_addr grp;     /* multicast group */
65
66         unsigned int has_local_ip:1,
67                      has_remote_ip:1,
68                      has_grp:1,
69                      has_expected_laddr:1,
70                      has_expected_raddr:1,
71                      bind_test_only:1;
72
73         unsigned short port;
74
75         int type;      /* DGRAM, STREAM, RAW */
76         int protocol;
77         int version;   /* AF_INET/AF_INET6 */
78
79         int use_setsockopt;
80         int use_cmsg;
81         const char *dev;
82         int ifindex;
83
84         const char *clientns;
85         const char *serverns;
86
87         const char *password;
88         /* prefix for MD5 password */
89         const char *md5_prefix_str;
90         union {
91                 struct sockaddr_in v4;
92                 struct sockaddr_in6 v6;
93         } md5_prefix;
94         unsigned int prefix_len;
95
96         /* expected addresses and device index for connection */
97         const char *expected_dev;
98         int expected_ifindex;
99
100         /* local address */
101         const char *expected_laddr_str;
102         union {
103                 struct in_addr  in;
104                 struct in6_addr in6;
105         } expected_laddr;
106
107         /* remote address */
108         const char *expected_raddr_str;
109         union {
110                 struct in_addr  in;
111                 struct in6_addr in6;
112         } expected_raddr;
113 };
114
115 static int server_mode;
116 static unsigned int prog_timeout = 5;
117 static unsigned int interactive;
118 static int iter = 1;
119 static char *msg = "Hello world!";
120 static int msglen;
121 static int quiet;
122 static int try_broadcast = 1;
123
124 static char *timestamp(char *timebuf, int buflen)
125 {
126         time_t now;
127
128         now = time(NULL);
129         if (strftime(timebuf, buflen, "%T", localtime(&now)) == 0) {
130                 memset(timebuf, 0, buflen);
131                 strncpy(timebuf, "00:00:00", buflen-1);
132         }
133
134         return timebuf;
135 }
136
137 static void log_msg(const char *format, ...)
138 {
139         char timebuf[64];
140         va_list args;
141
142         if (quiet)
143                 return;
144
145         fprintf(stdout, "%s %s:",
146                 timestamp(timebuf, sizeof(timebuf)),
147                 server_mode ? "server" : "client");
148         va_start(args, format);
149         vfprintf(stdout, format, args);
150         va_end(args);
151
152         fflush(stdout);
153 }
154
155 static void log_error(const char *format, ...)
156 {
157         char timebuf[64];
158         va_list args;
159
160         if (quiet)
161                 return;
162
163         fprintf(stderr, "%s %s:",
164                 timestamp(timebuf, sizeof(timebuf)),
165                 server_mode ? "server" : "client");
166         va_start(args, format);
167         vfprintf(stderr, format, args);
168         va_end(args);
169
170         fflush(stderr);
171 }
172
173 static void log_err_errno(const char *fmt, ...)
174 {
175         char timebuf[64];
176         va_list args;
177
178         if (quiet)
179                 return;
180
181         fprintf(stderr, "%s %s: ",
182                 timestamp(timebuf, sizeof(timebuf)),
183                 server_mode ? "server" : "client");
184         va_start(args, fmt);
185         vfprintf(stderr, fmt, args);
186         va_end(args);
187
188         fprintf(stderr, ": %d: %s\n", errno, strerror(errno));
189         fflush(stderr);
190 }
191
192 static void log_address(const char *desc, struct sockaddr *sa)
193 {
194         char addrstr[64];
195
196         if (quiet)
197                 return;
198
199         if (sa->sa_family == AF_INET) {
200                 struct sockaddr_in *s = (struct sockaddr_in *) sa;
201
202                 log_msg("%s %s:%d",
203                         desc,
204                         inet_ntop(AF_INET, &s->sin_addr, addrstr,
205                                   sizeof(addrstr)),
206                         ntohs(s->sin_port));
207
208         } else if (sa->sa_family == AF_INET6) {
209                 struct sockaddr_in6 *s6 = (struct sockaddr_in6 *) sa;
210
211                 log_msg("%s [%s]:%d",
212                         desc,
213                         inet_ntop(AF_INET6, &s6->sin6_addr, addrstr,
214                                   sizeof(addrstr)),
215                         ntohs(s6->sin6_port));
216         }
217
218         printf("\n");
219
220         fflush(stdout);
221 }
222
223 static int switch_ns(const char *ns)
224 {
225         char path[PATH_MAX];
226         int fd, ret;
227
228         if (geteuid())
229                 log_error("warning: likely need root to set netns %s!\n", ns);
230
231         snprintf(path, sizeof(path), "%s%s", NS_PREFIX, ns);
232         fd = open(path, 0);
233         if (fd < 0) {
234                 log_err_errno("Failed to open netns path; can not switch netns");
235                 return 1;
236         }
237
238         ret = setns(fd, CLONE_NEWNET);
239         close(fd);
240
241         return ret;
242 }
243
244 static int tcp_md5sig(int sd, void *addr, socklen_t alen, struct sock_args *args)
245 {
246         int keylen = strlen(args->password);
247         struct tcp_md5sig md5sig = {};
248         int opt = TCP_MD5SIG;
249         int rc;
250
251         md5sig.tcpm_keylen = keylen;
252         memcpy(md5sig.tcpm_key, args->password, keylen);
253
254         if (args->prefix_len) {
255                 opt = TCP_MD5SIG_EXT;
256                 md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_PREFIX;
257
258                 md5sig.tcpm_prefixlen = args->prefix_len;
259                 addr = &args->md5_prefix;
260         }
261         memcpy(&md5sig.tcpm_addr, addr, alen);
262
263         if (args->ifindex) {
264                 opt = TCP_MD5SIG_EXT;
265                 md5sig.tcpm_flags |= TCP_MD5SIG_FLAG_IFINDEX;
266
267                 md5sig.tcpm_ifindex = args->ifindex;
268         }
269
270         rc = setsockopt(sd, IPPROTO_TCP, opt, &md5sig, sizeof(md5sig));
271         if (rc < 0) {
272                 /* ENOENT is harmless. Returned when a password is cleared */
273                 if (errno == ENOENT)
274                         rc = 0;
275                 else
276                         log_err_errno("setsockopt(TCP_MD5SIG)");
277         }
278
279         return rc;
280 }
281
282 static int tcp_md5_remote(int sd, struct sock_args *args)
283 {
284         struct sockaddr_in sin = {
285                 .sin_family = AF_INET,
286         };
287         struct sockaddr_in6 sin6 = {
288                 .sin6_family = AF_INET6,
289         };
290         void *addr;
291         int alen;
292
293         switch (args->version) {
294         case AF_INET:
295                 sin.sin_port = htons(args->port);
296                 sin.sin_addr = args->remote_addr.in;
297                 addr = &sin;
298                 alen = sizeof(sin);
299                 break;
300         case AF_INET6:
301                 sin6.sin6_port = htons(args->port);
302                 sin6.sin6_addr = args->remote_addr.in6;
303                 addr = &sin6;
304                 alen = sizeof(sin6);
305                 break;
306         default:
307                 log_error("unknown address family\n");
308                 exit(1);
309         }
310
311         if (tcp_md5sig(sd, addr, alen, args))
312                 return -1;
313
314         return 0;
315 }
316
317 static int get_ifidx(const char *ifname)
318 {
319         struct ifreq ifdata;
320         int sd, rc;
321
322         if (!ifname || *ifname == '\0')
323                 return -1;
324
325         memset(&ifdata, 0, sizeof(ifdata));
326
327         strcpy(ifdata.ifr_name, ifname);
328
329         sd = socket(PF_INET, SOCK_DGRAM, IPPROTO_IP);
330         if (sd < 0) {
331                 log_err_errno("socket failed");
332                 return -1;
333         }
334
335         rc = ioctl(sd, SIOCGIFINDEX, (char *)&ifdata);
336         close(sd);
337         if (rc != 0) {
338                 log_err_errno("ioctl(SIOCGIFINDEX) failed");
339                 return -1;
340         }
341
342         return ifdata.ifr_ifindex;
343 }
344
345 static int bind_to_device(int sd, const char *name)
346 {
347         int rc;
348
349         rc = setsockopt(sd, SOL_SOCKET, SO_BINDTODEVICE, name, strlen(name)+1);
350         if (rc < 0)
351                 log_err_errno("setsockopt(SO_BINDTODEVICE)");
352
353         return rc;
354 }
355
356 static int get_bind_to_device(int sd, char *name, size_t len)
357 {
358         int rc;
359         socklen_t optlen = len;
360
361         name[0] = '\0';
362         rc = getsockopt(sd, SOL_SOCKET, SO_BINDTODEVICE, name, &optlen);
363         if (rc < 0)
364                 log_err_errno("setsockopt(SO_BINDTODEVICE)");
365
366         return rc;
367 }
368
369 static int check_device(int sd, struct sock_args *args)
370 {
371         int ifindex = 0;
372         char name[32];
373
374         if (get_bind_to_device(sd, name, sizeof(name)))
375                 *name = '\0';
376         else
377                 ifindex = get_ifidx(name);
378
379         log_msg("    bound to device %s/%d\n",
380                 *name ? name : "<none>", ifindex);
381
382         if (!args->expected_ifindex)
383                 return 0;
384
385         if (args->expected_ifindex != ifindex) {
386                 log_error("Device index mismatch: expected %d have %d\n",
387                           args->expected_ifindex, ifindex);
388                 return 1;
389         }
390
391         log_msg("Device index matches: expected %d have %d\n",
392                 args->expected_ifindex, ifindex);
393
394         return 0;
395 }
396
397 static int set_pktinfo_v4(int sd)
398 {
399         int one = 1;
400         int rc;
401
402         rc = setsockopt(sd, SOL_IP, IP_PKTINFO, &one, sizeof(one));
403         if (rc < 0 && rc != -ENOTSUP)
404                 log_err_errno("setsockopt(IP_PKTINFO)");
405
406         return rc;
407 }
408
409 static int set_recvpktinfo_v6(int sd)
410 {
411         int one = 1;
412         int rc;
413
414         rc = setsockopt(sd, SOL_IPV6, IPV6_RECVPKTINFO, &one, sizeof(one));
415         if (rc < 0 && rc != -ENOTSUP)
416                 log_err_errno("setsockopt(IPV6_RECVPKTINFO)");
417
418         return rc;
419 }
420
421 static int set_recverr_v4(int sd)
422 {
423         int one = 1;
424         int rc;
425
426         rc = setsockopt(sd, SOL_IP, IP_RECVERR, &one, sizeof(one));
427         if (rc < 0 && rc != -ENOTSUP)
428                 log_err_errno("setsockopt(IP_RECVERR)");
429
430         return rc;
431 }
432
433 static int set_recverr_v6(int sd)
434 {
435         int one = 1;
436         int rc;
437
438         rc = setsockopt(sd, SOL_IPV6, IPV6_RECVERR, &one, sizeof(one));
439         if (rc < 0 && rc != -ENOTSUP)
440                 log_err_errno("setsockopt(IPV6_RECVERR)");
441
442         return rc;
443 }
444
445 static int set_unicast_if(int sd, int ifindex, int version)
446 {
447         int opt = IP_UNICAST_IF;
448         int level = SOL_IP;
449         int rc;
450
451         ifindex = htonl(ifindex);
452
453         if (version == AF_INET6) {
454                 opt = IPV6_UNICAST_IF;
455                 level = SOL_IPV6;
456         }
457         rc = setsockopt(sd, level, opt, &ifindex, sizeof(ifindex));
458         if (rc < 0)
459                 log_err_errno("setsockopt(IP_UNICAST_IF)");
460
461         return rc;
462 }
463
464 static int set_multicast_if(int sd, int ifindex)
465 {
466         struct ip_mreqn mreq = { .imr_ifindex = ifindex };
467         int rc;
468
469         rc = setsockopt(sd, SOL_IP, IP_MULTICAST_IF, &mreq, sizeof(mreq));
470         if (rc < 0)
471                 log_err_errno("setsockopt(IP_MULTICAST_IF)");
472
473         return rc;
474 }
475
476 static int set_membership(int sd, uint32_t grp, uint32_t addr, int ifindex)
477 {
478         uint32_t if_addr = addr;
479         struct ip_mreqn mreq;
480         int rc;
481
482         if (addr == htonl(INADDR_ANY) && !ifindex) {
483                 log_error("Either local address or device needs to be given for multicast membership\n");
484                 return -1;
485         }
486
487         mreq.imr_multiaddr.s_addr = grp;
488         mreq.imr_address.s_addr = if_addr;
489         mreq.imr_ifindex = ifindex;
490
491         rc = setsockopt(sd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &mreq, sizeof(mreq));
492         if (rc < 0) {
493                 log_err_errno("setsockopt(IP_ADD_MEMBERSHIP)");
494                 return -1;
495         }
496
497         return 0;
498 }
499
500 static int set_broadcast(int sd)
501 {
502         unsigned int one = 1;
503         int rc = 0;
504
505         if (setsockopt(sd, SOL_SOCKET, SO_BROADCAST, &one, sizeof(one)) != 0) {
506                 log_err_errno("setsockopt(SO_BROADCAST)");
507                 rc = -1;
508         }
509
510         return rc;
511 }
512
513 static int set_reuseport(int sd)
514 {
515         unsigned int one = 1;
516         int rc = 0;
517
518         if (setsockopt(sd, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one)) != 0) {
519                 log_err_errno("setsockopt(SO_REUSEPORT)");
520                 rc = -1;
521         }
522
523         return rc;
524 }
525
526 static int set_reuseaddr(int sd)
527 {
528         unsigned int one = 1;
529         int rc = 0;
530
531         if (setsockopt(sd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) != 0) {
532                 log_err_errno("setsockopt(SO_REUSEADDR)");
533                 rc = -1;
534         }
535
536         return rc;
537 }
538
539 static int str_to_uint(const char *str, int min, int max, unsigned int *value)
540 {
541         int number;
542         char *end;
543
544         errno = 0;
545         number = (unsigned int) strtoul(str, &end, 0);
546
547         /* entire string should be consumed by conversion
548          * and value should be between min and max
549          */
550         if (((*end == '\0') || (*end == '\n')) && (end != str) &&
551             (errno != ERANGE) && (min <= number) && (number <= max)) {
552                 *value = number;
553                 return 0;
554         }
555
556         return -1;
557 }
558
559 static int resolve_devices(struct sock_args *args)
560 {
561         if (args->dev) {
562                 args->ifindex = get_ifidx(args->dev);
563                 if (args->ifindex < 0) {
564                         log_error("Invalid device name\n");
565                         return 1;
566                 }
567         }
568
569         if (args->expected_dev) {
570                 unsigned int tmp;
571
572                 if (str_to_uint(args->expected_dev, 0, INT_MAX, &tmp) == 0) {
573                         args->expected_ifindex = (int)tmp;
574                 } else {
575                         args->expected_ifindex = get_ifidx(args->expected_dev);
576                         if (args->expected_ifindex < 0) {
577                                 fprintf(stderr, "Invalid expected device\n");
578                                 return 1;
579                         }
580                 }
581         }
582
583         return 0;
584 }
585
586 static int expected_addr_match(struct sockaddr *sa, void *expected,
587                                const char *desc)
588 {
589         char addrstr[64];
590         int rc = 0;
591
592         if (sa->sa_family == AF_INET) {
593                 struct sockaddr_in *s = (struct sockaddr_in *) sa;
594                 struct in_addr *exp_in = (struct in_addr *) expected;
595
596                 if (s->sin_addr.s_addr != exp_in->s_addr) {
597                         log_error("%s address does not match expected %s",
598                                   desc,
599                                   inet_ntop(AF_INET, exp_in,
600                                             addrstr, sizeof(addrstr)));
601                         rc = 1;
602                 }
603         } else if (sa->sa_family == AF_INET6) {
604                 struct sockaddr_in6 *s6 = (struct sockaddr_in6 *) sa;
605                 struct in6_addr *exp_in = (struct in6_addr *) expected;
606
607                 if (memcmp(&s6->sin6_addr, exp_in, sizeof(*exp_in))) {
608                         log_error("%s address does not match expected %s",
609                                   desc,
610                                   inet_ntop(AF_INET6, exp_in,
611                                             addrstr, sizeof(addrstr)));
612                         rc = 1;
613                 }
614         } else {
615                 log_error("%s address does not match expected - unknown family",
616                           desc);
617                 rc = 1;
618         }
619
620         if (!rc)
621                 log_msg("%s address matches expected\n", desc);
622
623         return rc;
624 }
625
626 static int show_sockstat(int sd, struct sock_args *args)
627 {
628         struct sockaddr_in6 local_addr, remote_addr;
629         socklen_t alen = sizeof(local_addr);
630         struct sockaddr *sa;
631         const char *desc;
632         int rc = 0;
633
634         desc = server_mode ? "server local:" : "client local:";
635         sa = (struct sockaddr *) &local_addr;
636         if (getsockname(sd, sa, &alen) == 0) {
637                 log_address(desc, sa);
638
639                 if (args->has_expected_laddr) {
640                         rc = expected_addr_match(sa, &args->expected_laddr,
641                                                  "local");
642                 }
643         } else {
644                 log_err_errno("getsockname failed");
645         }
646
647         sa = (struct sockaddr *) &remote_addr;
648         desc = server_mode ? "server peer:" : "client peer:";
649         if (getpeername(sd, sa, &alen) == 0) {
650                 log_address(desc, sa);
651
652                 if (args->has_expected_raddr) {
653                         rc |= expected_addr_match(sa, &args->expected_raddr,
654                                                  "remote");
655                 }
656         } else {
657                 log_err_errno("getpeername failed");
658         }
659
660         return rc;
661 }
662
663 enum addr_type {
664         ADDR_TYPE_LOCAL,
665         ADDR_TYPE_REMOTE,
666         ADDR_TYPE_MCAST,
667         ADDR_TYPE_EXPECTED_LOCAL,
668         ADDR_TYPE_EXPECTED_REMOTE,
669         ADDR_TYPE_MD5_PREFIX,
670 };
671
672 static int convert_addr(struct sock_args *args, const char *_str,
673                         enum addr_type atype)
674 {
675         int pfx_len_max = args->version == AF_INET6 ? 128 : 32;
676         int family = args->version;
677         char *str, *dev, *sep;
678         struct in6_addr *in6;
679         struct in_addr  *in;
680         const char *desc;
681         void *addr;
682         int rc = 0;
683
684         str = strdup(_str);
685         if (!str)
686                 return -ENOMEM;
687
688         switch (atype) {
689         case ADDR_TYPE_LOCAL:
690                 desc = "local";
691                 addr = &args->local_addr;
692                 break;
693         case ADDR_TYPE_REMOTE:
694                 desc = "remote";
695                 addr = &args->remote_addr;
696                 break;
697         case ADDR_TYPE_MCAST:
698                 desc = "mcast grp";
699                 addr = &args->grp;
700                 break;
701         case ADDR_TYPE_EXPECTED_LOCAL:
702                 desc = "expected local";
703                 addr = &args->expected_laddr;
704                 break;
705         case ADDR_TYPE_EXPECTED_REMOTE:
706                 desc = "expected remote";
707                 addr = &args->expected_raddr;
708                 break;
709         case ADDR_TYPE_MD5_PREFIX:
710                 desc = "md5 prefix";
711                 if (family == AF_INET) {
712                         args->md5_prefix.v4.sin_family = AF_INET;
713                         addr = &args->md5_prefix.v4.sin_addr;
714                 } else if (family == AF_INET6) {
715                         args->md5_prefix.v6.sin6_family = AF_INET6;
716                         addr = &args->md5_prefix.v6.sin6_addr;
717                 } else
718                         return 1;
719
720                 sep = strchr(str, '/');
721                 if (sep) {
722                         *sep = '\0';
723                         sep++;
724                         if (str_to_uint(sep, 1, pfx_len_max,
725                                         &args->prefix_len) != 0) {
726                                 fprintf(stderr, "Invalid port\n");
727                                 return 1;
728                         }
729                 } else {
730                         args->prefix_len = pfx_len_max;
731                 }
732                 break;
733         default:
734                 log_error("unknown address type");
735                 exit(1);
736         }
737
738         switch (family) {
739         case AF_INET:
740                 in  = (struct in_addr *) addr;
741                 if (str) {
742                         if (inet_pton(AF_INET, str, in) == 0) {
743                                 log_error("Invalid %s IP address\n", desc);
744                                 rc = -1;
745                                 goto out;
746                         }
747                 } else {
748                         in->s_addr = htonl(INADDR_ANY);
749                 }
750                 break;
751
752         case AF_INET6:
753                 dev = strchr(str, '%');
754                 if (dev) {
755                         *dev = '\0';
756                         dev++;
757                 }
758
759                 in6 = (struct in6_addr *) addr;
760                 if (str) {
761                         if (inet_pton(AF_INET6, str, in6) == 0) {
762                                 log_error("Invalid %s IPv6 address\n", desc);
763                                 rc = -1;
764                                 goto out;
765                         }
766                 } else {
767                         *in6 = in6addr_any;
768                 }
769                 if (dev) {
770                         args->scope_id = get_ifidx(dev);
771                         if (args->scope_id < 0) {
772                                 log_error("Invalid scope on %s IPv6 address\n",
773                                           desc);
774                                 rc = -1;
775                                 goto out;
776                         }
777                 }
778                 break;
779
780         default:
781                 log_error("Invalid address family\n");
782         }
783
784 out:
785         free(str);
786         return rc;
787 }
788
789 static int validate_addresses(struct sock_args *args)
790 {
791         if (args->local_addr_str &&
792             convert_addr(args, args->local_addr_str, ADDR_TYPE_LOCAL) < 0)
793                 return 1;
794
795         if (args->remote_addr_str &&
796             convert_addr(args, args->remote_addr_str, ADDR_TYPE_REMOTE) < 0)
797                 return 1;
798
799         if (args->md5_prefix_str &&
800             convert_addr(args, args->md5_prefix_str,
801                          ADDR_TYPE_MD5_PREFIX) < 0)
802                 return 1;
803
804         if (args->expected_laddr_str &&
805             convert_addr(args, args->expected_laddr_str,
806                          ADDR_TYPE_EXPECTED_LOCAL))
807                 return 1;
808
809         if (args->expected_raddr_str &&
810             convert_addr(args, args->expected_raddr_str,
811                          ADDR_TYPE_EXPECTED_REMOTE))
812                 return 1;
813
814         return 0;
815 }
816
817 static int get_index_from_cmsg(struct msghdr *m)
818 {
819         struct cmsghdr *cm;
820         int ifindex = 0;
821         char buf[64];
822
823         for (cm = (struct cmsghdr *)CMSG_FIRSTHDR(m);
824              m->msg_controllen != 0 && cm;
825              cm = (struct cmsghdr *)CMSG_NXTHDR(m, cm)) {
826
827                 if (cm->cmsg_level == SOL_IP &&
828                     cm->cmsg_type == IP_PKTINFO) {
829                         struct in_pktinfo *pi;
830
831                         pi = (struct in_pktinfo *)(CMSG_DATA(cm));
832                         inet_ntop(AF_INET, &pi->ipi_addr, buf, sizeof(buf));
833                         ifindex = pi->ipi_ifindex;
834                 } else if (cm->cmsg_level == SOL_IPV6 &&
835                            cm->cmsg_type == IPV6_PKTINFO) {
836                         struct in6_pktinfo *pi6;
837
838                         pi6 = (struct in6_pktinfo *)(CMSG_DATA(cm));
839                         inet_ntop(AF_INET6, &pi6->ipi6_addr, buf, sizeof(buf));
840                         ifindex = pi6->ipi6_ifindex;
841                 }
842         }
843
844         if (ifindex) {
845                 log_msg("    pktinfo: ifindex %d dest addr %s\n",
846                         ifindex, buf);
847         }
848         return ifindex;
849 }
850
851 static int send_msg_no_cmsg(int sd, void *addr, socklen_t alen)
852 {
853         int err;
854
855 again:
856         err = sendto(sd, msg, msglen, 0, addr, alen);
857         if (err < 0) {
858                 if (errno == EACCES && try_broadcast) {
859                         try_broadcast = 0;
860                         if (!set_broadcast(sd))
861                                 goto again;
862                         errno = EACCES;
863                 }
864
865                 log_err_errno("sendto failed");
866                 return 1;
867         }
868
869         return 0;
870 }
871
872 static int send_msg_cmsg(int sd, void *addr, socklen_t alen,
873                          int ifindex, int version)
874 {
875         unsigned char cmsgbuf[64];
876         struct iovec iov[2];
877         struct cmsghdr *cm;
878         struct msghdr m;
879         int err;
880
881         iov[0].iov_base = msg;
882         iov[0].iov_len = msglen;
883         m.msg_iov = iov;
884         m.msg_iovlen = 1;
885         m.msg_name = (caddr_t)addr;
886         m.msg_namelen = alen;
887
888         memset(cmsgbuf, 0, sizeof(cmsgbuf));
889         cm = (struct cmsghdr *)cmsgbuf;
890         m.msg_control = (caddr_t)cm;
891
892         if (version == AF_INET) {
893                 struct in_pktinfo *pi;
894
895                 cm->cmsg_level = SOL_IP;
896                 cm->cmsg_type = IP_PKTINFO;
897                 cm->cmsg_len = CMSG_LEN(sizeof(struct in_pktinfo));
898                 pi = (struct in_pktinfo *)(CMSG_DATA(cm));
899                 pi->ipi_ifindex = ifindex;
900
901                 m.msg_controllen = cm->cmsg_len;
902
903         } else if (version == AF_INET6) {
904                 struct in6_pktinfo *pi6;
905
906                 cm->cmsg_level = SOL_IPV6;
907                 cm->cmsg_type = IPV6_PKTINFO;
908                 cm->cmsg_len = CMSG_LEN(sizeof(struct in6_pktinfo));
909
910                 pi6 = (struct in6_pktinfo *)(CMSG_DATA(cm));
911                 pi6->ipi6_ifindex = ifindex;
912
913                 m.msg_controllen = cm->cmsg_len;
914         }
915
916 again:
917         err = sendmsg(sd, &m, 0);
918         if (err < 0) {
919                 if (errno == EACCES && try_broadcast) {
920                         try_broadcast = 0;
921                         if (!set_broadcast(sd))
922                                 goto again;
923                         errno = EACCES;
924                 }
925
926                 log_err_errno("sendmsg failed");
927                 return 1;
928         }
929
930         return 0;
931 }
932
933
934 static int send_msg(int sd, void *addr, socklen_t alen, struct sock_args *args)
935 {
936         if (args->type == SOCK_STREAM) {
937                 if (write(sd, msg, msglen) < 0) {
938                         log_err_errno("write failed sending msg to peer");
939                         return 1;
940                 }
941         } else if (args->ifindex && args->use_cmsg) {
942                 if (send_msg_cmsg(sd, addr, alen, args->ifindex, args->version))
943                         return 1;
944         } else {
945                 if (send_msg_no_cmsg(sd, addr, alen))
946                         return 1;
947         }
948
949         log_msg("Sent message:\n");
950         log_msg("    %.24s%s\n", msg, msglen > 24 ? " ..." : "");
951
952         return 0;
953 }
954
955 static int socket_read_dgram(int sd, struct sock_args *args)
956 {
957         unsigned char addr[sizeof(struct sockaddr_in6)];
958         struct sockaddr *sa = (struct sockaddr *) addr;
959         socklen_t alen = sizeof(addr);
960         struct iovec iov[2];
961         struct msghdr m = {
962                 .msg_name = (caddr_t)addr,
963                 .msg_namelen = alen,
964                 .msg_iov = iov,
965                 .msg_iovlen = 1,
966         };
967         unsigned char cmsgbuf[256];
968         struct cmsghdr *cm = (struct cmsghdr *)cmsgbuf;
969         char buf[16*1024];
970         int ifindex;
971         int len;
972
973         iov[0].iov_base = (caddr_t)buf;
974         iov[0].iov_len = sizeof(buf);
975
976         memset(cmsgbuf, 0, sizeof(cmsgbuf));
977         m.msg_control = (caddr_t)cm;
978         m.msg_controllen = sizeof(cmsgbuf);
979
980         len = recvmsg(sd, &m, 0);
981         if (len == 0) {
982                 log_msg("peer closed connection.\n");
983                 return 0;
984         } else if (len < 0) {
985                 log_msg("failed to read message: %d: %s\n",
986                         errno, strerror(errno));
987                 return -1;
988         }
989
990         buf[len] = '\0';
991
992         log_address("Message from:", sa);
993         log_msg("    %.24s%s\n", buf, len > 24 ? " ..." : "");
994
995         ifindex = get_index_from_cmsg(&m);
996         if (args->expected_ifindex) {
997                 if (args->expected_ifindex != ifindex) {
998                         log_error("Device index mismatch: expected %d have %d\n",
999                                   args->expected_ifindex, ifindex);
1000                         return -1;
1001                 }
1002                 log_msg("Device index matches: expected %d have %d\n",
1003                         args->expected_ifindex, ifindex);
1004         }
1005
1006         if (!interactive && server_mode) {
1007                 if (sa->sa_family == AF_INET6) {
1008                         struct sockaddr_in6 *s6 = (struct sockaddr_in6 *) sa;
1009                         struct in6_addr *in6 = &s6->sin6_addr;
1010
1011                         if (IN6_IS_ADDR_V4MAPPED(in6)) {
1012                                 const uint32_t *pa = (uint32_t *) &in6->s6_addr;
1013                                 struct in_addr in4;
1014                                 struct sockaddr_in *sin;
1015
1016                                 sin = (struct sockaddr_in *) addr;
1017                                 pa += 3;
1018                                 in4.s_addr = *pa;
1019                                 sin->sin_addr = in4;
1020                                 sin->sin_family = AF_INET;
1021                                 if (send_msg_cmsg(sd, addr, alen,
1022                                                   ifindex, AF_INET) < 0)
1023                                         goto out_err;
1024                         }
1025                 }
1026 again:
1027                 iov[0].iov_len = len;
1028
1029                 if (args->version == AF_INET6) {
1030                         struct sockaddr_in6 *s6 = (struct sockaddr_in6 *) sa;
1031
1032                         if (args->dev) {
1033                                 /* avoid PKTINFO conflicts with bindtodev */
1034                                 if (sendto(sd, buf, len, 0,
1035                                            (void *) addr, alen) < 0)
1036                                         goto out_err;
1037                         } else {
1038                                 /* kernel is allowing scope_id to be set to VRF
1039                                  * index for LLA. for sends to global address
1040                                  * reset scope id
1041                                  */
1042                                 s6->sin6_scope_id = ifindex;
1043                                 if (sendmsg(sd, &m, 0) < 0)
1044                                         goto out_err;
1045                         }
1046                 } else {
1047                         int err;
1048
1049                         err = sendmsg(sd, &m, 0);
1050                         if (err < 0) {
1051                                 if (errno == EACCES && try_broadcast) {
1052                                         try_broadcast = 0;
1053                                         if (!set_broadcast(sd))
1054                                                 goto again;
1055                                         errno = EACCES;
1056                                 }
1057                                 goto out_err;
1058                         }
1059                 }
1060                 log_msg("Sent message:\n");
1061                 log_msg("    %.24s%s\n", buf, len > 24 ? " ..." : "");
1062         }
1063
1064         return 1;
1065 out_err:
1066         log_err_errno("failed to send msg to peer");
1067         return -1;
1068 }
1069
1070 static int socket_read_stream(int sd)
1071 {
1072         char buf[1024];
1073         int len;
1074
1075         len = read(sd, buf, sizeof(buf)-1);
1076         if (len == 0) {
1077                 log_msg("client closed connection.\n");
1078                 return 0;
1079         } else if (len < 0) {
1080                 log_msg("failed to read message\n");
1081                 return -1;
1082         }
1083
1084         buf[len] = '\0';
1085         log_msg("Incoming message:\n");
1086         log_msg("    %.24s%s\n", buf, len > 24 ? " ..." : "");
1087
1088         if (!interactive && server_mode) {
1089                 if (write(sd, buf, len) < 0) {
1090                         log_err_errno("failed to send buf");
1091                         return -1;
1092                 }
1093                 log_msg("Sent message:\n");
1094                 log_msg("     %.24s%s\n", buf, len > 24 ? " ..." : "");
1095         }
1096
1097         return 1;
1098 }
1099
1100 static int socket_read(int sd, struct sock_args *args)
1101 {
1102         if (args->type == SOCK_STREAM)
1103                 return socket_read_stream(sd);
1104
1105         return socket_read_dgram(sd, args);
1106 }
1107
1108 static int stdin_to_socket(int sd, int type, void *addr, socklen_t alen)
1109 {
1110         char buf[1024];
1111         int len;
1112
1113         if (fgets(buf, sizeof(buf), stdin) == NULL)
1114                 return 0;
1115
1116         len = strlen(buf);
1117         if (type == SOCK_STREAM) {
1118                 if (write(sd, buf, len) < 0) {
1119                         log_err_errno("failed to send buf");
1120                         return -1;
1121                 }
1122         } else {
1123                 int err;
1124
1125 again:
1126                 err = sendto(sd, buf, len, 0, addr, alen);
1127                 if (err < 0) {
1128                         if (errno == EACCES && try_broadcast) {
1129                                 try_broadcast = 0;
1130                                 if (!set_broadcast(sd))
1131                                         goto again;
1132                                 errno = EACCES;
1133                         }
1134                         log_err_errno("failed to send msg to peer");
1135                         return -1;
1136                 }
1137         }
1138         log_msg("Sent message:\n");
1139         log_msg("    %.24s%s\n", buf, len > 24 ? " ..." : "");
1140
1141         return 1;
1142 }
1143
1144 static void set_recv_attr(int sd, int version)
1145 {
1146         if (version == AF_INET6) {
1147                 set_recvpktinfo_v6(sd);
1148                 set_recverr_v6(sd);
1149         } else {
1150                 set_pktinfo_v4(sd);
1151                 set_recverr_v4(sd);
1152         }
1153 }
1154
1155 static int msg_loop(int client, int sd, void *addr, socklen_t alen,
1156                     struct sock_args *args)
1157 {
1158         struct timeval timeout = { .tv_sec = prog_timeout }, *ptval = NULL;
1159         fd_set rfds;
1160         int nfds;
1161         int rc;
1162
1163         if (args->type != SOCK_STREAM)
1164                 set_recv_attr(sd, args->version);
1165
1166         if (msg) {
1167                 msglen = strlen(msg);
1168
1169                 /* client sends first message */
1170                 if (client) {
1171                         if (send_msg(sd, addr, alen, args))
1172                                 return 1;
1173                 }
1174                 if (!interactive) {
1175                         ptval = &timeout;
1176                         if (!prog_timeout)
1177                                 timeout.tv_sec = 5;
1178                 }
1179         }
1180
1181         nfds = interactive ? MAX(fileno(stdin), sd)  + 1 : sd + 1;
1182         while (1) {
1183                 FD_ZERO(&rfds);
1184                 FD_SET(sd, &rfds);
1185                 if (interactive)
1186                         FD_SET(fileno(stdin), &rfds);
1187
1188                 rc = select(nfds, &rfds, NULL, NULL, ptval);
1189                 if (rc < 0) {
1190                         if (errno == EINTR)
1191                                 continue;
1192
1193                         rc = 1;
1194                         log_err_errno("select failed");
1195                         break;
1196                 } else if (rc == 0) {
1197                         log_error("Timed out waiting for response\n");
1198                         rc = 2;
1199                         break;
1200                 }
1201
1202                 if (FD_ISSET(sd, &rfds)) {
1203                         rc = socket_read(sd, args);
1204                         if (rc < 0) {
1205                                 rc = 1;
1206                                 break;
1207                         }
1208                         if (rc == 0)
1209                                 break;
1210                 }
1211
1212                 rc = 0;
1213
1214                 if (FD_ISSET(fileno(stdin), &rfds)) {
1215                         if (stdin_to_socket(sd, args->type, addr, alen) <= 0)
1216                                 break;
1217                 }
1218
1219                 if (interactive)
1220                         continue;
1221
1222                 if (iter != -1) {
1223                         --iter;
1224                         if (iter == 0)
1225                                 break;
1226                 }
1227
1228                 log_msg("Going into quiet mode\n");
1229                 quiet = 1;
1230
1231                 if (client) {
1232                         if (send_msg(sd, addr, alen, args)) {
1233                                 rc = 1;
1234                                 break;
1235                         }
1236                 }
1237         }
1238
1239         return rc;
1240 }
1241
1242 static int msock_init(struct sock_args *args, int server)
1243 {
1244         uint32_t if_addr = htonl(INADDR_ANY);
1245         struct sockaddr_in laddr = {
1246                 .sin_family = AF_INET,
1247                 .sin_port = htons(args->port),
1248         };
1249         int one = 1;
1250         int sd;
1251
1252         if (!server && args->has_local_ip)
1253                 if_addr = args->local_addr.in.s_addr;
1254
1255         sd = socket(PF_INET, SOCK_DGRAM, 0);
1256         if (sd < 0) {
1257                 log_err_errno("socket");
1258                 return -1;
1259         }
1260
1261         if (setsockopt(sd, SOL_SOCKET, SO_REUSEADDR,
1262                        (char *)&one, sizeof(one)) < 0) {
1263                 log_err_errno("Setting SO_REUSEADDR error");
1264                 goto out_err;
1265         }
1266
1267         if (setsockopt(sd, SOL_SOCKET, SO_BROADCAST,
1268                        (char *)&one, sizeof(one)) < 0)
1269                 log_err_errno("Setting SO_BROADCAST error");
1270
1271         if (args->dev && bind_to_device(sd, args->dev) != 0)
1272                 goto out_err;
1273         else if (args->use_setsockopt &&
1274                  set_multicast_if(sd, args->ifindex))
1275                 goto out_err;
1276
1277         laddr.sin_addr.s_addr = if_addr;
1278
1279         if (bind(sd, (struct sockaddr *) &laddr, sizeof(laddr)) < 0) {
1280                 log_err_errno("bind failed");
1281                 goto out_err;
1282         }
1283
1284         if (server &&
1285             set_membership(sd, args->grp.s_addr,
1286                            args->local_addr.in.s_addr, args->ifindex))
1287                 goto out_err;
1288
1289         return sd;
1290 out_err:
1291         close(sd);
1292         return -1;
1293 }
1294
1295 static int msock_server(struct sock_args *args)
1296 {
1297         return msock_init(args, 1);
1298 }
1299
1300 static int msock_client(struct sock_args *args)
1301 {
1302         return msock_init(args, 0);
1303 }
1304
1305 static int bind_socket(int sd, struct sock_args *args)
1306 {
1307         struct sockaddr_in serv_addr = {
1308                 .sin_family = AF_INET,
1309         };
1310         struct sockaddr_in6 serv6_addr = {
1311                 .sin6_family = AF_INET6,
1312         };
1313         void *addr;
1314         socklen_t alen;
1315
1316         if (!args->has_local_ip && args->type == SOCK_RAW)
1317                 return 0;
1318
1319         switch (args->version) {
1320         case AF_INET:
1321                 serv_addr.sin_port = htons(args->port);
1322                 serv_addr.sin_addr = args->local_addr.in;
1323                 addr = &serv_addr;
1324                 alen = sizeof(serv_addr);
1325                 break;
1326
1327         case AF_INET6:
1328                 serv6_addr.sin6_port = htons(args->port);
1329                 serv6_addr.sin6_addr = args->local_addr.in6;
1330                 addr = &serv6_addr;
1331                 alen = sizeof(serv6_addr);
1332                 break;
1333
1334         default:
1335                 log_error("Invalid address family\n");
1336                 return -1;
1337         }
1338
1339         if (bind(sd, addr, alen) < 0) {
1340                 log_err_errno("error binding socket");
1341                 return -1;
1342         }
1343
1344         return 0;
1345 }
1346
1347 static int lsock_init(struct sock_args *args)
1348 {
1349         long flags;
1350         int sd;
1351
1352         sd = socket(args->version, args->type, args->protocol);
1353         if (sd < 0) {
1354                 log_err_errno("Error opening socket");
1355                 return  -1;
1356         }
1357
1358         if (set_reuseaddr(sd) != 0)
1359                 goto err;
1360
1361         if (set_reuseport(sd) != 0)
1362                 goto err;
1363
1364         if (args->dev && bind_to_device(sd, args->dev) != 0)
1365                 goto err;
1366         else if (args->use_setsockopt &&
1367                  set_unicast_if(sd, args->ifindex, args->version))
1368                 goto err;
1369
1370         if (bind_socket(sd, args))
1371                 goto err;
1372
1373         if (args->bind_test_only)
1374                 goto out;
1375
1376         if (args->type == SOCK_STREAM && listen(sd, 1) < 0) {
1377                 log_err_errno("listen failed");
1378                 goto err;
1379         }
1380
1381         flags = fcntl(sd, F_GETFL);
1382         if ((flags < 0) || (fcntl(sd, F_SETFL, flags|O_NONBLOCK) < 0)) {
1383                 log_err_errno("Failed to set non-blocking option");
1384                 goto err;
1385         }
1386
1387         if (fcntl(sd, F_SETFD, FD_CLOEXEC) < 0)
1388                 log_err_errno("Failed to set close-on-exec flag");
1389
1390 out:
1391         return sd;
1392
1393 err:
1394         close(sd);
1395         return -1;
1396 }
1397
1398 static void ipc_write(int fd, int message)
1399 {
1400         /* Not in both_mode, so there's no process to signal */
1401         if (fd < 0)
1402                 return;
1403
1404         if (write(fd, &message, sizeof(message)) < 0)
1405                 log_err_errno("Failed to send client status");
1406 }
1407
1408 static int do_server(struct sock_args *args, int ipc_fd)
1409 {
1410         /* ipc_fd = -1 if no parent process to signal */
1411         struct timeval timeout = { .tv_sec = prog_timeout }, *ptval = NULL;
1412         unsigned char addr[sizeof(struct sockaddr_in6)] = {};
1413         socklen_t alen = sizeof(addr);
1414         int lsd, csd = -1;
1415
1416         fd_set rfds;
1417         int rc;
1418
1419         if (args->serverns) {
1420                 if (switch_ns(args->serverns)) {
1421                         log_error("Could not set server netns to %s\n",
1422                                   args->serverns);
1423                         goto err_exit;
1424                 }
1425                 log_msg("Switched server netns\n");
1426         }
1427
1428         if (resolve_devices(args) || validate_addresses(args))
1429                 goto err_exit;
1430
1431         if (prog_timeout)
1432                 ptval = &timeout;
1433
1434         if (args->has_grp)
1435                 lsd = msock_server(args);
1436         else
1437                 lsd = lsock_init(args);
1438
1439         if (lsd < 0)
1440                 goto err_exit;
1441
1442         if (args->bind_test_only) {
1443                 close(lsd);
1444                 ipc_write(ipc_fd, 1);
1445                 return 0;
1446         }
1447
1448         if (args->type != SOCK_STREAM) {
1449                 ipc_write(ipc_fd, 1);
1450                 rc = msg_loop(0, lsd, (void *) addr, alen, args);
1451                 close(lsd);
1452                 return rc;
1453         }
1454
1455         if (args->password && tcp_md5_remote(lsd, args)) {
1456                 close(lsd);
1457                 goto err_exit;
1458         }
1459
1460         ipc_write(ipc_fd, 1);
1461         while (1) {
1462                 log_msg("\n");
1463                 log_msg("waiting for client connection.\n");
1464                 FD_ZERO(&rfds);
1465                 FD_SET(lsd, &rfds);
1466
1467                 rc = select(lsd+1, &rfds, NULL, NULL, ptval);
1468                 if (rc == 0) {
1469                         rc = 2;
1470                         break;
1471                 }
1472
1473                 if (rc < 0) {
1474                         if (errno == EINTR)
1475                                 continue;
1476
1477                         log_err_errno("select failed");
1478                         break;
1479                 }
1480
1481                 if (FD_ISSET(lsd, &rfds)) {
1482
1483                         csd = accept(lsd, (void *) addr, &alen);
1484                         if (csd < 0) {
1485                                 log_err_errno("accept failed");
1486                                 break;
1487                         }
1488
1489                         rc = show_sockstat(csd, args);
1490                         if (rc)
1491                                 break;
1492
1493                         rc = check_device(csd, args);
1494                         if (rc)
1495                                 break;
1496                 }
1497
1498                 rc = msg_loop(0, csd, (void *) addr, alen, args);
1499                 close(csd);
1500
1501                 if (!interactive)
1502                         break;
1503         }
1504
1505         close(lsd);
1506
1507         return rc;
1508 err_exit:
1509         ipc_write(ipc_fd, 0);
1510         return 1;
1511 }
1512
1513 static int wait_for_connect(int sd)
1514 {
1515         struct timeval _tv = { .tv_sec = prog_timeout }, *tv = NULL;
1516         fd_set wfd;
1517         int val = 0, sz = sizeof(val);
1518         int rc;
1519
1520         FD_ZERO(&wfd);
1521         FD_SET(sd, &wfd);
1522
1523         if (prog_timeout)
1524                 tv = &_tv;
1525
1526         rc = select(FD_SETSIZE, NULL, &wfd, NULL, tv);
1527         if (rc == 0) {
1528                 log_error("connect timed out\n");
1529                 return -2;
1530         } else if (rc < 0) {
1531                 log_err_errno("select failed");
1532                 return -3;
1533         }
1534
1535         if (getsockopt(sd, SOL_SOCKET, SO_ERROR, &val, (socklen_t *)&sz) < 0) {
1536                 log_err_errno("getsockopt(SO_ERROR) failed");
1537                 return -4;
1538         }
1539
1540         if (val != 0) {
1541                 log_error("connect failed: %d: %s\n", val, strerror(val));
1542                 return -1;
1543         }
1544
1545         return 0;
1546 }
1547
1548 static int connectsock(void *addr, socklen_t alen, struct sock_args *args)
1549 {
1550         int sd, rc = -1;
1551         long flags;
1552
1553         sd = socket(args->version, args->type, args->protocol);
1554         if (sd < 0) {
1555                 log_err_errno("Failed to create socket");
1556                 return -1;
1557         }
1558
1559         flags = fcntl(sd, F_GETFL);
1560         if ((flags < 0) || (fcntl(sd, F_SETFL, flags|O_NONBLOCK) < 0)) {
1561                 log_err_errno("Failed to set non-blocking option");
1562                 goto err;
1563         }
1564
1565         if (set_reuseport(sd) != 0)
1566                 goto err;
1567
1568         if (args->dev && bind_to_device(sd, args->dev) != 0)
1569                 goto err;
1570         else if (args->use_setsockopt &&
1571                  set_unicast_if(sd, args->ifindex, args->version))
1572                 goto err;
1573
1574         if (args->has_local_ip && bind_socket(sd, args))
1575                 goto err;
1576
1577         if (args->type != SOCK_STREAM)
1578                 goto out;
1579
1580         if (args->password && tcp_md5sig(sd, addr, alen, args))
1581                 goto err;
1582
1583         if (args->bind_test_only)
1584                 goto out;
1585
1586         if (connect(sd, addr, alen) < 0) {
1587                 if (errno != EINPROGRESS) {
1588                         log_err_errno("Failed to connect to remote host");
1589                         rc = -1;
1590                         goto err;
1591                 }
1592                 rc = wait_for_connect(sd);
1593                 if (rc < 0)
1594                         goto err;
1595         }
1596 out:
1597         return sd;
1598
1599 err:
1600         close(sd);
1601         return rc;
1602 }
1603
1604 static int do_client(struct sock_args *args)
1605 {
1606         struct sockaddr_in sin = {
1607                 .sin_family = AF_INET,
1608         };
1609         struct sockaddr_in6 sin6 = {
1610                 .sin6_family = AF_INET6,
1611         };
1612         void *addr;
1613         int alen;
1614         int rc = 0;
1615         int sd;
1616
1617         if (!args->has_remote_ip && !args->has_grp) {
1618                 fprintf(stderr, "remote IP or multicast group not given\n");
1619                 return 1;
1620         }
1621
1622         if (args->clientns) {
1623                 if (switch_ns(args->clientns)) {
1624                         log_error("Could not set client netns to %s\n",
1625                                   args->clientns);
1626                         return 1;
1627                 }
1628                 log_msg("Switched client netns\n");
1629         }
1630
1631         if (resolve_devices(args) || validate_addresses(args))
1632                 return 1;
1633
1634         if ((args->use_setsockopt || args->use_cmsg) && !args->ifindex) {
1635                 fprintf(stderr, "Device binding not specified\n");
1636                 return 1;
1637         }
1638         if (args->use_setsockopt || args->use_cmsg)
1639                 args->dev = NULL;
1640
1641         switch (args->version) {
1642         case AF_INET:
1643                 sin.sin_port = htons(args->port);
1644                 if (args->has_grp)
1645                         sin.sin_addr = args->grp;
1646                 else
1647                         sin.sin_addr = args->remote_addr.in;
1648                 addr = &sin;
1649                 alen = sizeof(sin);
1650                 break;
1651         case AF_INET6:
1652                 sin6.sin6_port = htons(args->port);
1653                 sin6.sin6_addr = args->remote_addr.in6;
1654                 sin6.sin6_scope_id = args->scope_id;
1655                 addr = &sin6;
1656                 alen = sizeof(sin6);
1657                 break;
1658         }
1659
1660         if (args->has_grp)
1661                 sd = msock_client(args);
1662         else
1663                 sd = connectsock(addr, alen, args);
1664
1665         if (sd < 0)
1666                 return -sd;
1667
1668         if (args->bind_test_only)
1669                 goto out;
1670
1671         if (args->type == SOCK_STREAM) {
1672                 rc = show_sockstat(sd, args);
1673                 if (rc != 0)
1674                         goto out;
1675         }
1676
1677         rc = msg_loop(1, sd, addr, alen, args);
1678
1679 out:
1680         close(sd);
1681
1682         return rc;
1683 }
1684
1685 static char *random_msg(int len)
1686 {
1687         int i, n = 0, olen = len + 1;
1688         char *m;
1689
1690         if (len <= 0)
1691                 return NULL;
1692
1693         m = malloc(olen);
1694         if (!m)
1695                 return NULL;
1696
1697         while (len > 26) {
1698                 i = snprintf(m + n, olen - n, "%.26s",
1699                              "abcdefghijklmnopqrstuvwxyz");
1700                 n += i;
1701                 len -= i;
1702         }
1703         i = snprintf(m + n, olen - n, "%.*s", len,
1704                      "abcdefghijklmnopqrstuvwxyz");
1705         return m;
1706 }
1707
1708 static int ipc_child(int fd, struct sock_args *args)
1709 {
1710         server_mode = 1; /* to tell log_msg in case we are in both_mode */
1711
1712         return do_server(args, fd);
1713 }
1714
1715 static int ipc_parent(int cpid, int fd, struct sock_args *args)
1716 {
1717         int client_status;
1718         int status;
1719         int buf;
1720
1721         /* do the client-side function here in the parent process,
1722          * waiting to be told when to continue
1723          */
1724         if (read(fd, &buf, sizeof(buf)) <= 0) {
1725                 log_err_errno("Failed to read IPC status from status");
1726                 return 1;
1727         }
1728         if (!buf) {
1729                 log_error("Server failed; can not continue\n");
1730                 return 1;
1731         }
1732         log_msg("Server is ready\n");
1733
1734         client_status = do_client(args);
1735         log_msg("parent is done!\n");
1736
1737         if (kill(cpid, 0) == 0)
1738                 kill(cpid, SIGKILL);
1739
1740         wait(&status);
1741         return client_status;
1742 }
1743
1744 #define GETOPT_STR  "sr:l:p:t:g:P:DRn:M:m:d:BN:O:SCi6L:0:1:2:Fbq"
1745
1746 static void print_usage(char *prog)
1747 {
1748         printf(
1749         "usage: %s OPTS\n"
1750         "Required:\n"
1751         "    -r addr       remote address to connect to (client mode only)\n"
1752         "    -p port       port to connect to (client mode)/listen on (server mode)\n"
1753         "                  (default: %d)\n"
1754         "    -s            server mode (default: client mode)\n"
1755         "    -t            timeout seconds (default: none)\n"
1756         "\n"
1757         "Optional:\n"
1758         "    -B            do both client and server via fork and IPC\n"
1759         "    -N ns         set client to network namespace ns (requires root)\n"
1760         "    -O ns         set server to network namespace ns (requires root)\n"
1761         "    -F            Restart server loop\n"
1762         "    -6            IPv6 (default is IPv4)\n"
1763         "    -P proto      protocol for socket: icmp, ospf (default: none)\n"
1764         "    -D|R          datagram (D) / raw (R) socket (default stream)\n"
1765         "    -l addr       local address to bind to\n"
1766         "\n"
1767         "    -d dev        bind socket to given device name\n"
1768         "    -S            use setsockopt (IP_UNICAST_IF or IP_MULTICAST_IF)\n"
1769         "                  to set device binding\n"
1770         "    -C            use cmsg and IP_PKTINFO to specify device binding\n"
1771         "\n"
1772         "    -L len        send random message of given length\n"
1773         "    -n num        number of times to send message\n"
1774         "\n"
1775         "    -M password   use MD5 sum protection\n"
1776         "    -m prefix/len prefix and length to use for MD5 key\n"
1777         "    -g grp        multicast group (e.g., 239.1.1.1)\n"
1778         "    -i            interactive mode (default is echo and terminate)\n"
1779         "\n"
1780         "    -0 addr       Expected local address\n"
1781         "    -1 addr       Expected remote address\n"
1782         "    -2 dev        Expected device name (or index) to receive packet\n"
1783         "\n"
1784         "    -b            Bind test only.\n"
1785         "    -q            Be quiet. Run test without printing anything.\n"
1786         , prog, DEFAULT_PORT);
1787 }
1788
1789 int main(int argc, char *argv[])
1790 {
1791         struct sock_args args = {
1792                 .version = AF_INET,
1793                 .type    = SOCK_STREAM,
1794                 .port    = DEFAULT_PORT,
1795         };
1796         struct protoent *pe;
1797         int both_mode = 0;
1798         unsigned int tmp;
1799         int forever = 0;
1800         int fd[2];
1801         int cpid;
1802
1803         /* process inputs */
1804         extern char *optarg;
1805         int rc = 0;
1806
1807         /*
1808          * process input args
1809          */
1810
1811         while ((rc = getopt(argc, argv, GETOPT_STR)) != -1) {
1812                 switch (rc) {
1813                 case 'B':
1814                         both_mode = 1;
1815                         break;
1816                 case 's':
1817                         server_mode = 1;
1818                         break;
1819                 case 'F':
1820                         forever = 1;
1821                         break;
1822                 case 'l':
1823                         args.has_local_ip = 1;
1824                         args.local_addr_str = optarg;
1825                         break;
1826                 case 'r':
1827                         args.has_remote_ip = 1;
1828                         args.remote_addr_str = optarg;
1829                         break;
1830                 case 'p':
1831                         if (str_to_uint(optarg, 1, 65535, &tmp) != 0) {
1832                                 fprintf(stderr, "Invalid port\n");
1833                                 return 1;
1834                         }
1835                         args.port = (unsigned short) tmp;
1836                         break;
1837                 case 't':
1838                         if (str_to_uint(optarg, 0, INT_MAX,
1839                                         &prog_timeout) != 0) {
1840                                 fprintf(stderr, "Invalid timeout\n");
1841                                 return 1;
1842                         }
1843                         break;
1844                 case 'D':
1845                         args.type = SOCK_DGRAM;
1846                         break;
1847                 case 'R':
1848                         args.type = SOCK_RAW;
1849                         args.port = 0;
1850                         if (!args.protocol)
1851                                 args.protocol = IPPROTO_RAW;
1852                         break;
1853                 case 'P':
1854                         pe = getprotobyname(optarg);
1855                         if (pe) {
1856                                 args.protocol = pe->p_proto;
1857                         } else {
1858                                 if (str_to_uint(optarg, 0, 0xffff, &tmp) != 0) {
1859                                         fprintf(stderr, "Invalid protocol\n");
1860                                         return 1;
1861                                 }
1862                                 args.protocol = tmp;
1863                         }
1864                         break;
1865                 case 'n':
1866                         iter = atoi(optarg);
1867                         break;
1868                 case 'N':
1869                         args.clientns = optarg;
1870                         break;
1871                 case 'O':
1872                         args.serverns = optarg;
1873                         break;
1874                 case 'L':
1875                         msg = random_msg(atoi(optarg));
1876                         break;
1877                 case 'M':
1878                         args.password = optarg;
1879                         break;
1880                 case 'm':
1881                         args.md5_prefix_str = optarg;
1882                         break;
1883                 case 'S':
1884                         args.use_setsockopt = 1;
1885                         break;
1886                 case 'C':
1887                         args.use_cmsg = 1;
1888                         break;
1889                 case 'd':
1890                         args.dev = optarg;
1891                         break;
1892                 case 'i':
1893                         interactive = 1;
1894                         break;
1895                 case 'g':
1896                         args.has_grp = 1;
1897                         if (convert_addr(&args, optarg, ADDR_TYPE_MCAST) < 0)
1898                                 return 1;
1899                         args.type = SOCK_DGRAM;
1900                         break;
1901                 case '6':
1902                         args.version = AF_INET6;
1903                         break;
1904                 case 'b':
1905                         args.bind_test_only = 1;
1906                         break;
1907                 case '0':
1908                         args.has_expected_laddr = 1;
1909                         args.expected_laddr_str = optarg;
1910                         break;
1911                 case '1':
1912                         args.has_expected_raddr = 1;
1913                         args.expected_raddr_str = optarg;
1914                         break;
1915                 case '2':
1916                         args.expected_dev = optarg;
1917                         break;
1918                 case 'q':
1919                         quiet = 1;
1920                         break;
1921                 default:
1922                         print_usage(argv[0]);
1923                         return 1;
1924                 }
1925         }
1926
1927         if (args.password &&
1928             ((!args.has_remote_ip && !args.md5_prefix_str) ||
1929               args.type != SOCK_STREAM)) {
1930                 log_error("MD5 passwords apply to TCP only and require a remote ip for the password\n");
1931                 return 1;
1932         }
1933
1934         if (args.md5_prefix_str && !args.password) {
1935                 log_error("Prefix range for MD5 protection specified without a password\n");
1936                 return 1;
1937         }
1938
1939         if (iter == 0) {
1940                 fprintf(stderr, "Invalid number of messages to send\n");
1941                 return 1;
1942         }
1943
1944         if (args.type == SOCK_STREAM && !args.protocol)
1945                 args.protocol = IPPROTO_TCP;
1946         if (args.type == SOCK_DGRAM && !args.protocol)
1947                 args.protocol = IPPROTO_UDP;
1948
1949         if ((args.type == SOCK_STREAM || args.type == SOCK_DGRAM) &&
1950              args.port == 0) {
1951                 fprintf(stderr, "Invalid port number\n");
1952                 return 1;
1953         }
1954
1955         if ((both_mode || !server_mode) && !args.has_grp &&
1956             !args.has_remote_ip && !args.has_local_ip) {
1957                 fprintf(stderr,
1958                         "Local (server mode) or remote IP (client IP) required\n");
1959                 return 1;
1960         }
1961
1962         if (interactive) {
1963                 prog_timeout = 0;
1964                 msg = NULL;
1965         }
1966
1967         if (both_mode) {
1968                 if (pipe(fd) < 0) {
1969                         perror("pipe");
1970                         exit(1);
1971                 }
1972
1973                 cpid = fork();
1974                 if (cpid < 0) {
1975                         perror("fork");
1976                         exit(1);
1977                 }
1978                 if (cpid)
1979                         return ipc_parent(cpid, fd[0], &args);
1980
1981                 return ipc_child(fd[1], &args);
1982         }
1983
1984         if (server_mode) {
1985                 do {
1986                         rc = do_server(&args, -1);
1987                 } while (forever);
1988
1989                 return rc;
1990         }
1991         return do_client(&args);
1992 }