libfreerdp-core: remove receive queue, make tsg connection sequence more event-driven
authorMarc-André Moreau <marcandre.moreau@gmail.com>
Mon, 2 Feb 2015 13:19:07 +0000 (08:19 -0500)
committerMarc-André Moreau <marcandre.moreau@gmail.com>
Mon, 2 Feb 2015 13:19:07 +0000 (08:19 -0500)
libfreerdp/core/gateway/rpc.c
libfreerdp/core/gateway/rpc.h
libfreerdp/core/gateway/rpc_client.c
libfreerdp/core/gateway/rpc_client.h
libfreerdp/core/gateway/tsg.c
libfreerdp/core/gateway/tsg.h
libfreerdp/core/transport.c

index 8a88184..ab62787 100644 (file)
@@ -464,151 +464,6 @@ out_free_pdu:
        return -1;
 }
 
-int rpc_check(rdpRpc* rpc)
-{
-       RPC_PDU* pdu;
-       rpcconn_rts_hdr_t* rts;
-
-       if (rpc->State < RPC_CLIENT_STATE_ESTABLISHED)
-       {
-               switch (rpc->VirtualConnection->State)
-               {
-                       case VIRTUAL_CONNECTION_STATE_INITIAL:
-                               break;
-
-                       case VIRTUAL_CONNECTION_STATE_OUT_CHANNEL_WAIT:
-                               break;
-
-                       case VIRTUAL_CONNECTION_STATE_WAIT_A3W:
-
-                               pdu = rpc_recv_dequeue_pdu(rpc, TRUE);
-
-                               if (!pdu)
-                                       return -1;
-
-                               rts = (rpcconn_rts_hdr_t*) Stream_Buffer(pdu->s);
-
-                               if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_A3_SIGNATURE, rts))
-                               {
-                                       WLog_ERR(TAG, "unexpected RTS PDU: Expected CONN/A3");
-                                       return -1;
-                               }
-
-                               rts_recv_CONN_A3_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s));
-
-                               rpc_client_receive_pool_return(rpc, pdu);
-
-                               rpc->VirtualConnection->State = VIRTUAL_CONNECTION_STATE_WAIT_C2;
-                               WLog_DBG(TAG, "VIRTUAL_CONNECTION_STATE_WAIT_C2");
-
-                               break;
-
-                       case VIRTUAL_CONNECTION_STATE_WAIT_C2:
-
-                               pdu = rpc_recv_dequeue_pdu(rpc, TRUE);
-
-                               if (!pdu)
-                                       return FALSE;
-
-                               rts = (rpcconn_rts_hdr_t*) Stream_Buffer(pdu->s);
-
-                               if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_C2_SIGNATURE, rts))
-                               {
-                                       WLog_ERR(TAG, "unexpected RTS PDU: Expected CONN/C2");
-                                       return FALSE;
-                               }
-
-                               rts_recv_CONN_C2_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s));
-
-                               rpc_client_receive_pool_return(rpc, pdu);
-
-                               rpc->VirtualConnection->State = VIRTUAL_CONNECTION_STATE_OPENED;
-                               WLog_DBG(TAG, "VIRTUAL_CONNECTION_STATE_OPENED");
-
-                               rpc->State = RPC_CLIENT_STATE_ESTABLISHED;
-
-                               break;
-
-                       case VIRTUAL_CONNECTION_STATE_OPENED:
-                               break;
-
-                       case VIRTUAL_CONNECTION_STATE_FINAL:
-                               break;
-               }
-
-               return 1;
-       }
-
-       if (rpc->State == RPC_CLIENT_STATE_ESTABLISHED)
-       {
-               if (rpc_send_bind_pdu(rpc) < 0)
-               {
-                       WLog_ERR(TAG, "rpc_send_bind_pdu failure");
-                       return -1;
-               }
-
-               rpc->State = RPC_CLIENT_STATE_WAIT_SECURE_BIND_ACK;
-       }
-       else if (rpc->State ==  RPC_CLIENT_STATE_WAIT_SECURE_BIND_ACK)
-       {
-               pdu = rpc_recv_dequeue_pdu(rpc, TRUE);
-
-               if (!pdu)
-               {
-                       WLog_ERR(TAG, "rpc_recv_dequeue_pdu failure");
-                       return -1;
-               }
-
-               if (rpc_recv_bind_ack_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s)) <= 0)
-               {
-                       WLog_ERR(TAG, "rpc_recv_bind_ack_pdu failure");
-                       return -1;
-               }
-
-               rpc_client_receive_pool_return(rpc, pdu);
-
-               if (rpc_send_rpc_auth_3_pdu(rpc) <= 0)
-               {
-                       WLog_ERR(TAG, "rpc_secure_bind: error sending rpc_auth_3 pdu!");
-                       return -1;
-               }
-
-               rpc->State = RPC_CLIENT_STATE_CONTEXT_NEGOTIATED;
-       }
-       else
-       {
-               WLog_ERR(TAG, "rpc_check: invalid state: %d", rpc->State);
-               return -1;
-       }
-
-       return 1;
-}
-
-int rpc_connect(rdpRpc* rpc)
-{
-       rpc->TlsIn = rpc->transport->TlsIn;
-       rpc->TlsOut = rpc->transport->TlsOut;
-
-       if (!rts_connect(rpc))
-       {
-               WLog_ERR(TAG, "rts_connect error!");
-               return -1;
-       }
-
-       rpc_client_start(rpc);
-
-       while (rpc->State != RPC_CLIENT_STATE_CONTEXT_NEGOTIATED)
-       {
-               if (rpc_check(rpc) < 0)
-               {
-                       WLog_ERR(TAG, "rpc_check failure");
-                       return -1;
-               }
-       }
-
-       return 1;
-}
-
 void rpc_client_virtual_connection_init(rdpRpc* rpc, RpcVirtualConnection* connection)
 {
        connection->DefaultInChannel->State = CLIENT_IN_CHANNEL_STATE_INITIAL;
index 488d82d..3504a64 100644 (file)
@@ -707,16 +707,11 @@ struct rpc_client
        HANDLE StopEvent;
 
        RPC_PDU* pdu;
-       wQueue* ReceivePool;
-       wQueue* ReceiveQueue;
-
        HANDLE PipeEvent;
        RingBuffer ReceivePipe;
        wStream* ReceiveFragment;
        CRITICAL_SECTION PipeLock;
-
        wArrayList* ClientCallList;
-
        BOOL SynchronousReceive;
 };
 typedef struct rpc_client RpcClient;
