Close ssl_sock before returning error in connect_https_socket()
[platform/upstream/openconnect.git] / cstp.c
diff --git a/cstp.c b/cstp.c
index ff0eba3..dacb2ae 100644 (file)
--- a/cstp.c
+++ b/cstp.c
@@ -106,7 +106,7 @@ static int  __attribute__ ((format (printf, 3, 4)))
  */
 static void calculate_mtu(struct openconnect_info *vpninfo, int *base_mtu, int *mtu)
 {
-       *mtu = vpninfo->mtu;
+       *mtu = vpninfo->reqmtu;
        *base_mtu = vpninfo->basemtu;
 
 #if defined(__linux__) && defined(TCP_INFO)
@@ -269,6 +269,7 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
 
        /* We may have advertised it, but we only do it if the server agrees */
        vpninfo->deflate = 0;
+       mtu = 0;
 
        while ((i = openconnect_SSL_gets(vpninfo, buf, sizeof(buf)))) {
                struct vpn_option *new_option;
@@ -301,6 +302,9 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
 
                if (!new_option->option || !new_option->value) {
                        vpn_progress(vpninfo, PRG_ERR, _("No memory for options\n"));
+                       free(new_option->option);
+                       free(new_option->value);
+                       free(new_option);
                        return -ENOMEM;
                }
 
@@ -311,9 +315,9 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
                        next_dtls_option = &new_option->next;
 
                        if (!strcmp(buf + 7, "MTU")) {
-                               int mtu = atol(colon);
-                               if (mtu > vpninfo->mtu)
-                                       vpninfo->mtu = mtu;
+                               int dtlsmtu = atol(colon);
+                               if (dtlsmtu > mtu)
+                                       mtu = dtlsmtu;
                        } else if (!strcmp(buf + 7, "Session-ID")) {
                                if (strlen(colon) != 64) {
                                        vpn_progress(vpninfo, PRG_ERR,
@@ -352,18 +356,20 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
                                return -EINVAL;
                        }
                } else if (!strcmp(buf + 7, "MTU")) {
-                       int mtu = atol(colon);
-                       if (mtu > vpninfo->mtu)
-                               vpninfo->mtu = mtu;
+                       int cstpmtu = atol(colon);
+                       if (cstpmtu > mtu)
+                               mtu = cstpmtu;
                } else if (!strcmp(buf + 7, "Address")) {
-                       if (strchr(new_option->value, ':'))
-                               vpninfo->vpn_addr6 = new_option->value;
-                       else
+                       if (strchr(new_option->value, ':')) {
+                               if (!vpninfo->disable_ipv6)
+                                       vpninfo->vpn_addr6 = new_option->value;
+                       } else
                                vpninfo->vpn_addr = new_option->value;
                } else if (!strcmp(buf + 7, "Netmask")) {
-                       if (strchr(new_option->value, ':'))
-                               vpninfo->vpn_netmask6 = new_option->value;
-                       else
+                       if (strchr(new_option->value, ':')) {
+                               if (!vpninfo->disable_ipv6)
+                                       vpninfo->vpn_netmask6 = new_option->value;
+                       } else
                                vpninfo->vpn_netmask = new_option->value;
                } else if (!strcmp(buf + 7, "DNS")) {
                        int j;
@@ -411,6 +417,13 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
                }
        }
 
+       if (!mtu) {
+               vpn_progress(vpninfo, PRG_ERR,
+                            _("No MTU received. Aborting\n"));
+               return -EINVAL;
+       }
+       vpninfo->actual_mtu = mtu;
+
        if (!vpninfo->vpn_addr && !vpninfo->vpn_addr6) {
                vpn_progress(vpninfo, PRG_ERR,
                             _("No IP address received. Aborting\n"));
@@ -527,13 +540,13 @@ int cstp_reconnect(struct openconnect_info *vpninfo)
 
        openconnect_close_https(vpninfo, 0);
 
-       /* Requeue the original packet that was deflated */
-       if (vpninfo->current_ssl_pkt == vpninfo->deflate_pkt) {
-               vpninfo->current_ssl_pkt = NULL;
-               queue_packet(&vpninfo->outgoing_queue, vpninfo->pending_deflated_pkt);
-               vpninfo->pending_deflated_pkt = NULL;
-       }
        if (vpninfo->deflate) {
+               /* Requeue the original packet that was deflated */
+               if (vpninfo->current_ssl_pkt == vpninfo->deflate_pkt) {
+                       vpninfo->current_ssl_pkt = NULL;
+                       queue_packet(&vpninfo->outgoing_queue, vpninfo->pending_deflated_pkt);
+                       vpninfo->pending_deflated_pkt = NULL;
+               }
                inflateEnd(&vpninfo->inflate_strm);
                deflateEnd(&vpninfo->deflate_strm);
        }
@@ -561,7 +574,7 @@ int cstp_reconnect(struct openconnect_info *vpninfo)
 static int inflate_and_queue_packet(struct openconnect_info *vpninfo,
                                    unsigned char *buf, int len)
 {
-       struct pkt *new = malloc(sizeof(struct pkt) + vpninfo->mtu);
+       struct pkt *new = malloc(sizeof(struct pkt) + vpninfo->actual_mtu);
        uint32_t pkt_sum;
 
        if (!new)
@@ -573,7 +586,7 @@ static int inflate_and_queue_packet(struct openconnect_info *vpninfo,
        vpninfo->inflate_strm.avail_in = len - 4;
 
        vpninfo->inflate_strm.next_out = new->data;
-       vpninfo->inflate_strm.avail_out = vpninfo->mtu;
+       vpninfo->inflate_strm.avail_out = vpninfo->actual_mtu;
        vpninfo->inflate_strm.total_out = 0;
 
        if (inflate(&vpninfo->inflate_strm, Z_SYNC_FLUSH)) {
@@ -793,11 +806,25 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout)
                ret = cstp_write(vpninfo,
                                 vpninfo->current_ssl_pkt->hdr,
                                 vpninfo->current_ssl_pkt->len + 8);
-               
                if (ret < 0)
                        goto do_reconnect;
-               else if (!ret && ka_stalled_dpd_time(&vpninfo->ssl_times, timeout))
-                       goto peer_dead;
+               else if (!ret) {
+                       /* -EAGAIN: cstp_write() will have added the SSL fd to
+                          ->select_wfds if appropriate, so we can just return
+                          and wait. Unless it's been stalled for so long that
+                          DPD kicks in and we kill the connection. */
+                       switch (ka_stalled_action(&vpninfo->ssl_times, timeout)) {
+                       case KA_DPD_DEAD:
+                               goto peer_dead;
+                       case KA_REKEY:
+                               goto do_rekey;
+                       case KA_NONE:
+                               return work_done;
+                       default:
+                               /* This should never happen */
+                               ;
+                       }
+               }
 
                if (ret != vpninfo->current_ssl_pkt->len + 8) {
                        vpn_progress(vpninfo, PRG_ERR,
@@ -825,6 +852,7 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout)
 
        switch (keepalive_action(&vpninfo->ssl_times, timeout)) {
        case KA_REKEY:
+       do_rekey:
                /* Not that this will ever happen; we don't even process
                   the setting when we're asked for it. */
                vpn_progress(vpninfo, PRG_INFO, _("CSTP rekey due\n"));