Move to using select() instead of poll(). poll() doesn't work on MacOS
authorDavid Woodhouse <David.Woodhouse@intel.com>
Wed, 15 Oct 2008 12:20:38 +0000 (13:20 +0100)
committerDavid Woodhouse <David.Woodhouse@intel.com>
Wed, 15 Oct 2008 12:20:38 +0000 (13:20 +0100)
cstp.c
dtls.c
mainloop.c
openconnect.h
ssl.c
tun.c

diff --git a/cstp.c b/cstp.c
index 675a94d..d1fdd5c 100644 (file)
--- a/cstp.c
+++ b/cstp.c
@@ -100,8 +100,7 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
                vpninfo->progress(vpninfo, PRG_ERR, "Error fetching HTTPS response\n");
                if (!retried) {
                        retried = 1;
-                       SSL_free(vpninfo->https_ssl);
-                       close(vpninfo->ssl_fd);
+                       openconnect_close_https(vpninfo);
                
                        if (openconnect_open_https(vpninfo)) {
                                vpninfo->progress(vpninfo, PRG_ERR,
@@ -253,7 +252,11 @@ static int start_cstp_connection(struct openconnect_info *vpninfo)
        BIO_set_nbio(SSL_get_wbio(vpninfo->https_ssl),1);
 
        fcntl(vpninfo->ssl_fd, F_SETFL, fcntl(vpninfo->ssl_fd, F_GETFL) | O_NONBLOCK);
-       vpninfo->ssl_pfd = vpn_add_pollfd(vpninfo, vpninfo->ssl_fd, POLLIN|POLLHUP|POLLERR);
+       if (vpninfo->select_nfds <= vpninfo->ssl_fd)
+               vpninfo->select_nfds = vpninfo->ssl_fd + 1;
+
+       FD_SET(vpninfo->ssl_fd, &vpninfo->select_rfds);
+       FD_SET(vpninfo->ssl_fd, &vpninfo->select_efds);
 
        vpninfo->ssl_times.last_rx = vpninfo->ssl_times.last_tx = time(NULL);
        return 0;
@@ -435,7 +438,7 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout)
        if (vpninfo->current_ssl_pkt) {
        handle_outgoing:
                vpninfo->ssl_times.last_tx = time(NULL);
-               vpninfo->pfds[vpninfo->ssl_pfd].events &= ~POLLOUT;
+               FD_CLR(vpninfo->ssl_fd, &vpninfo->select_wfds);
                ret = SSL_write(vpninfo->https_ssl,
                                vpninfo->current_ssl_pkt->hdr,
                                vpninfo->current_ssl_pkt->len + 8);
@@ -445,7 +448,8 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout)
                        case SSL_ERROR_WANT_WRITE:
                                /* Waiting for the socket to become writable -- it's
                                   probably stalled, and/or the buffers are full */
-                               vpninfo->pfds[vpninfo->ssl_pfd].events |= POLLOUT;
+                               FD_SET(vpninfo->ssl_fd, &vpninfo->select_wfds);
+
                        case SSL_ERROR_WANT_READ:
                                if (ka_stalled_dpd_time(&vpninfo->ssl_times, timeout))
                                        goto peer_dead;
@@ -491,9 +495,7 @@ int cstp_mainloop(struct openconnect_info *vpninfo, int *timeout)
        case KA_DPD_DEAD:
        peer_dead:
                vpninfo->progress(vpninfo, PRG_ERR, "CSTP Dead Peer Detection detected dead peer!\n");
-               SSL_free(vpninfo->https_ssl);
-               vpninfo->https_ssl = NULL;
-               close(vpninfo->ssl_fd);
+               openconnect_close_https(vpninfo);
 
                /* It's already deflated in the old stream. Extremely 
                   non-trivial to reconstitute it; just throw it away */
diff --git a/dtls.c b/dtls.c
index 8b4bec0..97335d8 100644 (file)
--- a/dtls.c
+++ b/dtls.c
@@ -189,7 +189,12 @@ int connect_dtls_socket(struct openconnect_info *vpninfo)
 
        vpninfo->new_dtls_fd = dtls_fd;
        vpninfo->new_dtls_ssl = dtls_ssl;
-       vpninfo->pfds[vpninfo->new_dtls_pfd].fd = vpninfo->new_dtls_fd;
+       
+       if (vpninfo->select_nfds <= dtls_fd)
+               vpninfo->select_nfds = dtls_fd + 1;
+
+       FD_SET(dtls_fd, &vpninfo->select_rfds);
+       FD_SET(dtls_fd, &vpninfo->select_efds);
 
        time(&vpninfo->new_dtls_started);
        return dtls_try_handshake(vpninfo);
@@ -206,12 +211,13 @@ int dtls_try_handshake(struct openconnect_info *vpninfo)
                        /* We are replacing an old connection */
                        SSL_free(vpninfo->dtls_ssl);
                        close(vpninfo->dtls_fd);
+                       FD_CLR(vpninfo->dtls_fd, &vpninfo->select_rfds);
+                       FD_CLR(vpninfo->dtls_fd, &vpninfo->select_wfds);
+                       FD_CLR(vpninfo->dtls_fd, &vpninfo->select_efds);
                }
-               vpninfo->pfds[vpninfo->dtls_pfd].fd = vpninfo->new_dtls_fd;
                vpninfo->dtls_ssl = vpninfo->new_dtls_ssl;
                vpninfo->dtls_fd = vpninfo->new_dtls_fd;
 
-               vpninfo->pfds[vpninfo->new_dtls_pfd].fd = -1;
                vpninfo->new_dtls_ssl = NULL;
                vpninfo->new_dtls_fd = -1;
 
@@ -233,7 +239,8 @@ int dtls_try_handshake(struct openconnect_info *vpninfo)
 
        /* Kill the new (failed) connection... */
        SSL_free(vpninfo->new_dtls_ssl);
-       vpninfo->pfds[vpninfo->new_dtls_pfd].fd = -1;
+       FD_CLR(vpninfo->new_dtls_fd, &vpninfo->select_rfds);
+       FD_CLR(vpninfo->new_dtls_fd, &vpninfo->select_efds);
        close(vpninfo->new_dtls_fd);
        vpninfo->new_dtls_ssl = NULL;
        vpninfo->new_dtls_fd = -1;
@@ -244,7 +251,9 @@ int dtls_try_handshake(struct openconnect_info *vpninfo)
        if (vpninfo->dtls_ssl) {
                SSL_free(vpninfo->dtls_ssl);
                close(vpninfo->dtls_fd);
-               vpninfo->pfds[vpninfo->dtls_pfd].fd = -1;
+               FD_CLR(vpninfo->dtls_fd, &vpninfo->select_rfds);
+               FD_CLR(vpninfo->dtls_fd, &vpninfo->select_wfds);
+               FD_CLR(vpninfo->dtls_fd, &vpninfo->select_efds);
                vpninfo->dtls_ssl = NULL;
                vpninfo->dtls_fd = -1;
        }
@@ -258,7 +267,9 @@ static int dtls_restart(struct openconnect_info *vpninfo)
        if (vpninfo->dtls_ssl) {
                SSL_free(vpninfo->dtls_ssl);
                close(vpninfo->dtls_fd);
-               vpninfo->pfds[vpninfo->dtls_pfd].fd = -1;
+               FD_CLR(vpninfo->dtls_fd, &vpninfo->select_rfds);
+               FD_CLR(vpninfo->dtls_fd, &vpninfo->select_wfds);
+               FD_CLR(vpninfo->dtls_fd, &vpninfo->select_efds);
                vpninfo->dtls_ssl = NULL;
                vpninfo->dtls_fd = -1;
        }
@@ -315,11 +326,7 @@ int setup_dtls(struct openconnect_info *vpninfo)
                return -EINVAL;
        }
 
