libfreerdp-core: handle TCP/TLS connection at TSG RPC channel level
authorMarc-André Moreau <marcandre.moreau@gmail.com>
Thu, 12 Feb 2015 19:40:26 +0000 (14:40 -0500)
committerMarc-André Moreau <marcandre.moreau@gmail.com>
Thu, 12 Feb 2015 19:40:26 +0000 (14:40 -0500)
libfreerdp/core/gateway/rpc.c
libfreerdp/core/gateway/rpc.h
libfreerdp/core/gateway/tsg.c

index 62df2cd..18b9933 100644 (file)
@@ -117,53 +117,6 @@ const RPC_SECURITY_PROVIDER_INFO RPC_SECURITY_PROVIDER_INFO_TABLE[] =
  *
  */
 
-BOOL rpc_connect(rdpRpc* rpc)
-{
-       RpcInChannel* inChannel;
-       RpcOutChannel* outChannel;
-
-       inChannel = rpc->VirtualConnection->DefaultInChannel;
-       outChannel = rpc->VirtualConnection->DefaultOutChannel;
-
-       rpc_virtual_connection_transition_to_state(rpc, rpc->VirtualConnection, VIRTUAL_CONNECTION_STATE_INITIAL);
-
-       /* Connect IN Channel */
-
-       rpc_in_channel_transition_to_state(inChannel, CLIENT_IN_CHANNEL_STATE_CONNECTED);
-
-       if (rpc_ncacn_http_ntlm_init(rpc, (RpcChannel*) inChannel) < 0)
-               return FALSE;
-
-       /* Send IN Channel Request */
-
-       if (rpc_ncacn_http_send_in_channel_request(rpc, inChannel) < 0)
-       {
-               WLog_ERR(TAG, "rpc_ncacn_http_send_in_channel_request failure");
-               return FALSE;
-       }
-
-       rpc_in_channel_transition_to_state(inChannel, CLIENT_IN_CHANNEL_STATE_SECURITY);
-
-       /* Connect OUT Channel */
-
-       rpc_out_channel_transition_to_state(outChannel, CLIENT_OUT_CHANNEL_STATE_CONNECTED);
-
-       if (rpc_ncacn_http_ntlm_init(rpc, (RpcChannel*) outChannel) < 0)
-               return FALSE;
-
-       /* Send OUT Channel Request */
-
-       if (rpc_ncacn_http_send_out_channel_request(rpc, outChannel, FALSE) < 0)
-       {
-               WLog_ERR(TAG, "rpc_ncacn_http_send_out_channel_request failure");
-               return FALSE;
-       }
-
-       rpc_out_channel_transition_to_state(outChannel, CLIENT_OUT_CHANNEL_STATE_SECURITY);
-
-       return TRUE;
-}
-
 void rpc_pdu_header_print(rpcconn_hdr_t* header)
 {
        WLog_INFO(TAG,  "rpc_vers: %d", header->common.rpc_vers);
@@ -780,6 +733,133 @@ void rpc_virtual_connection_free(RpcVirtualConnection* connection)
        free(connection);
 }
 
