Implement chunked encoding for RDG_OUT_DATA
authorakarl10 <mike@mwsys.mine.bz>
Fri, 22 Jan 2021 09:53:00 +0000 (10:53 +0100)
committerakallabeth <akallabeth@users.noreply.github.com>
Thu, 25 Feb 2021 08:51:41 +0000 (09:51 +0100)
(cherry picked from commit cfb1a878a3b1c2c469c3d15e135f21f388c177b8)

libfreerdp/core/gateway/http.c
libfreerdp/core/gateway/http.h
libfreerdp/core/gateway/rdg.c

index 8481010..85f0908 100644 (file)
@@ -63,7 +63,7 @@ struct _http_request
        char* Authorization;
        size_t ContentLength;
        char* Content;
-       char* TransferEncoding;
+       TRANSFER_ENCODING TransferEncoding;
 };
 
 struct _http_response
@@ -76,6 +76,7 @@ struct _http_response
 
        size_t ContentLength;
        const char* ContentType;
+       TRANSFER_ENCODING TransferEncoding;
 
        size_t BodyLength;
        BYTE* BodyContent;
@@ -342,16 +343,12 @@ BOOL http_request_set_auth_param(HttpRequest* request, const char* AuthParam)
        return TRUE;
 }
 
-BOOL http_request_set_transfer_encoding(HttpRequest* request, const char* TransferEncoding)
+BOOL http_request_set_transfer_encoding(HttpRequest* request, TRANSFER_ENCODING TransferEncoding)
 {
-       if (!request || !TransferEncoding)
+       if (!request || TransferEncoding == TransferEncodingUnknown)
                return FALSE;
 
-       free(request->TransferEncoding);
-       request->TransferEncoding = _strdup(TransferEncoding);
-
-       if (!request->TransferEncoding)
-               return FALSE;
+       request->TransferEncoding = TransferEncoding;
 
        return TRUE;
 }
@@ -448,9 +445,14 @@ wStream* http_request_write(HttpContext* context, HttpRequest* request)
                        goto fail;
        }
 
-       if (request->TransferEncoding)
+       if (request->TransferEncoding != TransferEncodingIdentity)
        {
-               if (!http_encode_body_line(s, "Transfer-Encoding", request->TransferEncoding))
+               if (request->TransferEncoding == TransferEncodingChunked)
+               {
+                       if (!http_encode_body_line(s, "Transfer-Encoding", "chunked"))
+                               goto fail;
+               }
+               else
                        goto fail;
        }
        else
@@ -480,7 +482,12 @@ fail:
 
 HttpRequest* http_request_new(void)
 {
-       return (HttpRequest*)calloc(1, sizeof(HttpRequest));
+       HttpRequest* request = (HttpRequest*)calloc(1, sizeof(HttpRequest));
+       if (!request)
+               return NULL;
+
+       request->TransferEncoding = TransferEncodingIdentity;
+       return request;
 }
 
 void http_request_free(HttpRequest* request)
@@ -494,7 +501,6 @@ void http_request_free(HttpRequest* request)
        free(request->Content);
        free(request->Method);
        free(request->URI);
-       free(request->TransferEncoding);
        free(request);
 }
 
@@ -572,6 +578,15 @@ static BOOL http_response_parse_header_field(HttpResponse* response, const char*
                if (!response->ContentType)
                        return FALSE;
        }
+       else if (_stricmp(name, "Transfer-Encoding") == 0)
+       {
+               if (_stricmp(value, "identity") == 0)
+                       response->TransferEncoding = TransferEncodingIdentity;
+               else if (_stricmp(value, "chunked") == 0)
+                       response->TransferEncoding = TransferEncodingChunked;
+               else
+                       response->TransferEncoding = TransferEncodingUnknown;
+       }
        else if (_stricmp(name, "WWW-Authenticate") == 0)
        {
                char* separator = NULL;
@@ -948,6 +963,8 @@ HttpResponse* http_response_new(void)
 
        ListDictionary_KeyObject(response->Authenticates)->fnObjectEquals = strings_equals_nocase;
        ListDictionary_ValueObject(response->Authenticates)->fnObjectEquals = strings_equals_nocase;
+
+       response->TransferEncoding = TransferEncodingIdentity;
        return response;
 fail:
        http_response_free(response);
@@ -1006,13 +1023,21 @@ SSIZE_T http_response_get_body_length(HttpResponse* response)
        return (SSIZE_T)response->BodyLength;
 }
 