-       vpninfo->dtls_pfd = vpn_add_pollfd(vpninfo, -1,
-                                          POLLIN|POLLHUP|POLLERR);
-       vpninfo->new_dtls_pfd = vpn_add_pollfd(vpninfo, -1,
-                                          POLLIN|POLLHUP|POLLERR);
-
+       
        if (connect_dtls_socket(vpninfo))
                return -EINVAL;
 
index d892e85..166faf6 100644 (file)
@@ -54,20 +54,6 @@ int queue_new_packet(struct pkt **q, int type, void *buf, int len)
        return 0;
 }
 
-int vpn_add_pollfd(struct openconnect_info *vpninfo, int fd, short events)
-{
-       vpninfo->nfds++;
-       vpninfo->pfds = realloc(vpninfo->pfds, sizeof(struct pollfd) * vpninfo->nfds);
-       if (!vpninfo->pfds) {
-               vpninfo->progress(vpninfo, PRG_ERR, "Failed to reallocate pfds\n");
-               exit(1);
-       }
-       vpninfo->pfds[vpninfo->nfds - 1].fd = fd;
-       vpninfo->pfds[vpninfo->nfds - 1].events = events;
-
-       return vpninfo->nfds - 1;
-}
-
 static int killed;
 
 static void handle_sigint(int sig)
