Improve error messages for invalid WebSocket headers (#1962)
authorStephen Toub <stoub@microsoft.com>
Wed, 22 Jan 2020 18:06:30 +0000 (13:06 -0500)
committerGitHub <noreply@github.com>
Wed, 22 Jan 2020 18:06:30 +0000 (13:06 -0500)
And add tests.

src/libraries/Common/src/System/Net/WebSockets/ManagedWebSocket.cs
src/libraries/Common/tests/System/Net/WebSockets/WebSocketCreateTest.cs
src/libraries/System.Net.WebSockets.WebSocketProtocol/src/Resources/Strings.resx
src/libraries/System.Net.WebSockets/src/Resources/Strings.resx

index edace4f..9c76f6f 100644 (file)
@@ -666,9 +666,10 @@ namespace System.Net.WebSockets
                             }
                         }
 
-                        if (!TryParseMessageHeaderFromReceiveBuffer(out header))
+                        string headerErrorMessage = TryParseMessageHeaderFromReceiveBuffer(out header);
+                        if (headerErrorMessage != null)
                         {
-                            await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted).ConfigureAwait(false);
+                            await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, headerErrorMessage).ConfigureAwait(false);
                         }
                         _receivedMaskOffsetOffset = 0;
                     }
@@ -770,6 +771,12 @@ namespace System.Net.WebSockets
                     throw new OperationCanceledException(nameof(WebSocketState.Aborted), exc);
                 }
                 _abortSource.Cancel();
+
+                if (exc is WebSocketException)
+                {
+                    throw;
+                }
+
                 throw new WebSocketException(WebSocketError.ConnectionClosedPrematurely, exc);
             }
             finally
@@ -831,7 +838,7 @@ namespace System.Net.WebSockets
                     }
                     catch (DecoderFallbackException exc)
                     {
-                        await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, exc).ConfigureAwait(false);
+                        await CloseWithReceiveErrorAndThrowAsync(WebSocketCloseStatus.ProtocolError, WebSocketError.Faulted, innerException: exc).ConfigureAwait(false);
                     }
                 }
                 ConsumeFromBuffer((int)header.PayloadLength);
@@ -947,9 +954,10 @@ namespace System.Net.WebSockets
         /// <summary>Send a close message to the server and throw an exception, in response to getting bad data from the server.</summary>
         /// <param name="closeStatus">The close status code to use.</param>
         /// <param name="error">The error reason.</param>
+        /// <param name="errorMessage">An optional error message to include in the thrown exception.</param>
         /// <param name="innerException">An optional inner exception to include in the thrown exception.</param>
         private async ValueTask CloseWithReceiveErrorAndThrowAsync(
-            WebSocketCloseStatus closeStatus, WebSocketError error, Exception innerException = null)
+            WebSocketCloseStatus closeStatus, WebSocketError error, string errorMessage = null, Exception innerException = null)
         {
             // Close the connection if it hasn't already been closed
             if (!_sentCloseFrame)
@@ -961,13 +969,15 @@ namespace System.Net.WebSockets
             _receiveBufferCount = 0;
 
             // Let the caller know we've failed
-            throw new WebSocketException(error, innerException);
+            throw errorMessage != null ?
+                new WebSocketException(error, errorMessage, innerException) :
+                new WebSocketException(error, innerException);
         }
 
         /// <summary>Parses a message header from the buffer.  This assumes the header is in the buffer.</summary>
         /// <param name="resultHeader">The read header.</param>
-        /// <returns>true if a header was read; false if the header was invalid.</returns>
-        private bool TryParseMessageHeaderFromReceiveBuffer(out MessageHeader resultHeader)
+        /// <returns>null if a valid header was read; non-null containing the string error message to use if the header was invalid.</returns>
+        private string TryParseMessageHeaderFromReceiveBuffer(out MessageHeader resultHeader)
         {
             Debug.Assert(_receiveBufferCount >= 2, $"Expected to at least have the first two bytes of the header.");
 
@@ -1001,12 +1011,18 @@ namespace System.Net.WebSockets
                 ConsumeFromBuffer(8);
             }
 
