Add compression support
authorDavid Woodhouse <David.Woodhouse@intel.com>
Mon, 22 Sep 2008 09:31:03 +0000 (02:31 -0700)
committerDavid Woodhouse <David.Woodhouse@intel.com>
Mon, 22 Sep 2008 09:31:03 +0000 (02:31 -0700)
anyconnect.h
main.c
mainloop.c
ssl.c

index df364d3..81f720f 100644 (file)
 #include <openssl/ssl.h>
 #include <poll.h>
 #include <zlib.h>
+#include <stdint.h>
 
 struct pkt {
        int type;
        int len;
        struct pkt *next;
-       unsigned char hdr[8];
        unsigned char data[];
 };
        
@@ -51,7 +51,9 @@ struct anyconnect_info {
        time_t last_ssl_tx;
 
        z_stream inflate_strm;
+       uint32_t inflate_adler32;
        z_stream deflate_strm;
+       uint32_t deflate_adler32;
 
        unsigned char dtls_secret[48];
        SSL_CTX *dtls_ctx;
@@ -74,6 +76,8 @@ struct anyconnect_info {
 
        int deflate;
        const char *useragent;
+
+       char *quit_reason;
 };
 
 /* tun.c */
@@ -97,4 +101,4 @@ extern int verbose;
 int vpn_add_pollfd(struct anyconnect_info *vpninfo, int fd, short events);
 int vpn_mainloop(struct anyconnect_info *vpninfo);
 int queue_new_packet(struct pkt **q, int type, void *buf, int len);
-
+int inflate_and_queue_packet(struct anyconnect_info *vpninfo, int type, void *buf, int len);
diff --git a/main.c b/main.c
index 829461a..d42090c 100644 (file)
--- a/main.c
+++ b/main.c
@@ -114,8 +114,7 @@ int main(int argc, char **argv)
                        break;
 
                case 'd':
-                       fprintf(stderr, "Deflate not yet supported\n");
-                       //vpninfo->deflate = 1;
+                       vpninfo->deflate = 1;
                        break;
                }
        }
@@ -124,6 +123,17 @@ int main(int argc, char **argv)
                exit(1);
        }
 
