Fix SslStreamStreamToStreamTest to exercise correct overloads (dotnet/corefx#36065)
authorStephen Toub <stoub@microsoft.com>
Fri, 15 Mar 2019 19:59:31 +0000 (15:59 -0400)
committerGitHub <noreply@github.com>
Fri, 15 Mar 2019 19:59:31 +0000 (15:59 -0400)
The SslStreamStreamToStreamTest is set up as a base class from which three test classes derive, one for each of Async, Begin/End, and Sync.  But the base class isn't actually deferring to the derived types to customize most of the functionality being executed, namely read/write methods.  This PR fixes that, so that the base class properly exercises the relevant methods, customized to the base type.

Commit migrated from https://github.com/dotnet/corefx/commit/6a9fe42b8d236f9cb0e5cdede2f06dc7537dc933

src/libraries/System.Net.Security/tests/FunctionalTests/LoggingTest.cs
src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs
src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs

index 7d5f970..31fedb4 100644 (file)
@@ -40,7 +40,7 @@ namespace System.Net.Security.Tests
                         // Invoke tests that'll cause some events to be generated
                         var test = new SslStreamStreamToStreamTest_Async();
                         test.SslStream_StreamToStream_Authentication_Success().GetAwaiter().GetResult();
-                        test.SslStream_StreamToStream_Successive_ClientWrite_Sync_Success().GetAwaiter().GetResult();
+                        test.SslStream_StreamToStream_Successive_ClientWrite_Success().GetAwaiter().GetResult();
                     });
                     Assert.DoesNotContain(events, ev => ev.EventId == 0); // errors from the EventSource itself
                     Assert.InRange(events.Count, 1, int.MaxValue);
index 160765c..a8730a5 100644 (file)
@@ -3,11 +3,9 @@
 // See the LICENSE file in the project root for more information.
 
 using System.Net.Sockets;
-using System.Net.Test.Common;
 using System.Security.Authentication;
 using System.Security.Cryptography.X509Certificates;
 using System.Text;
-using System.Threading;
 using System.Threading.Tasks;
 
 using Xunit;
index 2e578e6..2712309 100644 (file)
@@ -8,6 +8,7 @@ using System.Net.Test.Common;
 using System.Security.Authentication;
 using System.Security.Cryptography.X509Certificates;
 using System.Text;
+using System.Threading;
 using System.Threading.Tasks;
 
 using Xunit;
@@ -22,6 +23,10 @@ namespace System.Net.Security.Tests
 
         protected abstract Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream);
 
+        protected abstract Task<int> ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
+
+        protected abstract Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
+
         [Fact]
         public async Task SslStream_StreamToStream_Authentication_Success()
         {
@@ -81,41 +86,6 @@ namespace System.Net.Security.Tests
         }
 
         [Fact]