@@ -769,8 +764,6 @@ struct rdp_rpc
        wArrayList* VirtualConnectionCookieTable;
 };
 
-int rpc_connect(rdpRpc* rpc);
-
 void rpc_pdu_header_print(rpcconn_hdr_t* header);
 void rpc_pdu_header_init(rdpRpc* rpc, rpcconn_hdr_t* header);
 
index 4849199..bedadec 100644 (file)
@@ -29,6 +29,7 @@
 #include <winpr/thread.h>
 #include <winpr/stream.h>
 
+#include "rpc_bind.h"
 #include "rpc_fault.h"
 #include "rpc_client.h"
 #include "../rdp.h"
 
 #define SYNCHRONOUS_TIMEOUT 5000
 
+static void rpc_pdu_reset(RPC_PDU* pdu)
+{
+       pdu->Type = 0;
+       pdu->Flags = 0;
+       pdu->CallId = 0;
+       Stream_SetPosition(pdu->s, 0);
+}
+
 RPC_PDU* rpc_pdu_new()
 {
        RPC_PDU* pdu;
@@ -54,9 +63,7 @@ RPC_PDU* rpc_pdu_new()
                return NULL;
        }
 
-       pdu->Type = 0;
-       pdu->Flags = 0;
-       pdu->CallId = 0;
+       rpc_pdu_reset(pdu);
 
        return pdu;
 }
@@ -70,30 +77,6 @@ static void rpc_pdu_free(RPC_PDU* pdu)
        free(pdu);
 }
 
-RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc)
-{
-       RPC_PDU* pdu = NULL;
-
-       if (WaitForSingleObject(Queue_Event(rpc->client->ReceivePool), 0) == WAIT_OBJECT_0)
-               pdu = Queue_Dequeue(rpc->client->ReceivePool);
-
-       if (!pdu)
-               pdu = rpc_pdu_new();
-
-       pdu->Type = 0;
-       pdu->Flags = 0;
-       pdu->CallId = 0;
-
-       Stream_SetPosition(pdu->s, 0);
-
-       return pdu;
-}
-
-int rpc_client_receive_pool_return(rdpRpc* rpc, RPC_PDU* pdu)
-{
-       return Queue_Enqueue(rpc->client->ReceivePool, pdu) == TRUE ? 0 : -1;
-}
-
 int rpc_client_receive_pipe_write(rdpRpc* rpc, const BYTE* buffer, size_t length)
 {
        int status = 0;
@@ -143,7 +126,114 @@ int rpc_client_receive_pipe_read(rdpRpc* rpc, BYTE* buffer, size_t length)
 
 int rpc_client_recv_pdu(rdpRpc* rpc, RPC_PDU* pdu)
 {
-       Queue_Enqueue(rpc->client->ReceiveQueue, pdu);
+       rpcconn_rts_hdr_t* rts;
+       rdpTsg* tsg = rpc->transport->tsg;
+
+       if (rpc->State < RPC_CLIENT_STATE_ESTABLISHED)
+       {
+               switch (rpc->VirtualConnection->State)
+               {
+                       case VIRTUAL_CONNECTION_STATE_INITIAL:
+                               break;
+
+                       case VIRTUAL_CONNECTION_STATE_OUT_CHANNEL_WAIT:
+                               break;
+
+                       case VIRTUAL_CONNECTION_STATE_WAIT_A3W:
+
+                               rts = (rpcconn_rts_hdr_t*) Stream_Buffer(pdu->s);
+
+                               if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_A3_SIGNATURE, rts))
+                               {
+                                       WLog_ERR(TAG, "unexpected RTS PDU: Expected CONN/A3");
+                                       return -1;
+                               }
+
+                               rts_recv_CONN_A3_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s));
+
+                               rpc->VirtualConnection->State = VIRTUAL_CONNECTION_STATE_WAIT_C2;
+                               WLog_DBG(TAG, "VIRTUAL_CONNECTION_STATE_WAIT_C2");
+
+                               break;
+
+                       case VIRTUAL_CONNECTION_STATE_WAIT_C2:
+
+                               rts = (rpcconn_rts_hdr_t*) Stream_Buffer(pdu->s);
+
+                               if (!rts_match_pdu_signature(rpc, &RTS_PDU_CONN_C2_SIGNATURE, rts))
+                               {
+                                       WLog_ERR(TAG, "unexpected RTS PDU: Expected CONN/C2");
+                                       return FALSE;
+                               }
+
+                               rts_recv_CONN_C2_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s));
+
+                               rpc->VirtualConnection->State = VIRTUAL_CONNECTION_STATE_OPENED;
+                               WLog_DBG(TAG, "VIRTUAL_CONNECTION_STATE_OPENED");
+
+                               rpc->State = RPC_CLIENT_STATE_ESTABLISHED;
+
+                               if (rpc_send_bind_pdu(rpc) < 0)
+                               {
+                                       WLog_ERR(TAG, "rpc_send_bind_pdu failure");
+                                       return -1;
+                               }
+
+                               rpc->State = RPC_CLIENT_STATE_WAIT_SECURE_BIND_ACK;
+
+                               break;
+
+                       case VIRTUAL_CONNECTION_STATE_OPENED:
+                               break;
+
+                       case VIRTUAL_CONNECTION_STATE_FINAL:
+                               break;
+               }
+
+               return 1;
+       }
+
+       if (rpc->State < RPC_CLIENT_STATE_CONTEXT_NEGOTIATED)
+       {
+               if (rpc->State ==  RPC_CLIENT_STATE_WAIT_SECURE_BIND_ACK)
+               {
+                       if (rpc_recv_bind_ack_pdu(rpc, Stream_Buffer(pdu->s), Stream_Length(pdu->s)) <= 0)
+                       {
+                               WLog_ERR(TAG, "rpc_recv_bind_ack_pdu failure");
+                               return -1;
+                       }
+
+                       if (rpc_send_rpc_auth_3_pdu(rpc) <= 0)
+                       {
+                               WLog_ERR(TAG, "rpc_secure_bind: error sending rpc_auth_3 pdu!");
+                               return -1;
+                       }
+
+                       rpc->State = RPC_CLIENT_STATE_CONTEXT_NEGOTIATED;
+
+                       if (!TsProxyCreateTunnel(tsg, NULL, NULL, NULL, NULL))
+                       {
+                               WLog_ERR(TAG, "TsProxyCreateTunnel failure");
+                               tsg->state = TSG_STATE_FINAL;
+                               return -1;
+                       }
+
+                       tsg->state = TSG_STATE_INITIAL;
+               }
+               else
+               {
+                       WLog_ERR(TAG, "rpc_client_recv_pdu: invalid rpc->State: %d", rpc->State);
+                       return -1;
+               }
+
+               return 1;
+       }
+
+       if (tsg->state != TSG_STATE_PIPE_CREATED)
+       {
+               return tsg_recv_pdu(tsg, pdu);
+       }
+
        return 1;
 }
 
@@ -156,6 +246,7 @@ int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment)
        RpcClientCall* call;
        rpcconn_hdr_t* header;
 
+       pdu = rpc->client->pdu;
        buffer = (BYTE*) Stream_Buffer(fragment);
        header = (rpcconn_hdr_t*) Stream_Buffer(fragment);
 
@@ -210,11 +301,6 @@ int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment)
 
                if (call->OpNum != TsProxySetupReceivePipeOpnum)
                {
-                       if (!rpc->client->pdu)
-                               rpc->client->pdu = rpc_client_receive_pool_take(rpc);
-
-                       pdu = rpc->client->pdu;
-
                        Stream_EnsureCapacity(pdu->s, header->response.alloc_hint);
                        Stream_Write(pdu->s, &buffer[StubOffset], StubLength);
                        rpc->StubFragCount++;
@@ -226,7 +312,7 @@ int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment)
                                pdu->CallId = rpc->StubCallId;
                                Stream_SealLength(pdu->s);
                                rpc_client_recv_pdu(rpc, pdu);
-                               rpc->client->pdu = NULL;
+                               rpc_pdu_reset(pdu);
                                rpc->StubFragCount = 0;
                                rpc->StubCallId = 0;
                        }
@@ -243,17 +329,12 @@ int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment)
                        }
                }
 
-               return 0;
+               return 1;
        }
        else if (header->common.ptype == PTYPE_RTS)
        {
                if (rpc->State < RPC_CLIENT_STATE_CONTEXT_NEGOTIATED)
                {
-                       if (!rpc->client->pdu)
-                               rpc->client->pdu = rpc_client_receive_pool_take(rpc);
-
-                       pdu = rpc->client->pdu;
-
                        pdu->Flags = 0;
                        pdu->Type = header->common.ptype;
                        pdu->CallId = header->common.call_id;
@@ -261,8 +342,7 @@ int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment)
                        Stream_Write(pdu->s, buffer, Stream_Length(fragment));
                        Stream_SealLength(pdu->s);
                        rpc_client_recv_pdu(rpc, pdu);
-                       rpc->client->pdu = NULL;
-                       return 0;
+                       rpc_pdu_reset(pdu);
                }
                else
                {
@@ -272,14 +352,11 @@ int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment)
                        WLog_DBG(TAG, "Receiving Out-of-Sequence RTS PDU");
                        rts_recv_out_of_sequence_pdu(rpc, buffer, header->common.frag_length);
                }
