Allow transport_write calls to be non-blocking
authorHardening <rdp.effort@gmail.com>
Wed, 21 May 2014 15:32:14 +0000 (17:32 +0200)
committerHardening <rdp.effort@gmail.com>
Wed, 21 May 2014 15:42:31 +0000 (17:42 +0200)
This big patch allows to have non-blocking writes. To achieve
this, it slightly changes the way transport is handled. The misc transport
layers are handled with OpenSSL BIOs. In the chain we insert a
bufferedBIO that will bufferize write calls that couldn't be honored.

For an access with Tls security the BIO chain would look like this:
  FreeRdp Code ===> SSL bio ===> buffered BIO ===> socket BIO

The buffered BIO will store bytes that couldn't be send because of
blocking write calls.

This patch also rework TSG so that it would look like this in the
case of SSL security with TSG:
                                         (TSG in)
                              > SSL BIO => buffered BIO ==> socket BIO
                             /
FreeRdp => SSL BIO => TSG BIO
                             \
                              > SSL BIO => buffered BIO ==> socket BIO
                                        (TSG out)

So from the FreeRDP point of view sending something is only BIO_writing
on the frontBio (last BIO on the left).

20 files changed:
include/freerdp/crypto/tls.h
include/freerdp/peer.h
include/freerdp/settings.h
libfreerdp/core/gateway/http.c
libfreerdp/core/gateway/http.h
libfreerdp/core/gateway/ncacn_http.c
libfreerdp/core/gateway/rpc.c
libfreerdp/core/gateway/rpc.h
libfreerdp/core/gateway/rpc_bind.c
libfreerdp/core/gateway/rpc_client.c
libfreerdp/core/gateway/rts.c
libfreerdp/core/gateway/rts_signature.c
libfreerdp/core/gateway/tsg.c
libfreerdp/core/peer.c
libfreerdp/core/settings.c
libfreerdp/core/tcp.c
libfreerdp/core/tcp.h
libfreerdp/core/transport.c
libfreerdp/core/transport.h
libfreerdp/crypto/tls.c

index bf55213..180007e 100644 (file)
@@ -70,7 +70,6 @@ struct rdp_tls
        SSL* ssl;
        BIO* bio;
        void* tsg;
-       int sockfd;
        SSL_CTX* ctx;
        BYTE* PublicKey;
        BIO_METHOD* methods;
@@ -84,17 +83,11 @@ struct rdp_tls
        int alertDescription;
 };
 
-FREERDP_API int tls_connect(rdpTls* tls);
-FREERDP_API BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file);
+FREERDP_API int tls_connect(rdpTls* tls, BIO *underlying);
+FREERDP_API BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file);
 FREERDP_API BOOL tls_disconnect(rdpTls* tls);
 
-FREERDP_API int tls_read(rdpTls* tls, BYTE* data, int length);
-FREERDP_API int tls_write(rdpTls* tls, BYTE* data, int length);
-
-FREERDP_API int tls_write_all(rdpTls* tls, BYTE* data, int length);
-
-FREERDP_API int tls_wait_read(rdpTls* tls);
-FREERDP_API int tls_wait_write(rdpTls* tls);
+FREERDP_API int tls_write_all(rdpTls* tls, const BYTE* data, int length);
 
 FREERDP_API int tls_set_alert_code(rdpTls* tls, int level, int description);
 
index c89d37a..4fbe75b 100644 (file)
@@ -34,7 +34,10 @@ typedef void (*psPeerContextFree)(freerdp_peer* client, rdpContext* context);
 typedef BOOL (*psPeerInitialize)(freerdp_peer* client);
 typedef BOOL (*psPeerGetFileDescriptor)(freerdp_peer* client, void** rfds, int* rcount);
 typedef HANDLE (*psPeerGetEventHandle)(freerdp_peer* client);
+typedef HANDLE (*psPeerGetReceiveEventHandle)(freerdp_peer* client);
 typedef BOOL (*psPeerCheckFileDescriptor)(freerdp_peer* client);
+typedef BOOL (*psPeerIsWriteBlocked)(freerdp_peer* client);
+typedef int (*psPeerDrainOutputBuffer)(freerdp_peer* client);
 typedef BOOL (*psPeerClose)(freerdp_peer* client);
 typedef void (*psPeerDisconnect)(freerdp_peer* client);
 typedef BOOL (*psPeerCapabilities)(freerdp_peer* client);
@@ -62,6 +65,7 @@ struct rdp_freerdp_peer
        psPeerInitialize Initialize;
        psPeerGetFileDescriptor GetFileDescriptor;
        psPeerGetEventHandle GetEventHandle;
+       psPeerGetReceiveEventHandle GetReceiveEventHandle;
        psPeerCheckFileDescriptor CheckFileDescriptor;
        psPeerClose Close;
        psPeerDisconnect Disconnect;
@@ -81,6 +85,9 @@ struct rdp_freerdp_peer
        BOOL activated;
        BOOL authenticated;
        SEC_WINNT_AUTH_IDENTITY identity;
+
+       psPeerIsWriteBlocked IsWriteBlocked;
+       psPeerDrainOutputBuffer DrainOutputBuffer;
 };
 
 #ifdef __cplusplus
index 6e921eb..dab7875 100644 (file)
@@ -798,7 +798,8 @@ struct rdp_settings
        ALIGN64 char* Password; /* 22 */
        ALIGN64 char* Domain; /* 23 */
        ALIGN64 char* PasswordHash; /* 24 */
-       UINT64 padding0064[64 - 25]; /* 25 */
+       ALIGN64 BOOL WaitForOutputBufferFlush; /* 25 */
+       UINT64 padding0064[64 - 26]; /* 26 */
        UINT64 padding0128[128 - 64]; /* 64 */
 
        /**
index c9f33f0..610b230 100644 (file)
 #include <winpr/stream.h>
 #include <winpr/string.h>
 
+#ifdef HAVE_VALGRIND_MEMCHECK_H
+#include <valgrind/memcheck.h>
+#endif
+
 #include "http.h"
 
 HttpContext* http_context_new()
@@ -472,7 +476,7 @@ HttpResponse* http_response_recv(rdpTls* tls)
        nbytes = 0;
        length = 10000;
        content = NULL;
-       buffer = malloc(length);
+       buffer = calloc(length, 1);
        if (!buffer)
                return NULL;
 
@@ -487,14 +491,20 @@ HttpResponse* http_response_recv(rdpTls* tls)
        {
                while (nbytes < 5)
                {
-                       status = tls_read(tls, p, length - nbytes);
+                       status = BIO_read(tls->bio, p, length - nbytes);
             
-                       if (status < 0)
-                               goto out_error;
+                       if (status <= 0)
+                       {
+                               if (!BIO_should_retry(tls->bio))
+                                       goto out_error;
 
-                       if (!status)
+                               USleep(100);
                                continue;
+                       }
 
+#ifdef HAVE_VALGRIND_MEMCHECK_H
+                       VALGRIND_MAKE_MEM_DEFINED(p, status);
+#endif
                        nbytes += status;
                        p = (BYTE*) &buffer[nbytes];
                }
@@ -503,7 +513,7 @@ HttpResponse* http_response_recv(rdpTls* tls)
         
                if (!header_end)
                {
-                       fprintf(stderr, "http_response_recv: invalid response:\n");
+                       fprintf(stderr, "%s: invalid response:\n", __FUNCTION__);
                        winpr_HexDump(buffer, status);
                        goto out_error;
                }
@@ -517,7 +527,7 @@ HttpResponse* http_response_recv(rdpTls* tls)
 
                        header_end[0] = '\0';
                        header_end[1] = '\0';
-                       content = &header_end[2];
+                       content = header_end + 2;
 
                        count = 0;
                        line = (char*) buffer;
@@ -552,11 +562,14 @@ HttpResponse* http_response_recv(rdpTls* tls)
                        if (!http_response_parse_header(http_response))
                                goto out_error;
 
-                       if (http_response->ContentLength > 0)
+                       http_response->bodyLen = nbytes - (content - (char *)buffer);
+                       if (http_response->bodyLen > 0)
                        {
-                               http_response->Content = _strdup(content);
-                               if (!http_response->Content)
+                               http_response->BodyContent = (BYTE *)malloc(http_response->bodyLen);
+                               if (!http_response->BodyContent)
                                        goto out_error;
+
+                               CopyMemory(http_response->BodyContent, content, http_response->bodyLen);
                        }
 
                        break;
@@ -627,7 +640,7 @@ void http_response_free(HttpResponse* http_response)
        ListDictionary_Free(http_response->Authenticates);
 
        if (http_response->ContentLength > 0)
-               free(http_response->Content);
+               free(http_response->BodyContent);
 
        free(http_response);
 }
index 748b45a..ded9ba2 100644 (file)
@@ -84,7 +84,8 @@ struct _http_response
 
        wListDictionary *Authenticates;
        int ContentLength;
-       char* Content;
+       BYTE *BodyContent;
+       int bodyLen;
 };
 
 void http_response_print(HttpResponse* http_response);
index 270dafb..b5beff4 100644 (file)
@@ -98,6 +98,8 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc)
        rdpNtlm* ntlm = rpc->NtlmHttpIn->ntlm;
 
        http_response = http_response_recv(rpc->TlsIn);
+       if (!http_response)
+               return -1;
 
        if (ListDictionary_Contains(http_response->Authenticates, "NTLM"))
        {
@@ -105,14 +107,12 @@ int rpc_ncacn_http_recv_in_channel_response(rdpRpc* rpc)
                if (!token64)
                        goto out;
 
-               ntlm_token_data = NULL;
                crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length);
        }
 
+out:
        ntlm->inputBuffer[0].pvBuffer = ntlm_token_data;
        ntlm->inputBuffer[0].cbBuffer = ntlm_token_length;
-
-out:
        http_response_free(http_response);
 
        return 0;
@@ -123,25 +123,19 @@ int rpc_ncacn_http_ntlm_init(rdpRpc* rpc, TSG_CHANNEL channel)
        rdpNtlm* ntlm = NULL;
        rdpSettings* settings = rpc->settings;
        freerdp* instance = (freerdp*) rpc->settings->instance;
-       BOOL promptPassword = FALSE;
 
        if (channel == TSG_CHANNEL_IN)
                ntlm = rpc->NtlmHttpIn->ntlm;
        else if (channel == TSG_CHANNEL_OUT)
                ntlm = rpc->NtlmHttpOut->ntlm;
 
-       if ((!settings->GatewayPassword) || (!settings->GatewayUsername)
-                       || (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername)))
-       {
-               promptPassword = TRUE;
-       }
-
-       if (promptPassword)
+       if (!settings->GatewayPassword || !settings->GatewayUsername ||
+                       !strlen(settings->GatewayPassword) || !strlen(settings->GatewayUsername))
        {
                if (instance->GatewayAuthenticate)
                {
-                       BOOL proceed = instance->GatewayAuthenticate(instance,
-                                       &settings->GatewayUsername, &settings->GatewayPassword, &settings->GatewayDomain);
+                       BOOL proceed = instance->GatewayAuthenticate(instance, &settings->GatewayUsername,
+                                                                               &settings->GatewayPassword, &settings->GatewayDomain);
 
                        if (!proceed)
                        {
@@ -240,12 +234,10 @@ int rpc_ncacn_http_recv_out_channel_response(rdpRpc* rpc)
                char *token64 = ListDictionary_GetItemValue(http_response->Authenticates, "NTLM");
                crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length);
        }
-
        ntlm->inputBuffer[0].pvBuffer = ntlm_token_data;
        ntlm->inputBuffer[0].cbBuffer = ntlm_token_length;
-
+       
        http_response_free(http_response);
-
        return 0;
 }
 
@@ -259,15 +251,12 @@ BOOL rpc_ntlm_http_out_connect(rdpRpc* rpc)
                success = TRUE;
 
                /* Send OUT Channel Request */
-
                rpc_ncacn_http_send_out_channel_request(rpc);
 
                /* Receive OUT Channel Response */
-
                rpc_ncacn_http_recv_out_channel_response(rpc);
 
                /* Send OUT Channel Request */
-
                rpc_ncacn_http_send_out_channel_request(rpc);
 
                ntlm_client_uninit(ntlm);
@@ -296,13 +285,11 @@ void rpc_ntlm_http_init_channel(rdpRpc* rpc, rdpNtlmHttp* ntlm_http, TSG_CHANNEL
 
        if (channel == TSG_CHANNEL_IN)
        {
-               http_context_set_pragma(ntlm_http->context,
-                       "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729");
+               http_context_set_pragma(ntlm_http->context,     "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729");
        }
        else if (channel == TSG_CHANNEL_OUT)
        {
-               http_context_set_pragma(ntlm_http->context,
-                               "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729" ", "
+               http_context_set_pragma(ntlm_http->context,     "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729, "
                                "SessionId=fbd9c34f-397d-471d-a109-1b08cc554624");
        }
 }
index c91a710..2432ab0 100644 (file)
 #include <winpr/dsparse.h>
 
 #include <openssl/rand.h>
+#include <openssl/bio.h>
+
+#ifdef HAVE_VALGRIND_MEMCHECK_H
+#include <valgrind/memcheck.h>
+#endif
 
 #include "http.h"
 #include "ntlm.h"
@@ -235,80 +240,77 @@ BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* buffer, UINT32* offset, UINT32* l
 {
        UINT32 alloc_hint = 0;
        rpcconn_hdr_t* header;
+       UINT32 frag_length;
+       UINT32 auth_length;
+       UINT32 auth_pad_length;
+       UINT32 sec_trailer_offset;
+       rpc_sec_trailer* sec_trailer;
 
        *offset = RPC_COMMON_FIELDS_LENGTH;
        header = ((rpcconn_hdr_t*) buffer);
 
-       if (header->common.ptype == PTYPE_RESPONSE)
-       {
-               *offset += 8;
-               rpc_offset_align(offset, 8);
-               alloc_hint = header->response.alloc_hint;
-       }
-       else if (header->common.ptype == PTYPE_REQUEST)
-       {
-               *offset += 4;
-               rpc_offset_align(offset, 8);
-               alloc_hint = header->request.alloc_hint;
-       }
-       else if (header->common.ptype == PTYPE_RTS)
-       {
-               *offset += 4;
-       }
-       else
+       switch (header->common.ptype)
        {
-               return FALSE;
+               case PTYPE_RESPONSE:
+                       *offset += 8;
+                       rpc_offset_align(offset, 8);
+                       alloc_hint = header->response.alloc_hint;
+                       break;
+               case PTYPE_REQUEST:
+                       *offset += 4;
+                       rpc_offset_align(offset, 8);
+                       alloc_hint = header->request.alloc_hint;
+                       break;
+               case PTYPE_RTS:
+                       *offset += 4;
+                       break;
+               default:
+                       fprintf(stderr, "%s: unknown ptype=0x%x\n", __FUNCTION__, header->common.ptype);
+                       return FALSE;
        }
 
-       if (length)
+       if (!length)
+               return TRUE;
+
+       if (header->common.ptype == PTYPE_REQUEST)
        {
-               if (header->common.ptype == PTYPE_REQUEST)
-               {
-                       UINT32 sec_trailer_offset;
+               UINT32 sec_trailer_offset;
 
-                       sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8;
-                       *length = sec_trailer_offset - *offset;
-               }
-               else
-               {
-                       UINT32 frag_length;
-                       UINT32 auth_length;
-                       UINT32 auth_pad_length;
-                       UINT32 sec_trailer_offset;
-                       rpc_sec_trailer* sec_trailer;
+               sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8;
+               *length = sec_trailer_offset - *offset;
+               return TRUE;
+       }
 
-                       frag_length = header->common.frag_length;
-                       auth_length = header->common.auth_length;
 
-                       sec_trailer_offset = frag_length - auth_length - 8;
-                       sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset];
-                       auth_pad_length = sec_trailer->auth_pad_length;
+       frag_length = header->common.frag_length;
+       auth_length = header->common.auth_length;
+
+       sec_trailer_offset = frag_length - auth_length - 8;
+       sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset];
+       auth_pad_length = sec_trailer->auth_pad_length;
 
 #if 0
-                       fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n",
-                                       sec_trailer->auth_type,
-                                       sec_trailer->auth_level,
-                                       sec_trailer->auth_pad_length,
-                                       sec_trailer->auth_reserved,
-                                       sec_trailer->auth_context_id);
+       fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n",
+                       sec_trailer->auth_type,
+                       sec_trailer->auth_level,
+                       sec_trailer->auth_pad_length,
+                       sec_trailer->auth_reserved,
+                       sec_trailer->auth_context_id);
 #endif
 
-                       /**
-                        * According to [MS-RPCE], auth_pad_length is the number of padding
-                        * octets used to 4-byte align the security trailer, but in practice
-                        * we get values up to 15, which indicates 16-byte alignment.
-                        */
-
-                       if ((frag_length - (sec_trailer_offset + 8)) != auth_length)
-                       {
-                               fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length,
-                                               (frag_length - (sec_trailer_offset + 8)));
-                       }
+       /**
+        * According to [MS-RPCE], auth_pad_length is the number of padding
+        * octets used to 4-byte align the security trailer, but in practice
+        * we get values up to 15, which indicates 16-byte alignment.
+        */
 