-        public async Task SslStream_StreamToStream_Successive_ClientWrite_Sync_Success()
-        {
-            byte[] recvBuf = new byte[_sampleMsg.Length];
-            VirtualNetwork network = new VirtualNetwork();
-
-            using (var clientStream = new VirtualNetworkStream(network, isServer: false))
-            using (var serverStream = new VirtualNetworkStream(network, isServer: true))
-            using (var clientSslStream = new SslStream(clientStream, false, AllowAnyServerCertificate))
-            using (var serverSslStream = new SslStream(serverStream))
-            {
-                await DoHandshake(clientSslStream, serverSslStream);
-                
-                clientSslStream.Write(_sampleMsg);
-
-                int bytesRead = 0;
-                while (bytesRead < _sampleMsg.Length)
-                {
-                    bytesRead += serverSslStream.Read(recvBuf, bytesRead, _sampleMsg.Length - bytesRead);
-                }
-
-                Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify first read data is as expected.");
-
-                clientSslStream.Write(_sampleMsg);
-
-                bytesRead = 0;
-                while (bytesRead < _sampleMsg.Length)
-                {
-                    bytesRead += serverSslStream.Read(recvBuf, bytesRead, _sampleMsg.Length - bytesRead);
-                }
-
-                Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify second read data is as expected.");
-            }
-        }
-
-        [Fact]
         public async Task SslStream_StreamToStream_Successive_ClientWrite_WithZeroBytes_Success()
         {
             byte[] recvBuf = new byte[_sampleMsg.Length];
@@ -128,26 +98,24 @@ namespace System.Net.Security.Tests
             {
                 await DoHandshake(clientSslStream, serverSslStream);
                 
-                clientSslStream.Write(Array.Empty<byte>());
-                await clientSslStream.WriteAsync(Array.Empty<byte>(), 0, 0);
-                clientSslStream.Write(_sampleMsg);
+                await WriteAsync(clientSslStream, Array.Empty<byte>(), 0, 0);
+                await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length);
 
                 int bytesRead = 0;
                 while (bytesRead < _sampleMsg.Length)
                 {
-                    bytesRead += serverSslStream.Read(recvBuf, bytesRead, _sampleMsg.Length - bytesRead);
+                    bytesRead += await ReadAsync(serverSslStream, recvBuf, bytesRead, _sampleMsg.Length - bytesRead);
                 }
 
                 Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify first read data is as expected.");
 
-                clientSslStream.Write(_sampleMsg);
-                await clientSslStream.WriteAsync(Array.Empty<byte>(), 0, 0);
-                clientSslStream.Write(Array.Empty<byte>());
+                await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length);
+                await WriteAsync(clientSslStream, Array.Empty<byte>(), 0, 0);
 
                 bytesRead = 0;
                 while (bytesRead < _sampleMsg.Length)
                 {
-                    bytesRead += serverSslStream.Read(recvBuf, bytesRead, _sampleMsg.Length - bytesRead);
+                    bytesRead += await ReadAsync(serverSslStream, recvBuf, bytesRead, _sampleMsg.Length - bytesRead);
                 }
                 Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify second read data is as expected.");
             }
@@ -156,7 +124,7 @@ namespace System.Net.Security.Tests
         [Theory]
         [InlineData(false)]
         [InlineData(true)]
-        public async Task SslStream_StreamToStream_LargeWrites_Sync_Success(bool randomizedData)
+        public async Task SslStream_StreamToStream_LargeWrites_Success(bool randomizedData)
         {
             VirtualNetwork network = new VirtualNetwork();
 
@@ -182,10 +150,10 @@ namespace System.Net.Security.Tests
                 byte[] receivedLargeMsg = new byte[largeMsg.Length];
 
                 // First do a large write and read blocks at a time
-                clientSslStream.Write(largeMsg);
+                await WriteAsync(clientSslStream, largeMsg, 0, largeMsg.Length);
                 int bytesRead = 0, totalRead = 0;
                 while (totalRead < largeMsg.Length &&
-                    (bytesRead = serverSslStream.Read(receivedLargeMsg, totalRead, receivedLargeMsg.Length - totalRead)) != 0)
+                    (bytesRead = await ReadAsync(serverSslStream, receivedLargeMsg, totalRead, receivedLargeMsg.Length - totalRead)) != 0)
                 {
                     totalRead += bytesRead;
                 }
@@ -193,7 +161,7 @@ namespace System.Net.Security.Tests
                 Assert.Equal(largeMsg, receivedLargeMsg);
 
                 // Then write again and read bytes at a time
-                clientSslStream.Write(largeMsg);
+                await WriteAsync(clientSslStream, largeMsg, 0, largeMsg.Length);
                 foreach (byte b in largeMsg)
                 {
                     Assert.Equal(b, serverSslStream.ReadByte());
@@ -202,7 +170,7 @@ namespace System.Net.Security.Tests
         }
 
         [Fact]
