Consolidate db signature validation in one function.
[platform/upstream/crda.git] / crda.c
1 /*
2  * Central Regulatory Domain Agent for Linux
3  *
4  * Userspace helper which sends regulatory domains to Linux via nl80211
5  */
6
7 #include <errno.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <sys/mman.h>
11 #include <sys/stat.h>
12 #include <fcntl.h>
13 #include <arpa/inet.h>
14
15 #include <netlink/genl/genl.h>
16 #include <netlink/genl/family.h>
17 #include <netlink/genl/ctrl.h>
18 #include <netlink/msg.h>
19 #include <netlink/attr.h>
20 #include <linux/nl80211.h>
21
22 #include "regdb.h"
23
24 struct nl80211_state {
25         struct nl_handle *nl_handle;
26         struct nl_cache *nl_cache;
27         struct genl_family *nl80211;
28 };
29
30 static int nl80211_init(struct nl80211_state *state)
31 {
32         int err;
33
34         state->nl_handle = nl_handle_alloc();
35         if (!state->nl_handle) {
36                 fprintf(stderr, "Failed to allocate netlink handle.\n");
37                 return -ENOMEM;
38         }
39
40         if (genl_connect(state->nl_handle)) {
41                 fprintf(stderr, "Failed to connect to generic netlink.\n");
42                 err = -ENOLINK;
43                 goto out_handle_destroy;
44         }
45
46         state->nl_cache = genl_ctrl_alloc_cache(state->nl_handle);
47         if (!state->nl_cache) {
48                 fprintf(stderr, "Failed to allocate generic netlink cache.\n");
49                 err = -ENOMEM;
50                 goto out_handle_destroy;
51         }
52
53         state->nl80211 = genl_ctrl_search_by_name(state->nl_cache, "nl80211");
54         if (!state->nl80211) {
55                 fprintf(stderr, "nl80211 not found.\n");
56                 err = -ENOENT;
57                 goto out_cache_free;
58         }
59
60         return 0;
61
62  out_cache_free:
63         nl_cache_free(state->nl_cache);
64  out_handle_destroy:
65         nl_handle_destroy(state->nl_handle);
66         return err;
67 }
68
69 static void nl80211_cleanup(struct nl80211_state *state)
70 {
71         genl_family_put(state->nl80211);
72         nl_cache_free(state->nl_cache);
73         nl_handle_destroy(state->nl_handle);
74 }
75
76 static int reg_handler(struct nl_msg *msg, void *arg)
77 {
78         return NL_SKIP;
79 }
80
81 static int wait_handler(struct nl_msg *msg, void *arg)
82 {
83         int *finished = arg;
84         *finished = 1;
85         return NL_STOP;
86 }
87
88
89 static int error_handler(struct sockaddr_nl *nla, struct nlmsgerr *err, void *arg)
90 {
91         fprintf(stderr, "nl80211 error %d\n", err->error);
92         exit(err->error);
93 }
94
95 int isalpha_upper(char letter)
96 {
97         if (letter >= 'A' && letter <= 'Z')
98                 return 1;
99         return 0;
100 }
101
102 static int is_alpha2(const char *alpha2)
103 {
104         if (isalpha_upper(alpha2[0]) && isalpha_upper(alpha2[1]))
105                 return 1;
106         return 0;
107 }
108
109 static int is_world_regdom(const char *alpha2)
110 {
111         if (alpha2[0] == '0' && alpha2[1] == '0')
112                 return 1;
113         return 0;
114 }
115
116 static int is_valid_regdom(const char * alpha2)
117 {
118         if (strlen(alpha2) != 2)
119                 return 0;
120
121         if (!is_alpha2(alpha2) && !is_world_regdom(alpha2)) {
122                 return 0;
123         }
124
125         return 1;
126 }
127
128 /* ptr is 32 big endian. You don't need to convert it before passing to this
129  * function */
130
131 static void *get_file_ptr(__u8 *db, int dblen, int structlen, __be32 ptr)
132 {
133         __u32 p = ntohl(ptr);
134
135         if (p > dblen - structlen) {
136                 fprintf(stderr, "Invalid database file, bad pointer!\n");
137                 exit(3);
138         }
139
140         return (void *)(db + p);
141 }
142
143 static int put_reg_rule(__u8 *db, int dblen, __be32 ruleptr, struct nl_msg *msg)
144 {
145         struct regdb_file_reg_rule *rule;
146         struct regdb_file_freq_range *freq;
147         struct regdb_file_power_rule *power;
148
149         rule    = get_file_ptr(db, dblen, sizeof(*rule), ruleptr);
150         freq    = get_file_ptr(db, dblen, sizeof(*freq), rule->freq_range_ptr);
151         power   = get_file_ptr(db, dblen, sizeof(*power), rule->power_rule_ptr);
152
153         NLA_PUT_U32(msg, NL80211_ATTR_REG_RULE_FLAGS,           ntohl(rule->flags));
154         NLA_PUT_U32(msg, NL80211_ATTR_FREQ_RANGE_START,         ntohl(freq->start_freq));
155         NLA_PUT_U32(msg, NL80211_ATTR_FREQ_RANGE_END,           ntohl(freq->end_freq));
156         NLA_PUT_U32(msg, NL80211_ATTR_FREQ_RANGE_MAX_BW,        ntohl(freq->max_bandwidth));
157         NLA_PUT_U32(msg, NL80211_ATTR_POWER_RULE_MAX_ANT_GAIN,  ntohl(power->max_antenna_gain));
158         NLA_PUT_U32(msg, NL80211_ATTR_POWER_RULE_MAX_EIRP,      ntohl(power->max_eirp));
159
160         return 0;
161
162 nla_put_failure:
163         return -1;
164 }
165
166 int main(int argc, char **argv)
167 {
168         int fd;
169         struct stat stat;
170         __u8 *db;
171         struct regdb_file_header *header;
172         struct regdb_file_reg_country *countries;
173         int dblen, siglen, num_countries, i, j, r;
174         char alpha2[2];
175         char *env_country;
176         struct nl80211_state nlstate;
177         struct nl_cb *cb = NULL;
178         struct nl_msg *msg;
179         int found_country = 0;
180         int finished = 0;
181
182         struct regdb_file_reg_rules_collection *rcoll;
183         struct regdb_file_reg_country *country;
184         struct nlattr *nl_reg_rules;
185         int num_rules;
186
187         const char regdb[] = "/usr/lib/crda/regulatory.bin";
188
189         if (argc != 1) {
190                 fprintf(stderr, "Usage: %s\n", argv[0]);
191                 return -EINVAL;
192         }
193
194         env_country = getenv("COUNTRY");
195         if (!env_country) {
196                 fprintf(stderr, "COUNTRY environment variable not set.\n");
197                 return -EINVAL;
198         }
199
200         if (!is_valid_regdom(env_country)) {
201                 fprintf(stderr, "COUNTRY environment variable must be an "
202                         "ISO ISO 3166-1-alpha-2 (uppercase) or 00\n");
203                 return -EINVAL;
204         }
205
206         memcpy(alpha2, env_country, 2);
207
208         fd = open(regdb, O_RDONLY);
209         if (fd < 0) {
210                 perror("failed to open db file");
211                 return -ENOENT;
212         }
213
214         if (fstat(fd, &stat)) {
215                 perror("failed to fstat db file");
216                 return -EIO;
217         }
218
219         dblen = stat.st_size;
220
221         db = mmap(NULL, dblen, PROT_READ, MAP_PRIVATE, fd, 0);
222         if (db == MAP_FAILED) {
223                 perror("failed to mmap db file");
224                 return -EIO;
225         }
226
227         /* db file starts with a struct regdb_file_header */
228         header = get_file_ptr(db, dblen, sizeof(*header), 0);
229
230         if (ntohl(header->magic) != REGDB_MAGIC) {
231                 fprintf(stderr, "Invalid database magic\n");
232                 return -EINVAL;
233         }
234
235         if (ntohl(header->version) != REGDB_VERSION) {
236                 fprintf(stderr, "Invalid database version\n");
237                 return -EINVAL;
238         }
239
240         siglen = ntohl(header->signature_length);
241         /* adjust dblen so later sanity checks don't run into the signature */
242         dblen -= siglen;
243
244         if (dblen <= sizeof(*header)) {
245                 fprintf(stderr, "Invalid signature length %d\n", siglen);
246                 return -EINVAL;
247         }
248
249         /* verify signature */
250         if (!crda_verify_db_signature(db, dblen, siglen))
251                 return -EINVAL;
252
253         num_countries = ntohl(header->reg_country_num);
254         countries = get_file_ptr(db, dblen,
255                                  sizeof(struct regdb_file_reg_country) * num_countries,
256                                  header->reg_country_ptr);
257
258         for (i = 0; i < num_countries; i++) {
259                 country = countries + i;
260                 if (memcmp(country->alpha2, alpha2, 2) == 0) {
261                         found_country = 1;
262                         break;
263                 }
264         }
265
266         if (!found_country) {
267                 fprintf(stderr, "failed to find a country match in regulatory database\n");
268                 return -1;
269         }
270
271         r = nl80211_init(&nlstate);
272         if (r)
273                 return -EIO;
274
275         msg = nlmsg_alloc();
276         if (!msg) {
277                 fprintf(stderr, "Failed to allocate netlink message.\n");
278                 r = -1;
279                 goto out;
280         }
281
282         genlmsg_put(msg, 0, 0, genl_family_get_id(nlstate.nl80211), 0,
283                 0, NL80211_CMD_SET_REG, 0);
284
285         rcoll = get_file_ptr(db, dblen, sizeof(*rcoll), country->reg_collection_ptr);
286         num_rules = ntohl(rcoll->reg_rule_num);
287         /* re-get pointer with sanity checking for num_rules */
288         rcoll = get_file_ptr(db, dblen,
289                              sizeof(*rcoll) + num_rules * sizeof(__be32),
290                              country->reg_collection_ptr);
291
292         NLA_PUT_STRING(msg, NL80211_ATTR_REG_ALPHA2, (char *) country->alpha2);
293
294         nl_reg_rules = nla_nest_start(msg, NL80211_ATTR_REG_RULES);
295         if (!nl_reg_rules) {
296                 r = -1;
297                 goto nla_put_failure;
298         }
299
300         for (j = 0; j < num_rules; j++) {
301                 struct nlattr *nl_reg_rule;
302                 nl_reg_rule = nla_nest_start(msg, i);
303                 if (!nl_reg_rule)
304                         goto nla_put_failure;
305
306                 r = put_reg_rule(db, dblen, rcoll->reg_rule_ptrs[j], msg);
307                 if (r)
308                         goto nla_put_failure;
309
310                 nla_nest_end(msg, nl_reg_rule);
311         }
312
313         nla_nest_end(msg, nl_reg_rules);
314
315         cb = nl_cb_alloc(NL_CB_CUSTOM);
316         if (!cb)
317                 goto cb_out;
318
319         r = nl_send_auto_complete(nlstate.nl_handle, msg);
320
321         if (r < 0) {
322                 fprintf(stderr, "failed to send regulatory request: %d\n", r);
323                 goto cb_out;
324         }
325
326         nl_cb_set(cb, NL_CB_VALID, NL_CB_CUSTOM, reg_handler, NULL);
327         nl_cb_set(cb, NL_CB_ACK, NL_CB_CUSTOM, wait_handler, &finished);
328         nl_cb_err(cb, NL_CB_CUSTOM, error_handler, NULL);
329
330         if (!finished) {
331                 r = nl_wait_for_ack(nlstate.nl_handle);
332                 if (r < 0) {
333                         fprintf(stderr, "failed to set regulatory domain: %d\n", r);
334                         goto cb_out;
335                 }
336         }
337
338 cb_out:
339         nl_cb_put(cb);
340 nla_put_failure:
341         nlmsg_free(msg);
342 out:
343         nl80211_cleanup(&nlstate);
344         return r;
345 }