+       if (vpninfo->deflate) {
+               if (inflateInit2(&vpninfo->inflate_strm, -12) ||
+                   deflateInit2(&vpninfo->deflate_strm, Z_DEFAULT_COMPRESSION,
+                                Z_DEFLATED, -12, 9, Z_DEFAULT_STRATEGY)) {
+                       fprintf(stderr, "Compression setup failed\n");
+                       vpninfo->deflate = 0;
+               }
+       }
+       vpninfo->deflate_adler32 = 1;
+       vpninfo->inflate_adler32 = 1;
+
        if (make_ssl_connection(vpninfo)) {
                fprintf(stderr, "Creating SSL connection failed\n");
                exit(1);
index bd4c6ec..df07c09 100644 (file)
 #include <limits.h>
 #include <sys/select.h>
 #include <signal.h>
+#include <arpa/inet.h>
 
 #include "anyconnect.h"
 
+int inflate_and_queue_packet(struct anyconnect_info *vpninfo, int type, void *buf, int len)
+{
+       struct pkt **q = &vpninfo->incoming_queue;
+
+       while (*q)
+               q = &(*q)->next;
+
+       *q = malloc(sizeof(struct pkt) + vpninfo->mtu);
+       if (!*q)
+               return -ENOMEM;
+
+       (*q)->type = type;
+       (*q)->next = NULL;
+
+       vpninfo->inflate_strm.next_in = buf;
+       vpninfo->inflate_strm.avail_in = len - 4;
+
+       vpninfo->inflate_strm.next_out = (*q)->data;
+       vpninfo->inflate_strm.avail_out = vpninfo->mtu;
+       vpninfo->inflate_strm.total_out = 0;
+
+       if (inflate(&vpninfo->inflate_strm, Z_SYNC_FLUSH)) {
+               fprintf(stderr, "inflate failed\n");
+               free(*q);
+               *q = NULL;
+               return -EINVAL;
+       }
+
+       (*q)->len = vpninfo->inflate_strm.total_out;
+
+       vpninfo->inflate_adler32 = adler32(vpninfo->inflate_adler32,
+                                          (*q)->data, (*q)->len);
+
+       if (vpninfo->inflate_adler32 != ntohl( *(uint32_t *)(buf + len - 4))) {
+               vpninfo->quit_reason = "Compression (inflate) adler32 failure";
+       }
+
+       return 0;
+}
+
 int queue_new_packet(struct pkt **q, int type, void *buf, int len)
 {
        while (*q)
@@ -70,7 +111,7 @@ int vpn_mainloop(struct anyconnect_info *vpninfo)
        sa.sa_handler = handle_sigint;
        
        sigaction(SIGINT, &sa, NULL);
-       while (!killed) {
+       while (!killed && !vpninfo->quit_reason) {
                int did_work = 0;
                int timeout = INT_MAX;
 
@@ -87,8 +128,16 @@ int vpn_mainloop(struct anyconnect_info *vpninfo)
                        printf("Did no work; sleeping for %d ms...\n", timeout);
 
                poll(vpninfo->pfds, vpninfo->nfds, timeout);
+               if (vpninfo->pfds[vpninfo->ssl_pfd].revents & POLL_HUP) {
+                       fprintf(stderr, "Server closed connection!\n");
+                       /* OpenSSL doesn't seem to cope properly with this... */
+                       exit(1);
+               }
        }
-       ssl_bye(vpninfo, "Client received SIGINT\n");
+       if (!vpninfo->quit_reason)
+               vpninfo->quit_reason = "Client received SIGINT";
+
+       ssl_bye(vpninfo, vpninfo->quit_reason);
 
        return 0;
 }
diff --git a/ssl.c b/ssl.c
index a8da639..523cf6a 100644 (file)
--- a/ssl.c
+++ b/ssl.c
@@ -192,6 +192,8 @@ int start_ssl_connection(struct anyconnect_info *vpninfo)
                return -EINVAL;
        }
 
+       /* We may have advertised it, but we only do it if the server agrees */
+       vpninfo->deflate = 0;
 
        while ((i=my_SSL_gets(vpninfo->https_ssl, buf, sizeof(buf)))) {
                struct vpn_option *new_option;
@@ -222,15 +224,26 @@ int start_ssl_connection(struct anyconnect_info *vpninfo)
                        return -ENOMEM;
                }
 
-               if (!strcmp(buf, "X-CSTP-Keepalive"))
-                       vpninfo->ssl_keepalive = atol(colon);
-                   
                if (!strncmp(buf, "X-DTLS-", 7)) {
                        *next_dtls_option = new_option;
                        next_dtls_option = &new_option->next;
-               } else {
-                       *next_cstp_option = new_option;
-                       next_cstp_option = &new_option->next;
+                       continue;
+               }
+               /* CSTP options... */
+               *next_cstp_option = new_option;
+               next_cstp_option = &new_option->next;
+
+               if (!strcmp(buf + 7, "Keepalive")) {
+                       vpninfo->ssl_keepalive = atol(colon);
+               } else if (!strcmp(buf + 7, "Content-Encoding")) {
+                       if (!strcmp(colon, "deflate"))
+                               vpninfo->deflate = 1;
+                       else {
+                               fprintf(stderr, 
+                                       "Unknown CSTP-Content-Encoding %s\n",
+                                       colon);
+                               return -EINVAL;
+                       }
                }
        }
 
@@ -242,7 +255,7 @@ int start_ssl_connection(struct anyconnect_info *vpninfo)
 
        fcntl(vpninfo->ssl_fd, F_SETFL, fcntl(vpninfo->ssl_fd, F_GETFL) | O_NONBLOCK);
 
-       vpninfo->ssl_pfd = vpn_add_pollfd(vpninfo, vpninfo->ssl_fd, POLLIN);
+       vpninfo->ssl_pfd = vpn_add_pollfd(vpninfo, vpninfo->ssl_fd, POLLIN|POLLHUP|POLLERR);
        vpninfo->last_ssl_tx = time(NULL);
        return 0;
 }
@@ -284,13 +297,8 @@ int ssl_mainloop(struct anyconnect_info *vpninfo, int *timeout)
        while ( (len = SSL_read(vpninfo->https_ssl, buf, sizeof(buf))) > 0) {
 
                if (buf[0] != 'S' || buf[1] != 'T' ||
-                   buf[2] != 'F' || buf[3] != 1 || buf[7]) {
-               unknown_pkt:
-                       printf("Unknown packet %02x %02x %02x %02x %02x %02x %02x %02x\n",
-                              buf[0], buf[1], buf[2], buf[3],
-                              buf[4], buf[5], buf[6], buf[7]);
-                       continue;
-               }
+                   buf[2] != 'F' || buf[3] != 1 || buf[7])
+                       goto unknown_pkt;
                if (len != 8 + (buf[4] << 8) + buf[5]) {
                        printf("Unexpected packet length. SSL_read returned %d but packet is\n",
                               len);
@@ -299,28 +307,87 @@ int ssl_mainloop(struct anyconnect_info *vpninfo, int *timeout)
                               buf[4], buf[5], buf[6], buf[7]);
                        continue;
                }
