Use const and add is_valid_regdom() check
[platform/upstream/crda.git] / crda.c
1 /*
2  * Central Regulatory Domain Agent for Linux
3  *
4  * Userspace helper which sends regulatory domains to Linux via nl80211
5  */
6
7 #include <errno.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <sys/mman.h>
11 #include <sys/stat.h>
12 #include <fcntl.h>
13 #include <arpa/inet.h>
14
15 #include <netlink/genl/genl.h>
16 #include <netlink/genl/family.h>
17 #include <netlink/genl/ctrl.h>
18 #include <netlink/msg.h>
19 #include <netlink/attr.h>
20 #include <linux/nl80211.h>
21
22 #include "regdb.h"
23
24 #ifdef USE_OPENSSL
25 #include <openssl/objects.h>
26 #include <openssl/bn.h>
27 #include <openssl/rsa.h>
28 #include <openssl/sha.h>
29
30 #include "keys-ssl.c"
31 #endif
32
33 #ifdef USE_GCRYPT
34 #include <gcrypt.h>
35
36 #include "keys-gcrypt.c"
37 #endif
38
39 struct nl80211_state {
40         struct nl_handle *nl_handle;
41         struct nl_cache *nl_cache;
42         struct genl_family *nl80211;
43 };
44
45 static int nl80211_init(struct nl80211_state *state)
46 {
47         int err;
48
49         state->nl_handle = nl_handle_alloc();
50         if (!state->nl_handle) {
51                 fprintf(stderr, "Failed to allocate netlink handle.\n");
52                 return -ENOMEM;
53         }
54
55         if (genl_connect(state->nl_handle)) {
56                 fprintf(stderr, "Failed to connect to generic netlink.\n");
57                 err = -ENOLINK;
58                 goto out_handle_destroy;
59         }
60
61         state->nl_cache = genl_ctrl_alloc_cache(state->nl_handle);
62         if (!state->nl_cache) {
63                 fprintf(stderr, "Failed to allocate generic netlink cache.\n");
64                 err = -ENOMEM;
65                 goto out_handle_destroy;
66         }
67
68         state->nl80211 = genl_ctrl_search_by_name(state->nl_cache, "nl80211");
69         if (!state->nl80211) {
70                 fprintf(stderr, "nl80211 not found.\n");
71                 err = -ENOENT;
72                 goto out_cache_free;
73         }
74
75         return 0;
76
77  out_cache_free:
78         nl_cache_free(state->nl_cache);
79  out_handle_destroy:
80         nl_handle_destroy(state->nl_handle);
81         return err;
82 }
83
84 static void nl80211_cleanup(struct nl80211_state *state)
85 {
86         genl_family_put(state->nl80211);
87         nl_cache_free(state->nl_cache);
88         nl_handle_destroy(state->nl_handle);
89 }
90
91 static int reg_handler(struct nl_msg *msg, void *arg)
92 {
93         return NL_SKIP;
94 }
95
96 static int wait_handler(struct nl_msg *msg, void *arg)
97 {
98         int *finished = arg;
99         *finished = 1;
100         return NL_STOP;
101 }
102
103
104 static int error_handler(struct sockaddr_nl *nla, struct nlmsgerr *err, void *arg)
105 {
106         fprintf(stderr, "nl80211 error %d\n", err->error);
107         exit(err->error);
108 }
109
110 int isalpha_upper(char letter)
111 {
112         if (letter >= 'A' && letter <= 'Z')
113                 return 1;
114         return 0;
115 }
116
117 static int is_alpha2(const char *alpha2)
118 {
119         if (isalpha_upper(alpha2[0]) && isalpha_upper(alpha2[1]))
120                 return 1;
121         return 0;
122 }
123
124 static int is_world_regdom(const char *alpha2)
125 {
126         if (alpha2[0] == '0' && alpha2[1] == '0')
127                 return 1;
128         return 0;
129 }
130
131 static int is_valid_regdom(const char * alpha2)
132 {
133         if (strlen(alpha2) != 2)
134                 return 0;
135
136         if (!is_alpha2(alpha2) && !is_world_regdom(alpha2)) {
137                 return 0;
138         }
139
140         return 1;
141 }
142
143 /* ptr is 32 big endian. You don't need to convert it before passing to this
144  * function */
145
146 static void *get_file_ptr(__u8 *db, int dblen, int structlen, __be32 ptr)
147 {
148         __u32 p = ntohl(ptr);
149
150         if (p > dblen - structlen) {
151                 fprintf(stderr, "Invalid database file, bad pointer!\n");
152                 exit(3);
153         }
154
155         return (void *)(db + p);
156 }
157
158 static int put_reg_rule(__u8 *db, int dblen, __be32 ruleptr, struct nl_msg *msg)
159 {
160         struct regdb_file_reg_rule *rule;
161         struct regdb_file_freq_range *freq;
162         struct regdb_file_power_rule *power;
163
164         rule    = get_file_ptr(db, dblen, sizeof(*rule), ruleptr);
165         freq    = get_file_ptr(db, dblen, sizeof(*freq), rule->freq_range_ptr);
166         power   = get_file_ptr(db, dblen, sizeof(*power), rule->power_rule_ptr);
167
168         NLA_PUT_U32(msg, NL80211_ATTR_REG_RULE_FLAGS,           ntohl(rule->flags));
169         NLA_PUT_U32(msg, NL80211_ATTR_FREQ_RANGE_START,         ntohl(freq->start_freq));
170         NLA_PUT_U32(msg, NL80211_ATTR_FREQ_RANGE_END,           ntohl(freq->end_freq));
171         NLA_PUT_U32(msg, NL80211_ATTR_FREQ_RANGE_MAX_BW,        ntohl(freq->max_bandwidth));
172         NLA_PUT_U32(msg, NL80211_ATTR_POWER_RULE_MAX_ANT_GAIN,  ntohl(power->max_antenna_gain));
173         NLA_PUT_U32(msg, NL80211_ATTR_POWER_RULE_MAX_EIRP,      ntohl(power->max_eirp));
174
175         return 0;
176
177 nla_put_failure:
178         return -1;
179 }
180
181 int main(int argc, char **argv)
182 {
183         int fd;
184         struct stat stat;
185         __u8 *db;
186         struct regdb_file_header *header;
187         struct regdb_file_reg_country *countries;
188         int dblen, siglen, num_countries, i, j, r;
189         char alpha2[2];
190         char *env_country;
191         struct nl80211_state nlstate;
192         struct nl_cb *cb = NULL;
193         struct nl_msg *msg;
194         int found_country = 0;
195         int finished = 0;
196
197         struct regdb_file_reg_rules_collection *rcoll;
198         struct regdb_file_reg_country *country;
199         struct nlattr *nl_reg_rules;
200         int num_rules;
201
202 #ifdef USE_OPENSSL
203         RSA *rsa;
204         __u8 hash[SHA_DIGEST_LENGTH];
205         int ok = 0;
206 #endif
207 #ifdef USE_GCRYPT
208         gcry_mpi_t mpi_e, mpi_n;
209         gcry_sexp_t rsa, signature, data;
210         __u8 hash[20];
211         int ok = 0;
212 #endif
213
214         const char regdb[] = "/usr/lib/crda/regulatory.bin";
215
216         if (argc != 1) {
217                 fprintf(stderr, "Usage: %s\n", argv[0]);
218                 return -EINVAL;
219         }
220
221         env_country = getenv("COUNTRY");
222         if (!env_country) {
223                 fprintf(stderr, "COUNTRY environment variable not set.\n");
224                 return -EINVAL;
225         }
226
227         if (!is_valid_regdom(env_country)) {
228                 fprintf(stderr, "COUNTRY environment variable must be an "
229                         "ISO ISO 3166-1-alpha-2 (uppercase) or 00\n");
230                 return -EINVAL;
231         }
232
233         memcpy(alpha2, env_country, 2);
234
235         fd = open(regdb, O_RDONLY);
236         if (fd < 0) {
237                 perror("failed to open db file");
238                 return -ENOENT;
239         }
240
241         if (fstat(fd, &stat)) {
242                 perror("failed to fstat db file");
243                 return -EIO;
244         }
245
246         dblen = stat.st_size;
247
248         db = mmap(NULL, dblen, PROT_READ, MAP_PRIVATE, fd, 0);
249         if (db == MAP_FAILED) {
250                 perror("failed to mmap db file");
251                 return -EIO;
252         }
253
254         /* db file starts with a struct regdb_file_header */
255         header = get_file_ptr(db, dblen, sizeof(*header), 0);
256
257         if (ntohl(header->magic) != REGDB_MAGIC) {
258                 fprintf(stderr, "Invalid database magic\n");
259                 return -EINVAL;
260         }
261
262         if (ntohl(header->version) != REGDB_VERSION) {
263                 fprintf(stderr, "Invalid database version\n");
264                 return -EINVAL;
265         }
266
267         siglen = ntohl(header->signature_length);
268         /* adjust dblen so later sanity checks don't run into the signature */
269         dblen -= siglen;
270
271         if (dblen <= sizeof(*header)) {
272                 fprintf(stderr, "Invalid signature length %d\n", siglen);
273                 return -EINVAL;
274         }
275
276         /* verify signature */
277 #ifdef USE_OPENSSL
278         rsa = RSA_new();
279         if (!rsa) {
280                 fprintf(stderr, "Failed to create RSA key\n");
281                 return -EINVAL;
282         }
283
284         if (SHA1(db, dblen, hash) != hash) {
285                 fprintf(stderr, "Failed to calculate SHA sum\n");
286                 RSA_free(rsa);
287                 return -EINVAL;
288         }
289
290         for (i = 0; i < sizeof(keys)/sizeof(keys[0]); i++) {
291                 rsa->e = &keys[i].e;
292                 rsa->n = &keys[i].n;
293
294                 if (RSA_size(rsa) != siglen)
295                         continue;
296
297                 ok = RSA_verify(NID_sha1, hash, SHA_DIGEST_LENGTH,
298                                 db + dblen, siglen, rsa) == 1;
299                 if (ok)
300                         break;
301         }
302
303         rsa->e = NULL;
304         rsa->n = NULL;
305         RSA_free(rsa);
306
307         if (!ok) {
308                 fprintf(stderr, "Database signature wrong\n");
309                 return -EINVAL;
310         }
311
312         BN_print_fp(stdout, &keys[0].n);
313 #endif
314
315 #ifdef USE_GCRYPT
316         /* initialise */
317         gcry_check_version(NULL);
318
319         /* hash the db */
320         gcry_md_hash_buffer(GCRY_MD_SHA1, hash, db, dblen);
321
322         if (gcry_sexp_build(&data, NULL, "(data (flags pkcs1) (hash sha1 %b))",
323                             20, hash)) {
324                 fprintf(stderr, "failed to build data expression\n");
325                 return 2;
326         }
327
328         if (gcry_sexp_build(&signature, NULL, "(sig-val (rsa (s %b)))",
329                             siglen, db + dblen)) {
330                 fprintf(stderr, "failed to build signature expression\n");
331                 return 2;
332         }
333
334         for (i = 0; i < sizeof(keys)/sizeof(keys[0]); i++) {
335                 if (gcry_mpi_scan(&mpi_e, GCRYMPI_FMT_USG,
336                                   keys[0].e, keys[0].len_e, NULL) ||
337                     gcry_mpi_scan(&mpi_n, GCRYMPI_FMT_USG,
338                                   keys[0].n, keys[0].len_n, NULL)) {
339                         fprintf(stderr, "failed to convert numbers\n");
340                         return 2;
341                 }
342
343                 if (gcry_sexp_build(&rsa, NULL,
344                                     "(public-key (rsa (n %m) (e %m)))",
345                                     mpi_n, mpi_e)) {
346                         fprintf(stderr, "failed to build rsa key\n");
347                         return 2;
348                 }
349
350                 if (!gcry_pk_verify(signature, data, rsa)) {
351                         ok = 1;
352                         break;
353                 }
354         }
355
356         if (!ok) {
357                 fprintf(stderr, "Database signature wrong\n");
358                 return 2;
359         }
360 #endif
361
362         num_countries = ntohl(header->reg_country_num);
363         countries = get_file_ptr(db, dblen,
364                                  sizeof(struct regdb_file_reg_country) * num_countries,
365                                  header->reg_country_ptr);
366
367         for (i = 0; i < num_countries; i++) {
368                 country = countries + i;
369                 if (memcmp(country->alpha2, alpha2, 2) == 0) {
370                         found_country = 1;
371                         break;
372                 }
373         }
374
375         if (!found_country) {
376                 fprintf(stderr, "failed to find a country match in regulatory database\n");
377                 return -1;
378         }
379
380         r = nl80211_init(&nlstate);
381         if (r)
382                 return -EIO;
383
384         msg = nlmsg_alloc();
385         if (!msg) {
386                 fprintf(stderr, "Failed to allocate netlink message.\n");
387                 r = -1;
388                 goto out;
389         }
390
391         genlmsg_put(msg, 0, 0, genl_family_get_id(nlstate.nl80211), 0,
392                 0, NL80211_CMD_SET_REG, 0);
393
394         rcoll = get_file_ptr(db, dblen, sizeof(*rcoll), country->reg_collection_ptr);
395         num_rules = ntohl(rcoll->reg_rule_num);
396         /* re-get pointer with sanity checking for num_rules */
397         rcoll = get_file_ptr(db, dblen,
398                              sizeof(*rcoll) + num_rules * sizeof(__be32),
399                              country->reg_collection_ptr);
400
401         NLA_PUT_STRING(msg, NL80211_ATTR_REG_ALPHA2, (char *) country->alpha2);
402
403         nl_reg_rules = nla_nest_start(msg, NL80211_ATTR_REG_RULES);
404         if (!nl_reg_rules) {
405                 r = -1;
406                 goto nla_put_failure;
407         }
408
409         for (j = 0; j < num_rules; j++) {
410                 struct nlattr *nl_reg_rule;
411                 nl_reg_rule = nla_nest_start(msg, i);
412                 if (!nl_reg_rule)
413                         goto nla_put_failure;
414
415                 r = put_reg_rule(db, dblen, rcoll->reg_rule_ptrs[j], msg);
416                 if (r)
417                         goto nla_put_failure;
418
419                 nla_nest_end(msg, nl_reg_rule);
420         }
421
422         nla_nest_end(msg, nl_reg_rules);
423
424         cb = nl_cb_alloc(NL_CB_CUSTOM);
425         if (!cb)
426                 goto cb_out;
427
428         r = nl_send_auto_complete(nlstate.nl_handle, msg);
429
430         if (r < 0) {
431                 fprintf(stderr, "failed to send regulatory request: %d\n", r);
432                 goto cb_out;
433         }
434
435         nl_cb_set(cb, NL_CB_VALID, NL_CB_CUSTOM, reg_handler, NULL);
436         nl_cb_set(cb, NL_CB_ACK, NL_CB_CUSTOM, wait_handler, &finished);
437         nl_cb_err(cb, NL_CB_CUSTOM, error_handler, NULL);
438
439         if (!finished) {
440                 r = nl_wait_for_ack(nlstate.nl_handle);
441                 if (r < 0) {
442                         fprintf(stderr, "failed to set regulatory domain: %d\n", r);
443                         goto cb_out;
444                 }
445         }
446
447 cb_out:
448         nl_cb_put(cb);
449 nla_put_failure:
450         nlmsg_free(msg);
451 out:
452         nl80211_cleanup(&nlstate);
453         return r;
454 }