From: Stephen Toub Date: Fri, 15 Mar 2019 19:59:31 +0000 (-0400) Subject: Fix SslStreamStreamToStreamTest to exercise correct overloads (dotnet/corefx#36065) X-Git-Tag: submit/tizen/20210909.063632~11031^2~2170 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f6ec54a1f77b0b893fd2c01b9d0bef28794e4441;p=platform%2Fupstream%2Fdotnet%2Fruntime.git Fix SslStreamStreamToStreamTest to exercise correct overloads (dotnet/corefx#36065) The SslStreamStreamToStreamTest is set up as a base class from which three test classes derive, one for each of Async, Begin/End, and Sync. But the base class isn't actually deferring to the derived types to customize most of the functionality being executed, namely read/write methods. This PR fixes that, so that the base class properly exercises the relevant methods, customized to the base type. Commit migrated from https://github.com/dotnet/corefx/commit/6a9fe42b8d236f9cb0e5cdede2f06dc7537dc933 --- diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/LoggingTest.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/LoggingTest.cs index 7d5f970..31fedb4 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/LoggingTest.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/LoggingTest.cs @@ -40,7 +40,7 @@ namespace System.Net.Security.Tests // Invoke tests that'll cause some events to be generated var test = new SslStreamStreamToStreamTest_Async(); test.SslStream_StreamToStream_Authentication_Success().GetAwaiter().GetResult(); - test.SslStream_StreamToStream_Successive_ClientWrite_Sync_Success().GetAwaiter().GetResult(); + test.SslStream_StreamToStream_Successive_ClientWrite_Success().GetAwaiter().GetResult(); }); Assert.DoesNotContain(events, ev => ev.EventId == 0); // errors from the EventSource itself Assert.InRange(events.Count, 1, int.MaxValue); diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs index 160765c..a8730a5 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamNetworkStreamTest.cs @@ -3,11 +3,9 @@ // See the LICENSE file in the project root for more information. using System.Net.Sockets; -using System.Net.Test.Common; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Text; -using System.Threading; using System.Threading.Tasks; using Xunit; diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs index 2e578e6..2712309 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs @@ -8,6 +8,7 @@ using System.Net.Test.Common; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Text; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -22,6 +23,10 @@ namespace System.Net.Security.Tests protected abstract Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream); + protected abstract Task ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default); + + protected abstract Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default); + [Fact] public async Task SslStream_StreamToStream_Authentication_Success() { @@ -81,41 +86,6 @@ namespace System.Net.Security.Tests } [Fact] - public async Task SslStream_StreamToStream_Successive_ClientWrite_Sync_Success() - { - byte[] recvBuf = new byte[_sampleMsg.Length]; - VirtualNetwork network = new VirtualNetwork(); - - using (var clientStream = new VirtualNetworkStream(network, isServer: false)) - using (var serverStream = new VirtualNetworkStream(network, isServer: true)) - using (var clientSslStream = new SslStream(clientStream, false, AllowAnyServerCertificate)) - using (var serverSslStream = new SslStream(serverStream)) - { - await DoHandshake(clientSslStream, serverSslStream); - - clientSslStream.Write(_sampleMsg); - - int bytesRead = 0; - while (bytesRead < _sampleMsg.Length) - { - bytesRead += serverSslStream.Read(recvBuf, bytesRead, _sampleMsg.Length - bytesRead); - } - - Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify first read data is as expected."); - - clientSslStream.Write(_sampleMsg); - - bytesRead = 0; - while (bytesRead < _sampleMsg.Length) - { - bytesRead += serverSslStream.Read(recvBuf, bytesRead, _sampleMsg.Length - bytesRead); - } - - Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify second read data is as expected."); - } - } - - [Fact] public async Task SslStream_StreamToStream_Successive_ClientWrite_WithZeroBytes_Success() { byte[] recvBuf = new byte[_sampleMsg.Length]; @@ -128,26 +98,24 @@ namespace System.Net.Security.Tests { await DoHandshake(clientSslStream, serverSslStream); - clientSslStream.Write(Array.Empty()); - await clientSslStream.WriteAsync(Array.Empty(), 0, 0); - clientSslStream.Write(_sampleMsg); + await WriteAsync(clientSslStream, Array.Empty(), 0, 0); + await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length); int bytesRead = 0; while (bytesRead < _sampleMsg.Length) { - bytesRead += serverSslStream.Read(recvBuf, bytesRead, _sampleMsg.Length - bytesRead); + bytesRead += await ReadAsync(serverSslStream, recvBuf, bytesRead, _sampleMsg.Length - bytesRead); } Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify first read data is as expected."); - clientSslStream.Write(_sampleMsg); - await clientSslStream.WriteAsync(Array.Empty(), 0, 0); - clientSslStream.Write(Array.Empty()); + await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length); + await WriteAsync(clientSslStream, Array.Empty(), 0, 0); bytesRead = 0; while (bytesRead < _sampleMsg.Length) { - bytesRead += serverSslStream.Read(recvBuf, bytesRead, _sampleMsg.Length - bytesRead); + bytesRead += await ReadAsync(serverSslStream, recvBuf, bytesRead, _sampleMsg.Length - bytesRead); } Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify second read data is as expected."); } @@ -156,7 +124,7 @@ namespace System.Net.Security.Tests [Theory] [InlineData(false)] [InlineData(true)] - public async Task SslStream_StreamToStream_LargeWrites_Sync_Success(bool randomizedData) + public async Task SslStream_StreamToStream_LargeWrites_Success(bool randomizedData) { VirtualNetwork network = new VirtualNetwork(); @@ -182,10 +150,10 @@ namespace System.Net.Security.Tests byte[] receivedLargeMsg = new byte[largeMsg.Length]; // First do a large write and read blocks at a time - clientSslStream.Write(largeMsg); + await WriteAsync(clientSslStream, largeMsg, 0, largeMsg.Length); int bytesRead = 0, totalRead = 0; while (totalRead < largeMsg.Length && - (bytesRead = serverSslStream.Read(receivedLargeMsg, totalRead, receivedLargeMsg.Length - totalRead)) != 0) + (bytesRead = await ReadAsync(serverSslStream, receivedLargeMsg, totalRead, receivedLargeMsg.Length - totalRead)) != 0) { totalRead += bytesRead; } @@ -193,7 +161,7 @@ namespace System.Net.Security.Tests Assert.Equal(largeMsg, receivedLargeMsg); // Then write again and read bytes at a time - clientSslStream.Write(largeMsg); + await WriteAsync(clientSslStream, largeMsg, 0, largeMsg.Length); foreach (byte b in largeMsg) { Assert.Equal(b, serverSslStream.ReadByte()); @@ -202,7 +170,7 @@ namespace System.Net.Security.Tests } [Fact] - public async Task SslStream_StreamToStream_Successive_ClientWrite_Async_Success() + public async Task SslStream_StreamToStream_Successive_ClientWrite_Success() { byte[] recvBuf = new byte[_sampleMsg.Length]; VirtualNetwork network = new VirtualNetwork(); @@ -214,25 +182,25 @@ namespace System.Net.Security.Tests { await DoHandshake(clientSslStream, serverSslStream); - await clientSslStream.WriteAsync(_sampleMsg, 0, _sampleMsg.Length) + await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length) .TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds); int bytesRead = 0; while (bytesRead < _sampleMsg.Length) { - bytesRead += await serverSslStream.ReadAsync(recvBuf, bytesRead, _sampleMsg.Length - bytesRead) + bytesRead += await ReadAsync(serverSslStream, recvBuf, bytesRead, _sampleMsg.Length - bytesRead) .TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds); } Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify first read data is as expected."); - await clientSslStream.WriteAsync(_sampleMsg, 0, _sampleMsg.Length) + await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length) .TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds); bytesRead = 0; while (bytesRead < _sampleMsg.Length) { - bytesRead += await serverSslStream.ReadAsync(recvBuf, bytesRead, _sampleMsg.Length - bytesRead) + bytesRead += await ReadAsync(serverSslStream, recvBuf, bytesRead, _sampleMsg.Length - bytesRead) .TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds); } @@ -254,7 +222,7 @@ namespace System.Net.Security.Tests for (int i = 0; i < 3; i++) { - clientSslStream.Write(_sampleMsg); + await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length); foreach (byte b in _sampleMsg) { Assert.Equal(b, serverSslStream.ReadByte()); @@ -277,7 +245,7 @@ namespace System.Net.Security.Tests for (int i = 0; i < 3; i++) { - await clientSslStream.WriteAsync(_sampleMsg, 0, _sampleMsg.Length).ConfigureAwait(false); + await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length).ConfigureAwait(false); foreach (byte b in _sampleMsg) { Assert.Equal(b, serverSslStream.ReadByte()); @@ -304,7 +272,7 @@ namespace System.Net.Security.Tests { tcs.TrySetResult(null); }; - Task readTask = serverSslStream.ReadAsync(serverBuffer, 0, serverBuffer.Length); + Task readTask = ReadAsync(serverSslStream, serverBuffer, 0, serverBuffer.Length); // Since the sequence of calls that ends in serverStream.Read() is sync, by now // the read task will have acquired the semaphore shared by Stream.BeginReadInternal() @@ -313,36 +281,32 @@ namespace System.Net.Security.Tests await tcs.Task.TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds); // Should not hang - await serverSslStream.WriteAsync(new byte[] { 1 }, 0, 1) + await WriteAsync(serverSslStream, new byte[] { 1 }, 0, 1) .TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds); // Read in client var clientBuffer = new byte[1]; - await clientSslStream.ReadAsync(clientBuffer, 0, clientBuffer.Length); + await ReadAsync(clientSslStream, clientBuffer, 0, clientBuffer.Length); Assert.Equal(1, clientBuffer[0]); // Complete server read task - await clientSslStream.WriteAsync(new byte[] { 2 }, 0, 1); + await WriteAsync(clientSslStream, new byte[] { 2 }, 0, 1); await readTask; Assert.Equal(2, serverBuffer[0]); } } - [OuterLoop("Executes for several seconds")] [Fact] public async Task SslStream_ConcurrentBidirectionalReadsWrites_Success() { VirtualNetwork network = new VirtualNetwork(); - - using (var clientStream = new VirtualNetworkStream(network, isServer: false)) - using (var serverStream = new NotifyReadVirtualNetworkStream(network, isServer: true)) - using (var clientSslStream = new SslStream(clientStream, false, AllowAnyServerCertificate)) - using (var serverSslStream = new SslStream(serverStream)) + using (var clientSslStream = new SslStream(new VirtualNetworkStream(network, isServer: false), false, AllowAnyServerCertificate)) + using (var serverSslStream = new SslStream(new NotifyReadVirtualNetworkStream(network, isServer: true))) { await DoHandshake(clientSslStream, serverSslStream); const int BytesPerSend = 100; - DateTime endTime = DateTime.UtcNow + TimeSpan.FromSeconds(5); + DateTime endTime = DateTime.UtcNow + TimeSpan.FromSeconds(3); await new Task[] { Task.Run(async delegate @@ -350,9 +314,9 @@ namespace System.Net.Security.Tests var buffer = new byte[BytesPerSend]; while (DateTime.UtcNow < endTime) { - await clientStream.WriteAsync(buffer, 0, buffer.Length); + await WriteAsync(clientSslStream, buffer, 0, buffer.Length); int received = 0, bytesRead = 0; - while (received < BytesPerSend && (bytesRead = await serverStream.ReadAsync(buffer, 0, buffer.Length)) != 0) + while (received < BytesPerSend && (bytesRead = await ReadAsync(serverSslStream, buffer, 0, buffer.Length)) != 0) { received += bytesRead; } @@ -364,9 +328,9 @@ namespace System.Net.Security.Tests var buffer = new byte[BytesPerSend]; while (DateTime.UtcNow < endTime) { - await serverStream.WriteAsync(buffer, 0, buffer.Length); + await WriteAsync(serverSslStream, buffer, 0, buffer.Length); int received = 0, bytesRead = 0; - while (received < BytesPerSend && (bytesRead = await clientStream.ReadAsync(buffer, 0, buffer.Length)) != 0) + while (received < BytesPerSend && (bytesRead = await ReadAsync(clientSslStream, buffer, 0, buffer.Length)) != 0) { received += bytesRead; } @@ -393,8 +357,8 @@ namespace System.Net.Security.Tests await DoHandshake(clientSslStream, serverSslStream); var serverBuffer = new byte[1]; - Task serverReadTask = serverSslStream.ReadAsync(serverBuffer, 0, serverBuffer.Length); - await serverSslStream.WriteAsync(new byte[] { 1 }, 0, 1) + Task serverReadTask = ReadAsync(serverSslStream, serverBuffer, 0, serverBuffer.Length); + await WriteAsync(serverSslStream, new byte[] { 1 }, 0, 1) .TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds); // Shouldn't throw, the context is diposed now. @@ -403,12 +367,16 @@ namespace System.Net.Security.Tests // Read in client var clientBuffer = new byte[1]; - await clientSslStream.ReadAsync(clientBuffer, 0, clientBuffer.Length); + await ReadAsync(clientSslStream, clientBuffer, 0, clientBuffer.Length); Assert.Equal(1, clientBuffer[0]); - await clientSslStream.WriteAsync(new byte[] { 2 }, 0, 1); + await WriteAsync(clientSslStream, new byte[] { 2 }, 0, 1); - if (PlatformDetection.IsFullFramework) + // We're inconsistent as to whether the ObjectDisposedException is thrown directly + // or wrapped in an IOException. For sync operations, it's never wrapped. For + // Begin/End, it's always wrapped, And for Async, it's only wrapped on netfx. + if (this is SslStreamStreamToStreamTest_BeginEnd || + (this is SslStreamStreamToStreamTest_Async && PlatformDetection.IsFullFramework)) { await Assert.ThrowsAsync(() => serverReadTask); } @@ -418,11 +386,11 @@ namespace System.Net.Security.Tests Assert.IsType(serverException.InnerException); } - await Assert.ThrowsAsync(() => serverSslStream.ReadAsync(serverBuffer, 0, serverBuffer.Length)); + await Assert.ThrowsAsync(() => ReadAsync(serverSslStream, serverBuffer, 0, serverBuffer.Length)); // Now, there is no pending read, so the internal buffer will be returned to ArrayPool. serverSslStream.Dispose(); - await Assert.ThrowsAsync(() => serverSslStream.ReadAsync(serverBuffer, 0, serverBuffer.Length)); + await Assert.ThrowsAsync(() => ReadAsync(serverSslStream, serverBuffer, 0, serverBuffer.Length)); } } @@ -469,6 +437,7 @@ namespace System.Net.Security.Tests canWriteFunc: () => true, canReadFunc: () => true, writeFunc: (buffer, offset, count) => serverNetworkStream.Write(buffer, offset, count), + writeAsyncFunc: (buffer, offset, count, token) => serverNetworkStream.WriteAsync(buffer, offset, count, token), readFunc: (buffer, offset, count) => { // Do normal reads as requested until the read mode is set @@ -487,16 +456,34 @@ namespace System.Net.Security.Tests { return 0; } + }, + readAsyncFunc: (buffer, offset, count, token) => + { + // Do normal reads as requested until the read mode is set + // to 1. Then do a single read of only 10 bytes to read only + // part of the message, and subsequently return EOF. + if (readMode == 0) + { + return serverNetworkStream.ReadAsync(buffer, offset, count); + } + else if (readMode == 1) + { + readMode = 2; + return serverNetworkStream.ReadAsync(buffer, offset, 10); // read at least header but less than full frame + } + else + { + return Task.FromResult(0); + } }); - using (var clientSslStream = new SslStream(clientNetworkStream, false, AllowAnyServerCertificate)) using (var serverSslStream = new SslStream(serverWrappedNetworkStream)) { await DoHandshake(clientSslStream, serverSslStream); - await clientSslStream.WriteAsync(new byte[20], 0, 20); + await WriteAsync(clientSslStream, new byte[20], 0, 20); readMode = 1; - await Assert.ThrowsAsync(() => serverSslStream.ReadAsync(new byte[1], 0, 1)); + await Assert.ThrowsAsync(() => ReadAsync(serverSslStream, new byte[1], 0, 1)); } } } @@ -543,6 +530,12 @@ namespace System.Net.Security.Tests await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2); } } + + protected override Task ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) => + stream.ReadAsync(buffer, offset, count, cancellationToken); + + protected override Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) => + stream.WriteAsync(buffer, offset, count, cancellationToken); } public sealed class SslStreamStreamToStreamTest_BeginEnd : SslStreamStreamToStreamTest @@ -556,6 +549,16 @@ namespace System.Net.Security.Tests await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2); } } + + protected override Task ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) => + cancellationToken.IsCancellationRequested ? + Task.FromCanceled(cancellationToken) : + Task.Factory.FromAsync(stream.BeginRead, stream.EndRead, buffer, offset, count, null); + + protected override Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) => + cancellationToken.IsCancellationRequested ? + Task.FromCanceled(cancellationToken) : + Task.Factory.FromAsync(stream.BeginWrite, stream.EndWrite, buffer, offset, count, null); } public sealed class SslStreamStreamToStreamTest_Sync : SslStreamStreamToStreamTest @@ -569,5 +572,12 @@ namespace System.Net.Security.Tests await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2); } } + + protected override Task ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) => + Task.Run(() => stream.Read(buffer, offset, count), cancellationToken); + + protected override Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) => + Task.Run(() => stream.Write(buffer, offset, count), cancellationToken); + } }