-                       *length = frag_length - auth_length - 24 - 8 - auth_pad_length;
-               }
+       if ((frag_length - (sec_trailer_offset + 8)) != auth_length)
+       {
+               fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length,
+                               (frag_length - (sec_trailer_offset + 8)));
        }
 
+       *length = frag_length - auth_length - 24 - 8 - auth_pad_length;
        return TRUE;
 }
 
@@ -316,12 +318,23 @@ int rpc_out_read(rdpRpc* rpc, BYTE* data, int length)
 {
        int status;
 
-       status = tls_read(rpc->TlsOut, data, length);
+       status = BIO_read(rpc->TlsOut->bio, data, length);
+       /* fprintf(stderr, "%s: length=%d => status=%d shouldRetry=%d\n", __FUNCTION__, length,
+        *                                      status, BIO_should_retry(rpc->TlsOut->bio)); */
+       if (status > 0) {
+#ifdef HAVE_VALGRIND_MEMCHECK_H
+               VALGRIND_MAKE_MEM_DEFINED(data, status);
+#endif
+               return status;
+       }
 
-       return status;
+       if (BIO_should_retry(rpc->TlsOut->bio))
+               return 0;
+
+       return -1;
 }
 
-int rpc_out_write(rdpRpc* rpc, BYTE* data, int length)
+int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length)
 {
        int status;
 
@@ -330,7 +343,7 @@ int rpc_out_write(rdpRpc* rpc, BYTE* data, int length)
        return status;
 }
 
-int rpc_in_write(rdpRpc* rpc, BYTE* data, int length)
+int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length)
 {
        int status;
 
@@ -360,20 +373,21 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
 
        ntlm = rpc->ntlm;
 
-       if ((!ntlm) || (!ntlm->table))
+       if (!ntlm || !ntlm->table)
        {
-               fprintf(stderr, "rpc_write: invalid ntlm context\n");
+               fprintf(stderr, "%s: invalid ntlm context\n", __FUNCTION__);
                return -1;
        }
 
        if (ntlm->table->QueryContextAttributes(&ntlm->context, SECPKG_ATTR_SIZES, &ntlm->ContextSizes) != SEC_E_OK)
        {
-               fprintf(stderr, "QueryContextAttributes SECPKG_ATTR_SIZES failure\n");
+               fprintf(stderr, "%s: QueryContextAttributes SECPKG_ATTR_SIZES failure\n", __FUNCTION__);
                return -1;
        }
 
-       request_pdu = (rpcconn_request_hdr_t*) malloc(sizeof(rpcconn_request_hdr_t));
-       ZeroMemory(request_pdu, sizeof(rpcconn_request_hdr_t));
+       request_pdu = (rpcconn_request_hdr_t*) calloc(1, sizeof(rpcconn_request_hdr_t));
+       if (!request_pdu)
+               return -1;
 
        rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) request_pdu);
 
@@ -386,7 +400,11 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
        request_pdu->opnum = opnum;
 
        clientCall = rpc_client_call_new(request_pdu->call_id, request_pdu->opnum);
-       ArrayList_Add(rpc->client->ClientCallList, clientCall);
+       if (!clientCall)
+               goto out_free_pdu;
+
+       if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0)
+               goto out_free_clientCall;
 
        if (request_pdu->opnum == TsProxySetupReceivePipeOpnum)
                rpc->PipeCallId = request_pdu->call_id;
@@ -407,8 +425,9 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
 
        request_pdu->frag_length = offset;
 
-       buffer = (BYTE*) malloc(request_pdu->frag_length);
-
+       buffer = (BYTE*) calloc(1, request_pdu->frag_length);
+       if (!buffer)
+               goto out_free_pdu;
        CopyMemory(buffer, request_pdu, 24);
 
        offset = 24;
@@ -427,15 +446,15 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
        Buffers[0].cbBuffer = offset;
 
        Buffers[1].cbBuffer = ntlm->ContextSizes.cbMaxSignature;
-       Buffers[1].pvBuffer = malloc(Buffers[1].cbBuffer);
-       ZeroMemory(Buffers[1].pvBuffer, Buffers[1].cbBuffer);
+       Buffers[1].pvBuffer = calloc(1, Buffers[1].cbBuffer);
+       if (!Buffers[1].pvBuffer)
+               return -1;
 
        Message.cBuffers = 2;
        Message.ulVersion = SECBUFFER_VERSION;
        Message.pBuffers = (PSecBuffer) &Buffers;
 
        encrypt_status = ntlm->table->EncryptMessage(&ntlm->context, 0, &Message, rpc->SendSeqNum++);
-
        if (encrypt_status != SEC_E_OK)
        {
                fprintf(stderr, "EncryptMessage status: 0x%08X\n", encrypt_status);
@@ -447,12 +466,18 @@ int rpc_write(rdpRpc* rpc, BYTE* data, int length, UINT16 opnum)
        offset += Buffers[1].cbBuffer;
        free(Buffers[1].pvBuffer);
 
-       if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) != 0)
+       if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) < 0)
                length = -1;
 
        free(request_pdu);
 
        return length;
+
+out_free_clientCall:
+       rpc_client_call_free(clientCall);
+out_free_pdu:
+       free(request_pdu);
+       return -1;
 }
 
 BOOL rpc_connect(rdpRpc* rpc)
@@ -592,13 +617,17 @@ rdpRpc* rpc_new(rdpTransport* transport)
 
        rpc->CallId = 2;
 
-       rpc_client_new(rpc);
+       if (rpc_client_new(rpc) < 0)
+               goto out_free_virtualConnectionCookieTable;
 
        rpc->client->SynchronousSend = TRUE;
        rpc->client->SynchronousReceive = TRUE;
 
        return rpc;
 
+out_free_virtualConnectionCookieTable:
+       rpc_client_free(rpc);
+       ArrayList_Free(rpc->VirtualConnectionCookieTable);
 out_free_virtual_connection:
        rpc_client_virtual_connection_free(rpc->VirtualConnection);
 out_free_ntlm_http_out:
index d10d665..c86a861 100644 (file)
@@ -772,8 +772,8 @@ UINT32 rpc_offset_pad(UINT32* offset, UINT32 pad);
 
 int rpc_out_read(rdpRpc* rpc, BYTE* data, int length);
 
-int rpc_out_write(rdpRpc* rpc, BYTE* data, int length);
-int rpc_in_write(rdpRpc* rpc, BYTE* data, int length);
+int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length);
+int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length);
 
 BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* header, UINT32* offset, UINT32* length);
 
index cf02a80..ceae951 100644 (file)
@@ -103,6 +103,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
        DEBUG_RPC("Sending bind PDU");
 
        rpc->ntlm = ntlm_new();
+       if (!rpc->ntlm)
+               return -1;
 
        if ((!settings->GatewayPassword) || (!settings->GatewayUsername)
                        || (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername)))
@@ -129,17 +131,22 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
                                settings->Username = _strdup(settings->GatewayUsername);
                                settings->Domain = _strdup(settings->GatewayDomain);
                                settings->Password = _strdup(settings->GatewayPassword);
+
+                               if (!settings->Username || !settings->Domain || settings->Password)
+                                       return -1;
                        }
                }
        }
 
-       ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL);
-       ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname);
-
-       ntlm_authenticate(rpc->ntlm);
+       if (!ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL) ||
+               !ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname) ||
+               !ntlm_authenticate(rpc->ntlm)
+               )
+               return -1;
 
-       bind_pdu = (rpcconn_bind_hdr_t*) malloc(sizeof(rpcconn_bind_hdr_t));
-       ZeroMemory(bind_pdu, sizeof(rpcconn_bind_hdr_t));
+       bind_pdu = (rpcconn_bind_hdr_t*) calloc(1, sizeof(rpcconn_bind_hdr_t));
+       if (!bind_pdu)
+               return -1;
 
        rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) bind_pdu);
 
@@ -159,6 +166,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
        bind_pdu->p_context_elem.reserved2 = 0;
 
        bind_pdu->p_context_elem.p_cont_elem = malloc(sizeof(p_cont_elem_t) * bind_pdu->p_context_elem.n_context_elem);
+       if (!bind_pdu->p_context_elem.p_cont_elem)
+               return -1;
 
        p_cont_elem = &bind_pdu->p_context_elem.p_cont_elem[0];
 
@@ -196,6 +205,8 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
        bind_pdu->frag_length = offset;
 
        buffer = (BYTE*) malloc(bind_pdu->frag_length);
+       if (!buffer)
+               return -1;
 
        CopyMemory(buffer, bind_pdu, 24);
        CopyMemory(&buffer[24], &bind_pdu->p_context_elem, 4);
@@ -214,7 +225,10 @@ int rpc_send_bind_pdu(rdpRpc* rpc)
        length = bind_pdu->frag_length;
 
        clientCall = rpc_client_call_new(bind_pdu->call_id, 0);
-       ArrayList_Add(rpc->client->ClientCallList, clientCall);
+       if (!clientCall)
+               return -1;
+       if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0)
+               return -1;
 
        if (rpc_send_enqueue_pdu(rpc, buffer, length) != 0)
                length = -1;
index dff88b3..c3613f6 100644 (file)
@@ -34,9 +34,7 @@
 #include <winpr/stream.h>
 
 #include "rpc_fault.h"
-
 #include "rpc_client.h"
-
 #include "../rdp.h"
 
 #define SYNCHRONOUS_TIMEOUT 5000
@@ -69,8 +67,15 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc)
 
        if (!pdu)
        {
-               pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU));
+               pdu = (RPC_PDU *)malloc(sizeof(RPC_PDU));
+               if (!pdu)
+                       return NULL;
                pdu->s = Stream_New(NULL, rpc->max_recv_frag);
+               if (!pdu->s)
+               {
+                       free(pdu);
+                       return NULL;
+               }
        }
 
        pdu->CallId = 0;
@@ -84,8 +89,7 @@ RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc)
 
 int rpc_client_receive_pool_return(rdpRpc* rpc, RPC_PDU* pdu)
 {
-       Queue_Enqueue(rpc->client->ReceivePool, pdu);
-       return 0;
+       return Queue_Enqueue(rpc->client->ReceivePool, pdu) == TRUE ? 0 : -1;
 }
 
 int rpc_client_on_fragment_received_event(rdpRpc* rpc)
@@ -97,7 +101,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
        rpcconn_hdr_t* header;
        freerdp* instance;
 
-       instance = (freerdp*) rpc->transport->settings->instance;
+       instance = (freerdp *)rpc->transport->settings->instance;
 
        if (!rpc->client->pdu)
                rpc->client->pdu = rpc_client_receive_pool_take(rpc);
@@ -125,34 +129,29 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
                return 0;
        }
 
-       if (header->common.ptype == PTYPE_RTS)
+       switch (header->common.ptype)
        {
-               if (rpc->VirtualConnection->State >= VIRTUAL_CONNECTION_STATE_OPENED)
-               {
-                       //fprintf(stderr, "Receiving Out-of-Sequence RTS PDU\n");
+               case PTYPE_RTS:
+                       if (rpc->VirtualConnection->State < VIRTUAL_CONNECTION_STATE_OPENED)
+                       {
+                               fprintf(stderr, "%s: warning: unhandled RTS PDU\n", __FUNCTION__);
+                               return 0;
+                       }
+                       fprintf(stderr, "%s: Receiving Out-of-Sequence RTS PDU\n", __FUNCTION__);
                        rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length);
-
                        rpc_client_fragment_pool_return(rpc, fragment);
-               }
-               else
-               {
-                       fprintf(stderr, "warning: unhandled RTS PDU\n");
-               }
+                       return 0;
 
-               return 0;
-       }
-       else if (header->common.ptype == PTYPE_FAULT)
-       {
-               rpc_recv_fault_pdu(header);
-               Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
-               return -1;
-       }
-
-       if (header->common.ptype != PTYPE_RESPONSE)
-       {
-               fprintf(stderr, "Unexpected RPC PDU type: %d\n", header->common.ptype);
-               Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
-               return -1;
+               case PTYPE_FAULT:
+                       rpc_recv_fault_pdu(header);
+                       Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
+                       return -1;
+               case PTYPE_RESPONSE:
+                       break;
+               default:
+                       fprintf(stderr, "%s: unexpected RPC PDU type %d\n", __FUNCTION__, header->common.ptype);
+                       Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
+                       return -1;
        }
 
        rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header->common.frag_length;
@@ -160,7 +159,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
 
        if (!rpc_get_stub_data_info(rpc, buffer, &StubOffset, &StubLength))
        {
-               fprintf(stderr, "rpc_recv_pdu_fragment: expected stub\n");
+               fprintf(stderr, "%s: expected stub\n", __FUNCTION__);
                Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
                return -1;
        }
@@ -196,7 +195,7 @@ int rpc_client_on_fragment_received_event(rdpRpc* rpc)
 
        if (rpc->StubCallId != header->common.call_id)
        {
-               fprintf(stderr, "invalid call_id: actual: %d, expected: %d, frag_count: %d\n",
+               fprintf(stderr, "%s: invalid call_id: actual: %d, expected: %d, frag_count: %d\n", __FUNCTION__,
                                rpc->StubCallId, header->common.call_id, rpc->StubFragCount);
        }
 
@@ -243,27 +242,34 @@ int rpc_client_on_read_event(rdpRpc* rpc)
        int status = -1;
        rpcconn_common_hdr_t* header;
 
-       if (!rpc->client->RecvFrag)
-               rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
-
-       position = Stream_GetPosition(rpc->client->RecvFrag);
-
-       if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
+       while (1)
        {
-               status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
-                               RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag));
+               if (!rpc->client->RecvFrag)
+                       rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
 
-               if (status < 0)
+               position = Stream_GetPosition(rpc->client->RecvFrag);
+
+               while (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
                {
-                       fprintf(stderr, "rpc_client_frag_read: error reading header\n");
-                       return -1;
+                       status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
+                                       RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag));
+
+                       if (status < 0)
+                       {
+                               fprintf(stderr, "rpc_client_frag_read: error reading header\n");
+                               return -1;
+                       }
+
+                       if (!status)
+                               return 0;
+
+                       Stream_Seek(rpc->client->RecvFrag, status);
                }
 
-               Stream_Seek(rpc->client->RecvFrag, status);
-       }
+               if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
+                       return status;
+
 
-       if (Stream_GetPosition(rpc->client->RecvFrag) >= RPC_COMMON_FIELDS_LENGTH)
-       {
                header = (rpcconn_common_hdr_t*) Stream_Buffer(rpc->client->RecvFrag);
 
                if (header->frag_length > rpc->max_recv_frag)
@@ -274,45 +280,44 @@ int rpc_client_on_read_event(rdpRpc* rpc)
                        return -1;
                }
 
-               if (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length)
+               while (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length)
                {
                        status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
                                        header->frag_length - Stream_GetPosition(rpc->client->RecvFrag));
 
                        if (status < 0)
                        {
-                               fprintf(stderr, "rpc_client_frag_read: error reading fragment body\n");
+                               fprintf(stderr, "%s: error reading fragment body\n", __FUNCTION__);
                                return -1;
                        }
 
+                       if (!status)
+                               return 0;
+
                        Stream_Seek(rpc->client->RecvFrag, status);
                }
-       }
-       else
-       {
-               return status;
-       }
 
-       if (status < 0)
-               return -1;
+               if (status < 0)
+                       return -1;
 
-       status = Stream_GetPosition(rpc->client->RecvFrag) - position;
+               status = Stream_GetPosition(rpc->client->RecvFrag) - position;
 
-       if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length)
-       {
-               /* complete fragment received */
+               if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length)
+               {
+                       /* complete fragment received */
 
-               Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag);
-               Stream_SetPosition(rpc->client->RecvFrag, 0);
+                       Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag);
+                       Stream_SetPosition(rpc->client->RecvFrag, 0);
 
-               Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag);
-               rpc->client->RecvFrag = NULL;
+                       Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag);
+                       rpc->client->RecvFrag = NULL;
 
-               if (rpc_client_on_fragment_received_event(rpc) < 0)
-                       return -1;
+                       if (rpc_client_on_fragment_received_event(rpc) < 0)
+                               return -1;
+               }
        }
 
