Updated reconnect to handle cases where PostConnect was not called
authorArmin Novak <armin.novak@thincast.com>
Tue, 10 Jul 2018 10:04:27 +0000 (12:04 +0200)
committerArmin Novak <armin.novak@thincast.com>
Tue, 10 Jul 2018 10:04:27 +0000 (12:04 +0200)
freerdp_reconnect might be called after a freerdp_connect failed due
to a TCP timeout waiting for user input.
In such cases we need to know if PostConect was already called and
do that if not.

libfreerdp/core/connection.c
libfreerdp/core/connection.h
libfreerdp/core/freerdp.c
libfreerdp/core/nego.c

index 05a902f..6f3536a 100644 (file)
@@ -351,6 +351,35 @@ BOOL rdp_client_disconnect_and_clear(rdpRdp* rdp)
        return TRUE;
 }
 
+static BOOL rdp_client_reconnect_channels(rdpRdp* rdp)
+{
+       BOOL status;
+       rdpContext* context;
+       rdpChannels* channels;
+
+       if (!rdp || !rdp->context || !rdp->context->channels)
+               return FALSE;
+
+       context = rdp->context;
+       channels = context->channels;
+
+       if (context->instance->ConnectionCallbackState == CLIENT_STATE_INITIAL)
+               return FALSE;
+
+       if (context->instance->ConnectionCallbackState == CLIENT_STATE_PRECONNECT_PASSED)
+       {
+               if (!IFCALLRESULT(FALSE, context->instance->PostConnect, context->instance))
+                       return FALSE;
+
+               context->instance->ConnectionCallbackState = CLIENT_STATE_POSTCONNECT_PASSED;
+       }
+
+       if (context->instance->ConnectionCallbackState == CLIENT_STATE_POSTCONNECT_PASSED)
+               status = (freerdp_channels_post_connect(context->channels, context->instance) == CHANNEL_RC_OK);
+
+       return status;
+}
+
 BOOL rdp_client_redirect(rdpRdp* rdp)
 {
        BOOL status;
@@ -424,8 +453,8 @@ BOOL rdp_client_redirect(rdpRdp* rdp)
 
        status = rdp_client_connect(rdp);
 
-       if (status && (context->instance->ConnectionCallbackState == CLIENT_STATE_POSTCONNECT_PASSED))
-               status = (freerdp_channels_post_connect(context->channels, context->instance) == CHANNEL_RC_OK);
+       if (status)
+               status = rdp_client_reconnect_channels(rdp);
 
        return status;
 }
@@ -447,8 +476,8 @@ BOOL rdp_client_reconnect(rdpRdp* rdp)
 
        status = rdp_client_connect(rdp);
 
-       if (status && (context->instance->ConnectionCallbackState == CLIENT_STATE_POSTCONNECT_PASSED))
-               status = (freerdp_channels_post_connect(channels, context->instance) == CHANNEL_RC_OK);
+       if (status)
+               status = rdp_client_reconnect_channels(rdp);
 
        return status;
 }
index 78dbba5..1fddae2 100644 (file)
@@ -53,8 +53,7 @@ enum CLIENT_CONNECTION_STATE
 {
        CLIENT_STATE_INITIAL,
        CLIENT_STATE_PRECONNECT_PASSED,
-       CLIENT_STATE_POSTCONNECT_PASSED,
-       CLIENT_STATE_POSTDISCONNECT_PASSED
+       CLIENT_STATE_POSTCONNECT_PASSED
 };
 
 FREERDP_LOCAL BOOL rdp_client_connect(rdpRdp* rdp);
index d7563b9..72cae89 100644 (file)
@@ -400,6 +400,7 @@ BOOL freerdp_check_event_handles(rdpContext* context)
        {
                if (freerdp_get_last_error(context) == FREERDP_ERROR_SUCCESS)
                        WLog_ERR(TAG, "checkChannelErrorEvent() failed - %"PRIi32"", status);
+
                return FALSE;
        }
 
