Merge git://git.kernel.org/pub/scm/linux/kernel/git/kvalo/wireless-drivers.git
[platform/kernel/linux-starfive.git] / tools / testing / selftests / bpf / network_helpers.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 #include <errno.h>
3 #include <stdbool.h>
4 #include <stdio.h>
5 #include <string.h>
6 #include <unistd.h>
7
8 #include <arpa/inet.h>
9
10 #include <linux/err.h>
11 #include <linux/in.h>
12 #include <linux/in6.h>
13
14 #include "bpf_util.h"
15 #include "network_helpers.h"
16
17 #define clean_errno() (errno == 0 ? "None" : strerror(errno))
18 #define log_err(MSG, ...) ({                                            \
19                         int __save = errno;                             \
20                         fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", \
21                                 __FILE__, __LINE__, clean_errno(),      \
22                                 ##__VA_ARGS__);                         \
23                         errno = __save;                                 \
24 })
25
26 struct ipv4_packet pkt_v4 = {
27         .eth.h_proto = __bpf_constant_htons(ETH_P_IP),
28         .iph.ihl = 5,
29         .iph.protocol = IPPROTO_TCP,
30         .iph.tot_len = __bpf_constant_htons(MAGIC_BYTES),
31         .tcp.urg_ptr = 123,
32         .tcp.doff = 5,
33 };
34
35 struct ipv6_packet pkt_v6 = {
36         .eth.h_proto = __bpf_constant_htons(ETH_P_IPV6),
37         .iph.nexthdr = IPPROTO_TCP,
38         .iph.payload_len = __bpf_constant_htons(MAGIC_BYTES),
39         .tcp.urg_ptr = 123,
40         .tcp.doff = 5,
41 };
42
43 static int settimeo(int fd, int timeout_ms)
44 {
45         struct timeval timeout = { .tv_sec = 3 };
46
47         if (timeout_ms > 0) {
48                 timeout.tv_sec = timeout_ms / 1000;
49                 timeout.tv_usec = (timeout_ms % 1000) * 1000;
50         }
51
52         if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout,
53                        sizeof(timeout))) {
54                 log_err("Failed to set SO_RCVTIMEO");
55                 return -1;
56         }
57
58         if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout,
59                        sizeof(timeout))) {
60                 log_err("Failed to set SO_SNDTIMEO");
61                 return -1;
62         }
63
64         return 0;
65 }
66
67 #define save_errno_close(fd) ({ int __save = errno; close(fd); errno = __save; })
68
69 int start_server(int family, int type, const char *addr_str, __u16 port,
70                  int timeout_ms)
71 {
72         struct sockaddr_storage addr = {};
73         socklen_t len;
74         int fd;
75
76         if (make_sockaddr(family, addr_str, port, &addr, &len))
77                 return -1;
78
79         fd = socket(family, type, 0);
80         if (fd < 0) {
81                 log_err("Failed to create server socket");
82                 return -1;
83         }
84
85         if (settimeo(fd, timeout_ms))
86                 goto error_close;
87
88         if (bind(fd, (const struct sockaddr *)&addr, len) < 0) {
89                 log_err("Failed to bind socket");
90                 goto error_close;
91         }
92
93         if (type == SOCK_STREAM) {
94                 if (listen(fd, 1) < 0) {
95                         log_err("Failed to listed on socket");
96                         goto error_close;
97                 }
98         }
99
100         return fd;
101
102 error_close:
103         save_errno_close(fd);
104         return -1;
105 }
106
107 static int connect_fd_to_addr(int fd,
108                               const struct sockaddr_storage *addr,
109                               socklen_t addrlen)
110 {
111         if (connect(fd, (const struct sockaddr *)addr, addrlen)) {
112                 log_err("Failed to connect to server");
113                 return -1;
114         }
115
116         return 0;
117 }
118
119 int connect_to_fd(int server_fd, int timeout_ms)
120 {
121         struct sockaddr_storage addr;
122         struct sockaddr_in *addr_in;
123         socklen_t addrlen, optlen;
124         int fd, type;
125
126         optlen = sizeof(type);
127         if (getsockopt(server_fd, SOL_SOCKET, SO_TYPE, &type, &optlen)) {
128                 log_err("getsockopt(SOL_TYPE)");
129                 return -1;
130         }
131
132         addrlen = sizeof(addr);
133         if (getsockname(server_fd, (struct sockaddr *)&addr, &addrlen)) {
134                 log_err("Failed to get server addr");
135                 return -1;
136         }
137
138         addr_in = (struct sockaddr_in *)&addr;
139         fd = socket(addr_in->sin_family, type, 0);
140         if (fd < 0) {
141                 log_err("Failed to create client socket");
142                 return -1;
143         }
144
145         if (settimeo(fd, timeout_ms))
146                 goto error_close;
147
148         if (connect_fd_to_addr(fd, &addr, addrlen))
149                 goto error_close;
150
151         return fd;
152
153 error_close:
154         save_errno_close(fd);
155         return -1;
156 }
157
158 int connect_fd_to_fd(int client_fd, int server_fd, int timeout_ms)
159 {
160         struct sockaddr_storage addr;
161         socklen_t len = sizeof(addr);
162
163         if (settimeo(client_fd, timeout_ms))
164                 return -1;
165
166         if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
167                 log_err("Failed to get server addr");
168                 return -1;
169         }
170
171         if (connect_fd_to_addr(client_fd, &addr, len))
172                 return -1;
173
174         return 0;
175 }
176
177 int make_sockaddr(int family, const char *addr_str, __u16 port,
178                   struct sockaddr_storage *addr, socklen_t *len)
179 {
180         if (family == AF_INET) {
181                 struct sockaddr_in *sin = (void *)addr;
182
183                 sin->sin_family = AF_INET;
184                 sin->sin_port = htons(port);
185                 if (addr_str &&
186                     inet_pton(AF_INET, addr_str, &sin->sin_addr) != 1) {
187                         log_err("inet_pton(AF_INET, %s)", addr_str);
188                         return -1;
189                 }
190                 if (len)
191                         *len = sizeof(*sin);
192                 return 0;
193         } else if (family == AF_INET6) {
194                 struct sockaddr_in6 *sin6 = (void *)addr;
195
196                 sin6->sin6_family = AF_INET6;
197                 sin6->sin6_port = htons(port);
198                 if (addr_str &&
199                     inet_pton(AF_INET6, addr_str, &sin6->sin6_addr) != 1) {
200                         log_err("inet_pton(AF_INET6, %s)", addr_str);
201                         return -1;
202                 }
203                 if (len)
204                         *len = sizeof(*sin6);
205                 return 0;
206         }
207         return -1;
208 }