-       return status;
+       return 0;
 }
 
 /**
@@ -349,13 +354,12 @@ RpcClientCall* rpc_client_call_new(UINT32 CallId, UINT32 OpNum)
        RpcClientCall* clientCall;
 
        clientCall = (RpcClientCall*) malloc(sizeof(RpcClientCall));
+       if (!clientCall)
+               return NULL;
 
-       if (clientCall)
-       {
-               clientCall->CallId = CallId;
-               clientCall->OpNum = OpNum;
-               clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS;
-       }
+       clientCall->CallId = CallId;
+       clientCall->OpNum = OpNum;
+       clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS;
 
        return clientCall;
 }
@@ -371,16 +375,22 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
        int status;
 
        pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU));
+       if (!pdu)
+               return -1;
+
        pdu->s = Stream_New(buffer, length);
+       if (!pdu->s)
+               goto out_free;
 
-       Queue_Enqueue(rpc->client->SendQueue, pdu);
+       if (!Queue_Enqueue(rpc->client->SendQueue, pdu))
+               goto out_free_stream;
 
        if (rpc->client->SynchronousSend)
        {
                status = WaitForSingleObject(rpc->client->PduSentEvent, SYNCHRONOUS_TIMEOUT);
                if (status == WAIT_TIMEOUT)
                {
-                       fprintf(stderr, "rpc_send_enqueue_pdu: timed out waiting for pdu sent event\n");
+                       fprintf(stderr, "%s: timed out waiting for pdu sent event %p\n", __FUNCTION__, rpc->client->PduSentEvent);
                        return -1;
                }
 
@@ -388,6 +398,12 @@ int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
        }
 
        return 0;
+
+out_free_stream:
+       Stream_Free(pdu->s, TRUE);
+out_free:
+       free(pdu);
+       return -1;
 }
 
 int rpc_send_dequeue_pdu(rdpRpc* rpc)
@@ -396,13 +412,14 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
        RPC_PDU* pdu;
        RpcClientCall* clientCall;
        rpcconn_common_hdr_t* header;
+       RpcInChannel *inChannel;
 
        pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->SendQueue);
-
        if (!pdu)
                return 0;
 
-       WaitForSingleObject(rpc->VirtualConnection->DefaultInChannel->Mutex, INFINITE);
+       inChannel = rpc->VirtualConnection->DefaultInChannel;
+       WaitForSingleObject(inChannel->Mutex, INFINITE);
 
        status = rpc_in_write(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s));
 
@@ -410,7 +427,7 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
        clientCall = rpc_client_call_find_by_id(rpc, header->call_id);
        clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED;
 
-       ReleaseMutex(rpc->VirtualConnection->DefaultInChannel->Mutex);
+       ReleaseMutex(inChannel->Mutex);
 
        /*
         * This protocol specifies that only RPC PDUs are subject to the flow control abstract
@@ -421,8 +438,8 @@ int rpc_send_dequeue_pdu(rdpRpc* rpc)
 
        if (header->ptype == PTYPE_REQUEST)
        {
-               rpc->VirtualConnection->DefaultInChannel->BytesSent += status;
-               rpc->VirtualConnection->DefaultInChannel->SenderAvailableWindow -= status;
+               inChannel->BytesSent += status;
+               inChannel->SenderAvailableWindow -= status;
        }
 
        Stream_Free(pdu->s, TRUE);
@@ -440,57 +457,48 @@ RPC_PDU* rpc_recv_dequeue_pdu(rdpRpc* rpc)
        DWORD dwMilliseconds;
        DWORD result;
 
-       pdu = NULL;
-       dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0;
+       dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT * 4 : 0;
 
        result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds);
        if (result == WAIT_TIMEOUT)
        {
-               fprintf(stderr, "rpc_recv_dequeue_pdu: timed out waiting for receive event\n");
+               fprintf(stderr, "%s: timed out waiting for receive event\n", __FUNCTION__);
                return NULL;
        }
 
-       if (result == WAIT_OBJECT_0)
-       {
-               pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->ReceiveQueue);
+       if (result != WAIT_OBJECT_0)
+               return NULL;
 
-#ifdef WITH_DEBUG_TSG
-               if (pdu)
-               {
-                       fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId);
-                       winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s));
-                       fprintf(stderr, "\n");
-               }
-#endif
+       pdu = (RPC_PDU *)Queue_Dequeue(rpc->client->ReceiveQueue);
 
-               return pdu;
+#ifdef WITH_DEBUG_TSG
+       if (pdu)
+       {
+               fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId);
+               winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s));
+               fprintf(stderr, "\n");
+       }
+       else
+       {
+               fprintf(stderr, "Receiving a NULL PDU\n");
        }
+#endif
 
        return pdu;
 }
 
 RPC_PDU* rpc_recv_peek_pdu(rdpRpc* rpc)
 {
-       RPC_PDU* pdu;
        DWORD dwMilliseconds;
        DWORD result;
 
-       pdu = NULL;
        dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0;
 
        result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds);
-       if (result == WAIT_TIMEOUT)
-       {
+       if (result != WAIT_OBJECT_0)
                return NULL;
-       }
-
-       if (result == WAIT_OBJECT_0)
-       {
-               pdu = (RPC_PDU*) Queue_Peek(rpc->client->ReceiveQueue);
-               return pdu;
-       }
 
-       return pdu;
+       return (RPC_PDU *)Queue_Peek(rpc->client->ReceiveQueue);
 }
 
 static void* rpc_client_thread(void* arg)
@@ -500,40 +508,52 @@ static void* rpc_client_thread(void* arg)
        DWORD nCount;
        HANDLE events[3];
        HANDLE ReadEvent;
+       int fd;
 
        rpc = (rdpRpc*) arg;
+       fd = BIO_get_fd(rpc->TlsOut->bio, NULL);
 
-       ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, rpc->TlsOut->sockfd);
+       ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, fd);
 
        nCount = 0;
        events[nCount++] = rpc->client->StopEvent;
        events[nCount++] = Queue_Event(rpc->client->SendQueue);
        events[nCount++] = ReadEvent;
 
+       /* Do a first free run in case some bytes were set from the HTTP headers.
+        * We also have to do it because most of the time the underlying socket has notified,
+        * and the ssl layer has eaten all bytes, so we won't be notified any more even if the
+        * bytes are buffered locally
+        */
+       if (rpc_client_on_read_event(rpc) < 0)
+       {
+               fprintf(stderr, "%s: an error occured when treating first packet\n", __FUNCTION__);
+               goto out;
+       }
+
        while (rpc->transport->layer != TRANSPORT_LAYER_CLOSED)
        {
                status = WaitForMultipleObjects(nCount, events, FALSE, 100);
 
-               if (status != WAIT_TIMEOUT)
+               if (status == WAIT_TIMEOUT)
+                       continue;
+
+               if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0)
+                       break;
+
+               if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0)
                {
-                       if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0)
-                       {
+                       if (rpc_client_on_read_event(rpc) < 0)
                                break;
-                       }
-
-                       if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0)
-                       {
-                               if (rpc_client_on_read_event(rpc) < 0)
-                                       break;
-                       }
+               }
 
-                       if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0)
-                       {
-                               rpc_send_dequeue_pdu(rpc);
-                       }
+               if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0)
+               {
+                       rpc_send_dequeue_pdu(rpc);
                }
        }
 
+out:
        CloseHandle(ReadEvent);
 
        return NULL;
@@ -541,6 +561,9 @@ static void* rpc_client_thread(void* arg)
 
 static void rpc_pdu_free(RPC_PDU* pdu)
 {
+       if (!pdu)
+               return;
+
        Stream_Free(pdu->s, TRUE);
        free(pdu);
 }
@@ -554,35 +577,55 @@ int rpc_client_new(rdpRpc* rpc)
 {
        RpcClient* client = NULL;
 
-       client = (RpcClient*) calloc(1, sizeof(RpcClient));
+       client = (RpcClient *)calloc(1, sizeof(RpcClient));
+       rpc->client = client;
+       if (!client)
+               return -1;
 
-       if (client)
-       {
-               client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
-               client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+       client->Thread = CreateThread(NULL, 0,
+                       (LPTHREAD_START_ROUTINE) rpc_client_thread,
+                       rpc, CREATE_SUSPENDED, NULL);
+       if (!client->Thread)
+               return -1;
 
-               client->SendQueue = Queue_New(TRUE, -1, -1);
-               Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
+       client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+       if (!client->StopEvent)
+               return -1;
+       client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+       if (!client->PduSentEvent)
+               return -1;
 
-               client->pdu = NULL;
-               client->ReceivePool = Queue_New(TRUE, -1, -1);
-               client->ReceiveQueue = Queue_New(TRUE, -1, -1);
-               Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
-               Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
+       client->SendQueue = Queue_New(TRUE, -1, -1);
+       if (!client->SendQueue)
+               return -1;
+       Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
 
-               client->RecvFrag = NULL;
-               client->FragmentPool = Queue_New(TRUE, -1, -1);
-               client->FragmentQueue = Queue_New(TRUE, -1, -1);
+       client->pdu = NULL;
+       client->ReceivePool = Queue_New(TRUE, -1, -1);
+       if (!client->ReceivePool)
+               return -1;
+       Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
 
-               Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
-               Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
+       client->ReceiveQueue = Queue_New(TRUE, -1, -1);
+       if (!client->ReceiveQueue)
+               return -1;
+       Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
 
-               client->ClientCallList = ArrayList_New(TRUE);
-               ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free;
-       }
+       client->RecvFrag = NULL;
+       client->FragmentPool = Queue_New(TRUE, -1, -1);
+       if (!client->FragmentPool)
+               return -1;
+       Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
 
-       rpc->client = client;
+       client->FragmentQueue = Queue_New(TRUE, -1, -1);
+       if (!client->FragmentQueue)
+               return -1;
+       Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
 
+       client->ClientCallList = ArrayList_New(TRUE);
+       if (!client->ClientCallList)
+               return -1;
+       ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free;
        return 0;
 }
 
@@ -604,9 +647,7 @@ int rpc_client_stop(rdpRpc* rpc)
                rpc->client->Thread = NULL;
        }
 
-       rpc_client_free(rpc);
-
-       return 0;
+       return rpc_client_free(rpc);
 }
 
 int rpc_client_free(rdpRpc* rpc)
@@ -615,31 +656,39 @@ int rpc_client_free(rdpRpc* rpc)
 
        client = rpc->client;
 
-       if (client)
-       {
+       if (!client)
+               return 0;
+
+       if (client->SendQueue)
                Queue_Free(client->SendQueue);
 
-               if (client->RecvFrag)
-                       rpc_fragment_free(client->RecvFrag);
+       if (client->RecvFrag)
+               rpc_fragment_free(client->RecvFrag);
 
+       if (client->FragmentPool)
                Queue_Free(client->FragmentPool);
+       if (client->FragmentQueue)
                Queue_Free(client->FragmentQueue);
 
-               if (client->pdu)
-                       rpc_pdu_free(client->pdu);
+       if (client->pdu)
+               rpc_pdu_free(client->pdu);
 
+       if (client->ReceivePool)
                Queue_Free(client->ReceivePool);
+       if (client->ReceiveQueue)
                Queue_Free(client->ReceiveQueue);
 
+       if (client->ClientCallList)
                ArrayList_Free(client->ClientCallList);
 
+       if (client->StopEvent)
                CloseHandle(client->StopEvent);
+       if (client->PduSentEvent)
                CloseHandle(client->PduSentEvent);
 
+       if (client->Thread)
                CloseHandle(client->Thread);
 
-               free(client);
-       }
-
+       free(client);
        return 0;
 }
index 42ce2ad..d57a424 100644 (file)
@@ -93,25 +93,25 @@ BOOL rts_connect(rdpRpc* rpc)
 
        if (!rpc_ntlm_http_out_connect(rpc))
        {
-               fprintf(stderr, "rpc_out_connect_http error!\n");
+               fprintf(stderr, "%s: rpc_out_connect_http error!\n", __FUNCTION__);
                return FALSE;
        }
 
        if (rts_send_CONN_A1_pdu(rpc) != 0)
        {
-               fprintf(stderr, "rpc_send_CONN_A1_pdu error!\n");
+               fprintf(stderr, "%s: rpc_send_CONN_A1_pdu error!\n", __FUNCTION__);
                return FALSE;
        }
 
        if (!rpc_ntlm_http_in_connect(rpc))
        {
-               fprintf(stderr, "rpc_in_connect_http error!\n");
+               fprintf(stderr, "%s: rpc_in_connect_http error!\n", __FUNCTION__);
                return FALSE;
        }
 
-       if (rts_send_CONN_B1_pdu(rpc) != 0)
+       if (rts_send_CONN_B1_pdu(rpc) < 0)
        {
-               fprintf(stderr, "rpc_send_CONN_B1_pdu error!\n");
+               fprintf(stderr, "%s: rpc_send_CONN_B1_pdu error!\n", __FUNCTION__);
                return FALSE;
        }
 
@@ -147,10 +147,15 @@ BOOL rts_connect(rdpRpc* rpc)
         */
 
        http_response = http_response_recv(rpc->TlsOut);
+       if (!http_response)
+       {
+               fprintf(stderr, "%s: unable to retrieve OUT Channel Response!\n", __FUNCTION__);
+               return FALSE;
+       }
 
        if (http_response->StatusCode != HTTP_STATUS_OK)
        {
-               fprintf(stderr, "rts_connect error! Status Code: %d\n", http_response->StatusCode);
+               fprintf(stderr, "%s: error! Status Code: %d\n", __FUNCTION__, http_response->StatusCode);
                http_response_print(http_response);
                http_response_free(http_response);
 
@@ -170,6 +175,14 @@ BOOL rts_connect(rdpRpc* rpc)
                return FALSE;
        }
 
+       if (http_response->bodyLen)
+       {
+               /* inject bytes we have read in the body as a received packet for the RPC client */
+               rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
+               Stream_EnsureCapacity(rpc->client->RecvFrag, http_response->bodyLen);
+               CopyMemory(rpc->client->RecvFrag, http_response->BodyContent,  http_response->bodyLen);
+       }
+
        //http_response_print(http_response);
        http_response_free(http_response);
 
@@ -195,7 +208,6 @@ BOOL rts_connect(rdpRpc* rpc)
        rpc_client_start(rpc);
 
        pdu = rpc_recv_dequeue_pdu(rpc);
-
        if (!pdu)
                return FALSE;
 
@@ -203,7 +215,7 @@ BOOL rts_connect(rdpRpc* rpc)
 
        if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_A3_SIGNATURE, rts))
        {
-               fprintf(stderr, "Unexpected RTS PDU: Expected CONN/A3\n");
+               fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/A3\n", __FUNCTION__);
                return FALSE;
        }
 
@@ -236,7 +248,6 @@ BOOL rts_connect(rdpRpc* rpc)
         */
 
        pdu = rpc_recv_dequeue_pdu(rpc);
-
        if (!pdu)
                return FALSE;
 
@@ -244,7 +255,7 @@ BOOL rts_connect(rdpRpc* rpc)
 
        if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_C2_SIGNATURE, rts))
        {
-               fprintf(stderr, "Unexpected RTS PDU: Expected CONN/C2\n");
+               fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/C2\n", __FUNCTION__);
                return FALSE;
        }
 
@@ -261,7 +272,7 @@ BOOL rts_connect(rdpRpc* rpc)
        return TRUE;
 }
 
-#if defined WITH_DEBUG_RTS && 0
+#ifdef WITH_DEBUG_RTS
 
 static const char* const RTS_CMD_STRINGS[] =
 {
@@ -317,6 +328,7 @@ static const char* const RTS_CMD_STRINGS[] =
 
 void rts_pdu_header_init(rpcconn_rts_hdr_t* header)
 {
+       ZeroMemory(header, sizeof(*header));
        header->rpc_vers = 5;
        header->rpc_vers_minor = 0;
        header->ptype = PTYPE_RTS;
@@ -681,6 +693,8 @@ int rts_send_CONN_A1_pdu(rdpRpc* rpc)
        ReceiveWindowSize = rpc->VirtualConnection->DefaultOutChannel->ReceiveWindow;
 
        buffer = (BYTE*) malloc(header.frag_length);
+       if (!buffer)
+               return -1;
 
        CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
        rts_version_command_write(&buffer[20]); /* Version (8 bytes) */
@@ -718,6 +732,7 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc)
        BYTE* INChannelCookie;
        BYTE* AssociationGroupId;
        BYTE* VirtualConnectionCookie;
+       int status;
 
        rts_pdu_header_init(&header);
        header.frag_length = 104;
@@ -734,6 +749,8 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc)
        AssociationGroupId = (BYTE*) &(rpc->VirtualConnection->AssociationGroupId);
 
        buffer = (BYTE*) malloc(header.frag_length);
+       if (!buffer)
+               return -1;
 
        CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
        rts_version_command_write(&buffer[20]); /* Version (8 bytes) */
@@ -745,11 +762,11 @@ int rts_send_CONN_B1_pdu(rdpRpc* rpc)
 
        length = header.frag_length;
 
-       rpc_in_write(rpc, buffer, length);
+       status = rpc_in_write(rpc, buffer, length);
 
        free(buffer);
 
-       return 0;
+       return status;
 }
 
 /* CONN/C Sequence */
@@ -795,12 +812,15 @@ int rts_send_keep_alive_pdu(rdpRpc* rpc)
        DEBUG_RPC("Sending Keep-Alive RTS PDU");
 
        buffer = (BYTE*) malloc(header.frag_length);
+       if (!buffer)
+               return -1;
        CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
        rts_client_keepalive_command_write(&buffer[20], rpc->CurrentKeepAliveInterval); /* ClientKeepAlive (8 bytes) */
 
        length = header.frag_length;
 
-       rpc_in_write(rpc, buffer, length);
+       if (rpc_in_write(rpc, buffer, length) < 0)
+               return -1;
        free(buffer);
 
        return length;