@@ -407,6 +408,7 @@ BOOL freerdp_check_event_handles(rdpContext* context)
        {
                int rc = freerdp_message_queue_process_pending_messages(
                             context->instance, FREERDP_INPUT_MESSAGE_QUEUE);
+
                if (rc < 0)
                        return FALSE;
                else
@@ -512,7 +514,6 @@ BOOL freerdp_disconnect(freerdp* instance)
        }
 
        IFCALL(instance->PostDisconnect, instance);
-       instance->ConnectionCallbackState = CLIENT_STATE_POSTDISCONNECT_PASSED;
 
        if (instance->update->pcap_rfx)
        {
index 2bcba6b..dcfc6ff 100644 (file)
@@ -103,7 +103,6 @@ BOOL nego_connect(rdpNego* nego)
                {
                        WLog_DBG(TAG, "Security Layer Negotiation is disabled");
                        /* attempt only the highest enabled protocol (see nego_attempt_*) */
-
                        nego->EnabledProtocols[PROTOCOL_NLA] = FALSE;
                        nego->EnabledProtocols[PROTOCOL_TLS] = FALSE;
                        nego->EnabledProtocols[PROTOCOL_RDP] = FALSE;
@@ -152,13 +151,13 @@ BOOL nego_connect(rdpNego* nego)
                do
                {
                        WLog_DBG(TAG, "state: %s", NEGO_STATE_STRINGS[nego->state]);
-
                        nego_send(nego);
 
                        if (nego->state == NEGO_STATE_FAIL)
                        {
                                if (freerdp_get_last_error(nego->transport->context) == FREERDP_ERROR_SUCCESS)
                                        WLog_ERR(TAG, "Protocol Security Negotiation Failure");
+
                                nego->state = NEGO_STATE_FINAL;
                                return FALSE;
                        }
@@ -167,7 +166,6 @@ BOOL nego_connect(rdpNego* nego)
        }
 
        WLog_DBG(TAG, "Negotiated %s security", PROTOCOL_SECURITY_STRINGS[nego->SelectedProtocol]);
-
        /* update settings with negotiated protocol security */
        settings->RequestedProtocols = nego->RequestedProtocols;
        settings->SelectedProtocol = nego->SelectedProtocol;
@@ -183,14 +181,16 @@ BOOL nego_connect(rdpNego* nego)
                         * Advertise all supported encryption methods if the client
                         * implementation did not set any security methods
                         */
-                       settings->EncryptionMethods = ENCRYPTION_METHOD_40BIT | ENCRYPTION_METHOD_56BIT | ENCRYPTION_METHOD_128BIT | ENCRYPTION_METHOD_FIPS;
+                       settings->EncryptionMethods = ENCRYPTION_METHOD_40BIT | ENCRYPTION_METHOD_56BIT |
+                                                     ENCRYPTION_METHOD_128BIT | ENCRYPTION_METHOD_FIPS;
                }
        }
 
        /* finally connect security layer (if not already done) */
        if (!nego_security_connect(nego))
        {
-               WLog_DBG(TAG, "Failed to connect with %s security", PROTOCOL_SECURITY_STRINGS[nego->SelectedProtocol]);
+               WLog_DBG(TAG, "Failed to connect with %s security",
+                        PROTOCOL_SECURITY_STRINGS[nego->SelectedProtocol]);
                return FALSE;
        }
 
@@ -242,7 +242,7 @@ BOOL nego_security_connect(rdpNego* nego)
  * @return
  */
 
-BOOL nego_tcp_connect(rdpNego* nego)
+static BOOL nego_tcp_connect(rdpNego* nego)
 {
        if (!nego->TcpConnected)
        {
@@ -280,7 +280,8 @@ BOOL nego_tcp_connect(rdpNego* nego)
 
 BOOL nego_transport_connect(rdpNego* nego)
 {
-       nego_tcp_connect(nego);
+       if (!nego_tcp_connect(nego))
+               return FALSE;
 
        if (nego->TcpConnected && !nego->NegotiateSecurityLayer)
                return nego_security_connect(nego);
@@ -301,7 +302,6 @@ BOOL nego_transport_disconnect(rdpNego* nego)
 
        nego->TcpConnected = FALSE;
        nego->SecurityConnected = FALSE;
-
        return TRUE;
 }
 
@@ -317,7 +317,6 @@ BOOL nego_send_preconnection_pdu(rdpNego* nego)
        UINT32 cbSize;
        UINT16 cchPCB = 0;
        WCHAR* wszPCB = NULL;
-
        WLog_DBG(TAG, "Sending preconnection PDU");
 
        if (!nego_tcp_connect(nego))
@@ -334,6 +333,7 @@ BOOL nego_send_preconnection_pdu(rdpNego* nego)
        }
 
        s = Stream_New(NULL, cbSize);
+
        if (!s)
        {
                WLog_ERR(TAG, "Stream_New failed!");
@@ -361,7 +361,6 @@ BOOL nego_send_preconnection_pdu(rdpNego* nego)
        }
 
        Stream_Free(s, TRUE);
-
        return TRUE;
 }
 
