SSL* ssl;
BIO* bio;
void* tsg;
- int sockfd;
SSL_CTX* ctx;
BYTE* PublicKey;
BIO_METHOD* methods;
int alertDescription;
};
-FREERDP_API int tls_connect(rdpTls* tls);
-FREERDP_API BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file);
+FREERDP_API int tls_connect(rdpTls* tls, BIO *underlying);
+FREERDP_API BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file);
FREERDP_API BOOL tls_disconnect(rdpTls* tls);
-FREERDP_API int tls_read(rdpTls* tls, BYTE* data, int length);
-FREERDP_API int tls_write(rdpTls* tls, BYTE* data, int length);
-
-FREERDP_API int tls_write_all(rdpTls* tls, BYTE* data, int length);
-
-FREERDP_API int tls_wait_read(rdpTls* tls);
-FREERDP_API int tls_wait_write(rdpTls* tls);
+FREERDP_API int tls_write_all(rdpTls* tls, const BYTE* data, int length);
FREERDP_API int tls_set_alert_code(rdpTls* tls, int level, int description);
typedef BOOL (*psPeerInitialize)(freerdp_peer* client);
typedef BOOL (*psPeerGetFileDescriptor)(freerdp_peer* client, void** rfds, int* rcount);
typedef HANDLE (*psPeerGetEventHandle)(freerdp_peer* client);
+typedef HANDLE (*psPeerGetReceiveEventHandle)(freerdp_peer* client);
typedef BOOL (*psPeerCheckFileDescriptor)(freerdp_peer* client);
+typedef BOOL (*psPeerIsWriteBlocked)(freerdp_peer* client);
+typedef int (*psPeerDrainOutputBuffer)(freerdp_peer* client);
typedef BOOL (*psPeerClose)(freerdp_peer* client);
typedef void (*psPeerDisconnect)(freerdp_peer* client);
typedef BOOL (*psPeerCapabilities)(freerdp_peer* client);
psPeerInitialize Initialize;
psPeerGetFileDescriptor GetFileDescriptor;
psPeerGetEventHandle GetEventHandle;
+ psPeerGetReceiveEventHandle GetReceiveEventHandle;
psPeerCheckFileDescriptor CheckFileDescriptor;
psPeerClose Close;
psPeerDisconnect Disconnect;
BOOL activated;
BOOL authenticated;
SEC_WINNT_AUTH_IDENTITY identity;
+
+ psPeerIsWriteBlocked IsWriteBlocked;
+ psPeerDrainOutputBuffer DrainOutputBuffer;
};
#ifdef __cplusplus
ALIGN64 char* Password; /* 22 */
ALIGN64 char* Domain; /* 23 */
ALIGN64 char* PasswordHash; /* 24 */
- UINT64 padding0064[64 - 25]; /* 25 */
+ ALIGN64 BOOL WaitForOutputBufferFlush; /* 25 */
+ UINT64 padding0064[64 - 26]; /* 26 */
UINT64 padding0128[128 - 64]; /* 64 */
/**
#include <winpr/stream.h>
#include <winpr/string.h>
+#ifdef HAVE_VALGRIND_MEMCHECK_H
+#include <valgrind/memcheck.h>
+#endif
+
#include "http.h"
HttpContext* http_context_new()
nbytes = 0;
length = 10000;
content = NULL;
- buffer = malloc(length);
+ buffer = calloc(length, 1);
if (!buffer)
return NULL;
{
while (nbytes < 5)
{
- status = tls_read(tls, p, length - nbytes);
+ status = BIO_read(tls->bio, p, length - nbytes);
- if (status < 0)
- goto out_error;
+ if (status <= 0)
+ {
+ if (!BIO_should_retry(tls->bio))
+ goto out_error;
- if (!status)
+ USleep(100);
continue;
+ }
+#ifdef HAVE_VALGRIND_MEMCHECK_H
+ VALGRIND_MAKE_MEM_DEFINED(p, status);
+#endif
nbytes += status;
p = (BYTE*) &buffer[nbytes];
}
if (!header_end)
{
- fprintf(stderr, "http_response_recv: invalid response:\n");
+ fprintf(stderr, "%s: invalid response:\n", __FUNCTION__);
winpr_HexDump(buffer, status);
goto out_error;
}
header_end[0] = '\0';
header_end[1] = '\0';
- content = &header_end[2];
+ content = header_end + 2;
count = 0;
line = (char*) buffer;
if (!http_response_parse_header(http_response))
goto out_error;
- if (http_response->ContentLength > 0)
+ http_response->bodyLen = nbytes - (content - (char *)buffer);
+ if (http_response->bodyLen > 0)
{
- http_response->Content = _strdup(content);
- if (!http_response->Content)
+ http_response->BodyContent = (BYTE *)malloc(http_response->bodyLen);
+ if (!http_response->BodyContent)
goto out_error;
+
+ CopyMemory(http_response->BodyContent, content, http_response->bodyLen);
}
break;
ListDictionary_Free(http_response->Authenticates);
if (http_response->ContentLength > 0)
- free(http_response->Content);
+ free(http_response->BodyContent);
free(http_response);
}
wListDictionary *Authenticates;
int ContentLength;
- char* Content;
+ BYTE *BodyContent;
+ int bodyLen;
};
void http_response_print(HttpResponse* http_response);
rdpNtlm* ntlm = rpc->NtlmHttpIn->ntlm;
http_response = http_response_recv(rpc->TlsIn);
+ if (!http_response)
+ return -1;
if (ListDictionary_Contains(http_response->Authenticates, "NTLM"))
{
if (!token64)
goto out;
- ntlm_token_data = NULL;
crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length);
}
+out:
ntlm->inputBuffer[0].pvBuffer = ntlm_token_data;
ntlm->inputBuffer[0].cbBuffer = ntlm_token_length;
-
-out:
http_response_free(http_response);
return 0;
rdpNtlm* ntlm = NULL;
rdpSettings* settings = rpc->settings;
freerdp* instance = (freerdp*) rpc->settings->instance;
- BOOL promptPassword = FALSE;
if (channel == TSG_CHANNEL_IN)
ntlm = rpc->NtlmHttpIn->ntlm;
else if (channel == TSG_CHANNEL_OUT)
ntlm = rpc->NtlmHttpOut->ntlm;
- if ((!settings->GatewayPassword) || (!settings->GatewayUsername)
- || (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername)))
- {
- promptPassword = TRUE;
- }
-
- if (promptPassword)
+ if (!settings->GatewayPassword || !settings->GatewayUsername ||
+ !strlen(settings->GatewayPassword) || !strlen(settings->GatewayUsername))
{
if (instance->GatewayAuthenticate)
{
- BOOL proceed = instance->GatewayAuthenticate(instance,
- &settings->GatewayUsername, &settings->GatewayPassword, &settings->GatewayDomain);
+ BOOL proceed = instance->GatewayAuthenticate(instance, &settings->GatewayUsername,
+ &settings->GatewayPassword, &settings->GatewayDomain);
if (!proceed)
{
char *token64 = ListDictionary_GetItemValue(http_response->Authenticates, "NTLM");
crypto_base64_decode(token64, strlen(token64), &ntlm_token_data, &ntlm_token_length);
}
-
ntlm->inputBuffer[0].pvBuffer = ntlm_token_data;
ntlm->inputBuffer[0].cbBuffer = ntlm_token_length;
-
+
http_response_free(http_response);
-
return 0;
}
success = TRUE;
/* Send OUT Channel Request */
-
rpc_ncacn_http_send_out_channel_request(rpc);
/* Receive OUT Channel Response */
-
rpc_ncacn_http_recv_out_channel_response(rpc);
/* Send OUT Channel Request */
-
rpc_ncacn_http_send_out_channel_request(rpc);
ntlm_client_uninit(ntlm);
if (channel == TSG_CHANNEL_IN)
{
- http_context_set_pragma(ntlm_http->context,
- "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729");
+ http_context_set_pragma(ntlm_http->context, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729");
}
else if (channel == TSG_CHANNEL_OUT)
{
- http_context_set_pragma(ntlm_http->context,
- "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729" ", "
+ http_context_set_pragma(ntlm_http->context, "ResourceTypeUuid=44e265dd-7daf-42cd-8560-3cdb6e7a2729, "
"SessionId=fbd9c34f-397d-471d-a109-1b08cc554624");
}
}
#include <winpr/dsparse.h>
#include <openssl/rand.h>
+#include <openssl/bio.h>
+
+#ifdef HAVE_VALGRIND_MEMCHECK_H
+#include <valgrind/memcheck.h>
+#endif
#include "http.h"
#include "ntlm.h"
{
UINT32 alloc_hint = 0;
rpcconn_hdr_t* header;
+ UINT32 frag_length;
+ UINT32 auth_length;
+ UINT32 auth_pad_length;
+ UINT32 sec_trailer_offset;
+ rpc_sec_trailer* sec_trailer;
*offset = RPC_COMMON_FIELDS_LENGTH;
header = ((rpcconn_hdr_t*) buffer);
- if (header->common.ptype == PTYPE_RESPONSE)
- {
- *offset += 8;
- rpc_offset_align(offset, 8);
- alloc_hint = header->response.alloc_hint;
- }
- else if (header->common.ptype == PTYPE_REQUEST)
- {
- *offset += 4;
- rpc_offset_align(offset, 8);
- alloc_hint = header->request.alloc_hint;
- }
- else if (header->common.ptype == PTYPE_RTS)
- {
- *offset += 4;
- }
- else
+ switch (header->common.ptype)
{
- return FALSE;
+ case PTYPE_RESPONSE:
+ *offset += 8;
+ rpc_offset_align(offset, 8);
+ alloc_hint = header->response.alloc_hint;
+ break;
+ case PTYPE_REQUEST:
+ *offset += 4;
+ rpc_offset_align(offset, 8);
+ alloc_hint = header->request.alloc_hint;
+ break;
+ case PTYPE_RTS:
+ *offset += 4;
+ break;
+ default:
+ fprintf(stderr, "%s: unknown ptype=0x%x\n", __FUNCTION__, header->common.ptype);
+ return FALSE;
}
- if (length)
+ if (!length)
+ return TRUE;
+
+ if (header->common.ptype == PTYPE_REQUEST)
{
- if (header->common.ptype == PTYPE_REQUEST)
- {
- UINT32 sec_trailer_offset;
+ UINT32 sec_trailer_offset;
- sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8;
- *length = sec_trailer_offset - *offset;
- }
- else
- {
- UINT32 frag_length;
- UINT32 auth_length;
- UINT32 auth_pad_length;
- UINT32 sec_trailer_offset;
- rpc_sec_trailer* sec_trailer;
+ sec_trailer_offset = header->common.frag_length - header->common.auth_length - 8;
+ *length = sec_trailer_offset - *offset;
+ return TRUE;
+ }
- frag_length = header->common.frag_length;
- auth_length = header->common.auth_length;
- sec_trailer_offset = frag_length - auth_length - 8;
- sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset];
- auth_pad_length = sec_trailer->auth_pad_length;
+ frag_length = header->common.frag_length;
+ auth_length = header->common.auth_length;
+
+ sec_trailer_offset = frag_length - auth_length - 8;
+ sec_trailer = (rpc_sec_trailer*) &buffer[sec_trailer_offset];
+ auth_pad_length = sec_trailer->auth_pad_length;
#if 0
- fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n",
- sec_trailer->auth_type,
- sec_trailer->auth_level,
- sec_trailer->auth_pad_length,
- sec_trailer->auth_reserved,
- sec_trailer->auth_context_id);
+ fprintf(stderr, "sec_trailer: type: %d level: %d pad_length: %d reserved: %d context_id: %d\n",
+ sec_trailer->auth_type,
+ sec_trailer->auth_level,
+ sec_trailer->auth_pad_length,
+ sec_trailer->auth_reserved,
+ sec_trailer->auth_context_id);
#endif
- /**
- * According to [MS-RPCE], auth_pad_length is the number of padding
- * octets used to 4-byte align the security trailer, but in practice
- * we get values up to 15, which indicates 16-byte alignment.
- */
-
- if ((frag_length - (sec_trailer_offset + 8)) != auth_length)
- {
- fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length,
- (frag_length - (sec_trailer_offset + 8)));
- }
+ /**
+ * According to [MS-RPCE], auth_pad_length is the number of padding
+ * octets used to 4-byte align the security trailer, but in practice
+ * we get values up to 15, which indicates 16-byte alignment.
+ */
- *length = frag_length - auth_length - 24 - 8 - auth_pad_length;
- }
+ if ((frag_length - (sec_trailer_offset + 8)) != auth_length)
+ {
+ fprintf(stderr, "invalid auth_length: actual: %d, expected: %d\n", auth_length,
+ (frag_length - (sec_trailer_offset + 8)));
}
+ *length = frag_length - auth_length - 24 - 8 - auth_pad_length;
return TRUE;
}
{
int status;
- status = tls_read(rpc->TlsOut, data, length);
+ status = BIO_read(rpc->TlsOut->bio, data, length);
+ /* fprintf(stderr, "%s: length=%d => status=%d shouldRetry=%d\n", __FUNCTION__, length,
+ * status, BIO_should_retry(rpc->TlsOut->bio)); */
+ if (status > 0) {
+#ifdef HAVE_VALGRIND_MEMCHECK_H
+ VALGRIND_MAKE_MEM_DEFINED(data, status);
+#endif
+ return status;
+ }
- return status;
+ if (BIO_should_retry(rpc->TlsOut->bio))
+ return 0;
+
+ return -1;
}
-int rpc_out_write(rdpRpc* rpc, BYTE* data, int length)
+int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length)
{
int status;
return status;
}
-int rpc_in_write(rdpRpc* rpc, BYTE* data, int length)
+int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length)
{
int status;
ntlm = rpc->ntlm;
- if ((!ntlm) || (!ntlm->table))
+ if (!ntlm || !ntlm->table)
{
- fprintf(stderr, "rpc_write: invalid ntlm context\n");
+ fprintf(stderr, "%s: invalid ntlm context\n", __FUNCTION__);
return -1;
}
if (ntlm->table->QueryContextAttributes(&ntlm->context, SECPKG_ATTR_SIZES, &ntlm->ContextSizes) != SEC_E_OK)
{
- fprintf(stderr, "QueryContextAttributes SECPKG_ATTR_SIZES failure\n");
+ fprintf(stderr, "%s: QueryContextAttributes SECPKG_ATTR_SIZES failure\n", __FUNCTION__);
return -1;
}
- request_pdu = (rpcconn_request_hdr_t*) malloc(sizeof(rpcconn_request_hdr_t));
- ZeroMemory(request_pdu, sizeof(rpcconn_request_hdr_t));
+ request_pdu = (rpcconn_request_hdr_t*) calloc(1, sizeof(rpcconn_request_hdr_t));
+ if (!request_pdu)
+ return -1;
rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) request_pdu);
request_pdu->opnum = opnum;
clientCall = rpc_client_call_new(request_pdu->call_id, request_pdu->opnum);
- ArrayList_Add(rpc->client->ClientCallList, clientCall);
+ if (!clientCall)
+ goto out_free_pdu;
+
+ if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0)
+ goto out_free_clientCall;
if (request_pdu->opnum == TsProxySetupReceivePipeOpnum)
rpc->PipeCallId = request_pdu->call_id;
request_pdu->frag_length = offset;
- buffer = (BYTE*) malloc(request_pdu->frag_length);
-
+ buffer = (BYTE*) calloc(1, request_pdu->frag_length);
+ if (!buffer)
+ goto out_free_pdu;
CopyMemory(buffer, request_pdu, 24);
offset = 24;
Buffers[0].cbBuffer = offset;
Buffers[1].cbBuffer = ntlm->ContextSizes.cbMaxSignature;
- Buffers[1].pvBuffer = malloc(Buffers[1].cbBuffer);
- ZeroMemory(Buffers[1].pvBuffer, Buffers[1].cbBuffer);
+ Buffers[1].pvBuffer = calloc(1, Buffers[1].cbBuffer);
+ if (!Buffers[1].pvBuffer)
+ return -1;
Message.cBuffers = 2;
Message.ulVersion = SECBUFFER_VERSION;
Message.pBuffers = (PSecBuffer) &Buffers;
encrypt_status = ntlm->table->EncryptMessage(&ntlm->context, 0, &Message, rpc->SendSeqNum++);
-
if (encrypt_status != SEC_E_OK)
{
fprintf(stderr, "EncryptMessage status: 0x%08X\n", encrypt_status);
offset += Buffers[1].cbBuffer;
free(Buffers[1].pvBuffer);
- if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) != 0)
+ if (rpc_send_enqueue_pdu(rpc, buffer, request_pdu->frag_length) < 0)
length = -1;
free(request_pdu);
return length;
+
+out_free_clientCall:
+ rpc_client_call_free(clientCall);
+out_free_pdu:
+ free(request_pdu);
+ return -1;
}
BOOL rpc_connect(rdpRpc* rpc)
rpc->CallId = 2;
- rpc_client_new(rpc);
+ if (rpc_client_new(rpc) < 0)
+ goto out_free_virtualConnectionCookieTable;
rpc->client->SynchronousSend = TRUE;
rpc->client->SynchronousReceive = TRUE;
return rpc;
+out_free_virtualConnectionCookieTable:
+ rpc_client_free(rpc);
+ ArrayList_Free(rpc->VirtualConnectionCookieTable);
out_free_virtual_connection:
rpc_client_virtual_connection_free(rpc->VirtualConnection);
out_free_ntlm_http_out:
int rpc_out_read(rdpRpc* rpc, BYTE* data, int length);
-int rpc_out_write(rdpRpc* rpc, BYTE* data, int length);
-int rpc_in_write(rdpRpc* rpc, BYTE* data, int length);
+int rpc_out_write(rdpRpc* rpc, const BYTE* data, int length);
+int rpc_in_write(rdpRpc* rpc, const BYTE* data, int length);
BOOL rpc_get_stub_data_info(rdpRpc* rpc, BYTE* header, UINT32* offset, UINT32* length);
DEBUG_RPC("Sending bind PDU");
rpc->ntlm = ntlm_new();
+ if (!rpc->ntlm)
+ return -1;
if ((!settings->GatewayPassword) || (!settings->GatewayUsername)
|| (!strlen(settings->GatewayPassword)) || (!strlen(settings->GatewayUsername)))
settings->Username = _strdup(settings->GatewayUsername);
settings->Domain = _strdup(settings->GatewayDomain);
settings->Password = _strdup(settings->GatewayPassword);
+
+ if (!settings->Username || !settings->Domain || settings->Password)
+ return -1;
}
}
}
- ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL);
- ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname);
-
- ntlm_authenticate(rpc->ntlm);
+ if (!ntlm_client_init(rpc->ntlm, FALSE, settings->GatewayUsername, settings->GatewayDomain, settings->GatewayPassword, NULL) ||
+ !ntlm_client_make_spn(rpc->ntlm, NULL, settings->GatewayHostname) ||
+ !ntlm_authenticate(rpc->ntlm)
+ )
+ return -1;
- bind_pdu = (rpcconn_bind_hdr_t*) malloc(sizeof(rpcconn_bind_hdr_t));
- ZeroMemory(bind_pdu, sizeof(rpcconn_bind_hdr_t));
+ bind_pdu = (rpcconn_bind_hdr_t*) calloc(1, sizeof(rpcconn_bind_hdr_t));
+ if (!bind_pdu)
+ return -1;
rpc_pdu_header_init(rpc, (rpcconn_hdr_t*) bind_pdu);
bind_pdu->p_context_elem.reserved2 = 0;
bind_pdu->p_context_elem.p_cont_elem = malloc(sizeof(p_cont_elem_t) * bind_pdu->p_context_elem.n_context_elem);
+ if (!bind_pdu->p_context_elem.p_cont_elem)
+ return -1;
p_cont_elem = &bind_pdu->p_context_elem.p_cont_elem[0];
bind_pdu->frag_length = offset;
buffer = (BYTE*) malloc(bind_pdu->frag_length);
+ if (!buffer)
+ return -1;
CopyMemory(buffer, bind_pdu, 24);
CopyMemory(&buffer[24], &bind_pdu->p_context_elem, 4);
length = bind_pdu->frag_length;
clientCall = rpc_client_call_new(bind_pdu->call_id, 0);
- ArrayList_Add(rpc->client->ClientCallList, clientCall);
+ if (!clientCall)
+ return -1;
+ if (ArrayList_Add(rpc->client->ClientCallList, clientCall) < 0)
+ return -1;
if (rpc_send_enqueue_pdu(rpc, buffer, length) != 0)
length = -1;
#include <winpr/stream.h>
#include "rpc_fault.h"
-
#include "rpc_client.h"
-
#include "../rdp.h"
#define SYNCHRONOUS_TIMEOUT 5000
if (!pdu)
{
- pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU));
+ pdu = (RPC_PDU *)malloc(sizeof(RPC_PDU));
+ if (!pdu)
+ return NULL;
pdu->s = Stream_New(NULL, rpc->max_recv_frag);
+ if (!pdu->s)
+ {
+ free(pdu);
+ return NULL;
+ }
}
pdu->CallId = 0;
int rpc_client_receive_pool_return(rdpRpc* rpc, RPC_PDU* pdu)
{
- Queue_Enqueue(rpc->client->ReceivePool, pdu);
- return 0;
+ return Queue_Enqueue(rpc->client->ReceivePool, pdu) == TRUE ? 0 : -1;
}
int rpc_client_on_fragment_received_event(rdpRpc* rpc)
rpcconn_hdr_t* header;
freerdp* instance;
- instance = (freerdp*) rpc->transport->settings->instance;
+ instance = (freerdp *)rpc->transport->settings->instance;
if (!rpc->client->pdu)
rpc->client->pdu = rpc_client_receive_pool_take(rpc);
return 0;
}
- if (header->common.ptype == PTYPE_RTS)
+ switch (header->common.ptype)
{
- if (rpc->VirtualConnection->State >= VIRTUAL_CONNECTION_STATE_OPENED)
- {
- //fprintf(stderr, "Receiving Out-of-Sequence RTS PDU\n");
+ case PTYPE_RTS:
+ if (rpc->VirtualConnection->State < VIRTUAL_CONNECTION_STATE_OPENED)
+ {
+ fprintf(stderr, "%s: warning: unhandled RTS PDU\n", __FUNCTION__);
+ return 0;
+ }
+ fprintf(stderr, "%s: Receiving Out-of-Sequence RTS PDU\n", __FUNCTION__);
rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length);
-
rpc_client_fragment_pool_return(rpc, fragment);
- }
- else
- {
- fprintf(stderr, "warning: unhandled RTS PDU\n");
- }
+ return 0;
- return 0;
- }
- else if (header->common.ptype == PTYPE_FAULT)
- {
- rpc_recv_fault_pdu(header);
- Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
- return -1;
- }
-
- if (header->common.ptype != PTYPE_RESPONSE)
- {
- fprintf(stderr, "Unexpected RPC PDU type: %d\n", header->common.ptype);
- Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
- return -1;
+ case PTYPE_FAULT:
+ rpc_recv_fault_pdu(header);
+ Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
+ return -1;
+ case PTYPE_RESPONSE:
+ break;
+ default:
+ fprintf(stderr, "%s: unexpected RPC PDU type %d\n", __FUNCTION__, header->common.ptype);
+ Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
+ return -1;
}
rpc->VirtualConnection->DefaultOutChannel->BytesReceived += header->common.frag_length;
if (!rpc_get_stub_data_info(rpc, buffer, &StubOffset, &StubLength))
{
- fprintf(stderr, "rpc_recv_pdu_fragment: expected stub\n");
+ fprintf(stderr, "%s: expected stub\n", __FUNCTION__);
Queue_Enqueue(rpc->client->ReceiveQueue, NULL);
return -1;
}
if (rpc->StubCallId != header->common.call_id)
{
- fprintf(stderr, "invalid call_id: actual: %d, expected: %d, frag_count: %d\n",
+ fprintf(stderr, "%s: invalid call_id: actual: %d, expected: %d, frag_count: %d\n", __FUNCTION__,
rpc->StubCallId, header->common.call_id, rpc->StubFragCount);
}
int status = -1;
rpcconn_common_hdr_t* header;
- if (!rpc->client->RecvFrag)
- rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
-
- position = Stream_GetPosition(rpc->client->RecvFrag);
-
- if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
+ while (1)
{
- status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
- RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag));
+ if (!rpc->client->RecvFrag)
+ rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
- if (status < 0)
+ position = Stream_GetPosition(rpc->client->RecvFrag);
+
+ while (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
{
- fprintf(stderr, "rpc_client_frag_read: error reading header\n");
- return -1;
+ status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
+ RPC_COMMON_FIELDS_LENGTH - Stream_GetPosition(rpc->client->RecvFrag));
+
+ if (status < 0)
+ {
+ fprintf(stderr, "rpc_client_frag_read: error reading header\n");
+ return -1;
+ }
+
+ if (!status)
+ return 0;
+
+ Stream_Seek(rpc->client->RecvFrag, status);
}
- Stream_Seek(rpc->client->RecvFrag, status);
- }
+ if (Stream_GetPosition(rpc->client->RecvFrag) < RPC_COMMON_FIELDS_LENGTH)
+ return status;
+
- if (Stream_GetPosition(rpc->client->RecvFrag) >= RPC_COMMON_FIELDS_LENGTH)
- {
header = (rpcconn_common_hdr_t*) Stream_Buffer(rpc->client->RecvFrag);
if (header->frag_length > rpc->max_recv_frag)
return -1;
}
- if (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length)
+ while (Stream_GetPosition(rpc->client->RecvFrag) < header->frag_length)
{
status = rpc_out_read(rpc, Stream_Pointer(rpc->client->RecvFrag),
header->frag_length - Stream_GetPosition(rpc->client->RecvFrag));
if (status < 0)
{
- fprintf(stderr, "rpc_client_frag_read: error reading fragment body\n");
+ fprintf(stderr, "%s: error reading fragment body\n", __FUNCTION__);
return -1;
}
+ if (!status)
+ return 0;
+
Stream_Seek(rpc->client->RecvFrag, status);
}
- }
- else
- {
- return status;
- }
- if (status < 0)
- return -1;
+ if (status < 0)
+ return -1;
- status = Stream_GetPosition(rpc->client->RecvFrag) - position;
+ status = Stream_GetPosition(rpc->client->RecvFrag) - position;
- if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length)
- {
- /* complete fragment received */
+ if (Stream_GetPosition(rpc->client->RecvFrag) >= header->frag_length)
+ {
+ /* complete fragment received */
- Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag);
- Stream_SetPosition(rpc->client->RecvFrag, 0);
+ Stream_Length(rpc->client->RecvFrag) = Stream_GetPosition(rpc->client->RecvFrag);
+ Stream_SetPosition(rpc->client->RecvFrag, 0);
- Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag);
- rpc->client->RecvFrag = NULL;
+ Queue_Enqueue(rpc->client->FragmentQueue, rpc->client->RecvFrag);
+ rpc->client->RecvFrag = NULL;
- if (rpc_client_on_fragment_received_event(rpc) < 0)
- return -1;
+ if (rpc_client_on_fragment_received_event(rpc) < 0)
+ return -1;
+ }
}
- return status;
+ return 0;
}
/**
RpcClientCall* clientCall;
clientCall = (RpcClientCall*) malloc(sizeof(RpcClientCall));
+ if (!clientCall)
+ return NULL;
- if (clientCall)
- {
- clientCall->CallId = CallId;
- clientCall->OpNum = OpNum;
- clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS;
- }
+ clientCall->CallId = CallId;
+ clientCall->OpNum = OpNum;
+ clientCall->State = RPC_CLIENT_CALL_STATE_SEND_PDUS;
return clientCall;
}
int status;
pdu = (RPC_PDU*) malloc(sizeof(RPC_PDU));
+ if (!pdu)
+ return -1;
+
pdu->s = Stream_New(buffer, length);
+ if (!pdu->s)
+ goto out_free;
- Queue_Enqueue(rpc->client->SendQueue, pdu);
+ if (!Queue_Enqueue(rpc->client->SendQueue, pdu))
+ goto out_free_stream;
if (rpc->client->SynchronousSend)
{
status = WaitForSingleObject(rpc->client->PduSentEvent, SYNCHRONOUS_TIMEOUT);
if (status == WAIT_TIMEOUT)
{
- fprintf(stderr, "rpc_send_enqueue_pdu: timed out waiting for pdu sent event\n");
+ fprintf(stderr, "%s: timed out waiting for pdu sent event %p\n", __FUNCTION__, rpc->client->PduSentEvent);
return -1;
}
}
return 0;
+
+out_free_stream:
+ Stream_Free(pdu->s, TRUE);
+out_free:
+ free(pdu);
+ return -1;
}
int rpc_send_dequeue_pdu(rdpRpc* rpc)
RPC_PDU* pdu;
RpcClientCall* clientCall;
rpcconn_common_hdr_t* header;
+ RpcInChannel *inChannel;
pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->SendQueue);
-
if (!pdu)
return 0;
- WaitForSingleObject(rpc->VirtualConnection->DefaultInChannel->Mutex, INFINITE);
+ inChannel = rpc->VirtualConnection->DefaultInChannel;
+ WaitForSingleObject(inChannel->Mutex, INFINITE);
status = rpc_in_write(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s));
clientCall = rpc_client_call_find_by_id(rpc, header->call_id);
clientCall->State = RPC_CLIENT_CALL_STATE_DISPATCHED;
- ReleaseMutex(rpc->VirtualConnection->DefaultInChannel->Mutex);
+ ReleaseMutex(inChannel->Mutex);
/*
* This protocol specifies that only RPC PDUs are subject to the flow control abstract
if (header->ptype == PTYPE_REQUEST)
{
- rpc->VirtualConnection->DefaultInChannel->BytesSent += status;
- rpc->VirtualConnection->DefaultInChannel->SenderAvailableWindow -= status;
+ inChannel->BytesSent += status;
+ inChannel->SenderAvailableWindow -= status;
}
Stream_Free(pdu->s, TRUE);
DWORD dwMilliseconds;
DWORD result;
- pdu = NULL;
- dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0;
+ dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT * 4 : 0;
result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds);
if (result == WAIT_TIMEOUT)
{
- fprintf(stderr, "rpc_recv_dequeue_pdu: timed out waiting for receive event\n");
+ fprintf(stderr, "%s: timed out waiting for receive event\n", __FUNCTION__);
return NULL;
}
- if (result == WAIT_OBJECT_0)
- {
- pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->ReceiveQueue);
+ if (result != WAIT_OBJECT_0)
+ return NULL;
-#ifdef WITH_DEBUG_TSG
- if (pdu)
- {
- fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId);
- winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s));
- fprintf(stderr, "\n");
- }
-#endif
+ pdu = (RPC_PDU *)Queue_Dequeue(rpc->client->ReceiveQueue);
- return pdu;
+#ifdef WITH_DEBUG_TSG
+ if (pdu)
+ {
+ fprintf(stderr, "Receiving PDU (length: %d, CallId: %d)\n", pdu->s->length, pdu->CallId);
+ winpr_HexDump(Stream_Buffer(pdu->s), Stream_Length(pdu->s));
+ fprintf(stderr, "\n");
+ }
+ else
+ {
+ fprintf(stderr, "Receiving a NULL PDU\n");
}
+#endif
return pdu;
}
RPC_PDU* rpc_recv_peek_pdu(rdpRpc* rpc)
{
- RPC_PDU* pdu;
DWORD dwMilliseconds;
DWORD result;
- pdu = NULL;
dwMilliseconds = rpc->client->SynchronousReceive ? SYNCHRONOUS_TIMEOUT : 0;
result = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), dwMilliseconds);
- if (result == WAIT_TIMEOUT)
- {
+ if (result != WAIT_OBJECT_0)
return NULL;
- }
-
- if (result == WAIT_OBJECT_0)
- {
- pdu = (RPC_PDU*) Queue_Peek(rpc->client->ReceiveQueue);
- return pdu;
- }
- return pdu;
+ return (RPC_PDU *)Queue_Peek(rpc->client->ReceiveQueue);
}
static void* rpc_client_thread(void* arg)
DWORD nCount;
HANDLE events[3];
HANDLE ReadEvent;
+ int fd;
rpc = (rdpRpc*) arg;
+ fd = BIO_get_fd(rpc->TlsOut->bio, NULL);
- ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, rpc->TlsOut->sockfd);
+ ReadEvent = CreateFileDescriptorEvent(NULL, TRUE, FALSE, fd);
nCount = 0;
events[nCount++] = rpc->client->StopEvent;
events[nCount++] = Queue_Event(rpc->client->SendQueue);
events[nCount++] = ReadEvent;
+ /* Do a first free run in case some bytes were set from the HTTP headers.
+ * We also have to do it because most of the time the underlying socket has notified,
+ * and the ssl layer has eaten all bytes, so we won't be notified any more even if the
+ * bytes are buffered locally
+ */
+ if (rpc_client_on_read_event(rpc) < 0)
+ {
+ fprintf(stderr, "%s: an error occured when treating first packet\n", __FUNCTION__);
+ goto out;
+ }
+
while (rpc->transport->layer != TRANSPORT_LAYER_CLOSED)
{
status = WaitForMultipleObjects(nCount, events, FALSE, 100);
- if (status != WAIT_TIMEOUT)
+ if (status == WAIT_TIMEOUT)
+ continue;
+
+ if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0)
+ break;
+
+ if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0)
{
- if (WaitForSingleObject(rpc->client->StopEvent, 0) == WAIT_OBJECT_0)
- {
+ if (rpc_client_on_read_event(rpc) < 0)
break;
- }
-
- if (WaitForSingleObject(ReadEvent, 0) == WAIT_OBJECT_0)
- {
- if (rpc_client_on_read_event(rpc) < 0)
- break;
- }
+ }
- if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0)
- {
- rpc_send_dequeue_pdu(rpc);
- }
+ if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0)
+ {
+ rpc_send_dequeue_pdu(rpc);
}
}
+out:
CloseHandle(ReadEvent);
return NULL;
static void rpc_pdu_free(RPC_PDU* pdu)
{
+ if (!pdu)
+ return;
+
Stream_Free(pdu->s, TRUE);
free(pdu);
}
{
RpcClient* client = NULL;
- client = (RpcClient*) calloc(1, sizeof(RpcClient));
+ client = (RpcClient *)calloc(1, sizeof(RpcClient));
+ rpc->client = client;
+ if (!client)
+ return -1;
- if (client)
- {
- client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
- client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+ client->Thread = CreateThread(NULL, 0,
+ (LPTHREAD_START_ROUTINE) rpc_client_thread,
+ rpc, CREATE_SUSPENDED, NULL);
+ if (!client->Thread)
+ return -1;
- client->SendQueue = Queue_New(TRUE, -1, -1);
- Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
+ client->StopEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+ if (!client->StopEvent)
+ return -1;
+ client->PduSentEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+ if (!client->PduSentEvent)
+ return -1;
- client->pdu = NULL;
- client->ReceivePool = Queue_New(TRUE, -1, -1);
- client->ReceiveQueue = Queue_New(TRUE, -1, -1);
- Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
- Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
+ client->SendQueue = Queue_New(TRUE, -1, -1);
+ if (!client->SendQueue)
+ return -1;
+ Queue_Object(client->SendQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
- client->RecvFrag = NULL;
- client->FragmentPool = Queue_New(TRUE, -1, -1);
- client->FragmentQueue = Queue_New(TRUE, -1, -1);
+ client->pdu = NULL;
+ client->ReceivePool = Queue_New(TRUE, -1, -1);
+ if (!client->ReceivePool)
+ return -1;
+ Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
- Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
- Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
+ client->ReceiveQueue = Queue_New(TRUE, -1, -1);
+ if (!client->ReceiveQueue)
+ return -1;
+ Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
- client->ClientCallList = ArrayList_New(TRUE);
- ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free;
- }
+ client->RecvFrag = NULL;
+ client->FragmentPool = Queue_New(TRUE, -1, -1);
+ if (!client->FragmentPool)
+ return -1;
+ Queue_Object(client->FragmentPool)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
- rpc->client = client;
+ client->FragmentQueue = Queue_New(TRUE, -1, -1);
+ if (!client->FragmentQueue)
+ return -1;
+ Queue_Object(client->FragmentQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_fragment_free;
+ client->ClientCallList = ArrayList_New(TRUE);
+ if (!client->ClientCallList)
+ return -1;
+ ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free;
return 0;
}
rpc->client->Thread = NULL;
}
- rpc_client_free(rpc);
-
- return 0;
+ return rpc_client_free(rpc);
}
int rpc_client_free(rdpRpc* rpc)
client = rpc->client;
- if (client)
- {
+ if (!client)
+ return 0;
+
+ if (client->SendQueue)
Queue_Free(client->SendQueue);
- if (client->RecvFrag)
- rpc_fragment_free(client->RecvFrag);
+ if (client->RecvFrag)
+ rpc_fragment_free(client->RecvFrag);
+ if (client->FragmentPool)
Queue_Free(client->FragmentPool);
+ if (client->FragmentQueue)
Queue_Free(client->FragmentQueue);
- if (client->pdu)
- rpc_pdu_free(client->pdu);
+ if (client->pdu)
+ rpc_pdu_free(client->pdu);
+ if (client->ReceivePool)
Queue_Free(client->ReceivePool);
+ if (client->ReceiveQueue)
Queue_Free(client->ReceiveQueue);
+ if (client->ClientCallList)
ArrayList_Free(client->ClientCallList);
+ if (client->StopEvent)
CloseHandle(client->StopEvent);
+ if (client->PduSentEvent)
CloseHandle(client->PduSentEvent);
+ if (client->Thread)
CloseHandle(client->Thread);
- free(client);
- }
-
+ free(client);
return 0;
}
if (!rpc_ntlm_http_out_connect(rpc))
{
- fprintf(stderr, "rpc_out_connect_http error!\n");
+ fprintf(stderr, "%s: rpc_out_connect_http error!\n", __FUNCTION__);
return FALSE;
}
if (rts_send_CONN_A1_pdu(rpc) != 0)
{
- fprintf(stderr, "rpc_send_CONN_A1_pdu error!\n");
+ fprintf(stderr, "%s: rpc_send_CONN_A1_pdu error!\n", __FUNCTION__);
return FALSE;
}
if (!rpc_ntlm_http_in_connect(rpc))
{
- fprintf(stderr, "rpc_in_connect_http error!\n");
+ fprintf(stderr, "%s: rpc_in_connect_http error!\n", __FUNCTION__);
return FALSE;
}
- if (rts_send_CONN_B1_pdu(rpc) != 0)
+ if (rts_send_CONN_B1_pdu(rpc) < 0)
{
- fprintf(stderr, "rpc_send_CONN_B1_pdu error!\n");
+ fprintf(stderr, "%s: rpc_send_CONN_B1_pdu error!\n", __FUNCTION__);
return FALSE;
}
*/
http_response = http_response_recv(rpc->TlsOut);
+ if (!http_response)
+ {
+ fprintf(stderr, "%s: unable to retrieve OUT Channel Response!\n", __FUNCTION__);
+ return FALSE;
+ }
if (http_response->StatusCode != HTTP_STATUS_OK)
{
- fprintf(stderr, "rts_connect error! Status Code: %d\n", http_response->StatusCode);
+ fprintf(stderr, "%s: error! Status Code: %d\n", __FUNCTION__, http_response->StatusCode);
http_response_print(http_response);
http_response_free(http_response);
return FALSE;
}
+ if (http_response->bodyLen)
+ {
+ /* inject bytes we have read in the body as a received packet for the RPC client */
+ rpc->client->RecvFrag = rpc_client_fragment_pool_take(rpc);
+ Stream_EnsureCapacity(rpc->client->RecvFrag, http_response->bodyLen);
+ CopyMemory(rpc->client->RecvFrag, http_response->BodyContent, http_response->bodyLen);
+ }
+
//http_response_print(http_response);
http_response_free(http_response);
rpc_client_start(rpc);
pdu = rpc_recv_dequeue_pdu(rpc);
-
if (!pdu)
return FALSE;
if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_A3_SIGNATURE, rts))
{
- fprintf(stderr, "Unexpected RTS PDU: Expected CONN/A3\n");
+ fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/A3\n", __FUNCTION__);
return FALSE;
}
*/
pdu = rpc_recv_dequeue_pdu(rpc);
-
if (!pdu)
return FALSE;
if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_C2_SIGNATURE, rts))
{
- fprintf(stderr, "Unexpected RTS PDU: Expected CONN/C2\n");
+ fprintf(stderr, "%s: unexpected RTS PDU: Expected CONN/C2\n", __FUNCTION__);
return FALSE;
}
return TRUE;
}
-#if defined WITH_DEBUG_RTS && 0
+#ifdef WITH_DEBUG_RTS
static const char* const RTS_CMD_STRINGS[] =
{
void rts_pdu_header_init(rpcconn_rts_hdr_t* header)
{
+ ZeroMemory(header, sizeof(*header));
header->rpc_vers = 5;
header->rpc_vers_minor = 0;
header->ptype = PTYPE_RTS;
ReceiveWindowSize = rpc->VirtualConnection->DefaultOutChannel->ReceiveWindow;
buffer = (BYTE*) malloc(header.frag_length);
+ if (!buffer)
+ return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_version_command_write(&buffer[20]); /* Version (8 bytes) */
BYTE* INChannelCookie;
BYTE* AssociationGroupId;
BYTE* VirtualConnectionCookie;
+ int status;
rts_pdu_header_init(&header);
header.frag_length = 104;
AssociationGroupId = (BYTE*) &(rpc->VirtualConnection->AssociationGroupId);
buffer = (BYTE*) malloc(header.frag_length);
+ if (!buffer)
+ return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_version_command_write(&buffer[20]); /* Version (8 bytes) */
length = header.frag_length;
- rpc_in_write(rpc, buffer, length);
+ status = rpc_in_write(rpc, buffer, length);
free(buffer);
- return 0;
+ return status;
}
/* CONN/C Sequence */
DEBUG_RPC("Sending Keep-Alive RTS PDU");
buffer = (BYTE*) malloc(header.frag_length);
+ if (!buffer)
+ return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_client_keepalive_command_write(&buffer[20], rpc->CurrentKeepAliveInterval); /* ClientKeepAlive (8 bytes) */
length = header.frag_length;
- rpc_in_write(rpc, buffer, length);
+ if (rpc_in_write(rpc, buffer, length) < 0)
+ return -1;
free(buffer);
return length;
rpc->VirtualConnection->DefaultOutChannel->AvailableWindowAdvertised;
buffer = (BYTE*) malloc(header.frag_length);
+ if (!buffer)
+ return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
rts_destination_command_write(&buffer[20], FDOutProxy); /* Destination Command (8 bytes) */
length = header.frag_length;
- rpc_in_write(rpc, buffer, length);
+ if (rpc_in_write(rpc, buffer, length) < 0)
+ return -1;
free(buffer);
return 0;
DEBUG_RPC("Sending Ping RTS PDU");
buffer = (BYTE*) malloc(header.frag_length);
+ if (!buffer)
+ return -1;
CopyMemory(buffer, ((BYTE*) &header), 20); /* RTS Header (20 bytes) */
length = header.frag_length;
- rpc_in_write(rpc, buffer, length);
+ if (rpc_in_write(rpc, buffer, length) < 0)
+ return -1;
free(buffer);
return length;
rts_extract_pdu_signature(rpc, &signature, rts);
SignatureId = rts_identify_pdu_signature(rpc, &signature, NULL);
- if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK)
+ switch (SignatureId)
{
- return rts_recv_flow_control_ack_pdu(rpc, buffer, length);
- }
- else if (SignatureId == RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION)
- {
- return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length);
- }
- else if (SignatureId == RTS_PDU_PING)
- {
- rts_send_ping_pdu(rpc);
- }
- else
- {
- fprintf(stderr, "Unimplemented signature id: 0x%08X\n", SignatureId);
- rts_print_pdu_signature(rpc, &signature);
+ case RTS_PDU_FLOW_CONTROL_ACK:
+ return rts_recv_flow_control_ack_pdu(rpc, buffer, length);
+ case RTS_PDU_FLOW_CONTROL_ACK_WITH_DESTINATION:
+ return rts_recv_flow_control_ack_with_destination_pdu(rpc, buffer, length);
+ case RTS_PDU_PING:
+ return rts_send_ping_pdu(rpc);
+ default:
+ fprintf(stderr, "%s: unimplemented signature id: 0x%08X\n", __FUNCTION__, SignatureId);
+ rts_print_pdu_signature(rpc, &signature);
+ break;
}
return 0;
return FALSE;
status = rts_command_length(rpc, CommandType, &buffer[offset], length);
-
if (status < 0)
return FALSE;
signature->CommandTypes[i] = CommandType;
status = rts_command_length(rpc, CommandType, &buffer[offset], length);
-
if (status < 0)
return FALSE;
{
pSignature = RTS_PDU_SIGNATURE_TABLE[i].Signature;
- if (signature->Flags == pSignature->Flags)
+ if (signature->Flags != pSignature->Flags)
+ continue;
+
+ if (signature->NumberOfCommands != pSignature->NumberOfCommands)
+ continue;
+
+ for (j = 0; j < signature->NumberOfCommands; j++)
{
- if (signature->NumberOfCommands == pSignature->NumberOfCommands)
- {
- for (j = 0; j < signature->NumberOfCommands; j++)
- {
- if (signature->CommandTypes[j] != pSignature->CommandTypes[j])
- continue;
- }
-
- if (entry)
- *entry = &RTS_PDU_SIGNATURE_TABLE[i];
-
- return RTS_PDU_SIGNATURE_TABLE[i].SignatureId;
- }
+ if (signature->CommandTypes[j] != pSignature->CommandTypes[j])
+ continue;
}
+
+ if (entry)
+ *entry = &RTS_PDU_SIGNATURE_TABLE[i];
+
+ return RTS_PDU_SIGNATURE_TABLE[i].SignatureId;
}
return 0;
#include <winpr/stream.h>
#include "rpc_client.h"
-
#include "tsg.h"
+
/**
* RPC Functions: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378623/
* Remote Procedure Call: http://msdn.microsoft.com/en-us/library/windows/desktop/aa378651/
}
length = 28 + totalDataBytes;
- buffer = (BYTE*) malloc(length);
+ buffer = (BYTE*) calloc(1, length);
+ if (!buffer)
+ return -1;
s = Stream_New(buffer, length);
if (!(pdu->Flags & RPC_PDU_FLAG_STUB))
buffer = &buffer[24];
- packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET));
- ZeroMemory(packet, sizeof(TSG_PACKET));
+ packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET));
+ if (!packet)
+ return FALSE;
offset = 4; // Skip Packet Pointer
packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */
if ((packet->packetId == TSG_PACKET_TYPE_CAPS_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_CAPS_RESPONSE))
{
- packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) malloc(sizeof(TSG_PACKET_CAPS_RESPONSE));
- ZeroMemory(packetCapsResponse, sizeof(TSG_PACKET_CAPS_RESPONSE));
+ packetCapsResponse = (PTSG_PACKET_CAPS_RESPONSE) calloc(1, sizeof(TSG_PACKET_CAPS_RESPONSE));
+ if (!packetCapsResponse) // TODO: correct cleanup
+ return FALSE;
packet->tsgPacket.packetCapsResponse = packetCapsResponse;
/* PacketQuarResponsePtr (4 bytes) */
IsMessagePresent = *((UINT32*) &buffer[offset]);
offset += 4;
MessageSwitchValue = *((UINT32*) &buffer[offset]);
- DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d",
- IsMessagePresent, MessageSwitchValue);
+ DEBUG_TSG("IsMessagePresent %d MessageSwitchValue %d", IsMessagePresent, MessageSwitchValue);
offset += 4;
}
offset += 4;
}
- versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS));
- ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS));
+ versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS));
+ if (!versionCaps) // TODO: correct cleanup
+ return FALSE;
packetCapsResponse->pktQuarEncResponse.versionCaps = versionCaps;
versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */
/* 4-byte alignment */
rpc_offset_align(&offset, 4);
- tsgCaps = (PTSG_PACKET_CAPABILITIES) malloc(sizeof(TSG_PACKET_CAPABILITIES));
- ZeroMemory(tsgCaps, sizeof(TSG_PACKET_CAPABILITIES));
+ tsgCaps = (PTSG_PACKET_CAPABILITIES) calloc(1, sizeof(TSG_PACKET_CAPABILITIES));
+ if (!tsgCaps)
+ return FALSE;
+
versionCaps->tsgCaps = tsgCaps;
offset += 4; /* MaxCount (4 bytes) */
}
else if ((packet->packetId == TSG_PACKET_TYPE_QUARENC_RESPONSE) && (SwitchValue == TSG_PACKET_TYPE_QUARENC_RESPONSE))
{
- packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) malloc(sizeof(TSG_PACKET_QUARENC_RESPONSE));
- ZeroMemory(packetQuarEncResponse, sizeof(TSG_PACKET_QUARENC_RESPONSE));
+ packetQuarEncResponse = (PTSG_PACKET_QUARENC_RESPONSE) calloc(1, sizeof(TSG_PACKET_QUARENC_RESPONSE));
+ if (!packetQuarEncResponse) // TODO: handle cleanup
+ return FALSE;
packet->tsgPacket.packetQuarEncResponse = packetQuarEncResponse;
/* PacketQuarResponsePtr (4 bytes) */
offset += 4;
}
- versionCaps = (PTSG_PACKET_VERSIONCAPS) malloc(sizeof(TSG_PACKET_VERSIONCAPS));
- ZeroMemory(versionCaps, sizeof(TSG_PACKET_VERSIONCAPS));
+ versionCaps = (PTSG_PACKET_VERSIONCAPS) calloc(1, sizeof(TSG_PACKET_VERSIONCAPS));
+ if (!versionCaps) // TODO: handle cleanup
+ return FALSE;
packetQuarEncResponse->versionCaps = versionCaps;
versionCaps->tsgHeader.ComponentId = *((UINT16*) &buffer[offset]); /* ComponentId */
if (!(pdu->Flags & RPC_PDU_FLAG_STUB))
buffer = &buffer[24];
- packet = (PTSG_PACKET) malloc(sizeof(TSG_PACKET));
- ZeroMemory(packet, sizeof(TSG_PACKET));
+ packet = (PTSG_PACKET) calloc(1, sizeof(TSG_PACKET));
+ if (!packet)
+ return FALSE;
offset = 4;
packet->packetId = *((UINT32*) &buffer[offset]); /* PacketId */
length = 60 + (count * 2);
buffer = (BYTE*) malloc(length);
+ if (!buffer)
+ return FALSE;
/* TunnelContext */
handle = (CONTEXT_HANDLE*) tunnelContext;
return CopyLength;
}
- else
- {
- tsg->pdu = rpc_recv_peek_pdu(rpc);
- if (!tsg->pdu)
- {
- if (tsg->rpc->client->SynchronousReceive)
- return tsg_read(tsg, data, length);
- else
- return 0;
- }
- tsg->PendingPdu = TRUE;
- tsg->BytesAvailable = Stream_Length(tsg->pdu->s);
- tsg->BytesRead = 0;
+ tsg->pdu = rpc_recv_peek_pdu(rpc);
+ if (!tsg->pdu)
+ {
+ if (!tsg->rpc->client->SynchronousReceive)
+ return 0;
- CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable;
+ // weird !!!!
+ return tsg_read(tsg, data, length);
+ }
- CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength);
- tsg->BytesAvailable -= CopyLength;
- tsg->BytesRead += CopyLength;
+ tsg->PendingPdu = TRUE;
+ tsg->BytesAvailable = Stream_Length(tsg->pdu->s);
+ tsg->BytesRead = 0;
- if (tsg->BytesAvailable < 1)
- {
- tsg->PendingPdu = FALSE;
- rpc_recv_dequeue_pdu(rpc);
- rpc_client_receive_pool_return(rpc, tsg->pdu);
- }
+ CopyLength = (length < tsg->BytesAvailable) ? length : tsg->BytesAvailable;
- return CopyLength;
+ CopyMemory(data, &tsg->pdu->s->buffer[tsg->BytesRead], CopyLength);
+ tsg->BytesAvailable -= CopyLength;
+ tsg->BytesRead += CopyLength;
+
+ if (tsg->BytesAvailable < 1)
+ {
+ tsg->PendingPdu = FALSE;
+ rpc_recv_dequeue_pdu(rpc);
+ rpc_client_receive_pool_return(rpc, tsg->pdu);
}
+
+ return CopyLength;
+
}
int tsg_write(rdpTsg* tsg, BYTE* data, UINT32 length)
{
+ int status;
+
if (tsg->rpc->transport->layer == TRANSPORT_LAYER_CLOSED)
{
- fprintf(stderr, "tsg_write error: connection lost\n");
+ fprintf(stderr, "%s: error, connection lost\n", __FUNCTION__);
return -1;
}
- return TsProxySendToServer((handle_t) tsg, data, 1, &length);
+ status = TsProxySendToServer((handle_t) tsg, data, 1, &length);
+ if (status < 0)
+ return -1;
+ return length;
}
BOOL tsg_set_blocking_mode(rdpTsg* tsg, BOOL blocking)
{
rdpTsg* tsg;
- tsg = (rdpTsg*) malloc(sizeof(rdpTsg));
- ZeroMemory(tsg, sizeof(rdpTsg));
-
- if (tsg != NULL)
- {
- tsg->transport = transport;
- tsg->settings = transport->settings;
- tsg->rpc = rpc_new(tsg->transport);
- tsg->PendingPdu = FALSE;
- }
+ tsg = (rdpTsg*) calloc(1, sizeof(rdpTsg));
+ if (!tsg)
+ return NULL;
+ tsg->transport = transport;
+ tsg->settings = transport->settings;
+ tsg->rpc = rpc_new(tsg->transport);
+ if (!tsg->rpc)
+ goto out_free;
+ tsg->PendingPdu = FALSE;
return tsg;
+
+out_free:
+ free(tsg);
+ return NULL;
}
void tsg_free(rdpTsg* tsg)
fprintf(stderr, "%s: inavlid RDP key file %s\n", __FUNCTION__, settings->RdpKeyFile);
return FALSE;
}
+
if (settings->RdpServerRsaKey->ModulusLength > 256)
{
fprintf(stderr, "%s: Key sizes > 2048 are currently not supported for RDP security.\n", __FUNCTION__);
fprintf(stderr, "%s: Set a different key file than %s\n", __FUNCTION__, settings->RdpKeyFile);
exit(1);
}
-
}
return TRUE;
return client->context->rdp->transport->TcpIn->event;
}
-static BOOL freerdp_peer_check_fds(freerdp_peer* client)
+
+static BOOL freerdp_peer_check_fds(freerdp_peer* peer)
{
int status;
rdpRdp* rdp;
- rdp = client->context->rdp;
+ rdp = peer->context->rdp;
status = rdp_check_fds(rdp);
return rdp_send_channel_data(client->context->rdp, channelId, data, size);
}
+static BOOL freerdp_peer_is_write_blocked(freerdp_peer* peer)
+{
+ return tranport_is_write_blocked(peer->context->rdp->transport);
+}
+
+static int freerdp_peer_drain_output_buffer(freerdp_peer* peer)
+{
+
+ rdpTransport *transport = peer->context->rdp->transport;
+
+ return tranport_drain_output_buffer(transport);
+}
+
void freerdp_peer_context_new(freerdp_peer* client)
{
rdpRdp* rdp;
rdp->transport->ReceiveExtra = client;
transport_set_blocking_mode(rdp->transport, FALSE);
+ client->IsWriteBlocked = freerdp_peer_is_write_blocked;
+ client->DrainOutputBuffer = freerdp_peer_drain_output_buffer;
+
IFCALL(client->ContextNew, client, client->context);
}
client->Close = freerdp_peer_close;
client->Disconnect = freerdp_peer_disconnect;
client->SendChannelData = freerdp_peer_send_channel_data;
+ client->IsWriteBlocked = freerdp_peer_is_write_blocked;
+ client->DrainOutputBuffer = freerdp_peer_drain_output_buffer;
}
return client;
void freerdp_peer_free(freerdp_peer* client)
{
- if (client)
- {
- rdp_free(client->context->rdp);
- free(client->context);
- free(client);
- }
+ if (!client)
+ return;
+
+ rdp_free(client->context->rdp);
+ free(client->context);
+ free(client);
}
ZeroMemory(settings, sizeof(rdpSettings));
settings->ServerMode = (flags & FREERDP_SETTINGS_SERVER_MODE) ? TRUE : FALSE;
+ settings->WaitForOutputBufferFlush = TRUE;
settings->DesktopWidth = 1024;
settings->DesktopHeight = 768;
/* BOOL values */
_settings->ServerMode = settings->ServerMode; /* 16 */
+ _settings->WaitForOutputBufferFlush = settings->WaitForOutputBufferFlush; /* 25 */
_settings->NetworkAutoDetect = settings->NetworkAutoDetect; /* 137 */
_settings->SupportAsymetricKeys = settings->SupportAsymetricKeys; /* 138 */
_settings->SupportErrorInfoPdu = settings->SupportErrorInfoPdu; /* 139 */
#include "tcp.h"
+long transport_bio_buffered_callback(BIO* bio, int mode, const char* argp, int argi, long argl, long ret)
+{
+ return 1;
+}
+
+static int transport_bio_buffered_write(BIO* bio, const char* buf, int num)
+{
+ int status, ret;
+ rdpTcp *tcp = (rdpTcp *)bio->ptr;
+ int nchunks, committedBytes, i;
+ DataChunk chunks[2];
+
+ ret = num;
+ BIO_clear_retry_flags(bio);
+ tcp->writeBlocked = FALSE;
+
+ /* we directly append extra bytes in the xmit buffer, this could be prevented
+ * but for now it makes the code more simple.
+ */
+ if (buf && num && !ringbuffer_write(&tcp->xmitBuffer, buf, num))
+ {
+ fprintf(stderr, "%s: an error occured when writing(toWrite=%d)\n", __FUNCTION__, num);
+ return -1;
+ }
+
+ committedBytes = 0;
+ nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer));
+ for (i = 0; i < nchunks; i++)
+ {
+ while (chunks[i].size)
+ {
+ status = BIO_write(bio->next_bio, chunks[i].data, chunks[i].size);
+ /*fprintf(stderr, "%s: i=%d/%d size=%d/%d status=%d retry=%d\n", __FUNCTION__, i, nchunks,
+ chunks[i].size, ringbuffer_used(&tcp->xmitBuffer), status,
+ BIO_should_retry(bio->next_bio)
+ );*/
+ if (status <= 0)
+ {
+ if (BIO_should_retry(bio->next_bio))
+ {
+ tcp->writeBlocked = TRUE;
+ goto out; /* EWOULDBLOCK */
+ }
+
+ /* any other is an error, but we still have to commit written bytes */
+ ret = -1;
+ goto out;
+ }
+
+ committedBytes += status;
+ chunks[i].size -= status;
+ chunks[i].data += status;
+ }
+ }
+
+out:
+ ringbuffer_commit_read_bytes(&tcp->xmitBuffer, committedBytes);
+ return ret;
+}
+
+static int transport_bio_buffered_read(BIO* bio, char* buf, int size)
+{
+ int status;
+ rdpTcp *tcp = (rdpTcp *)bio->ptr;
+
+ tcp->readBlocked = FALSE;
+ BIO_clear_retry_flags(bio);
+
+ status = BIO_read(bio->next_bio, buf, size);
+ /*fprintf(stderr, "%s: size=%d status=%d shouldRetry=%d\n", __FUNCTION__, size, status, BIO_should_retry(bio->next_bio)); */
+
+ if (status <= 0 && BIO_should_retry(bio->next_bio))
+ {
+ BIO_set_retry_read(bio);
+ tcp->readBlocked = TRUE;
+ }
+
+ return status;
+}
+
+static int transport_bio_buffered_puts(BIO* bio, const char* str)
+{
+ return 1;
+}
+
+static int transport_bio_buffered_gets(BIO* bio, char* str, int size)
+{
+ return 1;
+}
+
+static long transport_bio_buffered_ctrl(BIO* bio, int cmd, long arg1, void* arg2)
+{
+ rdpTcp *tcp = (rdpTcp *)bio->ptr;
+
+ switch (cmd)
+ {
+ case BIO_CTRL_FLUSH:
+ return 1;
+ case BIO_CTRL_WPENDING:
+ return ringbuffer_used(&tcp->xmitBuffer);
+ case BIO_CTRL_PENDING:
+ return 0;
+ default:
+ /*fprintf(stderr, "%s: passing to next BIO, bio=%p cmd=%d arg1=%d arg2=%p\n", __FUNCTION__, bio, cmd, arg1, arg2); */
+ return BIO_ctrl(bio->next_bio, cmd, arg1, arg2);
+ }
+
+ return 0;
+}
+
+static int transport_bio_buffered_new(BIO* bio)
+{
+ bio->init = 1;
+ bio->num = 0;
+ bio->ptr = NULL;
+ bio->flags = 0;
+
+ return 1;
+}
+
+static int transport_bio_buffered_free(BIO* bio)
+{
+ return 1;
+}
+
+
+static BIO_METHOD transport_bio_buffered_socket_methods =
+{
+ BIO_TYPE_BUFFERED,
+ "BufferedSocket",
+ transport_bio_buffered_write,
+ transport_bio_buffered_read,
+ transport_bio_buffered_puts,
+ transport_bio_buffered_gets,
+ transport_bio_buffered_ctrl,
+ transport_bio_buffered_new,
+ transport_bio_buffered_free,
+ NULL,
+};
+
+BIO_METHOD* BIO_s_buffered_socket(void)
+{
+ return &transport_bio_buffered_socket_methods;
+}
+
+BOOL transport_bio_buffered_drain(BIO *bio)
+{
+ rdpTcp *tcp = (rdpTcp *)bio->ptr;
+ int status;
+
+ if (!ringbuffer_used(&tcp->xmitBuffer))
+ return 1;
+
+ status = transport_bio_buffered_write(bio, NULL, 0);
+ return status >= 0;
+}
+
+
+
void tcp_get_ip_address(rdpTcp* tcp)
{
BYTE* ip;
if (hostname[0] == '/')
{
tcp->sockfd = freerdp_uds_connect(hostname);
-
if (tcp->sockfd < 0)
return FALSE;
+
+ tcp->socketBio = BIO_new_fd(tcp->sockfd, 1);
+ if (!tcp->socketBio)
+ return FALSE;
}
else
{
- tcp->sockfd = freerdp_tcp_connect(hostname, port);
+ tcp->socketBio = BIO_new(BIO_s_connect());
+ if (!tcp->socketBio)
+ return FALSE;
- if (tcp->sockfd < 0)
+ if (BIO_set_conn_hostname(tcp->socketBio, hostname) < 0 || BIO_set_conn_int_port(tcp->socketBio, &port) < 0)
return FALSE;
- SetEventFileDescriptor(tcp->event, tcp->sockfd);
+ if (BIO_do_connect(tcp->socketBio) <= 0)
+ return FALSE;
+
+ tcp->sockfd = BIO_get_fd(tcp->socketBio, NULL);
+ }
+
+ SetEventFileDescriptor(tcp->event, tcp->sockfd);
- tcp_get_ip_address(tcp);
- tcp_get_mac_address(tcp);
+ tcp_get_ip_address(tcp);
+ tcp_get_mac_address(tcp);
- option_value = 1;
- option_len = sizeof(option_value);
- setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len);
+ option_value = 1;
+ option_len = sizeof(option_value);
+ if (setsockopt(tcp->sockfd, IPPROTO_TCP, TCP_NODELAY, (void*) &option_value, option_len) < 0)
+ fprintf(stderr, "%s: unable to set TCP_NODELAY\n", __FUNCTION__);
- /* receive buffer must be a least 32 K */
- if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0)
+ /* receive buffer must be a least 32 K */
+ if (getsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, &option_len) == 0)
+ {
+ if (option_value < (1024 * 32))
{
- if (option_value < (1024 * 32))
+ option_value = 1024 * 32;
+ option_len = sizeof(option_value);
+ if (setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len) < 0)
{
- option_value = 1024 * 32;
- option_len = sizeof(option_value);
- setsockopt(tcp->sockfd, SOL_SOCKET, SO_RCVBUF, (void*) &option_value, option_len);
+ fprintf(stderr, "%s: unable to set receive buffer len\n", __FUNCTION__);
+ return FALSE;
}
}
-
- tcp_set_keep_alive_mode(tcp);
}
- return TRUE;
-}
-
-int tcp_read(rdpTcp* tcp, BYTE* data, int length)
-{
- return freerdp_tcp_read(tcp->sockfd, data, length);
-}
+ if (!tcp_set_keep_alive_mode(tcp))
+ return FALSE;
-int tcp_write(rdpTcp* tcp, BYTE* data, int length)
-{
- return freerdp_tcp_write(tcp->sockfd, data, length);
-}
+ tcp->bufferedBio = BIO_new(BIO_s_buffered_socket());
+ if (!tcp->bufferedBio)
+ return FALSE;
+ tcp->bufferedBio->ptr = tcp;
-int tcp_wait_read(rdpTcp* tcp)
-{
- return freerdp_tcp_wait_read(tcp->sockfd);
+ tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio);
+ return TRUE;
}
-int tcp_wait_write(rdpTcp* tcp)
-{
- return freerdp_tcp_wait_write(tcp->sockfd);
-}
BOOL tcp_disconnect(rdpTcp* tcp)
{
if (flags == -1)
{
- fprintf(stderr, "tcp_set_blocking_mode: fcntl failed.\n");
+ fprintf(stderr, "%s: fcntl failed, %s.\n", __FUNCTION__, strerror(errno));
return FALSE;
}
{
tcp->sockfd = sockfd;
SetEventFileDescriptor(tcp->event, tcp->sockfd);
+
+ ringbuffer_commit_read_bytes(&tcp->xmitBuffer, ringbuffer_used(&tcp->xmitBuffer));
+
+ if (tcp->socketBio)
+ {
+ if (BIO_set_fd(tcp->socketBio, sockfd, 1) < 0)
+ return -1;
+ }
+ else
+ {
+ tcp->socketBio = BIO_new_socket(sockfd, 1);
+ if (!tcp->socketBio)
+ return -1;
+ }
+
+ if (!tcp->bufferedBio)
+ {
+ tcp->bufferedBio = BIO_new(BIO_s_buffered_socket());
+ if (!tcp->bufferedBio)
+ return FALSE;
+ tcp->bufferedBio->ptr = tcp;
+
+ tcp->bufferedBio = BIO_push(tcp->bufferedBio, tcp->socketBio);
+ }
+
return 0;
}
{
rdpTcp* tcp;
- tcp = (rdpTcp*) malloc(sizeof(rdpTcp));
+ tcp = (rdpTcp *)calloc(1, sizeof(rdpTcp));
+ if (!tcp)
+ return NULL;
- if (tcp)
- {
- ZeroMemory(tcp, sizeof(rdpTcp));
+ if (!ringbuffer_init(&tcp->xmitBuffer, 0x10000))
+ goto out_free;
- tcp->sockfd = -1;
- tcp->settings = settings;
- tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd);
- }
+ tcp->sockfd = -1;
+ tcp->settings = settings;
+
+ tcp->event = CreateFileDescriptorEvent(NULL, FALSE, FALSE, tcp->sockfd);
+ if (!tcp->event || tcp->event == INVALID_HANDLE_VALUE)
+ goto out_ringbuffer;
return tcp;
+out_ringbuffer:
+ ringbuffer_destroy(&tcp->xmitBuffer);
+out_free:
+ free(tcp);
+ return NULL;
}
void tcp_free(rdpTcp* tcp)
{
- if (tcp)
- {
- CloseHandle(tcp->event);
- free(tcp);
- }
+ if (!tcp)
+ return;
+
+ ringbuffer_destroy(&tcp->xmitBuffer);
+ CloseHandle(tcp->event);
+ free(tcp);
}
#include <winpr/stream.h>
#include <winpr/winsock.h>
+#include <freerdp/utils/ringbuffer.h>
+#include <openssl/bio.h>
+
#ifndef MSG_NOSIGNAL
#define MSG_NOSIGNAL 0
#endif
+#define BIO_TYPE_BUFFERED 66
+
typedef struct rdp_tcp rdpTcp;
struct rdp_tcp
#ifdef _WIN32
WSAEVENT wsa_event;
#endif
+ BIO *socketBio;
+ BIO *bufferedBio;
+ RingBuffer xmitBuffer;
+ BOOL writeBlocked;
+ BOOL readBlocked;
+
HANDLE event;
};
#include <freerdp/error.h>
#include <freerdp/utils/tcp.h>
+#include <freerdp/utils/ringbuffer.h>
+#include <openssl/bio.h>
#include <time.h>
#include <errno.h>
#include <fcntl.h>
#ifndef _WIN32
#include <netdb.h>
#include <sys/socket.h>
+#include <sys/select.h>
+#include <sys/time.h>
+#endif
+
+#ifdef HAVE_VALGRIND_MEMCHECK_H
+#include <valgrind/memcheck.h>
#endif
#include "tpkt.h"
#include "transport.h"
#include "rdp.h"
+
#define BUFFER_SIZE 16384
static void* transport_client_thread(void* arg);
tcp_attach(transport->TcpIn, sockfd);
transport->SplitInputOutput = FALSE;
transport->TcpOut = transport->TcpIn;
+ transport->frontBio = transport->TcpIn->bufferedBio;
}
void transport_stop(rdpTransport* transport)
transport_stop(transport);
- if (transport->layer == TRANSPORT_LAYER_TLS)
- status &= tls_disconnect(transport->TlsIn);
-
- if ((transport->layer == TRANSPORT_LAYER_TSG) || (transport->layer == TRANSPORT_LAYER_TSG_TLS))
- {
- status &= tsg_disconnect(transport->tsg);
- }
- else
- {
- status &= tcp_disconnect(transport->TcpIn);
- }
+ BIO_free_all(transport->frontBio);
+ transport->frontBio = 0;
return status;
}
rdpTsg* tsg;
tsg = (rdpTsg*) bio->ptr;
- status = tsg_write(tsg, (BYTE*) buf, num);
BIO_clear_retry_flags(bio);
+ status = tsg_write(tsg, (BYTE*) buf, num);
+ if (status > 0)
+ return status;
if (status == 0)
- {
BIO_set_retry_write(bio);
- }
- return status < 0 ? 0 : num;
+ return -1;
}
static int transport_bio_tsg_read(BIO* bio, char* buf, int size)
return &transport_bio_tsg_methods;
}
+
+
BOOL transport_connect_tls(rdpTransport* transport)
{
+ rdpSettings *settings = transport->settings;
+ rdpTls *targetTls;
+ BIO *targetBio;
int tls_status;
freerdp* instance;
rdpContext* context;
if (transport->layer == TRANSPORT_LAYER_TSG)
{
transport->TsgTls = tls_new(transport->settings);
-
- transport->TsgTls->methods = BIO_s_tsg();
- transport->TsgTls->tsg = (void*) transport->tsg;
-
transport->layer = TRANSPORT_LAYER_TSG_TLS;
- transport->TsgTls->hostname = transport->settings->ServerHostname;
- transport->TsgTls->port = transport->settings->ServerPort;
-
- if (transport->TsgTls->port == 0)
- transport->TsgTls->port = 3389;
-
- tls_status = tls_connect(transport->TsgTls);
-
- if (tls_status < 1)
- {
- if (tls_status < 0)
- {
- if (!connectErrorCode)
- connectErrorCode = TLSCONNECTERROR;
-
- if (!freerdp_get_last_error(context))
- freerdp_set_last_error(context, FREERDP_ERROR_TLS_CONNECT_FAILED);
- }
- else
- {
- if (!freerdp_get_last_error(context))
- freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED);
- }
+ targetTls = transport->TsgTls;
+ targetBio = transport->frontBio;
+ }
+ else
+ {
+ if (!transport->TlsIn)
+ transport->TlsIn = tls_new(settings);
- tls_free(transport->TsgTls);
- transport->TsgTls = NULL;
+ if (!transport->TlsOut)
+ transport->TlsOut = transport->TlsIn;
- return FALSE;
- }
+ targetTls = transport->TlsIn;
+ targetBio = transport->TcpIn->bufferedBio;
- return TRUE;
+ transport->layer = TRANSPORT_LAYER_TLS;
}
- if (!transport->TlsIn)
- transport->TlsIn = tls_new(transport->settings);
-
- if (!transport->TlsOut)
- transport->TlsOut = transport->TlsIn;
- transport->layer = TRANSPORT_LAYER_TLS;
- transport->TlsIn->sockfd = transport->TcpIn->sockfd;
+ targetTls->hostname = settings->ServerHostname;
+ targetTls->port = settings->ServerPort;
- transport->TlsIn->hostname = transport->settings->ServerHostname;
- transport->TlsIn->port = transport->settings->ServerPort;
+ if (targetTls->port == 0)
+ targetTls->port = 3389;
- if (transport->TlsIn->port == 0)
- transport->TlsIn->port = 3389;
-
- tls_status = tls_connect(transport->TlsIn);
+ tls_status = tls_connect(targetTls, targetBio);
if (tls_status < 1)
{
freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED);
}
- tls_free(transport->TlsIn);
-
- if (transport->TlsIn == transport->TlsOut)
- transport->TlsIn = transport->TlsOut = NULL;
- else
- transport->TlsIn = NULL;
+ return FALSE;
+ }
+ transport->frontBio = targetTls->bio;
+ if (!transport->frontBio)
+ {
+ fprintf(stderr, "%s: unable to prepend a filtering TLS bio");
return FALSE;
}
{
freerdp* instance;
rdpSettings* settings;
+ rdpCredssp *credSsp;
settings = transport->settings;
instance = (freerdp*) settings->instance;
if (!transport->credssp)
{
transport->credssp = credssp_new(instance, transport, settings);
+ if (!transport->credssp)
+ return FALSE;
+
transport_set_nla_mode(transport, TRUE);
if (settings->AuthenticationServiceClass)
{
transport->credssp->ServicePrincipalName =
credssp_make_spn(settings->AuthenticationServiceClass, settings->ServerHostname);
+ if (!transport->credssp->ServicePrincipalName)
+ return FALSE;
}
}
- if (credssp_authenticate(transport->credssp) < 0)
+ credSsp = transport->credssp;
+ if (credssp_authenticate(credSsp) < 0)
{
if (!connectErrorCode)
connectErrorCode = AUTHENTICATIONERROR;
"If credentials are valid, the NTLMSSP implementation may be to blame.\n");
transport_set_nla_mode(transport, FALSE);
- credssp_free(transport->credssp);
+ credssp_free(credSsp);
transport->credssp = NULL;
return FALSE;
}
transport_set_nla_mode(transport, FALSE);
- credssp_free(transport->credssp);
+ credssp_free(credSsp);
transport->credssp = NULL;
return TRUE;
int tls_status;
freerdp* instance;
rdpContext* context;
+ rdpSettings *settings = transport->settings;
instance = (freerdp*) transport->settings->instance;
context = instance->context;
tsg = tsg_new(transport);
+ if (!tsg)
+ return FALSE;
tsg->transport = transport;
transport->tsg = tsg;
transport->SplitInputOutput = TRUE;
if (!transport->TlsIn)
- transport->TlsIn = tls_new(transport->settings);
-
- transport->TlsIn->sockfd = transport->TcpIn->sockfd;
- transport->TlsIn->hostname = transport->settings->GatewayHostname;
- transport->TlsIn->port = transport->settings->GatewayPort;
-
- if (transport->TlsIn->port == 0)
- transport->TlsIn->port = 443;
-
+ {
+ transport->TlsIn = tls_new(settings);
+ if (!transport->TlsIn)
+ return FALSE;
+ }
if (!transport->TlsOut)
- transport->TlsOut = tls_new(transport->settings);
+ {
+ transport->TlsOut = tls_new(settings);
+ if (!transport->TlsOut)
+ return FALSE;
+ }
- transport->TlsOut->sockfd = transport->TcpOut->sockfd;
- transport->TlsOut->hostname = transport->settings->GatewayHostname;
- transport->TlsOut->port = transport->settings->GatewayPort;
+ /* put a decent default value for gateway port */
+ if (!settings->GatewayPort)
+ settings->GatewayPort = 443;
- if (transport->TlsOut->port == 0)
- transport->TlsOut->port = 443;
+ transport->TlsIn->hostname = transport->TlsOut->hostname = settings->GatewayHostname;
+ transport->TlsIn->port = transport->TlsOut->port = settings->GatewayPort;
- tls_status = tls_connect(transport->TlsIn);
+ tls_status = tls_connect(transport->TlsIn, transport->TcpIn->bufferedBio);
if (tls_status < 1)
{
if (tls_status < 0)
return FALSE;
}
- tls_status = tls_connect(transport->TlsOut);
-
+ tls_status = tls_connect(transport->TlsOut, transport->TcpOut->bufferedBio);
if (tls_status < 1)
{
if (tls_status < 0)
if (!tsg_connect(tsg, hostname, port))
return FALSE;
+ transport->frontBio = BIO_new(BIO_s_tsg());
+ transport->frontBio->ptr = tsg;
return TRUE;
}
if (transport->GatewayEnabled)
{
transport->layer = TRANSPORT_LAYER_TSG;
+ transport->SplitInputOutput = TRUE;
transport->TcpOut = tcp_new(settings);
- status = tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort);
+ if (!tcp_connect(transport->TcpIn, settings->GatewayHostname, settings->GatewayPort) ||
+ !tcp_set_blocking_mode(transport->TcpIn, FALSE))
+ return FALSE;
- if (status)
- status = tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort);
+ if (!tcp_connect(transport->TcpOut, settings->GatewayHostname, settings->GatewayPort) ||
+ !tcp_set_blocking_mode(transport->TcpOut, FALSE))
+ return FALSE;
- if (status)
- status = transport_tsg_connect(transport, hostname, port);
+ if (!transport_tsg_connect(transport, hostname, port))
+ return FALSE;
+ status = TRUE;
}
else
{
transport->SplitInputOutput = FALSE;
transport->TcpOut = transport->TcpIn;
+ transport->frontBio = transport->TcpIn->bufferedBio;
}
if (status)
transport->TlsOut = transport->TlsIn;
transport->layer = TRANSPORT_LAYER_TLS;
- transport->TlsIn->sockfd = transport->TcpIn->sockfd;
- if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
+ if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
return FALSE;
+ transport->frontBio = transport->TlsIn->bio;
return TRUE;
}
transport->TlsOut = transport->TlsIn;
transport->layer = TRANSPORT_LAYER_TLS;
- transport->TlsIn->sockfd = transport->TcpIn->sockfd;
- if (!tls_accept(transport->TlsIn, transport->settings->CertificateFile, transport->settings->PrivateKeyFile))
+ if (!tls_accept(transport->TlsIn, transport->TcpIn->bufferedBio, settings->CertificateFile, settings->PrivateKeyFile))
return FALSE;
+ transport->frontBio = transport->TlsIn->bio;
/* Network Level Authentication */
return length;
}
+static int transport_wait_for_read(rdpTransport* transport)
+{
+ struct timeval tv;
+ fd_set rset, wset;
+ fd_set *rsetPtr = NULL, *wsetPtr = NULL;
+ rdpTcp *tcpIn;
+
+ tcpIn = transport->TcpIn;
+ if (tcpIn->readBlocked)
+ {
+ rsetPtr = &rset;
+ FD_ZERO(rsetPtr);
+ FD_SET(tcpIn->sockfd, rsetPtr);
+ }
+ else if (tcpIn->writeBlocked)
+ {
+ wsetPtr = &wset;
+ FD_ZERO(wsetPtr);
+ FD_SET(tcpIn->sockfd, wsetPtr);
+ }
+
+ if (!wsetPtr && !rsetPtr)
+ {
+ USleep(1000);
+ return 0;
+ }
+
+ tv.tv_sec = 0;
+ tv.tv_usec = 1000;
+
+ return select(tcpIn->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
+}
+
+
+static int transport_wait_for_write(rdpTransport* transport)
+{
+ struct timeval tv;
+ fd_set rset, wset;
+ fd_set *rsetPtr = NULL, *wsetPtr = NULL;
+ rdpTcp *tcpOut;
+
+ tcpOut = transport->SplitInputOutput ? transport->TcpOut : transport->TcpIn;
+ if (tcpOut->writeBlocked)
+ {
+ wsetPtr = &wset;
+ FD_ZERO(wsetPtr);
+ FD_SET(tcpOut->sockfd, wsetPtr);
+ }
+ else if (tcpOut->readBlocked)
+ {
+ rsetPtr = &rset;
+ FD_ZERO(rsetPtr);
+ FD_SET(tcpOut->sockfd, rsetPtr);
+ }
+
+ if (!wsetPtr && !rsetPtr)
+ {
+ USleep(1000);
+ return 0;
+ }
+
+ tv.tv_sec = 0;
+ tv.tv_usec = 1000;
+
+ return select(tcpOut->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
+}
+
+
int transport_read_layer(rdpTransport* transport, BYTE* data, int bytes)
{
int read = 0;
int status = -1;
+
while (read < bytes)
{
- if (transport->layer == TRANSPORT_LAYER_TLS)
- status = tls_read(transport->TlsIn, data + read, bytes - read);
- else if (transport->layer == TRANSPORT_LAYER_TCP)
- status = tcp_read(transport->TcpIn, data + read, bytes - read);
- else if (transport->layer == TRANSPORT_LAYER_TSG)
- status = tsg_read(transport->tsg, data + read, bytes - read);
- else if (transport->layer == TRANSPORT_LAYER_TSG_TLS) {
- status = tls_read(transport->TsgTls, data + read, bytes - read);
- }
-
- /* blocking means that we can't continue until this is read */
-
- if (!transport->blocking)
- return status;
+ status = BIO_read(transport->frontBio, data + read, bytes - read);
- if (status < 0)
+ if (!status)
{
- /* A read error indicates that the peer has dropped the connection */
transport->layer = TRANSPORT_LAYER_CLOSED;
- return status;
+ return -1;
}
- read += status;
-
- if (status == 0)
+ if (status < 0)
{
- /*
- * instead of sleeping, we should wait timeout on the
- * socket but this only happens on initial connection
- */
- USleep(transport->SleepInterval);
+ if (!BIO_should_retry(transport->frontBio))
+ {
+ /* something unexpected happened, let's close */
+ transport->layer = TRANSPORT_LAYER_CLOSED;
+ return -1;
+ }
+
+ /* non blocking will survive a partial read */
+ if (!transport->blocking)
+ return read;
+
+ /* blocking means that we can't continue until we have read the number of
+ * requested bytes */
+ if (transport_wait_for_read(transport) < 0)
+ {
+ fprintf(stderr, "%s: error when selecting for read\n", __FUNCTION__);
+ return -1;
+ }
+ continue;
}
+
+#ifdef HAVE_VALGRIND_MEMCHECK_H
+ VALGRIND_MAKE_MEM_DEFINED(data + read, bytes - read);
+#endif
+
+ read += status;
}
return read;
}
+
+
int transport_read(rdpTransport* transport, wStream* s)
{
int status;
int position;
int pduLength;
- BYTE header[4];
+ BYTE *header;
int transport_status;
position = 0;
position += status;
}
- CopyMemory(header, Stream_Buffer(s), 4); /* peek at first 4 bytes */
+ header = Stream_Buffer(s);
/* if header is present, read exactly one PDU */
return status;
}
+BOOL transport_bio_buffered_drain(BIO *bio);
+
int transport_write(rdpTransport* transport, wStream* s)
{
int length;
while (length > 0)
{
- if (transport->layer == TRANSPORT_LAYER_TLS)
- status = tls_write(transport->TlsOut, Stream_Pointer(s), length);
- else if (transport->layer == TRANSPORT_LAYER_TCP)
- status = tcp_write(transport->TcpOut, Stream_Pointer(s), length);
- else if (transport->layer == TRANSPORT_LAYER_TSG)
- status = tsg_write(transport->tsg, Stream_Pointer(s), length);
- else if (transport->layer == TRANSPORT_LAYER_TSG_TLS)
- status = tls_write(transport->TsgTls, Stream_Pointer(s), length);
+ status = BIO_write(transport->frontBio, Stream_Pointer(s), length);
- if (status < 0)
- break; /* error occurred */
-
- if (status == 0)
+ if (status <= 0)
{
- /* when sending is blocked in nonblocking mode, the receiving buffer should be checked */
+ /* the buffered BIO that is at the end of the chain always says OK for writing,
+ * so a retry means that for any reason we need to read. The most probable
+ * is a SSL or TSG BIO in the chain.
+ */
+ if (!BIO_should_retry(transport->frontBio))
+ return status;
+
+ /* non-blocking can live with blocked IOs */
if (!transport->blocking)
+ return status;
+
+ if (transport_wait_for_write(transport) < 0)
{
- /* and in case we do have buffered some data, we set the event so next loop will get it */
- if (transport_read_nonblocking(transport) > 0)
- SetEvent(transport->ReceiveEvent);
+ fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__);
+ return -1;
}
+ continue;
+ }
- if (transport->layer == TRANSPORT_LAYER_TLS)
- tls_wait_write(transport->TlsOut);
- else if (transport->layer == TRANSPORT_LAYER_TCP)
- tcp_wait_write(transport->TcpOut);
- else if (transport->layer == TRANSPORT_LAYER_TSG_TLS)
- tls_wait_write(transport->TsgTls);
- else
- USleep(transport->SleepInterval);
+ if (transport->blocking || transport->settings->WaitForOutputBufferFlush)
+ {
+ /* blocking transport, we must ensure the write buffer is really empty */
+ rdpTcp *out = transport->TcpOut;
+
+ while (out->writeBlocked)
+ {
+ if (transport_wait_for_write(transport) < 0)
+ {
+ fprintf(stderr, "%s: error when selecting for write\n", __FUNCTION__);
+ return -1;
+ }
+
+ if (!transport_bio_buffered_drain(out->bufferedBio))
+ {
+ fprintf(stderr, "%s: error when draining outputBuffer\n", __FUNCTION__);
+ return -1;
+ }
+ }
}
length -= status;
}
}
+BOOL tranport_is_write_blocked(rdpTransport* transport)
+{
+ if (transport->TcpIn->writeBlocked)
+ return TRUE;
+
+ return transport->SplitInputOutput &&
+ transport->TcpOut &&
+ transport->TcpOut->writeBlocked;
+}
+
+int tranport_drain_output_buffer(rdpTransport* transport)
+{
+ BOOL ret = FALSE;
+
+ /* First try to send some accumulated bytes in the send buffer */
+ if (transport->TcpIn->writeBlocked)
+ {
+ if (!transport_bio_buffered_drain(transport->TcpIn->bufferedBio))
+ return -1;
+ ret |= transport->TcpIn->writeBlocked;
+ }
+
+ if (transport->SplitInputOutput && transport->TcpOut && transport->TcpOut->writeBlocked)
+ {
+ if (!transport_bio_buffered_drain(transport->TcpOut->bufferedBio))
+ return -1;
+ ret |= transport->TcpOut->writeBlocked;
+ }
+
+ return ret;
+}
+
int transport_check_fds(rdpTransport* transport)
{
int pos;
recv_status = transport->ReceiveCallback(transport, received, transport->ReceiveExtra);
- Stream_Release(received);
-
- if (recv_status < 0)
- return -1;
-
if (recv_status == 1)
{
return 1; /* session redirection */
}
+ Stream_Release(received);
+
+ if (recv_status < 0)
+ return -1;
}
return 0;
{
rdpTransport* transport;
- transport = (rdpTransport*) malloc(sizeof(rdpTransport));
-
- if (transport)
- {
- ZeroMemory(transport, sizeof(rdpTransport));
+ transport = (rdpTransport *)calloc(1, sizeof(rdpTransport));
+ if (!transport)
+ return NULL;
- WLog_Init();
- transport->log = WLog_Get("com.freerdp.core.transport");
+ WLog_Init();
+ transport->log = WLog_Get("com.freerdp.core.transport");
+ if (!transport->log)
+ goto out_free;
- transport->TcpIn = tcp_new(settings);
+ transport->TcpIn = tcp_new(settings);
+ if (!transport->TcpIn)
+ goto out_free;
- transport->settings = settings;
+ transport->settings = settings;
- /* a small 0.1ms delay when transport is blocking. */
- transport->SleepInterval = 100;
+ /* a small 0.1ms delay when transport is blocking. */
+ transport->SleepInterval = 100;
- transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE);
+ transport->ReceivePool = StreamPool_New(TRUE, BUFFER_SIZE);
+ if (!transport->ReceivePool)
+ goto out_free_tcpin;
- /* receive buffer for non-blocking read. */
- transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0);
- transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+ /* receive buffer for non-blocking read. */
+ transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0);
+ if (!transport->ReceiveBuffer)
+ goto out_free_receivepool;
- transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+ transport->ReceiveEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+ if (!transport->ReceiveEvent || transport->ReceiveEvent == INVALID_HANDLE_VALUE)
+ goto out_free_receivebuffer;
- transport->blocking = TRUE;
- transport->GatewayEnabled = FALSE;
+ transport->connectedEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
+ if (!transport->connectedEvent || transport->connectedEvent == INVALID_HANDLE_VALUE)
+ goto out_free_receiveEvent;
- InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000);
- InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000);
+ transport->blocking = TRUE;
+ transport->GatewayEnabled = FALSE;
+ transport->layer = TRANSPORT_LAYER_TCP;
- transport->layer = TRANSPORT_LAYER_TCP;
- }
+ if (!InitializeCriticalSectionAndSpinCount(&(transport->ReadLock), 4000))
+ goto out_free_connectedEvent;
+ if (!InitializeCriticalSectionAndSpinCount(&(transport->WriteLock), 4000))
+ goto out_free_readlock;
return transport;
+
+out_free_readlock:
+ DeleteCriticalSection(&(transport->ReadLock));
+out_free_connectedEvent:
+ CloseHandle(transport->connectedEvent);
+out_free_receiveEvent:
+ CloseHandle(transport->ReceiveEvent);
+out_free_receivebuffer:
+ StreamPool_Return(transport->ReceivePool, transport->ReceiveBuffer);
+out_free_receivepool:
+ StreamPool_Free(transport->ReceivePool);
+out_free_tcpin:
+ tcp_free(transport->TcpIn);
+out_free:
+ free(transport);
+ return NULL;
}
void transport_free(rdpTransport* transport)
{
- if (transport)
- {
- transport_stop(transport);
+ if (!transport)
+ return;
- if (transport->ReceiveBuffer)
- Stream_Release(transport->ReceiveBuffer);
+ transport_stop(transport);
- StreamPool_Free(transport->ReceivePool);
+ if (transport->ReceiveBuffer)
+ Stream_Release(transport->ReceiveBuffer);
- CloseHandle(transport->ReceiveEvent);
- CloseHandle(transport->connectedEvent);
+ StreamPool_Free(transport->ReceivePool);
- if (transport->TlsIn)
- tls_free(transport->TlsIn);
+ CloseHandle(transport->ReceiveEvent);
+ CloseHandle(transport->connectedEvent);
- if (transport->TlsOut != transport->TlsIn)
- tls_free(transport->TlsOut);
+ if (transport->TlsIn)
+ tls_free(transport->TlsIn);
- transport->TlsIn = NULL;
- transport->TlsOut = NULL;
+ if (transport->TlsOut != transport->TlsIn)
+ tls_free(transport->TlsOut);
- if (transport->TcpIn)
- tcp_free(transport->TcpIn);
+ transport->TlsIn = NULL;
+ transport->TlsOut = NULL;
- if (transport->TcpOut != transport->TcpIn)
- tcp_free(transport->TcpOut);
+ if (transport->TcpIn)
+ tcp_free(transport->TcpIn);
- transport->TcpIn = NULL;
- transport->TcpOut = NULL;
+ if (transport->TcpOut != transport->TcpIn)
+ tcp_free(transport->TcpOut);
- tsg_free(transport->tsg);
- transport->tsg = NULL;
+ transport->TcpIn = NULL;
+ transport->TcpOut = NULL;
- DeleteCriticalSection(&(transport->ReadLock));
- DeleteCriticalSection(&(transport->WriteLock));
+ tsg_free(transport->tsg);
+ transport->tsg = NULL;
- free(transport);
- }
+ DeleteCriticalSection(&(transport->ReadLock));
+ DeleteCriticalSection(&(transport->WriteLock));
+
+ free(transport);
}
#include <freerdp/types.h>
#include <freerdp/settings.h>
+
typedef int (*TransportRecv) (rdpTransport* transport, wStream* stream, void* extra);
struct rdp_transport
{
TRANSPORT_LAYER layer;
+ BIO *frontBio;
rdpTsg* tsg;
rdpTcp* TcpIn;
rdpTcp* TcpOut;
void transport_set_gateway_enabled(rdpTransport* transport, BOOL GatewayEnabled);
void transport_set_nla_mode(rdpTransport* transport, BOOL NlaMode);
void transport_get_read_handles(rdpTransport* transport, HANDLE* events, DWORD* count);
+BOOL tranport_is_write_blocked(rdpTransport* transport);
+BOOL tranport_drain_output_buffer(rdpTransport* transport);
wStream* transport_receive_pool_take(rdpTransport* transport);
int transport_receive_pool_return(rdpTransport* transport, wStream* pdu);
#include <winpr/stream.h>
#include <freerdp/utils/tcp.h>
+#include <freerdp/utils/ringbuffer.h>
#include <freerdp/crypto/tls.h>
-
-#ifdef HAVE_VALGRIND_MEMCHECK_H
-#include <valgrind/memcheck.h>
-#endif
+#include "../core/tcp.h"
static CryptoCert tls_get_certificate(rdpTls* tls, BOOL peer)
{
CryptoCert cert;
- X509* server_cert;
+ X509* remote_cert;
if (peer)
- server_cert = SSL_get_peer_certificate(tls->ssl);
+ remote_cert = SSL_get_peer_certificate(tls->ssl);
else
- server_cert = SSL_get_certificate(tls->ssl);
+ remote_cert = SSL_get_certificate(tls->ssl);
- if (!server_cert)
+ if (!remote_cert)
{
- fprintf(stderr, "tls_get_certificate: failed to get the server TLS certificate\n");
- cert = NULL;
+ fprintf(stderr, "%s: failed to get the server TLS certificate\n", __FUNCTION__);
+ return NULL;
}
- else
+
+ cert = malloc(sizeof(*cert));
+ if (!cert)
{
- cert = malloc(sizeof(*cert));
- cert->px509 = server_cert;
+ X509_free(remote_cert);
+ return NULL;
}
+ cert->px509 = remote_cert;
return cert;
}
PrefixLength = strlen(TLS_SERVER_END_POINT);
ChannelBindingTokenLength = PrefixLength + CertificateHashLength;
- ContextBindings = (SecPkgContext_Bindings*) malloc(sizeof(SecPkgContext_Bindings));
- ZeroMemory(ContextBindings, sizeof(SecPkgContext_Bindings));
+ ContextBindings = (SecPkgContext_Bindings*) calloc(1, sizeof(SecPkgContext_Bindings));
+ if (!ContextBindings)
+ return NULL;
ContextBindings->BindingsLength = sizeof(SEC_CHANNEL_BINDINGS) + ChannelBindingTokenLength;
- ChannelBindings = (SEC_CHANNEL_BINDINGS*) malloc(ContextBindings->BindingsLength);
- ZeroMemory(ChannelBindings, ContextBindings->BindingsLength);
+ ChannelBindings = (SEC_CHANNEL_BINDINGS*) calloc(1, ContextBindings->BindingsLength);
+ if (!ChannelBindings)
+ goto out_free;
ContextBindings->Bindings = ChannelBindings;
ChannelBindings->cbApplicationDataLength = ChannelBindingTokenLength;
CopyMemory(&ChannelBindingToken[PrefixLength], CertificateHash, CertificateHashLength);
return ContextBindings;
-}
-
-static void tls_ssl_info_callback(const SSL* ssl, int type, int val)
-{
- if (type & SSL_CB_HANDSHAKE_START)
- {
- }
+out_free:
+ free(ContextBindings);
+ return NULL;
}
-int tls_connect(rdpTls* tls)
-{
- CryptoCert cert;
- long options = 0;
- int verify_status;
- int connection_status;
-
- tls->ctx = SSL_CTX_new(TLSv1_client_method());
+BOOL tls_prepare(rdpTls* tls, BIO *underlying, const SSL_METHOD *method, int options, BOOL clientMode)
+{
+ tls->ctx = SSL_CTX_new(method);
if (!tls->ctx)
{
- fprintf(stderr, "SSL_CTX_new failed\n");
- return -1;
+ fprintf(stderr, "%s: SSL_CTX_new failed\n", __FUNCTION__);
+ return FALSE;
}
- //SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
-
- /**
- * SSL_OP_NO_COMPRESSION:
- *
- * The Microsoft RDP server does not advertise support
- * for TLS compression, but alternative servers may support it.
- * This was observed between early versions of the FreeRDP server
- * and the FreeRDP client, and caused major performance issues,
- * which is why we're disabling it.
- */
-#ifdef SSL_OP_NO_COMPRESSION
- options |= SSL_OP_NO_COMPRESSION;
-#endif
-
- /**
- * SSL_OP_TLS_BLOCK_PADDING_BUG:
- *
- * The Microsoft RDP server does *not* support TLS padding.
- * It absolutely needs to be disabled otherwise it won't work.
- */
- options |= SSL_OP_TLS_BLOCK_PADDING_BUG;
-
- /**
- * SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS:
- *
- * Just like TLS padding, the Microsoft RDP server does not
- * support empty fragments. This needs to be disabled.
- */
- options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
+ SSL_CTX_set_mode(tls->ctx, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER | SSL_MODE_ENABLE_PARTIAL_WRITE);
SSL_CTX_set_options(tls->ctx, options);
+ SSL_CTX_set_read_ahead(tls->ctx, 1);
- tls->ssl = SSL_new(tls->ctx);
-
- if (!tls->ssl)
+ tls->bio = BIO_new_ssl(tls->ctx, clientMode);
+ if (BIO_get_ssl(tls->bio, &tls->ssl) < 0)
{
- fprintf(stderr, "SSL_new failed\n");
- return -1;
+ fprintf(stderr, "%s: unable to retrieve the SSL of the connection\n", __FUNCTION__);
+ return FALSE;
}
- if (tls->tsg)
+ BIO_push(tls->bio, underlying);
+ return TRUE;
+}
+
+int tls_do_handshake(rdpTls* tls, BOOL clientMode)
+{
+ CryptoCert cert;
+ int verify_status, status;
+
+ do
{
- tls->bio = BIO_new(tls->methods);
+ struct timeval tv;
+ fd_set rset;
+ int fd;
- if (!tls->bio)
- {
- fprintf(stderr, "BIO_new failed\n");
+ status = BIO_do_handshake(tls->bio);
+ if (status == 1)
+ break;
+ if (!BIO_should_retry(tls->bio))
return -1;
- }
-
- tls->bio->ptr = tls->tsg;
- SSL_set_bio(tls->ssl, tls->bio, tls->bio);
+ /* we select() only for read even if we should test both read and write
+ * depending of what have blocked */
+ FD_ZERO(&rset);
- SSL_CTX_set_info_callback(tls->ctx, tls_ssl_info_callback);
- }
- else
- {
- if (SSL_set_fd(tls->ssl, tls->sockfd) < 1)
+ fd = BIO_get_fd(tls->bio, NULL);
+ if (fd < 0)
{
- fprintf(stderr, "SSL_set_fd failed\n");
+ fprintf(stderr, "%s: unable to retrieve BIO fd\n", __FUNCTION__);
return -1;
}
- }
- connection_status = SSL_connect(tls->ssl);
+ FD_SET(fd, &rset);
+ tv.tv_sec = 0;
+ tv.tv_usec = 10 * 1000; /* 10ms */
- if (connection_status <= 0)
- {
- if (tls_print_error("SSL_connect", tls->ssl, connection_status))
+ status = select(fd + 1, &rset, NULL, NULL, &tv);
+ if (status < 0)
{
+ fprintf(stderr, "%s: error during select()\n", __FUNCTION__);
return -1;
}
}
+ while (TRUE);
- cert = tls_get_certificate(tls, TRUE);
+ if (!clientMode)
+ return 1;
+ cert = tls_get_certificate(tls, clientMode);
if (!cert)
{
- fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n");
+ fprintf(stderr, "%s: tls_get_certificate failed to return the server certificate.\n", __FUNCTION__);
return -1;
}
tls->Bindings = tls_get_channel_bindings(cert->px509);
+ if (!tls->Bindings)
+ {
+ fprintf(stderr, "%s: unable to retrieve bindings\n", __FUNCTION__);
+ return -1;
+ }
if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength))
{
- fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n");
+ fprintf(stderr, "%s: crypto_cert_get_public_key failed to return the server public key.\n", __FUNCTION__);
tls_free_certificate(cert);
return -1;
}
if (verify_status < 1)
{
- fprintf(stderr, "tls_connect: certificate not trusted, aborting.\n");
+ fprintf(stderr, "%s: certificate not trusted, aborting.\n", __FUNCTION__);
tls_disconnect(tls);
+ tls_free_certificate(cert);
+ return 0;
}
tls_free_certificate(cert);
return verify_status;
}
-BOOL tls_accept(rdpTls* tls, const char* cert_file, const char* privatekey_file)
+int tls_connect(rdpTls* tls, BIO *underlying)
{
- CryptoCert cert;
- long options = 0;
- int connection_status;
+ int options = 0;
- tls->ctx = SSL_CTX_new(SSLv23_server_method());
+ /**
+ * SSL_OP_NO_COMPRESSION:
+ *
+ * The Microsoft RDP server does not advertise support
+ * for TLS compression, but alternative servers may support it.
+ * This was observed between early versions of the FreeRDP server
+ * and the FreeRDP client, and caused major performance issues,
+ * which is why we're disabling it.
+ */
+#ifdef SSL_OP_NO_COMPRESSION
+ options |= SSL_OP_NO_COMPRESSION;
+#endif
- if (tls->ctx == NULL)
- {
- fprintf(stderr, "SSL_CTX_new failed\n");
+ /**
+ * SSL_OP_TLS_BLOCK_PADDING_BUG:
+ *
+ * The Microsoft RDP server does *not* support TLS padding.
+ * It absolutely needs to be disabled otherwise it won't work.
+ */
+ options |= SSL_OP_TLS_BLOCK_PADDING_BUG;
+
+ /**
+ * SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS:
+ *
+ * Just like TLS padding, the Microsoft RDP server does not
+ * support empty fragments. This needs to be disabled.
+ */
+ options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
+
+ if (!tls_prepare(tls, underlying, TLSv1_client_method(), options, TRUE))
return FALSE;
- }
- /*
+ return tls_do_handshake(tls, TRUE);
+}
+
+
+
+BOOL tls_accept(rdpTls* tls, BIO *underlying, const char* cert_file, const char* privatekey_file)
+{
+ long options = 0;
+
+ /**
* SSL_OP_NO_SSLv2:
*
* We only want SSLv3 and TLSv1, so disable SSLv2.
*/
options |= SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS;
- SSL_CTX_set_options(tls->ctx, options);
-
- if (SSL_CTX_use_RSAPrivateKey_file(tls->ctx, privatekey_file, SSL_FILETYPE_PEM) <= 0)
- {
- fprintf(stderr, "SSL_CTX_use_RSAPrivateKey_file failed\n");
- fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file);
+ if (!tls_prepare(tls, underlying, SSLv23_server_method(), options, FALSE))
return FALSE;
- }
- tls->ssl = SSL_new(tls->ctx);
-
- if (!tls->ssl)
+ if (SSL_use_RSAPrivateKey_file(tls->ssl, privatekey_file, SSL_FILETYPE_PEM) <= 0)
{
- fprintf(stderr, "SSL_new failed\n");
+ fprintf(stderr, "%s: SSL_CTX_use_RSAPrivateKey_file failed\n", __FUNCTION__);
+ fprintf(stderr, "PrivateKeyFile: %s\n", privatekey_file);
return FALSE;
}
if (SSL_use_certificate_file(tls->ssl, cert_file, SSL_FILETYPE_PEM) <= 0)
{
- fprintf(stderr, "SSL_use_certificate_file failed\n");
+ fprintf(stderr, "%s: SSL_use_certificate_file failed\n", __FUNCTION__);
return FALSE;
}
- if (SSL_set_fd(tls->ssl, tls->sockfd) < 1)
- {
- fprintf(stderr, "SSL_set_fd failed\n");
+ return tls_do_handshake(tls, FALSE) > 0;
+}
+
+BOOL tls_disconnect(rdpTls* tls)
+{
+ if (!tls)
return FALSE;
- }
- while (1)
+ if (!tls->ssl)
+ return TRUE;
+
+ if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY)
{
- connection_status = SSL_accept(tls->ssl);
+ /**
+ * OpenSSL doesn't really expose an API for sending a TLS alert manually.
+ *
+ * The following code disables the sending of the default "close notify"
+ * and then proceeds to force sending a custom TLS alert before shutting down.
+ *
+ * Manually sending a TLS alert is necessary in certain cases,
+ * like when server-side NLA results in an authentication failure.
+ */
- if (connection_status <= 0)
- {
- switch (SSL_get_error(tls->ssl, connection_status))
- {
- case SSL_ERROR_WANT_READ:
- case SSL_ERROR_WANT_WRITE:
- break;
+ SSL_set_quiet_shutdown(tls->ssl, 1);
- default:
- if (tls_print_error("SSL_accept", tls->ssl, connection_status))
- return FALSE;
- break;
+ if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session))
+ SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session);
- }
- }
- else
- {
- break;
- }
- }
+ tls->ssl->s3->alert_dispatch = 1;
+ tls->ssl->s3->send_alert[0] = tls->alertLevel;
+ tls->ssl->s3->send_alert[1] = tls->alertDescription;
- cert = tls_get_certificate(tls, FALSE);
+ if (tls->ssl->s3->wbuf.left == 0)
+ tls->ssl->method->ssl_dispatch_alert(tls->ssl);
- if (!cert)
- {
- fprintf(stderr, "tls_connect: tls_get_certificate failed to return the server certificate.\n");
- return FALSE;
+ SSL_shutdown(tls->ssl);
}
-
- if (!crypto_cert_get_public_key(cert, &tls->PublicKey, &tls->PublicKeyLength))
+ else
{
- fprintf(stderr, "tls_connect: crypto_cert_get_public_key failed to return the server public key.\n");
- tls_free_certificate(cert);
- return FALSE;
+ SSL_shutdown(tls->ssl);
}
- free(cert);
-
- fprintf(stderr, "TLS connection accepted\n");
-
return TRUE;
}
-BOOL tls_disconnect(rdpTls* tls)
+
+BIO *findBufferedBio(BIO *front)
{
- if (!tls)
- return FALSE;
+ BIO *ret = front;
- if (tls->ssl)
+ while (ret)
{
- if (tls->alertDescription != TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY)
- {
- /**
- * OpenSSL doesn't really expose an API for sending a TLS alert manually.
- *
- * The following code disables the sending of the default "close notify"
- * and then proceeds to force sending a custom TLS alert before shutting down.
- *
- * Manually sending a TLS alert is necessary in certain cases,
- * like when server-side NLA results in an authentication failure.
- */
-
- SSL_set_quiet_shutdown(tls->ssl, 1);
-
- if ((tls->alertLevel == TLS_ALERT_LEVEL_FATAL) && (tls->ssl->session))
- SSL_CTX_remove_session(tls->ssl->ctx, tls->ssl->session);
-
- tls->ssl->s3->alert_dispatch = 1;
- tls->ssl->s3->send_alert[0] = tls->alertLevel;
- tls->ssl->s3->send_alert[1] = tls->alertDescription;
-
- if (tls->ssl->s3->wbuf.left == 0)
- tls->ssl->method->ssl_dispatch_alert(tls->ssl);
-
- SSL_shutdown(tls->ssl);
- }
- else
- {
- SSL_shutdown(tls->ssl);
- }
+ if (BIO_method_type(ret) == BIO_TYPE_BUFFERED)
+ return ret;
+ ret = ret->next_bio;
}
- return TRUE;
+ return ret;
}
-int tls_read(rdpTls* tls, BYTE* data, int length)
+int tls_write_all(rdpTls* tls, const BYTE* data, int length)
{
- int error;
- int status;
-
- if (!tls)
- return -1;
-
- if (!tls->ssl)
+ int status, nchunks, commitedBytes;
+ rdpTcp *tcp;
+ fd_set rset, wset;
+ fd_set *rsetPtr, *wsetPtr;
+ struct timeval tv;
+ BIO *bio = tls->bio;
+ DataChunk chunks[2];
+
+ BIO *bufferedBio = findBufferedBio(bio);
+ if (!bufferedBio)
+ {
+ fprintf(stderr, "%s: error unable to retrieve the bufferedBio in the BIO chain\n", __FUNCTION__);
return -1;
-
- status = SSL_read(tls->ssl, data, length);
-
- if (status == 0)
- {
- return -1; /* peer disconnected */
}
- if (status <= 0)
+ tcp = (rdpTcp *)bufferedBio->ptr;
+
+ do
{
- error = SSL_get_error(tls->ssl, status);
+ status = BIO_write(bio, data, length);
+ /*fprintf(stderr, "%s: BIO_write(len=%d) = %d (retry=%d)\n", __FUNCTION__, length, status, BIO_should_retry(bio));*/
+ if (status > 0)
+ break;
- //fprintf(stderr, "tls_read: length: %d status: %d error: 0x%08X\n",
- // length, status, error);
+ if (!BIO_should_retry(bio))
+ return -1;
- switch (error)
+ /* we try to handle SSL want_read and want_write nicely */
+ rsetPtr = wsetPtr = 0;
+ if (tcp->writeBlocked)
{
- case SSL_ERROR_NONE:
- break;
-
- case SSL_ERROR_WANT_READ:
- case SSL_ERROR_WANT_WRITE:
- status = 0;
- break;
-
- case SSL_ERROR_SYSCALL:
-#ifdef _WIN32
- if (WSAGetLastError() == WSAEWOULDBLOCK)
-#else
- if ((errno == EAGAIN) || (errno == 0))
-#endif
- {
- status = 0;
- }
- else
- {
- if (tls_print_error("SSL_read", tls->ssl, status))
- {
- status = -1;
- }
- else
- {
- status = 0;
- }
- }
- break;
-
- default:
- if (tls_print_error("SSL_read", tls->ssl, status))
- {
- status = -1;
- }
- else
- {
- status = 0;
- }
- break;
+ wsetPtr = &wset;
+ FD_ZERO(&wset);
+ FD_SET(tcp->sockfd, &wset);
+ }
+ else if (tcp->readBlocked)
+ {
+ rsetPtr = &rset;
+ FD_ZERO(&rset);
+ FD_SET(tcp->sockfd, &rset);
+ }
+ else
+ {
+ fprintf(stderr, "%s: weird we're blocked but the underlying is not read or write blocked !\n", __FUNCTION__);
+ USleep(10);
+ continue;
}
- }
-
-#ifdef HAVE_VALGRIND_MEMCHECK_H
- VALGRIND_MAKE_MEM_DEFINED(data, status);
-#endif
-
- return status;
-}
-
-int tls_write(rdpTls* tls, BYTE* data, int length)
-{
- int error;
- int status;
-
- if (!tls)
- return -1;
-
- if (!tls->ssl)
- return -1;
- status = SSL_write(tls->ssl, data, length);
+ tv.tv_sec = 0;
+ tv.tv_usec = 100 * 1000;
- if (status == 0)
- {
- return -1; /* peer disconnected */
+ status = select(tcp->sockfd + 1, rsetPtr, wsetPtr, NULL, &tv);
+ if (status < 0)
+ return -1;
}
+ while (TRUE);
- if (status < 0)
+ /* make sure the output buffer is empty */
+ commitedBytes = 0;
+ while ((nchunks = ringbuffer_peek(&tcp->xmitBuffer, chunks, ringbuffer_used(&tcp->xmitBuffer))))
{
- error = SSL_get_error(tls->ssl, status);
-
- //fprintf(stderr, "tls_write: length: %d status: %d error: 0x%08X\n", length, status, error);
+ int i;
- switch (error)
+ for (i = 0; i < nchunks; i++)
{
- case SSL_ERROR_NONE:
- break;
-
- case SSL_ERROR_WANT_READ:
- case SSL_ERROR_WANT_WRITE:
- status = 0;
- break;
-
- case SSL_ERROR_SYSCALL:
- if (errno == EAGAIN)
- {
- status = 0;
- }
- else
+ while (chunks[i].size)
+ {
+ status = BIO_write(tcp->socketBio, chunks[i].data, chunks[i].size);
+ if (status > 0)
{
- tls_print_error("SSL_write", tls->ssl, status);
- status = -1;
+ chunks[i].size -= status;
+ chunks[i].data += status;
+ commitedBytes += status;
+ continue;
}
- break;
-
- default:
- tls_print_error("SSL_write", tls->ssl, status);
- status = -1;
- break;
- }
- }
-
- return status;
-}
-
-int tls_write_all(rdpTls* tls, BYTE* data, int length)
-{
- int status;
- int sent = 0;
- do
- {
- status = tls_write(tls, &data[sent], length - sent);
+ if (!BIO_should_retry(tcp->socketBio))
+ goto out_fail;
+ FD_ZERO(&rset);
+ FD_SET(tcp->sockfd, &rset);
+ tv.tv_sec = 0;
+ tv.tv_usec = 100 * 1000;
- if (status > 0)
- sent += status;
- else if (status == 0)
- tls_wait_write(tls);
+ status = select(tcp->sockfd + 1, &rset, NULL, NULL, &tv);
+ if (status < 0)
+ goto out_fail;
+ }
- if (sent >= length)
- break;
+ }
}
- while (status >= 0);
-
- if (status > 0)
- return length;
- else
- return status;
-}
-int tls_wait_read(rdpTls* tls)
-{
- return freerdp_tcp_wait_read(tls->sockfd);
-}
+ ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes);
+ return length;
-int tls_wait_write(rdpTls* tls)
-{
- return freerdp_tcp_wait_write(tls->sockfd);
+out_fail:
+ ringbuffer_commit_read_bytes(&tcp->xmitBuffer, commitedBytes);
+ return -1;
}
-static void tls_errors(const char *prefix)
-{
- unsigned long error;
- while ((error = ERR_get_error()) != 0)
- fprintf(stderr, "%s: %s\n", prefix, ERR_error_string(error, NULL));
-}
-
-BOOL tls_print_error(char* func, SSL* connection, int value)
-{
- switch (SSL_get_error(connection, value))
- {
- case SSL_ERROR_ZERO_RETURN:
- fprintf(stderr, "%s: Server closed TLS connection\n", func);
- return TRUE;
-
- case SSL_ERROR_WANT_READ:
- fprintf(stderr, "%s: SSL_ERROR_WANT_READ\n", func);
- return FALSE;
-
- case SSL_ERROR_WANT_WRITE:
- fprintf(stderr, "%s: SSL_ERROR_WANT_WRITE\n", func);
- return FALSE;
-
- case SSL_ERROR_SYSCALL:
-#ifdef _WIN32
- fprintf(stderr, "%s: I/O error: %d\n", func, WSAGetLastError());
-#else
- fprintf(stderr, "%s: I/O error: %s (%d)\n", func, strerror(errno), errno);
-#endif
- tls_errors(func);
- return TRUE;
-
- case SSL_ERROR_SSL:
- fprintf(stderr, "%s: Failure in SSL library (protocol error?)\n", func);
- tls_errors(func);
- return TRUE;
-
- default:
- fprintf(stderr, "%s: Unknown error\n", func);
- tls_errors(func);
- return TRUE;
- }
-}
int tls_set_alert_code(rdpTls* tls, int level, int description)
{
if (!bio)
{
- fprintf(stderr, "tls_verify_certificate: BIO_new() failure\n");
+ fprintf(stderr, "%s: BIO_new() failure\n", __FUNCTION__);
return -1;
}
if (status < 0)
{
- fprintf(stderr, "tls_verify_certificate: PEM_write_bio_X509 failure: %d\n", status);
+ fprintf(stderr, "%s: PEM_write_bio_X509 failure: %d\n", __FUNCTION__, status);
return -1;
}
if (status < 0)
{
- fprintf(stderr, "tls_verify_certificate: failed to read certificate\n");
+ fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__);
return -1;
}
if (status < 0)
{
- fprintf(stderr, "tls_verify_certificate: failed to read certificate\n");
+ fprintf(stderr, "%s: failed to read certificate\n", __FUNCTION__);
return -1;
}
status = instance->VerifyX509Certificate(instance, pemCert, length, hostname, port, 0);
}
- fprintf(stderr, "VerifyX509Certificate: (length = %d) status: %d\n%s\n",
- length, status, pemCert);
+ fprintf(stderr, "%s: (length = %d) status: %d\n%s\n", __FUNCTION__, length, status, pemCert);
free(pemCert);
BIO_free(bio);
{
rdpTls* tls;
- tls = (rdpTls*) malloc(sizeof(rdpTls));
-
- if (tls)
- {
- ZeroMemory(tls, sizeof(rdpTls));
-
- SSL_load_error_strings();
- SSL_library_init();
+ tls = (rdpTls *)calloc(1, sizeof(rdpTls));
+ if (!tls)
+ return NULL;
- tls->settings = settings;
- tls->certificate_store = certificate_store_new(settings);
+ SSL_load_error_strings();
+ SSL_library_init();
- tls->alertLevel = TLS_ALERT_LEVEL_WARNING;
- tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY;
- }
+ tls->settings = settings;
+ tls->certificate_store = certificate_store_new(settings);
+ if (!tls->certificate_store)
+ goto out_free;
+ tls->alertLevel = TLS_ALERT_LEVEL_WARNING;
+ tls->alertDescription = TLS_ALERT_DESCRIPTION_CLOSE_NOTIFY;
return tls;
+
+out_free:
+ free(tls);
+ return NULL;
}
void tls_free(rdpTls* tls)
{
- if (tls)
- {
- if (tls->ssl)
- {
- SSL_free(tls->ssl);
- tls->ssl = NULL;
- }
+ if (!tls)
+ return;
- if (tls->ctx)
- {
- SSL_CTX_free(tls->ctx);
- tls->ctx = NULL;
- }
+ if (tls->ctx)
+ {
+ SSL_CTX_free(tls->ctx);
+ tls->ctx = NULL;
+ }
- if (tls->PublicKey)
- {
- free(tls->PublicKey);
- tls->PublicKey = NULL;
- }
+ if (tls->PublicKey)
+ {
+ free(tls->PublicKey);
+ tls->PublicKey = NULL;
+ }
- if (tls->Bindings)
- {
- free(tls->Bindings->Bindings);
- free(tls->Bindings);
- tls->Bindings = NULL;
- }
+ if (tls->Bindings)
+ {
+ free(tls->Bindings->Bindings);
+ free(tls->Bindings);
+ tls->Bindings = NULL;
+ }
- certificate_store_free(tls->certificate_store);
- tls->certificate_store = NULL;
+ certificate_store_free(tls->certificate_store);
+ tls->certificate_store = NULL;
- free(tls);
- }
+ free(tls);
}