rdg websocket support
authorMichael Saxl <mike@mwsys.mine.bz>
Mon, 25 Jan 2021 15:20:18 +0000 (16:20 +0100)
committerakallabeth <akallabeth@users.noreply.github.com>
Thu, 25 Feb 2021 08:51:41 +0000 (09:51 +0100)
(cherry picked from commit 99489318f4832dbeca2130ae2c45f26c4e9a85bd)

libfreerdp/core/gateway/http.c
libfreerdp/core/gateway/http.h
libfreerdp/core/gateway/rdg.c

index 85f0908..91d1962 100644 (file)
@@ -30,6 +30,9 @@
 
 #include <freerdp/log.h>
 
+/* websocket need sha1 for Sec-Websocket-Accept */
+#include <winpr/crypto.h>
+
 #ifdef HAVE_VALGRIND_MEMCHECK_H
 #include <valgrind/memcheck.h>
 #endif
@@ -40,6 +43,8 @@
 
 #define RESPONSE_SIZE_LIMIT 64 * 1024 * 1024
 
+#define WEBSOCKET_MAGIC_GUID "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
+
 struct _http_context
 {
        char* Method;
@@ -52,6 +57,8 @@ struct _http_context
        char* Pragma;
        char* RdgConnectionId;
        char* RdgAuthScheme;
+       BOOL websocketUpgrade;
+       char SecWebsocketKey[16];
 };
 
 struct _http_request
@@ -77,6 +84,8 @@ struct _http_response
        size_t ContentLength;
        const char* ContentType;
        TRANSFER_ENCODING TransferEncoding;
+       const char* SecWebsocketVersion;
+       const char* SecWebsocketAccept;
 
        size_t BodyLength;
        BYTE* BodyContent;
@@ -259,6 +268,30 @@ BOOL http_context_set_rdg_connection_id(HttpContext* context, const char* RdgCon
        return TRUE;
 }
 