+
+               return 1;
        }
        else if (header->common.ptype == PTYPE_BIND_ACK)
        {
-               if (!rpc->client->pdu)
-                       rpc->client->pdu = rpc_client_receive_pool_take(rpc);
-
-               pdu = rpc->client->pdu;
-
                pdu->Flags = 0;
                pdu->Type = header->common.ptype;
                pdu->CallId = header->common.call_id;
@@ -287,8 +364,9 @@ int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment)
                Stream_Write(pdu->s, buffer, Stream_Length(fragment));
                Stream_SealLength(pdu->s);
                rpc_client_recv_pdu(rpc, pdu);
-               rpc->client->pdu = NULL;
-               return 0;
+               rpc_pdu_reset(pdu);
+
+               return 1;
        }
        else if (header->common.ptype == PTYPE_FAULT)
        {
@@ -298,9 +376,10 @@ int rpc_client_recv_fragment(rdpRpc* rpc, wStream* fragment)
        else
        {
                WLog_ERR(TAG, "unexpected RPC PDU type 0x%04X", header->common.ptype);
+               return -1;
        }
 
-       return 0;
+       return 1;
 }
 
 int rpc_client_recv(rdpRpc* rpc)
@@ -376,12 +455,9 @@ int rpc_client_recv(rdpRpc* rpc)
 
                        Stream_SetPosition(rpc->client->ReceiveFragment, 0);
                }
-
-               //if (WaitForSingleObject(Queue_Event(rpc->client->SendQueue), 0) == WAIT_OBJECT_0)
-               //      break;
        }
 
-       return 0;
+       return 1;
 }
 
 /**
@@ -466,30 +542,6 @@ int rpc_send_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length)
        return status;
 }
 
-RPC_PDU* rpc_recv_dequeue_pdu(rdpRpc* rpc, BOOL blocking)
-{
-       RPC_PDU* pdu;
-       DWORD timeout;
-       DWORD waitStatus;
-
-       timeout = blocking ? SYNCHRONOUS_TIMEOUT * 4 : 0;
-
-       waitStatus = WaitForSingleObject(Queue_Event(rpc->client->ReceiveQueue), timeout);
-
-       if (waitStatus == WAIT_TIMEOUT)
-       {
-               WLog_ERR(TAG, "timed out waiting for receive event");
-               return NULL;
-       }
-
-       if (waitStatus != WAIT_OBJECT_0)
-               return NULL;
-
-       pdu = (RPC_PDU*) Queue_Dequeue(rpc->client->ReceiveQueue);
-
-       return pdu;
-}
-
 static void* rpc_client_thread(void* arg)
 {
        DWORD nCount;
@@ -537,6 +589,26 @@ static void* rpc_client_thread(void* arg)
        return NULL;
 }
 
+int rpc_client_start(rdpRpc* rpc)
+{
+       rpc->client->Thread = CreateThread(NULL, 0,
+               (LPTHREAD_START_ROUTINE) rpc_client_thread, rpc, 0, NULL);
+
+       return 0;
+}
+
+int rpc_client_stop(rdpRpc* rpc)
+{
+       if (rpc->client->Thread)
+       {
+               SetEvent(rpc->client->StopEvent);
+               WaitForSingleObject(rpc->client->Thread, INFINITE);
+               rpc->client->Thread = NULL;
+       }
+
+       return 0;
+}
+
 int rpc_client_new(rdpRpc* rpc)
 {
        RpcClient* client;
@@ -553,22 +625,11 @@ int rpc_client_new(rdpRpc* rpc)
        if (!client->StopEvent)
                return -1;
 
-       client->pdu = NULL;
+       client->pdu = rpc_pdu_new();
 
-       client->ReceivePool = Queue_New(TRUE, -1, -1);
-
-       if (!client->ReceivePool)
+       if (!client->pdu)
                return -1;
 
-       Queue_Object(client->ReceivePool)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
-
-       client->ReceiveQueue = Queue_New(TRUE, -1, -1);
-
-       if (!client->ReceiveQueue)
-               return -1;
-
-       Queue_Object(client->ReceiveQueue)->fnObjectFree = (OBJECT_FREE_FN) rpc_pdu_free;
-
        client->ReceiveFragment = Stream_New(NULL, rpc->max_recv_frag);
 
        if (!client->ReceiveFragment)
@@ -591,35 +652,16 @@ int rpc_client_new(rdpRpc* rpc)
                return -1;
 
        ArrayList_Object(client->ClientCallList)->fnObjectFree = (OBJECT_FREE_FN) rpc_client_call_free;
-       return 0;
-}
 
-int rpc_client_start(rdpRpc* rpc)
-{
-       rpc->client->Thread = CreateThread(NULL, 0,
-               (LPTHREAD_START_ROUTINE) rpc_client_thread, rpc, 0, NULL);
-
-       return 0;
-}
-
-int rpc_client_stop(rdpRpc* rpc)
-{
-       if (rpc->client->Thread)
-       {
-               SetEvent(rpc->client->StopEvent);
-               WaitForSingleObject(rpc->client->Thread, INFINITE);
-               rpc->client->Thread = NULL;
-       }
-
-       return 0;
+       return 1;
 }
 
-int rpc_client_free(rdpRpc* rpc)
+void rpc_client_free(rdpRpc* rpc)
 {
        RpcClient* client = rpc->client;
 
        if (!client)
-               return 0;
+               return;
 
        rpc_client_stop(rpc);
 
@@ -636,12 +678,6 @@ int rpc_client_free(rdpRpc* rpc)
        if (client->pdu)
                rpc_pdu_free(client->pdu);
 
-       if (client->ReceivePool)
-               Queue_Free(client->ReceivePool);
-
-       if (client->ReceiveQueue)
-               Queue_Free(client->ReceiveQueue);
-
        if (client->ClientCallList)
                ArrayList_Free(client->ClientCallList);
 
@@ -653,6 +689,4 @@ int rpc_client_free(rdpRpc* rpc)
 
        free(client);
        rpc->client = NULL;
-
-       return 0;
 }
index edafaa3..fe27ce6 100644 (file)
 
 #include "rpc.h"
 
-#include <winpr/interlocked.h>
-
-wStream* rpc_client_fragment_pool_take(rdpRpc* rpc);
-int rpc_client_fragment_pool_return(rdpRpc* rpc, wStream* fragment);
-
-RPC_PDU* rpc_client_receive_pool_take(rdpRpc* rpc);
-int rpc_client_receive_pool_return(rdpRpc* rpc, RPC_PDU* pdu);
-
 RpcClientCall* rpc_client_call_find_by_id(rdpRpc* rpc, UINT32 CallId);
 
 RpcClientCall* rpc_client_call_new(UINT32 CallId, UINT32 OpNum);
 void rpc_client_call_free(RpcClientCall* client_call);
 
 int rpc_send_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length);
-int rpc_send_enqueue_pdu(rdpRpc* rpc, BYTE* buffer, UINT32 length);
-int rpc_send_dequeue_pdu(rdpRpc* rpc);
-
-int rpc_recv_enqueue_pdu(rdpRpc* rpc);
-RPC_PDU* rpc_recv_dequeue_pdu(rdpRpc* rpc, BOOL blocking);
-
 int rpc_client_receive_pipe_read(rdpRpc* rpc, BYTE* buffer, size_t length);
 
 int rpc_client_new(rdpRpc* rpc);
 int rpc_client_start(rdpRpc* rpc);
 int rpc_client_stop(rdpRpc* rpc);
-int rpc_client_free(rdpRpc* rpc);
+void rpc_client_free(rdpRpc* rpc);
 
 #endif /* FREERDP_CORE_RPC_CLIENT_H */