-const char* http_response_get_auth_token(HttpResponse* respone, const char* method)
+const char* http_response_get_auth_token(HttpResponse* response, const char* method)
 {
-       if (!respone || !method)
+       if (!response || !method)
                return NULL;
 
-       if (!ListDictionary_Contains(respone->Authenticates, method))
+       if (!ListDictionary_Contains(response->Authenticates, method))
                return NULL;
 
-       return ListDictionary_GetItemValue(respone->Authenticates, method);
+       return ListDictionary_GetItemValue(response->Authenticates, method);
+}
+
+TRANSFER_ENCODING http_response_get_transfer_encoding(HttpResponse* response)
+{
+       if (!response)
+               return TransferEncodingUnknown;
+
+       return response->TransferEncoding;
 }
index 6c5756b..af1be0f 100644 (file)
 #include <freerdp/api.h>
 #include <freerdp/crypto/tls.h>
 
+typedef enum _TRANSFER_ENCODING
+{
+       TransferEncodingUnknown,
+       TransferEncodingIdentity,
+       TransferEncodingChunked
+} TRANSFER_ENCODING;
+
 /* HTTP context */
 typedef struct _http_context HttpContext;
 
@@ -61,7 +68,7 @@ FREERDP_LOCAL BOOL http_request_set_uri(HttpRequest* request, const char* URI);
 FREERDP_LOCAL BOOL http_request_set_auth_scheme(HttpRequest* request, const char* AuthScheme);
 FREERDP_LOCAL BOOL http_request_set_auth_param(HttpRequest* request, const char* AuthParam);
 FREERDP_LOCAL BOOL http_request_set_transfer_encoding(HttpRequest* request,
-                                                      const char* TransferEncoding);
+                                                      TRANSFER_ENCODING TransferEncoding);
 
 FREERDP_LOCAL wStream* http_request_write(HttpContext* context, HttpRequest* request);
 
@@ -76,6 +83,7 @@ FREERDP_LOCAL HttpResponse* http_response_recv(rdpTls* tls, BOOL readContentLeng
 
 FREERDP_LOCAL long http_response_get_status_code(HttpResponse* response);
 FREERDP_LOCAL SSIZE_T http_response_get_body_length(HttpResponse* response);
-FREERDP_LOCAL const char* http_response_get_auth_token(HttpResponse* respone, const char* method);
+FREERDP_LOCAL const char* http_response_get_auth_token(HttpResponse* response, const char* method);
+FREERDP_LOCAL TRANSFER_ENCODING http_response_get_transfer_encoding(HttpResponse* response);
 
 #endif /* FREERDP_LIB_CORE_GATEWAY_HTTP_H */
index 6400d6a..9dd4fc4 100644 (file)
 #define HTTP_CAPABILITY_REAUTH 0x10
 #define HTTP_CAPABILITY_UDP_TRANSPORT 0x20
 
+typedef enum _CHUNK_STATE
+{
+       ChunkStateLenghHeader,
+       ChunkStateData,
+       ChunkStateFooter
+} CHUNK_STATE;
+
+typedef struct
+{
+       size_t nextOffset;
+       size_t headerFooterPos;
+       CHUNK_STATE state;
+       char* lenBuffer;
+} rdg_http_encoding_chunked_context;
+
+typedef struct
+{
+       TRANSFER_ENCODING httpTransferEncoding;
+       union _context
+       {
+               rdg_http_encoding_chunked_context* chunked;
+       } context;
+} rdg_http_encoding_context;
+
 struct rdp_rdg
 {
        rdpContext* context;
@@ -124,6 +148,7 @@ struct rdp_rdg
        int timeout;
        UINT16 extAuth;
        UINT16 reserved2;
+       rdg_http_encoding_context* transferEncoding;
 };
 
 enum
@@ -298,7 +323,131 @@ static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket)
        return TRUE;
 }
 