@@ -830,6 +850,8 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc)
                        rpc->VirtualConnection->DefaultOutChannel->AvailableWindowAdvertised;
 
        buffer = (BYTE*) malloc(header.frag_length);
+       if (!buffer)
+               return -1;
 
        CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
        rts_destination_command_write(&buffer[20], FDOutProxy); /* Destination Command (8 bytes) */
@@ -839,7 +861,8 @@ int rts_send_flow_control_ack_pdu(rdpRpc* rpc)
 
        length = header.frag_length;
 
-       rpc_in_write(rpc, buffer, length);
+       if (rpc_in_write(rpc, buffer, length) < 0)
+               return -1;
        free(buffer);
 
        return 0;
@@ -923,12 +946,15 @@ int rts_send_ping_pdu(rdpRpc* rpc)
        DEBUG_RPC("Sending Ping RTS PDU");
 
        buffer = (BYTE*) malloc(header.frag_length);
+       if (!buffer)
+               return -1;
 
        CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
 
        length = header.frag_length;
 
-       rpc_in_write(rpc, buffer, length);
+       if (rpc_in_write(rpc, buffer, length) < 0)
+               return -1;
        free(buffer);
 
        return length;
@@ -1020,22 +1046,18 @@ int rts_recv_out_of_sequence_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
        rts_extract_pdu_signature(rpc, &signature, rts);
        SignatureId = rts_identify_pdu_signature(rpc, &signature, NULL);
 
-       if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK)
+       switch (SignatureId)
        {
-               return rts_recv_flow_control_ack_pdu(rpc, buffer, length);
-       }
-       else if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION)
-       {
-               return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length);
-       }
-       else if (SignatureId == RTS_PDU_PING)
-       {
-               rts_send_ping_pdu(rpc);
-       }
-       else
-       {
-               fprintf(stderr, "Unimplemented signature id: 0x%08X\n", SignatureId);
-               rts_print_pdu_signature(rpc, &signature);
+               case RTS_PDU_FLOW_CONTROL_ACK:
+                       return rts_recv_flow_control_ack_pdu(rpc, buffer, length);
+               case RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION:
+                       return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length);
+               case RTS_PDU_PING:
+                       return rts_send_ping_pdu(rpc);
+               default:
+                       fprintf(stderr, "%s: unimplemented signature id: 0x%08X\n", __FUNCTION__, SignatureId);
+                       rts_print_pdu_signature(rpc, &signature);
+                       break;
        }
 
        return 0;
index 34598fe..47242ca 100644 (file)
@@ -234,7 +234,6 @@ BOOL rts_match_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_rt
                        return FALSE;
 
                status = rts_command_length(rpc, CommandType, &buffer[offset], length);
-
                if (status < 0)
                        return FALSE;
 
@@ -272,7 +271,6 @@ int rts_extract_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, rpcconn_r
                signature->CommandTypes[i] = CommandType;
 
                status = rts_command_length(rpc, CommandType, &buffer[offset], length);
-
                if (status < 0)
                        return FALSE;
 
@@ -294,22 +292,22 @@ UINT32 rts_identify_pdu_signature(rdpRpc* rpc, RtsPduSignature* signature, RTS_P
        {
                pSignature = RTS_PDU_SIGNATURE_TABLE[i].Signature;
 
-               if (signature->Flags == pSignature->Flags)
+               if (signature->Flags != pSignature->Flags)
+                       continue;
+
+               if (signature->NumberOfCommands != pSignature->NumberOfCommands)
+                       continue;
+
+               for (j = 0; j < signature->NumberOfCommands; j++)
                {
-                       if (signature->NumberOfCommands == pSignature->NumberOfCommands)
-                       {
-                               for (j = 0; j < signature->NumberOfCommands; j++)
-                               {
-                                       if (signature->CommandTypes[j] != pSignature->CommandTypes[j])
-                                               continue;
-                               }
-
-                               if (entry)
-                                       *entry = &RTS_PDU_SIGNATURE_TABLE[i];
-
-                               return RTS_PDU_SIGNATURE_TABLE[i].SignatureId;
-                       }
+                       if (signature->CommandTypes[j] != pSignature->CommandTypes[j])
+                               continue;
                }
+
+               if (entry)
+                       *entry = &RTS_PDU_SIGNATURE_TABLE[i];
+
+               return RTS_PDU_SIGNATURE_TABLE[i].SignatureId;
        }
 
        return 0;
index f130f73..5dd6888 100644 (file)
@@ -33,9 +33,9 @@
 #include <winpr/stream.h>
 
 #include "rpc_client.h"
-
 #include "tsg.h"
 
+
 /**
  * RPC Functions: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378623/
  * Remote Procedure Call: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378651/
@@ -96,7 +96,9 @@ DWORD TsProxySendToServer(handle_t IDL_handle, byte pRpcMessage[], UINT32 count,
        }
 
        length = 28 + totalDataBytes;
-       buffer = (BYTE*) malloc(length);
+       buffer = (BYTE*) calloc(1, length);
+       if (!buffer)
+               return -1;
 
        s = Stream_New(buffer, length);
 
@@ -228,8 +230,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
        if (!(pdu->Flags & RPC_PDU_FLAG_STUB))
                buffer = &buffer[24];
 
-       packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET));
-       ZeroMemory(packet, sizeof(TSG_PACKET));
+       packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET));
+       if (!packet)
+               return FALSE;
 
        offset = 4; // Skip Packet Pointer
        packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */
@@ -237,8 +240,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
 
        if ((packet->packetId == TSG_PACKET_TYPE_CAPS_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_CAPS_RESPONSE))
        {
-               packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) malloc(sizeof(TSG_PACKET_CAPS_RESPONSE));
-               ZeroMemory(packetCapsResponse, sizeof(TSG_PACKET_CAPS_RESPONSE));
+               packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) calloc(1, sizeof(TSG_PACKET_CAPS_RESPONSE));
+               if (!packetCapsResponse) // TODO: correct cleanup
+                       return FALSE;
                packet->tsgPacket.packetCapsResponse = packetCapsResponse;
 
                /* PacketQuarResponsePtr (4 bytes) */
@@ -258,8 +262,7 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
                        IsMessagePresent = *((UINT32*) &buffer[offset]);
                        offset += 4;
                        MessageSwitchValue = *((UINT32*) &buffer[offset]);
-                       DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d",
-                                       IsMessagePresent, MessageSwitchValue);
+                       DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d", IsMessagePresent, MessageSwitchValue);
                        offset += 4;
                }
 
@@ -289,8 +292,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
                        offset += 4;
                }
 
-               versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS));
-               ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS));
+               versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS));
+               if (!versionCaps) // TODO: correct cleanup
+                       return FALSE;
                packetCapsResponse->pktQuarEncResponse.versionCaps = versionCaps;
 
                versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */
@@ -317,8 +321,10 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
                /* 4-byte alignment */
                rpc_offset_align(&offset, 4);
 
-               tsgCaps = (PTSG_PACKET_CAPABILITIES) malloc(sizeof(TSG_PACKET_CAPABILITIES));
-               ZeroMemory(tsgCaps, sizeof(TSG_PACKET_CAPABILITIES));
+               tsgCaps = (PTSG_PACKET_CAPABILITIES) calloc(1, sizeof(TSG_PACKET_CAPABILITIES));
+               if (!tsgCaps)
+                       return FALSE;
+
                versionCaps->tsgCaps = tsgCaps;
 
                offset += 4; /* MaxCount (4 bytes) */
@@ -406,8 +412,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
        }
        else if ((packet->packetId == TSG_PACKET_TYPE_QUARENC_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_QUARENC_RESPONSE))
        {
-               packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) malloc(sizeof(TSG_PACKET_QUARENC_RESPONSE));
-               ZeroMemory(packetQuarEncResponse, sizeof(TSG_PACKET_QUARENC_RESPONSE));
+               packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) calloc(1, sizeof(TSG_PACKET_QUARENC_RESPONSE));
+               if (!packetQuarEncResponse) // TODO: handle cleanup
+                       return FALSE;
                packet->tsgPacket.packetQuarEncResponse = packetQuarEncResponse;
 
                /* PacketQuarResponsePtr (4 bytes) */
@@ -443,8 +450,9 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
                        offset += 4;
                }
 
-               versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS));
-               ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS));
+               versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS));
+               if (!versionCaps) // TODO: handle cleanup
+                       return FALSE;
                packetQuarEncResponse->versionCaps = versionCaps;
 
                versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */
@@ -779,8 +787,9 @@ BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
        if (!(pdu->Flags & RPC_PDU_FLAG_STUB))
                buffer = &buffer[24];
 
-       packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET));
-       ZeroMemory(packet, sizeof(TSG_PACKET));
+       packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET));
+       if (!packet)
+               return FALSE;
 
        offset = 4;
        packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */
@@ -923,6 +932,8 @@ BOOL TsProxyCreateChannelWriteRequest(rdpTsg* tsg, PTUNNEL_CONTEXT_HANDLE_NOSERI
        length = 60 + (count * 2);
 
        buffer = (BYTE*) malloc(length);
+       if (!buffer)
+               return FALSE;
 
        /* TunnelContext */
        handle = (CONTEXT_HANDLE*) tunnelContext;
@@ -1526,48 +1537,53 @@ int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length)
 
                return CopyLength;
        }
-       else
-       {
-               tsg->pdu = rpc_recv_peek_pdu(rpc);
 
-               if (!tsg->pdu)
-               {
-                       if (tsg->rpc->client->SynchronousReceive)
-                               return tsg_read(tsg, data, length);
-                       else
-                               return 0;
-               }
 
-               tsg->PendingPdu = TRUE;
-               tsg->BytesAvailable = Stream_Length(tsg->pdu->s);
-               tsg->BytesRead = 0;
+       tsg->pdu = rpc_recv_peek_pdu(rpc);
+       if (!tsg->pdu)
+       {
+               if (!tsg->rpc->client->SynchronousReceive)
+                       return 0;
 
-               CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable;
+               // weird !!!!
+               return tsg_read(tsg, data, length);
+       }
 
-               CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength);
-               tsg->BytesAvailable -= CopyLength;
-               tsg->BytesRead += CopyLength;
+       tsg->PendingPdu = TRUE;
+       tsg->BytesAvailable = Stream_Length(tsg->pdu->s);
+       tsg->BytesRead = 0;
 
-               if (tsg->BytesAvailable < 1)
-               {
-                       tsg->PendingPdu = FALSE;
-                       rpc_recv_dequeue_pdu(rpc);
-                       rpc_client_receive_pool_return(rpc, tsg->pdu);
-               }
+       CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable;
 
-               return CopyLength;
+       CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength);
+       tsg->BytesAvailable -= CopyLength;
+       tsg->BytesRead += CopyLength;
+
+       if (tsg->BytesAvailable < 1)
+       {
+               tsg->PendingPdu = FALSE;
+               rpc_recv_dequeue_pdu(rpc);
+               rpc_client_receive_pool_return(rpc, tsg->pdu);
        }
+
+       return CopyLength;
+
 }
 
 int tsg_write(rdpTsg* tsg, BYTE* data, UINT32 length)
 {
+       int status;
+
        if (tsg->rpc->transport->layer == TRANSPORT_LAYER_CLOSED)
        {
-               fprintf(stderr, "tsg_write error: connection lost\n");
+               fprintf(stderr, "%s: error, connection lost\n", __FUNCTION__);
                return -1;
        }
 
-       return TsProxySendToServer((handle_t) tsg, data, 1, &length);
+       status = TsProxySendToServer((handle_t) tsg, data, 1, &length);
+       if (status < 0)
+               return -1;
+       return length;
 }
 
 BOOL tsg_set_blocking_mode(rdpTsg* tsg, BOOL blocking)
@@ -1584,18 +1600,21 @@ rdpTsg* tsg_new(rdpTransport* transport)
 {
        rdpTsg* tsg;
 
-       tsg = (rdpTsg*) malloc(sizeof(rdpTsg));
-       ZeroMemory(tsg, sizeof(rdpTsg));
-
-       if (tsg != NULL)
-       {
-               tsg->transport = transport;
-               tsg->settings = transport->settings;
-               tsg->rpc = rpc_new(tsg->transport);
-               tsg->PendingPdu = FALSE;
-       }
+       tsg = (rdpTsg*) calloc(1, sizeof(rdpTsg));
+       if (!tsg)
+               return NULL;
 
+       tsg->transport = transport;
+       tsg->settings = transport->settings;
+       tsg->rpc = rpc_new(tsg->transport);
+       if (!tsg->rpc)
+               goto out_free;
+       tsg->PendingPdu = FALSE;
        return tsg;
+
+out_free:
+       free(tsg);
+       return NULL;
 }
 
 void tsg_free(rdpTsg* tsg)
index e1662d3..bc7431f 100644 (file)
@@ -52,13 +52,13 @@ static BOOL freerdp_peer_initialize(freerdp_peer* client)
                        fprintf(stderr, "%s: inavlid RDP key file %s\n", __FUNCTION__, settings->RdpKeyFile);
                        return FALSE;
                }
+
                if (settings->RdpServerRsaKey->ModulusLength > 256)
                {
                        fprintf(stderr, "%s: Key sizes > 2048 are currently not supported for RDP security.\n", __FUNCTION__);
                        fprintf(stderr, "%s: Set a different key file than %s\n", __FUNCTION__, settings->RdpKeyFile);
                        exit(1);
                }
-
        }
 
        return TRUE;
@@ -77,12 +77,13 @@ static HANDLE freerdp_peer_get_event_handle(freerdp_peer* client)
        return client->context->rdp->transport->TcpIn->event;
 }
 
-static BOOL freerdp_peer_check_fds(freerdp_peer* client)
+
+static BOOL freerdp_peer_check_fds(freerdp_peer* peer)
 {
        int status;
        rdpRdp* rdp;
 
-       rdp = client->context->rdp;
+       rdp = peer->context->rdp;
 
        status = rdp_check_fds(rdp);
 
@@ -413,6 +414,19 @@ static int freerdp_peer_send_channel_data(freerdp_peer* client, UINT16 channelId
        return rdp_send_channel_data(client->context->rdp, channelId, data, size);
 }
 
+static BOOL freerdp_peer_is_write_blocked(freerdp_peer* peer)
+{
+       return tranport_is_write_blocked(peer->context->rdp->transport);
+}
+
+static int freerdp_peer_drain_output_buffer(freerdp_peer* peer)
+{
+
+       rdpTransport *transport = peer->context->rdp->transport;
+
+       return tranport_drain_output_buffer(transport);
+}
+
 void freerdp_peer_context_new(freerdp_peer* client)
 {
        rdpRdp* rdp;
@@ -445,6 +459,9 @@ void freerdp_peer_context_new(freerdp_peer* client)
        rdp->transport->ReceiveExtra = client;
        transport_set_blocking_mode(rdp->transport, FALSE);
 
+       client->IsWriteBlocked = freerdp_peer_is_write_blocked;
+       client->DrainOutputBuffer = freerdp_peer_drain_output_buffer;
+
        IFCALL(client->ContextNew, client, client->context);
 }
 
@@ -473,6 +490,8 @@ freerdp_peer* freerdp_peer_new(int sockfd)
                client->Close = freerdp_peer_close;
                client->Disconnect = freerdp_peer_disconnect;
                client->SendChannelData = freerdp_peer_send_channel_data;
+               client->IsWriteBlocked = freerdp_peer_is_write_blocked;
+               client->DrainOutputBuffer = freerdp_peer_drain_output_buffer;
        }
 
        return client;
@@ -480,10 +499,10 @@ freerdp_peer* freerdp_peer_new(int sockfd)
 
 void freerdp_peer_free(freerdp_peer* client)
 {
-       if (client)
-       {
-               rdp_free(client->context->rdp);
-               free(client->context);
-               free(client);
-       }
+       if (!client)
+               return;
+
+       rdp_free(client->context->rdp);
+       free(client->context);
+       free(client);
 }
index 6538ec7..6bc2515 100644 (file)
@@ -209,6 +209,7 @@ rdpSettings* freerdp_settings_new(DWORD flags)
                ZeroMemory(settings, sizeof(rdpSettings));
 
                settings->ServerMode = (flags & FREERDP_SETTINGS_SERVER_MODE) ? TRUE : FALSE;
+               settings->WaitForOutputBufferFlush = TRUE;
 
                settings->DesktopWidth = 1024;
                settings->DesktopHeight = 768;
@@ -579,6 +580,7 @@ rdpSettings* freerdp_settings_clone(rdpSettings* settings)
                /* BOOL values */
 
                _settings->ServerMode = settings->ServerMode; /* 16 */
+               _settings->WaitForOutputBufferFlush = settings->WaitForOutputBufferFlush; /* 25 */
                _settings->NetworkAutoDetect = settings->NetworkAutoDetect; /* 137 */
                _settings->SupportAsymetricKeys = settings->SupportAsymetricKeys; /* 138 */
                _settings->SupportErrorInfoPdu = settings->SupportErrorInfoPdu; /* 139 */