-        public async Task SslStream_StreamToStream_Successive_ClientWrite_Async_Success()
+        public async Task SslStream_StreamToStream_Successive_ClientWrite_Success()
         {
             byte[] recvBuf = new byte[_sampleMsg.Length];
             VirtualNetwork network = new VirtualNetwork();
@@ -214,25 +182,25 @@ namespace System.Net.Security.Tests
             {
                 await DoHandshake(clientSslStream, serverSslStream);
                                 
-                await clientSslStream.WriteAsync(_sampleMsg, 0, _sampleMsg.Length)
+                await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length)
                     .TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
 
                 int bytesRead = 0;
                 while (bytesRead < _sampleMsg.Length)
                 {
-                    bytesRead += await serverSslStream.ReadAsync(recvBuf, bytesRead, _sampleMsg.Length - bytesRead)
+                    bytesRead += await ReadAsync(serverSslStream, recvBuf, bytesRead, _sampleMsg.Length - bytesRead)
                         .TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
                 }
 
                 Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify first read data is as expected.");
 
-                await clientSslStream.WriteAsync(_sampleMsg, 0, _sampleMsg.Length)
+                await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length)
                     .TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
 
                 bytesRead = 0;
                 while (bytesRead < _sampleMsg.Length)
                 {
-                    bytesRead += await serverSslStream.ReadAsync(recvBuf, bytesRead, _sampleMsg.Length - bytesRead)
+                    bytesRead += await ReadAsync(serverSslStream, recvBuf, bytesRead, _sampleMsg.Length - bytesRead)
                         .TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
                 }
 
@@ -254,7 +222,7 @@ namespace System.Net.Security.Tests
                 
                 for (int i = 0; i < 3; i++)
                 {
-                    clientSslStream.Write(_sampleMsg);
+                    await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length);
                     foreach (byte b in _sampleMsg)
                     {
                         Assert.Equal(b, serverSslStream.ReadByte());
@@ -277,7 +245,7 @@ namespace System.Net.Security.Tests
                 
                 for (int i = 0; i < 3; i++)
                 {
-                    await clientSslStream.WriteAsync(_sampleMsg, 0, _sampleMsg.Length).ConfigureAwait(false);
+                    await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length).ConfigureAwait(false);
                     foreach (byte b in _sampleMsg)
                     {
                         Assert.Equal(b, serverSslStream.ReadByte());
@@ -304,7 +272,7 @@ namespace System.Net.Security.Tests
                 {
                     tcs.TrySetResult(null);
                 };
-                Task readTask = serverSslStream.ReadAsync(serverBuffer, 0, serverBuffer.Length);
+                Task readTask = ReadAsync(serverSslStream, serverBuffer, 0, serverBuffer.Length);
 
                 // Since the sequence of calls that ends in serverStream.Read() is sync, by now
                 // the read task will have acquired the semaphore shared by Stream.BeginReadInternal()
@@ -313,36 +281,32 @@ namespace System.Net.Security.Tests
                 await tcs.Task.TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
 
                 // Should not hang
-                await serverSslStream.WriteAsync(new byte[] { 1 }, 0, 1)
+                await WriteAsync(serverSslStream, new byte[] { 1 }, 0, 1)
                     .TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
 
                 // Read in client
                 var clientBuffer = new byte[1];
-                await clientSslStream.ReadAsync(clientBuffer, 0, clientBuffer.Length);
+                await ReadAsync(clientSslStream, clientBuffer, 0, clientBuffer.Length);
                 Assert.Equal(1, clientBuffer[0]);
 
                 // Complete server read task
-                await clientSslStream.WriteAsync(new byte[] { 2 }, 0, 1);
+                await WriteAsync(clientSslStream, new byte[] { 2 }, 0, 1);
                 await readTask;
                 Assert.Equal(2, serverBuffer[0]);
             }
         }
 
