dnsproxy: Handle partial TCP messages from client
authorJukka Rissanen <jukka.rissanen@linux.intel.com>
Fri, 5 Apr 2013 09:33:08 +0000 (12:33 +0300)
committerPatrik Flykt <patrik.flykt@linux.intel.com>
Fri, 5 Apr 2013 13:05:53 +0000 (16:05 +0300)
We were not handling client sent partial TCP messages correctly.
This meant that ConnMan would block if the client using TCP would
not send full DNS request.

When a new TCP client connects in tcp_listener_event(), a new
TCP channel is created to handle a partial client message if
all of the message is not already available. The partial TCP
message is patched together in tcp_client_event(). When all of
the TCP message has been received, the message is processed in
read_tcp_data(). The client has to send the DNS request within two
seconds before the connection is closed by ConnMan.

src/dnsproxy.c

index 8366fa4..388ea4b 100644 (file)
@@ -32,6 +32,7 @@
 #include <netinet/in.h>
 #include <sys/types.h>
 #include <sys/socket.h>
+#include <fcntl.h>
 #include <netdb.h>
 #include <resolv.h>
 #include <gweb/gresolv.h>
@@ -136,6 +137,20 @@ struct listener_data {
        guint tcp6_listener_watch;
 };
 
+/*
+ * The TCP client requires some extra handling as we need to
+ * be prepared to receive also partial DNS requests.
+ */
+struct tcp_partial_client_data {
+       int family;
+       struct listener_data *ifdata;
+       GIOChannel *channel;
+       guint watch;
+       unsigned char *buf;
+       unsigned int buf_end;
+       guint timeout;
+};
+
 struct cache_data {
        time_t inserted;
        time_t valid_until;
@@ -168,6 +183,11 @@ struct domain_rr {
 } __attribute__ ((packed));
 
 /*
+ * Max length of the DNS TCP packet.
+ */
+#define TCP_MAX_BUF_LEN 4096
+
+/*
  * We limit how long the cached DNS entry stays in the cache.
  * By default the TTL (time-to-live) of the DNS response is used
  * when setting the cache entry life time. The value is in seconds.
@@ -194,6 +214,7 @@ static GSList *server_list = NULL;
 static GSList *request_list = NULL;
 static GHashTable *listener_table = NULL;
 static time_t next_refresh;
+static GHashTable *partial_tcp_req_table;
 
 static guint16 get_id()
 {
@@ -516,12 +537,19 @@ static gboolean request_timeout(gpointer user_data)
        if (req->resplen > 0 && req->resp != NULL) {
                int sk, err;
 
-               sk = get_req_udp_socket(req);
-               if (sk < 0)
-                       return FALSE;
+               if (req->protocol == IPPROTO_UDP) {
+                       sk = get_req_udp_socket(req);
+                       if (sk < 0)
+                               return FALSE;
 
-               err = sendto(sk, req->resp, req->resplen, MSG_NOSIGNAL,
-                                               &req->sa, req->sa_len);
+                       err = sendto(sk, req->resp, req->resplen, MSG_NOSIGNAL,
+                               &req->sa, req->sa_len);
+               } else {
+                       sk = req->client_sk;
+                       err = send(sk, req->resp, req->resplen, MSG_NOSIGNAL);
+                       if (err < 0)
+                               close(sk);
+               }
                if (err < 0)
                        return FALSE;
        } else if (req->request && req->numserv == 0) {
@@ -547,6 +575,16 @@ static gboolean request_timeout(gpointer user_data)
                }
        }
 
+       /*
+        * We cannot leave TCP client hanging so just kick it out
+        * if we get a request timeout from server.
+        */
+       if (req->protocol == IPPROTO_TCP) {
+               DBG("client %d removed", req->client_sk);
+               g_hash_table_remove(partial_tcp_req_table,
+                               GINT_TO_POINTER(req->client_sk));
+       }
+
        req->timeout = 0;
        destroy_request_data(req);
 
@@ -1773,7 +1811,6 @@ static int forward_dns_reply(unsigned char *reply, int reply_len, int protocol,
        } else {
                sk = req->client_sk;
                err = send(sk, req->resp, req->resplen, MSG_NOSIGNAL);
-               close(sk);
        }
 
        if (err < 0)
@@ -2553,64 +2590,95 @@ static int parse_request(unsigned char *buf, int len,
        return 0;
 }
 
-static gboolean tcp_listener_event(GIOChannel *channel, GIOCondition condition,
-                               struct listener_data *ifdata, int family,
-                               guint *listener_watch)
+static void client_reset(struct tcp_partial_client_data *client)
 {
-       unsigned char buf[768];
-       char query[512];
+       if (client == NULL)
+               return;
+
+       if (client->channel != NULL) {
+               DBG("client %d closing",
+                       g_io_channel_unix_get_fd(client->channel));
+
+               g_io_channel_unref(client->channel);
+               client->channel = NULL;
+       }
+
+       if (client->watch > 0) {
+               g_source_remove(client->watch);
+               client->watch = 0;
+       }
+
+       if (client->timeout > 0) {
+               g_source_remove(client->timeout);
+               client->timeout = 0;
+       }
+
+       g_free(client->buf);
+       client->buf = NULL;
+
+       client->buf_end = 0;
+}
+
+static unsigned int get_msg_len(unsigned char *buf)
+{
+       return buf[0]<<8 | buf[1];
+}
+
+static gboolean read_tcp_data(struct tcp_partial_client_data *client,
+                               void *client_addr, socklen_t client_addr_len,
+                               int read_len)
+{
+       char query[TCP_MAX_BUF_LEN];
        struct request_data *req;
-       int sk, client_sk, len, err;
-       struct sockaddr_in6 client_addr6;
-       socklen_t client_addr6_len = sizeof(client_addr6);
-       struct sockaddr_in client_addr4;
-       socklen_t client_addr4_len = sizeof(client_addr4);
-       void *client_addr;
-       socklen_t *client_addr_len;
+       int client_sk, err;
+       unsigned int msg_len;
        GSList *list;
        int waiting_for_connect = FALSE, qtype = 0;
        struct cache_entry *entry;
 
-       DBG("condition 0x%x", condition);
-
-       if (condition & (G_IO_NVAL | G_IO_ERR | G_IO_HUP)) {
-               if (*listener_watch > 0)
-                       g_source_remove(*listener_watch);
-               *listener_watch = 0;
-
-               connman_error("Error with TCP listener channel");
+       client_sk = g_io_channel_unix_get_fd(client->channel);
 
+       if (read_len == 0) {
+               DBG("client %d closed, pending %d bytes",
+                       client_sk, client->buf_end);
+               g_hash_table_remove(partial_tcp_req_table,
+                                       GINT_TO_POINTER(client_sk));
                return FALSE;
        }
 
-       sk = g_io_channel_unix_get_fd(channel);
+       DBG("client %d received %d bytes", client_sk, read_len);
 
-       if (family == AF_INET) {
-               client_addr = &client_addr4;
-               client_addr_len = &client_addr4_len;
-       } else {
-               client_addr = &client_addr6;
-               client_addr_len = &client_addr6_len;
-       }
+       client->buf_end += read_len;
 
-       client_sk = accept(sk, client_addr, client_addr_len);
+       if (client->buf_end < 2)
+               return TRUE;
 
-       if (client_sk < 0) {
-               connman_error("Accept failure on TCP listener");
-               *listener_watch = 0;
+       msg_len = get_msg_len(client->buf);
+       if (msg_len > TCP_MAX_BUF_LEN) {
+               DBG("client %d sent too much data %d", client_sk, msg_len);
+               g_hash_table_remove(partial_tcp_req_table,
+                                       GINT_TO_POINTER(client_sk));
                return FALSE;
        }
 
-       len = recv(client_sk, buf, sizeof(buf), 0);
-       if (len < 2)
+read_another:
+       DBG("client %d msg len %d end %d past end %d", client_sk, msg_len,
+               client->buf_end, client->buf_end - (msg_len + 2));
+
+       if (client->buf_end < (msg_len + 2)) {
+               DBG("client %d still missing %d bytes",
+                       client_sk,
+                       msg_len + 2 - client->buf_end);
                return TRUE;
+       }
 
-       DBG("Received %d bytes (id 0x%04x) from %d", len,
-               buf[2] | buf[3] << 8, client_sk);
+       DBG("client %d all data %d received", client_sk, msg_len);
 
-       err = parse_request(buf + 2, len - 2, query, sizeof(query));
+       err = parse_request(client->buf + 2, msg_len,
+                       query, sizeof(query));
        if (err < 0 || (g_slist_length(server_list) == 0)) {
-               send_response(client_sk, buf, len, NULL, 0, IPPROTO_TCP);
+               send_response(client_sk, client->buf, msg_len + 2,
+                       NULL, 0, IPPROTO_TCP);
                return TRUE;
        }
 
@@ -2618,29 +2686,29 @@ static gboolean tcp_listener_event(GIOChannel *channel, GIOCondition condition,
        if (req == NULL)
                return TRUE;
 
-       memcpy(&req->sa, client_addr, *client_addr_len);
-       req->sa_len = *client_addr_len;
+       memcpy(&req->sa, client_addr, client_addr_len);
+       req->sa_len = client_addr_len;
        req->client_sk = client_sk;
        req->protocol = IPPROTO_TCP;
-       req->family = family;
+       req->family = client->family;
 
-       req->srcid = buf[2] | (buf[3] << 8);
+       req->srcid = client->buf[2] | (client->buf[3] << 8);
        req->dstid = get_id();
        req->altid = get_id();
-       req->request_len = len;
+       req->request_len = msg_len + 2;
 
-       buf[2] = req->dstid & 0xff;
-       buf[3] = req->dstid >> 8;
+       client->buf[2] = req->dstid & 0xff;
+       client->buf[3] = req->dstid >> 8;
 
        req->numserv = 0;
-       req->ifdata = ifdata;
+       req->ifdata = client->ifdata;
        req->append_domain = FALSE;
 
        /*
         * Check if the answer is found in the cache before
         * creating sockets to the server.
         */
-       entry = cache_check(buf, &qtype, IPPROTO_TCP);
+       entry = cache_check(client->buf, &qtype, IPPROTO_TCP);
        if (entry != NULL) {
                int ttl_left = 0;
                struct cache_data *data;
@@ -2660,7 +2728,7 @@ static gboolean tcp_listener_event(GIOChannel *channel, GIOCondition condition,
                                        req->srcid, data->answers, ttl_left);
 
                        g_free(req);
-                       return TRUE;
+                       goto out;
                } else
                        DBG("data missing, ignoring cache for this query");
        }
@@ -2680,7 +2748,8 @@ static gboolean tcp_listener_event(GIOChannel *channel, GIOCondition condition,
 
        if (waiting_for_connect == FALSE) {
                /* No server is waiting for connect */
-               send_response(client_sk, buf, len, NULL, 0, IPPROTO_TCP);
+               send_response(client_sk, client->buf,
+                       req->request_len, NULL, 0, IPPROTO_TCP);
                g_free(req);
                return TRUE;
        }
@@ -2693,18 +2762,20 @@ static gboolean tcp_listener_event(GIOChannel *channel, GIOCondition condition,
         */
        req->request = g_try_malloc0(req->request_len);
        if (req->request == NULL) {
-               send_response(client_sk, buf, len, NULL, 0, IPPROTO_TCP);
+               send_response(client_sk, client->buf,
+                       req->request_len, NULL, 0, IPPROTO_TCP);
                g_free(req);
-               return TRUE;
+               goto out;
        }
-       memcpy(req->request, buf, req->request_len);
+       memcpy(req->request, client->buf, req->request_len);
 
        req->name = g_try_malloc0(sizeof(query));
        if (req->name == NULL) {
-               send_response(client_sk, buf, len, NULL, 0, IPPROTO_TCP);
+               send_response(client_sk, client->buf,
+                       req->request_len, NULL, 0, IPPROTO_TCP);
                g_free(req->request);
                g_free(req);
-               return TRUE;
+               goto out;
        }
        memcpy(req->name, query, sizeof(query));
 
@@ -2712,9 +2783,262 @@ static gboolean tcp_listener_event(GIOChannel *channel, GIOCondition condition,
 
        request_list = g_slist_append(request_list, req);
 
+out:
+       if (client->buf_end > (msg_len + 2)) {
+               DBG("client %d buf %p -> %p end %d len %d new %d",
+                       client_sk,
+                       client->buf + msg_len + 2,
+                       client->buf, client->buf_end,
+                       TCP_MAX_BUF_LEN - client->buf_end,
+                       client->buf_end - (msg_len + 2));
+               memmove(client->buf, client->buf + msg_len + 2,
+                       TCP_MAX_BUF_LEN - client->buf_end);
+               client->buf_end = client->buf_end - (msg_len + 2);
+
+               /*
+                * If we have a full message waiting, just read it
+                * immediately.
+                */
+               msg_len = get_msg_len(client->buf);
+               if ((msg_len + 2) == client->buf_end) {
+                       DBG("client %d reading another %d bytes", client_sk,
+                                                               msg_len + 2);
+                       goto read_another;
+               }
+       } else {
+               DBG("client %d clearing reading buffer", client_sk);
+
+               client->buf_end = 0;
+               memset(client->buf, 0, TCP_MAX_BUF_LEN);
+
+               /*
+                * We received all the packets from client so we must also
+                * remove the timeout handler here otherwise we might get
+                * timeout while waiting the results from server.
+                */
+               g_source_remove(client->timeout);
+               client->timeout = 0;
+       }
+
        return TRUE;
 }
 
+static gboolean tcp_client_event(GIOChannel *channel, GIOCondition condition,
+                               gpointer user_data)
+{
+       struct tcp_partial_client_data *client = user_data;
+       struct sockaddr_in6 client_addr6;
+       socklen_t client_addr6_len = sizeof(client_addr6);
+       struct sockaddr_in client_addr4;
+       socklen_t client_addr4_len = sizeof(client_addr4);
+       void *client_addr;
+       socklen_t *client_addr_len;
+       int len, client_sk;
+
+       client_sk = g_io_channel_unix_get_fd(channel);
+
+       if (condition & (G_IO_NVAL | G_IO_ERR | G_IO_HUP)) {
+               g_hash_table_remove(partial_tcp_req_table,
+                                       GINT_TO_POINTER(client_sk));
+
+               connman_error("Error with TCP client %d channel", client_sk);
+               return FALSE;
+       }
+
+       switch (client->family) {
+       case AF_INET:
+               client_addr = &client_addr4;
+               client_addr_len = &client_addr4_len;
+               break;
+       case AF_INET6:
+               client_addr = &client_addr6;
+               client_addr_len = &client_addr6_len;
+               break;
+       default:
+               g_hash_table_remove(partial_tcp_req_table,
+                                       GINT_TO_POINTER(client_sk));
+               connman_error("client %p corrupted", client);
+               return FALSE;
+       }
+
+       len = recvfrom(client_sk, client->buf + client->buf_end,
+                       TCP_MAX_BUF_LEN - client->buf_end, 0,
+                       client_addr, client_addr_len);
+       if (len < 0) {
+               if (errno == EAGAIN || errno == EWOULDBLOCK)
+                       return TRUE;
+
+               DBG("client %d cannot read errno %d/%s", client_sk, -errno,
+                       strerror(errno));
+               g_hash_table_remove(partial_tcp_req_table,
+                                       GINT_TO_POINTER(client_sk));
+               return FALSE;
+       }
+
+       return read_tcp_data(client, client_addr, *client_addr_len, len);
+}
+
+static gboolean client_timeout(gpointer user_data)
+{
+       struct tcp_partial_client_data *client = user_data;
+       int sock;
+
+       sock = g_io_channel_unix_get_fd(client->channel);
+
+       DBG("client %d timeout pending %d bytes", sock, client->buf_end);
+
+       g_hash_table_remove(partial_tcp_req_table, GINT_TO_POINTER(sock));
+
+       return FALSE;
+}
+
+static gboolean tcp_listener_event(GIOChannel *channel, GIOCondition condition,
+                               struct listener_data *ifdata, int family,
+                               guint *listener_watch)
+{
+       int sk, client_sk, len;
+       unsigned int msg_len;
+       struct tcp_partial_client_data *client;
+       struct sockaddr_in6 client_addr6;
+       socklen_t client_addr6_len = sizeof(client_addr6);
+       struct sockaddr_in client_addr4;
+       socklen_t client_addr4_len = sizeof(client_addr4);
+       void *client_addr;
+       socklen_t *client_addr_len;
+       struct timeval tv;
+       fd_set readfds;
+
+       DBG("condition 0x%02x channel %p ifdata %p family %d",
+               condition, channel, ifdata, family);
+
+       if (condition & (G_IO_NVAL | G_IO_ERR | G_IO_HUP)) {
+               if (*listener_watch > 0)
+                       g_source_remove(*listener_watch);
+               *listener_watch = 0;
+
+               connman_error("Error with TCP listener channel");
+
+               return FALSE;
+       }
+
+       sk = g_io_channel_unix_get_fd(channel);
+
+       if (family == AF_INET) {
+               client_addr = &client_addr4;
+               client_addr_len = &client_addr4_len;
+       } else {
+               client_addr = &client_addr6;
+               client_addr_len = &client_addr6_len;
+       }
+
+       tv.tv_sec = tv.tv_usec = 0;
+       FD_ZERO(&readfds);
+       FD_SET(sk, &readfds);
+
+       select(sk + 1, &readfds, NULL, NULL, &tv);
+       if (FD_ISSET(sk, &readfds)) {
+               client_sk = accept(sk, client_addr, client_addr_len);
+               DBG("client %d accepted", client_sk);
+       } else {
+               DBG("No data to read from master %d, waiting.", sk);
+               return TRUE;
+       }
+
+       if (client_sk < 0) {
+               connman_error("Accept failure on TCP listener");
+               *listener_watch = 0;
+               return FALSE;
+       }
+
+       fcntl(client_sk, F_SETFL, O_NONBLOCK);
+
+       client = g_hash_table_lookup(partial_tcp_req_table,
+                                       GINT_TO_POINTER(client_sk));
+       if (client == NULL) {
+               client = g_try_new0(struct tcp_partial_client_data, 1);
+               if (client == NULL)
+                       return FALSE;
+
+               g_hash_table_insert(partial_tcp_req_table,
+                                       GINT_TO_POINTER(client_sk),
+                                       client);
+
+               client->channel = g_io_channel_unix_new(client_sk);
+               g_io_channel_set_close_on_unref(client->channel, TRUE);
+
+               client->watch = g_io_add_watch(client->channel,
+                                               G_IO_IN, tcp_client_event,
+                                               (gpointer)client);
+
+               client->ifdata = ifdata;
+
+               DBG("client %d created %p", client_sk, client);
+       } else {
+               DBG("client %d already exists %p", client_sk, client);
+       }
+
+       if (client->buf == NULL) {
+               client->buf = g_try_malloc(TCP_MAX_BUF_LEN);
+               if (client->buf == NULL)
+                       return FALSE;
+       }
+       memset(client->buf, 0, TCP_MAX_BUF_LEN);
+       client->buf_end = 0;
+       client->family = family;
+
+       if (client->timeout == 0)
+               client->timeout = g_timeout_add_seconds(2, client_timeout,
+                                                       client);
+
+       /*
+        * Check how much data there is. If all is there, then we can
+        * proceed normally, otherwise read the bits until everything
+        * is received or timeout occurs.
+        */
+       len = recv(client_sk, client->buf, TCP_MAX_BUF_LEN, 0);
+       if (len < 0) {
+               if (errno == EAGAIN || errno == EWOULDBLOCK) {
+                       DBG("client %d no data to read, waiting", client_sk);
+                       return TRUE;
+               }
+
+               DBG("client %d cannot read errno %d/%s", client_sk, -errno,
+                       strerror(errno));
+               g_hash_table_remove(partial_tcp_req_table,
+                                       GINT_TO_POINTER(client_sk));
+               return TRUE;
+       }
+
+       if (len < 2) {
+               DBG("client %d not enough data to read, waiting", client_sk);
+               client->buf_end += len;
+               return TRUE;
+       }
+
+       msg_len = get_msg_len(client->buf);
+       if (msg_len > TCP_MAX_BUF_LEN) {
+               DBG("client %d invalid message length %u ignoring packet",
+                       client_sk, msg_len);
+               g_hash_table_remove(partial_tcp_req_table,
+                                       GINT_TO_POINTER(client_sk));
+               return TRUE;
+       }
+
+       /*
+        * The packet length bytes do not contain the total message length,
+        * that is the reason to -2 below.
+        */
+       if (msg_len != (unsigned int)(len - 2)) {
+               DBG("client %d sent %d bytes but expecting %u pending %d",
+                       client_sk, len, msg_len + 2, msg_len + 2 - len);
+
+               client->buf_end += len;
+               return TRUE;
+       }
+
+       return read_tcp_data(client, client_addr, *client_addr_len, len);
+}
+
 static gboolean tcp4_listener_event(GIOChannel *channel, GIOCondition condition,
                                gpointer user_data)
 {
@@ -2925,11 +3249,16 @@ static GIOChannel *get_listener(int family, int protocol, int index)
                return NULL;
        }
 
-       if (protocol == IPPROTO_TCP && listen(sk, 10) < 0) {
-               connman_error("Failed to listen on TCP socket %d/%s", -errno,
-                       strerror(errno));
-               close(sk);
-               return NULL;
+       if (protocol == IPPROTO_TCP) {
+
+               if (listen(sk, 10) < 0) {
+                       connman_error("Failed to listen on TCP socket %d/%s",
+                               -errno, strerror(errno));
+                       close(sk);
+                       return NULL;
+               }
+
+               fcntl(sk, F_SETFL, O_NONBLOCK);
        }
 
        channel = g_io_channel_unix_new(sk);
@@ -3152,6 +3481,14 @@ static void remove_listener(gpointer key, gpointer value, gpointer user_data)
        destroy_listener(ifdata);
 }
 
+static void free_partial_reqs(gpointer value)
+{
+       struct tcp_partial_client_data *data = value;
+
+       client_reset(data);
+       g_free(data);
+}
+
 int __connman_dnsproxy_init(void)
 {
        int err, index;
@@ -3163,6 +3500,11 @@ int __connman_dnsproxy_init(void)
        listener_table = g_hash_table_new_full(g_direct_hash, g_direct_equal,
                                                        NULL, g_free);
 
+       partial_tcp_req_table = g_hash_table_new_full(g_direct_hash,
+                                                       g_direct_equal,
+                                                       NULL,
+                                                       free_partial_reqs);
+
        index = connman_inet_ifindex("lo");
        err = __connman_dnsproxy_add_listener(index);
        if (err < 0)
@@ -3177,6 +3519,7 @@ int __connman_dnsproxy_init(void)
 destroy:
        __connman_dnsproxy_remove_listener(index);
        g_hash_table_destroy(listener_table);
+       g_hash_table_destroy(partial_tcp_req_table);
 
        return err;
 }
@@ -3190,4 +3533,6 @@ void __connman_dnsproxy_cleanup(void)
        g_hash_table_foreach(listener_table, remove_listener, NULL);
 
        g_hash_table_destroy(listener_table);
+
+       g_hash_table_destroy(partial_tcp_req_table);
 }