+BOOL http_context_enable_websocket_upgrade(HttpContext* context, BOOL enable)
+{
+       if (!context)
+               return FALSE;
+
+       if (enable)
+       {
+               int i;
+               winpr_RAND((BYTE*)context->SecWebsocketKey, 15);
+               for (i = 0; i < 16; i++)
+                       context->SecWebsocketKey[i] = (context->SecWebsocketKey[i] | 0x40) & 0x5f;
+               context->SecWebsocketKey[15] = '\0';
+       }
+       else
+               context->SecWebsocketKey[0] = '\0';
+       context->websocketUpgrade = enable;
+       return TRUE;
+}
+
+BOOL http_context_is_websocket_upgrade_enabled(HttpContext* context)
+{
+       return context->websocketUpgrade;
+}
+
 BOOL http_context_set_rdg_auth_scheme(HttpContext* context, const char* RdgAuthScheme)
 {
        if (!context || !RdgAuthScheme)
@@ -426,13 +459,26 @@ wStream* http_request_write(HttpContext* context, HttpRequest* request)
 
        if (!http_encode_header_line(s, request->Method, request->URI) ||
            !http_encode_body_line(s, "Cache-Control", context->CacheControl) ||
-           !http_encode_body_line(s, "Connection", context->Connection) ||
            !http_encode_body_line(s, "Pragma", context->Pragma) ||
            !http_encode_body_line(s, "Accept", context->Accept) ||
            !http_encode_body_line(s, "User-Agent", context->UserAgent) ||
            !http_encode_body_line(s, "Host", context->Host))
                goto fail;
 
+       if (!context->websocketUpgrade)
+       {
+               if (!http_encode_body_line(s, "Connection", context->Connection))
+                       goto fail;
+       }
+       else
+       {
+               if (!http_encode_body_line(s, "Connection", "Upgrade") ||
+                   !http_encode_body_line(s, "Upgrade", "websocket") ||
+                   !http_encode_body_line(s, "Sec-Websocket-Version", "13") ||
+                   !http_encode_body_line(s, "Sec-Websocket-Key", context->SecWebsocketKey))
+                       goto fail;
+       }
+
        if (context->RdgConnectionId)
        {
                if (!http_encode_body_line(s, "RDG-Connection-Id", context->RdgConnectionId))
@@ -556,7 +602,6 @@ static BOOL http_response_parse_header_field(HttpResponse* response, const char*
                                              const char* value)
 {
        BOOL status = TRUE;
-
        if (!response || !name)
                return FALSE;
 
@@ -587,6 +632,20 @@ static BOOL http_response_parse_header_field(HttpResponse* response, const char*
                else
                        response->TransferEncoding = TransferEncodingUnknown;
        }
+       else if (_stricmp(name, "Sec-WebSocket-Version") == 0)
+       {
+               response->SecWebsocketVersion = value;
+
+               if (!response->SecWebsocketVersion)
+                       return FALSE;
+       }
+       else if (_stricmp(name, "Sec-WebSocket-Accept") == 0)
+       {
+               response->SecWebsocketAccept = value;
+
+               if (!response->SecWebsocketAccept)
+                       return FALSE;
+       }
        else if (_stricmp(name, "WWW-Authenticate") == 0)
        {
                char* separator = NULL;
@@ -1041,3 +1100,56 @@ TRANSFER_ENCODING http_response_get_transfer_encoding(HttpResponse* response)
 
        return response->TransferEncoding;
 }
+
+BOOL http_response_is_websocket(HttpContext* http, HttpResponse* response)
+{
+       BOOL isWebsocket = FALSE;
+       WINPR_DIGEST_CTX* sha1 = NULL;
+       char* base64accept = NULL;
+       BYTE sha1_digest[WINPR_SHA1_DIGEST_LENGTH];
+
+       if (!http || !response)
+               return FALSE;
+
+       if (!http->websocketUpgrade || response->StatusCode != HTTP_STATUS_SWITCH_PROTOCOLS)
+               return FALSE;
+
+       if (response->SecWebsocketVersion && _stricmp(response->SecWebsocketVersion, "13") != 0)
+               return FALSE;
+
+       if (!response->SecWebsocketAccept)
+               return FALSE;
+
+       /* now check if Sec-Websocket-Accept is correct */
+
+       sha1 = winpr_Digest_New();
+       if (!sha1)
+               goto out;
+
+       if (!winpr_Digest_Init(sha1, WINPR_MD_SHA1))
+               goto out;
+
+       if (!winpr_Digest_Update(sha1, (const BYTE*)http->SecWebsocketKey,
+                                strlen(http->SecWebsocketKey)))
+               goto out;
+       if (!winpr_Digest_Update(sha1, (const BYTE*)WEBSOCKET_MAGIC_GUID, strlen(WEBSOCKET_MAGIC_GUID)))
+               goto out;
+
+       if (!winpr_Digest_Final(sha1, sha1_digest, sizeof(sha1_digest)))
+               goto out;
+
+       base64accept = crypto_base64_encode(sha1_digest, WINPR_SHA1_DIGEST_LENGTH);
+       if (!base64accept)
+               goto out;
+
+       if (_stricmp(response->SecWebsocketAccept, base64accept) != 0)
+       {
+               WLog_WARN(TAG, "Webserver gave Websocket Upgrade response but sanity check failed");
+               goto out;
+       }
+       isWebsocket = TRUE;
+out:
+       winpr_Digest_Free(sha1);
+       free(base64accept);
+       return isWebsocket;
+}
index af1be0f..00bfb03 100644 (file)
@@ -52,6 +52,8 @@ FREERDP_LOCAL BOOL http_context_set_rdg_connection_id(HttpContext* context,
                                                       const char* RdgConnectionId);
 FREERDP_LOCAL BOOL http_context_set_rdg_auth_scheme(HttpContext* context,
                                                     const char* RdgAuthScheme);
+FREERDP_LOCAL BOOL http_context_enable_websocket_upgrade(HttpContext* context, BOOL enable);
+FREERDP_LOCAL BOOL http_context_is_websocket_upgrade_enabled(HttpContext* context);
 
 /* HTTP request */
 typedef struct _http_request HttpRequest;
@@ -85,5 +87,6 @@ FREERDP_LOCAL long http_response_get_status_code(HttpResponse* response);
 FREERDP_LOCAL SSIZE_T http_response_get_body_length(HttpResponse* response);
 FREERDP_LOCAL const char* http_response_get_auth_token(HttpResponse* response, const char* method);
 FREERDP_LOCAL TRANSFER_ENCODING http_response_get_transfer_encoding(HttpResponse* response);
+FREERDP_LOCAL BOOL http_response_is_websocket(HttpContext* http, HttpResponse* response);
 
 #endif /* FREERDP_LIB_CORE_GATEWAY_HTTP_H */
index 04ef0a0..4e26948 100644 (file)
 #define HTTP_CAPABILITY_REAUTH 0x10
 #define HTTP_CAPABILITY_UDP_TRANSPORT 0x20
 
+#define WEBSOCKET_MASK_BIT 0x80
+#define WEBSOCKET_FIN_BIT 0x80
+
+typedef enum _WEBSOCKET_OPCODE
+{
+       WebsocketContinuationOpcode = 0x0,
+       WebsocketTextOpcode = 0x1,
+       WebsocketBinaryOpcode = 0x2,
+       WebsocketCloseOpcode = 0x8,
+       WebsocketPingOpcode = 0x9,
+       WebsocketPongOpcode = 0xa,
+} WEBSOCKET_OPCODE;
+
+typedef enum _WEBSOCKET_STATE
+{
+       WebsocketStateOpcodeAndFin,
+       WebsocketStateLengthAndMasking,
+       WebsocketStateShortLength,
+       WebsocketStateLongLength,
+       WebSocketStateMaskingKey,
+       WebSocketStatePayload,
+} WEBSOCKET_STATE;
+
+typedef struct
+{
+       size_t payloadLength;
+       uint32_t maskingKey;
+       BOOL masking;
+       BOOL closeSent;
+       BYTE opcode;
+       BYTE fragmentOriginalOpcode;
+       BYTE lengthAndMaskPosition;
+       WEBSOCKET_STATE state;
+       wStream* responseStreamBuffer;
+} rdg_http_websocket_context;
 typedef enum _CHUNK_STATE
 {
        ChunkStateLenghHeader,
@@ -122,9 +157,11 @@ typedef struct
 typedef struct
 {
        TRANSFER_ENCODING httpTransferEncoding;
+       BOOL isWebsocketTransport;
        union _context
        {
                rdg_http_encoding_chunked_context chunked;
+               rdg_http_websocket_context websocket;
        } context;
 } rdg_http_encoding_context;
 
@@ -293,9 +330,9 @@ static BOOL rdg_read_http_unicode_string(wStream* s, const WCHAR** string, UINT1
        return TRUE;
 }
 
-static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket)
+static BOOL rdg_write_chunked(BIO* bio, wStream* sPacket)
 {
-       size_t s;
+       size_t len;
        int status;
        wStream* sChunk;
        char chunkSize[11];
@@ -309,22 +346,414 @@ static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket)
        Stream_Write(sChunk, Stream_Buffer(sPacket), Stream_Length(sPacket));
        Stream_Write(sChunk, "\r\n", 2);
        Stream_SealLength(sChunk);
-       s = Stream_Length(sChunk);
+       len = Stream_Length(sChunk);
 
-       if (s > INT_MAX)
+       if (len > INT_MAX)
                return FALSE;
 
-       status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)s);
+       status = BIO_write(bio, Stream_Buffer(sChunk), (int)len);
        Stream_Free(sChunk, TRUE);
 
+       if (status != len)
+               return FALSE;
+
+       return TRUE;
+}
+
+static BOOL rdg_write_websocket(BIO* bio, wStream* sPacket, WEBSOCKET_OPCODE opcode)
+{
+       size_t len;
+       size_t fullLen;
+       int status;
+       wStream* sWS;
+
+       uint32_t maskingKey;
+
+       size_t streamPos;
+
+       len = Stream_Length(sPacket);
+       Stream_SetPosition(sPacket, 0);
+
+       if (len > INT_MAX)
+               return FALSE;
+
+       if (len < 126)
+               fullLen = len + 6; /* 2 byte "mini header" + 4 byte masking key */
+       else if (len < 0x10000)
+               fullLen = len + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */
+       else
+               fullLen = len + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
+
+       sWS = Stream_New(NULL, fullLen);
+       if (!sWS)
+               return FALSE;
+
+       winpr_RAND((BYTE*)&maskingKey, 4);
+
+       Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | opcode);
+       if (len < 126)
+               Stream_Write_UINT8(sWS, len | WEBSOCKET_MASK_BIT);
+       else if (len < 0x10000)
+       {
+               Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
+               Stream_Write_UINT16_BE(sWS, len);
+       }
+       else
+       {
+               Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
+               Stream_Write_UINT32_BE(sWS, 0); /* payload is limited to INT_MAX */
+               Stream_Write_UINT32_BE(sWS, len);
+       }
+       Stream_Write_UINT32(sWS, maskingKey);
+
+       /* mask as much as possible with 32bit access */
+       for (streamPos = 0; streamPos + 4 <= len; streamPos += 4)
+       {
+               uint32_t data;
+               Stream_Read_UINT32(sPacket, data);
+               Stream_Write_UINT32(sWS, data ^ maskingKey);
+       }
+
+       /* mask the rest byte by byte */
+       for (; streamPos < len; streamPos++)
+       {
+               BYTE data;
+               BYTE* partialMask = ((BYTE*)&maskingKey) + (streamPos % 4);
+               Stream_Read_UINT8(sPacket, data);
+               Stream_Write_UINT8(sWS, data ^ *partialMask);
+       }
+
+       Stream_SealLength(sWS);
+
+       status = BIO_write(bio, Stream_Buffer(sWS), Stream_Length(sWS));
+       Stream_Free(sWS, TRUE);
+
+       if (status != fullLen)
+               return FALSE;
+
+       return TRUE;
+}
+
+static BOOL rdg_write_packet(rdpRdg* rdg, wStream* sPacket)
+{
+       if (rdg->transferEncoding.isWebsocketTransport)
+       {
+               if (rdg->transferEncoding.context.websocket.closeSent)
+                       return FALSE;
+               return rdg_write_websocket(rdg->tlsOut->bio, sPacket, WebsocketBinaryOpcode);
+       }
+
+       return rdg_write_chunked(rdg->tlsIn->bio, sPacket);
+}
+
+static int rdg_websocket_read_data(BIO* bio, BYTE* pBuffer, size_t size,
+                                   rdg_http_websocket_context* encodingContext)
+{
+       int status;
+
+       if (encodingContext->payloadLength == 0)
+       {
+               encodingContext->state = WebsocketStateOpcodeAndFin;
+               return 0;
+       }
+
+       status =
+           BIO_read(bio, pBuffer,
+                    (encodingContext->payloadLength < size ? encodingContext->payloadLength : size));
+       if (status <= 0)
+               return status;
+
+       encodingContext->payloadLength -= status;
+
+       if (encodingContext->payloadLength == 0)
+               encodingContext->state = WebsocketStateOpcodeAndFin;
+
+       return status;
+}
+
+static int rdg_websocket_read_discard(BIO* bio, rdg_http_websocket_context* encodingContext)
+{
+       char _dummy[256];
+       int status;
+
+       if (encodingContext->payloadLength == 0)
+       {
+               encodingContext->state = WebsocketStateOpcodeAndFin;
+               return 0;
+       }
+
+       status = BIO_read(bio, _dummy, sizeof(_dummy));
+       if (status <= 0)
+               return status;
+
+       encodingContext->payloadLength -= status;
+
+       if (encodingContext->payloadLength == 0)
+               encodingContext->state = WebsocketStateOpcodeAndFin;
+
+       return status;
+}
+
+static int rdg_websocket_read_wstream(BIO* bio, wStream* s,
+                                      rdg_http_websocket_context* encodingContext)
+{
+       int status;
+
+       if (encodingContext->payloadLength == 0)
+       {
+               encodingContext->state = WebsocketStateOpcodeAndFin;
+               return 0;
+       }
+       if (s == NULL || Stream_GetRemainingCapacity(s) != encodingContext->payloadLength)
+               return -1;
+
+       status = BIO_read(bio, Stream_Pointer(s), encodingContext->payloadLength);
+       if (status <= 0)
+               return status;
+
+       Stream_Seek(s, status);
+
+       encodingContext->payloadLength -= status;
+
+       if (encodingContext->payloadLength == 0)
+       {
+               encodingContext->state = WebsocketStateOpcodeAndFin;
+               Stream_SealLength(s);
+               Stream_SetPosition(s, 0);
+       }
+
+       return status;
+}
+
+static BOOL rdg_websocket_reply_close(BIO* bio, wStream* s)
+{
+       /* write back close */
+       wStream* closeFrame;
+       uint16_t maskingKey1;
+       uint16_t maskingKey2;
+       int status;
+       size_t closeDataLen;
+
+       closeDataLen = 0;
+       if (s != NULL && Stream_Length(s) >= 2)
+               closeDataLen = 2;
+
+       closeFrame = Stream_New(NULL, 6 + closeDataLen);
+       Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketPongOpcode);
+       Stream_Write_UINT8(closeFrame, closeDataLen | WEBSOCKET_MASK_BIT); /* no payload */
+       winpr_RAND((BYTE*)&maskingKey1, 2);
+       winpr_RAND((BYTE*)&maskingKey2, 2);
+       Stream_Write_UINT16(closeFrame, maskingKey1);
+       Stream_Write_UINT16(closeFrame, maskingKey2); /* unused half, max 2 bytes of data */
+
+       if (closeDataLen == 2)
+       {
+               uint16_t data;
+               Stream_Read_UINT16(s, data);
+               Stream_Write_UINT16(s, data ^ maskingKey1);
+       }
+       Stream_SealLength(closeFrame);
+
+       status = BIO_write(bio, Stream_Buffer(closeFrame), Stream_Length(closeFrame));
+       /* server MUST close socket now. The server is not allowed anymore to
+        * send frames but if he does, nothing bad would happen */
        if (status < 0)
                return FALSE;