index 15c4176..6676382 100644 (file)
 
 #include "tcp.h"
 
+long transport_bio_buffered_callback(BIO* bio, int mode, const char* argp, int argi, long argl, long ret)
+{
+       return 1;
+}
+
+static int transport_bio_buffered_write(BIO* bio, const char* buf, int num)
+{
+       int status, ret;
+       rdpTcp *tcp = (rdpTcp *)bio->ptr;
+       int nchunks, committedBytes, i;
+       DataChunk chunks[2];
+
+       ret = num;
+       BIO_clear_retry_flags(bio);
+       tcp->writeBlocked = FALSE;
+
+       /* we directly append extra bytes in the xmit buffer, this could be prevented
+        * but for now it makes the code more simple.
+        */
+       if (buf && num && !ringbuffer_write(&tcp->xmitBuffer, buf, num))
+       {
+               fprintf(stderr, "%s: an error occured when writing(toWrite=%d)\n", __FUNCTION__, num);
+               return -1;
+       }
+
+       committedBytes = 0;
+       nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer));
+       for (i = 0; i < nchunks; i++)
+       {
+               while (chunks[i].size)
+               {
+                       status = BIO_write(bio->next_bio, chunks[i].data, chunks[i].size);
+                       /*fprintf(stderr, "%s: i=%d/%d size=%d/%d status=%d retry=%d\n", __FUNCTION__, i, nchunks,
+                                       chunks[i].size, ringbuffer_used(&tcp->xmitBuffer), status,
+                                       BIO_should_retry(bio->next_bio)
+                       );*/
+                       if (status <= 0)
+                       {
+                               if (BIO_should_retry(bio->next_bio))
+                               {
+                                       tcp->writeBlocked = TRUE;
+                                       goto out; /* EWOULDBLOCK */
+                               }
+
+                               /* any other is an error, but we still have to commit written bytes */
+                               ret = -1;
+                               goto out;
+                       }
+
+                       committedBytes += status;
+                       chunks[i].size -= status;
+                       chunks[i].data += status;
+               }
+       }
+
+out:
+       ringbuffer_commit_read_bytes(&tcp->xmitBuffer, committedBytes);
+       return ret;
+}
+
+static int transport_bio_buffered_read(BIO* bio, char* buf, int size)
+{
+       int status;
+       rdpTcp *tcp = (rdpTcp *)bio->ptr;
+
+       tcp->readBlocked = FALSE;
+       BIO_clear_retry_flags(bio);
+
+       status = BIO_read(bio->next_bio, buf, size);
+       /*fprintf(stderr, "%s: size=%d status=%d shouldRetry=%d\n", __FUNCTION__, size, status, BIO_should_retry(bio->next_bio)); */
+
+       if (status <= 0 && BIO_should_retry(bio->next_bio))
+       {
+               BIO_set_retry_read(bio);
+               tcp->readBlocked = TRUE;
+       }
+
+       return status;
+}
+
+static int transport_bio_buffered_puts(BIO* bio, const char* str)
+{
+       return 1;
+}
+
+static int transport_bio_buffered_gets(BIO* bio, char* str, int size)
+{
+       return 1;
+}
+
+static long transport_bio_buffered_ctrl(BIO* bio, int cmd, long arg1, void* arg2)
+{
+       rdpTcp *tcp = (rdpTcp *)bio->ptr;
+
+       switch (cmd)
+       {
+       case BIO_CTRL_FLUSH:
+               return 1;
+       case BIO_CTRL_WPENDING:
+               return ringbuffer_used(&tcp->xmitBuffer);
+       case BIO_CTRL_PENDING:
+               return 0;
+       default:
+               /*fprintf(stderr, "%s: passing to next BIO, bio=%p cmd=%d arg1=%d arg2=%p\n", __FUNCTION__, bio, cmd, arg1, arg2); */
+               return BIO_ctrl(bio->next_bio, cmd, arg1, arg2);
+       }
+
+       return 0;
+}
+
+static int transport_bio_buffered_new(BIO* bio)
+{
+       bio->init = 1;
+       bio->num = 0;
+       bio->ptr = NULL;
+       bio->flags = 0;
+
+       return 1;
+}
+
+static int transport_bio_buffered_free(BIO* bio)
+{
+       return 1;
+}
+
+
+static BIO_METHOD transport_bio_buffered_socket_methods =
+{
+       BIO_TYPE_BUFFERED,
+       "BufferedSocket",
+       transport_bio_buffered_write,
+       transport_bio_buffered_read,
+       transport_bio_buffered_puts,
+       transport_bio_buffered_gets,
+       transport_bio_buffered_ctrl,
+       transport_bio_buffered_new,
+       transport_bio_buffered_free,
+       NULL,
+};
+
+BIO_METHOD* BIO_s_buffered_socket(void)
+{
+       return &transport_bio_buffered_socket_methods;
+}
+
+BOOL transport_bio_buffered_drain(BIO *bio)
+{
+       rdpTcp *tcp = (rdpTcp *)bio->ptr;
+       int status;
+
+       if (!ringbuffer_used(&tcp->xmitBuffer))
+               return 1;
+
+       status = transport_bio_buffered_write(bio, NULL, 0);
+       return status >= 0;
+}
+
+
+
 void tcp_get_ip_address(rdpTcp* tcp)
 {
        BYTE* ip;
@@ -136,62 +295,65 @@ BOOL tcp_connect(rdpTcp* tcp, const char* hostname, int port)
        if (hostname[0] == '/')
        {
                tcp->sockfd = freerdp_uds_connect(hostname);
-
                if (tcp->sockfd < 0)
                        return FALSE;
+
+               tcp->socketBio = BIO_new_fd(tcp->sockfd, 1);
+               if (!tcp->socketBio)
+                       return FALSE;
        }
        else
        {
-               tcp->sockfd = freerdp_tcp_connect(hostname, port);
+               tcp->socketBio = BIO_new(BIO_s_connect());
+               if (!tcp->socketBio)
+                       return FALSE;
 
-               if (tcp->sockfd < 0)
+               if (BIO_set_conn_hostname(tcp->socketBio, hostname) < 0 ||      BIO_set_conn_int_port(tcp->socketBio, &port) < 0)
                        return FALSE;
 
-               SetEventFileDescriptor(tcp->event, tcp->sockfd);
+               if (BIO_do_connect(tcp->socketBio) <= 0)
+                       return FALSE;
+
+               tcp->sockfd = BIO_get_fd(tcp->socketBio, NULL);
+       }
+
+       SetEventFileDescriptor(tcp->event, tcp->sockfd);
 
-               tcp_get_ip_address(tcp);
-               tcp_get_mac_address(tcp);
+       tcp_get_ip_address(tcp);
+       tcp_get_mac_address(tcp);
 
-               option_value = 1;
-               option_len = sizeof(option_value);
-               setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len);
+       option_value = 1;
+       option_len = sizeof(option_value);
+       if (setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len) < 0)
+               fprintf(stderr, "%s: unable to set TCP_NODELAY\n", __FUNCTION__);
 
-               /* receive buffer must be a least 32 K */
-               if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0)
+       /* receive buffer must be a least 32 K */
+       if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0)
+       {
+               if (option_value < (1024 * 32))
                {
-                       if (option_value < (1024 * 32))
+                       option_value = 1024 * 32;
+                       option_len = sizeof(option_value);
+                       if (setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len) < 0)
                        {
-                               option_value = 1024 * 32;
-                               option_len = sizeof(option_value);
-                               setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len);
+                               fprintf(stderr, "%s: unable to set receive buffer len\n", __FUNCTION__);
+                               return FALSE;
                        }
                }
-
-               tcp_set_keep_alive_mode(tcp);
        }
 
-       return TRUE;
-}
-
-int tcp_read(rdpTcp* tcp, BYTE* data, int length)
-{
-       return freerdp_tcp_read(tcp->sockfd, data, length);
-}
+       if (!tcp_set_keep_alive_mode(tcp))
+               return FALSE;
 
-int tcp_write(rdpTcp* tcp, BYTE* data, int length)
-{
-       return freerdp_tcp_write(tcp->sockfd, data, length);
-}
+       tcp->bufferedBio = BIO_new(BIO_s_buffered_socket());
+       if (!tcp->bufferedBio)
+               return FALSE;
+       tcp->bufferedBio->ptr = tcp;
 
-int tcp_wait_read(rdpTcp* tcp)
-{
-       return freerdp_tcp_wait_read(tcp->sockfd);
+       tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio);
+       return TRUE;
 }
 
-int tcp_wait_write(rdpTcp* tcp)
-{
-       return freerdp_tcp_wait_write(tcp->sockfd);
-}
 
 BOOL tcp_disconnect(rdpTcp* tcp)
 {
@@ -209,7 +371,7 @@ BOOL tcp_set_blocking_mode(rdpTcp* tcp, BOOL blocking)
 
        if (flags == -1)
        {
-               fprintf(stderr, "tcp_set_blocking_mode: fcntl failed.\n");
+               fprintf(stderr, "%s: fcntl failed, %s.\n", __FUNCTION__, strerror(errno));
                return FALSE;
        }
 
@@ -297,6 +459,31 @@ int tcp_attach(rdpTcp* tcp, int sockfd)
 {
        tcp->sockfd = sockfd;
        SetEventFileDescriptor(tcp->event, tcp->sockfd);
+
+       ringbuffer_commit_read_bytes(&tcp->xmitBuffer, ringbuffer_used(&tcp->xmitBuffer));
+
+       if (tcp->socketBio)
+       {
+               if (BIO_set_fd(tcp->socketBio, sockfd, 1) < 0)
+                       return -1;
+       }
+       else
+       {
+               tcp->socketBio = BIO_new_socket(sockfd, 1);
+               if (!tcp->socketBio)
+                       return -1;
+       }
+
+       if (!tcp->bufferedBio)
+       {
+               tcp->bufferedBio = BIO_new(BIO_s_buffered_socket());
+               if (!tcp->bufferedBio)
+                       return FALSE;
+               tcp->bufferedBio->ptr = tcp;
+
+               tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio);
+       }
+
        return 0;
 }
 
@@ -316,25 +503,34 @@ rdpTcp* tcp_new(rdpSettings* settings)
 {
        rdpTcp* tcp;
 
-       tcp = (rdpTcp*) malloc(sizeof(rdpTcp));
+       tcp = (rdpTcp *)calloc(1, sizeof(rdpTcp));
+       if (!tcp)
+               return NULL;
 
-       if (tcp)
-       {
-               ZeroMemory(tcp, sizeof(rdpTcp));
+       if (!ringbuffer_init(&tcp->xmitBuffer, 0x10000))
+               goto out_free;
 
-               tcp->sockfd = -1;
-               tcp->settings = settings;
-               tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd);
-       }
+       tcp->sockfd = -1;
+       tcp->settings = settings;
+
+       tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd);
+       if (!tcp->event || tcp->event == INVALID_HANDLE_VALUE)
+               goto out_ringbuffer;
 
        return tcp;
+out_ringbuffer:
+       ringbuffer_destroy(&tcp->xmitBuffer);
+out_free:
+       free(tcp);
+       return NULL;
 }
 
 void tcp_free(rdpTcp* tcp)
 {
-       if (tcp)
-       {
-               CloseHandle(tcp->event);
-               free(tcp);
-       }
+       if (!tcp)
+               return;
+
+       ringbuffer_destroy(&tcp->xmitBuffer);
+       CloseHandle(tcp->event);
+       free(tcp);
 }
index b43fbaf..a8b3153 100644 (file)
 #include <winpr/stream.h>
 #include <winpr/winsock.h>
 
+#include <freerdp/utils/ringbuffer.h>
+#include <openssl/bio.h>
+
 #ifndef MSG_NOSIGNAL
 #define MSG_NOSIGNAL 0
 #endif
 
+#define BIO_TYPE_BUFFERED      66
+
 typedef struct rdp_tcp rdpTcp;
 
 struct rdp_tcp
@@ -46,6 +51,12 @@ struct rdp_tcp
 #ifdef _WIN32
        WSAEVENT wsa_event;
 #endif
+       BIO *socketBio;
+       BIO *bufferedBio;
+       RingBuffer xmitBuffer;
+       BOOL writeBlocked;
+       BOOL readBlocked;
+
        HANDLE event;
 };
 
index c194c29..f79d51a 100644 (file)
@@ -33,7 +33,9 @@
 
 #include <freerdp/error.h>
 #include <freerdp/utils/tcp.h>
+#include <freerdp/utils/ringbuffer.h>
 
+#include <openssl/bio.h>
 #include <time.h>
 #include <errno.h>
 #include <fcntl.h>
 #ifndef _WIN32
 #include <netdb.h>
 #include <sys/socket.h>
+#include <sys/select.h>
+#include <sys/time.h>
+#endif
+
+#ifdef HAVE_VALGRIND_MEMCHECK_H
+#include <valgrind/memcheck.h>
 #endif
 
 #include "tpkt.h"
@@ -48,6 +56,7 @@
 #include "transport.h"
 #include "rdp.h"
 
+
 #define BUFFER_SIZE 16384
 
 static void* transport_client_thread(void* arg);
@@ -69,6 +78,7 @@ void transport_attach(rdpTransport* transport, int sockfd)
        tcp_attach(transport->TcpIn, sockfd);
        transport->SplitInputOutput = FALSE;
        transport->TcpOut = transport->TcpIn;
+       transport->frontBio = transport->TcpIn->bufferedBio;
 }
 
 void transport_stop(rdpTransport* transport)
@@ -98,18 +108,9 @@ BOOL transport_disconnect(rdpTransport* transport)
 
        transport_stop(transport);
 
-       if (transport->layer == TRANSPORT_LAYER_TLS)
-               status &= tls_disconnect(transport->TlsIn);
-
-       if ((transport->layer == TRANSPORT_LAYER_TSG) || (transport->layer == TRANSPORT_LAYER_TSG_TLS))
-       {
-               status &= tsg_disconnect(transport->tsg);
-       }
-       else
-       {
-               status &= tcp_disconnect(transport->TcpIn);
-       }
+       BIO_free_all(transport->frontBio);
 
+       transport->frontBio = 0;
        return status;
 }
 
@@ -131,16 +132,16 @@ static int transport_bio_tsg_write(BIO* bio, const char* buf, int num)
        rdpTsg* tsg;
 
        tsg = (rdpTsg*) bio->ptr;
-       status = tsg_write(tsg, (BYTE*) buf, num);
 
        BIO_clear_retry_flags(bio);
+       status = tsg_write(tsg, (BYTE*) buf, num);
+       if (status > 0)
+               return status;
 
        if (status == 0)
-       {
                BIO_set_retry_write(bio);
-       }
 
-       return status < 0 ? 0 : num;
+       return -1;
 }
 
 static int transport_bio_tsg_read(BIO* bio, char* buf, int size)
@@ -222,8 +223,13 @@ BIO_METHOD* BIO_s_tsg(void)
        return &transport_bio_tsg_methods;
 }
 
+
+
 BOOL transport_connect_tls(rdpTransport* transport)
 {
+       rdpSettings *settings = transport->settings;
+       rdpTls *targetTls;
+       BIO *targetBio;
        int tls_status;
        freerdp* instance;
        rdpContext* context;
@@ -234,61 +240,33 @@ BOOL transport_connect_tls(rdpTransport* transport)
        if (transport->layer == TRANSPORT_LAYER_TSG)
        {
                transport->TsgTls = tls_new(transport->settings);
-
-               transport->TsgTls->methods = BIO_s_tsg();
-               transport->TsgTls->tsg = (void*) transport->tsg;
-
                transport->layer = TRANSPORT_LAYER_TSG_TLS;
 
-               transport->TsgTls->hostname = transport->settings->ServerHostname;
-               transport->TsgTls->port = transport->settings->ServerPort;
-
-               if (transport->TsgTls->port == 0)
-                       transport->TsgTls->port = 3389;
-
-               tls_status = tls_connect(transport->TsgTls);
-
-               if (tls_status < 1)
-               {
-                       if (tls_status < 0)
-                       {
-                               if (!connectErrorCode)
-                                       connectErrorCode = TLSCONNECTERROR;
-
-                               if (!freerdp_get_last_error(context))
-                                       freerdp_set_last_error(context, FREERDP_ERROR_TLS_CONNECT_FAILED);
-                       }
-                       else
-                       {
-                               if (!freerdp_get_last_error(context))
-                                       freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED);
-                       }
+               targetTls = transport->TsgTls;
+               targetBio = transport->frontBio;
+       }
+       else
+       {
+               if (!transport->TlsIn)
+                       transport->TlsIn = tls_new(settings);
 
-                       tls_free(transport->TsgTls);
-                       transport->TsgTls = NULL;
+               if (!transport->TlsOut)
+                       transport->TlsOut = transport->TlsIn;
 
-                       return FALSE;
-               }
+               targetTls = transport->TlsIn;
+               targetBio = transport->TcpIn->bufferedBio;
 
-               return TRUE;
+               transport->layer = TRANSPORT_LAYER_TLS;
        }
 