@@ -87,6 +73,8 @@ int vpn_mainloop(struct openconnect_info *vpninfo)
        while (!vpninfo->quit_reason) {
                int did_work = 0;
                int timeout = INT_MAX;
+               struct timeval tv;
+               fd_set rfds, wfds, efds;
 
                if (vpninfo->new_dtls_ssl)
                        dtls_try_handshake(vpninfo);
@@ -125,9 +113,14 @@ int vpn_mainloop(struct openconnect_info *vpninfo)
 
                vpninfo->progress(vpninfo, PRG_TRACE, 
                                  "Did no work; sleeping for %d ms...\n", timeout);
-
-               poll(vpninfo->pfds, vpninfo->nfds, timeout);
-               if (vpninfo->pfds[vpninfo->ssl_pfd].revents & POLL_HUP) {
+               memcpy(&rfds, &vpninfo->select_rfds, sizeof(rfds));
+               memcpy(&wfds, &vpninfo->select_wfds, sizeof(wfds));
+               memcpy(&efds, &vpninfo->select_efds, sizeof(efds));
+
+               tv.tv_sec = timeout / 1000;
+               tv.tv_usec = (timeout * 1000) % 1000000;
+               select(vpninfo->select_nfds, &rfds, &wfds, &efds, &tv);
+               if (FD_ISSET(vpninfo->ssl_fd, &efds)) {
                        vpninfo->progress(vpninfo, PRG_ERR, "Server closed connection!\n");
                        /* OpenSSL doesn't seem to cope properly with this... */
                        exit(1);
index 2e9479a..4d85ffd 100644 (file)
 #define __OPENCONNECT_ANYCONNECT_H
 
 #include <openssl/ssl.h>
-#include <poll.h>
 #include <zlib.h>
 #include <stdint.h>
 #include <sys/socket.h>
+#include <sys/select.h>
+#include <sys/time.h>
+#include <sys/types.h>
+#include <unistd.h>
 
 struct pkt {
        int type;
@@ -113,15 +116,15 @@ struct openconnect_info {
        const char *vpn_nbns[3];
        const char *vpn_domain;
 
-       struct pollfd *pfds;
-       int nfds;
+       int select_nfds;
+       fd_set select_rfds;
+       fd_set select_wfds;
+       fd_set select_efds;
+
        int tun_fd;
        int ssl_fd;
        int dtls_fd;
        int new_dtls_fd;
-       int ssl_pfd;
-       int dtls_pfd;
-       int new_dtls_pfd;
 
        struct pkt *incoming_queue;
        struct pkt *outgoing_queue;
@@ -177,6 +180,7 @@ int  __attribute__ ((format (printf, 2, 3)))
                openconnect_SSL_printf(SSL *ssl, const char *fmt, ...);
 int openconnect_SSL_gets(SSL *ssl, char *buf, size_t len);
 int openconnect_open_https(struct openconnect_info *vpninfo);
+void openconnect_close_https(struct openconnect_info *vpninfo);
 
 /* main.c */
 extern int verbose;
diff --git a/ssl.c b/ssl.c
index 7f34df3..4a5ec66 100644 (file)
--- a/ssl.c
+++ b/ssl.c
@@ -253,6 +253,17 @@ int openconnect_open_https(struct openconnect_info *vpninfo)
        return 0;
 }
 
+void openconnect_close_https(struct openconnect_info *vpninfo)
+{
+       SSL_free(vpninfo->https_ssl);
+       vpninfo->https_ssl = NULL;
+       close(vpninfo->ssl_fd);
+       FD_CLR(vpninfo->ssl_fd, &vpninfo->select_rfds);
+       FD_CLR(vpninfo->ssl_fd, &vpninfo->select_wfds);
+       FD_CLR(vpninfo->ssl_fd, &vpninfo->select_efds);
+       vpninfo->ssl_fd = -1;
+}
+
 void openconnect_init_openssl(void)
 {
        SSL_library_init ();
diff --git a/tun.c b/tun.c
index 437404f..f2c63a5 100644 (file)
--- a/tun.c
+++ b/tun.c
@@ -154,7 +154,6 @@ int setup_tun(struct openconnect_info *vpninfo)
 {
        struct ifreq ifr;
        int tun_fd;
-       int pfd;
 
        if (vpninfo->script_tun) {
                pid_t child;
@@ -225,7 +224,11 @@ int setup_tun(struct openconnect_info *vpninfo)
        fcntl(tun_fd, F_SETFD, FD_CLOEXEC);
 
        vpninfo->tun_fd = tun_fd;
-       pfd = vpn_add_pollfd(vpninfo, vpninfo->tun_fd, POLLIN);
+       
+       if (vpninfo->select_nfds <= tun_fd)
+               vpninfo->select_nfds = tun_fd + 1;
+
+       FD_SET(tun_fd, &vpninfo->select_rfds);
 
        fcntl(vpninfo->tun_fd, F_SETFL, fcntl(vpninfo->tun_fd, F_GETFL) | O_NONBLOCK);