Move common helpers and functions into crda.h
[platform/upstream/crda.git] / crda.c
diff --git a/crda.c b/crda.c
index 3a45da1..681ff73 100644 (file)
--- a/crda.c
+++ b/crda.c
@@ -6,6 +6,7 @@
 
 #include <errno.h>
 #include <stdio.h>
+#include <stdlib.h>
 #include <sys/mman.h>
 #include <sys/stat.h>
 #include <fcntl.h>
 
 #include <netlink/genl/genl.h>
 #include <netlink/genl/family.h>
-#include <netlink/genl/ctrl.h>  
+#include <netlink/genl/ctrl.h>
 #include <netlink/msg.h>
 #include <netlink/attr.h>
 #include <linux/nl80211.h>
 
 #include "regdb.h"
-
-#ifdef USE_OPENSSL
-#include <openssl/objects.h>
-#include <openssl/bn.h>
-#include <openssl/rsa.h>
-#include <openssl/sha.h>
-
-#include "keys-ssl.c"
-#endif
-
-#ifdef USE_GCRYPT
-#include <gcrypt.h>
-
-#include "keys-gcrypt.c"
-#endif
+#include "crda.h"
 
 struct nl80211_state {
        struct nl_handle *nl_handle;
@@ -87,59 +74,38 @@ static void nl80211_cleanup(struct nl80211_state *state)
        nl_handle_destroy(state->nl_handle);
 }
 
-static int reg_handler(struct nl_msg *msg, void *arg)
+static int reg_handler(struct nl_msg __attribute__((unused)) *msg,
+                       void __attribute__((unused)) *arg)
 {
-       printf("=== reg_handler() called\n");
        return NL_SKIP;
 }
 
-static int wait_handler(struct nl_msg *msg, void *arg)
+static int wait_handler(struct nl_msg __attribute__((unused)) *msg, void *arg)
 {
        int *finished = arg;
        *finished = 1;
        return NL_STOP;
 }
 
-
-static int error_handler(struct sockaddr_nl *nla, struct nlmsgerr *err, void *arg)
+static int error_handler(struct sockaddr_nl __attribute__((unused)) *nla,
+                           struct nlmsgerr *err,
+                           void __attribute__((unused)) *arg)
 {
        fprintf(stderr, "nl80211 error %d\n", err->error);
        exit(err->error);
 }
 
-int isalpha_upper(char letter)
+static int is_valid_regdom(const char *alpha2)
 {
-       if (letter >= 65 && letter <= 90)
-               return 1;
-       return 0;
-}
+       if (strlen(alpha2) != 2)
+               return 0;
 
-static int is_alpha2(char *alpha2)
-{
-       if (isalpha_upper(alpha2[0]) && isalpha_upper(alpha2[1]))
-               return 1;
-       return 0;
-}
+       if (!is_alpha2(alpha2) && !is_world_regdom(alpha2))
+               return 0;
 
-static int is_world_regdom(char *alpha2)
-{
-       /* ASCII 0 */
-       if (alpha2[0] == 48 && alpha2[1] == 48)
-               return 1;
-       return 0;
+       return 1;
 }
 