-       if (!transport->TlsIn)
-               transport->TlsIn = tls_new(transport->settings);
-
-       if (!transport->TlsOut)
-               transport->TlsOut = transport->TlsIn;
 
-       transport->layer = TRANSPORT_LAYER_TLS;
-       transport->TlsIn->sockfd = transport->TcpIn->sockfd;
+       targetTls->hostname = settings->ServerHostname;
+       targetTls->port = settings->ServerPort;
 
-       transport->TlsIn->hostname = transport->settings->ServerHostname;
-       transport->TlsIn->port = transport->settings->ServerPort;
+       if (targetTls->port == 0)
+               targetTls->port = 3389;
 
-       if (transport->TlsIn->port == 0)
-               transport->TlsIn->port = 3389;
-
-       tls_status = tls_connect(transport->TlsIn);
+       tls_status = tls_connect(targetTls, targetBio);
 
        if (tls_status < 1)
        {
@@ -306,13 +284,13 @@ BOOL transport_connect_tls(rdpTransport* transport)
                                freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED);
                }
 
-               tls_free(transport->TlsIn);
-
-               if (transport->TlsIn == transport->TlsOut)
-                       transport->TlsIn = transport->TlsOut = NULL;
-               else
-                       transport->TlsIn = NULL;
+               return FALSE;
+       }
 
+       transport->frontBio = targetTls->bio;
+       if (!transport->frontBio)
+       {
+               fprintf(stderr, "%s: unable to prepend a filtering TLS bio");
                return FALSE;
        }
 
@@ -323,6 +301,7 @@ BOOL transport_connect_nla(rdpTransport* transport)
 {
        freerdp* instance;
        rdpSettings* settings;
+       rdpCredssp *credSsp;
 
        settings = transport->settings;
        instance = (freerdp*) settings->instance;
@@ -338,16 +317,22 @@ BOOL transport_connect_nla(rdpTransport* transport)
        if (!transport->credssp)
        {
                transport->credssp = credssp_new(instance, transport, settings);
+               if (!transport->credssp)
+                       return FALSE;
+
                transport_set_nla_mode(transport, TRUE);
 
                if (settings->AuthenticationServiceClass)
                {
                        transport->credssp->ServicePrincipalName =
                                credssp_make_spn(settings->AuthenticationServiceClass, settings->ServerHostname);
+                       if (!transport->credssp->ServicePrincipalName)
+                               return FALSE;
                }
        }
 
-       if (credssp_authenticate(transport->credssp) < 0)
+       credSsp = transport->credssp;
+       if (credssp_authenticate(credSsp) < 0)
        {
                if (!connectErrorCode)
                        connectErrorCode = AUTHENTICATIONERROR;
@@ -361,14 +346,14 @@ BOOL transport_connect_nla(rdpTransport* transport)
                        "If credentials are valid, the NTLMSSP implementation may be to blame.\n");
 
                transport_set_nla_mode(transport, FALSE);
-               credssp_free(transport->credssp);
+               credssp_free(credSsp);
                transport->credssp = NULL;
 
                return FALSE;
        }
 
        transport_set_nla_mode(transport, FALSE);
-       credssp_free(transport->credssp);
+       credssp_free(credSsp);
        transport->credssp = NULL;
 
        return TRUE;
@@ -380,38 +365,41 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
        int tls_status;
        freerdp* instance;
        rdpContext* context;
+       rdpSettings *settings = transport->settings;
 
        instance = (freerdp*) transport->settings->instance;
        context = instance->context;
 
        tsg = tsg_new(transport);
+       if (!tsg)
+               return FALSE;
 
        tsg->transport = transport;
        transport->tsg = tsg;
        transport->SplitInputOutput = TRUE;
 
        if (!transport->TlsIn)
-               transport->TlsIn = tls_new(transport->settings);
-
-       transport->TlsIn->sockfd = transport->TcpIn->sockfd;
-       transport->TlsIn->hostname = transport->settings->GatewayHostname;
-       transport->TlsIn->port = transport->settings->GatewayPort;
-
-       if (transport->TlsIn->port == 0)
-               transport->TlsIn->port = 443;
-
+       {
+               transport->TlsIn = tls_new(settings);
+               if (!transport->TlsIn)
+                       return FALSE;
+       }
        if (!transport->TlsOut)
-               transport->TlsOut = tls_new(transport->settings);
+       {
+               transport->TlsOut = tls_new(settings);
+               if (!transport->TlsOut)
+                       return FALSE;
+       }
 
-       transport->TlsOut->sockfd = transport->TcpOut->sockfd;
-       transport->TlsOut->hostname = transport->settings->GatewayHostname;
-       transport->TlsOut->port = transport->settings->GatewayPort;
+       /* put a decent default value for gateway port */
+       if (!settings->GatewayPort)
+               settings->GatewayPort = 443;
 
-       if (transport->TlsOut->port == 0)
-               transport->TlsOut->port = 443;
+       transport->TlsIn->hostname = transport->TlsOut->hostname = settings->GatewayHostname;
+       transport->TlsIn->port = transport->TlsOut->port = settings->GatewayPort;
 
-       tls_status = tls_connect(transport->TlsIn);
 
+       tls_status = tls_connect(transport->TlsIn, transport->TcpIn->bufferedBio);
        if (tls_status < 1)
        {
                if (tls_status < 0)
@@ -428,8 +416,7 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
                return FALSE;
        }
 
-       tls_status = tls_connect(transport->TlsOut);
-
+       tls_status = tls_connect(transport->TlsOut, transport->TcpOut->bufferedBio);
        if (tls_status < 1)
        {
                if (tls_status < 0)
@@ -449,6 +436,8 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
        if (!tsg_connect(tsg, hostname, port))
                return FALSE;
 
+       transport->frontBio = BIO_new(BIO_s_tsg());
+       transport->frontBio->ptr = tsg;
        return TRUE;
 }
 
@@ -462,15 +451,20 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por
        if (transport->GatewayEnabled)
        {
                transport->layer = TRANSPORT_LAYER_TSG;
+               transport->SplitInputOutput = TRUE;
                transport->TcpOut = tcp_new(settings);
 
-               status = tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort);
+               if (!tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort) ||
+                               !tcp_set_blocking_mode(transport->TcpIn, FALSE))
+                       return FALSE;
 
-               if (status)
-                       status = tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort);
+               if (!tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort) ||
+                               !tcp_set_blocking_mode(transport->TcpOut, FALSE))
+                       return FALSE;
 
-               if (status)
-                       status = transport_tsg_connect(transport, hostname, port);
+               if (!transport_tsg_connect(transport, hostname, port))
+                       return FALSE;
+               status = TRUE;
        }
        else
        {
@@ -478,6 +472,7 @@ BOOL transport_connect(rdpTransport* transport, const char* hostname, UINT16 por
 
                transport->SplitInputOutput = FALSE;
                transport->TcpOut = transport->TcpIn;
+               transport->frontBio = transport->TcpIn->bufferedBio;
        }
 
        if (status)
@@ -510,11 +505,11 @@ BOOL transport_accept_tls(rdpTransport* transport)
                transport->TlsOut = transport->TlsIn;
 
        transport->layer = TRANSPORT_LAYER_TLS;
-       transport->TlsIn->sockfd = transport->TcpIn->sockfd;
 
-       if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
+       if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
                return FALSE;
 
+       transport->frontBio = transport->TlsIn->bio;
        return TRUE;
 }
 
@@ -533,10 +528,10 @@ BOOL transport_accept_nla(rdpTransport* transport)
                transport->TlsOut = transport->TlsIn;
 
        transport->layer = TRANSPORT_LAYER_TLS;
-       transport->TlsIn->sockfd = transport->TcpIn->sockfd;
 
-       if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
+       if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, settings->CertificateFile, settings->PrivateKeyFile))
                return FALSE;
+       transport->frontBio = transport->TlsIn->bio;
 
        /* Network Level Authentication */
 
@@ -630,56 +625,131 @@ UINT32 nla_header_length(wStream* s)
        return length;
 }
 
+static int transport_wait_for_read(rdpTransport* transport)
+{
+       struct timeval tv;
+       fd_set rset, wset;
+       fd_set *rsetPtr = NULL, *wsetPtr = NULL;
+       rdpTcp *tcpIn;
+
+       tcpIn = transport->TcpIn;
+       if (tcpIn->readBlocked)
+       {
+               rsetPtr = &rset;
+               FD_ZERO(rsetPtr);
+               FD_SET(tcpIn->sockfd, rsetPtr);
+       }
+       else if (tcpIn->writeBlocked)
+       {
+               wsetPtr = &wset;
+               FD_ZERO(wsetPtr);
+               FD_SET(tcpIn->sockfd, wsetPtr);
+       }
+
+       if (!wsetPtr && !rsetPtr)
+       {
+               USleep(1000);
+               return 0;
+       }
+
+       tv.tv_sec = 0;
+       tv.tv_usec = 1000;
+
+       return select(tcpIn->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
+}
+
+
+static int transport_wait_for_write(rdpTransport* transport)
+{
+       struct timeval tv;
+       fd_set rset, wset;
+       fd_set *rsetPtr = NULL, *wsetPtr = NULL;
+       rdpTcp *tcpOut;
+
+       tcpOut = transport->SplitInputOutput ? transport->TcpOut : transport->TcpIn;
+       if (tcpOut->writeBlocked)
+       {
+               wsetPtr = &wset;
+               FD_ZERO(wsetPtr);
+               FD_SET(tcpOut->sockfd, wsetPtr);
+       }
+       else if (tcpOut->readBlocked)
+       {
+               rsetPtr = &rset;
+               FD_ZERO(rsetPtr);
+               FD_SET(tcpOut->sockfd, rsetPtr);
+       }
+
+       if (!wsetPtr && !rsetPtr)
+       {
+               USleep(1000);
+               return 0;
+       }
+
+       tv.tv_sec = 0;
+       tv.tv_usec = 1000;
+
+       return select(tcpOut->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
+}
+
+
 int transport_read_layer(rdpTransport* transport, BYTE* data, int bytes)
 {
        int read = 0;
        int status = -1;
 
+
        while (read < bytes)
        {
-               if (transport->layer == TRANSPORT_LAYER_TLS)
-                       status = tls_read(transport->TlsIn, data + read, bytes - read);
-               else if (transport->layer == TRANSPORT_LAYER_TCP)
-                       status = tcp_read(transport->TcpIn, data + read, bytes - read);
-               else if (transport->layer == TRANSPORT_LAYER_TSG)
-                       status = tsg_read(transport->tsg, data + read, bytes - read);
-               else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) {
-                       status = tls_read(transport->TsgTls, data + read, bytes - read);
-               }
-
-               /* blocking means that we can't continue until this is read */
-
-               if (!transport->blocking)
-                       return status;
+               status = BIO_read(transport->frontBio, data + read, bytes - read);
 
-               if (status < 0)
+               if (!status)
                {
-                       /* A read error indicates that the peer has dropped the connection */
                        transport->layer = TRANSPORT_LAYER_CLOSED;
-                       return status;
+                       return -1;
                }
 
-               read += status;
-
-               if (status == 0)
+               if (status < 0)
                {
-                       /*
-                        * instead of sleeping, we should wait timeout on the
-                        * socket but this only happens on initial connection
-                        */
-                       USleep(transport->SleepInterval);
+                       if (!BIO_should_retry(transport->frontBio))
+                       {
+                               /* something unexpected happened, let's close */
+                               transport->layer = TRANSPORT_LAYER_CLOSED;
+                               return -1;
+                       }
+
+                       /* non blocking will survive a partial read */
+                       if (!transport->blocking)
+                               return read;
+
+                       /* blocking means that we can't continue until we have read the number of
+                        * requested bytes */
+                       if (transport_wait_for_read(transport) < 0)
+                       {
+                               fprintf(stderr, "%s: error when selecting for read\n", __FUNCTION__);
+                               return -1;
+                       }
+                       continue;
                }
+
+#ifdef HAVE_VALGRIND_MEMCHECK_H
+               VALGRIND_MAKE_MEM_DEFINED(data + read, bytes - read);
+#endif
+
+               read += status;
        }
 
        return read;
 }
 
+
+
 int transport_read(rdpTransport* transport, wStream* s)
 {
        int status;
        int position;
        int pduLength;
-       BYTE header[4];
+       BYTE *header;
        int transport_status;
 
        position = 0;
@@ -710,7 +780,7 @@ int transport_read(rdpTransport* transport, wStream* s)
                position += status;
        }
 
-       CopyMemory(header, Stream_Buffer(s), 4); /* peek at first 4 bytes */
+       header = Stream_Buffer(s);
 
        /* if header is present, read exactly one PDU */
 
@@ -802,6 +872,8 @@ static int transport_read_nonblocking(rdpTransport* transport)
        return status;
 }
 
+BOOL transport_bio_buffered_drain(BIO *bio);
+
 int transport_write(rdpTransport* transport, wStream* s)
 {
        int length;
@@ -827,36 +899,48 @@ int transport_write(rdpTransport* transport, wStream* s)
 
        while (length > 0)
        {
-               if (transport->layer == TRANSPORT_LAYER_TLS)
-                       status = tls_write(transport->TlsOut, Stream_Pointer(s), length);
-               else if (transport->layer == TRANSPORT_LAYER_TCP)
-                       status = tcp_write(transport->TcpOut, Stream_Pointer(s), length);
-               else if (transport->layer == TRANSPORT_LAYER_TSG)
-                       status = tsg_write(transport->tsg, Stream_Pointer(s), length);
-               else if (transport->layer == TRANSPORT_LAYER_TSG_TLS)
-                       status = tls_write(transport->TsgTls, Stream_Pointer(s), length);
+               status = BIO_write(transport->frontBio, Stream_Pointer(s), length);
 
-               if (status < 0)
-                       break; /* error occurred */
-
-               if (status == 0)
+               if (status <= 0)
                {
-                       /* when sending is blocked in nonblocking mode, the receiving buffer should be checked */
+                       /* the buffered BIO that is at the end of the chain always says OK for writing,
+                        * so a retry means that for any reason we need to read. The most probable
+                        * is a SSL or TSG BIO in the chain.
+                        */
+                       if (!BIO_should_retry(transport->frontBio))
+                               return status;
+
+                       /* non-blocking can live with blocked IOs */
                        if (!transport->blocking)
+                               return status;
+
+                       if (transport_wait_for_write(transport) < 0)
                        {
-                               /* and in case we do have buffered some data, we set the event so next loop will get it */
-                               if (transport_read_nonblocking(transport) > 0)
-                                       SetEvent(transport->ReceiveEvent);
+                               fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__);
+                               return -1;
                        }
+                       continue;
+               }
 
-                       if (transport->layer == TRANSPORT_LAYER_TLS)
-                               tls_wait_write(transport->TlsOut);
-                       else if (transport->layer == TRANSPORT_LAYER_TCP)
-                               tcp_wait_write(transport->TcpOut);
-                       else if (transport->layer == TRANSPORT_LAYER_TSG_TLS)
-                               tls_wait_write(transport->TsgTls);
-                       else
-                               USleep(transport->SleepInterval);
+               if (transport->blocking || transport->settings->WaitForOutputBufferFlush)
+               {
+                       /* blocking transport, we must ensure the write buffer is really empty */
+                       rdpTcp *out = transport->TcpOut;
+
+                       while (out->writeBlocked)
+                       {
+                               if (transport_wait_for_write(transport) < 0)
+                               {
+                                       fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__);
+                                       return -1;
+                               }
+
+                               if (!transport_bio_buffered_drain(out->bufferedBio))
+                               {
+                                       fprintf(stderr, "%s: error when draining outputBuffer\n", __FUNCTION__);
+                                       return -1;
+                               }
+                       }
                }
 
                length -= status;
@@ -945,6 +1029,38 @@ void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD*
        }
 }
 
+BOOL tranport_is_write_blocked(rdpTransport* transport)
+{
+       if (transport->TcpIn->writeBlocked)
+               return TRUE;
+
+       return transport->SplitInputOutput &&
+                       transport->TcpOut &&
+                       transport->TcpOut->writeBlocked;
+}
+
+int tranport_drain_output_buffer(rdpTransport* transport)
+{
+       BOOL ret = FALSE;
+
+       /* First try to send some accumulated bytes in the send buffer */
+       if (transport->TcpIn->writeBlocked)
+       {
+               if (!transport_bio_buffered_drain(transport->TcpIn->bufferedBio))
+                       return -1;
+               ret |= transport->TcpIn->writeBlocked;
+       }
+
+       if (transport->SplitInputOutput && transport->TcpOut && transport->TcpOut->writeBlocked)
+       {
+               if (!transport_bio_buffered_drain(transport->TcpOut->bufferedBio))
+                       return -1;
+               ret |= transport->TcpOut->writeBlocked;
+       }
+
+       return ret;
+}
+
 int transport_check_fds(rdpTransport* transport)
 {
        int pos;
@@ -1079,15 +1195,14 @@ int transport_check_fds(rdpTransport* transport)
 
                recv_status = transport->ReceiveCallback(transport, received, transport->ReceiveExtra);
 
-               Stream_Release(received);
-
-               if (recv_status < 0)
-                       return -1;
-
                if (recv_status == 1)
                {
                        return 1; /* session redirection */
                }
+               Stream_Release(received);
+
+               if (recv_status < 0)
+                       return -1;
        }
 
        return 0;