@@ -373,7 +372,6 @@ BOOL nego_send_preconnection_pdu(rdpNego* nego)
 void nego_attempt_ext(rdpNego* nego)
 {
        nego->RequestedProtocols = PROTOCOL_NLA | PROTOCOL_TLS | PROTOCOL_EXT;
-
        WLog_DBG(TAG, "Attempting NLA extended security");
 
        if (!nego_transport_connect(nego))
@@ -419,7 +417,6 @@ void nego_attempt_ext(rdpNego* nego)
 void nego_attempt_nla(rdpNego* nego)
 {
        nego->RequestedProtocols = PROTOCOL_NLA | PROTOCOL_TLS;
-
        WLog_DBG(TAG, "Attempting NLA security");
 
        if (!nego_transport_connect(nego))
@@ -463,7 +460,6 @@ void nego_attempt_nla(rdpNego* nego)
 void nego_attempt_tls(rdpNego* nego)
 {
        nego->RequestedProtocols = PROTOCOL_TLS;
-
        WLog_DBG(TAG, "Attempting TLS security");
 
        if (!nego_transport_connect(nego))
@@ -503,7 +499,6 @@ void nego_attempt_tls(rdpNego* nego)
 void nego_attempt_rdp(rdpNego* nego)
 {
        nego->RequestedProtocols = PROTOCOL_RDP;
-
        WLog_DBG(TAG, "Attempting RDP security");
 
        if (!nego_transport_connect(nego))
@@ -534,7 +529,6 @@ BOOL nego_recv_response(rdpNego* nego)
 {
        int status;
        wStream* s;
-
        s = Stream_New(NULL, 1024);
 
        if (!s)
@@ -552,7 +546,6 @@ BOOL nego_recv_response(rdpNego* nego)
        }
 
        status = nego_recv(nego->transport, s, nego);
-
        Stream_Free(s, TRUE);
 
        if (status < 0)
@@ -588,14 +581,12 @@ int nego_recv(rdpTransport* transport, wStream* s, void* extra)
        if (li > 6)
        {
                /* rdpNegData (optional) */
-
                Stream_Read_UINT8(s, type); /* Type */
 
                switch (type)
                {
                        case TYPE_RDP_NEG_RSP:
                                nego_process_negotiation_response(nego, s);
-
                                WLog_DBG(TAG, "selected_protocol: %"PRIu32"", nego->SelectedProtocol);
 
                                /* enhanced security selected ? */
@@ -603,12 +594,13 @@ int nego_recv(rdpTransport* transport, wStream* s, void* extra)
                                if (nego->SelectedProtocol)
                                {
                                        if ((nego->SelectedProtocol == PROTOCOL_NLA) &&
-                                               (!nego->EnabledProtocols[PROTOCOL_NLA]))
+                                           (!nego->EnabledProtocols[PROTOCOL_NLA]))
                                        {
                                                nego->state = NEGO_STATE_FAIL;
                                        }
+
                                        if ((nego->SelectedProtocol == PROTOCOL_TLS) &&
-                                               (!nego->EnabledProtocols[PROTOCOL_TLS]))
+                                           (!nego->EnabledProtocols[PROTOCOL_TLS]))
                                        {
                                                nego->state = NEGO_STATE_FAIL;
                                        }
@@ -617,6 +609,7 @@ int nego_recv(rdpTransport* transport, wStream* s, void* extra)
                                {
                                        nego->state = NEGO_STATE_FAIL;
                                }
+
                                break;
 
                        case TYPE_RDP_NEG_FAILURE:
@@ -663,13 +656,11 @@ static BOOL nego_read_request_token_or_cookie(rdpNego* nego, wStream* s)
         * string terminated by a 0x0D0A two-byte sequence:
         * Cookie:[space]mstshash=[ANSISTRING][\x0D\x0A]
         */
-
-       BYTE *str = NULL;
+       BYTE* str = NULL;
        UINT16 crlf = 0;
        size_t pos, len;
        BOOL result = FALSE;
        BOOL isToken = FALSE;
-
        str = Stream_Pointer(s);
        pos = Stream_GetPosition(s);
 
@@ -693,8 +684,10 @@ static BOOL nego_read_request_token_or_cookie(rdpNego* nego, wStream* s)
        while (Stream_GetRemainingLength(s) >= 2)
        {
                Stream_Read_UINT16(s, crlf);
+
                if (crlf == 0x0A0D)
                        break;
+
                Stream_Rewind(s, 1);
        }
 
@@ -703,6 +696,7 @@ static BOOL nego_read_request_token_or_cookie(rdpNego* nego, wStream* s)
                Stream_Rewind(s, 2);
                len = Stream_GetPosition(s) - pos;
                Stream_Write_UINT16(s, 0);
+
                if (strlen((char*)str) == len)
                {
                        if (isToken)
@@ -716,12 +710,12 @@ static BOOL nego_read_request_token_or_cookie(rdpNego* nego, wStream* s)
        {
                Stream_SetPosition(s, pos);
                WLog_ERR(TAG, "invalid %s received",
-                       isToken ? "routing token" : "cookie");
+                        isToken ? "routing token" : "cookie");
        }
        else
        {
                WLog_DBG(TAG, "received %s [%s]",
-                       isToken ? "routing token" : "cookie", str);
+                        isToken ? "routing token" : "cookie", str);
        }
 
        return result;
@@ -760,7 +754,6 @@ BOOL nego_read_request(rdpNego* nego, wStream* s)
        if (Stream_GetRemainingLength(s) >= 8)
        {
                /* rdpNegData (optional) */
-
                Stream_Read_UINT8(s, type); /* Type */
 
                if (type != TYPE_RDP_NEG_REQ)
@@ -808,8 +801,8 @@ BOOL nego_send_negotiation_request(rdpNego* nego)
        size_t bm, em;
        BYTE flags = 0;
        int cookie_length;
-
        s = Stream_New(NULL, 512);
+
        if (!s)
        {
                WLog_ERR(TAG, "Stream_New failed!");
@@ -827,8 +820,8 @@ BOOL nego_send_negotiation_request(rdpNego* nego)
                /* Ensure Routing Token is correctly terminated - may already be present in string */
 
                if ((nego->RoutingTokenLength > 2) &&
-                               (nego->RoutingToken[nego->RoutingTokenLength - 2] == 0x0D) &&
-                               (nego->RoutingToken[nego->RoutingTokenLength - 1] == 0x0A))
+                   (nego->RoutingToken[nego->RoutingTokenLength - 2] == 0x0D) &&
+                   (nego->RoutingToken[nego->RoutingTokenLength - 1] == 0x0A))
                {
                        WLog_DBG(TAG, "Routing token looks correctly terminated - use verbatim");
                        length += nego->RoutingTokenLength;
@@ -860,7 +853,6 @@ BOOL nego_send_negotiation_request(rdpNego* nego)
        if ((nego->RequestedProtocols > PROTOCOL_RDP) || (nego->sendNegoData))
        {
                /* RDP_NEG_DATA must be present for TLS and NLA */
-
                if (nego->RestrictedAdminModeRequired)
                        flags |= RESTRICTED_ADMIN_MODE_REQUIRED;
 
@@ -876,7 +868,6 @@ BOOL nego_send_negotiation_request(rdpNego* nego)
        tpkt_write_header(s, length);
        tpdu_write_connection_request(s, length - 5);
        Stream_SetPosition(s, em);
-
        Stream_SealLength(s);
 
        if (transport_write(nego->transport, s) < 0)
@@ -886,7 +877,6 @@ BOOL nego_send_negotiation_request(rdpNego* nego)
        }
 
        Stream_Free(s, TRUE);
-
        return TRUE;
 }
 
@@ -900,13 +890,10 @@ void nego_process_negotiation_request(rdpNego* nego, wStream* s)
 {
        BYTE flags;
        UINT16 length;
-
        Stream_Read_UINT8(s, flags);
        Stream_Read_UINT16(s, length);
        Stream_Read_UINT32(s, nego->RequestedProtocols);
-
        WLog_DBG(TAG, "RDP_NEG_REQ: RequestedProtocol: 0x%08"PRIX32"", nego->RequestedProtocols);
-
        nego->state = NEGO_STATE_FINAL;
 }
 
@@ -919,7 +906,6 @@ void nego_process_negotiation_request(rdpNego* nego, wStream* s)
 void nego_process_negotiation_response(rdpNego* nego, wStream* s)
 {
        UINT16 length;
-
        WLog_DBG(TAG, "RDP_NEG_RSP");
 
        if (Stream_GetRemainingLength(s) < 7)
@@ -932,7 +918,6 @@ void nego_process_negotiation_response(rdpNego* nego, wStream* s)
        Stream_Read_UINT8(s, nego->flags);
        Stream_Read_UINT16(s, length);
        Stream_Read_UINT32(s, nego->SelectedProtocol);
-
        nego->state = NEGO_STATE_FINAL;
 }
 
@@ -947,9 +932,7 @@ void nego_process_negotiation_failure(rdpNego* nego, wStream* s)
        BYTE flags;
        UINT16 length;
        UINT32 failureCode;
-
        WLog_DBG(TAG, "RDP_NEG_FAILURE");
-
        Stream_Read_UINT8(s, flags);
        Stream_Read_UINT16(s, length);
        Stream_Read_UINT32(s, failureCode);
@@ -999,11 +982,10 @@ BOOL nego_send_negotiation_response(rdpNego* nego)
        wStream* s;
        BYTE flags;
        rdpSettings* settings;
-
        status = TRUE;
        settings = nego->transport->settings;
-
        s = Stream_New(NULL, 512);
+
        if (!s)
        {
                WLog_ERR(TAG, "Stream_New failed!");
@@ -1018,11 +1000,9 @@ BOOL nego_send_negotiation_response(rdpNego* nego)
        {
                UINT32 errorCode = (nego->SelectedProtocol & ~PROTOCOL_FAILED_NEGO);
                flags = 0;
-
                Stream_Write_UINT8(s, TYPE_RDP_NEG_FAILURE);
                Stream_Write_UINT8(s, flags); /* flags */
                Stream_Write_UINT16(s, 8); /* RDP_NEG_DATA length (8) */
-
                Stream_Write_UINT32(s, errorCode);
                length += 8;
                status = FALSE;
@@ -1047,7 +1027,6 @@ BOOL nego_send_negotiation_response(rdpNego* nego)
        tpkt_write_header(s, length);
        tpdu_write_connection_confirm(s, length - 5);
        Stream_SetPosition(s, em);
-
        Stream_SealLength(s);
 
        if (transport_write(nego->transport, s) < 0)
@@ -1148,9 +1127,7 @@ rdpNego* nego_new(rdpTransport* transport)
                return NULL;
 
        nego->transport = transport;
-
        nego_init(nego);
-
        return nego;
 }
 
@@ -1276,8 +1253,10 @@ BOOL nego_set_routing_token(rdpNego* nego, BYTE* RoutingToken, DWORD RoutingToke
        free(nego->RoutingToken);
        nego->RoutingTokenLength = RoutingTokenLength;
        nego->RoutingToken = (BYTE*) malloc(nego->RoutingTokenLength);
+
        if (!nego->RoutingToken)
                return FALSE;
+
        CopyMemory(nego->RoutingToken, RoutingToken, nego->RoutingTokenLength);
        return TRUE;
 }
@@ -1300,8 +1279,10 @@ BOOL nego_set_cookie(rdpNego* nego, char* cookie)
                return TRUE;
 
        nego->cookie = _strdup(cookie);
+
        if (!nego->cookie)
                return FALSE;
+
        return TRUE;
 }