-            bool shouldFail = reservedSet;
+            if (reservedSet)
+            {
+                resultHeader = default;
+                return SR.net_Websockets_ReservedBitsSet;
+            }
+
             if (masked)
             {
                 if (!_isServer)
                 {
-                    shouldFail = true;
+                    resultHeader = default;
+                    return SR.net_Websockets_ClientReceivedMaskedFrame;
                 }
                 header.Mask = CombineMaskBytes(receiveBufferSpan, _receiveBufferOffset);
 
@@ -1021,7 +1037,8 @@ namespace System.Net.WebSockets
                     if (_lastReceiveHeader.Fin)
                     {
                         // Can't continue from a final message
-                        shouldFail = true;
+                        resultHeader = default;
+                        return SR.net_Websockets_ContinuationFromFinalFrame;
                     }
                     break;
 
@@ -1030,7 +1047,8 @@ namespace System.Net.WebSockets
                     if (!_lastReceiveHeader.Fin)
                     {
                         // Must continue from a non-final message
-                        shouldFail = true;
+                        resultHeader = default;
+                        return SR.net_Websockets_NonContinuationAfterNonFinalFrame;
                     }
                     break;
 
@@ -1040,19 +1058,20 @@ namespace System.Net.WebSockets
                     if (header.PayloadLength > MaxControlPayloadLength || !header.Fin)
                     {
                         // Invalid control messgae
-                        shouldFail = true;
+                        resultHeader = default;
+                        return SR.net_Websockets_InvalidControlMessage;
                     }
                     break;
 
                 default:
                     // Unknown opcode
-                    shouldFail = true;
-                    break;
+                    resultHeader = default;
+                    return SR.Format(SR.net_Websockets_UnknownOpcode, header.Opcode);
             }
 
             // Return the read header
             resultHeader = header;
-            return !shouldFail;
+            return null;
         }
 
         /// <summary>Send a close message, then receive until we get a close response message.</summary>
index 4e9fff1..f2b4d05 100644 (file)
@@ -101,6 +101,51 @@ namespace System.Net.WebSockets.Tests
             }
         }
 
