fix buffer handling in Tls handshake (#32267)
authorTomas Weinfurt <tweinfurt@yahoo.com>
Fri, 14 Feb 2020 23:09:47 +0000 (15:09 -0800)
committerGitHub <noreply@github.com>
Fri, 14 Feb 2020 23:09:47 +0000 (15:09 -0800)
* fix buffer handling in Tls handshake

* feedback from review

src/libraries/System.Net.Security/src/System.Net.Security.csproj
src/libraries/System.Net.Security/src/System/Net/Security/SecureChannel.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.Adapters.cs
src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Implementation.cs

index 13c4150..abc2ce2 100644 (file)
       <Link>Common\System\Net\DebugCriticalHandleZeroOrMinusOneIsInvalid.cs</Link>
     </Compile>
     <!-- System.Net common -->
+    <Compile Include="$(CommonPath)System\Net\ArrayBuffer.cs">
+       <Link>Common\System\Net\ArrayBuffer.cs</Link>
+    </Compile>
     <Compile Include="$(CommonPath)System\Net\ExceptionCheck.cs">
       <Link>Common\System\Net\ExceptionCheck.cs</Link>
     </Compile>
     <Compile Include="$(CommonPath)System\Net\LazyAsyncResult.cs">
       <Link>Common\System\Net\LazyAsyncResult.cs</Link>
     </Compile>
-    <Compile Include="$(CommonPath)System\Net\UriScheme.cs">
-      <Link>Common\System\Net\UriScheme.cs</Link>
-    </Compile>
     <Compile Include="$(CommonPath)System\Net\SecurityProtocol.cs">
       <Link>Common\System\Net\SecurityProtocol.cs</Link>
     </Compile>
+    <Compile Include="$(CommonPath)System\Net\UriScheme.cs">
+      <Link>Common\System\Net\UriScheme.cs</Link>
+    </Compile>
     <!-- Common -->
     <Compile Include="$(CommonPath)System\NotImplemented.cs">
       <Link>Common\System\NotImplemented.cs</Link>
     <Reference Include="System.Security.Cryptography.OpenSsl" />
     <Reference Include="System.Security.Cryptography.Primitives" />
   </ItemGroup>
-</Project>
\ No newline at end of file
+</Project>
index 8f3140e..164ea48 100644 (file)
@@ -717,13 +717,13 @@ namespace System.Net.Security
         }
 
         //
