using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Text;
+using System.Threading;
using System.Threading.Tasks;
using Xunit;
protected abstract Task DoHandshake(SslStream clientSslStream, SslStream serverSslStream);
+ protected abstract Task<int> ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
+
+ protected abstract Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default);
+
[Fact]
public async Task SslStream_StreamToStream_Authentication_Success()
{
}
[Fact]
- public async Task SslStream_StreamToStream_Successive_ClientWrite_Sync_Success()
- {
- byte[] recvBuf = new byte[_sampleMsg.Length];
- VirtualNetwork network = new VirtualNetwork();
-
- using (var clientStream = new VirtualNetworkStream(network, isServer: false))
- using (var serverStream = new VirtualNetworkStream(network, isServer: true))
- using (var clientSslStream = new SslStream(clientStream, false, AllowAnyServerCertificate))
- using (var serverSslStream = new SslStream(serverStream))
- {
- await DoHandshake(clientSslStream, serverSslStream);
-
- clientSslStream.Write(_sampleMsg);
-
- int bytesRead = 0;
- while (bytesRead < _sampleMsg.Length)
- {
- bytesRead += serverSslStream.Read(recvBuf, bytesRead, _sampleMsg.Length - bytesRead);
- }
-
- Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify first read data is as expected.");
-
- clientSslStream.Write(_sampleMsg);
-
- bytesRead = 0;
- while (bytesRead < _sampleMsg.Length)
- {
- bytesRead += serverSslStream.Read(recvBuf, bytesRead, _sampleMsg.Length - bytesRead);
- }
-
- Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify second read data is as expected.");
- }
- }
-
- [Fact]
public async Task SslStream_StreamToStream_Successive_ClientWrite_WithZeroBytes_Success()
{
byte[] recvBuf = new byte[_sampleMsg.Length];
{
await DoHandshake(clientSslStream, serverSslStream);
- clientSslStream.Write(Array.Empty<byte>());
- await clientSslStream.WriteAsync(Array.Empty<byte>(), 0, 0);
- clientSslStream.Write(_sampleMsg);
+ await WriteAsync(clientSslStream, Array.Empty<byte>(), 0, 0);
+ await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length);
int bytesRead = 0;
while (bytesRead < _sampleMsg.Length)
{
- bytesRead += serverSslStream.Read(recvBuf, bytesRead, _sampleMsg.Length - bytesRead);
+ bytesRead += await ReadAsync(serverSslStream, recvBuf, bytesRead, _sampleMsg.Length - bytesRead);
}
Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify first read data is as expected.");
- clientSslStream.Write(_sampleMsg);
- await clientSslStream.WriteAsync(Array.Empty<byte>(), 0, 0);
- clientSslStream.Write(Array.Empty<byte>());
+ await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length);
+ await WriteAsync(clientSslStream, Array.Empty<byte>(), 0, 0);
bytesRead = 0;
while (bytesRead < _sampleMsg.Length)
{
- bytesRead += serverSslStream.Read(recvBuf, bytesRead, _sampleMsg.Length - bytesRead);
+ bytesRead += await ReadAsync(serverSslStream, recvBuf, bytesRead, _sampleMsg.Length - bytesRead);
}
Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify second read data is as expected.");
}
[Theory]
[InlineData(false)]
[InlineData(true)]
- public async Task SslStream_StreamToStream_LargeWrites_Sync_Success(bool randomizedData)
+ public async Task SslStream_StreamToStream_LargeWrites_Success(bool randomizedData)
{
VirtualNetwork network = new VirtualNetwork();
byte[] receivedLargeMsg = new byte[largeMsg.Length];
// First do a large write and read blocks at a time
- clientSslStream.Write(largeMsg);
+ await WriteAsync(clientSslStream, largeMsg, 0, largeMsg.Length);
int bytesRead = 0, totalRead = 0;
while (totalRead < largeMsg.Length &&
- (bytesRead = serverSslStream.Read(receivedLargeMsg, totalRead, receivedLargeMsg.Length - totalRead)) != 0)
+ (bytesRead = await ReadAsync(serverSslStream, receivedLargeMsg, totalRead, receivedLargeMsg.Length - totalRead)) != 0)
{
totalRead += bytesRead;
}
Assert.Equal(largeMsg, receivedLargeMsg);
// Then write again and read bytes at a time
- clientSslStream.Write(largeMsg);
+ await WriteAsync(clientSslStream, largeMsg, 0, largeMsg.Length);
foreach (byte b in largeMsg)
{
Assert.Equal(b, serverSslStream.ReadByte());
}
[Fact]
- public async Task SslStream_StreamToStream_Successive_ClientWrite_Async_Success()
+ public async Task SslStream_StreamToStream_Successive_ClientWrite_Success()
{
byte[] recvBuf = new byte[_sampleMsg.Length];
VirtualNetwork network = new VirtualNetwork();
{
await DoHandshake(clientSslStream, serverSslStream);
- await clientSslStream.WriteAsync(_sampleMsg, 0, _sampleMsg.Length)
+ await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length)
.TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
int bytesRead = 0;
while (bytesRead < _sampleMsg.Length)
{
- bytesRead += await serverSslStream.ReadAsync(recvBuf, bytesRead, _sampleMsg.Length - bytesRead)
+ bytesRead += await ReadAsync(serverSslStream, recvBuf, bytesRead, _sampleMsg.Length - bytesRead)
.TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
}
Assert.True(VerifyOutput(recvBuf, _sampleMsg), "verify first read data is as expected.");
- await clientSslStream.WriteAsync(_sampleMsg, 0, _sampleMsg.Length)
+ await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length)
.TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
bytesRead = 0;
while (bytesRead < _sampleMsg.Length)
{
- bytesRead += await serverSslStream.ReadAsync(recvBuf, bytesRead, _sampleMsg.Length - bytesRead)
+ bytesRead += await ReadAsync(serverSslStream, recvBuf, bytesRead, _sampleMsg.Length - bytesRead)
.TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
}
for (int i = 0; i < 3; i++)
{
- clientSslStream.Write(_sampleMsg);
+ await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length);
foreach (byte b in _sampleMsg)
{
Assert.Equal(b, serverSslStream.ReadByte());
for (int i = 0; i < 3; i++)
{
- await clientSslStream.WriteAsync(_sampleMsg, 0, _sampleMsg.Length).ConfigureAwait(false);
+ await WriteAsync(clientSslStream, _sampleMsg, 0, _sampleMsg.Length).ConfigureAwait(false);
foreach (byte b in _sampleMsg)
{
Assert.Equal(b, serverSslStream.ReadByte());
{
tcs.TrySetResult(null);
};
- Task readTask = serverSslStream.ReadAsync(serverBuffer, 0, serverBuffer.Length);
+ Task readTask = ReadAsync(serverSslStream, serverBuffer, 0, serverBuffer.Length);
// Since the sequence of calls that ends in serverStream.Read() is sync, by now
// the read task will have acquired the semaphore shared by Stream.BeginReadInternal()
await tcs.Task.TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
// Should not hang
- await serverSslStream.WriteAsync(new byte[] { 1 }, 0, 1)
+ await WriteAsync(serverSslStream, new byte[] { 1 }, 0, 1)
.TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
// Read in client
var clientBuffer = new byte[1];
- await clientSslStream.ReadAsync(clientBuffer, 0, clientBuffer.Length);
+ await ReadAsync(clientSslStream, clientBuffer, 0, clientBuffer.Length);
Assert.Equal(1, clientBuffer[0]);
// Complete server read task
- await clientSslStream.WriteAsync(new byte[] { 2 }, 0, 1);
+ await WriteAsync(clientSslStream, new byte[] { 2 }, 0, 1);
await readTask;
Assert.Equal(2, serverBuffer[0]);
}
}
- [OuterLoop("Executes for several seconds")]
[Fact]
public async Task SslStream_ConcurrentBidirectionalReadsWrites_Success()
{
VirtualNetwork network = new VirtualNetwork();
-
- using (var clientStream = new VirtualNetworkStream(network, isServer: false))
- using (var serverStream = new NotifyReadVirtualNetworkStream(network, isServer: true))
- using (var clientSslStream = new SslStream(clientStream, false, AllowAnyServerCertificate))
- using (var serverSslStream = new SslStream(serverStream))
+ using (var clientSslStream = new SslStream(new VirtualNetworkStream(network, isServer: false), false, AllowAnyServerCertificate))
+ using (var serverSslStream = new SslStream(new NotifyReadVirtualNetworkStream(network, isServer: true)))
{
await DoHandshake(clientSslStream, serverSslStream);
const int BytesPerSend = 100;
- DateTime endTime = DateTime.UtcNow + TimeSpan.FromSeconds(5);
+ DateTime endTime = DateTime.UtcNow + TimeSpan.FromSeconds(3);
await new Task[]
{
Task.Run(async delegate
var buffer = new byte[BytesPerSend];
while (DateTime.UtcNow < endTime)
{
- await clientStream.WriteAsync(buffer, 0, buffer.Length);
+ await WriteAsync(clientSslStream, buffer, 0, buffer.Length);
int received = 0, bytesRead = 0;
- while (received < BytesPerSend && (bytesRead = await serverStream.ReadAsync(buffer, 0, buffer.Length)) != 0)
+ while (received < BytesPerSend && (bytesRead = await ReadAsync(serverSslStream, buffer, 0, buffer.Length)) != 0)
{
received += bytesRead;
}
var buffer = new byte[BytesPerSend];
while (DateTime.UtcNow < endTime)
{
- await serverStream.WriteAsync(buffer, 0, buffer.Length);
+ await WriteAsync(serverSslStream, buffer, 0, buffer.Length);
int received = 0, bytesRead = 0;
- while (received < BytesPerSend && (bytesRead = await clientStream.ReadAsync(buffer, 0, buffer.Length)) != 0)
+ while (received < BytesPerSend && (bytesRead = await ReadAsync(clientSslStream, buffer, 0, buffer.Length)) != 0)
{
received += bytesRead;
}
await DoHandshake(clientSslStream, serverSslStream);
var serverBuffer = new byte[1];
- Task serverReadTask = serverSslStream.ReadAsync(serverBuffer, 0, serverBuffer.Length);
- await serverSslStream.WriteAsync(new byte[] { 1 }, 0, 1)
+ Task serverReadTask = ReadAsync(serverSslStream, serverBuffer, 0, serverBuffer.Length);
+ await WriteAsync(serverSslStream, new byte[] { 1 }, 0, 1)
.TimeoutAfter(TestConfiguration.PassingTestTimeoutMilliseconds);
// Shouldn't throw, the context is diposed now.
// Read in client
var clientBuffer = new byte[1];
- await clientSslStream.ReadAsync(clientBuffer, 0, clientBuffer.Length);
+ await ReadAsync(clientSslStream, clientBuffer, 0, clientBuffer.Length);
Assert.Equal(1, clientBuffer[0]);
- await clientSslStream.WriteAsync(new byte[] { 2 }, 0, 1);
+ await WriteAsync(clientSslStream, new byte[] { 2 }, 0, 1);
- if (PlatformDetection.IsFullFramework)
+ // We're inconsistent as to whether the ObjectDisposedException is thrown directly
+ // or wrapped in an IOException. For sync operations, it's never wrapped. For
+ // Begin/End, it's always wrapped, And for Async, it's only wrapped on netfx.
+ if (this is SslStreamStreamToStreamTest_BeginEnd ||
+ (this is SslStreamStreamToStreamTest_Async && PlatformDetection.IsFullFramework))
{
await Assert.ThrowsAsync<ObjectDisposedException>(() => serverReadTask);
}
Assert.IsType<ObjectDisposedException>(serverException.InnerException);
}
- await Assert.ThrowsAsync<ObjectDisposedException>(() => serverSslStream.ReadAsync(serverBuffer, 0, serverBuffer.Length));
+ await Assert.ThrowsAsync<ObjectDisposedException>(() => ReadAsync(serverSslStream, serverBuffer, 0, serverBuffer.Length));
// Now, there is no pending read, so the internal buffer will be returned to ArrayPool.
serverSslStream.Dispose();
- await Assert.ThrowsAsync<ObjectDisposedException>(() => serverSslStream.ReadAsync(serverBuffer, 0, serverBuffer.Length));
+ await Assert.ThrowsAsync<ObjectDisposedException>(() => ReadAsync(serverSslStream, serverBuffer, 0, serverBuffer.Length));
}
}
canWriteFunc: () => true,
canReadFunc: () => true,
writeFunc: (buffer, offset, count) => serverNetworkStream.Write(buffer, offset, count),
+ writeAsyncFunc: (buffer, offset, count, token) => serverNetworkStream.WriteAsync(buffer, offset, count, token),
readFunc: (buffer, offset, count) =>
{
// Do normal reads as requested until the read mode is set
{
return 0;
}
+ },
+ readAsyncFunc: (buffer, offset, count, token) =>
+ {
+ // Do normal reads as requested until the read mode is set
+ // to 1. Then do a single read of only 10 bytes to read only
+ // part of the message, and subsequently return EOF.
+ if (readMode == 0)
+ {
+ return serverNetworkStream.ReadAsync(buffer, offset, count);
+ }
+ else if (readMode == 1)
+ {
+ readMode = 2;
+ return serverNetworkStream.ReadAsync(buffer, offset, 10); // read at least header but less than full frame
+ }
+ else
+ {
+ return Task.FromResult(0);
+ }
});
-
using (var clientSslStream = new SslStream(clientNetworkStream, false, AllowAnyServerCertificate))
using (var serverSslStream = new SslStream(serverWrappedNetworkStream))
{
await DoHandshake(clientSslStream, serverSslStream);
- await clientSslStream.WriteAsync(new byte[20], 0, 20);
+ await WriteAsync(clientSslStream, new byte[20], 0, 20);
readMode = 1;
- await Assert.ThrowsAsync<IOException>(() => serverSslStream.ReadAsync(new byte[1], 0, 1));
+ await Assert.ThrowsAsync<IOException>(() => ReadAsync(serverSslStream, new byte[1], 0, 1));
}
}
}
await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
}
}
+
+ protected override Task<int> ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+ stream.ReadAsync(buffer, offset, count, cancellationToken);
+
+ protected override Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+ stream.WriteAsync(buffer, offset, count, cancellationToken);
}
public sealed class SslStreamStreamToStreamTest_BeginEnd : SslStreamStreamToStreamTest
await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
}
}
+
+ protected override Task<int> ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+ cancellationToken.IsCancellationRequested ?
+ Task.FromCanceled<int>(cancellationToken) :
+ Task.Factory.FromAsync(stream.BeginRead, stream.EndRead, buffer, offset, count, null);
+
+ protected override Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+ cancellationToken.IsCancellationRequested ?
+ Task.FromCanceled<int>(cancellationToken) :
+ Task.Factory.FromAsync(stream.BeginWrite, stream.EndWrite, buffer, offset, count, null);
}
public sealed class SslStreamStreamToStreamTest_Sync : SslStreamStreamToStreamTest
await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
}
}
+
+ protected override Task<int> ReadAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+ Task.Run(() => stream.Read(buffer, offset, count), cancellationToken);
+
+ protected override Task WriteAsync(SslStream stream, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default(CancellationToken)) =>
+ Task.Run(() => stream.Write(buffer, offset, count), cancellationToken);
+
}
}