public override void SetLength(long value) { }
public void Shutdown() { }
public System.Threading.Tasks.ValueTask ShutdownCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
+ public System.Threading.Tasks.ValueTask WaitForWriteCompletionAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public override void Write(byte[] buffer, int offset, int count) { }
public override void Write(System.ReadOnlySpan<byte> buffer) { }
public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence<byte> buffers, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
namespace System.Net.Quic.Implementations.Mock
{
}
MockStream.StreamState streamState = new MockStream.StreamState(streamId, bidirectional);
+ // TODO Streams are never removed from a connection. Consider cleaning up in the future.
+ state._streams[streamState._streamId] = streamState;
+
Channel<MockStream.StreamState> streamChannel = _isClient ? state._clientInitiatedStreamChannel : state._serverInitiatedStreamChannel;
streamChannel.Writer.TryWrite(streamState);
state._serverErrorCode = errorCode;
DrainAcceptQueue(errorCode, -1);
}
+
+ foreach (KeyValuePair<long, MockStream.StreamState> kvp in state._streams)
+ {
+ kvp.Value._outboundWritesCompletedTcs.TrySetException(new QuicConnectionAbortedException(errorCode));
+ kvp.Value._inboundWritesCompletedTcs.TrySetException(new QuicConnectionAbortedException(errorCode));
+ }
}
Dispose();
internal sealed class ConnectionState
{
public readonly SslApplicationProtocol _applicationProtocol;
- public Channel<MockStream.StreamState> _clientInitiatedStreamChannel;
- public Channel<MockStream.StreamState> _serverInitiatedStreamChannel;
+ public readonly Channel<MockStream.StreamState> _clientInitiatedStreamChannel;
+ public readonly Channel<MockStream.StreamState> _serverInitiatedStreamChannel;
+ public readonly ConcurrentDictionary<long, MockStream.StreamState> _streams;
public PeerStreamLimit? _clientStreamLimit;
public PeerStreamLimit? _serverStreamLimit;
_clientInitiatedStreamChannel = Channel.CreateUnbounded<MockStream.StreamState>();
_serverInitiatedStreamChannel = Channel.CreateUnbounded<MockStream.StreamState>();
_clientErrorCode = _serverErrorCode = -1;
+ _streams = new ConcurrentDictionary<long, MockStream.StreamState>();
}
}
}
if (endStream)
{
streamBuffer.EndWrite();
+ WritesCompletedTcs.TrySetResult();
}
}
if (_isInitiator)
{
_streamState._outboundWriteErrorCode = errorCode;
+ _streamState._inboundWritesCompletedTcs.TrySetException(new QuicStreamAbortedException(errorCode));
}
else
{
_streamState._inboundWriteErrorCode = errorCode;
+ _streamState._outboundWritesCompletedTcs.TrySetException(new QuicOperationAbortedException());
}
ReadStreamBuffer?.AbortRead();
if (_isInitiator)
{
_streamState._outboundReadErrorCode = errorCode;
+ _streamState._outboundWritesCompletedTcs.TrySetException(new QuicStreamAbortedException(errorCode));
}
else
{
_streamState._inboundReadErrorCode = errorCode;
+ _streamState._inboundWritesCompletedTcs.TrySetException(new QuicOperationAbortedException());
}
WriteStreamBuffer?.EndWrite();
{
_connection.LocalStreamLimit!.Bidirectional.Decrement();
}
+
+ WritesCompletedTcs.TrySetResult();
}
private void CheckDisposed()
return default;
}
+ internal override ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default)
+ {
+ CheckDisposed();
+
+ return new ValueTask(WritesCompletedTcs.Task);
+ }
+
+ private TaskCompletionSource WritesCompletedTcs => _isInitiator
+ ? _streamState._outboundWritesCompletedTcs
+ : _streamState._inboundWritesCompletedTcs;
+
internal sealed class StreamState
{
public readonly long _streamId;
public long _inboundReadErrorCode;
public long _outboundWriteErrorCode;
public long _inboundWriteErrorCode;
+ public TaskCompletionSource _outboundWritesCompletedTcs;
+ public TaskCompletionSource _inboundWritesCompletedTcs;
private const int InitialBufferSize =
#if DEBUG
_streamId = streamId;
_outboundStreamBuffer = new StreamBuffer(initialBufferSize: InitialBufferSize, maxBufferSize: MaxBufferSize);
_inboundStreamBuffer = (bidirectional ? new StreamBuffer() : null);
+ _outboundWritesCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+ _inboundWritesCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
}
}
}
// Resettable completions to be used for multiple calls to send.
public readonly ResettableCompletionSource<uint> SendResettableCompletionSource = new ResettableCompletionSource<uint>();
+ public ShutdownWriteState ShutdownWriteState;
+
+ // Set once writes have been shutdown.
+ public readonly TaskCompletionSource ShutdownWriteCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
public ShutdownState ShutdownState;
// The value makes sure that we release the handles only once.
public int ShutdownDone;
return;
}
+ bool shouldComplete = false;
+
lock (_state)
{
if (_state.SendState < SendState.Aborted)
{
_state.SendState = SendState.Aborted;
}
+
+ if (_state.ShutdownWriteState == ShutdownWriteState.None)
+ {
+ _state.ShutdownWriteState = ShutdownWriteState.Canceled;
+ shouldComplete = true;
+ }
+ }
+
+ if (shouldComplete)
+ {
+ _state.ShutdownWriteCompletionSource.SetException(
+ ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException("Write was aborted.")));
}
StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND, errorCode);
await _state.ShutdownCompletionSource.Task.ConfigureAwait(false);
}
+ internal override ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default)
+ {
+ // TODO: What should happen if this is called for a unidirectional stream and there are no writes?
+
+ ThrowIfDisposed();
+
+ lock (_state)
+ {
+ if (_state.ShutdownWriteState == ShutdownWriteState.ConnectionClosed)
+ {
+ throw GetConnectionAbortedException(_state);
+ }
+ }
+
+ return new ValueTask(_state.ShutdownWriteCompletionSource.Task.WaitAsync(cancellationToken));
+ }
+
internal override void Shutdown()
{
ThrowIfDisposed();
// Peer has stopped receiving data, don't send anymore.
case QUIC_STREAM_EVENT_TYPE.PEER_RECEIVE_ABORTED:
return HandleEventPeerRecvAborted(state, ref evt);
+ // Occurs when shutdown is completed for the send side.
+ // This only happens for shutdown on sending, not receiving
+ // Receive shutdown can only be abortive.
+ case QUIC_STREAM_EVENT_TYPE.SEND_SHUTDOWN_COMPLETE:
+ return HandleEventSendShutdownComplete(state, ref evt);
// Shutdown for both sending and receiving is completed.
case QUIC_STREAM_EVENT_TYPE.SHUTDOWN_COMPLETE:
return HandleEventShutdownComplete(state, ref evt);
private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt)
{
- bool shouldComplete = false;
+ bool shouldSendComplete = false;
+ bool shouldShutdownWriteComplete = false;
lock (state)
{
if (state.SendState == SendState.None || state.SendState == SendState.Pending)
{
- shouldComplete = true;
+ shouldSendComplete = true;
+ }
+
+ if (state.ShutdownWriteState == ShutdownWriteState.None)
+ {
+ state.ShutdownWriteState = ShutdownWriteState.Canceled;
+ shouldShutdownWriteComplete = true;
}
+
state.SendState = SendState.Aborted;
state.SendErrorCode = (long)evt.Data.PeerReceiveAborted.ErrorCode;
}
- if (shouldComplete)
+ if (shouldSendComplete)
{
state.SendResettableCompletionSource.CompleteException(
ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode)));
}
+ if (shouldShutdownWriteComplete)
+ {
+ state.ShutdownWriteCompletionSource.SetException(
+ ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode)));
+ }
+
return MsQuicStatusCodes.Success;
}
return MsQuicStatusCodes.Success;
}
+ private static uint HandleEventSendShutdownComplete(State state, ref StreamEvent evt)
+ {
+ // Graceful will be false in three situations:
+ // 1. The peer aborted reads and the PEER_RECEIVE_ABORTED event was raised.
+ // ShutdownWriteCompletionSource is already complete with an error.
+ // 2. We aborted writes.
+ // ShutdownWriteCompletionSource is already complete with an error.
+ // 3. The connection was closed.
+ // SHUTDOWN_COMPLETE event will be raised immediately after this event. It will handle completing with an error.
+ //
+ // Only use this event with sends gracefully completed.
+ if (evt.Data.SendShutdownComplete.Graceful != 0)
+ {
+ bool shouldComplete = false;
+ lock (state)
+ {
+ if (state.ShutdownWriteState == ShutdownWriteState.None)
+ {
+ state.ShutdownWriteState = ShutdownWriteState.Finished;
+ shouldComplete = true;
+ }
+ }
+
+ if (shouldComplete)
+ {
+ state.ShutdownWriteCompletionSource.SetResult();
+ }
+ }
+
+ return MsQuicStatusCodes.Success;
+ }
+
private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt)
{
StreamEventDataShutdownComplete shutdownCompleteEvent = evt.Data.ShutdownComplete;
}
bool shouldReadComplete = false;
+ bool shouldShutdownWriteComplete = false;
bool shouldShutdownComplete = false;
lock (state)
shouldReadComplete = CleanupReadStateAndCheckPending(state, ReadState.ReadsCompleted);
+ if (state.ShutdownWriteState == ShutdownWriteState.None)
+ {
+ // TODO: We can get to this point if the stream is unidirectional and there are no writes.
+ // Consider what is the best behavior here with write shutdown and the read side of
+ // unidirecitonal streams in the future.
+ state.ShutdownWriteState = ShutdownWriteState.Finished;
+ shouldShutdownWriteComplete = true;
+ }
+
if (state.ShutdownState == ShutdownState.None)
{
state.ShutdownState = ShutdownState.Finished;
state.ReceiveResettableCompletionSource.Complete(0);
}
+ if (shouldShutdownWriteComplete)
+ {
+ state.ShutdownWriteCompletionSource.SetResult();
+ }
+
if (shouldShutdownComplete)
{
state.ShutdownCompletionSource.SetResult();
bool shouldCompleteRead = false;
bool shouldCompleteSend = false;
+ bool shouldCompleteShutdownWrite = false;
bool shouldCompleteShutdown = false;
lock (state)
}
state.SendState = SendState.ConnectionClosed;
+ if (state.ShutdownWriteState == ShutdownWriteState.None)
+ {
+ shouldCompleteShutdownWrite = true;
+ }
+ state.ShutdownWriteState = ShutdownWriteState.ConnectionClosed;
+
if (state.ShutdownState == ShutdownState.None)
{
shouldCompleteShutdown = true;
ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state)));
}
+ if (shouldCompleteShutdownWrite)
+ {
+ state.ShutdownWriteCompletionSource.SetException(
+ ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state)));
+ }
+
if (shouldCompleteShutdown)
{
state.ShutdownCompletionSource.SetException(
Closed
}
+ private enum ShutdownWriteState
+ {
+ None = 0,
+ Canceled,
+ Finished,
+ ConnectionClosed
+ }
+
private enum ShutdownState
{
None = 0,
internal abstract ValueTask ShutdownCompleted(CancellationToken cancellationToken = default);
+ internal abstract ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default);
+
internal abstract void Shutdown();
internal abstract void Flush();
public ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) => _provider.ShutdownCompleted(cancellationToken);
+ public ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default) => _provider.WaitForWriteCompletionAsync(cancellationToken);
+
public void Shutdown() => _provider.Shutdown();
protected override void Dispose(bool disposing)
}
);
}
+
+ [Fact]
+ public async Task WaitForWriteCompletionAsync_ClientReadAborted_Throws()
+ {
+ const int ExpectedErrorCode = 0xfffffff;
+
+ TaskCompletionSource<long> waitForAbortTcs = new TaskCompletionSource<long>(TaskCreationOptions.RunContinuationsAsynchronously);
+ SemaphoreSlim sem = new SemaphoreSlim(0);
+
+ await RunBidirectionalClientServer(
+ async clientStream =>
+ {
+ await clientStream.WriteAsync(new byte[1], endStream: true);
+
+ // Wait for server to read data
+ await sem.WaitAsync();
+
+ clientStream.AbortRead(ExpectedErrorCode);
+ },
+ async serverStream =>
+ {
+ var writeCompletionTask = ReleaseOnWriteCompletionAsync();
+
+ int received = await serverStream.ReadAsync(new byte[1]);
+ Assert.Equal(1, received);
+ received = await serverStream.ReadAsync(new byte[1]);
+ Assert.Equal(0, received);
+
+ Assert.False(writeCompletionTask.IsCompleted, "Server is still writing.");
+
+ // Tell client that data has been read and it can abort its reads.
+ sem.Release();
+
+ long sendAbortErrorCode = await waitForAbortTcs.Task;
+ Assert.Equal(ExpectedErrorCode, sendAbortErrorCode);
+
+ await writeCompletionTask;
+
+ async ValueTask ReleaseOnWriteCompletionAsync()
+ {
+ try
+ {
+ await serverStream.WaitForWriteCompletionAsync();
+ waitForAbortTcs.SetException(new Exception("WaitForWriteCompletionAsync didn't throw stream aborted."));
+ }
+ catch (QuicStreamAbortedException ex)
+ {
+ waitForAbortTcs.SetResult(ex.ErrorCode);
+ }
+ catch (Exception ex)
+ {
+ waitForAbortTcs.SetException(ex);
+ }
+ };
+ });
+ }
+
+ [Fact]
+ public async Task WaitForWriteCompletionAsync_ServerWriteAborted_Throws()
+ {
+ const int ExpectedErrorCode = 0xfffffff;
+
+ TaskCompletionSource waitForAbortTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+ await RunBidirectionalClientServer(
+ async clientStream =>
+ {
+ await clientStream.WriteAsync(new byte[1], endStream: true);
+ },
+ async serverStream =>
+ {
+ var writeCompletionTask = ReleaseOnWriteCompletionAsync();
+
+ int received = await serverStream.ReadAsync(new byte[1]);
+ Assert.Equal(1, received);
+ received = await serverStream.ReadAsync(new byte[1]);
+ Assert.Equal(0, received);
+
+ Assert.False(writeCompletionTask.IsCompleted, "Server is still writing.");
+
+ serverStream.AbortWrite(ExpectedErrorCode);
+
+ await waitForAbortTcs.Task;
+ await writeCompletionTask;
+
+ async ValueTask ReleaseOnWriteCompletionAsync()
+ {
+ try
+ {
+ await serverStream.WaitForWriteCompletionAsync();
+ waitForAbortTcs.SetException(new Exception("WaitForWriteCompletionAsync didn't throw stream aborted."));
+ }
+ catch (QuicOperationAbortedException)
+ {
+ waitForAbortTcs.SetResult();
+ }
+ catch (Exception ex)
+ {
+ waitForAbortTcs.SetException(ex);
+ }
+ };
+ });
+ }
+
+ [Fact]
+ public async Task WaitForWriteCompletionAsync_ServerShutdown_Success()
+ {
+ await RunBidirectionalClientServer(
+ async clientStream =>
+ {
+ await clientStream.WriteAsync(new byte[1], endStream: true);
+
+ int readCount = await clientStream.ReadAsync(new byte[1]);
+ Assert.Equal(1, readCount);
+
+ readCount = await clientStream.ReadAsync(new byte[1]);
+ Assert.Equal(0, readCount);
+ },
+ async serverStream =>
+ {
+ var writeCompletionTask = serverStream.WaitForWriteCompletionAsync();
+
+ int received = await serverStream.ReadAsync(new byte[1]);
+ Assert.Equal(1, received);
+ received = await serverStream.ReadAsync(new byte[1]);
+ Assert.Equal(0, received);
+
+ await serverStream.WriteAsync(new byte[1]);
+
+ Assert.False(writeCompletionTask.IsCompleted, "Server is still writing.");
+
+ serverStream.Shutdown();
+
+ await writeCompletionTask;
+ });
+ }
+
+ [Fact]
+ public async Task WaitForWriteCompletionAsync_GracefulShutdown_Success()
+ {
+ await RunBidirectionalClientServer(
+ async clientStream =>
+ {
+ await clientStream.WriteAsync(new byte[1], endStream: true);
+
+ int readCount = await clientStream.ReadAsync(new byte[1]);
+ Assert.Equal(1, readCount);
+
+ readCount = await clientStream.ReadAsync(new byte[1]);
+ Assert.Equal(0, readCount);
+ },
+ async serverStream =>
+ {
+ var writeCompletionTask = serverStream.WaitForWriteCompletionAsync();
+
+ int received = await serverStream.ReadAsync(new byte[1]);
+ Assert.Equal(1, received);
+ received = await serverStream.ReadAsync(new byte[1]);
+ Assert.Equal(0, received);
+
+ Assert.False(writeCompletionTask.IsCompleted, "Server is still writing.");
+
+ await serverStream.WriteAsync(new byte[1], endStream: true);
+
+ await writeCompletionTask;
+ });
+ }
+
+ [Fact]
+ public async Task WaitForWriteCompletionAsync_ConnectionClosed_Throws()
+ {
+ const int ExpectedErrorCode = 0xfffffff;
+
+ using SemaphoreSlim sem = new SemaphoreSlim(0);
+ TaskCompletionSource<long> waitForAbortTcs = new TaskCompletionSource<long>(TaskCreationOptions.RunContinuationsAsynchronously);
+
+ await RunClientServer(
+ serverFunction: async connection =>
+ {
+ await using QuicStream stream = await connection.AcceptStreamAsync();
+
+ var writeCompletionTask = ReleaseOnWriteCompletionAsync();
+
+ int received = await stream.ReadAsync(new byte[1]);
+ Assert.Equal(1, received);
+ received = await stream.ReadAsync(new byte[1]);
+ Assert.Equal(0, received);
+
+ // Signal that the server has read data
+ sem.Release();
+
+ long closeErrorCode = await waitForAbortTcs.Task;
+ Assert.Equal(ExpectedErrorCode, closeErrorCode);
+
+ await writeCompletionTask;
+
+ async ValueTask ReleaseOnWriteCompletionAsync()
+ {
+ try
+ {
+ await stream.WaitForWriteCompletionAsync();
+ waitForAbortTcs.SetException(new Exception("WaitForWriteCompletionAsync didn't throw connection aborted."));
+ }
+ catch (QuicConnectionAbortedException ex)
+ {
+ waitForAbortTcs.SetResult(ex.ErrorCode);
+ }
+ };
+ },
+ clientFunction: async connection =>
+ {
+ await using QuicStream stream = connection.OpenBidirectionalStream();
+
+ await stream.WriteAsync(new byte[1], endStream: true);
+
+ await stream.WaitForWriteCompletionAsync();
+
+ // Wait for the server to read data before closing the connection
+ await sem.WaitAsync();
+
+ await connection.CloseAsync(ExpectedErrorCode);
+ }
+ );
+ }
}
public sealed class QuicStreamTests_MockProvider : QuicStreamTests<MockProviderFactory>