-        internal ProtocolToken NextMessage(byte[] incoming, int offset, int count)
+        internal ProtocolToken NextMessage(ReadOnlySpan<byte> incomingBuffer)
         {
             if (NetEventSource.IsEnabled)
                 NetEventSource.Enter(this);
 
             byte[] nextmsg = null;
-            SecurityStatusPal status = GenerateToken(incoming, offset, count, ref nextmsg);
+            SecurityStatusPal status = GenerateToken(incomingBuffer, ref nextmsg);
 
             if (!_sslAuthenticationOptions.IsServer && status.ErrorCode == SecurityStatusPalErrorCode.CredentialsNeeded)
             {
@@ -731,7 +731,7 @@ namespace System.Net.Security
                     NetEventSource.Info(this, "NextMessage() returned SecurityStatusPal.CredentialsNeeded");
 
                 SetRefreshCredentialNeeded();
-                status = GenerateToken(incoming, offset, count, ref nextmsg);
+                status = GenerateToken(incomingBuffer, ref nextmsg);
             }
 
             ProtocolToken token = new ProtocolToken(nextmsg, status);
@@ -763,27 +763,14 @@ namespace System.Net.Security
             Return:
                 status - error information
         --*/
-        private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref byte[] output)
+        private SecurityStatusPal GenerateToken(ReadOnlySpan<byte> inputBuffer, ref byte[] output)
         {
             if (NetEventSource.IsEnabled) NetEventSource.Enter(this, $"_refreshCredentialNeeded = {_refreshCredentialNeeded}");
 
-            if (offset < 0 || offset > (input == null ? 0 : input.Length))
-            {
-                NetEventSource.Fail(this, "Argument 'offset' out of range.");
-                throw new ArgumentOutOfRangeException(nameof(offset));
-            }
-
-            if (count < 0 || count > (input == null ? 0 : input.Length - offset))
-            {
-                NetEventSource.Fail(this, "Argument 'count' out of range.");
-                throw new ArgumentOutOfRangeException(nameof(count));
-            }
-
             byte[] result = Array.Empty<byte>();
             SecurityStatusPal status = default;
             bool cachedCreds = false;
             byte[] thumbPrint = null;
-            ReadOnlySpan<byte> inputBuffer = new ReadOnlySpan<byte>(input, offset, count);
 
             //
             // Looping through ASC or ISC with potentially cached credential that could have been
@@ -1155,7 +1142,7 @@ namespace System.Net.Security
             byte[] nextmsg = null;
 
             SecurityStatusPal status;
-            status = GenerateToken(null, 0, 0, ref nextmsg);
+            status = GenerateToken(default, ref nextmsg);
 
             ProtocolToken token = new ProtocolToken(nextmsg, status);
 
index acc6bd6..982837e 100644 (file)
@@ -11,7 +11,7 @@ namespace System.Net.Security
     {
         private interface ISslIOAdapter
         {
-            ValueTask<int> ReadAsync(byte[] buffer, int offset, int count);
+            ValueTask<int> ReadAsync(Memory<byte> buffer);
             ValueTask<int> ReadLockAsync(Memory<byte> buffer);
             Task WriteLockAsync();
             ValueTask WriteAsync(byte[] buffer, int offset, int count);
@@ -29,7 +29,7 @@ namespace System.Net.Security
                 _sslStream = sslStream;
             }
 
-            public ValueTask<int> ReadAsync(byte[] buffer, int offset, int count) => _sslStream.InnerStream.ReadAsync(new Memory<byte>(buffer, offset, count), _cancellationToken);
+            public ValueTask<int> ReadAsync(Memory<byte> buffer) => _sslStream.InnerStream.ReadAsync(buffer, _cancellationToken);
 
             public ValueTask<int> ReadLockAsync(Memory<byte> buffer) => _sslStream.CheckEnqueueReadAsync(buffer);
 
@@ -46,7 +46,7 @@ namespace System.Net.Security
 
             public SyncSslIOAdapter(SslStream sslStream) => _sslStream = sslStream;
 
-            public ValueTask<int> ReadAsync(byte[] buffer, int offset, int count) => new ValueTask<int>(_sslStream.InnerStream.Read(buffer, offset, count));
+            public ValueTask<int> ReadAsync(Memory<byte> buffer) => new ValueTask<int>(_sslStream.InnerStream.Read(buffer.Span));
 
             public ValueTask<int> ReadLockAsync(Memory<byte> buffer) => new ValueTask<int>(_sslStream.CheckEnqueueRead(buffer));
 
index 34e6542..81fc3bf 100644 (file)
@@ -55,6 +55,8 @@ namespace System.Net.Security
 
         private const int FrameOverhead = 32;
         private const int ReadBufferSize = 4096 * 4 + FrameOverhead;         // We read in 16K chunks + headers.
+        private const int InitialHandshakeBufferSize = 4096 + FrameOverhead; // try to fit at least 4K ServerCertificate
+        private ArrayBuffer _handshakeBuffer;
 
         private int _lockWriteState;
         private int _lockReadState;
@@ -234,7 +236,7 @@ namespace System.Net.Security
 
             if (reAuthenticationData == null)
             {
-                // prevent nesting only when authentication functions are called explicitly. e.g. handle renegotiation tansparently.
+                // prevent nesting only when authentication functions are called explicitly. e.g. handle renegotiation transparently.
                 if (Interlocked.Exchange(ref _nestedAuth, 1) == 1)
                 {
                     throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, isApm ? "BeginAuthenticate" : "Authenticate", "authenticate"));
@@ -246,7 +248,7 @@ namespace System.Net.Security
 
                 if (!receiveFirst)
                 {
-                    message = _context.NextMessage(reAuthenticationData, 0, (reAuthenticationData == null ? 0 : reAuthenticationData.Length));
+                    message = _context.NextMessage(reAuthenticationData);
                     if (message.Size > 0)
                     {
                         await adapter.WriteAsync(message.Payload, 0, message.Size).ConfigureAwait(false);
@@ -259,6 +261,7 @@ namespace System.Net.Security
                     }
                 }
 
+                _handshakeBuffer = new ArrayBuffer(InitialHandshakeBufferSize);
                 do
                 {
                     message = await ReceiveBlobAsync(adapter).ConfigureAwait(false);
@@ -274,6 +277,14 @@ namespace System.Net.Security
                     }
                 } while (message.Status.ErrorCode != SecurityStatusPalErrorCode.OK);
 
+                if (_handshakeBuffer.ActiveLength > 0)
+                {
+                    // If we read more than we needed for handshake, move it to input buffer for further processing.
+                    ResetReadBuffer();
+                    _handshakeBuffer.ActiveSpan.CopyTo(_internalBuffer);
+                    _internalBufferCount = _handshakeBuffer.ActiveLength;
+                }
+
                 ProtocolToken alertToken = null;
                 if (!CompleteHandshake(ref alertToken))
                 {
@@ -282,6 +293,7 @@ namespace System.Net.Security
             }
             finally
             {
+                _handshakeBuffer.Dispose();
                 if (reAuthenticationData == null)
                 {
                     _nestedAuth = 0;
@@ -303,8 +315,7 @@ namespace System.Net.Security
         private async ValueTask<ProtocolToken> ReceiveBlobAsync<TIOAdapter>(TIOAdapter adapter)
                  where TIOAdapter : ISslIOAdapter
         {
-            ResetReadBuffer();
-            int readBytes = await FillBufferAsync(adapter, SecureChannel.ReadHeaderSize).ConfigureAwait(false);
+            int readBytes = await FillHandshakeBufferAsync(adapter, SecureChannel.ReadHeaderSize).ConfigureAwait(false);
             if (readBytes == 0)
             {
                 throw new IOException(SR.net_io_eof);
@@ -312,29 +323,22 @@ namespace System.Net.Security
 
             if (_framing == Framing.Unified || _framing == Framing.Unknown)
             {
-                _framing = DetectFraming(_internalBuffer, readBytes);
+                _framing = DetectFraming(_handshakeBuffer.ActiveReadOnlySpan);
             }
 
-            int payloadBytes = GetRemainingFrameSize(_internalBuffer, _internalOffset, readBytes);
-            if (payloadBytes < 0)
+            int frameSize= GetFrameSize(_handshakeBuffer.ActiveReadOnlySpan);
+            if (frameSize < 0)
             {
                 throw new IOException(SR.net_frame_read_size);
             }
 
-            int frameSize = SecureChannel.ReadHeaderSize + payloadBytes;
-
-            if (readBytes < frameSize)
+            if (_handshakeBuffer.ActiveLength < frameSize)
             {
-                readBytes = await FillBufferAsync(adapter, frameSize).ConfigureAwait(false);
-                Debug.Assert(readBytes >= 0);
-                if (readBytes == 0)
-                {
-                    throw new IOException(SR.net_io_eof);
-                }
+                await FillHandshakeBufferAsync(adapter, frameSize).ConfigureAwait(false);
             }
 
-            ProtocolToken token = _context.NextMessage(_internalBuffer, _internalOffset, frameSize);
-            ConsumeBufferedBytes(frameSize);
+            ProtocolToken token = _context.NextMessage(_handshakeBuffer.ActiveReadOnlySpan.Slice(0, frameSize));
+            _handshakeBuffer.Discard(frameSize);
 
             return token;
         }
@@ -585,8 +589,7 @@ namespace System.Net.Security
             {
                 // Re-handshake status is not supported.
                 ArrayPool<byte>.Shared.Return(rentedBuffer);
-                ProtocolToken message = new ProtocolToken(null, status);
-                return new ValueTask(Task.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(new IOException(SR.net_io_encrypt, message.GetException()))));
+                return new ValueTask(Task.FromException(ExceptionDispatchInfo.SetCurrentStackTrace(new IOException(SR.net_io_encrypt, SslStreamPal.GetException(status)))));
             }
 
             ValueTask t = writeAdapter.WriteAsync(outBuffer, 0, encryptedBytes);
@@ -700,13 +703,13 @@ namespace System.Net.Security
                         return 0;
                     }
 
-                    int payloadBytes = GetRemainingFrameSize(_internalBuffer, _internalOffset, readBytes);
+                    int payloadBytes = GetFrameSize(new ReadOnlySpan<byte>(_internalBuffer, _internalOffset, readBytes));
                     if (payloadBytes < 0)
                     {
                         throw new IOException(SR.net_frame_read_size);
                     }
 
-                    readBytes = await FillBufferAsync(adapter, SecureChannel.ReadHeaderSize + payloadBytes).ConfigureAwait(false);
+                    readBytes = await FillBufferAsync(adapter, payloadBytes).ConfigureAwait(false);
                     Debug.Assert(readBytes >= 0);
                     if (readBytes == 0)
                     {
@@ -717,7 +720,7 @@ namespace System.Net.Security
                     // Set _decrytpedBytesOffset/Count to the current frame we have (including header)
                     // DecryptData will decrypt in-place and modify these to point to the actual decrypted data, which may be smaller.
                     _decryptedBytesOffset = _internalOffset;
-                    _decryptedBytesCount = readBytes;
+                    _decryptedBytesCount = payloadBytes;
                     SecurityStatusPal status = DecryptData();
 
                     // Treat the bytes we just decrypted as consumed
@@ -735,11 +738,10 @@ namespace System.Net.Security
                             _decryptedBytesCount = 0;
                         }
 
-                        ProtocolToken message = new ProtocolToken(null, status);
                         if (NetEventSource.IsEnabled)
-                            NetEventSource.Info(null, $"***Processing an error Status = {message.Status}");
+                            NetEventSource.Info(null, $"***Processing an error Status = {status}");
 
-                        if (message.Renegotiate)
+                        if (status.ErrorCode == SecurityStatusPalErrorCode.Renegotiate)
                         {
                             if (!_sslAuthenticationOptions.AllowRenegotiation)
                             {
@@ -752,12 +754,12 @@ namespace System.Net.Security
                             continue;
                         }
 
-                        if (message.CloseConnection)
+                        if (status.ErrorCode == SecurityStatusPalErrorCode.ContextExpired)
                         {
                             return 0;
                         }
 
-                        throw new IOException(SR.net_io_decrypt, message.GetException());
+                        throw new IOException(SR.net_io_decrypt, SslStreamPal.GetException(status));
                     }
                 }
             }
