Kill NL80211_ATTR_NUM_REG_RULES.
[platform/upstream/crda.git] / crda.c
1 /*
2  * Central Regulatory Domain Agent for Linux
3  *
4  * Userspace helper which sends regulatory domains to Linux via nl80211
5  */
6
7 #include <errno.h>
8 #include <stdlib.h>
9 #include <stdio.h>
10 #include <string.h>
11 #include <ctype.h> /* For isalnum(), remove later */
12 #include <sys/mman.h>
13 #include <sys/types.h>
14 #include <sys/stat.h>
15 #include <unistd.h>
16 #include <fcntl.h>
17 #include <arpa/inet.h>
18
19 #include <netlink/genl/genl.h>
20 #include <netlink/genl/family.h>
21 #include <netlink/genl/ctrl.h>  
22 #include <netlink/msg.h>
23 #include <netlink/attr.h>
24 #include <linux/nl80211.h>
25
26 #include "regdb.h"
27
28 #ifdef USE_OPENSSL
29 #include <openssl/objects.h>
30 #include <openssl/bn.h>
31 #include <openssl/rsa.h>
32 #include <openssl/sha.h>
33
34 #include "keys-ssl.c"
35 #endif
36
37 #ifdef USE_GCRYPT
38 #include <gcrypt.h>
39
40 #include "keys-gcrypt.c"
41 #endif
42
43 struct nl80211_state {
44         struct nl_handle *nl_handle;
45         struct nl_cache *nl_cache;
46         struct genl_family *nl80211;
47 };
48
49 static int nl80211_init(struct nl80211_state *state)
50 {
51         int err;
52
53         state->nl_handle = nl_handle_alloc();
54         if (!state->nl_handle) {
55                 fprintf(stderr, "Failed to allocate netlink handle.\n");
56                 return -ENOMEM;
57         }
58
59         if (genl_connect(state->nl_handle)) {
60                 fprintf(stderr, "Failed to connect to generic netlink.\n");
61                 err = -ENOLINK;
62                 goto out_handle_destroy;
63         }
64
65         state->nl_cache = genl_ctrl_alloc_cache(state->nl_handle);
66         if (!state->nl_cache) {
67                 fprintf(stderr, "Failed to allocate generic netlink cache.\n");
68                 err = -ENOMEM;
69                 goto out_handle_destroy;
70         }
71
72         state->nl80211 = genl_ctrl_search_by_name(state->nl_cache, "nl80211");
73         if (!state->nl80211) {
74                 fprintf(stderr, "nl80211 not found.\n");
75                 err = -ENOENT;
76                 goto out_cache_free;
77         }
78
79         return 0;
80
81  out_cache_free:
82         nl_cache_free(state->nl_cache);
83  out_handle_destroy:
84         nl_handle_destroy(state->nl_handle);
85         return err;
86 }
87
88 static void nl80211_cleanup(struct nl80211_state *state)
89 {
90         genl_family_put(state->nl80211);
91         nl_cache_free(state->nl_cache);
92         nl_handle_destroy(state->nl_handle);
93 }
94
95 static int reg_handler(struct nl_msg *msg, void *arg)
96 {
97         printf("=== reg_handler() called\n");
98         return NL_SKIP;
99 }
100
101 static int wait_handler(struct nl_msg *msg, void *arg)
102 {
103         int *finished = arg;
104         *finished = 1;
105         return NL_STOP;
106 }
107
108
109 static int error_handler(struct sockaddr_nl *nla, struct nlmsgerr *err, void *arg)
110 {
111         fprintf(stderr, "nl80211 error %d\n", err->error);
112         exit(err->error);
113 }
114
115 int isalpha_upper(char letter)
116 {
117         if (letter >= 65 && letter <= 90)
118                 return 1;
119         return 0;
120 }
121
122 static int is_alpha2(char *alpha2)
123 {
124         if (isalpha_upper(alpha2[0]) && isalpha_upper(alpha2[1]))
125                 return 1;
126         return 0;
127 }
128
129 static int is_world_regdom(char *alpha2)
130 {
131         /* ASCII 0 */
132         if (alpha2[0] == 48 && alpha2[1] == 48)
133                 return 1;
134         return 0;
135 }
136
137 static void *get_file_ptr(__u8 *db, int dblen, int structlen, __be32 ptr)
138 {
139         __u32 p = ntohl(ptr);
140
141         if (p > dblen - structlen) {
142                 fprintf(stderr, "Invalid database file, bad pointer!\n");
143                 exit(3);
144         }
145
146         return (void *)(db + p);
147 }
148
149 static int put_reg_rule(__u8 *db, int dblen, __be32 ruleptr, struct nl_msg *msg)
150 {
151         struct regdb_file_reg_rule *rule;
152         struct regdb_file_freq_range *freq;
153         struct regdb_file_power_rule *power;
154
155         rule    = get_file_ptr(db, dblen, sizeof(*rule), ruleptr);
156         freq    = get_file_ptr(db, dblen, sizeof(*freq), rule->freq_range_ptr);
157         power   = get_file_ptr(db, dblen, sizeof(*power), rule->power_rule_ptr);
158
159         NLA_PUT_U32(msg, NL80211_ATTR_REG_RULE_FLAGS,           ntohl(rule->flags));
160         NLA_PUT_U32(msg, NL80211_ATTR_FREQ_RANGE_START,         ntohl(freq->start_freq));
161         NLA_PUT_U32(msg, NL80211_ATTR_FREQ_RANGE_END,           ntohl(freq->end_freq));
162         NLA_PUT_U32(msg, NL80211_ATTR_FREQ_RANGE_MAX_BW,        ntohl(freq->max_bandwidth));
163         NLA_PUT_U32(msg, NL80211_ATTR_POWER_RULE_MAX_ANT_GAIN,  ntohl(power->max_antenna_gain));
164         NLA_PUT_U32(msg, NL80211_ATTR_POWER_RULE_MAX_EIRP,      ntohl(power->max_eirp));
165
166         return 0;
167
168 nla_put_failure:
169         return -1;
170 }
171
172 int main(int argc, char **argv)
173 {
174         int fd;
175         struct stat stat;
176         __u8 *db;
177         struct regdb_file_header *header;
178         struct regdb_file_reg_country *countries;
179         int dblen, siglen, num_countries, i, j, r;
180         char alpha2[2];
181         struct nl80211_state nlstate;
182         struct nl_cb *cb = NULL;
183         struct nl_msg *msg;
184         int found_country = 0;
185         int finished = 0;
186
187         struct regdb_file_reg_rules_collection *rcoll;
188         struct regdb_file_reg_country *country;
189         struct nlattr *nl_reg_rules;
190         int num_rules;
191
192 #ifdef USE_OPENSSL
193         RSA *rsa;
194         __u8 hash[SHA_DIGEST_LENGTH];
195         int ok = 0;
196 #endif
197 #ifdef USE_GCRYPT
198         gcry_mpi_t mpi_e, mpi_n;
199         gcry_sexp_t rsa, signature, data;
200         __u8 hash[20];
201         int ok = 0;
202 #endif
203         char *regdb = "/usr/lib/crda/regulatory.bin";
204
205         if (argc != 2) {
206                 fprintf(stderr, "Usage: %s <ISO-3166 alpha2 country code>\n", argv[0]);
207                 return -EINVAL;
208         }
209         
210         if (!is_alpha2(argv[1]) && !is_world_regdom(argv[1])) {
211                 fprintf(stderr, "Invalid alpha2\n");
212                 return -EINVAL;
213         }
214
215         memcpy(alpha2, argv[1], 2);
216
217         r = nl80211_init(&nlstate);
218         if (r)
219                 return -EIO;
220
221         fd = open(regdb, O_RDONLY);
222         if (fd < 0) {
223                 perror("failed to open db file");
224                 r = -ENOENT;
225                 goto out;
226         }
227
228         if (fstat(fd, &stat)) {
229                 perror("failed to fstat db file");
230                 r = -EIO;
231                 goto out;
232         }
233
234         dblen = stat.st_size;
235
236         db = mmap(NULL, dblen, PROT_READ, MAP_PRIVATE, fd, 0);
237         if (db == MAP_FAILED) {
238                 perror("failed to mmap db file");
239                 r = -EIO;
240                 goto out;
241         }
242
243         header = get_file_ptr(db, dblen, sizeof(*header), 0);
244
245         if (ntohl(header->magic) != REGDB_MAGIC) {
246                 fprintf(stderr, "Invalid database magic\n");
247                 r = -EINVAL;
248                 goto out;
249         }
250
251         if (ntohl(header->version) != REGDB_VERSION) {
252                 fprintf(stderr, "Invalid database version\n");
253                 r = -EINVAL;
254                 goto out;
255         }
256
257         siglen = ntohl(header->signature_length);
258         /* adjust dblen so later sanity checks don't run into the signature */
259         dblen -= siglen;
260
261         if (dblen <= sizeof(*header)) {
262                 fprintf(stderr, "Invalid signature length %d\n", siglen);
263                 r = -EINVAL;
264                 goto out;
265         }
266
267         /* verify signature */
268 #ifdef USE_OPENSSL
269         rsa = RSA_new();
270         if (!rsa) {
271                 fprintf(stderr, "Failed to create RSA key\n");
272                 r = -EINVAL;
273                 goto out;
274         }
275
276         if (SHA1(db, dblen, hash) != hash) {
277                 fprintf(stderr, "Failed to calculate SHA sum\n");
278                 r = -EINVAL;
279                 goto out;
280         }
281
282         for (i = 0; i < sizeof(keys)/sizeof(keys[0]); i++) {
283                 rsa->e = &keys[i].e;
284                 rsa->n = &keys[i].n;
285
286                 if (RSA_size(rsa) != siglen)
287                         continue;
288
289                 ok = RSA_verify(NID_sha1, hash, SHA_DIGEST_LENGTH,
290                                 db + dblen, siglen, rsa) == 1;
291                 if (ok)
292                         break;
293         }
294
295         if (!ok) {
296                 fprintf(stderr, "Database signature wrong\n");
297                 r = -EINVAL;
298                 goto out;
299         }
300
301         rsa->e = NULL;
302         rsa->n = NULL;
303         RSA_free(rsa);
304
305         BN_print_fp(stdout, &keys[0].n);
306
307 #endif
308
309 #ifdef USE_GCRYPT
310         /* hash the db */
311         gcry_md_hash_buffer(GCRY_MD_SHA1, hash, db, dblen);
312
313         if (gcry_sexp_build(&data, NULL, "(data (flags pkcs1) (hash sha1 %b))",
314                             20, hash)) {
315                 fprintf(stderr, "failed to build data expression\n");
316                 return 2;
317         }
318
319         if (gcry_sexp_build(&signature, NULL, "(sig-val (rsa (s %b)))",
320                             siglen, db + dblen)) {
321                 fprintf(stderr, "failed to build signature expression\n");
322                 return 2;
323         }
324
325         for (i = 0; i < sizeof(keys)/sizeof(keys[0]); i++) {
326                 if (gcry_mpi_scan(&mpi_e, GCRYMPI_FMT_USG,
327                                   keys[0].e, keys[0].len_e, NULL) ||
328                     gcry_mpi_scan(&mpi_n, GCRYMPI_FMT_USG,
329                                   keys[0].n, keys[0].len_n, NULL)) {
330                         fprintf(stderr, "failed to convert numbers\n");
331                         return 2;
332                 }
333
334                 if (gcry_sexp_build(&rsa, NULL,
335                                     "(public-key (rsa (n %m) (e %m)))",
336                                     mpi_n, mpi_e)) {
337                         fprintf(stderr, "failed to build rsa key\n");
338                         return 2;
339                 }
340
341                 if (!gcry_pk_verify(signature, data, rsa)) {
342                         ok = 1;
343                         break;
344                 }
345         }
346
347         if (!ok) {
348                 fprintf(stderr, "Database signature wrong\n");
349                 return 2;
350         }
351
352 #endif
353
354         num_countries = ntohl(header->reg_country_num);
355         countries = get_file_ptr(db, dblen,
356                                  sizeof(struct regdb_file_reg_country) * num_countries,
357                                  header->reg_country_ptr);
358
359         for (i = 0; i < num_countries; i++) {
360                 country = countries + i;
361                 if (memcmp(country->alpha2, alpha2, 2) == 0) {
362                         found_country = 1;
363                         break;
364                 }
365         }
366
367         if (!found_country) {
368                 fprintf(stderr, "failed to find a country match in regulatory database\n");
369                 return -1;
370         }
371
372         msg = nlmsg_alloc();
373         if (!msg) {
374                 fprintf(stderr, "failed to allocate netlink msg\n");
375                 return -1;
376         }
377
378         genlmsg_put(msg, 0, 0, genl_family_get_id(nlstate.nl80211), 0,
379                 0, NL80211_CMD_SET_REG, 0);
380
381
382         rcoll = get_file_ptr(db, dblen, sizeof(*rcoll), country->reg_collection_ptr);
383         num_rules = ntohl(rcoll->reg_rule_num);
384         /* re-get pointer with sanity checking for num_rules */
385         rcoll = get_file_ptr(db, dblen,
386                              sizeof(*rcoll) + num_rules * sizeof(__be32),
387                              country->reg_collection_ptr);
388
389         NLA_PUT_STRING(msg, NL80211_ATTR_REG_ALPHA2, (char *) country->alpha2);
390
391         nl_reg_rules = nla_nest_start(msg, NL80211_ATTR_REG_RULES);
392         if (!nl_reg_rules) {
393                 r = -1;
394                 goto nla_put_failure;
395         }
396
397         for (j = 0; j < num_rules; j++) {
398                 struct nlattr *nl_reg_rule;
399                 nl_reg_rule = nla_nest_start(msg, i);
400                 if (!nl_reg_rule)
401                         goto nla_put_failure;
402
403                 r = put_reg_rule(db, dblen, rcoll->reg_rule_ptrs[j], msg);
404                 if (r)
405                         goto nla_put_failure;
406
407                 nla_nest_end(msg, nl_reg_rule);
408         }
409
410         nla_nest_end(msg, nl_reg_rules);
411
412         cb = nl_cb_alloc(NL_CB_CUSTOM);
413         if (!cb)
414                 goto cb_out;
415
416         r = nl_send_auto_complete(nlstate.nl_handle, msg);
417
418         if (r < 0) {
419                 fprintf(stderr, "failed to send regulatory request: %d\n", r);
420                 goto cb_out;
421         }
422
423         nl_cb_set(cb, NL_CB_VALID, NL_CB_CUSTOM, reg_handler, NULL);
424         nl_cb_set(cb, NL_CB_ACK, NL_CB_CUSTOM, wait_handler, &finished);
425         nl_cb_err(cb, NL_CB_CUSTOM, error_handler, NULL);
426
427         if (!finished) {
428                 r = nl_wait_for_ack(nlstate.nl_handle);
429                 if (r < 0) {
430                         fprintf(stderr, "failed to set regulatory domain: %d\n", r);
431                         goto cb_out;
432                 }
433         }
434
435 cb_out:
436         nl_cb_put(cb);
437 nla_put_failure:
438         nlmsg_free(msg);
439 out:
440         nl80211_cleanup(&nlstate);
441         return r;
442 }