index f3ae7be..ff59541 100644 (file)
@@ -237,7 +237,6 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
        UINT32 MessageSwitchValue = 0;
        UINT32 IsMessagePresent;
        UINT32 MsgBytes;
-       rdpRpc* rpc = tsg->rpc;
        PTSG_PACKET_CAPABILITIES tsgCaps;
        PTSG_PACKET_VERSIONCAPS versionCaps;
        PTSG_PACKET_CAPS_RESPONSE packetCapsResponse;
@@ -527,13 +526,12 @@ BOOL TsProxyCreateTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
                return FALSE;
        }
 
-       rpc_client_receive_pool_return(rpc, pdu);
        free(packet);
        return TRUE;
 }
 
 BOOL TsProxyCreateTunnel(rdpTsg* tsg, PTSG_PACKET tsgPacket, PTSG_PACKET* tsgPacketResponse,
-                                                PTUNNEL_CONTEXT_HANDLE_SERIALIZE* tunnelContext, UINT32* tunnelId)
+                       PTUNNEL_CONTEXT_HANDLE_SERIALIZE* tunnelContext, UINT32* tunnelId)
 {
        /**
         * OpNum = 1
@@ -622,7 +620,6 @@ BOOL TsProxyAuthorizeTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
        UINT32 SwitchValue;
        UINT32 idleTimeout;
        PTSG_PACKET packet;
-       rdpRpc* rpc = tsg->rpc;
        PTSG_PACKET_RESPONSE packetResponse;
 
        if (!pdu)
@@ -711,8 +708,6 @@ BOOL TsProxyAuthorizeTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
                offset += SizeValue; /* ResponseData */
        }
        
-       rpc_client_receive_pool_return(rpc, pdu);
-       
        free(packetResponse);
        free(packet);
        
