Remove unused headers.
[platform/upstream/crda.git] / crda.c
1 /*
2  * Central Regulatory Domain Agent for Linux
3  *
4  * Userspace helper which sends regulatory domains to Linux via nl80211
5  */
6
7 #include <errno.h>
8 #include <stdio.h>
9 #include <sys/mman.h>
10 #include <sys/stat.h>
11 #include <fcntl.h>
12 #include <arpa/inet.h>
13
14 #include <netlink/genl/genl.h>
15 #include <netlink/genl/family.h>
16 #include <netlink/genl/ctrl.h>  
17 #include <netlink/msg.h>
18 #include <netlink/attr.h>
19 #include <linux/nl80211.h>
20
21 #include "regdb.h"
22
23 #ifdef USE_OPENSSL
24 #include <openssl/objects.h>
25 #include <openssl/bn.h>
26 #include <openssl/rsa.h>
27 #include <openssl/sha.h>
28
29 #include "keys-ssl.c"
30 #endif
31
32 #ifdef USE_GCRYPT
33 #include <gcrypt.h>
34
35 #include "keys-gcrypt.c"
36 #endif
37
38 struct nl80211_state {
39         struct nl_handle *nl_handle;
40         struct nl_cache *nl_cache;
41         struct genl_family *nl80211;
42 };
43
44 static int nl80211_init(struct nl80211_state *state)
45 {
46         int err;
47
48         state->nl_handle = nl_handle_alloc();
49         if (!state->nl_handle) {
50                 fprintf(stderr, "Failed to allocate netlink handle.\n");
51                 return -ENOMEM;
52         }
53
54         if (genl_connect(state->nl_handle)) {
55                 fprintf(stderr, "Failed to connect to generic netlink.\n");
56                 err = -ENOLINK;
57                 goto out_handle_destroy;
58         }
59
60         state->nl_cache = genl_ctrl_alloc_cache(state->nl_handle);
61         if (!state->nl_cache) {
62                 fprintf(stderr, "Failed to allocate generic netlink cache.\n");
63                 err = -ENOMEM;
64                 goto out_handle_destroy;
65         }
66
67         state->nl80211 = genl_ctrl_search_by_name(state->nl_cache, "nl80211");
68         if (!state->nl80211) {
69                 fprintf(stderr, "nl80211 not found.\n");
70                 err = -ENOENT;
71                 goto out_cache_free;
72         }
73
74         return 0;
75
76  out_cache_free:
77         nl_cache_free(state->nl_cache);
78  out_handle_destroy:
79         nl_handle_destroy(state->nl_handle);
80         return err;
81 }
82
83 static void nl80211_cleanup(struct nl80211_state *state)
84 {
85         genl_family_put(state->nl80211);
86         nl_cache_free(state->nl_cache);
87         nl_handle_destroy(state->nl_handle);
88 }
89
90 static int reg_handler(struct nl_msg *msg, void *arg)
91 {
92         printf("=== reg_handler() called\n");
93         return NL_SKIP;
94 }
95
96 static int wait_handler(struct nl_msg *msg, void *arg)
97 {
98         int *finished = arg;
99         *finished = 1;
100         return NL_STOP;
101 }
102
103
104 static int error_handler(struct sockaddr_nl *nla, struct nlmsgerr *err, void *arg)
105 {
106         fprintf(stderr, "nl80211 error %d\n", err->error);
107         exit(err->error);
108 }
109
110 int isalpha_upper(char letter)
111 {
112         if (letter >= 65 && letter <= 90)
113                 return 1;
114         return 0;
115 }
116
117 static int is_alpha2(char *alpha2)
118 {
119         if (isalpha_upper(alpha2[0]) && isalpha_upper(alpha2[1]))
120                 return 1;
121         return 0;
122 }
123
124 static int is_world_regdom(char *alpha2)
125 {
126         /* ASCII 0 */
127         if (alpha2[0] == 48 && alpha2[1] == 48)
128                 return 1;
129         return 0;
130 }
131
132 static void *get_file_ptr(__u8 *db, int dblen, int structlen, __be32 ptr)
133 {
134         __u32 p = ntohl(ptr);
135
136         if (p > dblen - structlen) {
137                 fprintf(stderr, "Invalid database file, bad pointer!\n");
138                 exit(3);
139         }
140
141         return (void *)(db + p);
142 }
143
144 static int put_reg_rule(__u8 *db, int dblen, __be32 ruleptr, struct nl_msg *msg)
145 {
146         struct regdb_file_reg_rule *rule;
147         struct regdb_file_freq_range *freq;
148         struct regdb_file_power_rule *power;
149
150         rule    = get_file_ptr(db, dblen, sizeof(*rule), ruleptr);
151         freq    = get_file_ptr(db, dblen, sizeof(*freq), rule->freq_range_ptr);
152         power   = get_file_ptr(db, dblen, sizeof(*power), rule->power_rule_ptr);
153
154         NLA_PUT_U32(msg, NL80211_ATTR_REG_RULE_FLAGS,           ntohl(rule->flags));
155         NLA_PUT_U32(msg, NL80211_ATTR_FREQ_RANGE_START,         ntohl(freq->start_freq));
156         NLA_PUT_U32(msg, NL80211_ATTR_FREQ_RANGE_END,           ntohl(freq->end_freq));
157         NLA_PUT_U32(msg, NL80211_ATTR_FREQ_RANGE_MAX_BW,        ntohl(freq->max_bandwidth));
158         NLA_PUT_U32(msg, NL80211_ATTR_POWER_RULE_MAX_ANT_GAIN,  ntohl(power->max_antenna_gain));
159         NLA_PUT_U32(msg, NL80211_ATTR_POWER_RULE_MAX_EIRP,      ntohl(power->max_eirp));
160
161         return 0;
162
163 nla_put_failure:
164         return -1;
165 }
166
167 int main(int argc, char **argv)
168 {
169         int fd;
170         struct stat stat;
171         __u8 *db;
172         struct regdb_file_header *header;
173         struct regdb_file_reg_country *countries;
174         int dblen, siglen, num_countries, i, j, r;
175         char alpha2[2];
176         struct nl80211_state nlstate;
177         struct nl_cb *cb = NULL;
178         struct nl_msg *msg;
179         int found_country = 0;
180         int finished = 0;
181
182         struct regdb_file_reg_rules_collection *rcoll;
183         struct regdb_file_reg_country *country;
184         struct nlattr *nl_reg_rules;
185         int num_rules;
186
187 #ifdef USE_OPENSSL
188         RSA *rsa;
189         __u8 hash[SHA_DIGEST_LENGTH];
190         int ok = 0;
191 #endif
192 #ifdef USE_GCRYPT
193         gcry_mpi_t mpi_e, mpi_n;
194         gcry_sexp_t rsa, signature, data;
195         __u8 hash[20];
196         int ok = 0;
197 #endif
198         char *regdb = "/usr/lib/crda/regulatory.bin";
199
200         if (argc != 2) {
201                 fprintf(stderr, "Usage: %s <ISO-3166 alpha2 country code>\n", argv[0]);
202                 return -EINVAL;
203         }
204         
205         if (!is_alpha2(argv[1]) && !is_world_regdom(argv[1])) {
206                 fprintf(stderr, "Invalid alpha2\n");
207                 return -EINVAL;
208         }
209
210         memcpy(alpha2, argv[1], 2);
211
212         r = nl80211_init(&nlstate);
213         if (r)
214                 return -EIO;
215
216         fd = open(regdb, O_RDONLY);
217         if (fd < 0) {
218                 perror("failed to open db file");
219                 r = -ENOENT;
220                 goto out;
221         }
222
223         if (fstat(fd, &stat)) {
224                 perror("failed to fstat db file");
225                 r = -EIO;
226                 goto out;
227         }
228
229         dblen = stat.st_size;
230
231         db = mmap(NULL, dblen, PROT_READ, MAP_PRIVATE, fd, 0);
232         if (db == MAP_FAILED) {
233                 perror("failed to mmap db file");
234                 r = -EIO;
235                 goto out;
236         }
237
238         header = get_file_ptr(db, dblen, sizeof(*header), 0);
239
240         if (ntohl(header->magic) != REGDB_MAGIC) {
241                 fprintf(stderr, "Invalid database magic\n");
242                 r = -EINVAL;
243                 goto out;
244         }
245
246         if (ntohl(header->version) != REGDB_VERSION) {
247                 fprintf(stderr, "Invalid database version\n");
248                 r = -EINVAL;
249                 goto out;
250         }
251
252         siglen = ntohl(header->signature_length);
253         /* adjust dblen so later sanity checks don't run into the signature */
254         dblen -= siglen;
255
256         if (dblen <= sizeof(*header)) {
257                 fprintf(stderr, "Invalid signature length %d\n", siglen);
258                 r = -EINVAL;
259                 goto out;
260         }
261
262         /* verify signature */
263 #ifdef USE_OPENSSL
264         rsa = RSA_new();
265         if (!rsa) {
266                 fprintf(stderr, "Failed to create RSA key\n");
267                 r = -EINVAL;
268                 goto out;
269         }
270
271         if (SHA1(db, dblen, hash) != hash) {
272                 fprintf(stderr, "Failed to calculate SHA sum\n");
273                 r = -EINVAL;
274                 goto out;
275         }
276
277         for (i = 0; i < sizeof(keys)/sizeof(keys[0]); i++) {
278                 rsa->e = &keys[i].e;
279                 rsa->n = &keys[i].n;
280
281                 if (RSA_size(rsa) != siglen)
282                         continue;
283
284                 ok = RSA_verify(NID_sha1, hash, SHA_DIGEST_LENGTH,
285                                 db + dblen, siglen, rsa) == 1;
286                 if (ok)
287                         break;
288         }
289
290         if (!ok) {
291                 fprintf(stderr, "Database signature wrong\n");
292                 r = -EINVAL;
293                 goto out;
294         }
295
296         rsa->e = NULL;
297         rsa->n = NULL;
298         RSA_free(rsa);
299
300         BN_print_fp(stdout, &keys[0].n);
301
302 #endif
303
304 #ifdef USE_GCRYPT
305         /* hash the db */
306         gcry_md_hash_buffer(GCRY_MD_SHA1, hash, db, dblen);
307
308         if (gcry_sexp_build(&data, NULL, "(data (flags pkcs1) (hash sha1 %b))",
309                             20, hash)) {
310                 fprintf(stderr, "failed to build data expression\n");
311                 return 2;
312         }
313
314         if (gcry_sexp_build(&signature, NULL, "(sig-val (rsa (s %b)))",
315                             siglen, db + dblen)) {
316                 fprintf(stderr, "failed to build signature expression\n");
317                 return 2;
318         }
319
320         for (i = 0; i < sizeof(keys)/sizeof(keys[0]); i++) {
321                 if (gcry_mpi_scan(&mpi_e, GCRYMPI_FMT_USG,
322                                   keys[0].e, keys[0].len_e, NULL) ||
323                     gcry_mpi_scan(&mpi_n, GCRYMPI_FMT_USG,
324                                   keys[0].n, keys[0].len_n, NULL)) {
325                         fprintf(stderr, "failed to convert numbers\n");
326                         return 2;
327                 }
328
329                 if (gcry_sexp_build(&rsa, NULL,
330                                     "(public-key (rsa (n %m) (e %m)))",
331                                     mpi_n, mpi_e)) {
332                         fprintf(stderr, "failed to build rsa key\n");
333                         return 2;
334                 }
335
336                 if (!gcry_pk_verify(signature, data, rsa)) {
337                         ok = 1;
338                         break;
339                 }
340         }
341
342         if (!ok) {
343                 fprintf(stderr, "Database signature wrong\n");
344                 return 2;
345         }
346
347 #endif
348
349         num_countries = ntohl(header->reg_country_num);
350         countries = get_file_ptr(db, dblen,
351                                  sizeof(struct regdb_file_reg_country) * num_countries,
352                                  header->reg_country_ptr);
353
354         for (i = 0; i < num_countries; i++) {
355                 country = countries + i;
356                 if (memcmp(country->alpha2, alpha2, 2) == 0) {
357                         found_country = 1;
358                         break;
359                 }
360         }
361
362         if (!found_country) {
363                 fprintf(stderr, "failed to find a country match in regulatory database\n");
364                 return -1;
365         }
366
367         msg = nlmsg_alloc();
368         if (!msg) {
369                 fprintf(stderr, "failed to allocate netlink msg\n");
370                 return -1;
371         }
372
373         genlmsg_put(msg, 0, 0, genl_family_get_id(nlstate.nl80211), 0,
374                 0, NL80211_CMD_SET_REG, 0);
375
376
377         rcoll = get_file_ptr(db, dblen, sizeof(*rcoll), country->reg_collection_ptr);
378         num_rules = ntohl(rcoll->reg_rule_num);
379         /* re-get pointer with sanity checking for num_rules */
380         rcoll = get_file_ptr(db, dblen,
381                              sizeof(*rcoll) + num_rules * sizeof(__be32),
382                              country->reg_collection_ptr);
383
384         NLA_PUT_STRING(msg, NL80211_ATTR_REG_ALPHA2, (char *) country->alpha2);
385
386         nl_reg_rules = nla_nest_start(msg, NL80211_ATTR_REG_RULES);
387         if (!nl_reg_rules) {
388                 r = -1;
389                 goto nla_put_failure;
390         }
391
392         for (j = 0; j < num_rules; j++) {
393                 struct nlattr *nl_reg_rule;
394                 nl_reg_rule = nla_nest_start(msg, i);
395                 if (!nl_reg_rule)
396                         goto nla_put_failure;
397
398                 r = put_reg_rule(db, dblen, rcoll->reg_rule_ptrs[j], msg);
399                 if (r)
400                         goto nla_put_failure;
401
402                 nla_nest_end(msg, nl_reg_rule);
403         }
404
405         nla_nest_end(msg, nl_reg_rules);
406
407         cb = nl_cb_alloc(NL_CB_CUSTOM);
408         if (!cb)
409                 goto cb_out;
410
411         r = nl_send_auto_complete(nlstate.nl_handle, msg);
412
413         if (r < 0) {
414                 fprintf(stderr, "failed to send regulatory request: %d\n", r);
415                 goto cb_out;
416         }
417
418         nl_cb_set(cb, NL_CB_VALID, NL_CB_CUSTOM, reg_handler, NULL);
419         nl_cb_set(cb, NL_CB_ACK, NL_CB_CUSTOM, wait_handler, &finished);
420         nl_cb_err(cb, NL_CB_CUSTOM, error_handler, NULL);
421
422         if (!finished) {
423                 r = nl_wait_for_ack(nlstate.nl_handle);
424                 if (r < 0) {
425                         fprintf(stderr, "failed to set regulatory domain: %d\n", r);
426                         goto cb_out;
427                 }
428         }
429
430 cb_out:
431         nl_cb_put(cb);
432 nla_put_failure:
433         nlmsg_free(msg);
434 out:
435         nl80211_cleanup(&nlstate);
436         return r;
437 }