-        [OuterLoop("Executes for several seconds")]
         [Fact]
         public async Task SslStream_ConcurrentBidirectionalReadsWrites_Success()
         {
             VirtualNetwork network = new VirtualNetwork();
-
-            using (var clientStream = new VirtualNetworkStream(network, isServer: false))
-            using (var serverStream = new NotifyReadVirtualNetworkStream(network, isServer: true))
-            using (var clientSslStream = new SslStream(clientStream, false, AllowAnyServerCertificate))
-            using (var serverSslStream = new SslStream(serverStream))
+            using (var clientSslStream = new SslStream(new VirtualNetworkStream(network, isServer: false), false, AllowAnyServerCertificate))
+            using (var serverSslStream = new SslStream(new NotifyReadVirtualNetworkStream(network, isServer: true)))
             {
                 await DoHandshake(clientSslStream, serverSslStream);
 
                 const int BytesPerSend = 100;
-                DateTime endTime = DateTime.UtcNow + TimeSpan.FromSeconds(5);
+                DateTime endTime = DateTime.UtcNow + TimeSpan.FromSeconds(3);
                 await new Task[]
                 {
                     Task.Run(async delegate
@@ -350,9 +314,9 @@ namespace System.Net.Security.Tests
                         var buffer = new byte[BytesPerSend];
                         while (DateTime.UtcNow < endTime)
                         {
-                            await clientStream.WriteAsync(buffer, 0, buffer.Length);
+                            await WriteAsync(clientSslStream, buffer, 0, buffer.Length);
                             int received = 0, bytesRead = 0;
-                            while (received < BytesPerSend && (bytesRead = await serverStream.ReadAsync(buffer, 0, buffer.Length)) != 0)
+                            while (received < BytesPerSend && (bytesRead = await ReadAsync(serverSslStream, buffer, 0, buffer.Length)) != 0)
                             {
                                 received += bytesRead;
                             }
@@ -364,9 +328,9 @@ namespace System.Net.Security.Tests
                         var buffer = new byte[BytesPerSend];
                         while (DateTime.UtcNow < endTime)
                         {
-                            await serverStream.WriteAsync(buffer, 0, buffer.Length);
+                            await WriteAsync(serverSslStream, buffer, 0, buffer.Length);
                             int received = 0, bytesRead = 0;
-                            while (received < BytesPerSend && (bytesRead = await clientStream.ReadAsync(buffer, 0, buffer.Length)) != 0)
+                            while (received < BytesPerSend && (bytesRead = await ReadAsync(clientSslStream, buffer, 0, buffer.Length)) != 0)
                             {
                                 received += bytesRead;
                             }
@@ -393,8 +357,8 @@ namespace System.Net.Security.Tests
                 await DoHandshake(clientSslStream, serverSslStream);
 
                 var serverBuffer = new byte[1];
-                Task serverReadTask = serverSslStream.ReadAsync(serverBuffer, 0, serverBuffer.Length);
-                await serverSslStream.WriteAsync(new byte[] { 1 }, 0, 1)
+                Task serverReadTask = ReadAsync(serverSslStream, serverBuffer, 0, serverBuffer.Length);
+                await WriteAsync(serverSslStream, new byte[] { 1 }, 0, 1)
                     .TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
 
                 // Shouldn't throw, the context is diposed now.
@@ -403,12 +367,16 @@ namespace System.Net.Security.Tests
 
                 // Read in client
                 var clientBuffer = new byte[1];
-                await clientSslStream.ReadAsync(clientBuffer, 0, clientBuffer.Length);
+                await ReadAsync(clientSslStream, clientBuffer, 0, clientBuffer.Length);
                 Assert.Equal(1, clientBuffer[0]);
 
-                await clientSslStream.WriteAsync(new byte[] { 2 }, 0, 1);
+                await WriteAsync(clientSslStream, new byte[] { 2 }, 0, 1);
 
-                if (PlatformDetection.IsFullFramework)
+                // We're inconsistent as to whether the ObjectDisposedException is thrown directly
+                // or wrapped in an IOException.  For sync operations, it's never wrapped.  For
+                // Begin/End, it's always wrapped,  And for Async, it's only wrapped on netfx.
+                if (this is SslStreamStreamToStreamTest_BeginEnd ||
+                    (this is SslStreamStreamToStreamTest_Async && PlatformDetection.IsFullFramework))
                 {
                     await Assert.ThrowsAsync<ObjectDisposedException>(() => serverReadTask);
                 }
@@ -418,11 +386,11 @@ namespace System.Net.Security.Tests
                     Assert.IsType<ObjectDisposedException>(serverException.InnerException);
                 }
 
-                await Assert.ThrowsAsync<ObjectDisposedException>(() => serverSslStream.ReadAsync(serverBuffer, 0, serverBuffer.Length));
+                await Assert.ThrowsAsync<ObjectDisposedException>(() => ReadAsync(serverSslStream, serverBuffer, 0, serverBuffer.Length));
 
                 // Now, there is no pending read, so the internal buffer will be returned to ArrayPool.
                 serverSslStream.Dispose();
-                await Assert.ThrowsAsync<ObjectDisposedException>(() => serverSslStream.ReadAsync(serverBuffer, 0, serverBuffer.Length));
+                await Assert.ThrowsAsync<ObjectDisposedException>(() => ReadAsync(serverSslStream, serverBuffer, 0, serverBuffer.Length));
             }
         }
 
@@ -469,6 +437,7 @@ namespace System.Net.Security.Tests
                     canWriteFunc: () => true,
                     canReadFunc: () => true,
                     writeFunc: (buffer, offset, count) => serverNetworkStream.Write(buffer, offset, count),
+                    writeAsyncFunc: (buffer, offset, count, token) => serverNetworkStream.WriteAsync(buffer, offset, count, token),
                     readFunc: (buffer, offset, count) =>
                     {
                         // Do normal reads as requested until the read mode is set
@@ -487,16 +456,34 @@ namespace System.Net.Security.Tests
                         {
                             return 0;
                         }
+                    },
+                    readAsyncFunc: (buffer, offset, count, token) =>
+                    {
+                        // Do normal reads as requested until the read mode is set
+                        // to 1.  Then do a single read of only 10 bytes to read only
+                        // part of the message, and subsequently return EOF.
+                        if (readMode == 0)
+                        {
+                            return serverNetworkStream.ReadAsync(buffer, offset, count);
+                        }
+                        else if (readMode == 1)
+                        {
+                            readMode = 2;
+                            return serverNetworkStream.ReadAsync(buffer, offset, 10); // read at least header but less than full frame
+                        }
+                        else
+                        {
+                            return Task.FromResult(0);
+                        }
                     });
 
-
                 using (var clientSslStream = new SslStream(clientNetworkStream, false, AllowAnyServerCertificate))
                 using (var serverSslStream = new SslStream(serverWrappedNetworkStream))
                 {
                     await DoHandshake(clientSslStream, serverSslStream);
-                    await clientSslStream.WriteAsync(new byte[20], 0, 20);
+                    await WriteAsync(clientSslStream, new byte[20], 0, 20);
                     readMode = 1;
-                    await Assert.ThrowsAsync<IOException>(() => serverSslStream.ReadAsync(new byte[1], 0, 1));
+                    await Assert.ThrowsAsync<IOException>(() => ReadAsync(serverSslStream, new byte[1], 0, 1));
                 }
             }
         }