@@ -720,7 +715,7 @@ BOOL TsProxyAuthorizeTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
 }
 
 BOOL TsProxyAuthorizeTunnel(rdpTsg* tsg, PTUNNEL_CONTEXT_HANDLE_NOSERIALIZE tunnelContext,
-                                                       PTSG_PACKET tsgPacket, PTSG_PACKET* tsgPacketResponse)
+                               PTSG_PACKET tsgPacket, PTSG_PACKET* tsgPacketResponse)
 {
        /**
         * OpNum = 2
@@ -908,7 +903,7 @@ BOOL TsProxyMakeTunnelCallReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
 }
 
 BOOL TsProxyMakeTunnelCall(rdpTsg* tsg, PTUNNEL_CONTEXT_HANDLE_NOSERIALIZE tunnelContext,
-                                                  UINT32 procId, PTSG_PACKET tsgPacket, PTSG_PACKET* tsgPacketResponse)
+                       UINT32 procId, PTSG_PACKET tsgPacket, PTSG_PACKET* tsgPacketResponse)
 {
        /**
         * OpNum = 3
@@ -981,7 +976,6 @@ BOOL TsProxyCreateChannelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
        BYTE* buffer;
        UINT32 length;
        UINT32 offset;
-       rdpRpc* rpc = tsg->rpc;
 
        if (!pdu)
                return FALSE;
@@ -997,13 +991,11 @@ BOOL TsProxyCreateChannelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
        CopyMemory(&tsg->ChannelContext.ContextType, &buffer[offset], 4); /* ContextType (4 bytes) */
        CopyMemory(tsg->ChannelContext.ContextUuid, &buffer[offset + 4], 16); /* ContextUuid (16 bytes) */
 
-       rpc_client_receive_pool_return(rpc, pdu);
-
        return TRUE;
 }
 
 BOOL TsProxyCreateChannel(rdpTsg* tsg, PTUNNEL_CONTEXT_HANDLE_NOSERIALIZE tunnelContext, PTSENDPOINTINFO tsEndPointInfo,
-                                                 PCHANNEL_CONTEXT_HANDLE_SERIALIZE* channelContext, UINT32* channelId)
+                                       PCHANNEL_CONTEXT_HANDLE_SERIALIZE* channelContext, UINT32* channelId)
 {
        /**
         * OpNum = 4
@@ -1057,7 +1049,6 @@ BOOL TsProxyCloseChannelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
        BYTE* buffer;
        UINT32 length;
        UINT32 offset;
-       rdpRpc* rpc = tsg->rpc;
 
        if (!pdu)
                return FALSE;
@@ -1069,15 +1060,12 @@ BOOL TsProxyCloseChannelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
                buffer = &buffer[24];
 
        offset = 0;
-       rpc_client_receive_pool_return(rpc, pdu);
 
        return TRUE;
 }
 
 HRESULT TsProxyCloseChannel(rdpTsg* tsg, PCHANNEL_CONTEXT_HANDLE_NOSERIALIZE* context)
 {
-       RPC_PDU* pdu = NULL;
-
        /**
         * HRESULT TsProxyCloseChannel(
         * [in, out] PCHANNEL_CONTEXT_HANDLE_NOSERIALIZE* context
@@ -1092,14 +1080,6 @@ HRESULT TsProxyCloseChannel(rdpTsg* tsg, PCHANNEL_CONTEXT_HANDLE_NOSERIALIZE* co
                return FALSE;
        }
 
-       pdu = rpc_recv_dequeue_pdu(tsg->rpc, TRUE);
-
-       if (!TsProxyCloseChannelReadResponse(tsg, pdu))
-       {
-               WLog_ERR(TAG, "error reading response");
-               return FALSE;
-       }
-
        return TRUE;
 }
 
@@ -1133,7 +1113,6 @@ BOOL TsProxyCloseTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
        BYTE* buffer;
        UINT32 length;
        UINT32 offset;
-       rdpRpc* rpc = tsg->rpc;
 
        if (!pdu)
                return FALSE;
@@ -1145,15 +1124,12 @@ BOOL TsProxyCloseTunnelReadResponse(rdpTsg* tsg, RPC_PDU* pdu)
                buffer = &buffer[24];
 
        offset = 0;
-       rpc_client_receive_pool_return(rpc, pdu);
 
        return TRUE;
 }
 
 HRESULT TsProxyCloseTunnel(rdpTsg* tsg, PTUNNEL_CONTEXT_HANDLE_SERIALIZE* context)
 {
-       RPC_PDU* pdu = NULL;
-
        /**
         * HRESULT TsProxyCloseTunnel(
         * [in, out] PTUNNEL_CONTEXT_HANDLE_SERIALIZE* context
@@ -1168,14 +1144,6 @@ HRESULT TsProxyCloseTunnel(rdpTsg* tsg, PTUNNEL_CONTEXT_HANDLE_SERIALIZE* contex
                return FALSE;
        }
 
-       pdu = rpc_recv_dequeue_pdu(tsg->rpc, TRUE);
-
-       if (!TsProxyCloseTunnelReadResponse(tsg, pdu))
-       {
-               WLog_ERR(TAG, "error reading response");
-               return FALSE;
-       }
-
        return TRUE;
 }
 
@@ -1232,9 +1200,8 @@ BOOL TsProxySetupReceivePipe(handle_t IDL_handle, BYTE* pRpcMessage)
        return TRUE;
 }
 
-int tsg_check(rdpTsg* tsg)
+int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu)
 {
-       RPC_PDU* pdu;
        RpcClientCall* call;
        rdpRpc* rpc = tsg->rpc;
 
@@ -1242,15 +1209,6 @@ int tsg_check(rdpTsg* tsg)
        {
                case TSG_STATE_INITIAL:
 
-                       if (!TsProxyCreateTunnel(tsg, NULL, NULL, NULL, NULL))
-                       {
-                               WLog_ERR(TAG, "TsProxyCreateTunnel failure");
-                               tsg->state = TSG_STATE_FINAL;
-                               return -1;
-                       }
-
-                       pdu = rpc_recv_dequeue_pdu(rpc, TRUE);
-
                        if (!TsProxyCreateTunnelReadResponse(tsg, pdu))
                        {
                                WLog_ERR(TAG, "TsProxyCreateTunnelReadResponse failure");
@@ -1259,10 +1217,6 @@ int tsg_check(rdpTsg* tsg)
 
                        tsg->state = TSG_STATE_CONNECTED;
 
-                       break;
-
-               case TSG_STATE_CONNECTED:
-
                        if (!TsProxyAuthorizeTunnel(tsg, &tsg->TunnelContext, NULL, NULL))
                        {
                                WLog_ERR(TAG, "TsProxyAuthorizeTunnel failure");
@@ -1270,7 +1224,9 @@ int tsg_check(rdpTsg* tsg)
                                return -1;
                        }
 
-                       pdu = rpc_recv_dequeue_pdu(rpc, TRUE);
+                       break;
+
+               case TSG_STATE_CONNECTED:
 
                        if (!TsProxyAuthorizeTunnelReadResponse(tsg, pdu))
                        {
@@ -1280,10 +1236,6 @@ int tsg_check(rdpTsg* tsg)
 
                        tsg->state = TSG_STATE_AUTHORIZED;
 
-                       break;
-
-               case TSG_STATE_AUTHORIZED:
-
                        if (!TsProxyMakeTunnelCall(tsg, &tsg->TunnelContext, TSG_TUNNEL_CALL_ASYNC_MSG_REQUEST, NULL, NULL))
                        {
                                WLog_ERR(TAG, "TsProxyMakeTunnelCall failure");
@@ -1296,13 +1248,9 @@ int tsg_check(rdpTsg* tsg)
                                return -1;
                        }
 
-                       pdu = rpc_recv_dequeue_pdu(rpc, TRUE);
+                       break;
 
-                       if (!pdu)
-                       {
-                               WLog_ERR(TAG, "rpc_recv_dequeue_pdu failure");
-                               return -1;
-                       }
+               case TSG_STATE_AUTHORIZED:
 
                        call = rpc_client_call_find_by_id(rpc, pdu->CallId);
 
@@ -1313,39 +1261,72 @@ int tsg_check(rdpTsg* tsg)
                                        WLog_ERR(TAG, "TsProxyMakeTunnelCallReadResponse failure");
                                        return -1;
                                }
-
-                               pdu = rpc_recv_dequeue_pdu(rpc, TRUE);
                        }
-
-                       if (!TsProxyCreateChannelReadResponse(tsg, pdu))
+                       else if (call->OpNum == TsProxyCreateChannelOpnum)
                        {
-                               WLog_ERR(TAG, "TsProxyCreateChannelReadResponse failure");
-                               return -1;
-                       }
-
-                       tsg->state = TSG_STATE_CHANNEL_CREATED;
+                               if (!TsProxyCreateChannelReadResponse(tsg, pdu))
+                               {
+                                       WLog_ERR(TAG, "TsProxyCreateChannelReadResponse failure");
+                                       return -1;
+                               }
 
-                       break;
+                               tsg->state = TSG_STATE_CHANNEL_CREATED;
 
-               case TSG_STATE_CHANNEL_CREATED:
+                               if (!TsProxySetupReceivePipe((handle_t) tsg, NULL))
+                               {
+                                       WLog_ERR(TAG, "TsProxySetupReceivePipe failure");
+                                       return -1;
+                               }
 
-                       if (!TsProxySetupReceivePipe((handle_t) tsg, NULL))
+                               tsg->state = TSG_STATE_PIPE_CREATED;
+                       }
+                       else
                        {
-                               WLog_ERR(TAG, "TsProxySetupReceivePipe failure");
-                               return -1;
+                               WLog_ERR(TAG, "TSG_STATE_AUTHORIZED unexpected OpNum: %d\n", call->OpNum);
                        }
 
-                       tsg->state = TSG_STATE_PIPE_CREATED;
+                       break;
 
+               case TSG_STATE_CHANNEL_CREATED:
                        break;
 
                case TSG_STATE_PIPE_CREATED:
                        break;
 
                case TSG_STATE_TUNNEL_CLOSE_PENDING:
+
+                       if (!TsProxyCloseChannelReadResponse(tsg, pdu))
+                       {
+                               WLog_ERR(TAG, "TsProxyCloseChannelReadResponse failure");
+                               return FALSE;
+                       }
+
+                       tsg->state = TSG_STATE_CHANNEL_CLOSE_PENDING;
+
+                       if (!TsProxyCloseChannelWriteRequest(tsg, NULL))
+                       {
+                               WLog_ERR(TAG, "TsProxyCloseChannelWriteRequest failure");
+                               return FALSE;
+                       }
+
+                       if (!TsProxyMakeTunnelCall(tsg, &tsg->TunnelContext, TSG_TUNNEL_CANCEL_ASYNC_MSG_REQUEST, NULL, NULL))
+                       {
+                               WLog_ERR(TAG, "TsProxyMakeTunnelCall failure");
+                               return FALSE;
+                       }
+
                        break;
 
                case TSG_STATE_CHANNEL_CLOSE_PENDING:
+
+                       if (!TsProxyCloseTunnelReadResponse(tsg, pdu))
+                       {
+                               WLog_ERR(TAG, "TsProxyCloseTunnelReadResponse failure");
+                               return FALSE;
+                       }
+
+                       tsg->state = TSG_STATE_FINAL;
+
                        break;
 
                case TSG_STATE_FINAL:
@@ -1362,6 +1343,9 @@ BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port)
 
        tsg->Port = port;
 
+       rpc->TlsIn = rpc->transport->TlsIn;
+       rpc->TlsOut = rpc->transport->TlsOut;
+
        free(tsg->Hostname);
        tsg->Hostname = NULL;
        ConvertToUnicode(CP_UTF8, 0, hostname, -1, &tsg->Hostname, 0);
@@ -1370,15 +1354,17 @@ BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port)
        tsg->MachineName = NULL;
        ConvertToUnicode(CP_UTF8, 0, settings->ComputerName, -1, &tsg->MachineName, 0);
 
-       tsg->state = TSG_STATE_INITIAL;
+       if (!rts_connect(rpc))
+       {
+               WLog_ERR(TAG, "rts_connect error!");
+               return -1;
+       }
+
+       rpc_client_start(rpc);
 
        while (tsg->state != TSG_STATE_PIPE_CREATED)
        {
-               if (tsg_check(tsg) < 0)
-               {
-                       WLog_ERR(TAG, "tsg_connect error");
-                       return FALSE;
-               }
+               USleep(100);
        }
 
        tsg->bio = BIO_new(BIO_s_tsg());
@@ -1414,22 +1400,13 @@ BOOL tsg_disconnect(rdpTsg* tsg)
        if (!tsg)
                return FALSE;
 
-       tsg->rpc->client->SynchronousReceive = TRUE;
-
-       /* if we are already in state pending (i.e. if a server initiated disconnect was issued)
-          we have to skip TsProxyCloseChannel - see Figure 13 in section 3.2.3
-        */
        if (tsg->state != TSG_STATE_TUNNEL_CLOSE_PENDING)
        {
                if (!TsProxyCloseChannel(tsg, NULL))
                        return FALSE;
-       }
-
-       if (!TsProxyMakeTunnelCall(tsg, &tsg->TunnelContext, TSG_TUNNEL_CANCEL_ASYNC_MSG_REQUEST, NULL, NULL))
-               return FALSE;
 
-       if (!TsProxyCloseTunnel(tsg, NULL))
-               return FALSE;
+               tsg->state = TSG_STATE_CHANNEL_CLOSE_PENDING;
+       }
 
        return TRUE;
 }
