From 7a39dcd7e2eea8b9b77cda2ffde5129563be7e1d Mon Sep 17 00:00:00 2001 From: Armin Novak Date: Tue, 10 Jul 2018 12:04:27 +0200 Subject: [PATCH] Updated reconnect to handle cases where PostConnect was not called freerdp_reconnect might be called after a freerdp_connect failed due to a TCP timeout waiting for user input. In such cases we need to know if PostConect was already called and do that if not. --- libfreerdp/core/connection.c | 37 +++++++++++++++++++--- libfreerdp/core/connection.h | 3 +- libfreerdp/core/freerdp.c | 3 +- libfreerdp/core/nego.c | 73 ++++++++++++++++---------------------------- 4 files changed, 63 insertions(+), 53 deletions(-) diff --git a/libfreerdp/core/connection.c b/libfreerdp/core/connection.c index 05a902f..6f3536a 100644 --- a/libfreerdp/core/connection.c +++ b/libfreerdp/core/connection.c @@ -351,6 +351,35 @@ BOOL rdp_client_disconnect_and_clear(rdpRdp* rdp) return TRUE; } +static BOOL rdp_client_reconnect_channels(rdpRdp* rdp) +{ + BOOL status; + rdpContext* context; + rdpChannels* channels; + + if (!rdp || !rdp->context || !rdp->context->channels) + return FALSE; + + context = rdp->context; + channels = context->channels; + + if (context->instance->ConnectionCallbackState == CLIENT_STATE_INITIAL) + return FALSE; + + if (context->instance->ConnectionCallbackState == CLIENT_STATE_PRECONNECT_PASSED) + { + if (!IFCALLRESULT(FALSE, context->instance->PostConnect, context->instance)) + return FALSE; + + context->instance->ConnectionCallbackState = CLIENT_STATE_POSTCONNECT_PASSED; + } + + if (context->instance->ConnectionCallbackState == CLIENT_STATE_POSTCONNECT_PASSED) + status = (freerdp_channels_post_connect(context->channels, context->instance) == CHANNEL_RC_OK); + + return status; +} + BOOL rdp_client_redirect(rdpRdp* rdp) { BOOL status; @@ -424,8 +453,8 @@ BOOL rdp_client_redirect(rdpRdp* rdp) status = rdp_client_connect(rdp); - if (status && (context->instance->ConnectionCallbackState == CLIENT_STATE_POSTCONNECT_PASSED)) - status = (freerdp_channels_post_connect(context->channels, context->instance) == CHANNEL_RC_OK); + if (status) + status = rdp_client_reconnect_channels(rdp); return status; } @@ -447,8 +476,8 @@ BOOL rdp_client_reconnect(rdpRdp* rdp) status = rdp_client_connect(rdp); - if (status && (context->instance->ConnectionCallbackState == CLIENT_STATE_POSTCONNECT_PASSED)) - status = (freerdp_channels_post_connect(channels, context->instance) == CHANNEL_RC_OK); + if (status) + status = rdp_client_reconnect_channels(rdp); return status; } diff --git a/libfreerdp/core/connection.h b/libfreerdp/core/connection.h index 78dbba5..1fddae2 100644 --- a/libfreerdp/core/connection.h +++ b/libfreerdp/core/connection.h @@ -53,8 +53,7 @@ enum CLIENT_CONNECTION_STATE { CLIENT_STATE_INITIAL, CLIENT_STATE_PRECONNECT_PASSED, - CLIENT_STATE_POSTCONNECT_PASSED, - CLIENT_STATE_POSTDISCONNECT_PASSED + CLIENT_STATE_POSTCONNECT_PASSED }; FREERDP_LOCAL BOOL rdp_client_connect(rdpRdp* rdp); diff --git a/libfreerdp/core/freerdp.c b/libfreerdp/core/freerdp.c index d7563b9..72cae89 100644 --- a/libfreerdp/core/freerdp.c +++ b/libfreerdp/core/freerdp.c @@ -400,6 +400,7 @@ BOOL freerdp_check_event_handles(rdpContext* context) { if (freerdp_get_last_error(context) == FREERDP_ERROR_SUCCESS) WLog_ERR(TAG, "checkChannelErrorEvent() failed - %"PRIi32"", status); + return FALSE; } @@ -407,6 +408,7 @@ BOOL freerdp_check_event_handles(rdpContext* context) { int rc = freerdp_message_queue_process_pending_messages( context->instance, FREERDP_INPUT_MESSAGE_QUEUE); + if (rc < 0) return FALSE; else @@ -512,7 +514,6 @@ BOOL freerdp_disconnect(freerdp* instance) } IFCALL(instance->PostDisconnect, instance); - instance->ConnectionCallbackState = CLIENT_STATE_POSTDISCONNECT_PASSED; if (instance->update->pcap_rfx) { diff --git a/libfreerdp/core/nego.c b/libfreerdp/core/nego.c index 2bcba6b..dcfc6ff 100644 --- a/libfreerdp/core/nego.c +++ b/libfreerdp/core/nego.c @@ -103,7 +103,6 @@ BOOL nego_connect(rdpNego* nego) { WLog_DBG(TAG, "Security Layer Negotiation is disabled"); /* attempt only the highest enabled protocol (see nego_attempt_*) */ - nego->EnabledProtocols[PROTOCOL_NLA] = FALSE; nego->EnabledProtocols[PROTOCOL_TLS] = FALSE; nego->EnabledProtocols[PROTOCOL_RDP] = FALSE; @@ -152,13 +151,13 @@ BOOL nego_connect(rdpNego* nego) do { WLog_DBG(TAG, "state: %s", NEGO_STATE_STRINGS[nego->state]); - nego_send(nego); if (nego->state == NEGO_STATE_FAIL) { if (freerdp_get_last_error(nego->transport->context) == FREERDP_ERROR_SUCCESS) WLog_ERR(TAG, "Protocol Security Negotiation Failure"); + nego->state = NEGO_STATE_FINAL; return FALSE; } @@ -167,7 +166,6 @@ BOOL nego_connect(rdpNego* nego) } WLog_DBG(TAG, "Negotiated %s security", PROTOCOL_SECURITY_STRINGS[nego->SelectedProtocol]); - /* update settings with negotiated protocol security */ settings->RequestedProtocols = nego->RequestedProtocols; settings->SelectedProtocol = nego->SelectedProtocol; @@ -183,14 +181,16 @@ BOOL nego_connect(rdpNego* nego) * Advertise all supported encryption methods if the client * implementation did not set any security methods */ - settings->EncryptionMethods = ENCRYPTION_METHOD_40BIT | ENCRYPTION_METHOD_56BIT | ENCRYPTION_METHOD_128BIT | ENCRYPTION_METHOD_FIPS; + settings->EncryptionMethods = ENCRYPTION_METHOD_40BIT | ENCRYPTION_METHOD_56BIT | + ENCRYPTION_METHOD_128BIT | ENCRYPTION_METHOD_FIPS; } } /* finally connect security layer (if not already done) */ if (!nego_security_connect(nego)) { - WLog_DBG(TAG, "Failed to connect with %s security", PROTOCOL_SECURITY_STRINGS[nego->SelectedProtocol]); + WLog_DBG(TAG, "Failed to connect with %s security", + PROTOCOL_SECURITY_STRINGS[nego->SelectedProtocol]); return FALSE; } @@ -242,7 +242,7 @@ BOOL nego_security_connect(rdpNego* nego) * @return */ -BOOL nego_tcp_connect(rdpNego* nego) +static BOOL nego_tcp_connect(rdpNego* nego) { if (!nego->TcpConnected) { @@ -280,7 +280,8 @@ BOOL nego_tcp_connect(rdpNego* nego) BOOL nego_transport_connect(rdpNego* nego) { - nego_tcp_connect(nego); + if (!nego_tcp_connect(nego)) + return FALSE; if (nego->TcpConnected && !nego->NegotiateSecurityLayer) return nego_security_connect(nego); @@ -301,7 +302,6 @@ BOOL nego_transport_disconnect(rdpNego* nego) nego->TcpConnected = FALSE; nego->SecurityConnected = FALSE; - return TRUE; } @@ -317,7 +317,6 @@ BOOL nego_send_preconnection_pdu(rdpNego* nego) UINT32 cbSize; UINT16 cchPCB = 0; WCHAR* wszPCB = NULL; - WLog_DBG(TAG, "Sending preconnection PDU"); if (!nego_tcp_connect(nego)) @@ -334,6 +333,7 @@ BOOL nego_send_preconnection_pdu(rdpNego* nego) } s = Stream_New(NULL, cbSize); + if (!s) { WLog_ERR(TAG, "Stream_New failed!"); @@ -361,7 +361,6 @@ BOOL nego_send_preconnection_pdu(rdpNego* nego) } Stream_Free(s, TRUE); - return TRUE; } @@ -373,7 +372,6 @@ BOOL nego_send_preconnection_pdu(rdpNego* nego) void nego_attempt_ext(rdpNego* nego) { nego->RequestedProtocols = PROTOCOL_NLA | PROTOCOL_TLS | PROTOCOL_EXT; - WLog_DBG(TAG, "Attempting NLA extended security"); if (!nego_transport_connect(nego)) @@ -419,7 +417,6 @@ void nego_attempt_ext(rdpNego* nego) void nego_attempt_nla(rdpNego* nego) { nego->RequestedProtocols = PROTOCOL_NLA | PROTOCOL_TLS; - WLog_DBG(TAG, "Attempting NLA security"); if (!nego_transport_connect(nego)) @@ -463,7 +460,6 @@ void nego_attempt_nla(rdpNego* nego) void nego_attempt_tls(rdpNego* nego) { nego->RequestedProtocols = PROTOCOL_TLS; - WLog_DBG(TAG, "Attempting TLS security"); if (!nego_transport_connect(nego)) @@ -503,7 +499,6 @@ void nego_attempt_tls(rdpNego* nego) void nego_attempt_rdp(rdpNego* nego) { nego->RequestedProtocols = PROTOCOL_RDP; - WLog_DBG(TAG, "Attempting RDP security"); if (!nego_transport_connect(nego)) @@ -534,7 +529,6 @@ BOOL nego_recv_response(rdpNego* nego) { int status; wStream* s; - s = Stream_New(NULL, 1024); if (!s) @@ -552,7 +546,6 @@ BOOL nego_recv_response(rdpNego* nego) } status = nego_recv(nego->transport, s, nego); - Stream_Free(s, TRUE); if (status < 0) @@ -588,14 +581,12 @@ int nego_recv(rdpTransport* transport, wStream* s, void* extra) if (li > 6) { /* rdpNegData (optional) */ - Stream_Read_UINT8(s, type); /* Type */ switch (type) { case TYPE_RDP_NEG_RSP: nego_process_negotiation_response(nego, s); - WLog_DBG(TAG, "selected_protocol: %"PRIu32"", nego->SelectedProtocol); /* enhanced security selected ? */ @@ -603,12 +594,13 @@ int nego_recv(rdpTransport* transport, wStream* s, void* extra) if (nego->SelectedProtocol) { if ((nego->SelectedProtocol == PROTOCOL_NLA) && - (!nego->EnabledProtocols[PROTOCOL_NLA])) + (!nego->EnabledProtocols[PROTOCOL_NLA])) { nego->state = NEGO_STATE_FAIL; } + if ((nego->SelectedProtocol == PROTOCOL_TLS) && - (!nego->EnabledProtocols[PROTOCOL_TLS])) + (!nego->EnabledProtocols[PROTOCOL_TLS])) { nego->state = NEGO_STATE_FAIL; } @@ -617,6 +609,7 @@ int nego_recv(rdpTransport* transport, wStream* s, void* extra) { nego->state = NEGO_STATE_FAIL; } + break; case TYPE_RDP_NEG_FAILURE: @@ -663,13 +656,11 @@ static BOOL nego_read_request_token_or_cookie(rdpNego* nego, wStream* s) * string terminated by a 0x0D0A two-byte sequence: * Cookie:[space]mstshash=[ANSISTRING][\x0D\x0A] */ - - BYTE *str = NULL; + BYTE* str = NULL; UINT16 crlf = 0; size_t pos, len; BOOL result = FALSE; BOOL isToken = FALSE; - str = Stream_Pointer(s); pos = Stream_GetPosition(s); @@ -693,8 +684,10 @@ static BOOL nego_read_request_token_or_cookie(rdpNego* nego, wStream* s) while (Stream_GetRemainingLength(s) >= 2) { Stream_Read_UINT16(s, crlf); + if (crlf == 0x0A0D) break; + Stream_Rewind(s, 1); } @@ -703,6 +696,7 @@ static BOOL nego_read_request_token_or_cookie(rdpNego* nego, wStream* s) Stream_Rewind(s, 2); len = Stream_GetPosition(s) - pos; Stream_Write_UINT16(s, 0); + if (strlen((char*)str) == len) { if (isToken) @@ -716,12 +710,12 @@ static BOOL nego_read_request_token_or_cookie(rdpNego* nego, wStream* s) { Stream_SetPosition(s, pos); WLog_ERR(TAG, "invalid %s received", - isToken ? "routing token" : "cookie"); + isToken ? "routing token" : "cookie"); } else { WLog_DBG(TAG, "received %s [%s]", - isToken ? "routing token" : "cookie", str); + isToken ? "routing token" : "cookie", str); } return result; @@ -760,7 +754,6 @@ BOOL nego_read_request(rdpNego* nego, wStream* s) if (Stream_GetRemainingLength(s) >= 8) { /* rdpNegData (optional) */ - Stream_Read_UINT8(s, type); /* Type */ if (type != TYPE_RDP_NEG_REQ) @@ -808,8 +801,8 @@ BOOL nego_send_negotiation_request(rdpNego* nego) size_t bm, em; BYTE flags = 0; int cookie_length; - s = Stream_New(NULL, 512); + if (!s) { WLog_ERR(TAG, "Stream_New failed!"); @@ -827,8 +820,8 @@ BOOL nego_send_negotiation_request(rdpNego* nego) /* Ensure Routing Token is correctly terminated - may already be present in string */ if ((nego->RoutingTokenLength > 2) && - (nego->RoutingToken[nego->RoutingTokenLength - 2] == 0x0D) && - (nego->RoutingToken[nego->RoutingTokenLength - 1] == 0x0A)) + (nego->RoutingToken[nego->RoutingTokenLength - 2] == 0x0D) && + (nego->RoutingToken[nego->RoutingTokenLength - 1] == 0x0A)) { WLog_DBG(TAG, "Routing token looks correctly terminated - use verbatim"); length += nego->RoutingTokenLength; @@ -860,7 +853,6 @@ BOOL nego_send_negotiation_request(rdpNego* nego) if ((nego->RequestedProtocols > PROTOCOL_RDP) || (nego->sendNegoData)) { /* RDP_NEG_DATA must be present for TLS and NLA */ - if (nego->RestrictedAdminModeRequired) flags |= RESTRICTED_ADMIN_MODE_REQUIRED; @@ -876,7 +868,6 @@ BOOL nego_send_negotiation_request(rdpNego* nego) tpkt_write_header(s, length); tpdu_write_connection_request(s, length - 5); Stream_SetPosition(s, em); - Stream_SealLength(s); if (transport_write(nego->transport, s) < 0) @@ -886,7 +877,6 @@ BOOL nego_send_negotiation_request(rdpNego* nego) } Stream_Free(s, TRUE); - return TRUE; } @@ -900,13 +890,10 @@ void nego_process_negotiation_request(rdpNego* nego, wStream* s) { BYTE flags; UINT16 length; - Stream_Read_UINT8(s, flags); Stream_Read_UINT16(s, length); Stream_Read_UINT32(s, nego->RequestedProtocols); - WLog_DBG(TAG, "RDP_NEG_REQ: RequestedProtocol: 0x%08"PRIX32"", nego->RequestedProtocols); - nego->state = NEGO_STATE_FINAL; } @@ -919,7 +906,6 @@ void nego_process_negotiation_request(rdpNego* nego, wStream* s) void nego_process_negotiation_response(rdpNego* nego, wStream* s) { UINT16 length; - WLog_DBG(TAG, "RDP_NEG_RSP"); if (Stream_GetRemainingLength(s) < 7) @@ -932,7 +918,6 @@ void nego_process_negotiation_response(rdpNego* nego, wStream* s) Stream_Read_UINT8(s, nego->flags); Stream_Read_UINT16(s, length); Stream_Read_UINT32(s, nego->SelectedProtocol); - nego->state = NEGO_STATE_FINAL; } @@ -947,9 +932,7 @@ void nego_process_negotiation_failure(rdpNego* nego, wStream* s) BYTE flags; UINT16 length; UINT32 failureCode; - WLog_DBG(TAG, "RDP_NEG_FAILURE"); - Stream_Read_UINT8(s, flags); Stream_Read_UINT16(s, length); Stream_Read_UINT32(s, failureCode); @@ -999,11 +982,10 @@ BOOL nego_send_negotiation_response(rdpNego* nego) wStream* s; BYTE flags; rdpSettings* settings; - status = TRUE; settings = nego->transport->settings; - s = Stream_New(NULL, 512); + if (!s) { WLog_ERR(TAG, "Stream_New failed!"); @@ -1018,11 +1000,9 @@ BOOL nego_send_negotiation_response(rdpNego* nego) { UINT32 errorCode = (nego->SelectedProtocol & ~PROTOCOL_FAILED_NEGO); flags = 0; - Stream_Write_UINT8(s, TYPE_RDP_NEG_FAILURE); Stream_Write_UINT8(s, flags); /* flags */ Stream_Write_UINT16(s, 8); /* RDP_NEG_DATA length (8) */ - Stream_Write_UINT32(s, errorCode); length += 8; status = FALSE; @@ -1047,7 +1027,6 @@ BOOL nego_send_negotiation_response(rdpNego* nego) tpkt_write_header(s, length); tpdu_write_connection_confirm(s, length - 5); Stream_SetPosition(s, em); - Stream_SealLength(s); if (transport_write(nego->transport, s) < 0) @@ -1148,9 +1127,7 @@ rdpNego* nego_new(rdpTransport* transport) return NULL; nego->transport = transport; - nego_init(nego); - return nego; } @@ -1276,8 +1253,10 @@ BOOL nego_set_routing_token(rdpNego* nego, BYTE* RoutingToken, DWORD RoutingToke free(nego->RoutingToken); nego->RoutingTokenLength = RoutingTokenLength; nego->RoutingToken = (BYTE*) malloc(nego->RoutingTokenLength); + if (!nego->RoutingToken) return FALSE; + CopyMemory(nego->RoutingToken, RoutingToken, nego->RoutingTokenLength); return TRUE; } @@ -1300,8 +1279,10 @@ BOOL nego_set_cookie(rdpNego* nego, char* cookie) return TRUE; nego->cookie = _strdup(cookie); + if (!nego->cookie) return FALSE; + return TRUE; } -- 2.7.4