-               if (buf[6] == 4) {
-                       /* Keepalive response */
+               switch(buf[6]) {
+               case 4: /* Keepalive response */
                        if (verbose)
                                printf("Got keepalive response\n");
                        continue;
+
+               case 0: /* Uncompressed Data */
+                       queue_new_packet(&vpninfo->incoming_queue, AF_INET, buf + 8,
+                                        (buf[4] << 8) + buf[5]);
+                       work_done = 1;
+                       continue;
+
+               case 8: /* Compressed data */
+                       if (!vpninfo->deflate) {
+                               fprintf(stderr, "Compressed packet received in !deflate mode\n");
+                               goto unknown_pkt;
+                       }
+                       inflate_and_queue_packet(vpninfo, AF_INET, buf + 8, len - 8);
+                       work_done = 1;
+                       continue;
+
+               case 9:
+                       fprintf(stderr, "received server terminate packet\n");
+                       vpninfo->quit_reason = "Server request";
+                       /* Do not pass Go. Do not collect £200 */
+                       exit(1);
                }
-               if (buf[6] != 0) /* Data */
-                       goto unknown_pkt;
-               
-               queue_new_packet(&vpninfo->incoming_queue, AF_INET, buf + 8,
-                                (buf[4] << 8) + buf[5]);
-               work_done = 1;
+
+       unknown_pkt:
+               printf("Unknown packet %02x %02x %02x %02x %02x %02x %02x %02x\n",
+                      buf[0], buf[1], buf[2], buf[3],
+                      buf[4], buf[5], buf[6], buf[7]);
+               exit(1);
        }
+               
+
 
        while (vpninfo->outgoing_queue) {
                struct pkt *this = vpninfo->outgoing_queue;
+               char buf[2048];
+
+               memcpy(buf, data_hdr, 8);
+
                vpninfo->outgoing_queue = this->next;
-               memcpy(this->hdr, data_hdr, 8);
-               this->hdr[4] = this->len >> 8;
-               this->hdr[5] = this->len & 0xff;
-                     
-               SSL_write(vpninfo->https_ssl, this->hdr, this->len + 8);
+
+               if (vpninfo->deflate) {
+                       int ret;
+                       vpninfo->deflate_strm.next_in = this->data;
+                       vpninfo->deflate_strm.avail_in = this->len;
+                       vpninfo->deflate_strm.next_out = (void *)buf + 8;
+                       vpninfo->deflate_strm.avail_out = 2040;
+                       vpninfo->deflate_strm.total_out = 0;
+
+                       ret = deflate(&vpninfo->deflate_strm, Z_SYNC_FLUSH);
+                       if (ret) {
+                               fprintf(stderr, "deflate failed %d\n", ret);
+                               goto uncompr;
+                       }
+
+                       buf[6] = 8;
+                       buf[4] = (vpninfo->deflate_strm.total_out + 4) >> 8;
+                       buf[5] = (vpninfo->deflate_strm.total_out + 4) & 0xff;
+
+                       /* Add ongoing adler32 to tail of compressed packet */
+                       vpninfo->deflate_adler32 = adler32(vpninfo->deflate_adler32,
+                                                          this->data, this->len);
+
+                       buf[8 + vpninfo->deflate_strm.total_out] = vpninfo->deflate_adler32 >> 24;
+                       buf[9 + vpninfo->deflate_strm.total_out] = (vpninfo->deflate_adler32 >> 16) & 0xff;
+                       buf[10 + vpninfo->deflate_strm.total_out] = (vpninfo->deflate_adler32 >> 8) & 0xff;
+                       buf[11 + vpninfo->deflate_strm.total_out] = vpninfo->deflate_adler32 & 0xff;
+
+                       SSL_write(vpninfo->https_ssl, buf, 
+                                 vpninfo->deflate_strm.total_out + 12);
+               } else {
+               uncompr:
+                       buf[4] = this->len >> 8;
+                       buf[5] = this->len & 0xff;
+                       memcpy(buf + 8, this->data, this->len);
+                       SSL_write(vpninfo->https_ssl, buf, this->len + 8);
+               }
                vpninfo->last_ssl_tx = time(NULL);
        }