@@ -776,6 +778,60 @@ namespace System.Net.Security
             }
         }
 
+        // This function tries to make sure buffer has at least minSize bytes available.
+        // If we have enough data, it returns synchronously. If not, it will try to read
+        // remaining bytes from given stream.
+        private ValueTask<int> FillHandshakeBufferAsync<TIOAdapter>(TIOAdapter adapter, int minSize)
+             where TIOAdapter : ISslIOAdapter
+        {
+            if (_handshakeBuffer.ActiveLength >= minSize)
+            {
+                return new ValueTask<int>(minSize);
+            }
+
+            int bytesNeeded = minSize - _handshakeBuffer.ActiveLength;
+            _handshakeBuffer.EnsureAvailableSpace(bytesNeeded);
+
+            while (_handshakeBuffer.ActiveLength < minSize)
+            {
+                ValueTask<int> t = adapter.ReadAsync(_handshakeBuffer.AvailableMemory);
+                if (!t.IsCompletedSuccessfully)
+                {
+                    return InternalFillHandshakeBufferAsync(adapter, t, minSize);
+                }
+                int bytesRead = t.Result;
+                if (bytesRead == 0)
+                {
+                    return new ValueTask<int>(0);
+                }
+
+                _handshakeBuffer.Commit(bytesRead);
+            }
+
+            return new ValueTask<int>(minSize);
+
+            async ValueTask<int> InternalFillHandshakeBufferAsync(TIOAdapter adap,  ValueTask<int> task, int minSize)
+            {
+                while (true)
+                {
+                    int bytesRead = await task.ConfigureAwait(false);
+                    if (bytesRead == 0)
+                    {
+                        throw new IOException(SR.net_io_eof);
+                    }
+
+                    _handshakeBuffer.Commit(bytesRead);
+                    if (_handshakeBuffer.ActiveLength >= minSize)
+                    {
+                        return minSize;
+                    }
+
+                    int bytesNeeded = minSize - _handshakeBuffer.ActiveLength;
+                    task = adap.ReadAsync(_handshakeBuffer.AvailableMemory);
+                }
+            }
+        }
+
         private ValueTask<int> FillBufferAsync<TIOAdapter>(TIOAdapter adapter, int minSize)
             where TIOAdapter : ISslIOAdapter
         {
@@ -787,7 +843,7 @@ namespace System.Net.Security
             int initialCount = _internalBufferCount;
             do
             {
-                ValueTask<int> t = adapter.ReadAsync(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount);
+                ValueTask<int> t = adapter.ReadAsync(new Memory<byte>(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount));
                 if (!t.IsCompletedSuccessfully)
                 {
                     return InternalFillBufferAsync(adapter, t, minSize, initialCount);
@@ -830,7 +886,7 @@ namespace System.Net.Security
                         return min;
                     }
 
-                    task = adap.ReadAsync(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount);
+                    task = adap.ReadAsync(new Memory<byte>(_internalBuffer, _internalBufferCount, _internalBuffer.Length - _internalBufferCount));
                 }
             }
         }