-static BOOL rdg_read_all(rdpTls* tls, BYTE* buffer, int size)
+static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size,
+                             rdg_http_encoding_context* encodingContext)
+{
+       int status;
+       int effectiveDataLen = 0;
+       assert(encodingContext != NULL);
+       while (TRUE)
+       {
+               switch (encodingContext->context.chunked->state)
+               {
+                       case ChunkStateData:
+                       {
+                               status = BIO_read(bio, pBuffer,
+                                                 (size > encodingContext->context.chunked->nextOffset
+                                                      ? encodingContext->context.chunked->nextOffset
+                                                      : size));
+                               if (status <= 0)
+                                       return (effectiveDataLen > 0 ? effectiveDataLen : status);
+
+                               encodingContext->context.chunked->nextOffset -= status;
+                               if (encodingContext->context.chunked->nextOffset == 0)
+                               {
+                                       encodingContext->context.chunked->state = ChunkStateFooter;
+                                       encodingContext->context.chunked->headerFooterPos = 0;
+                               }
+                               effectiveDataLen += status;
+
+                               if (status == size)
+                                       return effectiveDataLen;
+
+                               pBuffer += status;
+                               size -= status;
+                       }
+                       break;
+                       case ChunkStateFooter:
+                       {
+                               char _dummy[2];
+                               assert(encodingContext->context.chunked->nextOffset == 0);
+                               assert(encodingContext->context.chunked->headerFooterPos < 2);
+                               status =
+                                   BIO_read(bio, _dummy, 2 - encodingContext->context.chunked->headerFooterPos);
+                               if (status >= 0)
+                               {
+                                       encodingContext->context.chunked->headerFooterPos += status;
+                                       if (encodingContext->context.chunked->headerFooterPos == 2)
+                                       {
+                                               encodingContext->context.chunked->state = ChunkStateLenghHeader;
+                                               encodingContext->context.chunked->headerFooterPos = 0;
+                                       }
+                               }
+                               else
+                                       return (effectiveDataLen > 0 ? effectiveDataLen : status);
+                       }
+                       break;
+                       case ChunkStateLenghHeader:
+                       {
+                               BOOL _haveNewLine = FALSE;
+                               size_t tmp;
+                               assert(encodingContext->context.chunked->nextOffset == 0);
+                               while (encodingContext->context.chunked->headerFooterPos < 10 && !_haveNewLine)
+                               {
+                                       status = BIO_read(bio,
+                                                         encodingContext->context.chunked->lenBuffer +
+                                                             encodingContext->context.chunked->headerFooterPos,
+                                                         1);
+                                       if (status >= 0)
+                                       {
+                                               if (encodingContext->context.chunked
+                                                       ->lenBuffer[encodingContext->context.chunked->headerFooterPos] ==
+                                                   '\n')
+                                                       _haveNewLine = TRUE;
+                                               encodingContext->context.chunked->headerFooterPos += status;
+                                       }
+                                       else
+                                               return (effectiveDataLen > 0 ? effectiveDataLen : status);
+                               }
+                               encodingContext->context.chunked
+                                   ->lenBuffer[encodingContext->context.chunked->headerFooterPos] = '\0';
+                               /* strtoul is tricky, error are reported via errno, we also need
+                                * to ensure the result does not overflow */
+                               errno = 0;
+                               tmp = strtoul(encodingContext->context.chunked->lenBuffer, NULL, 16);
+                               if ((errno != 0) || (tmp > SIZE_MAX))
+                                       return -1;
+                               encodingContext->context.chunked->nextOffset = tmp;
+                               encodingContext->context.chunked->state = ChunkStateData;
+
+                               if (encodingContext->context.chunked->nextOffset == 0)
+                               { // end of stream
+                                       int fd = BIO_get_fd(bio, NULL);
+                                       if (fd >= 0)
+                                               close(fd);
+
+                                       WLog_WARN(TAG, "cunked encoding end of stream received");
+                                       encodingContext->context.chunked->headerFooterPos = 0;
+                                       encodingContext->context.chunked->state = ChunkStateFooter;
+                               }
+                       }
+                       break;
+                       default:
+                               /* invalid state */
+                               return -1;
+               }
+       }
+       return -1;
+}
+
+static int rdg_socket_read(BIO* bio, BYTE* pBuffer, size_t size,
+                           rdg_http_encoding_context* encodingContext)
+{
+       assert(encodingContext != NULL);
+       switch (encodingContext->httpTransferEncoding)
+       {
+               case TransferEncodingIdentity:
+                       return BIO_read(bio, pBuffer, size);
+               case TransferEncodingChunked:
+                       return rdg_chuncked_read(bio, pBuffer, size, encodingContext);
+               default:
+                       return -1;
+       }
+       return -1; /* should not be reached */
+}
+
+static BOOL rdg_read_all(rdpTls* tls, BYTE* buffer, size_t size,
+                         rdg_http_encoding_context* transferEncoding)
 {
        int status;
        int readCount = 0;
@@ -306,13 +455,14 @@ static BOOL rdg_read_all(rdpTls* tls, BYTE* buffer, int size)
 
        while (readCount < size)
        {
-               status = BIO_read(tls->bio, pBuffer, size - readCount);
+               status = rdg_socket_read(tls->bio, pBuffer, size - readCount, transferEncoding);
 
                if (status <= 0)
                {
                        if (!BIO_should_retry(tls->bio))
                                return FALSE;
 
+                       Sleep(10);
                        continue;
                }
 
@@ -334,7 +484,7 @@ static wStream* rdg_receive_packet(rdpRdg* rdg)
        if (!s)
                return NULL;
 
-       if (!rdg_read_all(rdg->tlsOut, Stream_Buffer(s), header))
+       if (!rdg_read_all(rdg->tlsOut, Stream_Buffer(s), header, rdg->transferEncoding))
        {
                Stream_Free(s, TRUE);
                return NULL;
@@ -350,7 +500,8 @@ static wStream* rdg_receive_packet(rdpRdg* rdg)
                return NULL;
        }
 
-       if (!rdg_read_all(rdg->tlsOut, Stream_Buffer(s) + header, (int)packetLength - (int)header))
+       if (!rdg_read_all(rdg->tlsOut, Stream_Buffer(s) + header, (int)packetLength - (int)header,
+                         rdg->transferEncoding))
        {
                Stream_Free(s, TRUE);
                return NULL;
@@ -557,7 +708,7 @@ static BOOL rdg_set_ntlm_auth_header(rdpNtlm* ntlm, HttpRequest* request)
 }
 
 static wStream* rdg_build_http_request(rdpRdg* rdg, const char* method,
-                                       const char* transferEncoding)
+                                       TRANSFER_ENCODING transferEncoding)
 {
        wStream* s = NULL;
        HttpRequest* request = NULL;
@@ -649,7 +800,8 @@ static BOOL rdg_handle_ntlm_challenge(rdpNtlm* ntlm, HttpResponse* response)
        return TRUE;
 }
 
-static BOOL rdg_skip_seed_payload(rdpTls* tls, SSIZE_T lastResponseLength)
+static BOOL rdg_skip_seed_payload(rdpTls* tls, SSIZE_T lastResponseLength,
+                                  rdg_http_encoding_context* transferEncoding)
 {
        BYTE seed_payload[10];
        const size_t size = sizeof(seed_payload);
@@ -661,7 +813,7 @@ static BOOL rdg_skip_seed_payload(rdpTls* tls, SSIZE_T lastResponseLength)
         */
        if (lastResponseLength < (SSIZE_T)size)
        {
-               if (!rdg_read_all(tls, seed_payload, size - lastResponseLength))
+               if (!rdg_read_all(tls, seed_payload, size - lastResponseLength, transferEncoding))
                {
                        return FALSE;
                }
@@ -1062,7 +1214,7 @@ static BOOL rdg_ntlm_init(rdpRdg* rdg, rdpTls* tls)
 }
 
 static BOOL rdg_send_http_request(rdpRdg* rdg, rdpTls* tls, const char* method,
-                                  const char* transferEncoding)
+                                  TRANSFER_ENCODING transferEncoding)
 {
        size_t sz;
        wStream* s = NULL;
@@ -1178,7 +1330,7 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char*
                if (!rdg_ntlm_init(rdg, tls))
                        return FALSE;
 
-               if (!rdg_send_http_request(rdg, tls, method, NULL))
+               if (!rdg_send_http_request(rdg, tls, method, TransferEncodingIdentity))
                        return FALSE;
 
                response = http_response_recv(tls, TRUE);
@@ -1213,7 +1365,7 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char*
                http_response_free(response);
        }
 
-       if (!rdg_send_http_request(rdg, tls, method, NULL))
+       if (!rdg_send_http_request(rdg, tls, method, TransferEncodingIdentity))
                return FALSE;
 
        ntlm_free(rdg->ntlm);
@@ -1241,12 +1393,24 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char*
 
        if (strcmp(method, "RDG_OUT_DATA") == 0)
        {
-               if (!rdg_skip_seed_payload(tls, bodyLength))
+               if (http_response_get_transfer_encoding(response) == TransferEncodingChunked)
+               {
+                       rdg->transferEncoding->httpTransferEncoding = TransferEncodingChunked;
+                       rdg->transferEncoding->context.chunked = (rdg_http_encoding_chunked_context*)calloc(
+                           1, sizeof(rdg_http_encoding_chunked_context));
+                       rdg->transferEncoding->context.chunked->nextOffset = 0;
+                       rdg->transferEncoding->context.chunked->headerFooterPos = 0;
+                       rdg->transferEncoding->context.chunked->state = ChunkStateLenghHeader;
+                       rdg->transferEncoding->context.chunked->lenBuffer = (char*)calloc(11, sizeof(char));
+               }
+               if (!rdg_skip_seed_payload(tls, bodyLength, rdg->transferEncoding))
+               {
                        return FALSE;
+               }
        }
        else
        {
-               if (!rdg_send_http_request(rdg, tls, method, "chunked"))
+               if (!rdg_send_http_request(rdg, tls, method, TransferEncodingChunked))
                        return FALSE;
        }
 
@@ -1471,8 +1635,8 @@ static BOOL rdg_process_control_packet(rdpRdg* rdg, int type, size_t packetLengt
 
                while (readCount < payloadSize)
                {
-                       status =
-                           BIO_read(rdg->tlsOut->bio, Stream_Pointer(s), (int)payloadSize - (int)readCount);
+                       status = rdg_socket_read(rdg->tlsOut->bio, Stream_Pointer(s), payloadSize - readCount,
+                                                rdg->transferEncoding);
 
                        if (status <= 0)
                        {
@@ -1535,7 +1699,7 @@ static int rdg_read_data_packet(rdpRdg* rdg, BYTE* buffer, int size)
 {
        RdgPacketHeader header;
        size_t readCount = 0;
-       int readSize;
+       size_t readSize;
        int status;
 
        if (!rdg->packetRemainingCount)
@@ -1544,8 +1708,9 @@ static int rdg_read_data_packet(rdpRdg* rdg, BYTE* buffer, int size)
 
                while (readCount < sizeof(RdgPacketHeader))
                {
-                       status = BIO_read(rdg->tlsOut->bio, (BYTE*)(&header) + readCount,
-                                         (int)sizeof(RdgPacketHeader) - (int)readCount);
+                       status = rdg_socket_read(rdg->tlsOut->bio, (BYTE*)(&header) + readCount,
+                                                (int)sizeof(RdgPacketHeader) - (int)readCount,
+                                                rdg->transferEncoding);
 
                        if (status <= 0)
                        {
@@ -1579,8 +1744,9 @@ static int rdg_read_data_packet(rdpRdg* rdg, BYTE* buffer, int size)
 
                while (readCount < 2)
                {
-                       status = BIO_read(rdg->tlsOut->bio, (BYTE*)(&rdg->packetRemainingCount) + readCount,
-                                         2 - (int)readCount);
+                       status =
+                           rdg_socket_read(rdg->tlsOut->bio, (BYTE*)(&rdg->packetRemainingCount) + readCount,
+                                           2 - (int)readCount, rdg->transferEncoding);
 
                        if (status < 0)
                        {
@@ -1595,8 +1761,8 @@ static int rdg_read_data_packet(rdpRdg* rdg, BYTE* buffer, int size)
                }
        }
 
-       readSize = (rdg->packetRemainingCount < size ? rdg->packetRemainingCount : size);
-       status = BIO_read(rdg->tlsOut->bio, buffer, readSize);
+       readSize = (rdg->packetRemainingCount < size) ? rdg->packetRemainingCount : size;
+       status = rdg_socket_read(rdg->tlsOut->bio, buffer, readSize, rdg->transferEncoding);
 
        if (status <= 0)
        {
@@ -1860,6 +2026,11 @@ rdpRdg* rdg_new(rdpContext* context)
 
                BIO_set_data(rdg->frontBio, rdg);
                InitializeCriticalSection(&rdg->writeSection);
+
+               rdg->transferEncoding =
+                   (rdg_http_encoding_context*)calloc(1, sizeof(rdg_http_encoding_context));
+
+               rdg->transferEncoding->httpTransferEncoding = TransferEncodingIdentity;
        }
 
        return rdg;
@@ -1882,6 +2053,14 @@ void rdg_free(rdpRdg* rdg)
                BIO_free_all(rdg->frontBio);
 
        DeleteCriticalSection(&rdg->writeSection);
+
+       if (rdg->transferEncoding->httpTransferEncoding == TransferEncodingChunked)
+       {
+               free(rdg->transferEncoding->context.chunked->lenBuffer);
+               free(rdg->transferEncoding->context.chunked);
+       }
+       free(rdg->transferEncoding);
+
        free(rdg);
 }