Fix NegotiateStream handling of EOF (#43739)
authorStephen Toub <stoub@microsoft.com>
Fri, 23 Oct 2020 13:10:15 +0000 (09:10 -0400)
committerGitHub <noreply@github.com>
Fri, 23 Oct 2020 13:10:15 +0000 (09:10 -0400)
In my refactoring of NegotiateStream to use async/await, I broke its handling of EOF, with it throwing an exception instead of returning 0.  This fixes it to correctly handle EOF.

src/libraries/System.Net.Security/src/System/Net/Security/NegotiateStream.cs
src/libraries/System.Net.Security/tests/FunctionalTests/NegotiateStreamStreamToStreamTest.cs

index 45ea9a0..2cab7ef 100644 (file)
@@ -362,7 +362,7 @@ namespace System.Net.Security
 
                 while (true)
                 {
-                    int readBytes = await ReadAllAsync(adapter, _readHeader).ConfigureAwait(false);
+                    int readBytes = await ReadAllAsync(adapter, _readHeader, allowZeroRead: true).ConfigureAwait(false);
                     if (readBytes == 0)
                     {
                         return 0;
@@ -386,12 +386,8 @@ namespace System.Net.Security
                     {
                         _readBuffer = new byte[readBytes];
                     }
-                    readBytes = await ReadAllAsync(adapter, new Memory<byte>(_readBuffer, 0, readBytes)).ConfigureAwait(false);
-                    if (readBytes == 0)
-                    {
-                        // We already checked that the frame body is bigger than 0 bytes. Hence, this is an EOF.
-                        throw new IOException(SR.net_io_eof);
-                    }
+
+                    readBytes = await ReadAllAsync(adapter, new Memory<byte>(_readBuffer, 0, readBytes), allowZeroRead: false).ConfigureAwait(false);
 
                     // Decrypt into internal buffer, change "readBytes" to count now _Decrypted Bytes_
                     // Decrypted data start from zero offset, the size can be shrunk after decryption.
@@ -423,16 +419,16 @@ namespace System.Net.Security
                 _readInProgress = 0;
             }
 
-            static async ValueTask<int> ReadAllAsync(TAdapter adapter, Memory<byte> buffer)
+            static async ValueTask<int> ReadAllAsync(TAdapter adapter, Memory<byte> buffer, bool allowZeroRead)
             {
-                int length = buffer.Length;
+                int read = 0;
 
                 do
                 {
                     int bytes = await adapter.ReadAsync(buffer).ConfigureAwait(false);
                     if (bytes == 0)
                     {
-                        if (!buffer.IsEmpty)
+                        if (read != 0 || !allowZeroRead)
                         {
                             throw new IOException(SR.net_io_eof);
                         }
@@ -440,10 +436,11 @@ namespace System.Net.Security
                     }
 
                     buffer = buffer.Slice(bytes);
+                    read += bytes;
                 }
                 while (!buffer.IsEmpty);
 
-                return length;
+                return read;
             }
         }
 
index 3ee034a..c8c10e6 100644 (file)
@@ -368,6 +368,29 @@ namespace System.Net.Security.Tests
                 await Assert.ThrowsAnyAsync<OperationCanceledException>(() => t);
             }
         }
+
+        [ConditionalFact(nameof(IsNtlmInstalled))]
+        public async Task NegotiateStream_ReadToEof_Returns0()
+        {
+            (Stream stream1, Stream stream2) = TestHelper.GetConnectedStreams();
+            using (var client = new NegotiateStream(stream1))
+            using (var server = new NegotiateStream(stream2))
+            {
+                await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
+                    AuthenticateAsClientAsync(client, CredentialCache.DefaultNetworkCredentials, string.Empty),
+                    AuthenticateAsServerAsync(server));
+
+                client.Write(Encoding.UTF8.GetBytes("hello"));
+                client.Dispose();
+
+                Assert.Equal('h', server.ReadByte());
+                Assert.Equal('e', server.ReadByte());
+                Assert.Equal('l', server.ReadByte());
+                Assert.Equal('l', server.ReadByte());
+                Assert.Equal('o', server.ReadByte());
+                Assert.Equal(-1, server.ReadByte());
+            }
+        }
     }
 
     public sealed class NegotiateStreamStreamToStreamTest_Async_Array : NegotiateStreamStreamToStreamTest