@@ -936,7 +992,7 @@ namespace System.Net.Security
         }
 
         // We need at least 5 bytes to determine what we have.
-        private Framing DetectFraming(byte[] bytes, int length)
+        private Framing DetectFraming(ReadOnlySpan<byte> bytes)
         {
             /* PCTv1.0 Hello starts with
              * RECORD_LENGTH_MSB  (ignore)
@@ -1016,7 +1072,7 @@ namespace System.Net.Security
 
             int version = -1;
 
-            if ((bytes == null || bytes.Length <= 0))
+            if (bytes.Length == 0)
             {
                 NetEventSource.Fail(this, "Header buffer is not allocated.");
             }
@@ -1025,7 +1081,7 @@ namespace System.Net.Security
             if (bytes[0] == (byte)FrameType.Handshake || bytes[0] == (byte)FrameType.AppData
                 || bytes[0] == (byte)FrameType.Alert)
             {
-                if (length < 3)
+                if (bytes.Length < 3)
                 {
                     return Framing.Invalid;
                 }
@@ -1057,7 +1113,7 @@ namespace System.Net.Security
             }
 #endif
 
-            if (length < 3)
+            if (bytes.Length < 3)
             {
                 return Framing.Invalid;
             }
@@ -1069,14 +1125,14 @@ namespace System.Net.Security
 
             if (bytes[2] == 0x1)  // SSL_MT_CLIENT_HELLO
             {
-                if (length >= 5)
+                if (bytes.Length >= 5)
                 {
                     version = (bytes[3] << 8) | bytes[4];
                 }
             }
             else if (bytes[2] == 0x4) // SSL_MT_SERVER_HELLO
             {
-                if (length >= 7)
+                if (bytes.Length >= 7)
                 {
                     version = (bytes[5] << 8) | bytes[6];
                 }
@@ -1114,44 +1170,43 @@ namespace System.Net.Security
 
         //
         // This is called from SslStream class too.
-        private int GetRemainingFrameSize(byte[] buffer, int offset, int dataSize)
+        // Returns TLS Frame size.
+        //
+        private int GetFrameSize(ReadOnlySpan<byte> buffer)
         {
             if (NetEventSource.IsEnabled)
-                NetEventSource.Enter(this, buffer, offset, dataSize);
+                NetEventSource.Enter(this, buffer.Length);
 
             int payloadSize = -1;
             switch (_framing)
             {
                 case Framing.Unified:
                 case Framing.BeforeSSL3:
-                    if (dataSize < 2)
+                    if (buffer.Length < 2)
                     {
                         throw new System.IO.IOException(SR.net_ssl_io_frame);
                     }
                     // Note: Cannot detect version mismatch for <= SSL2
 
-                    if ((buffer[offset] & 0x80) != 0)
+                    if ((buffer[0] & 0x80) != 0)
                     {
                         // Two bytes
-                        payloadSize = (((buffer[offset] & 0x7f) << 8) | buffer[offset + 1]) + 2;
-                        payloadSize -= dataSize;
+                        payloadSize = (((buffer[0] & 0x7f) << 8) | buffer[1]) + 2;
                     }
                     else
                     {
                         // Three bytes
-                        payloadSize = (((buffer[offset] & 0x3f) << 8) | buffer[offset + 1]) + 3;
-                        payloadSize -= dataSize;
+                        payloadSize = (((buffer[0] & 0x3f) << 8) | buffer[1]) + 3;
                     }
 
                     break;
                 case Framing.SinceSSL3:
-                    if (dataSize < 5)
+                    if (buffer.Length < 5)
                     {
                         throw new System.IO.IOException(SR.net_ssl_io_frame);
                     }
 
-                    payloadSize = ((buffer[offset + 3] << 8) | buffer[offset + 4]) + 5;
-                    payloadSize -= dataSize;
+                    payloadSize = ((buffer[3] << 8) | buffer[4]) + 5;
                     break;
                 default:
                     break;
@@ -1159,6 +1214,7 @@ namespace System.Net.Security
 
             if (NetEventSource.IsEnabled)
                 NetEventSource.Exit(this, payloadSize);
+
             return payloadSize;
         }
     }