+        [Theory]
+        [InlineData(0b_1000_0001, 0b_0_000_0001, false)] // fin + text, no mask + length == 1
+        [InlineData(0b_1100_0001, 0b_0_000_0001, true)] // fin + rsv1 + text, no mask + length == 1
+        [InlineData(0b_1010_0001, 0b_0_000_0001, true)] // fin + rsv2 + text, no mask + length == 1
+        [InlineData(0b_1001_0001, 0b_0_000_0001, true)] // fin + rsv3 + text, no mask + length == 1
+        [InlineData(0b_1111_0001, 0b_0_000_0001, true)] // fin + rsv1 + rsv2 + rsv3 + text, no mask + length == 1
+        [InlineData(0b_1000_0001, 0b_1_000_0001, true)] // fin + text, mask + length == 1
+        [InlineData(0b_1000_0011, 0b_0_000_0001, true)] // fin + opcode==3, no mask + length == 1
+        [InlineData(0b_1000_0100, 0b_0_000_0001, true)] // fin + opcode==4, no mask + length == 1
+        [InlineData(0b_1000_0101, 0b_0_000_0001, true)] // fin + opcode==5, no mask + length == 1
+        [InlineData(0b_1000_0110, 0b_0_000_0001, true)] // fin + opcode==6, no mask + length == 1
+        [InlineData(0b_1000_0111, 0b_0_000_0001, true)] // fin + opcode==7, no mask + length == 1
+        public async Task ReceiveAsync_InvalidFrameHeader_AbortsAndThrowsException(byte firstByte, byte secondByte, bool shouldFail)
+        {
+            using (Socket listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+            using (Socket client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+            {
+                listener.Bind(new IPEndPoint(IPAddress.Loopback, 0));
+                listener.Listen(1);
+
+                await client.ConnectAsync(listener.LocalEndPoint);
+                using (Socket server = await listener.AcceptAsync())
+                {
+                    WebSocket websocket = CreateFromStream(new NetworkStream(client, ownsSocket: false), isServer: false, null, Timeout.InfiniteTimeSpan);
+
+                    await server.SendAsync(new ArraySegment<byte>(new byte[3] { firstByte, secondByte, (byte)'a' }), SocketFlags.None);
+
+                    var buffer = new byte[1];
+                    Task<WebSocketReceiveResult> t = websocket.ReceiveAsync(new ArraySegment<byte>(buffer), CancellationToken.None);
+                    if (shouldFail)
+                    {
+                        await Assert.ThrowsAsync<WebSocketException>(() => t);
+                        Assert.Equal(WebSocketState.Aborted, websocket.State);
+                    }
+                    else
+                    {
+                        WebSocketReceiveResult result = await t;
+                        Assert.True(result.EndOfMessage);
+                        Assert.Equal(1, result.Count);
+                        Assert.Equal('a', (char)buffer[0]);
+                    }
+                }
+            }
+        }
+
         [Fact]
         public async Task ReceiveAsync_ServerSplitHeader_ValidDataReceived()
         {
index a67e9e0..1c8cf89 100644 (file)
   <data name="net_WebSockets_ArgumentOutOfRange_TooSmall" xml:space="preserve">
     <value>The argument must be a value greater than {0}.</value>
   </data>
+  <data name="net_Websockets_ReservedBitsSet" xml:space="preserve">
+    <value>The WebSocket received a frame with one or more reserved bits set.</value>
+  </data>
+  <data name="net_Websockets_ClientReceivedMaskedFrame" xml:space="preserve">
+    <value>The WebSocket server sent a masked frame.</value>
+  </data>
+  <data name="net_Websockets_ContinuationFromFinalFrame" xml:space="preserve">
+    <value>The WebSocket received a continuation frame from a previous final message.</value>
+  </data>
+  <data name="net_Websockets_NonContinuationAfterNonFinalFrame" xml:space="preserve">
+    <value>The WebSocket expected a continuation frame after having received a previous non-final frame.</value>
+  </data>
+  <data name="net_Websockets_InvalidControlMessage" xml:space="preserve">
+    <value>The WebSocket received an invalid control message.</value>
+  </data>
+  <data name="net_Websockets_UnknownOpcode" xml:space="preserve">
+    <value>The WebSocket received a frame with an unknown opcode: '0x{0}'.</value>
+  </data>
 </root>
index ce25b82..a4f630e 100644 (file)
   <data name="net_Websockets_AlreadyOneOutstandingOperation" xml:space="preserve">
     <value>There is already one outstanding '{0}' call for this WebSocket instance. ReceiveAsync and SendAsync can be called simultaneously, but at most one outstanding operation for each of them is allowed at the same time.</value>
   </data>
+  <data name="net_Websockets_ReservedBitsSet" xml:space="preserve">
+    <value>The WebSocket received a frame with one or more reserved bits set.</value>
+  </data>
+  <data name="net_Websockets_ClientReceivedMaskedFrame" xml:space="preserve">
+    <value>The WebSocket server sent a masked frame.</value>
+  </data>
+  <data name="net_Websockets_ContinuationFromFinalFrame" xml:space="preserve">
+    <value>The WebSocket received a continuation frame from a previous final message.</value>
+  </data>
+  <data name="net_Websockets_NonContinuationAfterNonFinalFrame" xml:space="preserve">
+    <value>The WebSocket expected a continuation frame after having received a previous non-final frame.</value>
+  </data>
+  <data name="net_Websockets_InvalidControlMessage" xml:space="preserve">
+    <value>The WebSocket received an invalid control message.</value>
+  </data>
+  <data name="net_Websockets_UnknownOpcode" xml:space="preserve">
+    <value>The WebSocket received a frame with an unknown opcode: '0x{0}'.</value>
+  </data>
   <data name="NotReadableStream" xml:space="preserve">
     <value>The base stream is not readable.</value>
   </data>