Close ssl_sock before returning error in connect_https_socket()
[platform/upstream/openconnect.git] / cstp.c
diff --git a/cstp.c b/cstp.c
index dbc24b9..dacb2ae 100644 (file)
--- a/cstp.c
+++ b/cstp.c
 #include <errno.h>
 #include <stdlib.h>
 #include <stdio.h>
-
-#include <openssl/ssl.h>
-#include <openssl/err.h>
+#include <sys/types.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <stdarg.h>
 
 #include "openconnect-internal.h"
 
@@ -86,6 +88,64 @@ static int  __attribute__ ((format (printf, 3, 4)))
        return ret;
 }
 
+/* Calculate MTU to request. Old servers simply use the X-CSTP-MTU: header,
+ * which represents the tunnel MTU, while new servers do calculations on the
+ * X-CSTP-Base-MTU: header which represents the cleartext MTU between client
+ * and server.
+ *
+ * If possible, the legacy MTU value should be the TCP MSS less 5 bytes of
+ * TLS and 8 bytes of CSTP overhead. We can get the MSS from either the
+ * TCP_INFO or TCP_MAXSEG sockopts.
+ *
+ * The base MTU comes from the TCP_INFO sockopt under Linux, but I don't know
+ * how to work it out on other systems. So leave it blank and do things the
+ * legacy way there. Contributions welcome...
+ *
+ * If we don't even have TCP_MAXSEG, then default to sending a legacy MTU of
+ * 1406 which is what we always used to do.
+ */
+static void calculate_mtu(struct openconnect_info *vpninfo, int *base_mtu, int *mtu)
+{
+       *mtu = vpninfo->reqmtu;
+       *base_mtu = vpninfo->basemtu;
+
+#if defined(__linux__) && defined(TCP_INFO)
+       if (!*mtu || !*base_mtu) {
+               struct tcp_info ti;
+               socklen_t ti_size = sizeof(ti);
+
+               if (!getsockopt(vpninfo->ssl_fd, IPPROTO_TCP, TCP_INFO,
+                               &ti, &ti_size)) {
+                       vpn_progress(vpninfo, PRG_TRACE,
+                                    _("TCP_INFO rcv mss %d, snd mss %d, adv mss %d, pmtu %d\n"),
+                                    ti.tcpi_rcv_mss, ti.tcpi_snd_mss, ti.tcpi_advmss, ti.tcpi_pmtu);
+                       if (!*base_mtu) *base_mtu = ti.tcpi_pmtu;
+                       if (!*mtu) {
+                               if (ti.tcpi_rcv_mss < ti.tcpi_snd_mss)
+                                       *mtu = ti.tcpi_rcv_mss - 13;
+                               else
+                                       *mtu = ti.tcpi_snd_mss - 13;
+                       }
+               }
+       }
+#endif
+#ifdef TCP_MAXSEG
+       if (!*mtu) {
+               int mss;
+               socklen_t mss_size = sizeof(mss);
+               if (!getsockopt(vpninfo->ssl_fd, IPPROTO_TCP, TCP_MAXSEG,
+                               &mss, &mss_size)) {
+                       vpn_progress(vpninfo, PRG_TRACE, _("TCP_MAXSEG %d\n"), mss);
+                       *mtu = mss - 13;
+               }
+       }
+#endif
+       if (!*mtu) {
+               /* Default */
+               *mtu = 1406;
+       }
+}
+
 static int start_cstp_connection(struct openconnect_info *vpninfo)
 {
        char buf[65536];
@@ -100,6 +160,7 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
        const char *old_addr6 = vpninfo->vpn_addr6;
        const char *old_netmask6 = vpninfo->vpn_netmask6;
        struct split_include *inc;
+       int base_mtu, mtu;
 
        /* Clear old options which will be overwritten */
        vpninfo->vpn_addr = vpninfo->vpn_netmask = NULL;
@@ -121,7 +182,12 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
                free(inc);
                inc = next;
        }
-       vpninfo->split_includes = vpninfo->split_excludes = NULL;
+       for (inc = vpninfo->split_dns; inc; ) {
+               struct split_include *next = inc->next;
+               free(inc);
+               inc = next;
+       }
+       vpninfo->split_dns = vpninfo->split_includes = vpninfo->split_excludes = NULL;
 
        /* Create (new) random master key for DTLS connection, if needed */
        if (vpninfo->dtls_times.last_rekey + vpninfo->dtls_times.rekey <
@@ -132,6 +198,8 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
        }
 
  retry:
