Close ssl_sock before returning error in connect_https_socket()
[platform/upstream/openconnect.git] / cstp.c
diff --git a/cstp.c b/cstp.c
index 792c17e..dacb2ae 100644 (file)
--- a/cstp.c
+++ b/cstp.c
 #include <errno.h>
 #include <stdlib.h>
 #include <stdio.h>
-#include <netinet/tcp.h>
 #include <sys/types.h>
 #include <sys/socket.h>
-
-#include <openssl/ssl.h>
-#include <openssl/err.h>
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+#include <stdarg.h>
 
 #include "openconnect-internal.h"
 
@@ -107,15 +106,15 @@ static int  __attribute__ ((format (printf, 3, 4)))
  */
 static void calculate_mtu(struct openconnect_info *vpninfo, int *base_mtu, int *mtu)
 {
-       *mtu = vpninfo->mtu;
+       *mtu = vpninfo->reqmtu;
        *base_mtu = vpninfo->basemtu;
 
-#ifdef TCP_INFO
+#if defined(__linux__) && defined(TCP_INFO)
        if (!*mtu || !*base_mtu) {
                struct tcp_info ti;
                socklen_t ti_size = sizeof(ti);
 
-               if (!getsockopt(vpninfo->ssl_fd, SOL_TCP, TCP_INFO,
+               if (!getsockopt(vpninfo->ssl_fd, IPPROTO_TCP, TCP_INFO,
                                &ti, &ti_size)) {
                        vpn_progress(vpninfo, PRG_TRACE,
                                     _("TCP_INFO rcv mss %d, snd mss %d, adv mss %d, pmtu %d\n"),
@@ -134,7 +133,7 @@ static void calculate_mtu(struct openconnect_info *vpninfo, int *base_mtu, int *
        if (!*mtu) {
                int mss;
                socklen_t mss_size = sizeof(mss);
-               if (!getsockopt(vpninfo->ssl_fd, SOL_TCP, TCP_MAXSEG,
+               if (!getsockopt(vpninfo->ssl_fd, IPPROTO_TCP, TCP_MAXSEG,
                                &mss, &mss_size)) {
                        vpn_progress(vpninfo, PRG_TRACE, _("TCP_MAXSEG %d\n"), mss);
                        *mtu = mss - 13;
@@ -270,6 +269,7 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
 
        /* We may have advertised it, but we only do it if the server agrees */
        vpninfo->deflate = 0;
+       mtu = 0;
 
        while ((i = openconnect_SSL_gets(vpninfo, buf, sizeof(buf)))) {
                struct vpn_option *new_option;
@@ -302,6 +302,9 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
 
                if (!new_option->option || !new_option->value) {
                        vpn_progress(vpninfo, PRG_ERR, _("No memory for options\n"));
+                       free(new_option->option);
+                       free(new_option->value);
+                       free(new_option);
                        return -ENOMEM;
                }
 
@@ -312,9 +315,9 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
                        next_dtls_option = &new_option->next;
 
                        if (!strcmp(buf + 7, "MTU")) {
-                               int mtu = atol(colon);
-                               if (mtu > vpninfo->mtu)
-                                       vpninfo->mtu = mtu;
+                               int dtlsmtu = atol(colon);
+                               if (dtlsmtu > mtu)
+                                       mtu = dtlsmtu;
                        } else if (!strcmp(buf + 7, "Session-ID")) {
                                if (strlen(colon) != 64) {
                                        vpn_progress(vpninfo, PRG_ERR,
@@ -353,18 +356,20 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
                                return -EINVAL;
                        }
                } else if (!strcmp(buf + 7, "MTU")) {
-                       int mtu = atol(colon);
-                       if (mtu > vpninfo->mtu)
-                               vpninfo->mtu = mtu;
+                       int cstpmtu = atol(colon);
+                       if (cstpmtu > mtu)
+                               mtu = cstpmtu;
                } else if (!strcmp(buf + 7, "Address")) {
-                       if (strchr(new_option->value, ':'))
-                               vpninfo->vpn_addr6 = new_option->value;
-                       else
+                       if (strchr(new_option->value, ':')) {
+                               if (!vpninfo->disable_ipv6)
+                                       vpninfo->vpn_addr6 = new_option->value;
+                       } else
                                vpninfo->vpn_addr = new_option->value;
                } else if (!strcmp(buf + 7, "Netmask")) {
-                       if (strchr(new_option->value, ':'))
-                               vpninfo->vpn_netmask6 = new_option->value;
-                       else
+                       if (strchr(new_option->value, ':')) {
+                               if (!vpninfo->disable_ipv6)
+                                       vpninfo->vpn_netmask6 = new_option->value;
+                       } else
                                vpninfo->vpn_netmask = new_option->value;
                } else if (!strcmp(buf + 7, "DNS")) {
                        int j;
@@ -412,6 +417,13 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
                }
        }
 
+       if (!mtu) {
+               vpn_progress(vpninfo, PRG_ERR,
+                            _("No MTU received. Aborting\n"));
+               return -EINVAL;
+       }
+       vpninfo->actual_mtu = mtu;
+
        if (!vpninfo->vpn_addr && !vpninfo->vpn_addr6) {
                vpn_progress(vpninfo, PRG_ERR,
                             _("No IP address received. Aborting\n"));
@@ -528,13 +540,13 @@ int cstp_reconnect(struct openconnect_info *vpninfo)
 
        openconnect_close_https(vpninfo, 0);
 
-       /* Requeue the original packet that was deflated */
-       if (vpninfo->current_ssl_pkt == vpninfo->deflate_pkt) {
-               vpninfo->current_ssl_pkt = NULL;
-               queue_packet(&vpninfo->outgoing_queue, vpninfo->pending_deflated_pkt);
-               vpninfo->pending_deflated_pkt = NULL;
-       }
        if (vpninfo->deflate) {
+               /* Requeue the original packet that was deflated */
+               if (vpninfo->current_ssl_pkt == vpninfo->deflate_pkt) {
+                       vpninfo->current_ssl_pkt = NULL;
+                       queue_packet(&vpninfo->outgoing_queue, vpninfo->pending_deflated_pkt);
+                       vpninfo->pending_deflated_pkt = NULL;
+               }
                inflateEnd(&vpninfo->inflate_strm);
                deflateEnd(&vpninfo->deflate_strm);
        }
@@ -562,7 +574,7 @@ int cstp_reconnect(struct openconnect_info *vpninfo)
 static int inflate_and_queue_packet(struct openconnect_info *vpninfo,
                                    unsigned char *buf, int len)
 {
-       struct pkt *new = malloc(sizeof(struct pkt) + vpninfo->mtu);
+       struct pkt *new = malloc(sizeof(struct pkt) + vpninfo->actual_mtu);
        uint32_t pkt_sum;
 
        if (!new)
@@ -574,7 +586,7 @@ static int inflate_and_queue_packet(struct openconnect_info *vpninfo,
        vpninfo->inflate_strm.avail_in = len - 4;
 
        vpninfo->inflate_strm.next_out = new->data;
-       vpninfo->inflate_strm.avail_out = vpninfo->mtu;
+       vpninfo->inflate_strm.avail_out = vpninfo->actual_mtu;
        vpninfo->inflate_strm.total_out = 0;
 
        if (inflate(&vpninfo->inflate_strm, Z_SYNC_FLUSH)) {
@@ -794,11 +806,25 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout)
                ret = cstp_write(vpninfo,
                                 vpninfo->current_ssl_pkt->hdr,
                                 vpninfo->current_ssl_pkt->len + 8);
-               
                if (ret < 0)
                        goto do_reconnect;
-               else if (!ret && ka_stalled_dpd_time(&vpninfo->ssl_times, timeout))
-                       goto peer_dead;
+               else if (!ret) {
+                       /* -EAGAIN: cstp_write() will have added the SSL fd to
+                          ->select_wfds if appropriate, so we can just return
+                          and wait. Unless it's been stalled for so long that
+                          DPD kicks in and we kill the connection. */
+                       switch (ka_stalled_action(&vpninfo->ssl_times, timeout)) {
+                       case KA_DPD_DEAD:
+                               goto peer_dead;
+                       case KA_REKEY:
+                               goto do_rekey;
+                       case KA_NONE:
+                               return work_done;
+                       default:
+                               /* This should never happen */
+                               ;
+                       }
+               }
 
                if (ret != vpninfo->current_ssl_pkt->len + 8) {
                        vpn_progress(vpninfo, PRG_ERR,
@@ -826,6 +852,7 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout)
 
        switch (keepalive_action(&vpninfo->ssl_times, timeout)) {
        case KA_REKEY:
+       do_rekey:
                /* Not that this will ever happen; we don't even process
                   the setting when we're asked for it. */
                vpn_progress(vpninfo, PRG_INFO, _("CSTP rekey due\n"));