1 // SPDX-License-Identifier: GPL-2.0
10 #include <linux/filter.h>
11 #include <linux/bpf.h>
12 #include <linux/if_packet.h>
13 #include <linux/if_vlan.h>
14 #include <linux/virtio_net.h>
16 #include <net/ethernet.h>
17 #include <netinet/ip.h>
18 #include <netinet/udp.h>
27 #include <sys/socket.h>
29 #include <sys/types.h>
32 #include "psock_lib.h"
34 static bool cfg_use_bind;
35 static bool cfg_use_csum_off;
36 static bool cfg_use_csum_off_bad;
37 static bool cfg_use_dgram;
38 static bool cfg_use_gso;
39 static bool cfg_use_qdisc_bypass;
40 static bool cfg_use_vlan;
41 static bool cfg_use_vnet;
43 static char *cfg_ifname = "lo";
44 static int cfg_mtu = 1500;
45 static int cfg_payload_len = DATA_LEN;
46 static int cfg_truncate_len = INT_MAX;
47 static uint16_t cfg_port = 8000;
49 /* test sending up to max mtu + 1 */
50 #define TEST_SZ (sizeof(struct virtio_net_hdr) + ETH_HLEN + ETH_MAX_MTU + 1)
52 static char tbuf[TEST_SZ], rbuf[TEST_SZ];
54 static unsigned long add_csum_hword(const uint16_t *start, int num_u16)
56 unsigned long sum = 0;
59 for (i = 0; i < num_u16; i++)
65 static uint16_t build_ip_csum(const uint16_t *start, int num_u16,
68 sum += add_csum_hword(start, num_u16);
71 sum = (sum & 0xffff) + (sum >> 16);
76 static int build_vnet_header(void *header)
78 struct virtio_net_hdr *vh = header;
80 vh->hdr_len = ETH_HLEN + sizeof(struct iphdr) + sizeof(struct udphdr);
82 if (cfg_use_csum_off) {
83 vh->flags |= VIRTIO_NET_HDR_F_NEEDS_CSUM;
84 vh->csum_start = ETH_HLEN + sizeof(struct iphdr);
85 vh->csum_offset = __builtin_offsetof(struct udphdr, check);
87 /* position check field exactly one byte beyond end of packet */
88 if (cfg_use_csum_off_bad)
89 vh->csum_start += sizeof(struct udphdr) + cfg_payload_len -
94 vh->gso_type = VIRTIO_NET_HDR_GSO_UDP;
95 vh->gso_size = cfg_mtu - sizeof(struct iphdr);
101 static int build_eth_header(void *header)
103 struct ethhdr *eth = header;
106 uint16_t *tag = header + ETH_HLEN;
108 eth->h_proto = htons(ETH_P_8021Q);
109 tag[1] = htons(ETH_P_IP);
113 eth->h_proto = htons(ETH_P_IP);
117 static int build_ipv4_header(void *header, int payload_len)
119 struct iphdr *iph = header;
124 iph->tot_len = htons(sizeof(*iph) + sizeof(struct udphdr) + payload_len);
125 iph->id = htons(1337);
126 iph->protocol = IPPROTO_UDP;
127 iph->saddr = htonl((172 << 24) | (17 << 16) | 2);
128 iph->daddr = htonl((172 << 24) | (17 << 16) | 1);
129 iph->check = build_ip_csum((void *) iph, iph->ihl << 1, 0);
131 return iph->ihl << 2;
134 static int build_udp_header(void *header, int payload_len)
136 const int alen = sizeof(uint32_t);
137 struct udphdr *udph = header;
138 int len = sizeof(*udph) + payload_len;
140 udph->source = htons(9);
141 udph->dest = htons(cfg_port);
142 udph->len = htons(len);
144 if (cfg_use_csum_off)
145 udph->check = build_ip_csum(header - (2 * alen), alen,
146 htons(IPPROTO_UDP) + udph->len);
150 return sizeof(*udph);
153 static int build_packet(int payload_len)
157 off += build_vnet_header(tbuf);
158 off += build_eth_header(tbuf + off);
159 off += build_ipv4_header(tbuf + off, payload_len);
160 off += build_udp_header(tbuf + off, payload_len);
162 if (off + payload_len > sizeof(tbuf))
163 error(1, 0, "payload length exceeds max");
165 memset(tbuf + off, DATA_CHAR, payload_len);
167 return off + payload_len;
170 static void do_bind(int fd)
172 struct sockaddr_ll laddr = {0};
174 laddr.sll_family = AF_PACKET;
175 laddr.sll_protocol = htons(ETH_P_IP);
176 laddr.sll_ifindex = if_nametoindex(cfg_ifname);
177 if (!laddr.sll_ifindex)
178 error(1, errno, "if_nametoindex");
180 if (bind(fd, (void *)&laddr, sizeof(laddr)))
181 error(1, errno, "bind");
184 static void do_send(int fd, char *buf, int len)
189 buf += sizeof(struct virtio_net_hdr);
190 len -= sizeof(struct virtio_net_hdr);
198 ret = write(fd, buf, len);
200 struct sockaddr_ll laddr = {0};
202 laddr.sll_protocol = htons(ETH_P_IP);
203 laddr.sll_ifindex = if_nametoindex(cfg_ifname);
204 if (!laddr.sll_ifindex)
205 error(1, errno, "if_nametoindex");
207 ret = sendto(fd, buf, len, 0, (void *)&laddr, sizeof(laddr));
211 error(1, errno, "write");
213 error(1, 0, "write: %u %u", ret, len);
215 fprintf(stderr, "tx: %u\n", ret);
218 static int do_tx(void)
223 fd = socket(PF_PACKET, cfg_use_dgram ? SOCK_DGRAM : SOCK_RAW, 0);
225 error(1, errno, "socket t");
230 if (cfg_use_qdisc_bypass &&
231 setsockopt(fd, SOL_PACKET, PACKET_QDISC_BYPASS, &one, sizeof(one)))
232 error(1, errno, "setsockopt qdisc bypass");
235 setsockopt(fd, SOL_PACKET, PACKET_VNET_HDR, &one, sizeof(one)))
236 error(1, errno, "setsockopt vnet");
238 len = build_packet(cfg_payload_len);
240 if (cfg_truncate_len < len)
241 len = cfg_truncate_len;
243 do_send(fd, tbuf, len);
246 error(1, errno, "close t");
251 static int setup_rx(void)
253 struct timeval tv = { .tv_usec = 100 * 1000 };
254 struct sockaddr_in raddr = {0};
257 fd = socket(PF_INET, SOCK_DGRAM, 0);
259 error(1, errno, "socket r");
261 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
262 error(1, errno, "setsockopt rcv timeout");
264 raddr.sin_family = AF_INET;
265 raddr.sin_port = htons(cfg_port);
266 raddr.sin_addr.s_addr = htonl(INADDR_ANY);
268 if (bind(fd, (void *)&raddr, sizeof(raddr)))
269 error(1, errno, "bind r");
274 static void do_rx(int fd, int expected_len, char *expected)
278 ret = recv(fd, rbuf, sizeof(rbuf), 0);
280 error(1, errno, "recv");
281 if (ret != expected_len)
282 error(1, 0, "recv: %u != %u", ret, expected_len);
284 if (memcmp(rbuf, expected, ret))
285 error(1, 0, "recv: data mismatch");
287 fprintf(stderr, "rx: %u\n", ret);
290 static int setup_sniffer(void)
292 struct timeval tv = { .tv_usec = 100 * 1000 };
295 fd = socket(PF_PACKET, SOCK_RAW, 0);
297 error(1, errno, "socket p");
299 if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)))
300 error(1, errno, "setsockopt rcv timeout");
302 pair_udp_setfilter(fd);
308 static void parse_opts(int argc, char **argv)
312 while ((c = getopt(argc, argv, "bcCdgl:qt:vV")) != -1) {
318 cfg_use_csum_off = true;
321 cfg_use_csum_off_bad = true;
324 cfg_use_dgram = true;
330 cfg_payload_len = strtoul(optarg, NULL, 0);
333 cfg_use_qdisc_bypass = true;
336 cfg_truncate_len = strtoul(optarg, NULL, 0);
345 error(1, 0, "%s: parse error", argv[0]);
349 if (cfg_use_vlan && cfg_use_dgram)
350 error(1, 0, "option vlan (-V) conflicts with dgram (-d)");
352 if (cfg_use_csum_off && !cfg_use_vnet)
353 error(1, 0, "option csum offload (-c) requires vnet (-v)");
355 if (cfg_use_csum_off_bad && !cfg_use_csum_off)
356 error(1, 0, "option csum bad (-C) requires csum offload (-c)");
358 if (cfg_use_gso && !cfg_use_csum_off)
359 error(1, 0, "option gso (-g) requires csum offload (-c)");
362 static void run_test(void)
364 int fdr, fds, total_len;
367 fds = setup_sniffer();
371 /* BPF filter accepts only this length, vlan changes MAC */
372 if (cfg_payload_len == DATA_LEN && !cfg_use_vlan)
373 do_rx(fds, total_len - sizeof(struct virtio_net_hdr),
374 tbuf + sizeof(struct virtio_net_hdr));
376 do_rx(fdr, cfg_payload_len, tbuf + total_len - cfg_payload_len);
379 error(1, errno, "close s");
381 error(1, errno, "close r");
384 int main(int argc, char **argv)
386 parse_opts(argc, argv);
388 if (system("ip link set dev lo mtu 1500"))
389 error(1, errno, "ip link set mtu");
390 if (system("ip addr add dev lo 172.17.0.1/24"))
391 error(1, errno, "ip addr add");
392 if (system("sysctl -w net.ipv4.conf.lo.accept_local=1"))
393 error(1, errno, "sysctl lo.accept_local");
397 fprintf(stderr, "OK\n\n");