+       calculate_mtu(vpninfo, &base_mtu, &mtu);
+
        buf[0] = 0;
        buf_append(buf, sizeof(buf), "CONNECT /CSCOSSLC/tunnel HTTP/1.1\r\n");
        buf_append(buf, sizeof(buf), "Host: %s\r\n", vpninfo->hostname);
@@ -141,7 +209,9 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
        buf_append(buf, sizeof(buf), "X-CSTP-Hostname: %s\r\n", vpninfo->localname);
        if (vpninfo->deflate && i < sizeof(buf))
                buf_append(buf, sizeof(buf), "X-CSTP-Accept-Encoding: deflate;q=1.0\r\n");
-       buf_append(buf, sizeof(buf), "X-CSTP-MTU: %d\r\n", vpninfo->mtu);
+       if (base_mtu)
+               buf_append(buf, sizeof(buf), "X-CSTP-Base-MTU: %d\r\n", base_mtu);
+       buf_append(buf, sizeof(buf), "X-CSTP-MTU: %d\r\n", mtu);
        buf_append(buf, sizeof(buf), "X-CSTP-Address-Type: %s\r\n",
                               vpninfo->disable_ipv6?"IPv4":"IPv6,IPv4");
        buf_append(buf, sizeof(buf), "X-DTLS-Master-Secret: ");