@@ -1198,80 +1313,107 @@ rdpTransport* transport_new(rdpSettings* settings)
 {
        rdpTransport* transport;
 
-       transport = (rdpTransport*) malloc(sizeof(rdpTransport));
-
-       if (transport)
-       {
-               ZeroMemory(transport, sizeof(rdpTransport));
+       transport = (rdpTransport *)calloc(1, sizeof(rdpTransport));
+       if (!transport)
+               return NULL;
 
-               WLog_Init();
-               transport->log = WLog_Get("com.freerdp.core.transport");
+       WLog_Init();
+       transport->log = WLog_Get("com.freerdp.core.transport");
+       if (!transport->log)
+               goto out_free;
 
-               transport->TcpIn = tcp_new(settings);
+       transport->TcpIn = tcp_new(settings);
+       if (!transport->TcpIn)
+               goto out_free;
 
-               transport->settings = settings;
+       transport->settings = settings;
 
-               /* a small 0.1ms delay when transport is blocking. */
-               transport->SleepInterval = 100;
+       /* a small 0.1ms delay when transport is blocking. */
+       transport->SleepInterval = 100;
 
-               transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE);
+       transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE);
+       if (!transport->ReceivePool)
+               goto out_free_tcpin;
 
-               /* receive buffer for non-blocking read. */
-               transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0);
-               transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+       /* receive buffer for non-blocking read. */
+       transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0);
+       if (!transport->ReceiveBuffer)
+               goto out_free_receivepool;
 
-               transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+       transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+       if (!transport->ReceiveEvent || transport->ReceiveEvent == INVALID_HANDLE_VALUE)
+               goto out_free_receivebuffer;
 
-               transport->blocking = TRUE;
-               transport->GatewayEnabled = FALSE;
+       transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+       if (!transport->connectedEvent || transport->connectedEvent == INVALID_HANDLE_VALUE)
+               goto out_free_receiveEvent;
 
-               InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000);
-               InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000);
+       transport->blocking = TRUE;
+       transport->GatewayEnabled = FALSE;
+       transport->layer = TRANSPORT_LAYER_TCP;
 
-               transport->layer = TRANSPORT_LAYER_TCP;
-       }
+       if (!InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000))
+               goto out_free_connectedEvent;
+       if (!InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000))
+               goto out_free_readlock;
 
        return transport;
+
+out_free_readlock:
+       DeleteCriticalSection(&(transport->ReadLock));
+out_free_connectedEvent:
+       CloseHandle(transport->connectedEvent);
+out_free_receiveEvent:
+       CloseHandle(transport->ReceiveEvent);
+out_free_receivebuffer:
+       StreamPool_Return(transport->ReceivePool, transport->ReceiveBuffer);
+out_free_receivepool:
+       StreamPool_Free(transport->ReceivePool);
+out_free_tcpin:
+       tcp_free(transport->TcpIn);
+out_free:
+       free(transport);
+       return NULL;
 }
 
 void transport_free(rdpTransport* transport)
 {
-       if (transport)
-       {
-               transport_stop(transport);
+       if (!transport)
+               return;
 
-               if (transport->ReceiveBuffer)
-                       Stream_Release(transport->ReceiveBuffer);
+       transport_stop(transport);
 
-               StreamPool_Free(transport->ReceivePool);
+       if (transport->ReceiveBuffer)
+               Stream_Release(transport->ReceiveBuffer);
 
-               CloseHandle(transport->ReceiveEvent);
-               CloseHandle(transport->connectedEvent);
+       StreamPool_Free(transport->ReceivePool);
 
-               if (transport->TlsIn)
-                       tls_free(transport->TlsIn);
+       CloseHandle(transport->ReceiveEvent);
+       CloseHandle(transport->connectedEvent);
 
-               if (transport->TlsOut != transport->TlsIn)
-                       tls_free(transport->TlsOut);
+       if (transport->TlsIn)
+               tls_free(transport->TlsIn);
 
-               transport->TlsIn = NULL;
-               transport->TlsOut = NULL;
+       if (transport->TlsOut != transport->TlsIn)
+               tls_free(transport->TlsOut);
 
-               if (transport->TcpIn)
-                       tcp_free(transport->TcpIn);
+       transport->TlsIn = NULL;
+       transport->TlsOut = NULL;
 
-               if (transport->TcpOut != transport->TcpIn)
-                       tcp_free(transport->TcpOut);
+       if (transport->TcpIn)
+               tcp_free(transport->TcpIn);
 
-               transport->TcpIn = NULL;
-               transport->TcpOut = NULL;
+       if (transport->TcpOut != transport->TcpIn)
+               tcp_free(transport->TcpOut);
 
-               tsg_free(transport->tsg);
-               transport->tsg = NULL;
+       transport->TcpIn = NULL;
+       transport->TcpOut = NULL;
 
-               DeleteCriticalSection(&(transport->ReadLock));
-               DeleteCriticalSection(&(transport->WriteLock));
+       tsg_free(transport->tsg);
+       transport->tsg = NULL;
 
-               free(transport);
-       }
+       DeleteCriticalSection(&(transport->ReadLock));
+       DeleteCriticalSection(&(transport->WriteLock));
+
+       free(transport);
 }
index b8834ce..8298074 100644 (file)
@@ -49,11 +49,13 @@ typedef struct rdp_transport rdpTransport;
 #include <freerdp/types.h>
 #include <freerdp/settings.h>
 
+
 typedef int (*TransportRecv) (rdpTransport* transport, wStream* stream, void* extra);
 
 struct rdp_transport
 {
        TRANSPORT_LAYER layer;
+       BIO *frontBio;
        rdpTsg* tsg;
        rdpTcp* TcpIn;
        rdpTcp* TcpOut;
@@ -102,6 +104,8 @@ BOOL transport_set_blocking_mode(rdpTransport* transport, BOOL blocking);
 void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled);
 void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode);
 void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* count);
+BOOL tranport_is_write_blocked(rdpTransport* transport);
+BOOL tranport_drain_output_buffer(rdpTransport* transport);
 
 wStream* transport_receive_pool_take(rdpTransport* transport);
 int transport_receive_pool_return(rdpTransport* transport, wStream* pdu);
index 52c2177..016584f 100644 (file)
 
 #include <winpr/stream.h>
 #include <freerdp/utils/tcp.h>
+#include <freerdp/utils/ringbuffer.h>
 
 #include <freerdp/crypto/tls.h>
-
-#ifdef HAVE_VALGRIND_MEMCHECK_H
-#include <valgrind/memcheck.h>
-#endif
+#include "../core/tcp.h"
 
 static CryptoCert tls_get_certificate(rdpTls* tls, BOOL peer)
 {
        CryptoCert cert;
-       X509* server_cert;
+       X509* remote_cert;
 
        if (peer)
-               server_cert = SSL_get_peer_certificate(tls->ssl);
+               remote_cert = SSL_get_peer_certificate(tls->ssl);
        else
-               server_cert = SSL_get_certificate(tls->ssl);
+               remote_cert = SSL_get_certificate(tls->ssl);
 
-       if (!server_cert)
+       if (!remote_cert)
        {
-               fprintf(stderr, "tls_get_certificate: failed to get the server TLS certificate\n");
-               cert = NULL;
+               fprintf(stderr, "%s: failed to get the server TLS certificate\n", __FUNCTION__);
+               return NULL;
        }
-       else
+
+       cert = malloc(sizeof(*cert));
+       if (!cert)
        {
-               cert = malloc(sizeof(*cert));
-               cert->px509 = server_cert;
+               X509_free(remote_cert);
+               return NULL;
        }
 
+       cert->px509 = remote_cert;
        return cert;
 }
 
@@ -83,12 +84,14 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert)
        PrefixLength = strlen(TLS_SERVER_END_POINT);
        ChannelBindingTokenLength = PrefixLength + CertificateHashLength;
 
-       ContextBindings = (SecPkgContext_Bindings*) malloc(sizeof(SecPkgContext_Bindings));
-       ZeroMemory(ContextBindings, sizeof(SecPkgContext_Bindings));
+       ContextBindings = (SecPkgContext_Bindings*) calloc(1, sizeof(SecPkgContext_Bindings));
+       if (!ContextBindings)
+               return NULL;
 
        ContextBindings->BindingsLength = sizeof(SEC_CHANNEL_BINDINGS) + ChannelBindingTokenLength;
-       ChannelBindings = (SEC_CHANNEL_BINDINGS*) malloc(ContextBindings->BindingsLength);
-       ZeroMemory(ChannelBindings, ContextBindings->BindingsLength);
+       ChannelBindings = (SEC_CHANNEL_BINDINGS*) calloc(1, ContextBindings->BindingsLength);
+       if (!ChannelBindings)
+               goto out_free;
        ContextBindings->Bindings = ChannelBindings;
 
        ChannelBindings->cbApplicationDataLength = ChannelBindingTokenLength;
@@ -99,120 +102,99 @@ SecPkgContext_Bindings* tls_get_channel_bindings(X509* cert)
        CopyMemory(&ChannelBindingToken[PrefixLength], CertificateHash, CertificateHashLength);
 
        return ContextBindings;
-}
-
-static void tls_ssl_info_callback(const SSL* ssl, int type, int val)
-{
-       if (type & SSL_CB_HANDSHAKE_START)
-       {
 
-       }
+out_free:
+       free(ContextBindings);
+       return NULL;
 }
 
-int tls_connect(rdpTls* tls)
-{
-       CryptoCert cert;
-       long options = 0;
-       int verify_status;
-       int connection_status;
-
-       tls->ctx = SSL_CTX_new(TLSv1_client_method());
 
+BOOL tls_prepare(rdpTls* tls, BIO *underlying, const SSL_METHOD *method, int options, BOOL clientMode)
+{
+       tls->ctx = SSL_CTX_new(method);
        if (!tls->ctx)
        {
-               fprintf(stderr, "SSL_CTX_new failed\n");
-               return -1;
+               fprintf(stderr, "%s: SSL_CTX_new failed\n", __FUNCTION__);
+               return FALSE;
        }
 
-       //SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
-
-       /**
-        * SSL_OP_NO_COMPRESSION:
-        *
-        * The Microsoft RDP server does not advertise support
-        * for TLS compression, but alternative servers may support it.
-        * This was observed between early versions of the FreeRDP server
-        * and the FreeRDP client, and caused major performance issues,
-        * which is why we're disabling it.
-        */
-#ifdef SSL_OP_NO_COMPRESSION
-       options |= SSL_OP_NO_COMPRESSION;
-#endif
-        
-       /**
-        * SSL_OP_TLS_BLOCK_PADDING_BUG:
-        *
-        * The Microsoft RDP server does *not* support TLS padding.
-        * It absolutely needs to be disabled otherwise it won't work.
-        */
-       options |= SSL_OP_TLS_BLOCK_PADDING_BUG;
-
-       /**
-        * SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS:
-        *
-        * Just like TLS padding, the Microsoft RDP server does not
-        * support empty fragments. This needs to be disabled.
-        */
-       options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
+       SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
 
        SSL_CTX_set_options(tls->ctx, options);
+       SSL_CTX_set_read_ahead(tls->ctx, 1);
 
-       tls->ssl = SSL_new(tls->ctx);
-
-       if (!tls->ssl)
+       tls->bio = BIO_new_ssl(tls->ctx, clientMode);
+       if (BIO_get_ssl(tls->bio, &tls->ssl) < 0)
        {
-               fprintf(stderr, "SSL_new failed\n");
-               return -1;
+               fprintf(stderr, "%s: unable to retrieve the SSL of the connection\n", __FUNCTION__);
+               return FALSE;
        }
 
-       if (tls->tsg)
+       BIO_push(tls->bio, underlying);
+       return TRUE;
+}
+
+int tls_do_handshake(rdpTls* tls, BOOL clientMode)
+{
+       CryptoCert cert;
+       int verify_status, status;
+
+       do
        {
-               tls->bio = BIO_new(tls->methods);
+               struct timeval tv;
+               fd_set rset;
+               int fd;
 
-               if (!tls->bio)
-               {
-                       fprintf(stderr, "BIO_new failed\n");
+               status = BIO_do_handshake(tls->bio);
+               if (status == 1)
+                       break;
+               if (!BIO_should_retry(tls->bio))
                        return -1;
-               }
-
-               tls->bio->ptr = tls->tsg;
 
-               SSL_set_bio(tls->ssl, tls->bio, tls->bio);
+               /* we select() only for read even if we should test both read and write
+                * depending of what have blocked */
+               FD_ZERO(&rset);
 
-               SSL_CTX_set_info_callback(tls->ctx, tls_ssl_info_callback);
-       }
-       else
-       {
-               if (SSL_set_fd(tls->ssl, tls->sockfd) < 1)
+               fd = BIO_get_fd(tls->bio, NULL);
+               if (fd < 0)
                {
-                       fprintf(stderr, "SSL_set_fd failed\n");
+                       fprintf(stderr, "%s: unable to retrieve BIO fd\n", __FUNCTION__);
                        return -1;
                }
-       }
 
-       connection_status = SSL_connect(tls->ssl);
+               FD_SET(fd, &rset);
+               tv.tv_sec = 0;
+               tv.tv_usec = 10 * 1000; /* 10ms */
 
-       if (connection_status <= 0)
-       {
-               if (tls_print_error("SSL_connect", tls->ssl, connection_status))
+               status = select(fd + 1, &rset, NULL, NULL, &tv);
+               if (status < 0)
                {
+                       fprintf(stderr, "%s: error during select()\n", __FUNCTION__);
                        return -1;
                }
        }
+       while (TRUE);
 
-       cert = tls_get_certificate(tls, TRUE);
+       if (!clientMode)
+               return 1;
 
+       cert = tls_get_certificate(tls, clientMode);
        if (!cert)
        {
-               fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n");
+               fprintf(stderr, "%s: tls_get_certificate failed to return the server certificate.\n", __FUNCTION__);
                return -1;
        }
 
        tls->Bindings = tls_get_channel_bindings(cert->px509);
+       if (!tls->Bindings)
+       {
+               fprintf(stderr, "%s: unable to retrieve bindings\n", __FUNCTION__);
+               return -1;
+       }
 
        if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength))
        {
-               fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n");
+               fprintf(stderr, "%s: crypto_cert_get_public_key failed to return the server public key.\n", __FUNCTION__);
                tls_free_certificate(cert);
                return -1;
        }
@@ -221,8 +203,10 @@ int tls_connect(rdpTls* tls)
 
        if (verify_status < 1)
        {
-               fprintf(stderr, "tls_connect: certificate not trusted, aborting.\n");
+               fprintf(stderr, "%s: certificate not trusted, aborting.\n", __FUNCTION__);
                tls_disconnect(tls);
+               tls_free_certificate(cert);
+               return 0;
        }
 
        tls_free_certificate(cert);
@@ -230,21 +214,52 @@ int tls_connect(rdpTls* tls)
        return verify_status;
 }
 
-BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file)
+int tls_connect(rdpTls* tls, BIO *underlying)
 {
-       CryptoCert cert;
-       long options = 0;
-       int connection_status;
+       int options = 0;
 
-       tls->ctx = SSL_CTX_new(SSLv23_server_method());
+       /**
+        * SSL_OP_NO_COMPRESSION:
+        *
+        * The Microsoft RDP server does not advertise support
+        * for TLS compression, but alternative servers may support it.
+        * This was observed between early versions of the FreeRDP server
+        * and the FreeRDP client, and caused major performance issues,
+        * which is why we're disabling it.
+        */
+#ifdef SSL_OP_NO_COMPRESSION
+       options |= SSL_OP_NO_COMPRESSION;
+#endif
 
-       if (tls->ctx == NULL)
-       {
-               fprintf(stderr, "SSL_CTX_new failed\n");
+       /**
+        * SSL_OP_TLS_BLOCK_PADDING_BUG:
+        *
+        * The Microsoft RDP server does *not* support TLS padding.
+        * It absolutely needs to be disabled otherwise it won't work.
+        */
+       options |= SSL_OP_TLS_BLOCK_PADDING_BUG;
+
+       /**
+        * SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS:
+        *
+        * Just like TLS padding, the Microsoft RDP server does not
+        * support empty fragments. This needs to be disabled.
+        */
+       options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
+
+       if (!tls_prepare(tls, underlying, TLSv1_client_method(), options, TRUE))
                return FALSE;
-       }
 
-       /*
+       return tls_do_handshake(tls, TRUE);
+}
+
+
+
+BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file)
+{
+       long options = 0;
+
+       /**
         * SSL_OP_NO_SSLv2:
         *
         * We only want SSLv3 and TLSv1, so disable SSLv2.
@@ -281,336 +296,184 @@ BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file)
         */
        options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
 
-       SSL_CTX_set_options(tls->ctx, options);
-
-       if (SSL_CTX_use_RSAPrivateKey_file(tls->ctx, privatekey_file, SSL_FILETYPE_PEM) <= 0)
-       {
-               fprintf(stderr, "SSL_CTX_use_RSAPrivateKey_file failed\n");
-               fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file);
+       if (!tls_prepare(tls, underlying, SSLv23_server_method(), options, FALSE))
                return FALSE;
-       }
 