+       return TRUE;
+}
+
+static BOOL rdg_websocket_reply_pong(BIO* bio, wStream* s)
+{
+       wStream* closeFrame;
+       uint32_t maskingKey;
+       int status;
 
+       if (s != NULL)
+               return rdg_write_websocket(bio, s, WebsocketPongOpcode);
+
+       closeFrame = Stream_New(NULL, 6);
+       Stream_Write_UINT8(closeFrame, WEBSOCKET_FIN_BIT | WebsocketPongOpcode);
+       Stream_Write_UINT8(closeFrame, 0 | WEBSOCKET_MASK_BIT); /* no payload */
+       winpr_RAND((BYTE*)&maskingKey, 4);
+       Stream_Write_UINT32(closeFrame, maskingKey); /* dummy masking key. */
+       Stream_SealLength(closeFrame);
+
+       status = BIO_write(bio, Stream_Buffer(closeFrame), Stream_Length(closeFrame));
+
+       if (status < 0)
+               return FALSE;
        return TRUE;
 }
 
+static int rdg_websocket_handle_payload(BIO* bio, BYTE* pBuffer, size_t size,
+                                        rdg_http_websocket_context* encodingContext)
+{
+       int status;
+       BYTE effectiveOpcode = ((encodingContext->opcode & 0xf) == WebsocketContinuationOpcode
+                                   ? encodingContext->fragmentOriginalOpcode & 0xf
+                                   : encodingContext->opcode & 0xf);
+
+       switch (effectiveOpcode)
+       {
+               case WebsocketBinaryOpcode:
+               {
+                       status = rdg_websocket_read_data(bio, pBuffer, size, encodingContext);
+                       if (status < 0)
+                               return status;
+
+                       return status;
+               }
+               break;
+               case WebsocketPingOpcode:
+               {
+                       if (encodingContext->responseStreamBuffer == NULL)
+                               encodingContext->responseStreamBuffer =
+                                   Stream_New(NULL, encodingContext->payloadLength);
+
+                       status = rdg_websocket_read_wstream(bio, encodingContext->responseStreamBuffer,
+                                                           encodingContext);
+                       if (status < 0)
+                               return status;
+
+                       if (encodingContext->payloadLength == 0)
+                       {
+                               if (!encodingContext->closeSent)
+                                       rdg_websocket_reply_pong(bio, encodingContext->responseStreamBuffer);
+
+                               if (encodingContext->responseStreamBuffer)
+                                       Stream_Free(encodingContext->responseStreamBuffer, TRUE);
+                               encodingContext->responseStreamBuffer = NULL;
+                       }
+               }
+               break;
+               case WebsocketCloseOpcode:
+               {
+                       if (encodingContext->responseStreamBuffer == NULL)
+                               encodingContext->responseStreamBuffer =
+                                   Stream_New(NULL, encodingContext->payloadLength);
+
+                       status = rdg_websocket_read_wstream(bio, encodingContext->responseStreamBuffer,
+                                                           encodingContext);
+                       if (status < 0)
+                               return status;
+
+                       if (encodingContext->payloadLength == 0)
+                       {
+                               rdg_websocket_reply_close(bio, encodingContext->responseStreamBuffer);
+                               encodingContext->closeSent = TRUE;
+
+                               if (encodingContext->responseStreamBuffer)
+                                       Stream_Free(encodingContext->responseStreamBuffer, TRUE);
+                               encodingContext->responseStreamBuffer = NULL;
+                       }
+               }
+               break;
+               default:
+                       WLog_WARN(TAG, "Unimplemented websocket opcode %x. Dropping", effectiveOpcode & 0xf);
+
+                       status = rdg_websocket_read_discard(bio, encodingContext);
+                       if (status < 0)
+                               return status;
+       }
+       /* return how many bytes have been written to pBuffer.
+        * Only WebsocketBinaryOpcode writes into it and it returns directly */
+       return 0;
+}
+
+static int rdg_websocket_read(BIO* bio, BYTE* pBuffer, size_t size,
+                              rdg_http_websocket_context* encodingContext)
+{
+       int status;
+       int effectiveDataLen = 0;
+       assert(encodingContext != NULL);
+       while (TRUE)
+       {
+               switch (encodingContext->state)
+               {
+                       case WebsocketStateOpcodeAndFin:
+                       {
+                               BYTE buffer[1];
+                               status = BIO_read(bio, (char*)buffer, 1);
+                               if (status <= 0)
+                                       return (effectiveDataLen > 0 ? effectiveDataLen : status);
+
+                               encodingContext->opcode = buffer[0];
+                               if (((encodingContext->opcode & 0xf) != WebsocketContinuationOpcode) &&
+                                   (encodingContext->opcode & 0xf) < 0x08)
+                                       encodingContext->fragmentOriginalOpcode = encodingContext->opcode;
+                               encodingContext->state = WebsocketStateLengthAndMasking;
+                       }
+                       break;
+                       case WebsocketStateLengthAndMasking:
+                       {
+                               BYTE buffer[1];
+                               BYTE len;
+                               status = BIO_read(bio, (char*)buffer, 1);
+                               if (status <= 0)
+                                       return (effectiveDataLen > 0 ? effectiveDataLen : status);
+
+                               encodingContext->masking = ((buffer[0] & WEBSOCKET_MASK_BIT) == WEBSOCKET_MASK_BIT);
+                               encodingContext->lengthAndMaskPosition = 0;
+                               encodingContext->payloadLength = 0;
+                               len = buffer[0] & 0x7f;
+                               if (len < 126)
+                               {
+                                       encodingContext->payloadLength = len;
+                                       encodingContext->state = (encodingContext->masking ? WebSocketStateMaskingKey
+                                                                                          : WebSocketStatePayload);
+                               }
+                               else if (len == 126)
+                                       encodingContext->state = WebsocketStateShortLength;
+                               else
+                                       encodingContext->state = WebsocketStateLongLength;
+                       }
+                       break;
+                       case WebsocketStateShortLength:
+                       case WebsocketStateLongLength:
+                       {
+                               BYTE buffer[1];
+                               BYTE lenLength = (encodingContext->state == WebsocketStateShortLength ? 2 : 8);
+                               while (encodingContext->lengthAndMaskPosition < lenLength)
+                               {
+                                       status = BIO_read(bio, (char*)buffer, 1);
+                                       if (status <= 0)
+                                               return (effectiveDataLen > 0 ? effectiveDataLen : status);
+
+                                       encodingContext->payloadLength =
+                                           (encodingContext->payloadLength) << 8 | buffer[0];
+                                       encodingContext->lengthAndMaskPosition += status;
+                               }
+                               encodingContext->state =
+                                   (encodingContext->masking ? WebSocketStateMaskingKey : WebSocketStatePayload);
+                       }
+                       break;
+                       case WebSocketStateMaskingKey:
+                       {
+                               WLog_WARN(
+                                   TAG, "Websocket Server sends data with masking key. This is against RFC 6455.");
+                               return -1;
+                       }
+                       break;
+                       case WebSocketStatePayload:
+                       {
+                               status = rdg_websocket_handle_payload(bio, pBuffer, size, encodingContext);
+                               if (status < 0)
+                                       return (effectiveDataLen > 0 ? effectiveDataLen : status);
+
+                               effectiveDataLen += status;
+
+                               if ((size_t)status == size)
+                                       return effectiveDataLen;
+                               pBuffer += status;
+                               size -= status;
+                       }
+               }
+       }
+       /* should be unreachable */
+       return -1;
+}
 static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size,