@@ -159,7 +229,7 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
                             _("Error fetching HTTPS response\n"));
                if (!retried) {
                        retried = 1;
-                       openconnect_close_https(vpninfo);
+                       openconnect_close_https(vpninfo, 0);
 
                        if (openconnect_open_https(vpninfo)) {
                                vpn_progress(vpninfo, PRG_ERR,
@@ -199,6 +269,7 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
 
        /* We may have advertised it, but we only do it if the server agrees */
        vpninfo->deflate = 0;
+       mtu = 0;
 
        while ((i = openconnect_SSL_gets(vpninfo, buf, sizeof(buf)))) {
                struct vpn_option *new_option;
@@ -231,6 +302,9 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
 
                if (!new_option->option || !new_option->value) {
                        vpn_progress(vpninfo, PRG_ERR, _("No memory for options\n"));
+                       free(new_option->option);
+                       free(new_option->value);
+                       free(new_option);
                        return -ENOMEM;
                }
 
@@ -240,7 +314,11 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
                        *next_dtls_option = new_option;
                        next_dtls_option = &new_option->next;
 
-                       if (!strcmp(buf + 7, "Session-ID")) {
+                       if (!strcmp(buf + 7, "MTU")) {
+                               int dtlsmtu = atol(colon);
+                               if (dtlsmtu > mtu)
+                                       mtu = dtlsmtu;
+                       } else if (!strcmp(buf + 7, "Session-ID")) {
                                if (strlen(colon) != 64) {
                                        vpn_progress(vpninfo, PRG_ERR,
                                                     _("X-DTLS-Session-ID not 64 characters; is: \"%s\"\n"),
@@ -278,16 +356,20 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
                                return -EINVAL;
                        }
                } else if (!strcmp(buf + 7, "MTU")) {
-                       vpninfo->mtu = atol(colon);
+                       int cstpmtu = atol(colon);
+                       if (cstpmtu > mtu)
+                               mtu = cstpmtu;
                } else if (!strcmp(buf + 7, "Address")) {
-                       if (strchr(new_option->value, ':'))
-                               vpninfo->vpn_addr6 = new_option->value;
-                       else
+                       if (strchr(new_option->value, ':')) {
+                               if (!vpninfo->disable_ipv6)
+                                       vpninfo->vpn_addr6 = new_option->value;
+                       } else
                                vpninfo->vpn_addr = new_option->value;
                } else if (!strcmp(buf + 7, "Netmask")) {
-                       if (strchr(new_option->value, ':'))
-                               vpninfo->vpn_netmask6 = new_option->value;
-                       else
+                       if (strchr(new_option->value, ':')) {
+                               if (!vpninfo->disable_ipv6)
+                                       vpninfo->vpn_netmask6 = new_option->value;
+                       } else
                                vpninfo->vpn_netmask = new_option->value;
                } else if (!strcmp(buf + 7, "DNS")) {
                        int j;
@@ -311,6 +393,13 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
                        vpninfo->vpn_proxy_pac = new_option->value;
                } else if (!strcmp(buf + 7, "Banner")) {
                        vpninfo->banner = new_option->value;
+               } else if (!strcmp(buf + 7, "Split-DNS")) {
+                       struct split_include *dns = malloc(sizeof(*dns));
+                       if (!dns)
+                               continue;
+                       dns->route = new_option->value;
+                       dns->next = vpninfo->split_dns;
+                       vpninfo->split_dns = dns;
                } else if (!strcmp(buf + 7, "Split-Include")) {
                        struct split_include *inc = malloc(sizeof(*inc));
                        if (!inc)
@@ -328,6 +417,13 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
                }
        }
 
+       if (!mtu) {
+               vpn_progress(vpninfo, PRG_ERR,
+                            _("No MTU received. Aborting\n"));
+               return -EINVAL;
+       }
+       vpninfo->actual_mtu = mtu;
+
        if (!vpninfo->vpn_addr && !vpninfo->vpn_addr6) {
                vpn_progress(vpninfo, PRG_ERR,
                             _("No IP address received. Aborting\n"));
@@ -442,15 +538,15 @@ int cstp_reconnect(struct openconnect_info *vpninfo)
        int timeout;
        int interval;
 
-       openconnect_close_https(vpninfo);
+       openconnect_close_https(vpninfo, 0);
 
-       /* Requeue the original packet that was deflated */
-       if (vpninfo->current_ssl_pkt == vpninfo->deflate_pkt) {
-               vpninfo->current_ssl_pkt = NULL;
-               queue_packet(&vpninfo->outgoing_queue, vpninfo->pending_deflated_pkt);
-               vpninfo->pending_deflated_pkt = NULL;
-       }
        if (vpninfo->deflate) {
+               /* Requeue the original packet that was deflated */
+               if (vpninfo->current_ssl_pkt == vpninfo->deflate_pkt) {
+                       vpninfo->current_ssl_pkt = NULL;
+                       queue_packet(&vpninfo->outgoing_queue, vpninfo->pending_deflated_pkt);
+                       vpninfo->pending_deflated_pkt = NULL;
+               }
                inflateEnd(&vpninfo->inflate_strm);
                deflateEnd(&vpninfo->deflate_strm);
        }
@@ -478,7 +574,7 @@ int cstp_reconnect(struct openconnect_info *vpninfo)
 static int inflate_and_queue_packet(struct openconnect_info *vpninfo,
                                    unsigned char *buf, int len)
 {
-       struct pkt *new = malloc(sizeof(struct pkt) + vpninfo->mtu);
+       struct pkt *new = malloc(sizeof(struct pkt) + vpninfo->actual_mtu);
        uint32_t pkt_sum;
 
        if (!new)
@@ -490,7 +586,7 @@ static int inflate_and_queue_packet(struct openconnect_info *vpninfo,
        vpninfo->inflate_strm.avail_in = len - 4;
 
        vpninfo->inflate_strm.next_out = new->data;
-       vpninfo->inflate_strm.avail_out = vpninfo->mtu;
+       vpninfo->inflate_strm.avail_out = vpninfo->actual_mtu;
        vpninfo->inflate_strm.total_out = 0;
 
        if (inflate(&vpninfo->inflate_strm, Z_SYNC_FLUSH)) {
@@ -710,11 +806,25 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout)
                ret = cstp_write(vpninfo,
                                 vpninfo->current_ssl_pkt->hdr,
                                 vpninfo->current_ssl_pkt->len + 8);
-               
                if (ret < 0)
                        goto do_reconnect;
-               else if (!ret && ka_stalled_dpd_time(&vpninfo->ssl_times, timeout))
-                       goto peer_dead;
+               else if (!ret) {
+                       /* -EAGAIN: cstp_write() will have added the SSL fd to
+                          ->select_wfds if appropriate, so we can just return
+                          and wait. Unless it's been stalled for so long that
+                          DPD kicks in and we kill the connection. */
+                       switch (ka_stalled_action(&vpninfo->ssl_times, timeout)) {
+                       case KA_DPD_DEAD:
+                               goto peer_dead;
+                       case KA_REKEY:
+                               goto do_rekey;
+                       case KA_NONE:
+                               return work_done;
+                       default:
+                               /* This should never happen */
+                               ;
+                       }
+               }
 
                if (ret != vpninfo->current_ssl_pkt->len + 8) {
                        vpn_progress(vpninfo, PRG_ERR,
@@ -742,6 +852,7 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout)
 
        switch (keepalive_action(&vpninfo->ssl_times, timeout)) {
        case KA_REKEY:
+       do_rekey:
                /* Not that this will ever happen; we don't even process
                   the setting when we're asked for it. */
                vpn_progress(vpninfo, PRG_INFO, _("CSTP rekey due\n"));