+int rpc_channel_tls_connect(RpcChannel* channel, int timeout)
+{
+       rdpTcp* tcp;
+       rdpTls* tls;
+       int tlsStatus;
+       rdpRpc* rpc = channel->rpc;
+       rdpContext* context = rpc->context;
+       rdpSettings* settings = context->settings;
+
+       tcp = channel->tcp = freerdp_tcp_new(settings);
+
+       if (!freerdp_tcp_connect(tcp, settings->GatewayHostname, settings->GatewayPort, timeout))
+               return -1;
+
+       if (!freerdp_tcp_set_blocking_mode(tcp, FALSE))
+               return -1;
+
+       tls = channel->tls = tls_new(settings);
+
+       if (!tls)
+               return -1;
+
+       tls->hostname = settings->GatewayHostname;
+       tls->port = settings->GatewayPort;
+       tls->isGatewayTransport = TRUE;
+
+       tlsStatus = tls_connect(tls, tcp->bufferedBio);
+
+       if (tlsStatus < 1)
+       {
+               if (tlsStatus < 0)
+               {
+                       if (!freerdp_get_last_error(context))
+                               freerdp_set_last_error(context, FREERDP_ERROR_TLS_CONNECT_FAILED);
+               }
+               else
+               {
+                       if (!freerdp_get_last_error(context))
+                               freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED);
+               }
+
+               return -1;
+       }
+
+       return 1;
+}
+
+int rpc_in_channel_connect(RpcInChannel* inChannel, int timeout)
+{
+       rdpRpc* rpc = inChannel->rpc;
+
+       /* Connect IN Channel */
+
+       if (rpc_channel_tls_connect((RpcChannel*) inChannel, timeout) < 0)
+               return -1;
+
+       rpc_in_channel_transition_to_state(inChannel, CLIENT_IN_CHANNEL_STATE_CONNECTED);
+
+       if (rpc_ncacn_http_ntlm_init(rpc, (RpcChannel*) inChannel) < 0)
+               return -1;
+
+       /* Send IN Channel Request */
+
+       if (rpc_ncacn_http_send_in_channel_request(rpc, inChannel) < 0)
+       {
+               WLog_ERR(TAG, "rpc_ncacn_http_send_in_channel_request failure");
+               return -1;
+       }
+
+       rpc_in_channel_transition_to_state(inChannel, CLIENT_IN_CHANNEL_STATE_SECURITY);
+
+       return 1;
+}
+
+int rpc_out_channel_connect(RpcOutChannel* outChannel, int timeout)
+{
+       rdpRpc* rpc = outChannel->rpc;
+
+       /* Connect OUT Channel */
+
+       if (rpc_channel_tls_connect((RpcChannel*) outChannel, timeout) < 0)
+               return -1;
+
+       rpc_out_channel_transition_to_state(outChannel, CLIENT_OUT_CHANNEL_STATE_CONNECTED);
+
+       if (rpc_ncacn_http_ntlm_init(rpc, (RpcChannel*) outChannel) < 0)
+               return FALSE;
+
+       /* Send OUT Channel Request */
+
+       if (rpc_ncacn_http_send_out_channel_request(rpc, outChannel, FALSE) < 0)
+       {
+               WLog_ERR(TAG, "rpc_ncacn_http_send_out_channel_request failure");
+               return FALSE;
+       }
+
+       rpc_out_channel_transition_to_state(outChannel, CLIENT_OUT_CHANNEL_STATE_SECURITY);
+
+       return 1;
+}
+
+BOOL rpc_connect(rdpRpc* rpc, int timeout)
+{
+       RpcInChannel* inChannel;
+       RpcOutChannel* outChannel;
+       RpcVirtualConnection* connection;
+
+       rpc->VirtualConnection = rpc_virtual_connection_new(rpc);
+
+       if (!rpc->VirtualConnection)
+               return FALSE;
+
+       connection = rpc->VirtualConnection;
+       inChannel = connection->DefaultInChannel;
+       outChannel = connection->DefaultOutChannel;
+
+       rpc_virtual_connection_transition_to_state(rpc, connection, VIRTUAL_CONNECTION_STATE_INITIAL);
+
+       if (rpc_in_channel_connect(inChannel, timeout) < 0)
+               return FALSE;
+
+       if (rpc_out_channel_connect(outChannel, timeout) < 0)
+               return FALSE;
+
+       return TRUE;
+}
+
 rdpRpc* rpc_new(rdpTransport* transport)
 {
        rdpRpc* rpc = (rdpRpc*) calloc(1, sizeof(rdpRpc));
@@ -820,11 +900,6 @@ rdpRpc* rpc_new(rdpTransport* transport)
        rpc->CurrentKeepAliveInterval = rpc->KeepAliveInterval;
        rpc->CurrentKeepAliveTime = 0;
 
-       rpc->VirtualConnection = rpc_virtual_connection_new(rpc);
-
-       if (!rpc->VirtualConnection)
-               goto out_free_virtual_connection;
-
        rpc->CallId = 2;
 
        if (rpc_client_new(rpc) < 0)
@@ -833,8 +908,6 @@ rdpRpc* rpc_new(rdpTransport* transport)
        return rpc;
 out_free_rpc_client:
        rpc_client_free(rpc);
-out_free_virtual_connection:
-       rpc_virtual_connection_free(rpc->VirtualConnection);
 out_free:
        free(rpc);
        return NULL;
@@ -853,7 +926,11 @@ void rpc_free(rdpRpc* rpc)
                        rpc->ntlm = NULL;
                }
 
-               rpc_virtual_connection_free(rpc->VirtualConnection);
+               if (rpc->VirtualConnection)
+               {
+                       rpc_virtual_connection_free(rpc->VirtualConnection);
+                       rpc->VirtualConnection = NULL;
+               }
 
                free(rpc);
        }
index a5c4dd5..ccb2186 100644 (file)
@@ -787,7 +787,7 @@ int rpc_out_channel_transition_to_state(RpcOutChannel* outChannel, CLIENT_OUT_CH
 int rpc_virtual_connection_transition_to_state(rdpRpc* rpc,
                RpcVirtualConnection* connection, VIRTUAL_CONNECTION_STATE state);
 
-BOOL rpc_connect(rdpRpc* rpc);
+BOOL rpc_connect(rdpRpc* rpc, int timeout);
 
 rdpRpc* rpc_new(rdpTransport* transport);
 void rpc_free(rdpRpc* rpc);
index 9a5dbb9..d3225aa 100644 (file)
@@ -1410,9 +1410,28 @@ int tsg_check(rdpTsg* tsg)
        return status;
 }
 