@@ -543,6 +530,12 @@ namespace System.Net.Security.Tests
                 await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
             }
         }
+
+        protected override Task<int> ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+            stream.ReadAsync(buffer, offset, count, cancellationToken);
+
+        protected override Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+            stream.WriteAsync(buffer, offset, count, cancellationToken);
     }
 
     public sealed class SslStreamStreamToStreamTest_BeginEnd : SslStreamStreamToStreamTest
@@ -556,6 +549,16 @@ namespace System.Net.Security.Tests
                 await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
             }
         }
+
+        protected override Task<int> ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+            cancellationToken.IsCancellationRequested ?
+                Task.FromCanceled<int>(cancellationToken) :
+                Task.Factory.FromAsync(stream.BeginRead, stream.EndRead, buffer, offset, count, null);
+
+        protected override Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+            cancellationToken.IsCancellationRequested ?
+                Task.FromCanceled<int>(cancellationToken) :
+                Task.Factory.FromAsync(stream.BeginWrite, stream.EndWrite, buffer, offset, count, null);
     }
 
     public sealed class SslStreamStreamToStreamTest_Sync : SslStreamStreamToStreamTest
@@ -569,5 +572,12 @@ namespace System.Net.Security.Tests
                 await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
             }
         }
+
+        protected override Task<int> ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+            Task.Run(() => stream.Read(buffer, offset, count), cancellationToken);
+
+        protected override Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+            Task.Run(() => stream.Write(buffer, offset, count), cancellationToken);
+
     }
 }