index fa43e92..38da142 100644 (file)
@@ -299,6 +299,9 @@ typedef struct _TSG_PACKET
        TSG_PACKET_TYPE_UNION tsgPacket;
 } TSG_PACKET, *PTSG_PACKET;
 
+BOOL TsProxyCreateTunnel(rdpTsg* tsg, PTSG_PACKET tsgPacket, PTSG_PACKET* tsgPacketResponse,
+                       PTUNNEL_CONTEXT_HANDLE_SERIALIZE* tunnelContext, UINT32* tunnelId);
+
 DWORD TsProxySendToServer(handle_t IDL_handle, BYTE pRpcMessage[], UINT32 count, UINT32* lengths);
 
 BOOL tsg_connect(rdpTsg* tsg, const char* hostname, UINT16 port);
@@ -307,6 +310,8 @@ BOOL tsg_disconnect(rdpTsg* tsg);
 int tsg_write(rdpTsg* tsg, BYTE* data, UINT32 length);
 int tsg_read(rdpTsg* tsg, BYTE* data, UINT32 length);
 
+int tsg_recv_pdu(rdpTsg* tsg, RPC_PDU* pdu);
+
 BOOL tsg_set_blocking_mode(rdpTsg* tsg, BOOL blocking);
 
 rdpTsg* tsg_new(rdpTransport* transport);
index 2e66962..9638552 100644 (file)
@@ -22,9 +22,6 @@
 #endif
 
 #include <assert.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <string.h>
 
 #include <winpr/crt.h>
 #include <winpr/synch.h>
@@ -351,14 +348,6 @@ BOOL transport_tsg_connect(rdpTransport* transport, const char* hostname, UINT16
                return FALSE;
        }
 
-       if (rpc_connect(tsg->rpc) < 0)
-       {
-               WLog_ERR(TAG, "rpc_connect failed!");
-               return FALSE;
-       }
-
-       WLog_DBG(TAG, "rpc_connect success");
-
        if (!tsg_connect(tsg, hostname, port))
                return FALSE;