-       tls->ssl = SSL_new(tls->ctx);
-
-       if (!tls->ssl)
+       if (SSL_use_RSAPrivateKey_file(tls->ssl, privatekey_file, SSL_FILETYPE_PEM) <= 0)
        {
-               fprintf(stderr, "SSL_new failed\n");
+               fprintf(stderr, "%s: SSL_CTX_use_RSAPrivateKey_file failed\n", __FUNCTION__);
+               fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file);
                return FALSE;
        }
 
        if (SSL_use_certificate_file(tls->ssl, cert_file, SSL_FILETYPE_PEM) <= 0)
        {
-               fprintf(stderr, "SSL_use_certificate_file failed\n");
+               fprintf(stderr, "%s: SSL_use_certificate_file failed\n", __FUNCTION__);
                return FALSE;
        }
 
-       if (SSL_set_fd(tls->ssl, tls->sockfd) < 1)
-       {
-               fprintf(stderr, "SSL_set_fd failed\n");
+       return tls_do_handshake(tls, FALSE) > 0;
+}
+
+BOOL tls_disconnect(rdpTls* tls)
+{
+       if (!tls)
                return FALSE;
-       }
 
-       while (1)
+       if (!tls->ssl)
+               return TRUE;
+
+       if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY)
        {
-               connection_status = SSL_accept(tls->ssl);
+               /**
+                * OpenSSL doesn't really expose an API for sending a TLS alert manually.
+                *
+                * The following code disables the sending of the default "close notify"
+                * and then proceeds to force sending a custom TLS alert before shutting down.
+                *
+                * Manually sending a TLS alert is necessary in certain cases,
+                * like when server-side NLA results in an authentication failure.
+                */
 
-               if (connection_status <= 0)
-               {
-                       switch (SSL_get_error(tls->ssl, connection_status))
-                       {
-                               case SSL_ERROR_WANT_READ:
-                               case SSL_ERROR_WANT_WRITE:
-                                       break;
+               SSL_set_quiet_shutdown(tls->ssl, 1);
 
-                               default:
-                                       if (tls_print_error("SSL_accept", tls->ssl, connection_status))
-                                               return FALSE;
-                                       break;
+               if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session))
+                       SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session);
 
-                       }
-               }
-               else
-               {
-                       break;
-               }
-       }
+               tls->ssl->s3->alert_dispatch = 1;
+               tls->ssl->s3->send_alert[0] = tls->alertLevel;
+               tls->ssl->s3->send_alert[1] = tls->alertDescription;
 
-       cert = tls_get_certificate(tls, FALSE);
+               if (tls->ssl->s3->wbuf.left == 0)
+                       tls->ssl->method->ssl_dispatch_alert(tls->ssl);
 
-       if (!cert)
-       {
-               fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n");
-               return FALSE;
+               SSL_shutdown(tls->ssl);
        }
-
-       if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength))
+       else
        {
-               fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n");
-               tls_free_certificate(cert);
-               return FALSE;
+               SSL_shutdown(tls->ssl);
        }
 
-       free(cert);
-
-       fprintf(stderr, "TLS connection accepted\n");
-
        return TRUE;
 }
 
-BOOL tls_disconnect(rdpTls* tls)
+
+BIO *findBufferedBio(BIO *front)
 {
-       if (!tls)
-               return FALSE;
+       BIO *ret = front;
 
-       if (tls->ssl)
+       while (ret)
        {
-               if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY)
-               {
-                       /**
-                        * OpenSSL doesn't really expose an API for sending a TLS alert manually.
-                        *
-                        * The following code disables the sending of the default "close notify"
-                        * and then proceeds to force sending a custom TLS alert before shutting down.
-                        *
-                        * Manually sending a TLS alert is necessary in certain cases,
-                        * like when server-side NLA results in an authentication failure.
-                        */
-
-                       SSL_set_quiet_shutdown(tls->ssl, 1);
-
-                       if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session))
-                               SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session);
-
-                       tls->ssl->s3->alert_dispatch = 1;
-                       tls->ssl->s3->send_alert[0] = tls->alertLevel;
-                       tls->ssl->s3->send_alert[1] = tls->alertDescription;
-
-                       if (tls->ssl->s3->wbuf.left == 0)
-                               tls->ssl->method->ssl_dispatch_alert(tls->ssl);
-
-                       SSL_shutdown(tls->ssl);
-               }
-               else
-               {
-                       SSL_shutdown(tls->ssl);
-               }
+               if (BIO_method_type(ret) == BIO_TYPE_BUFFERED)
+                       return ret;
+               ret = ret->next_bio;
        }
 
-       return TRUE;
+       return ret;
 }
 
-int tls_read(rdpTls* tls, BYTE* data, int length)
+int tls_write_all(rdpTls* tls, const BYTE* data, int length)
 {
-       int error;
-       int status;
-
-       if (!tls)
-               return -1;
-
-       if (!tls->ssl)
+       int status, nchunks, commitedBytes;
+       rdpTcp *tcp;
+       fd_set rset, wset;
+       fd_set *rsetPtr, *wsetPtr;
+       struct timeval tv;
+       BIO *bio = tls->bio;
+       DataChunk chunks[2];
+
+       BIO *bufferedBio = findBufferedBio(bio);
+       if (!bufferedBio)
+       {
+               fprintf(stderr, "%s: error unable to retrieve the bufferedBio in the BIO chain\n", __FUNCTION__);
                return -1;
-
-       status = SSL_read(tls->ssl, data, length);
-
-       if (status == 0)
-       {
-               return -1; /* peer disconnected */
        }
 
-       if (status <= 0)
+       tcp = (rdpTcp *)bufferedBio->ptr;
+
+       do
        {
-               error = SSL_get_error(tls->ssl, status);
+               status = BIO_write(bio, data, length);
+               /*fprintf(stderr, "%s: BIO_write(len=%d) = %d (retry=%d)\n", __FUNCTION__, length, status, BIO_should_retry(bio));*/
+               if (status > 0)
+                       break;
 
-               //fprintf(stderr, "tls_read: length: %d status: %d error: 0x%08X\n",
-               //              length, status, error);
+               if (!BIO_should_retry(bio))
+                       return -1;
 
-               switch (error)
+               /* we try to handle SSL want_read and want_write nicely */
+               rsetPtr = wsetPtr = 0;
+               if (tcp->writeBlocked)
                {
-                       case SSL_ERROR_NONE:
-                               break;
-
-                       case SSL_ERROR_WANT_READ:
-                       case SSL_ERROR_WANT_WRITE:
-                               status = 0;
-                               break;
-
-                       case SSL_ERROR_SYSCALL:
-#ifdef _WIN32
-                               if (WSAGetLastError() == WSAEWOULDBLOCK)
-#else
-                               if ((errno == EAGAIN) || (errno == 0))
-#endif
-                               {
-                                       status = 0;
-                               }
-                               else
-                               {
-                                       if (tls_print_error("SSL_read", tls->ssl, status))
-                                       {
-                                               status = -1;
-                                       }
-                                       else
-                                       {
-                                               status = 0;
-                                       }
-                               }
-                               break;
-
-                       default:
-                               if (tls_print_error("SSL_read", tls->ssl, status))
-                               {
-                                       status = -1;
-                               }
-                               else
-                               {
-                                       status = 0;
-                               }
-                               break;
+                       wsetPtr = &wset;
+                       FD_ZERO(&wset);
+                       FD_SET(tcp->sockfd, &wset);
+               }
+               else if (tcp->readBlocked)
+               {
+                       rsetPtr = &rset;
+                       FD_ZERO(&rset);
+                       FD_SET(tcp->sockfd, &rset);
+               }
+               else
+               {
+                       fprintf(stderr, "%s: weird we're blocked but the underlying is not read or write blocked !\n", __FUNCTION__);
+                       USleep(10);
+                       continue;
                }
-       }
-
-#ifdef HAVE_VALGRIND_MEMCHECK_H
-       VALGRIND_MAKE_MEM_DEFINED(data, status);
-#endif
-
-       return status;
-}
-
-int tls_write(rdpTls* tls, BYTE* data, int length)
-{
-       int error;
-       int status;
-
-       if (!tls)
-               return -1;
-
-       if (!tls->ssl)
-               return -1;
 
-       status = SSL_write(tls->ssl, data, length);
+               tv.tv_sec = 0;
+               tv.tv_usec = 100 * 1000;
 
-       if (status == 0)
-       {
-               return -1; /* peer disconnected */
+               status = select(tcp->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
+               if (status < 0)
+                       return -1;
        }
+       while (TRUE);
 
-       if (status < 0)
+       /* make sure the output buffer is empty */
+       commitedBytes = 0;
+       while ((nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer))))
        {
-               error = SSL_get_error(tls->ssl, status);
-
-               //fprintf(stderr, "tls_write: length: %d status: %d error: 0x%08X\n", length, status, error);
+               int i;
 
-               switch (error)
+               for (i = 0; i < nchunks; i++)
                {
-                       case SSL_ERROR_NONE:
-                               break;
-
-                       case SSL_ERROR_WANT_READ:
-                       case SSL_ERROR_WANT_WRITE:
-                               status = 0;
-                               break;
-
-                       case SSL_ERROR_SYSCALL:
-                               if (errno == EAGAIN)
-                               {
-                                       status = 0;
-                               }
-                               else
+                       while (chunks[i].size)
+                       {
+                               status = BIO_write(tcp->socketBio, chunks[i].data, chunks[i].size);
+                               if (status > 0)
                                {
-                                       tls_print_error("SSL_write", tls->ssl, status);
-                                       status = -1;
+                                       chunks[i].size -= status;
+                                       chunks[i].data += status;
+                                       commitedBytes += status;
+                                       continue;
                                }
-                               break;
-
-                       default:
-                               tls_print_error("SSL_write", tls->ssl, status);
-                               status = -1;
-                               break;
-               }
-       }
-
-       return status;
-}
-
-int tls_write_all(rdpTls* tls, BYTE* data, int length)
-{
-       int status;
-       int sent = 0;
 
-       do
-       {
-               status = tls_write(tls, &data[sent], length - sent);
+                               if (!BIO_should_retry(tcp->socketBio))
+                                       goto out_fail;
+                               FD_ZERO(&rset);
+                               FD_SET(tcp->sockfd, &rset);
+                               tv.tv_sec = 0;
+                               tv.tv_usec = 100 * 1000;
 
-               if (status > 0)
-                       sent += status;
-               else if (status == 0)
-                       tls_wait_write(tls);
+                               status = select(tcp->sockfd + 1, &rset, NULL, NULL, &tv);
+                               if (status < 0)
+                                       goto out_fail;
+                       }
 
-               if (sent >= length)
-                       break;
+               }
        }
-       while (status >= 0);
-
-       if (status > 0)
-               return length;
-       else
-               return status;
-}
 
-int tls_wait_read(rdpTls* tls)
-{
-       return freerdp_tcp_wait_read(tls->sockfd);
-}
+       ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes);
+       return length;
 
-int tls_wait_write(rdpTls* tls)
-{
-       return freerdp_tcp_wait_write(tls->sockfd);
+out_fail:
+       ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes);
+       return -1;
 }
 
-static void tls_errors(const char *prefix)
-{
-       unsigned long error;
 
-       while ((error = ERR_get_error()) != 0)
-               fprintf(stderr, "%s: %s\n", prefix, ERR_error_string(error, NULL));
-}
-
-BOOL tls_print_error(char* func, SSL* connection, int value)
-{
-       switch (SSL_get_error(connection, value))
-       {
-               case SSL_ERROR_ZERO_RETURN:
-                       fprintf(stderr, "%s: Server closed TLS connection\n", func);
-                       return TRUE;
-
-               case SSL_ERROR_WANT_READ:
-                       fprintf(stderr, "%s: SSL_ERROR_WANT_READ\n", func);
-                       return FALSE;
-
-               case SSL_ERROR_WANT_WRITE:
-                       fprintf(stderr, "%s: SSL_ERROR_WANT_WRITE\n", func);
-                       return FALSE;
-
-               case SSL_ERROR_SYSCALL:
-#ifdef _WIN32
-                       fprintf(stderr, "%s: I/O error: %d\n", func, WSAGetLastError());
-#else
-                       fprintf(stderr, "%s: I/O error: %s (%d)\n", func, strerror(errno), errno);
-#endif
-                       tls_errors(func);
-                       return TRUE;
-
-               case SSL_ERROR_SSL:
-                       fprintf(stderr, "%s: Failure in SSL library (protocol error?)\n", func);
-                       tls_errors(func);
-                       return TRUE;
-
-               default:
-                       fprintf(stderr, "%s: Unknown error\n", func);
-                       tls_errors(func);
-                       return TRUE;
-       }
-}
 
 int tls_set_alert_code(rdpTls* tls, int level, int description)
 {
@@ -672,7 +535,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
                
                if (!bio)
                {
-                       fprintf(stderr, "tls_verify_certificate: BIO_new() failure\n");
+                       fprintf(stderr, "%s: BIO_new() failure\n", __FUNCTION__);
                        return -1;
                }
 
@@ -680,7 +543,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
 
                if (status < 0)
                {
-                       fprintf(stderr, "tls_verify_certificate: PEM_write_bio_X509 failure: %d\n", status);
+                       fprintf(stderr, "%s: PEM_write_bio_X509 failure: %d\n", __FUNCTION__, status);
                        return -1;
                }
                
@@ -692,7 +555,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
                
                if (status < 0)
                {
-                       fprintf(stderr, "tls_verify_certificate: failed to read certificate\n");
+                       fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__);
                        return -1;
                }
                
@@ -713,7 +576,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
 
                if (status < 0)
                {
-                       fprintf(stderr, "tls_verify_certificate: failed to read certificate\n");
+                       fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__);
                        return -1;
                }
                
@@ -727,8 +590,7 @@ int tls_verify_certificate(rdpTls* tls, CryptoCert cert, char* hostname, int por
                        status = instance->VerifyX509Certificate(instance, pemCert, length, hostname, port, 0);
                }
                
-               fprintf(stderr, "VerifyX509Certificate: (length = %d) status: %d\n%s\n",
-                       length, status, pemCert);
+               fprintf(stderr, "%s: (length = %d) status: %d\n%s\n", __FUNCTION__,     length, status, pemCert);
 
                free(pemCert);
                BIO_free(bio);
@@ -932,57 +794,53 @@ rdpTls* tls_new(rdpSettings* settings)
 {
        rdpTls* tls;
 
-       tls = (rdpTls*) malloc(sizeof(rdpTls));
-
-       if (tls)
-       {
-               ZeroMemory(tls, sizeof(rdpTls));
-
-               SSL_load_error_strings();
-               SSL_library_init();
+       tls = (rdpTls *)calloc(1, sizeof(rdpTls));
+       if (!tls)
+               return NULL;
 
-               tls->settings = settings;
-               tls->certificate_store = certificate_store_new(settings);
+       SSL_load_error_strings();
+       SSL_library_init();
 
-               tls->alertLevel = TLS_ALERT_LEVEL_WARNING;
-               tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY;
-       }
+       tls->settings = settings;
+       tls->certificate_store = certificate_store_new(settings);
+       if (!tls->certificate_store)
+               goto out_free;
 
+       tls->alertLevel = TLS_ALERT_LEVEL_WARNING;
+       tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY;
        return tls;
+
+out_free:
+       free(tls);
+       return NULL;
 }
 
 void tls_free(rdpTls* tls)
 {
-       if (tls)
-       {
-               if (tls->ssl)
-               {
-                       SSL_free(tls->ssl);
-                       tls->ssl = NULL;
-               }
+       if (!tls)
+               return;
 
-               if (tls->ctx)
-               {
-                       SSL_CTX_free(tls->ctx);
-                       tls->ctx = NULL;
-               }
+       if (tls->ctx)
+       {
+               SSL_CTX_free(tls->ctx);
+               tls->ctx = NULL;
+       }
 
-               if (tls->PublicKey)
-               {
-                       free(tls->PublicKey);
-                       tls->PublicKey = NULL;
-               }
+       if (tls->PublicKey)
+       {
+               free(tls->PublicKey);
+               tls->PublicKey = NULL;
+       }
 
-               if (tls->Bindings)
-               {
-                       free(tls->Bindings->Bindings);
-                       free(tls->Bindings);
-                       tls->Bindings = NULL;
-               }
+       if (tls->Bindings)
+       {
+               free(tls->Bindings->Bindings);
+               free(tls->Bindings);
+               tls->Bindings = NULL;
+       }
 
-               certificate_store_free(tls->certificate_store);
-               tls->certificate_store = NULL;
+       certificate_store_free(tls->certificate_store);
+       tls->certificate_store = NULL;
 
-               free(tls);
-       }
+       free(tls);
 }