-static void *get_file_ptr(__u8 *db, int dblen, int structlen, __be32 ptr)
-{
-       __u32 p = ntohl(ptr);
-
-       if (p > dblen - structlen) {
-               fprintf(stderr, "Invalid database file, bad pointer!\n");
-               exit(3);
-       }
-
-       return (void *)(db + p);
-}
 
 static int put_reg_rule(__u8 *db, int dblen, __be32 ruleptr, struct nl_msg *msg)
 {
@@ -147,9 +113,9 @@ static int put_reg_rule(__u8 *db, int dblen, __be32 ruleptr, struct nl_msg *msg)
        struct regdb_file_freq_range *freq;
        struct regdb_file_power_rule *power;
 
-       rule    = get_file_ptr(db, dblen, sizeof(*rule), ruleptr);
-       freq    = get_file_ptr(db, dblen, sizeof(*freq), rule->freq_range_ptr);
-       power   = get_file_ptr(db, dblen, sizeof(*power), rule->power_rule_ptr);
+       rule  = crda_get_file_ptr(db, dblen, sizeof(*rule), ruleptr);
+       freq  = crda_get_file_ptr(db, dblen, sizeof(*freq), rule->freq_range_ptr);
+       power = crda_get_file_ptr(db, dblen, sizeof(*power), rule->power_rule_ptr);
 
        NLA_PUT_U32(msg, NL80211_ATTR_REG_RULE_FLAGS,           ntohl(rule->flags));
        NLA_PUT_U32(msg, NL80211_ATTR_FREQ_RANGE_START,         ntohl(freq->start_freq));
@@ -173,6 +139,7 @@ int main(int argc, char **argv)
        struct regdb_file_reg_country *countries;
        int dblen, siglen, num_countries, i, j, r;
        char alpha2[2];
+       char *env_country;
        struct nl80211_state nlstate;
        struct nl_cb *cb = NULL;
        struct nl_msg *msg;
@@ -184,46 +151,36 @@ int main(int argc, char **argv)
        struct nlattr *nl_reg_rules;
        int num_rules;
 
-#ifdef USE_OPENSSL
-       RSA *rsa;
-       __u8 hash[SHA_DIGEST_LENGTH];
-       int ok = 0;
-#endif
-#ifdef USE_GCRYPT
-       gcry_mpi_t mpi_e, mpi_n;
-       gcry_sexp_t rsa, signature, data;
-       __u8 hash[20];
-       int ok = 0;
-#endif
-       char *regdb = "/usr/lib/crda/regulatory.bin";
-
-       if (argc != 2) {
-               fprintf(stderr, "Usage: %s <ISO-3166 alpha2 country code>\n", argv[0]);
+       const char regdb[] = "/usr/lib/crda/regulatory.bin";
+
+       if (argc != 1) {
+               fprintf(stderr, "Usage: %s\n", argv[0]);
                return -EINVAL;
        }
-       
-       if (!is_alpha2(argv[1]) && !is_world_regdom(argv[1])) {
-               fprintf(stderr, "Invalid alpha2\n");
+
+       env_country = getenv("COUNTRY");
+       if (!env_country) {
+               fprintf(stderr, "COUNTRY environment variable not set.\n");
                return -EINVAL;
        }
 
-       memcpy(alpha2, argv[1], 2);
+       if (!is_valid_regdom(env_country)) {
+               fprintf(stderr, "COUNTRY environment variable must be an "
+                       "ISO ISO 3166-1-alpha-2 (uppercase) or 00\n");
+               return -EINVAL;
+       }
 
-       r = nl80211_init(&nlstate);
-       if (r)
-               return -EIO;
+       memcpy(alpha2, env_country, 2);
 
        fd = open(regdb, O_RDONLY);
        if (fd < 0) {
                perror("failed to open db file");
-               r = -ENOENT;
-               goto out;
+               return -ENOENT;
        }
 
        if (fstat(fd, &stat)) {
                perror("failed to fstat db file");
-               r = -EIO;
-               goto out;
+               return -EIO;
        }
 
        dblen = stat.st_size;
@@ -231,125 +188,39 @@ int main(int argc, char **argv)
        db = mmap(NULL, dblen, PROT_READ, MAP_PRIVATE, fd, 0);
        if (db == MAP_FAILED) {
                perror("failed to mmap db file");
-               r = -EIO;
-               goto out;
+               return -EIO;
        }
 
-       header = get_file_ptr(db, dblen, sizeof(*header), 0);
+       /* db file starts with a struct regdb_file_header */
+       header = crda_get_file_ptr(db, dblen, sizeof(*header), 0);
 
        if (ntohl(header->magic) != REGDB_MAGIC) {
                fprintf(stderr, "Invalid database magic\n");
-               r = -EINVAL;
-               goto out;
+               return -EINVAL;
        }
 
        if (ntohl(header->version) != REGDB_VERSION) {
                fprintf(stderr, "Invalid database version\n");
-               r = -EINVAL;
-               goto out;
+               return -EINVAL;
        }
 
        siglen = ntohl(header->signature_length);
        /* adjust dblen so later sanity checks don't run into the signature */
        dblen -= siglen;
 
-       if (dblen <= sizeof(*header)) {
+       if (dblen <= (int)sizeof(*header)) {
                fprintf(stderr, "Invalid signature length %d\n", siglen);
-               r = -EINVAL;
-               goto out;
+               return -EINVAL;
        }
 
        /* verify signature */
-#ifdef USE_OPENSSL
-       rsa = RSA_new();
-       if (!rsa) {
-               fprintf(stderr, "Failed to create RSA key\n");
-               r = -EINVAL;
-               goto out;
-       }
-
-       if (SHA1(db, dblen, hash) != hash) {
-               fprintf(stderr, "Failed to calculate SHA sum\n");
-               r = -EINVAL;
-               goto out;
-       }
-
-       for (i = 0; i < sizeof(keys)/sizeof(keys[0]); i++) {
-               rsa->e = &keys[i].e;
-               rsa->n = &keys[i].n;
-
-               if (RSA_size(rsa) != siglen)
-                       continue;
-
-               ok = RSA_verify(NID_sha1, hash, SHA_DIGEST_LENGTH,
-                               db + dblen, siglen, rsa) == 1;
-               if (ok)
-                       break;
-       }
-
-       if (!ok) {
-               fprintf(stderr, "Database signature wrong\n");
-               r = -EINVAL;
-               goto out;
-       }
-
-       rsa->e = NULL;
-       rsa->n = NULL;
-       RSA_free(rsa);
-
-       BN_print_fp(stdout, &keys[0].n);
-
-#endif
-
-#ifdef USE_GCRYPT
-       /* hash the db */
-       gcry_md_hash_buffer(GCRY_MD_SHA1, hash, db, dblen);
-
-       if (gcry_sexp_build(&data, NULL, "(data (flags pkcs1) (hash sha1 %b))",
-                           20, hash)) {
-               fprintf(stderr, "failed to build data expression\n");
-               return 2;
-       }
-
-       if (gcry_sexp_build(&signature, NULL, "(sig-val (rsa (s %b)))",
-                           siglen, db + dblen)) {
-               fprintf(stderr, "failed to build signature expression\n");
-               return 2;
-       }
-
-       for (i = 0; i < sizeof(keys)/sizeof(keys[0]); i++) {
-               if (gcry_mpi_scan(&mpi_e, GCRYMPI_FMT_USG,
-                                 keys[0].e, keys[0].len_e, NULL) ||
-                   gcry_mpi_scan(&mpi_n, GCRYMPI_FMT_USG,
-                                 keys[0].n, keys[0].len_n, NULL)) {
-                       fprintf(stderr, "failed to convert numbers\n");
-                       return 2;
-               }
-
-               if (gcry_sexp_build(&rsa, NULL,
-                                   "(public-key (rsa (n %m) (e %m)))",
-                                   mpi_n, mpi_e)) {
-                       fprintf(stderr, "failed to build rsa key\n");
-                       return 2;
-               }
-
-               if (!gcry_pk_verify(signature, data, rsa)) {
-                       ok = 1;
-                       break;
-               }
-       }
-
-       if (!ok) {
-               fprintf(stderr, "Database signature wrong\n");
-               return 2;
-       }
-
-#endif
+       if (!crda_verify_db_signature(db, dblen, siglen))
+               return -EINVAL;
 
        num_countries = ntohl(header->reg_country_num);
-       countries = get_file_ptr(db, dblen,
-                                sizeof(struct regdb_file_reg_country) * num_countries,
-                                header->reg_country_ptr);
+       countries = crda_get_file_ptr(db, dblen,
+                       sizeof(struct regdb_file_reg_country) * num_countries,
+                       header->reg_country_ptr);
 
        for (i = 0; i < num_countries; i++) {
                country = countries + i;
@@ -360,26 +231,31 @@ int main(int argc, char **argv)
        }
 
        if (!found_country) {
-               fprintf(stderr, "failed to find a country match in regulatory database\n");
+               fprintf(stderr, "No country match in regulatory database.\n");
                return -1;
        }
 
+       r = nl80211_init(&nlstate);
+       if (r)
+               return -EIO;
+
        msg = nlmsg_alloc();
        if (!msg) {
-               fprintf(stderr, "failed to allocate netlink msg\n");
-               return -1;
+               fprintf(stderr, "Failed to allocate netlink message.\n");
+               r = -1;
+               goto out;
        }
 
        genlmsg_put(msg, 0, 0, genl_family_get_id(nlstate.nl80211), 0,
                0, NL80211_CMD_SET_REG, 0);
 
-
-       rcoll = get_file_ptr(db, dblen, sizeof(*rcoll), country->reg_collection_ptr);
+       rcoll = crda_get_file_ptr(db, dblen, sizeof(*rcoll),
+                               country->reg_collection_ptr);
        num_rules = ntohl(rcoll->reg_rule_num);
        /* re-get pointer with sanity checking for num_rules */
-       rcoll = get_file_ptr(db, dblen,
-                            sizeof(*rcoll) + num_rules * sizeof(__be32),
-                            country->reg_collection_ptr);
+       rcoll = crda_get_file_ptr(db, dblen,
+                               sizeof(*rcoll) + num_rules * sizeof(__be32),
+                               country->reg_collection_ptr);
 
        NLA_PUT_STRING(msg, NL80211_ATTR_REG_ALPHA2, (char *) country->alpha2);
 
@@ -411,7 +287,7 @@ int main(int argc, char **argv)
        r = nl_send_auto_complete(nlstate.nl_handle, msg);
 
        if (r < 0) {
-               fprintf(stderr, "failed to send regulatory request: %d\n", r);
+               fprintf(stderr, "Failed to send regulatory request: %d\n", r);
                goto cb_out;
        }
 
@@ -422,7 +298,8 @@ int main(int argc, char **argv)
        if (!finished) {
                r = nl_wait_for_ack(nlstate.nl_handle);
                if (r < 0) {
-                       fprintf(stderr, "failed to set regulatory domain: %d\n", r);
+                       fprintf(stderr, "Failed to set regulatory domain: "
+                               "%d\n", r);
                        goto cb_out;
                }
        }