Add SslStream test for "unlocking" after failure (dotnet/corefx#36094)
authorStephen Toub <stoub@microsoft.com>
Sat, 16 Mar 2019 19:54:15 +0000 (15:54 -0400)
committerGitHub <noreply@github.com>
Sat, 16 Mar 2019 19:54:15 +0000 (15:54 -0400)
* Add SslStream test for "unlocking" after failure

* Address PR feedback

Commit migrated from https://github.com/dotnet/corefx/commit/6a26d45c3a0f3571f57637c2a995fb40336378f5

src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamStreamToStreamTest.cs

index 2712309..128ab4b 100644 (file)
@@ -23,9 +23,9 @@ namespace System.Net.Security.Tests
 
         protected abstract Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream);
 
-        protected abstract Task<int> ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
+        protected abstract Task<int> ReadAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
 
-        protected abstract Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
+        protected abstract Task WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
 
         [Fact]
         public async Task SslStream_StreamToStream_Authentication_Success()
@@ -85,6 +85,63 @@ namespace System.Net.Security.Tests
             }
         }
 
+        [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework, ".NET Framework does not recover well from underlying stream failures")]
+        [Fact]
+        public async Task Read_CorrectlyUnlocksAfterFailure()
+        {
+            var network = new VirtualNetwork();
+            var clientStream = new ThrowingDelegatingStream(new VirtualNetworkStream(network, isServer: false));
+            using (var clientSslStream = new SslStream(clientStream, false, AllowAnyServerCertificate))
+            using (var serverSslStream = new SslStream(new VirtualNetworkStream(network, isServer: true)))
+            {
+                await DoHandshake(clientSslStream, serverSslStream);
+
+                // Throw an exception from the wrapped stream's read operation
+                clientStream.ExceptionToThrow = new FormatException();
+                IOException thrown = await Assert.ThrowsAsync<IOException>(() => ReadAsync(clientSslStream, new byte[1], 0, 1));
+                Assert.Same(clientStream.ExceptionToThrow, thrown.InnerException);
+                clientStream.ExceptionToThrow = null;
+
+                // Validate that the SslStream continues to be usable
+                for (byte b = 42; b < 52; b++) // arbitrary test values
+                {
+                    await WriteAsync(serverSslStream, new byte[1] { b }, 0, 1);
+                    byte[] buffer = new byte[1];
+                    Assert.Equal(1, await ReadAsync(clientSslStream, buffer, 0, 1));
+                    Assert.Equal(b, buffer[0]);
+                }
+            }
+        }
+
+        [ActiveIssue(36076)]
+        [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework, ".NET Framework does not recover well from underlying stream failures")]
+        [Fact]
+        public async Task Write_CorrectlyUnlocksAfterFailure()
+        {
+            var network = new VirtualNetwork();
+            var clientStream = new ThrowingDelegatingStream(new VirtualNetworkStream(network, isServer: false));
+            using (var clientSslStream = new SslStream(clientStream, false, AllowAnyServerCertificate))
+            using (var serverSslStream = new SslStream(new VirtualNetworkStream(network, isServer: true)))
+            {
+                await DoHandshake(clientSslStream, serverSslStream);
+
+                // Throw an exception from the wrapped stream's write operation
+                clientStream.ExceptionToThrow = new FormatException();
+                IOException thrown = await Assert.ThrowsAsync<IOException>(() => WriteAsync(clientSslStream, new byte[1], 0, 1));
+                Assert.Same(clientStream.ExceptionToThrow, thrown.InnerException);
+                clientStream.ExceptionToThrow = null;
+
+                // Validate that the SslStream continues to be usable
+                for (byte b = 42; b < 52; b++)
+                {
+                    await WriteAsync(clientSslStream, new byte[1] { b }, 0, 1);
+                    byte[] buffer = new byte[1];
+                    Assert.Equal(1, await ReadAsync(serverSslStream, buffer, 0, 1));
+                    Assert.Equal(b, buffer[0]);
+                }
+            }
+        }
+
         [Fact]
         public async Task SslStream_StreamToStream_Successive_ClientWrite_WithZeroBytes_Success()
         {
@@ -97,7 +154,7 @@ namespace System.Net.Security.Tests
             using (var serverSslStream = new SslStream(serverStream))
             {
                 await DoHandshake(clientSslStream, serverSslStream);
-                
+
                 await WriteAsync(clientSslStream, Array.Empty<byte>(), 0, 0);
                 await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length);
 
@@ -181,7 +238,7 @@ namespace System.Net.Security.Tests
             using (var serverSslStream = new SslStream(serverStream))
             {
                 await DoHandshake(clientSslStream, serverSslStream);
-                                
+
                 await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length)
                     .TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
 
@@ -219,7 +276,7 @@ namespace System.Net.Security.Tests
             using (var serverSslStream = new SslStream(serverStream))
             {
                 await DoHandshake(clientSslStream, serverSslStream);
-                
+
                 for (int i = 0; i < 3; i++)
                 {
                     await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length);
@@ -242,7 +299,7 @@ namespace System.Net.Security.Tests
             using (var serverSslStream = new SslStream(serverStream))
             {
                 await DoHandshake(clientSslStream, serverSslStream);
-                
+
                 for (int i = 0; i < 3; i++)
                 {
                     await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length).ConfigureAwait(false);
@@ -265,7 +322,7 @@ namespace System.Net.Security.Tests
             using (var serverSslStream = new SslStream(serverStream))
             {
                 await DoHandshake(clientSslStream, serverSslStream);
-                
+
                 var serverBuffer = new byte[1];
                 var tcs = new TaskCompletionSource<object>();
                 serverStream.OnRead += (buffer, offset, count) =>
@@ -517,6 +574,73 @@ namespace System.Net.Security.Tests
                 return false;
             }
         }
+
+        private sealed class ThrowingDelegatingStream : Stream
+        {
+            private readonly Stream _stream;
+
+            public ThrowingDelegatingStream(Stream stream) => _stream = stream;
+
+            public override bool CanRead => _stream.CanRead;
+            public override bool CanWrite => _stream.CanWrite;
+            public override bool CanSeek => _stream.CanSeek;
+            protected override void Dispose(bool disposing) => _stream.Dispose();
+            public override long Length => _stream.Length;
+            public override long Position { get => _stream.Position; set => _stream.Position = value; }
+            public override void Flush() => _stream.Flush();
+            public override long Seek(long offset, SeekOrigin origin) => _stream.Seek(offset, origin);
+            public override void SetLength(long value) => _stream.SetLength(value);
+
+            public Exception ExceptionToThrow { get; set; }
+
+            public override int Read(byte[] buffer, int offset, int count)
+            {
+                ThrowIfNecessary();
+                return _stream.Read(buffer, offset, count);
+            }
+
+            public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
+            {
+                ThrowIfNecessary();
+                return _stream.BeginRead(buffer, offset, count, callback, state);
+            }
+
+            public override int EndRead(IAsyncResult asyncResult) => _stream.EndRead(asyncResult);
+
+            public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+            {
+                ThrowIfNecessary();
+                return _stream.ReadAsync(buffer, offset, count, cancellationToken);
+            }
+
+            public override void Write(byte[] buffer, int offset, int count)
+            {
+                ThrowIfNecessary();
+                _stream.Write(buffer, offset, count);
+            }
+
+            public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
+            {
+                ThrowIfNecessary();
+                return _stream.BeginWrite(buffer, offset, count, callback, state);
+            }
+
+            public override void EndWrite(IAsyncResult asyncResult) => _stream.EndWrite(asyncResult);
+
+            public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
+            {
+                ThrowIfNecessary();
+                return _stream.WriteAsync(buffer, offset, count, cancellationToken);
+            }
+
+            private void ThrowIfNecessary()
+            {
+                if (ExceptionToThrow != null)
+                {
+                    throw ExceptionToThrow;
+                }
+            }
+        }
     }
 
     public sealed class SslStreamStreamToStreamTest_Async : SslStreamStreamToStreamTest
@@ -531,10 +655,10 @@ namespace System.Net.Security.Tests
             }
         }
 
