DA: Skip initializing failed_bssids list when eapol failure case
[platform/upstream/connman.git] / vpn / plugins / libwireguard.c
1 // SPDX-License-Identifier: LGPL-2.1+
2 /*
3  * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  * Copyright (C) 2008-2012 Pablo Neira Ayuso <pablo@netfilter.org>.
5  */
6
7 #define _GNU_SOURCE
8
9 #include <errno.h>
10 #include <linux/genetlink.h>
11 #include <linux/if_link.h>
12 #include <linux/netlink.h>
13 #include <linux/rtnetlink.h>
14 #include <netinet/in.h>
15 #include <stdbool.h>
16 #include <stdio.h>
17 #include <stdlib.h>
18 #include <string.h>
19 #include <sys/socket.h>
20 #include <time.h>
21 #include <unistd.h>
22 #include <fcntl.h>
23 #include <assert.h>
24
25 #include <libmnl/libmnl.h>
26
27 #include "src/shared/mnlg.h"
28 #include "wireguard.h"
29
30 /* wireguard.h netlink uapi: */
31
32 #define WG_GENL_NAME "wireguard"
33 #define WG_GENL_VERSION 1
34
35 enum wg_cmd {
36         WG_CMD_GET_DEVICE,
37         WG_CMD_SET_DEVICE,
38         __WG_CMD_MAX
39 };
40
41 enum wgdevice_flag {
42         WGDEVICE_F_REPLACE_PEERS = 1U << 0
43 };
44 enum wgdevice_attribute {
45         WGDEVICE_A_UNSPEC,
46         WGDEVICE_A_IFINDEX,
47         WGDEVICE_A_IFNAME,
48         WGDEVICE_A_PRIVATE_KEY,
49         WGDEVICE_A_PUBLIC_KEY,
50         WGDEVICE_A_FLAGS,
51         WGDEVICE_A_LISTEN_PORT,
52         WGDEVICE_A_FWMARK,
53         WGDEVICE_A_PEERS,
54         __WGDEVICE_A_LAST
55 };
56
57 enum wgpeer_flag {
58         WGPEER_F_REMOVE_ME = 1U << 0,
59         WGPEER_F_REPLACE_ALLOWEDIPS = 1U << 1
60 };
61 enum wgpeer_attribute {
62         WGPEER_A_UNSPEC,
63         WGPEER_A_PUBLIC_KEY,
64         WGPEER_A_PRESHARED_KEY,
65         WGPEER_A_FLAGS,
66         WGPEER_A_ENDPOINT,
67         WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL,
68         WGPEER_A_LAST_HANDSHAKE_TIME,
69         WGPEER_A_RX_BYTES,
70         WGPEER_A_TX_BYTES,
71         WGPEER_A_ALLOWEDIPS,
72         WGPEER_A_PROTOCOL_VERSION,
73         __WGPEER_A_LAST
74 };
75
76 enum wgallowedip_attribute {
77         WGALLOWEDIP_A_UNSPEC,
78         WGALLOWEDIP_A_FAMILY,
79         WGALLOWEDIP_A_IPADDR,
80         WGALLOWEDIP_A_CIDR_MASK,
81         __WGALLOWEDIP_A_LAST
82 };
83
84 /* wireguard-specific parts: */
85
86 struct inflatable_buffer {
87         char *buffer;
88         char *next;
89         bool good;
90         size_t len;
91         size_t pos;
92 };
93
94 #define max(a, b) ((a) > (b) ? (a) : (b))
95
96 static int add_next_to_inflatable_buffer(struct inflatable_buffer *buffer)
97 {
98         size_t len, expand_to;
99         char *new_buffer;
100
101         if (!buffer->good || !buffer->next) {
102                 free(buffer->next);
103                 buffer->good = false;
104                 return 0;
105         }
106
107         len = strlen(buffer->next) + 1;
108
109         if (len == 1) {
110                 free(buffer->next);
111                 buffer->good = false;
112                 return 0;
113         }
114
115         if (buffer->len - buffer->pos <= len) {
116                 expand_to = max(buffer->len * 2, buffer->len + len + 1);
117                 new_buffer = realloc(buffer->buffer, expand_to);
118                 if (!new_buffer) {
119                         free(buffer->next);
120                         buffer->good = false;
121                         return -errno;
122                 }
123                 memset(&new_buffer[buffer->len], 0, expand_to - buffer->len);
124                 buffer->buffer = new_buffer;
125                 buffer->len = expand_to;
126         }
127         memcpy(&buffer->buffer[buffer->pos], buffer->next, len);
128         free(buffer->next);
129         buffer->good = false;
130         buffer->pos += len;
131         return 0;
132 }
133
134 static int parse_linkinfo(const struct nlattr *attr, void *data)
135 {
136         struct inflatable_buffer *buffer = data;
137
138         if (mnl_attr_get_type(attr) == IFLA_INFO_KIND && !strcmp(WG_GENL_NAME, mnl_attr_get_str(attr)))
139                 buffer->good = true;
140         return MNL_CB_OK;
141 }
142
143 static int parse_infomsg(const struct nlattr *attr, void *data)
144 {
145         struct inflatable_buffer *buffer = data;
146
147         if (mnl_attr_get_type(attr) == IFLA_LINKINFO)
148                 return mnl_attr_parse_nested(attr, parse_linkinfo, data);
149         else if (mnl_attr_get_type(attr) == IFLA_IFNAME)
150                 buffer->next = strdup(mnl_attr_get_str(attr));
151         return MNL_CB_OK;
152 }
153
154 static int read_devices_cb(const struct nlmsghdr *nlh, void *data)
155 {
156         struct inflatable_buffer *buffer = data;
157         int ret;
158
159         buffer->good = false;
160         buffer->next = NULL;
161         ret = mnl_attr_parse(nlh, sizeof(struct ifinfomsg), parse_infomsg, data);
162         if (ret != MNL_CB_OK)
163                 return ret;
164         ret = add_next_to_inflatable_buffer(buffer);
165         if (ret < 0)
166                 return ret;
167         if (nlh->nlmsg_type != NLMSG_DONE)
168                 return MNL_CB_OK + 1;
169         return MNL_CB_OK;
170 }
171
172 static int fetch_device_names(struct inflatable_buffer *buffer)
173 {
174         struct mnl_socket *nl = NULL;
175         char *rtnl_buffer = NULL;
176         size_t message_len;
177         unsigned int portid, seq;
178         ssize_t len;
179         int ret = 0;
180         struct nlmsghdr *nlh;
181         struct ifinfomsg *ifm;
182
183         ret = -ENOMEM;
184         rtnl_buffer = calloc(MNL_SOCKET_BUFFER_SIZE, 1);
185         if (!rtnl_buffer)
186                 goto cleanup;
187
188         nl = mnl_socket_open(NETLINK_ROUTE);
189         if (!nl) {
190                 ret = -errno;
191                 goto cleanup;
192         }
193
194         if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) {
195                 ret = -errno;
196                 goto cleanup;
197         }
198
199         seq = time(NULL);
200         portid = mnl_socket_get_portid(nl);
201         nlh = mnl_nlmsg_put_header(rtnl_buffer);
202         nlh->nlmsg_type = RTM_GETLINK;
203         nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP;
204         nlh->nlmsg_seq = seq;
205         ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm));
206         ifm->ifi_family = AF_UNSPEC;
207         message_len = nlh->nlmsg_len;
208
209         if (mnl_socket_sendto(nl, rtnl_buffer, message_len) < 0) {
210                 ret = -errno;
211                 goto cleanup;
212         }
213
214 another:
215         if ((len = mnl_socket_recvfrom(nl, rtnl_buffer, MNL_SOCKET_BUFFER_SIZE)) < 0) {
216                 ret = -errno;
217                 goto cleanup;
218         }
219         if ((len = mnl_cb_run(rtnl_buffer, len, seq, portid, read_devices_cb, buffer)) < 0) {
220                 /* Netlink returns NLM_F_DUMP_INTR if the set of all tunnels changed
221                  * during the dump. That's unfortunate, but is pretty common on busy
222                  * systems that are adding and removing tunnels all the time. Rather
223                  * than retrying, potentially indefinitely, we just work with the
224                  * partial results. */
225                 if (errno != EINTR) {
226                         ret = -errno;
227                         goto cleanup;
228                 }
229         }
230         if (len == MNL_CB_OK + 1)
231                 goto another;
232         ret = 0;
233
234 cleanup:
235         free(rtnl_buffer);
236         if (nl)
237                 mnl_socket_close(nl);
238         return ret;
239 }
240
241 static int add_del_iface(const char *ifname, bool add)
242 {
243         struct mnl_socket *nl = NULL;
244         char *rtnl_buffer;
245         ssize_t len;
246         int ret;
247         struct nlmsghdr *nlh;
248         struct ifinfomsg *ifm;
249         struct nlattr *nest;
250
251         rtnl_buffer = calloc(MNL_SOCKET_BUFFER_SIZE, 1);
252         if (!rtnl_buffer) {
253                 ret = -ENOMEM;
254                 goto cleanup;
255         }
256
257         nl = mnl_socket_open(NETLINK_ROUTE);
258         if (!nl) {
259                 ret = -errno;
260                 goto cleanup;
261         }
262
263         if (mnl_socket_bind(nl, 0, MNL_SOCKET_AUTOPID) < 0) {
264                 ret = -errno;
265                 goto cleanup;
266         }
267
268         nlh = mnl_nlmsg_put_header(rtnl_buffer);
269         nlh->nlmsg_type = add ? RTM_NEWLINK : RTM_DELLINK;
270         nlh->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK | (add ? NLM_F_CREATE | NLM_F_EXCL : 0);
271         nlh->nlmsg_seq = time(NULL);
272         ifm = mnl_nlmsg_put_extra_header(nlh, sizeof(*ifm));
273         ifm->ifi_family = AF_UNSPEC;
274         mnl_attr_put_strz(nlh, IFLA_IFNAME, ifname);
275         nest = mnl_attr_nest_start(nlh, IFLA_LINKINFO);
276         mnl_attr_put_strz(nlh, IFLA_INFO_KIND, WG_GENL_NAME);
277         mnl_attr_nest_end(nlh, nest);
278
279         if (mnl_socket_sendto(nl, rtnl_buffer, nlh->nlmsg_len) < 0) {
280                 ret = -errno;
281                 goto cleanup;
282         }
283         if ((len = mnl_socket_recvfrom(nl, rtnl_buffer, MNL_SOCKET_BUFFER_SIZE)) < 0) {
284                 ret = -errno;
285                 goto cleanup;
286         }
287         if (mnl_cb_run(rtnl_buffer, len, nlh->nlmsg_seq, mnl_socket_get_portid(nl), NULL, NULL) < 0) {
288                 ret = -errno;
289                 goto cleanup;
290         }
291         ret = 0;
292
293 cleanup:
294         free(rtnl_buffer);
295         if (nl)
296                 mnl_socket_close(nl);
297         return ret;
298 }
299
300 int wg_set_device(wg_device *dev)
301 {
302         int ret = 0;
303         wg_peer *peer = NULL;
304         wg_allowedip *allowedip = NULL;
305         struct nlattr *peers_nest, *peer_nest, *allowedips_nest, *allowedip_nest;
306         struct nlmsghdr *nlh;
307         struct mnlg_socket *nlg;
308
309         nlg = mnlg_socket_open(WG_GENL_NAME, WG_GENL_VERSION);
310         if (!nlg)
311                 return -errno;
312
313 again:
314         nlh = mnlg_msg_prepare(nlg, WG_CMD_SET_DEVICE, NLM_F_REQUEST | NLM_F_ACK);
315         mnl_attr_put_strz(nlh, WGDEVICE_A_IFNAME, dev->name);
316
317         if (!peer) {
318                 uint32_t flags = 0;
319
320                 if (dev->flags & WGDEVICE_HAS_PRIVATE_KEY)
321                         mnl_attr_put(nlh, WGDEVICE_A_PRIVATE_KEY, sizeof(dev->private_key), dev->private_key);
322                 if (dev->flags & WGDEVICE_HAS_LISTEN_PORT)
323                         mnl_attr_put_u16(nlh, WGDEVICE_A_LISTEN_PORT, dev->listen_port);
324                 if (dev->flags & WGDEVICE_HAS_FWMARK)
325                         mnl_attr_put_u32(nlh, WGDEVICE_A_FWMARK, dev->fwmark);
326                 if (dev->flags & WGDEVICE_REPLACE_PEERS)
327                         flags |= WGDEVICE_F_REPLACE_PEERS;
328                 if (flags)
329                         mnl_attr_put_u32(nlh, WGDEVICE_A_FLAGS, flags);
330         }
331         if (!dev->first_peer)
332                 goto send;
333         peers_nest = peer_nest = allowedips_nest = allowedip_nest = NULL;
334         peers_nest = mnl_attr_nest_start(nlh, WGDEVICE_A_PEERS);
335         for (peer = peer ? peer : dev->first_peer; peer; peer = peer->next_peer) {
336                 uint32_t flags = 0;
337
338                 peer_nest = mnl_attr_nest_start_check(nlh, MNL_SOCKET_BUFFER_SIZE, 0);
339                 if (!peer_nest)
340                         goto toobig_peers;
341                 if (!mnl_attr_put_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGPEER_A_PUBLIC_KEY, sizeof(peer->public_key), peer->public_key))
342                         goto toobig_peers;
343                 if (peer->flags & WGPEER_REMOVE_ME)
344                         flags |= WGPEER_F_REMOVE_ME;
345                 if (!allowedip) {
346                         if (peer->flags & WGPEER_REPLACE_ALLOWEDIPS)
347                                 flags |= WGPEER_F_REPLACE_ALLOWEDIPS;
348                         if (peer->flags & WGPEER_HAS_PRESHARED_KEY) {
349                                 if (!mnl_attr_put_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGPEER_A_PRESHARED_KEY, sizeof(peer->preshared_key), peer->preshared_key))
350                                         goto toobig_peers;
351                         }
352                         if (peer->endpoint.addr.sa_family == AF_INET) {
353                                 if (!mnl_attr_put_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGPEER_A_ENDPOINT, sizeof(peer->endpoint.addr4), &peer->endpoint.addr4))
354                                         goto toobig_peers;
355                         } else if (peer->endpoint.addr.sa_family == AF_INET6) {
356                                 if (!mnl_attr_put_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGPEER_A_ENDPOINT, sizeof(peer->endpoint.addr6), &peer->endpoint.addr6))
357                                         goto toobig_peers;
358                         }
359                         if (peer->flags & WGPEER_HAS_PERSISTENT_KEEPALIVE_INTERVAL) {
360                                 if (!mnl_attr_put_u16_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, peer->persistent_keepalive_interval))
361                                         goto toobig_peers;
362                         }
363                 }
364                 if (flags) {
365                         if (!mnl_attr_put_u32_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGPEER_A_FLAGS, flags))
366                                 goto toobig_peers;
367                 }
368                 if (peer->first_allowedip) {
369                         if (!allowedip)
370                                 allowedip = peer->first_allowedip;
371                         allowedips_nest = mnl_attr_nest_start_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGPEER_A_ALLOWEDIPS);
372                         if (!allowedips_nest)
373                                 goto toobig_allowedips;
374                         for (; allowedip; allowedip = allowedip->next_allowedip) {
375                                 allowedip_nest = mnl_attr_nest_start_check(nlh, MNL_SOCKET_BUFFER_SIZE, 0);
376                                 if (!allowedip_nest)
377                                         goto toobig_allowedips;
378                                 if (!mnl_attr_put_u16_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGALLOWEDIP_A_FAMILY, allowedip->family))
379                                         goto toobig_allowedips;
380                                 if (allowedip->family == AF_INET) {
381                                         if (!mnl_attr_put_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGALLOWEDIP_A_IPADDR, sizeof(allowedip->ip4), &allowedip->ip4))
382                                                 goto toobig_allowedips;
383                                 } else if (allowedip->family == AF_INET6) {
384                                         if (!mnl_attr_put_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGALLOWEDIP_A_IPADDR, sizeof(allowedip->ip6), &allowedip->ip6))
385                                                 goto toobig_allowedips;
386                                 }
387                                 if (!mnl_attr_put_u8_check(nlh, MNL_SOCKET_BUFFER_SIZE, WGALLOWEDIP_A_CIDR_MASK, allowedip->cidr))
388                                         goto toobig_allowedips;
389                                 mnl_attr_nest_end(nlh, allowedip_nest);
390                                 allowedip_nest = NULL;
391                         }
392                         mnl_attr_nest_end(nlh, allowedips_nest);
393                         allowedips_nest = NULL;
394                 }
395
396                 mnl_attr_nest_end(nlh, peer_nest);
397                 peer_nest = NULL;
398         }
399         mnl_attr_nest_end(nlh, peers_nest);
400         peers_nest = NULL;
401         goto send;
402 toobig_allowedips:
403         if (allowedip_nest)
404                 mnl_attr_nest_cancel(nlh, allowedip_nest);
405         if (allowedips_nest)
406                 mnl_attr_nest_end(nlh, allowedips_nest);
407         mnl_attr_nest_end(nlh, peer_nest);
408         mnl_attr_nest_end(nlh, peers_nest);
409         goto send;
410 toobig_peers:
411         if (peer_nest)
412                 mnl_attr_nest_cancel(nlh, peer_nest);
413         mnl_attr_nest_end(nlh, peers_nest);
414         goto send;
415 send:
416         if (mnlg_socket_send(nlg, nlh) < 0) {
417                 ret = -errno;
418                 goto out;
419         }
420         errno = 0;
421         if (mnlg_socket_recv_run(nlg, NULL, NULL) < 0) {
422                 ret = errno ? -errno : -EINVAL;
423                 goto out;
424         }
425         if (peer)
426                 goto again;
427
428 out:
429         mnlg_socket_close(nlg);
430         errno = -ret;
431         return ret;
432 }
433
434 static int parse_allowedip(const struct nlattr *attr, void *data)
435 {
436         wg_allowedip *allowedip = data;
437
438         switch (mnl_attr_get_type(attr)) {
439         case WGALLOWEDIP_A_UNSPEC:
440                 break;
441         case WGALLOWEDIP_A_FAMILY:
442                 if (!mnl_attr_validate(attr, MNL_TYPE_U16))
443                         allowedip->family = mnl_attr_get_u16(attr);
444                 break;
445         case WGALLOWEDIP_A_IPADDR:
446                 if (mnl_attr_get_payload_len(attr) == sizeof(allowedip->ip4))
447                         memcpy(&allowedip->ip4, mnl_attr_get_payload(attr), sizeof(allowedip->ip4));
448                 else if (mnl_attr_get_payload_len(attr) == sizeof(allowedip->ip6))
449                         memcpy(&allowedip->ip6, mnl_attr_get_payload(attr), sizeof(allowedip->ip6));
450                 break;
451         case WGALLOWEDIP_A_CIDR_MASK:
452                 if (!mnl_attr_validate(attr, MNL_TYPE_U8))
453                         allowedip->cidr = mnl_attr_get_u8(attr);
454                 break;
455         }
456
457         return MNL_CB_OK;
458 }
459
460 static int parse_allowedips(const struct nlattr *attr, void *data)
461 {
462         wg_peer *peer = data;
463         wg_allowedip *new_allowedip = calloc(1, sizeof(wg_allowedip));
464         int ret;
465
466         if (!new_allowedip)
467                 return MNL_CB_ERROR;
468         if (!peer->first_allowedip)
469                 peer->first_allowedip = peer->last_allowedip = new_allowedip;
470         else {
471                 peer->last_allowedip->next_allowedip = new_allowedip;
472                 peer->last_allowedip = new_allowedip;
473         }
474         ret = mnl_attr_parse_nested(attr, parse_allowedip, new_allowedip);
475         if (!ret)
476                 return ret;
477         if (!((new_allowedip->family == AF_INET && new_allowedip->cidr <= 32) || (new_allowedip->family == AF_INET6 && new_allowedip->cidr <= 128))) {
478                 errno = EAFNOSUPPORT;
479                 return MNL_CB_ERROR;
480         }
481         return MNL_CB_OK;
482 }
483
484 bool wg_key_is_zero(const wg_key key)
485 {
486         volatile uint8_t acc = 0;
487         unsigned int i;
488
489         for (i = 0; i < sizeof(wg_key); ++i) {
490                 acc |= key[i];
491                 __asm__ ("" : "=r" (acc) : "0" (acc));
492         }
493         return 1 & ((acc - 1) >> 8);
494 }
495
496 static int parse_peer(const struct nlattr *attr, void *data)
497 {
498         wg_peer *peer = data;
499
500         switch (mnl_attr_get_type(attr)) {
501         case WGPEER_A_UNSPEC:
502                 break;
503         case WGPEER_A_PUBLIC_KEY:
504                 if (mnl_attr_get_payload_len(attr) == sizeof(peer->public_key)) {
505                         memcpy(peer->public_key, mnl_attr_get_payload(attr), sizeof(peer->public_key));
506                         peer->flags |= WGPEER_HAS_PUBLIC_KEY;
507                 }
508                 break;
509         case WGPEER_A_PRESHARED_KEY:
510                 if (mnl_attr_get_payload_len(attr) == sizeof(peer->preshared_key)) {
511                         memcpy(peer->preshared_key, mnl_attr_get_payload(attr), sizeof(peer->preshared_key));
512                         if (!wg_key_is_zero(peer->preshared_key))
513                                 peer->flags |= WGPEER_HAS_PRESHARED_KEY;
514                 }
515                 break;
516         case WGPEER_A_ENDPOINT: {
517                 struct sockaddr *addr;
518
519                 if (mnl_attr_get_payload_len(attr) < sizeof(*addr))
520                         break;
521                 addr = mnl_attr_get_payload(attr);
522                 if (addr->sa_family == AF_INET && mnl_attr_get_payload_len(attr) == sizeof(peer->endpoint.addr4))
523                         memcpy(&peer->endpoint.addr4, addr, sizeof(peer->endpoint.addr4));
524                 else if (addr->sa_family == AF_INET6 && mnl_attr_get_payload_len(attr) == sizeof(peer->endpoint.addr6))
525                         memcpy(&peer->endpoint.addr6, addr, sizeof(peer->endpoint.addr6));
526                 break;
527         }
528         case WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL:
529                 if (!mnl_attr_validate(attr, MNL_TYPE_U16))
530                         peer->persistent_keepalive_interval = mnl_attr_get_u16(attr);
531                 break;
532         case WGPEER_A_LAST_HANDSHAKE_TIME:
533                 if (mnl_attr_get_payload_len(attr) == sizeof(peer->last_handshake_time))
534                         memcpy(&peer->last_handshake_time, mnl_attr_get_payload(attr), sizeof(peer->last_handshake_time));
535                 break;
536         case WGPEER_A_RX_BYTES:
537                 if (!mnl_attr_validate(attr, MNL_TYPE_U64))
538                         peer->rx_bytes = mnl_attr_get_u64(attr);
539                 break;
540         case WGPEER_A_TX_BYTES:
541                 if (!mnl_attr_validate(attr, MNL_TYPE_U64))
542                         peer->tx_bytes = mnl_attr_get_u64(attr);
543                 break;
544         case WGPEER_A_ALLOWEDIPS:
545                 return mnl_attr_parse_nested(attr, parse_allowedips, peer);
546         }
547
548         return MNL_CB_OK;
549 }
550
551 static int parse_peers(const struct nlattr *attr, void *data)
552 {
553         wg_device *device = data;
554         wg_peer *new_peer = calloc(1, sizeof(wg_peer));
555         int ret;
556
557         if (!new_peer)
558                 return MNL_CB_ERROR;
559         if (!device->first_peer)
560                 device->first_peer = device->last_peer = new_peer;
561         else {
562                 device->last_peer->next_peer = new_peer;
563                 device->last_peer = new_peer;
564         }
565         ret = mnl_attr_parse_nested(attr, parse_peer, new_peer);
566         if (!ret)
567                 return ret;
568         if (!(new_peer->flags & WGPEER_HAS_PUBLIC_KEY)) {
569                 errno = ENXIO;
570                 return MNL_CB_ERROR;
571         }
572         return MNL_CB_OK;
573 }
574
575 static int parse_device(const struct nlattr *attr, void *data)
576 {
577         wg_device *device = data;
578
579         switch (mnl_attr_get_type(attr)) {
580         case WGDEVICE_A_UNSPEC:
581                 break;
582         case WGDEVICE_A_IFINDEX:
583                 if (!mnl_attr_validate(attr, MNL_TYPE_U32))
584                         device->ifindex = mnl_attr_get_u32(attr);
585                 break;
586         case WGDEVICE_A_IFNAME:
587                 if (!mnl_attr_validate(attr, MNL_TYPE_STRING)) {
588                         strncpy(device->name, mnl_attr_get_str(attr), sizeof(device->name) - 1);
589                         device->name[sizeof(device->name) - 1] = '\0';
590                 }
591                 break;
592         case WGDEVICE_A_PRIVATE_KEY:
593                 if (mnl_attr_get_payload_len(attr) == sizeof(device->private_key)) {
594                         memcpy(device->private_key, mnl_attr_get_payload(attr), sizeof(device->private_key));
595                         device->flags |= WGDEVICE_HAS_PRIVATE_KEY;
596                 }
597                 break;
598         case WGDEVICE_A_PUBLIC_KEY:
599                 if (mnl_attr_get_payload_len(attr) == sizeof(device->public_key)) {
600                         memcpy(device->public_key, mnl_attr_get_payload(attr), sizeof(device->public_key));
601                         device->flags |= WGDEVICE_HAS_PUBLIC_KEY;
602                 }
603                 break;
604         case WGDEVICE_A_LISTEN_PORT:
605                 if (!mnl_attr_validate(attr, MNL_TYPE_U16))
606                         device->listen_port = mnl_attr_get_u16(attr);
607                 break;
608         case WGDEVICE_A_FWMARK:
609                 if (!mnl_attr_validate(attr, MNL_TYPE_U32))
610                         device->fwmark = mnl_attr_get_u32(attr);
611                 break;
612         case WGDEVICE_A_PEERS:
613                 return mnl_attr_parse_nested(attr, parse_peers, device);
614         }
615
616         return MNL_CB_OK;
617 }
618
619 static int read_device_cb(const struct nlmsghdr *nlh, void *data)
620 {
621         return mnl_attr_parse(nlh, sizeof(struct genlmsghdr), parse_device, data);
622 }
623
624 static void coalesce_peers(wg_device *device)
625 {
626         wg_peer *old_next_peer, *peer = device->first_peer;
627
628         while (peer && peer->next_peer) {
629                 if (memcmp(peer->public_key, peer->next_peer->public_key, sizeof(wg_key))) {
630                         peer = peer->next_peer;
631                         continue;
632                 }
633                 if (!peer->first_allowedip) {
634                         peer->first_allowedip = peer->next_peer->first_allowedip;
635                         peer->last_allowedip = peer->next_peer->last_allowedip;
636                 } else {
637                         peer->last_allowedip->next_allowedip = peer->next_peer->first_allowedip;
638                         peer->last_allowedip = peer->next_peer->last_allowedip;
639                 }
640                 old_next_peer = peer->next_peer;
641                 peer->next_peer = old_next_peer->next_peer;
642                 free(old_next_peer);
643         }
644 }
645
646 int wg_get_device(wg_device **device, const char *device_name)
647 {
648         int ret = 0;
649         struct nlmsghdr *nlh;
650         struct mnlg_socket *nlg;
651
652 try_again:
653         *device = calloc(1, sizeof(wg_device));
654         if (!*device)
655                 return -errno;
656
657         nlg = mnlg_socket_open(WG_GENL_NAME, WG_GENL_VERSION);
658         if (!nlg) {
659                 wg_free_device(*device);
660                 *device = NULL;
661                 return -errno;
662         }
663
664         nlh = mnlg_msg_prepare(nlg, WG_CMD_GET_DEVICE, NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP);
665         mnl_attr_put_strz(nlh, WGDEVICE_A_IFNAME, device_name);
666         if (mnlg_socket_send(nlg, nlh) < 0) {
667                 ret = -errno;
668                 goto out;
669         }
670         errno = 0;
671         if (mnlg_socket_recv_run(nlg, read_device_cb, *device) < 0) {
672                 ret = errno ? -errno : -EINVAL;
673                 goto out;
674         }
675         coalesce_peers(*device);
676
677 out:
678         if (nlg)
679                 mnlg_socket_close(nlg);
680         if (ret) {
681                 wg_free_device(*device);
682                 if (ret == -EINTR)
683                         goto try_again;
684                 *device = NULL;
685         }
686         errno = -ret;
687         return ret;
688 }
689
690 /* first\0second\0third\0forth\0last\0\0 */
691 char *wg_list_device_names(void)
692 {
693         struct inflatable_buffer buffer = { .len = MNL_SOCKET_BUFFER_SIZE };
694         int ret;
695
696         ret = -ENOMEM;
697         buffer.buffer = calloc(1, buffer.len);
698         if (!buffer.buffer)
699                 goto err;
700
701         ret = fetch_device_names(&buffer);
702 err:
703         errno = -ret;
704         if (errno) {
705                 free(buffer.buffer);
706                 return NULL;
707         }
708         return buffer.buffer;
709 }
710
711 int wg_add_device(const char *device_name)
712 {
713         return add_del_iface(device_name, true);
714 }
715
716 int wg_del_device(const char *device_name)
717 {
718         return add_del_iface(device_name, false);
719 }
720
721 void wg_free_device(wg_device *dev)
722 {
723         wg_peer *peer, *np;
724         wg_allowedip *allowedip, *na;
725
726         if (!dev)
727                 return;
728         for (peer = dev->first_peer, np = peer ? peer->next_peer : NULL; peer; peer = np, np = peer ? peer->next_peer : NULL) {
729                 for (allowedip = peer->first_allowedip, na = allowedip ? allowedip->next_allowedip : NULL; allowedip; allowedip = na, na = allowedip ? allowedip->next_allowedip : NULL)
730                         free(allowedip);
731                 free(peer);
732         }
733         free(dev);
734 }
735
736 static void encode_base64(char dest[static 4], const uint8_t src[static 3])
737 {
738         const uint8_t input[] = { (src[0] >> 2) & 63, ((src[0] << 4) | (src[1] >> 4)) & 63, ((src[1] << 2) | (src[2] >> 6)) & 63, src[2] & 63 };
739         unsigned int i;
740
741         for (i = 0; i < 4; ++i)
742                 dest[i] = input[i] + 'A'
743                           + (((25 - input[i]) >> 8) & 6)
744                           - (((51 - input[i]) >> 8) & 75)
745                           - (((61 - input[i]) >> 8) & 15)
746                           + (((62 - input[i]) >> 8) & 3);
747
748 }
749
750 void wg_key_to_base64(wg_key_b64_string base64, const wg_key key)
751 {
752         unsigned int i;
753
754         for (i = 0; i < 32 / 3; ++i)
755                 encode_base64(&base64[i * 4], &key[i * 3]);
756         encode_base64(&base64[i * 4], (const uint8_t[]){ key[i * 3 + 0], key[i * 3 + 1], 0 });
757         base64[sizeof(wg_key_b64_string) - 2] = '=';
758         base64[sizeof(wg_key_b64_string) - 1] = '\0';
759 }
760
761 static int decode_base64(const char src[static 4])
762 {
763         int val = 0;
764         unsigned int i;
765
766         for (i = 0; i < 4; ++i)
767                 val |= (-1
768                             + ((((('A' - 1) - src[i]) & (src[i] - ('Z' + 1))) >> 8) & (src[i] - 64))
769                             + ((((('a' - 1) - src[i]) & (src[i] - ('z' + 1))) >> 8) & (src[i] - 70))
770                             + ((((('0' - 1) - src[i]) & (src[i] - ('9' + 1))) >> 8) & (src[i] + 5))
771                             + ((((('+' - 1) - src[i]) & (src[i] - ('+' + 1))) >> 8) & 63)
772                             + ((((('/' - 1) - src[i]) & (src[i] - ('/' + 1))) >> 8) & 64)
773                         ) << (18 - 6 * i);
774         return val;
775 }
776
777 int wg_key_from_base64(wg_key key, const wg_key_b64_string base64)
778 {
779         unsigned int i;
780         int val;
781         volatile uint8_t ret = 0;
782
783         if (strlen(base64) != sizeof(wg_key_b64_string) - 1 || base64[sizeof(wg_key_b64_string) - 2] != '=') {
784                 errno = EINVAL;
785                 goto out;
786         }
787
788         for (i = 0; i < 32 / 3; ++i) {
789                 val = decode_base64(&base64[i * 4]);
790                 ret |= (uint32_t)val >> 31;
791                 key[i * 3 + 0] = (val >> 16) & 0xff;
792                 key[i * 3 + 1] = (val >> 8) & 0xff;
793                 key[i * 3 + 2] = val & 0xff;
794         }
795         val = decode_base64((const char[]){ base64[i * 4 + 0], base64[i * 4 + 1], base64[i * 4 + 2], 'A' });
796         ret |= ((uint32_t)val >> 31) | (val & 0xff);
797         key[i * 3 + 0] = (val >> 16) & 0xff;
798         key[i * 3 + 1] = (val >> 8) & 0xff;
799         errno = EINVAL & ~((ret - 1) >> 8);
800 out:
801         return -errno;
802 }
803
804 typedef int64_t fe[16];
805
806 static __attribute__((noinline)) void memzero_explicit(void *s, size_t count)
807 {
808         memset(s, 0, count);
809         __asm__ __volatile__("": :"r"(s) :"memory");
810 }
811
812 static void carry(fe o)
813 {
814         int i;
815
816         for (i = 0; i < 16; ++i) {
817                 o[(i + 1) % 16] += (i == 15 ? 38 : 1) * (o[i] >> 16);
818                 o[i] &= 0xffff;
819         }
820 }
821
822 static void cswap(fe p, fe q, int b)
823 {
824         int i;
825         int64_t t, c = ~(b - 1);
826
827         for (i = 0; i < 16; ++i) {
828                 t = c & (p[i] ^ q[i]);
829                 p[i] ^= t;
830                 q[i] ^= t;
831         }
832
833         memzero_explicit(&t, sizeof(t));
834         memzero_explicit(&c, sizeof(c));
835         memzero_explicit(&b, sizeof(b));
836 }
837
838 static void pack(uint8_t *o, const fe n)
839 {
840         int i, j, b;
841         fe m, t;
842
843         memcpy(t, n, sizeof(t));
844         carry(t);
845         carry(t);
846         carry(t);
847         for (j = 0; j < 2; ++j) {
848                 m[0] = t[0] - 0xffed;
849                 for (i = 1; i < 15; ++i) {
850                         m[i] = t[i] - 0xffff - ((m[i - 1] >> 16) & 1);
851                         m[i - 1] &= 0xffff;
852                 }
853                 m[15] = t[15] - 0x7fff - ((m[14] >> 16) & 1);
854                 b = (m[15] >> 16) & 1;
855                 m[14] &= 0xffff;
856                 cswap(t, m, 1 - b);
857         }
858         for (i = 0; i < 16; ++i) {
859                 o[2 * i] = t[i] & 0xff;
860                 o[2 * i + 1] = t[i] >> 8;
861         }
862
863         memzero_explicit(m, sizeof(m));
864         memzero_explicit(t, sizeof(t));
865         memzero_explicit(&b, sizeof(b));
866 }
867
868 static void add(fe o, const fe a, const fe b)
869 {
870         int i;
871
872         for (i = 0; i < 16; ++i)
873                 o[i] = a[i] + b[i];
874 }
875
876 static void subtract(fe o, const fe a, const fe b)
877 {
878         int i;
879
880         for (i = 0; i < 16; ++i)
881                 o[i] = a[i] - b[i];
882 }
883
884 static void multmod(fe o, const fe a, const fe b)
885 {
886         int i, j;
887         int64_t t[31] = { 0 };
888
889         for (i = 0; i < 16; ++i) {
890                 for (j = 0; j < 16; ++j)
891                         t[i + j] += a[i] * b[j];
892         }
893         for (i = 0; i < 15; ++i)
894                 t[i] += 38 * t[i + 16];
895         memcpy(o, t, sizeof(fe));
896         carry(o);
897         carry(o);
898
899         memzero_explicit(t, sizeof(t));
900 }
901
902 static void invert(fe o, const fe i)
903 {
904         fe c;
905         int a;
906
907         memcpy(c, i, sizeof(c));
908         for (a = 253; a >= 0; --a) {
909                 multmod(c, c, c);
910                 if (a != 2 && a != 4)
911                         multmod(c, c, i);
912         }
913         memcpy(o, c, sizeof(fe));
914
915         memzero_explicit(c, sizeof(c));
916 }
917
918 static void clamp_key(uint8_t *z)
919 {
920         z[31] = (z[31] & 127) | 64;
921         z[0] &= 248;
922 }
923
924 void wg_generate_public_key(wg_key public_key, const wg_key private_key)
925 {
926         int i, r;
927         uint8_t z[32];
928         fe a = { 1 }, b = { 9 }, c = { 0 }, d = { 1 }, e, f;
929
930         memcpy(z, private_key, sizeof(z));
931         clamp_key(z);
932
933         for (i = 254; i >= 0; --i) {
934                 r = (z[i >> 3] >> (i & 7)) & 1;
935                 cswap(a, b, r);
936                 cswap(c, d, r);
937                 add(e, a, c);
938                 subtract(a, a, c);
939                 add(c, b, d);
940                 subtract(b, b, d);
941                 multmod(d, e, e);
942                 multmod(f, a, a);
943                 multmod(a, c, a);
944                 multmod(c, b, e);
945                 add(e, a, c);
946                 subtract(a, a, c);
947                 multmod(b, a, a);
948                 subtract(c, d, f);
949                 multmod(a, c, (const fe){ 0xdb41, 1 });
950                 add(a, a, d);
951                 multmod(c, c, a);
952                 multmod(a, d, f);
953                 multmod(d, b, (const fe){ 9 });
954                 multmod(b, e, e);
955                 cswap(a, b, r);
956                 cswap(c, d, r);
957         }
958         invert(c, c);
959         multmod(a, a, c);
960         pack(public_key, a);
961
962         memzero_explicit(&r, sizeof(r));
963         memzero_explicit(z, sizeof(z));
964         memzero_explicit(a, sizeof(a));
965         memzero_explicit(b, sizeof(b));
966         memzero_explicit(c, sizeof(c));
967         memzero_explicit(d, sizeof(d));
968         memzero_explicit(e, sizeof(e));
969         memzero_explicit(f, sizeof(f));
970 }
971
972 void wg_generate_private_key(wg_key private_key)
973 {
974         wg_generate_preshared_key(private_key);
975         clamp_key(private_key);
976 }
977
978 void wg_generate_preshared_key(wg_key preshared_key)
979 {
980         ssize_t ret;
981         size_t i;
982         int fd;
983 #if defined(__OpenBSD__) || (defined(__APPLE__) && MAC_OS_X_VERSION_MIN_REQUIRED >= MAC_OS_X_VERSION_10_12) || (defined(__GLIBC__) && (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 25)))
984         if (!getentropy(preshared_key, sizeof(wg_key)))
985                 return;
986 #endif
987 #if defined(__NR_getrandom) && defined(__linux__)
988         if (syscall(__NR_getrandom, preshared_key, sizeof(wg_key), 0) == sizeof(wg_key))
989                 return;
990 #endif
991         fd = open("/dev/urandom", O_RDONLY);
992         assert(fd >= 0);
993         for (i = 0; i < sizeof(wg_key); i += ret) {
994                 ret = read(fd, preshared_key + i, sizeof(wg_key) - i);
995                 assert(ret > 0);
996         }
997         close(fd);
998 }