dnsproxy: Do not copy two extra bytes if using TCP
[framework/connectivity/connman.git] / src / dnsproxy.c
index cf907bc..21512c6 100644 (file)
@@ -35,9 +35,6 @@
 
 #include <glib.h>
 
-#define CONNMAN_API_SUBJECT_TO_CHANGE
-#include <connman/ondemand.h>
-
 #include "connman.h"
 
 #if __BYTE_ORDER == __LITTLE_ENDIAN
@@ -115,17 +112,23 @@ struct request_data {
        gpointer name;
        gpointer resp;
        gsize resplen;
+       struct listener_data *ifdata;
+       gboolean append_domain;
+};
+
+struct listener_data {
+       char *ifname;
+       GIOChannel *udp_listener_channel;
+       guint udp_listener_watch;
+       GIOChannel *tcp_listener_channel;
+       guint tcp_listener_watch;
 };
 
 static GSList *server_list = NULL;
 static GSList *request_list = NULL;
 static GSList *request_pending_list = NULL;
 static guint16 request_id = 0x0000;
-
-static GIOChannel *udp_listener_channel = NULL;
-static guint udp_listener_watch = 0;
-static GIOChannel *tcp_listener_channel = NULL;
-static guint tcp_listener_watch = 0;
+static GHashTable *listener_table = NULL;
 
 static int protocol_offset(int protocol)
 {
@@ -167,7 +170,13 @@ static struct server_data *find_server(const char *interface,
        for (list = server_list; list; list = list->next) {
                struct server_data *data = list->data;
 
-               if (data->interface == NULL || data->server == NULL)
+               if (interface == NULL && data->interface == NULL &&
+                               g_str_equal(data->server, server) == TRUE &&
+                               data->protocol == protocol)
+                       return data;
+
+               if (interface == NULL ||
+                               data->interface == NULL || data->server == NULL)
                        continue;
 
                if (g_str_equal(data->interface, interface) == TRUE &&
@@ -207,27 +216,37 @@ static void send_response(int sk, unsigned char *buf, int len,
        hdr->arcount = 0;
 
        err = sendto(sk, buf, len, 0, to, tolen);
+       if (err < 0) {
+               connman_error("Failed to send DNS response: %s",
+                               strerror(errno));
+               return;
+       }
 }
 
 static gboolean request_timeout(gpointer user_data)
 {
        struct request_data *req = user_data;
+       struct listener_data *ifdata;
 
        DBG("id 0x%04x", req->srcid);
 
        if (req == NULL)
                return FALSE;
 
+       ifdata = req->ifdata;
+
        request_list = g_slist_remove(request_list, req);
        req->numserv--;
 
        if (req->resplen > 0 && req->resp != NULL) {
                int sk, err;
 
-               sk = g_io_channel_unix_get_fd(udp_listener_channel);
+               sk = g_io_channel_unix_get_fd(ifdata->udp_listener_channel);
 
                err = sendto(sk, req->resp, req->resplen, 0,
-                            &req->sa, req->sa_len);
+                                               &req->sa, req->sa_len);
+               if (err < 0)
+                       return FALSE;
        } else if (req->request && req->numserv == 0) {
                struct domain_hdr *hdr;
 
@@ -235,16 +254,17 @@ static gboolean request_timeout(gpointer user_data)
                        hdr = (void *) (req->request + 2);
                        hdr->id = req->srcid;
                        send_response(req->client_sk, req->request,
-                                       req->request_len, NULL, 0, IPPROTO_TCP);
+                               req->request_len, NULL, 0, IPPROTO_TCP);
 
                } else if (req->protocol == IPPROTO_UDP) {
                        int sk;
 
                        hdr = (void *) (req->request);
                        hdr->id = req->srcid;
-                       sk = g_io_channel_unix_get_fd(udp_listener_channel);
+                       sk = g_io_channel_unix_get_fd(
+                                               ifdata->udp_listener_channel);
                        send_response(sk, req->request, req->request_len,
-                                     &req->sa, req->sa_len, IPPROTO_UDP);
+                                       &req->sa, req->sa_len, IPPROTO_UDP);
                }
        }
 
@@ -259,6 +279,7 @@ static int append_query(unsigned char *buf, unsigned int size,
 {
        unsigned char *ptr = buf;
        char *offset;
+       int len;
 
        DBG("query %s domain %s", query, domain);
 
@@ -268,11 +289,12 @@ static int append_query(unsigned char *buf, unsigned int size,
 
                tmp = strchr(offset, '.');
                if (tmp == NULL) {
-                       if (strlen(offset) == 0)
+                       len = strlen(offset);
+                       if (len == 0)
                                break;
-                       *ptr = strlen(offset);
-                       memcpy(ptr + 1, offset, strlen(offset));
-                       ptr += strlen(offset) + 1;
+                       *ptr = len;
+                       memcpy(ptr + 1, offset, len);
+                       ptr += len + 1;
                        break;
                }
 
@@ -289,11 +311,12 @@ static int append_query(unsigned char *buf, unsigned int size,
 
                tmp = strchr(offset, '.');
                if (tmp == NULL) {
-                       if (strlen(offset) == 0)
+                       len = strlen(offset);
+                       if (len == 0)
                                break;
-                       *ptr = strlen(offset);
-                       memcpy(ptr + 1, offset, strlen(offset));
-                       ptr += strlen(offset) + 1;
+                       *ptr = len;
+                       memcpy(ptr + 1, offset, len);
+                       ptr += len + 1;
                        break;
                }
 
@@ -314,6 +337,7 @@ static int ns_resolv(struct server_data *server, struct request_data *req,
 {
        GList *list;
        int sk, err;
+       char *dot, *lookup = (char *) name;
 
        sk = g_io_channel_unix_get_fd(server->channel);
 
@@ -321,6 +345,14 @@ static int ns_resolv(struct server_data *server, struct request_data *req,
 
        req->numserv++;
 
+       /* If we have more than one dot, we don't add domains */
+       dot = strchr(lookup, '.');
+       if (dot != NULL && dot != lookup + strlen(lookup) - 1)
+               return 0;
+
+       if (server->domains != NULL && server->domains->data != NULL)
+               req->append_domain = TRUE;
+
        for (list = server->domains; list; list = list->next) {
                char *domain;
                unsigned char alt[1024];
@@ -355,16 +387,18 @@ static int ns_resolv(struct server_data *server, struct request_data *req,
 
                memcpy(alt + offset + altlen,
                        request + offset + altlen - domlen,
-                               req->request_len - altlen + domlen);
+                               req->request_len - altlen - offset + domlen);
 
                if (server->protocol == IPPROTO_TCP) {
-                       int req_len = req->request_len + domlen - 1;
+                       int req_len = req->request_len + domlen - 2;
 
                        alt[0] = (req_len >> 8) & 0xff;
                        alt[1] = req_len & 0xff;
                }
 
-               err = send(sk, alt, req->request_len + domlen + 1, 0);
+               err = send(sk, alt, req->request_len + domlen, 0);
+               if (err < 0)
+                       return -EIO;
 
                req->numserv++;
        }
@@ -377,6 +411,7 @@ static int forward_dns_reply(unsigned char *reply, int reply_len, int protocol)
        struct domain_hdr *hdr;
        struct request_data *req;
        int dns_id, sk, err, offset = protocol_offset(protocol);
+       struct listener_data *ifdata;
 
        if (offset < 0)
                return offset;
@@ -392,12 +427,43 @@ static int forward_dns_reply(unsigned char *reply, int reply_len, int protocol)
 
        DBG("id 0x%04x rcode %d", hdr->id, hdr->rcode);
 
+       ifdata = req->ifdata;
+
        reply[offset] = req->srcid & 0xff;
        reply[offset + 1] = req->srcid >> 8;
 
        req->numresp++;
 
        if (hdr->rcode == 0 || req->resp == NULL) {
+
+               /*
+                * If the domain name was append
+                * remove it before forwarding the reply.
+                */
+               if (req->append_domain == TRUE) {
+                       unsigned char *ptr;
+                       uint8_t host_len;
+                       unsigned int domain_len;
+
+                       /*
+                        * ptr points to the first char of the hostname.
+                        * ->hostname.domain.net
+                        */
+                       ptr = reply + offset + sizeof(struct domain_hdr);
+                       host_len = *ptr;
+                       domain_len = strlen((const char *)ptr) - host_len - 1;
+
+                       /*
+                        * remove the domain name and replaced it by the end
+                        * of reply.
+                        */
+                       memmove(ptr + host_len + 1,
+                               ptr + host_len + domain_len + 1,
+                               reply_len - (ptr - reply + domain_len));
+
+                       reply_len = reply_len - domain_len;
+               }
+
                g_free(req->resp);
                req->resplen = 0;
 
@@ -418,7 +484,7 @@ static int forward_dns_reply(unsigned char *reply, int reply_len, int protocol)
        request_list = g_slist_remove(request_list, req);
 
        if (protocol == IPPROTO_UDP) {
-               sk = g_io_channel_unix_get_fd(udp_listener_channel);
+               sk = g_io_channel_unix_get_fd(ifdata->udp_listener_channel);
                err = sendto(sk, req->resp, req->resplen, 0,
                             &req->sa, req->sa_len);
        } else {
@@ -485,6 +551,8 @@ static gboolean udp_server_event(GIOChannel *channel, GIOCondition condition,
                return TRUE;
 
        err = forward_dns_reply(buf, len, IPPROTO_UDP);
+       if (err < 0)
+               return TRUE;
 
        return TRUE;
 }
@@ -532,7 +600,7 @@ hangup:
                        hdr = (void *) (req->request + 2);
                        hdr->id = req->srcid;
                        send_response(req->client_sk, req->request,
-                                       req->request_len, NULL, 0, IPPROTO_TCP);
+                               req->request_len, NULL, 0, IPPROTO_TCP);
 
                        request_list = g_slist_remove(request_list, req);
                }
@@ -833,6 +901,9 @@ static void append_domain(const char *interface, const char *domain)
 
        DBG("interface %s domain %s", interface, domain);
 
+       if (domain == NULL)
+               return;
+
        for (list = server_list; list; list = list->next) {
                struct server_data *data = list->data;
                GList *dom_list;
@@ -862,7 +933,7 @@ static void append_domain(const char *interface, const char *domain)
        }
 }
 
-static int dnsproxy_append(const char *interface, const char *domain,
+int __connman_dnsproxy_append(const char *interface, const char *domain,
                                                        const char *server)
 {
        struct server_data *data;
@@ -906,7 +977,7 @@ static void remove_server(const char *interface, const char *domain,
        destroy_server(data);
 }
 
-static int dnsproxy_remove(const char *interface, const char *domain,
+int __connman_dnsproxy_remove(const char *interface, const char *domain,
                                                        const char *server)
 {
        DBG("interface %s server %s", interface, server);
@@ -923,7 +994,7 @@ static int dnsproxy_remove(const char *interface, const char *domain,
        return 0;
 }
 
-static void dnsproxy_flush(void)
+void __connman_dnsproxy_flush(void)
 {
        GSList *list;
 
@@ -941,14 +1012,6 @@ static void dnsproxy_flush(void)
        }
 }
 
-static struct connman_resolver dnsproxy_resolver = {
-       .name           = "dnsproxy",
-       .priority       = CONNMAN_RESOLVER_PRIORITY_HIGH,
-       .append         = dnsproxy_append,
-       .remove         = dnsproxy_remove,
-       .flush          = dnsproxy_flush,
-};
-
 static void dnsproxy_offline_mode(connman_bool_t enabled)
 {
        GSList *list;
@@ -1091,13 +1154,14 @@ static gboolean tcp_listener_event(GIOChannel *channel, GIOCondition condition,
        struct sockaddr_in6 client_addr;
        socklen_t client_addr_len = sizeof(client_addr);
        GSList *list;
+       struct listener_data *ifdata = user_data;
 
        DBG("condition 0x%x", condition);
 
        if (condition & (G_IO_NVAL | G_IO_ERR | G_IO_HUP)) {
-               if (tcp_listener_watch > 0)
-                       g_source_remove(tcp_listener_watch);
-               tcp_listener_watch = 0;
+               if (ifdata->tcp_listener_watch > 0)
+                       g_source_remove(ifdata->tcp_listener_watch);
+               ifdata->tcp_listener_watch = 0;
 
                connman_error("Error with TCP listener channel");
 
@@ -1109,7 +1173,7 @@ static gboolean tcp_listener_event(GIOChannel *channel, GIOCondition condition,
        client_sk = accept(sk, (void *)&client_addr, &client_addr_len);
        if (client_sk < 0) {
                connman_error("Accept failure on TCP listener");
-               tcp_listener_watch = 0;
+               ifdata->tcp_listener_watch = 0;
                return FALSE;
        }
 
@@ -1120,8 +1184,7 @@ static gboolean tcp_listener_event(GIOChannel *channel, GIOCondition condition,
        DBG("Received %d bytes (id 0x%04x)", len, buf[2] | buf[3] << 8);
 
        err = parse_request(buf + 2, len - 2, query, sizeof(query));
-       if (err < 0 || (g_slist_length(server_list) == 0 &&
-                               connman_ondemand_connected())) {
+       if (err < 0 || (g_slist_length(server_list) == 0)) {
                send_response(client_sk, buf, len, NULL, 0, IPPROTO_TCP);
                return TRUE;
        }
@@ -1148,6 +1211,8 @@ static gboolean tcp_listener_event(GIOChannel *channel, GIOCondition condition,
        buf[3] = req->dstid >> 8;
 
        req->numserv = 0;
+       req->ifdata = (struct listener_data *) ifdata;
+       req->append_domain = FALSE;
        request_list = g_slist_append(request_list, req);
 
        for (list = server_list; list; list = list->next) {
@@ -1213,10 +1278,11 @@ static gboolean udp_listener_event(GIOChannel *channel, GIOCondition condition,
        struct sockaddr_in6 client_addr;
        socklen_t client_addr_len = sizeof(client_addr);
        int sk, err, len;
+       struct listener_data *ifdata = user_data;
 
        if (condition & (G_IO_NVAL | G_IO_ERR | G_IO_HUP)) {
                connman_error("Error with UDP listener channel");
-               udp_listener_watch = 0;
+               ifdata->udp_listener_watch = 0;
                return FALSE;
        }
 
@@ -1231,10 +1297,9 @@ static gboolean udp_listener_event(GIOChannel *channel, GIOCondition condition,
        DBG("Received %d bytes (id 0x%04x)", len, buf[0] | buf[1] << 8);
 
        err = parse_request(buf, len, query, sizeof(query));
-       if (err < 0 || (g_slist_length(server_list) == 0 &&
-                               connman_ondemand_connected())) {
+       if (err < 0 || (g_slist_length(server_list) == 0)) {
                send_response(sk, buf, len, (void *)&client_addr,
-                             client_addr_len, IPPROTO_UDP);
+                               client_addr_len, IPPROTO_UDP);
                return TRUE;
        }
 
@@ -1259,45 +1324,19 @@ static gboolean udp_listener_event(GIOChannel *channel, GIOCondition condition,
        buf[0] = req->dstid & 0xff;
        buf[1] = req->dstid >> 8;
 
-       if (!connman_ondemand_connected()) {
-               DBG("Starting on demand connection");
-               /*
-                * We're not connected, let's queue the request and start
-                * an on-demand connection.
-                */
-               req->request = g_try_malloc0(req->request_len);
-               if (req->request == NULL)
-                       return TRUE;
-
-               memcpy(req->request, buf, req->request_len);
-
-               req->name = g_try_malloc0(sizeof(query));
-               if (req->name == NULL) {
-                       g_free(req->request);
-                       return TRUE;
-               }
-               memcpy(req->name, query, sizeof(query));
-
-               request_pending_list = g_slist_append(request_pending_list,
-                                                                       req);
-
-               connman_ondemand_start("", 300);
-
-               return TRUE;
-       }
-
-
        req->numserv = 0;
+       req->ifdata = (struct listener_data *) ifdata;
        req->timeout = g_timeout_add_seconds(5, request_timeout, req);
+       req->append_domain = FALSE;
        request_list = g_slist_append(request_list, req);
 
        return resolv(req, buf, query);
 }
 
-static int create_dns_listener(int protocol)
+static int create_dns_listener(int protocol, struct listener_data *ifdata)
 {
        GIOChannel *channel;
-       const char *ifname = "lo", *proto;
+       const char *proto;
        union {
                struct sockaddr sa;
                struct sockaddr_in6 sin6;
@@ -1307,17 +1346,18 @@ static int create_dns_listener(int protocol)
        int sk, type, v6only = 0;
        int family = AF_INET6;
 
-       DBG("");
+
+       DBG("interface %s", ifdata->ifname);
 
        switch (protocol) {
        case IPPROTO_UDP:
                proto = "UDP";
-               type = SOCK_DGRAM;
+               type = SOCK_DGRAM | SOCK_CLOEXEC;
                break;
 
        case IPPROTO_TCP:
                proto = "TCP";
-               type = SOCK_STREAM;
+               type = SOCK_STREAM | SOCK_CLOEXEC;
                break;
 
        default:
@@ -1336,7 +1376,8 @@ static int create_dns_listener(int protocol)
        }
 
        if (setsockopt(sk, SOL_SOCKET, SO_BINDTODEVICE,
-                                       ifname, strlen(ifname) + 1) < 0) {
+                                       ifdata->ifname,
+                                       strlen(ifdata->ifname) + 1) < 0) {
                connman_error("Failed to bind %s listener interface", proto);
                close(sk);
                return -EIO;
@@ -1387,62 +1428,64 @@ static int create_dns_listener(int protocol)
        g_io_channel_set_close_on_unref(channel, TRUE);
 
        if (protocol == IPPROTO_TCP) {
-               tcp_listener_channel = channel;
-               tcp_listener_watch = g_io_add_watch(channel,
-                                       G_IO_IN, tcp_listener_event, NULL);
+               ifdata->tcp_listener_channel = channel;
+               ifdata->tcp_listener_watch = g_io_add_watch(channel,
+                               G_IO_IN, tcp_listener_event, (gpointer) ifdata);
        } else {
-               udp_listener_channel = channel;
-               udp_listener_watch = g_io_add_watch(channel,
-                                       G_IO_IN, udp_listener_event, NULL);
+               ifdata->udp_listener_channel = channel;
+               ifdata->udp_listener_watch = g_io_add_watch(channel,
+                               G_IO_IN, udp_listener_event, (gpointer) ifdata);
        }
 
        return 0;
 }
 
-static void destroy_udp_listener(void)
+static void destroy_udp_listener(struct listener_data *ifdata)
 {
-       DBG("");
+       DBG("interface %s", ifdata->ifname);
 
-       if (udp_listener_watch > 0)
-               g_source_remove(udp_listener_watch);
+       if (ifdata->udp_listener_watch > 0)
+               g_source_remove(ifdata->udp_listener_watch);
 
-       g_io_channel_unref(udp_listener_channel);
+       g_io_channel_unref(ifdata->udp_listener_channel);
 }
 
-static void destroy_tcp_listener(void)
+static void destroy_tcp_listener(struct listener_data *ifdata)
 {
-       DBG("");
+       DBG("interface %s", ifdata->ifname);
 
-       if (tcp_listener_watch > 0)
-               g_source_remove(tcp_listener_watch);
+       if (ifdata->tcp_listener_watch > 0)
+               g_source_remove(ifdata->tcp_listener_watch);
 
-       g_io_channel_unref(tcp_listener_channel);
+       g_io_channel_unref(ifdata->tcp_listener_channel);
 }
 
-static int create_listener(void)
+static int create_listener(struct listener_data *ifdata)
 {
        int err;
 
-       err = create_dns_listener(IPPROTO_UDP);
+       err = create_dns_listener(IPPROTO_UDP, ifdata);
        if (err < 0)
                return err;
 
-       err = create_dns_listener(IPPROTO_TCP);
+       err = create_dns_listener(IPPROTO_TCP, ifdata);
        if (err < 0) {
-               destroy_udp_listener();
+               destroy_udp_listener(ifdata);
                return err;
        }
 
-       connman_resolver_append("lo", NULL, "127.0.0.1");
+       if (g_strcmp0(ifdata->ifname, "lo") == 0)
+               __connman_resolvfile_append("lo", NULL, "127.0.0.1");
 
        return 0;
 }
 
-static void destroy_listener(void)
+static void destroy_listener(struct listener_data *ifdata)
 {
        GSList *list;
 
-       connman_resolver_remove_all("lo");
+       if (g_strcmp0(ifdata->ifname, "lo") == 0)
+               __connman_resolvfile_remove("lo", NULL, "127.0.0.1");
 
        for (list = request_pending_list; list; list = list->next) {
                struct request_data *req = list->data;
@@ -1476,42 +1519,94 @@ static void destroy_listener(void)
        g_slist_free(request_list);
        request_list = NULL;
 
-       destroy_tcp_listener();
-       destroy_udp_listener();
+       destroy_tcp_listener(ifdata);
+       destroy_udp_listener(ifdata);
 }
 
-int __connman_dnsproxy_init(void)
+int __connman_dnsproxy_add_listener(const char *interface)
 {
+       struct listener_data *ifdata;
        int err;
 
-       err = create_listener();
-       if (err < 0)
+       DBG("interface %s", interface);
+
+       if (g_hash_table_lookup(listener_table, interface) != NULL)
+               return 0;
+
+       ifdata = g_try_new0(struct listener_data, 1);
+       if (ifdata == NULL)
+               return -ENOMEM;
+
+       ifdata->ifname = g_strdup(interface);
+       ifdata->udp_listener_channel = NULL;
+       ifdata->udp_listener_watch = 0;
+       ifdata->tcp_listener_channel = NULL;
+       ifdata->tcp_listener_watch = 0;
+
+       err = create_listener(ifdata);
+       if (err < 0) {
+               connman_error("Couldn't create listener for %s err %d",
+                               interface, err);
+               g_free(ifdata->ifname);
+               g_free(ifdata);
                return err;
+       }
+       g_hash_table_insert(listener_table, ifdata->ifname, ifdata);
+       return 0;
+}
+
+void __connman_dnsproxy_remove_listener(const char *interface)
+{
+       struct listener_data *ifdata;
 
-       err = connman_resolver_register(&dnsproxy_resolver);
+       DBG("interface %s", interface);
+
+       ifdata = g_hash_table_lookup(listener_table, interface);
+       if (ifdata == NULL)
+               return;
+
+       destroy_listener(ifdata);
+
+       g_hash_table_remove(listener_table, interface);
+}
+
+static void remove_listener(gpointer key, gpointer value, gpointer user_data)
+{
+       __connman_dnsproxy_remove_listener(key);
+}
+
+int __connman_dnsproxy_init(void)
+{
+       int err;
+
+       DBG("");
+
+       listener_table = g_hash_table_new_full(g_str_hash, g_str_equal,
+                                                       g_free, g_free);
+       err = __connman_dnsproxy_add_listener("lo");
        if (err < 0)
-               goto destroy;
+               return err;
 
        err = connman_notifier_register(&dnsproxy_notifier);
        if (err < 0)
-               goto unregister;
+               goto destroy;
 
        return 0;
 
-unregister:
-       connman_resolver_unregister(&dnsproxy_resolver);
-
 destroy:
-       destroy_listener();
+       __connman_dnsproxy_remove_listener("lo");
+       g_hash_table_destroy(listener_table);
 
        return err;
 }
 
 void __connman_dnsproxy_cleanup(void)
 {
+       DBG("");
+
        connman_notifier_unregister(&dnsproxy_notifier);
 
-       connman_resolver_unregister(&dnsproxy_resolver);
+       g_hash_table_foreach(listener_table, remove_listener, NULL);
 
-       destroy_listener();
+       g_hash_table_destroy(listener_table);
 }