-        protected override Task<int> ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+        protected override Task<int> ReadAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
             stream.ReadAsync(buffer, offset, count, cancellationToken);
 
-        protected override Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+        protected override Task WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
             stream.WriteAsync(buffer, offset, count, cancellationToken);
     }
 
@@ -550,12 +674,12 @@ namespace System.Net.Security.Tests
             }
         }
 
-        protected override Task<int> ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+        protected override Task<int> ReadAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
             cancellationToken.IsCancellationRequested ?
                 Task.FromCanceled<int>(cancellationToken) :
                 Task.Factory.FromAsync(stream.BeginRead, stream.EndRead, buffer, offset, count, null);
 
-        protected override Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+        protected override Task WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
             cancellationToken.IsCancellationRequested ?
                 Task.FromCanceled<int>(cancellationToken) :
                 Task.Factory.FromAsync(stream.BeginWrite, stream.EndWrite, buffer, offset, count, null);
@@ -573,10 +697,10 @@ namespace System.Net.Security.Tests
             }
         }
 
-        protected override Task<int> ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+        protected override Task<int> ReadAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
             Task.Run(() => stream.Read(buffer, offset, count), cancellationToken);
 
-        protected override Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+        protected override Task WriteAsync(Stream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
             Task.Run(() => stream.Write(buffer, offset, count), cancellationToken);
 
     }