-                             rdg_http_encoding_context* encodingContext)
+                             rdg_http_encoding_chunked_context* encodingContext)
 {
        int status;
        int effectiveDataLen = 0;
@@ -335,18 +764,17 @@ static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size,
                {
                        case ChunkStateData:
                        {
-                               status = BIO_read(bio, pBuffer,
-                                                 (size > encodingContext->context.chunked.nextOffset
-                                                      ? encodingContext->context.chunked.nextOffset
-                                                      : size));
+                               status = BIO_read(
+                                   bio, pBuffer,
+                                   (size > encodingContext->nextOffset ? encodingContext->nextOffset : size));
                                if (status <= 0)
                                        return (effectiveDataLen > 0 ? effectiveDataLen : status);
 
-                               encodingContext->context.chunked.nextOffset -= status;
-                               if (encodingContext->context.chunked.nextOffset == 0)
+                               encodingContext->nextOffset -= status;
+                               if (encodingContext->nextOffset == 0)
                                {
-                                       encodingContext->context.chunked.state = ChunkStateFooter;
-                                       encodingContext->context.chunked.headerFooterPos = 0;
+                                       encodingContext->state = ChunkStateFooter;
+                                       encodingContext->headerFooterPos = 0;
                                }
                                effectiveDataLen += status;
 
@@ -360,17 +788,16 @@ static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size,
                        case ChunkStateFooter:
                        {
                                char _dummy[2];
-                               assert(encodingContext->context.chunked.nextOffset == 0);
-                               assert(encodingContext->context.chunked.headerFooterPos < 2);
-                               status =
-                                   BIO_read(bio, _dummy, 2 - encodingContext->context.chunked.headerFooterPos);
+                               assert(encodingContext->nextOffset == 0);
+                               assert(encodingContext->headerFooterPos < 2);
+                               status = BIO_read(bio, _dummy, 2 - encodingContext->headerFooterPos);
                                if (status >= 0)
                                {
-                                       encodingContext->context.chunked.headerFooterPos += status;
-                                       if (encodingContext->context.chunked.headerFooterPos == 2)
+                                       encodingContext->headerFooterPos += status;
+                                       if (encodingContext->headerFooterPos == 2)
                                        {
-                                               encodingContext->context.chunked.state = ChunkStateLenghHeader;
-                                               encodingContext->context.chunked.headerFooterPos = 0;
+                                               encodingContext->state = ChunkStateLenghHeader;
+                                               encodingContext->headerFooterPos = 0;
                                        }
                                }
                                else
@@ -381,17 +808,16 @@ static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size,
                        {
                                BOOL _haveNewLine = FALSE;
                                size_t tmp;
-                               char* dst = &encodingContext->context.chunked
-                                                .lenBuffer[encodingContext->context.chunked.headerFooterPos];
-                               assert(encodingContext->context.chunked.nextOffset == 0);
-                               while (encodingContext->context.chunked.headerFooterPos < 10 && !_haveNewLine)
+                               char* dst = &encodingContext->lenBuffer[encodingContext->headerFooterPos];
+                               assert(encodingContext->nextOffset == 0);
+                               while (encodingContext->headerFooterPos < 10 && !_haveNewLine)
                                {
                                        status = BIO_read(bio, dst, 1);
                                        if (status >= 0)
                                        {
                                                if (*dst == '\n')
                                                        _haveNewLine = TRUE;
-                                               encodingContext->context.chunked.headerFooterPos += status;
+                                               encodingContext->headerFooterPos += status;
                                                dst += status;
                                        }
                                        else
@@ -403,21 +829,21 @@ static int rdg_chuncked_read(BIO* bio, BYTE* pBuffer, size_t size,
                                /* strtoul is tricky, error are reported via errno, we also need
                                 * to ensure the result does not overflow */
                                errno = 0;
-                               tmp = strtoul(encodingContext->context.chunked.lenBuffer, NULL, 16);
+                               tmp = strtoul(encodingContext->lenBuffer, NULL, 16);
                                if ((errno != 0) || (tmp > SIZE_MAX))
                                        return -1;
-                               encodingContext->context.chunked.nextOffset = tmp;
-                               encodingContext->context.chunked.state = ChunkStateData;
+                               encodingContext->nextOffset = tmp;
+                               encodingContext->state = ChunkStateData;
 
-                               if (encodingContext->context.chunked.nextOffset == 0)
-                               { // end of stream
+                               if (encodingContext->nextOffset == 0)
+                               { /* end of stream */
                                        int fd = BIO_get_fd(bio, NULL);
                                        if (fd >= 0)
-                                               close(fd);
+                                               closesocket((SOCKET)fd);
 
                                        WLog_WARN(TAG, "cunked encoding end of stream received");
-                                       encodingContext->context.chunked.headerFooterPos = 0;
-                                       encodingContext->context.chunked.state = ChunkStateFooter;
+                                       encodingContext->headerFooterPos = 0;
+                                       encodingContext->state = ChunkStateFooter;
                                }
                        }
                        break;
@@ -433,12 +859,16 @@ static int rdg_socket_read(BIO* bio, BYTE* pBuffer, size_t size,
                            rdg_http_encoding_context* encodingContext)
 {
        assert(encodingContext != NULL);
+       if (encodingContext->isWebsocketTransport)
+       {
+               return rdg_websocket_read(bio, pBuffer, size, &encodingContext->context.websocket);
+       }
        switch (encodingContext->httpTransferEncoding)
        {
                case TransferEncodingIdentity:
                        return BIO_read(bio, pBuffer, size);
                case TransferEncodingChunked:
-                       return rdg_chuncked_read(bio, pBuffer, size, encodingContext);
+                       return rdg_chuncked_read(bio, pBuffer, size, &encodingContext->context.chunked);
                default:
                        return -1;
        }
@@ -730,10 +1160,7 @@ static wStream* rdg_build_http_request(rdpRdg* rdg, const char* method,
                        goto out;
        }
 
-       if (transferEncoding)
-       {
                http_request_set_transfer_encoding(request, transferEncoding);
-       }
 
        s = http_request_write(rdg->http, request);
 out:
@@ -1320,6 +1747,7 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char*
        SSIZE_T bodyLength;
        long StatusCode;
        TRANSFER_ENCODING encoding;
+       BOOL isWebsocket;
 
        if (!rdg_tls_connect(rdg, tls, peerAddress, timeout))
                return FALSE;
@@ -1344,6 +1772,7 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char*
                        case HTTP_STATUS_NOT_FOUND:
                        {
                                WLog_INFO(TAG, "RD Gateway does not support HTTP transport.");
+                               http_context_enable_websocket_upgrade(rdg->http, FALSE);
 
                                if (rpcFallback)
                                        *rpcFallback = TRUE;
@@ -1377,16 +1806,44 @@ static BOOL rdg_establish_data_connection(rdpRdg* rdg, rdpTls* tls, const char*
        statusCode = http_response_get_status_code(response);
        bodyLength = http_response_get_body_length(response);
        encoding = http_response_get_transfer_encoding(response);
+       isWebsocket = http_response_is_websocket(rdg->http, response);
        http_response_free(response);
        WLog_DBG(TAG, "%s authorization result: %d", method, statusCode);
 
        switch (statusCode)
        {
                case HTTP_STATUS_OK:
+                       /* old rdg endpoint without websocket support, don't request websocket for RDG_IN_DATA
+                        */
+                       http_context_enable_websocket_upgrade(rdg->http, FALSE);
                        break;
                case HTTP_STATUS_DENIED:
                        freerdp_set_last_error_log(rdg->context, FREERDP_ERROR_CONNECT_ACCESS_DENIED);
                        return FALSE;
+               case HTTP_STATUS_SWITCH_PROTOCOLS:
+                       if (!isWebsocket)
+                       {
+                               /*
+                                * webserver is broken, a fallback may be possible here
+                                * but only if already tested with oppurtonistic upgrade
+                                */
+                               if (http_context_is_websocket_upgrade_enabled(rdg->http))
+                               {
+                                       int fd = BIO_get_fd(tls->bio, NULL);
+                                       if (fd >= 0)
+                                               closesocket((SOCKET)fd);
+                                       http_context_enable_websocket_upgrade(rdg->http, FALSE);
+                                       return rdg_establish_data_connection(rdg, tls, method, peerAddress, timeout,
+                                                                            rpcFallback);
+                               }
+                               return FALSE;
+                       }
+                       rdg->transferEncoding.isWebsocketTransport = TRUE;
+                       rdg->transferEncoding.context.websocket.state = WebsocketStateOpcodeAndFin;
+                       rdg->transferEncoding.context.websocket.responseStreamBuffer = NULL;
+
+                       return TRUE;
+                       break;
                default:
                        return FALSE;
        }
@@ -1452,14 +1909,21 @@ BOOL rdg_connect(rdpRdg* rdg, DWORD timeout, BOOL* rpcFallback)
 
        if (status)
        {
-               /* Establish IN connection with the same peer/server as OUT connection,
-                * even when server hostname resolves to different IP addresses.
-                */
-               BIO_get_socket(rdg->tlsOut->underlying, &outConnSocket);
-               peerAddress = freerdp_tcp_get_peer_address(outConnSocket);
-               status = rdg_establish_data_connection(rdg, rdg->tlsIn, "RDG_IN_DATA", peerAddress, timeout,
-                                                      NULL);
-               free(peerAddress);
+               if (rdg->transferEncoding.isWebsocketTransport)
+               {
+                       WLog_DBG(TAG, "Upgraded to websocket. RDG_IN_DATA not required");
+               }
+               else
+               {
+                       /* Establish IN connection with the same peer/server as OUT connection,
+                        * even when server hostname resolves to different IP addresses.
+                        */
+                       BIO_get_socket(rdg->tlsOut->underlying, &outConnSocket);
+                       peerAddress = freerdp_tcp_get_peer_address(outConnSocket);
+                       status = rdg_establish_data_connection(rdg, rdg->tlsIn, "RDG_IN_DATA", peerAddress,
+                                                              timeout, NULL);
+                       free(peerAddress);
+               }
        }
 
        if (!status)
@@ -1476,10 +1940,97 @@ BOOL rdg_connect(rdpRdg* rdg, DWORD timeout, BOOL* rpcFallback)
        return TRUE;
 }
 
-static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
+static int rdg_write_websocket_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
+{
+       size_t payloadSize;
+       size_t fullLen;
+       int status;
+       wStream* sWS;
+
+       uint32_t maskingKey;
+       BYTE* maskingKeyByte1 = (BYTE*)&maskingKey;
+       BYTE* maskingKeyByte2 = maskingKeyByte1 + 1;
+       BYTE* maskingKeyByte3 = maskingKeyByte1 + 2;
+       BYTE* maskingKeyByte4 = maskingKeyByte1 + 3;
+
+       int streamPos;
+
+       winpr_RAND((BYTE*)&maskingKey, 4);
+
+       payloadSize = isize + 10;
+       if ((isize < 0) || (isize > UINT16_MAX))
+               return -1;
+
+       if (payloadSize < 1)
+               return 0;
+
+       if (payloadSize < 126)
+               fullLen = payloadSize + 6; /* 2 byte "mini header" + 4 byte masking key */
+       else if (payloadSize < 0x10000)
+               fullLen = payloadSize + 8; /* 2 byte "mini header" + 2 byte length + 4 byte masking key */
+       else
+               fullLen = payloadSize + 14; /* 2 byte "mini header" + 8 byte length + 4 byte masking key */
+
+       sWS = Stream_New(NULL, fullLen);
+       if (!sWS)
+               return FALSE;
+
+       Stream_Write_UINT8(sWS, WEBSOCKET_FIN_BIT | WebsocketBinaryOpcode);
+       if (payloadSize < 126)
+               Stream_Write_UINT8(sWS, payloadSize | WEBSOCKET_MASK_BIT);
+       else if (payloadSize < 0x10000)
+       {
+               Stream_Write_UINT8(sWS, 126 | WEBSOCKET_MASK_BIT);
+               Stream_Write_UINT16_BE(sWS, payloadSize);
+       }
+       else
+       {
+               Stream_Write_UINT8(sWS, 127 | WEBSOCKET_MASK_BIT);
+               /* biggest packet possible is 0xffff + 0xa, so 32bit is always enough */
+               Stream_Write_UINT32_BE(sWS, 0);
+               Stream_Write_UINT32_BE(sWS, payloadSize);
+       }
+       Stream_Write_UINT32(sWS, maskingKey);
+
+       Stream_Write_UINT16(sWS, PKT_TYPE_DATA ^ (*maskingKeyByte1 | *maskingKeyByte2 << 8)); /* Type */
+       Stream_Write_UINT16(sWS, 0 ^ (*maskingKeyByte3 | *maskingKeyByte4 << 8)); /* Reserved */
+       Stream_Write_UINT32(sWS, (UINT32)payloadSize ^ maskingKey);               /* Packet length */
+       Stream_Write_UINT16(sWS,
+                           (UINT16)isize ^ (*maskingKeyByte1 | *maskingKeyByte2 << 8)); /* Data size */
+
+       /* masking key is now off by 2 bytes. fix that */
+       maskingKey = (maskingKey & 0xffff) << 16 | (maskingKey >> 16);
+
+       /* mask as much as possible with 32bit access */
+       for (streamPos = 0; streamPos + 4 <= isize; streamPos += 4)
+       {
+               uint32_t masked = *((uint32_t*)((BYTE*)buf + streamPos)) ^ maskingKey;
+               Stream_Write_UINT32(sWS, masked);
+       }
+
+       /* mask the rest byte by byte */
+       for (; streamPos < isize; streamPos++)
+       {
+               BYTE* partialMask = (BYTE*)(&maskingKey) + streamPos % 4;
+               BYTE masked = *((BYTE*)((BYTE*)buf + streamPos)) ^ *partialMask;
+               Stream_Write_UINT8(sWS, masked);
+       }
+
+       Stream_SealLength(sWS);
+
+       status = tls_write_all(rdg->tlsOut, Stream_Buffer(sWS), Stream_Length(sWS));
+       Stream_Free(sWS, TRUE);
+
+       if (status < 0)
+               return status;
+
+       return isize;
+}
+
+static int rdg_write_chunked_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
 {
        int status;
-       size_t s;
+       size_t len;
        wStream* sChunk;
        size_t size = (size_t)isize;
        size_t packetSize = size + 10;
@@ -1505,12 +2056,12 @@ static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
        Stream_Write(sChunk, buf, size);                 /* Data */
        Stream_Write(sChunk, "\r\n", 2);
        Stream_SealLength(sChunk);
-       s = Stream_Length(sChunk);
+       len = Stream_Length(sChunk);
 
-       if (s > INT_MAX)
+       if (len > INT_MAX)
                return -1;
 
-       status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)s);
+       status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)len);
        Stream_Free(sChunk, TRUE);
 
        if (status < 0)
@@ -1519,68 +2070,68 @@ static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
        return (int)size;
 }
 
-static BOOL rdg_process_close_packet(rdpRdg* rdg)
+static int rdg_write_data_packet(rdpRdg* rdg, const BYTE* buf, int isize)
+{
+       if (rdg->transferEncoding.isWebsocketTransport)
+       {
+               if (rdg->transferEncoding.context.websocket.closeSent == TRUE)
+                       return -1;
+               return rdg_write_websocket_data_packet(rdg, buf, isize);
+       }
+       else
+               return rdg_write_chunked_data_packet(rdg, buf, isize);
+
+       return -1;
+}
+
+static BOOL rdg_process_close_packet(rdpRdg* rdg, wStream* s)
 {
        int status = -1;
-       size_t s;
-       wStream* sChunk;
+       wStream* sClose;
+       UINT32 errorCode;
        UINT32 packetSize = 12;
-       char chunkSize[11];
-       int chunkLen = sprintf_s(chunkSize, sizeof(chunkSize), "%" PRIx32 "\r\n", packetSize);
 
-       if (chunkLen < 0)
+       /* Read error code */
+       if (Stream_GetRemainingLength(s) < 4)
                return FALSE;
+       Stream_Read_UINT32(s, errorCode);
 
-       sChunk = Stream_New(NULL, (size_t)chunkLen + packetSize + 2);
+       if (errorCode != 0)
+               freerdp_set_last_error_log(rdg->context, errorCode);
 
-       if (!sChunk)
+       sClose = Stream_New(NULL, packetSize);
+       if (!sClose)
                return FALSE;
 
-       Stream_Write(sChunk, chunkSize, (size_t)chunkLen);
-       Stream_Write_UINT16(sChunk, PKT_TYPE_CLOSE_CHANNEL_RESPONSE); /* Type */
-       Stream_Write_UINT16(sChunk, 0);                               /* Reserved */
-       Stream_Write_UINT32(sChunk, packetSize);                      /* Packet length */
-       Stream_Write_UINT32(sChunk, 0);                               /* Status code */
-       Stream_Write(sChunk, "\r\n", 2);
-       Stream_SealLength(sChunk);
-       s = Stream_Length(sChunk);
+       Stream_Write_UINT16(sClose, PKT_TYPE_CLOSE_CHANNEL_RESPONSE); /* Type */
+       Stream_Write_UINT16(sClose, 0);                               /* Reserved */
+       Stream_Write_UINT32(sClose, packetSize);                      /* Packet length */
+       Stream_Write_UINT32(sClose, 0);                               /* Status code */
+       Stream_SealLength(sClose);
+       status = rdg_write_packet(rdg, sClose);
+       Stream_Free(sClose, TRUE);
 
-       if (s <= INT_MAX)
-               status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)s);
-
-       Stream_Free(sChunk, TRUE);
        return (status < 0 ? FALSE : TRUE);
 }
 
 static BOOL rdg_process_keep_alive_packet(rdpRdg* rdg)
 {
        int status = -1;
-       size_t s;
-       wStream* sChunk;
+       wStream* sKeepAlive;
        size_t packetSize = 8;
-       char chunkSize[11];
-       int chunkLen = sprintf_s(chunkSize, sizeof(chunkSize), "%" PRIxz "\r\n", packetSize);
-
-       if ((chunkLen < 0) || (packetSize > UINT32_MAX))
-               return FALSE;
 
-       sChunk = Stream_New(NULL, (size_t)chunkLen + packetSize + 2);
+       sKeepAlive = Stream_New(NULL, packetSize);
 
-       if (!sChunk)
+       if (!sKeepAlive)
                return FALSE;
 
-       Stream_Write(sChunk, chunkSize, (size_t)chunkLen);
-       Stream_Write_UINT16(sChunk, PKT_TYPE_KEEPALIVE); /* Type */
-       Stream_Write_UINT16(sChunk, 0);                  /* Reserved */
-       Stream_Write_UINT32(sChunk, (UINT32)packetSize); /* Packet length */
-       Stream_Write(sChunk, "\r\n", 2);
-       Stream_SealLength(sChunk);
-       s = Stream_Length(sChunk);
+       Stream_Write_UINT16(sKeepAlive, PKT_TYPE_KEEPALIVE); /* Type */
+       Stream_Write_UINT16(sKeepAlive, 0);                  /* Reserved */
+       Stream_Write_UINT32(sKeepAlive, (UINT32)packetSize); /* Packet length */
+       Stream_SealLength(sKeepAlive);
+       status = rdg_write_packet(rdg, sKeepAlive);
+       Stream_Free(sKeepAlive, TRUE);
 
-       if (s <= INT_MAX)
-               status = tls_write_all(rdg->tlsIn, Stream_Buffer(sChunk), (int)s);
-
-       Stream_Free(sChunk, TRUE);
        return (status < 0 ? FALSE : TRUE);
 }
 
@@ -1663,7 +2214,7 @@ static BOOL rdg_process_control_packet(rdpRdg* rdg, int type, size_t packetLengt
        {
                case PKT_TYPE_CLOSE_CHANNEL:
                        EnterCriticalSection(&rdg->writeSection);
-                       status = rdg_process_close_packet(rdg);
+                       status = rdg_process_close_packet(rdg, s);
                        LeaveCriticalSection(&rdg->writeSection);
                        break;
 
@@ -1995,7 +2546,8 @@ rdpRdg* rdg_new(rdpContext* context)
                    !http_context_set_connection(rdg->http, "Keep-Alive") ||
                    !http_context_set_user_agent(rdg->http, "MS-RDGateway/1.0") ||
                    !http_context_set_host(rdg->http, rdg->settings->GatewayHostname) ||
-                   !http_context_set_rdg_connection_id(rdg->http, bracedUuid))
+                   !http_context_set_rdg_connection_id(rdg->http, bracedUuid) ||
+                   !http_context_enable_websocket_upgrade(rdg->http, TRUE))
                {
                        goto rdg_alloc_error;
                }
@@ -2025,6 +2577,7 @@ rdpRdg* rdg_new(rdpContext* context)
                InitializeCriticalSection(&rdg->writeSection);
 
                rdg->transferEncoding.httpTransferEncoding = TransferEncodingIdentity;
+               rdg->transferEncoding.isWebsocketTransport = FALSE;
        }
 
        return rdg;
@@ -2047,6 +2600,11 @@ void rdg_free(rdpRdg* rdg)
                BIO_free_all(rdg->frontBio);
 
        DeleteCriticalSection(&rdg->writeSection);
+       if (rdg->transferEncoding.isWebsocketTransport)
+       {
+               if (rdg->transferEncoding.context.websocket.responseStreamBuffer != NULL)
+                       Stream_Free(rdg->transferEncoding.context.websocket.responseStreamBuffer, TRUE);
+       }
 
        free(rdg);
 }