+BOOL tsg_set_hostname(rdpTsg* tsg, const char* hostname)
+{
+       free(tsg->Hostname);
+       tsg->Hostname = NULL;
+
+       ConvertToUnicode(CP_UTF8, 0, hostname, -1, &tsg->Hostname, 0);
+
+       return TRUE;
+}
+
+BOOL tsg_set_machine_name(rdpTsg* tsg, const char* machineName)
+{
+       free(tsg->MachineName);
+       tsg->MachineName = NULL;
+
+       ConvertToUnicode(CP_UTF8, 0, machineName, -1, &tsg->MachineName, 0);
+
+       return TRUE;
+}
+
 BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port, int timeout)
 {
-       int tlsStatus;
        HANDLE events[2];
        rdpRpc* rpc = tsg->rpc;
        RpcInChannel* inChannel;
@@ -1420,95 +1439,26 @@ BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port, int timeout)
        RpcVirtualConnection* connection;
        rdpSettings* settings = rpc->settings;
        rdpTransport* transport = rpc->transport;
-       rdpContext* context = rpc->context;
 
        tsg->Port = port;
        tsg->transport = transport;
 
-       free(tsg->Hostname);
-       tsg->Hostname = NULL;
-       ConvertToUnicode(CP_UTF8, 0, hostname, -1, &tsg->Hostname, 0);
-
-       free(tsg->MachineName);
-       tsg->MachineName = NULL;
-       ConvertToUnicode(CP_UTF8, 0, settings->ComputerName, -1, &tsg->MachineName, 0);
-
-       connection = rpc->VirtualConnection;
-       inChannel = connection->DefaultInChannel;
-       outChannel = connection->DefaultOutChannel;
-
-       inChannel->tcp = freerdp_tcp_new(settings);
-       outChannel->tcp = freerdp_tcp_new(settings);
-
-       if (!freerdp_tcp_connect(inChannel->tcp, settings->GatewayHostname, settings->GatewayPort, timeout) ||
-                       !freerdp_tcp_set_blocking_mode(inChannel->tcp, FALSE))
-               return FALSE;
-
-       if (!freerdp_tcp_connect(outChannel->tcp, settings->GatewayHostname, settings->GatewayPort, timeout) ||
-                       !freerdp_tcp_set_blocking_mode(outChannel->tcp, FALSE))
-               return FALSE;
-
-       inChannel->tls = tls_new(settings);
-
-       if (!inChannel->tls)
-               return FALSE;
-
-       outChannel->tls = tls_new(settings);
-
-       if (!outChannel->tls)
-               return FALSE;
-
-       /* put a decent default value for gateway port */
        if (!settings->GatewayPort)
                settings->GatewayPort = 443;
 
-       inChannel->tls->hostname = outChannel->tls->hostname = settings->GatewayHostname;
-       inChannel->tls->port = outChannel->tls->port = settings->GatewayPort;
-
-       inChannel->tls->isGatewayTransport = TRUE;
-       tlsStatus = tls_connect(inChannel->tls, inChannel->tcp->bufferedBio);
-
-       if (tlsStatus < 1)
-       {
-               if (tlsStatus < 0)
-               {
-                       if (!freerdp_get_last_error(context))
-                               freerdp_set_last_error(context, FREERDP_ERROR_TLS_CONNECT_FAILED);
-               }
-               else
-               {
-                       if (!freerdp_get_last_error(context))
-                               freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED);
-               }
-
-               return FALSE;
-       }
-
-       outChannel->tls->isGatewayTransport = TRUE;
-       tlsStatus = tls_connect(outChannel->tls, outChannel->tcp->bufferedBio);
-
-       if (tlsStatus < 1)
-       {
-               if (tlsStatus < 0)
-               {
-                       if (!freerdp_get_last_error(context))
-                               freerdp_set_last_error(context, FREERDP_ERROR_TLS_CONNECT_FAILED);
-               }
-               else
-               {
-                       if (!freerdp_get_last_error(context))
-                               freerdp_set_last_error(context, FREERDP_ERROR_CONNECT_CANCELLED);
-               }
-
-               return FALSE;
-       }
+       tsg_set_hostname(tsg, hostname);
+       tsg_set_machine_name(tsg, settings->ComputerName);
 
-       if (!rpc_connect(rpc))
+       if (!rpc_connect(rpc, timeout))
        {
                WLog_ERR(TAG, "rpc_connect error!");
                return FALSE;
        }
 
+       connection = rpc->VirtualConnection;
+       inChannel = connection->DefaultInChannel;
+       outChannel = connection->DefaultOutChannel;
+
        BIO_get_event(inChannel->tls->bio, &events[0]);
        BIO_get_event(outChannel->tls->bio, &events[1]);