avoid ArgumentOutOfRangeException while processing invalid or incomplete TLS frame...
authorTomas Weinfurt <tweinfurt@yahoo.com>
Fri, 31 Dec 2021 23:02:24 +0000 (15:02 -0800)
committerGitHub <noreply@github.com>
Fri, 31 Dec 2021 23:02:24 +0000 (15:02 -0800)
* avoid ArgumentOutOfRangeException while processing invalid or incomplete TLS frame

* feedback from review

src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs
src/libraries/System.Net.Security/tests/FunctionalTests/ServerAsyncAuthenticateTest.cs

index f553f14..1cc4fec 100644 (file)
@@ -459,12 +459,7 @@ namespace System.Net.Security
         private async ValueTask<ProtocolToken> ReceiveBlobAsync<TIOAdapter>(TIOAdapter adapter)
                  where TIOAdapter : IReadWriteAdapter
         {
-            int readBytes = await FillHandshakeBufferAsync(adapter, SecureChannel.ReadHeaderSize).ConfigureAwait(false);
-            if (readBytes == 0)
-            {
-                throw new IOException(SR.net_io_eof);
-            }
-
+            await FillHandshakeBufferAsync(adapter, SecureChannel.ReadHeaderSize).ConfigureAwait(false);
             if (_framing == Framing.Unified || _framing == Framing.Unknown)
             {
                 _framing = DetectFraming(_handshakeBuffer.ActiveReadOnlySpan);
@@ -1061,13 +1056,13 @@ namespace System.Net.Security
 
         // This function tries to make sure buffer has at least minSize bytes available.
         // If we have enough data, it returns synchronously. If not, it will try to read
-        // remaining bytes from given stream.
-        private ValueTask<int> FillHandshakeBufferAsync<TIOAdapter>(TIOAdapter adapter, int minSize)
+        // remaining bytes from given stream. It will throw if unable to fulfill minSize.
+        private ValueTask FillHandshakeBufferAsync<TIOAdapter>(TIOAdapter adapter, int minSize)
              where TIOAdapter : IReadWriteAdapter
         {
             if (_handshakeBuffer.ActiveLength >= minSize)
             {
-                return new ValueTask<int>(minSize);
+                return ValueTask.CompletedTask;
             }
 
             int bytesNeeded = minSize - _handshakeBuffer.ActiveLength;
@@ -1083,15 +1078,15 @@ namespace System.Net.Security
                 int bytesRead = t.Result;
                 if (bytesRead == 0)
                 {
-                    return new ValueTask<int>(0);
+                    throw new IOException(SR.net_io_eof);
                 }
 
                 _handshakeBuffer.Commit(bytesRead);
             }
 
-            return new ValueTask<int>(minSize);
+            return ValueTask.CompletedTask;
 
-            async ValueTask<int> InternalFillHandshakeBufferAsync(TIOAdapter adap,  ValueTask<int> task, int minSize)
+            async ValueTask InternalFillHandshakeBufferAsync(TIOAdapter adap,  ValueTask<int> task, int minSize)
             {
                 while (true)
                 {
@@ -1104,7 +1099,7 @@ namespace System.Net.Security
                     _handshakeBuffer.Commit(bytesRead);
                     if (_handshakeBuffer.ActiveLength >= minSize)
                     {
-                        return minSize;
+                        return;
                     }
 
                     task = adap.ReadAsync(_handshakeBuffer.AvailableMemory);
@@ -1112,24 +1107,6 @@ namespace System.Net.Security
             }
         }
 
-        private async ValueTask FillBufferAsync<TIOAdapter>(TIOAdapter adapter, int numBytesRequired)
-            where TIOAdapter : IReadWriteAdapter
-        {
-            Debug.Assert(_internalBufferCount > 0);
-            Debug.Assert(_internalBufferCount < numBytesRequired);
-
-            while (_internalBufferCount < numBytesRequired)
-            {
-                int bytesRead = await adapter.ReadAsync(_internalBuffer.AsMemory(_internalBufferCount)).ConfigureAwait(false);
-                if (bytesRead == 0)
-                {
-                    throw new IOException(SR.net_io_eof);
-                }
-
-                _internalBufferCount += bytesRead;
-            }
-        }
-
         private async ValueTask WriteAsyncInternal<TIOAdapter>(TIOAdapter writeAdapter, ReadOnlyMemory<byte> buffer)
             where TIOAdapter : struct, IReadWriteAdapter
         {
index 73d480a..6edcace 100644 (file)
@@ -285,6 +285,49 @@ namespace System.Net.Security.Tests
             }
         }
 
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task ServerAsyncAuthenticate_InvalidHello_Throws(bool close)
+        {
+            (NetworkStream client, NetworkStream server) = TestHelper.GetConnectedTcpStreams();
+            using (client)
+            using (SslStream ssl = new SslStream(server))
+            {
+                byte[] buffer = new byte[182];
+                buffer[0] = 178;
+                buffer[1] = 0;
+                buffer[2] = 0;
+                buffer[3] = 1;
+                buffer[4] = 133;
+                buffer[5] = 166;
+
+                Task t1 = ssl.AuthenticateAsServerAsync(_serverCertificate, false, false);
+                Task t2 = client.WriteAsync(buffer).AsTask();
+                if (close)
+                {
+                    await t2.WaitAsync(TestConfiguration.PassingTestTimeout);
+                    client.Socket.Shutdown(SocketShutdown.Send);
+                }
+                else
+                {
+                    // Write enough data to full frame size
+                    buffer = new byte[13000];
+                    t2 = client.WriteAsync(buffer).AsTask();
+                    await t2.WaitAsync(TestConfiguration.PassingTestTimeout);
+                }
+
+                if (close)
+                {
+                    await Assert.ThrowsAsync<IOException>(() => t1);
+                }
+                else
+                {
+                    await Assert.ThrowsAsync<AuthenticationException>(() => t1);
+                }
+            }
+        }
+
         public static IEnumerable<object[]> ProtocolMismatchData()
         {
             if (PlatformDetection.SupportsSsl3)