[release/6.0] [QUIC] Add QuicStream.WaitForWriteCompletionAsync (#58415)
authorgithub-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Wed, 1 Sep 2021 14:37:45 +0000 (08:37 -0600)
committerGitHub <noreply@github.com>
Wed, 1 Sep 2021 14:37:45 +0000 (08:37 -0600)
src/libraries/System.Net.Quic/ref/System.Net.Quic.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockConnection.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs
src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs

index 790ad3f..be0b43d 100644 (file)
@@ -109,6 +109,7 @@ namespace System.Net.Quic
         public override void SetLength(long value) { }
         public void Shutdown() { }
         public System.Threading.Tasks.ValueTask ShutdownCompleted(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
+        public System.Threading.Tasks.ValueTask WaitForWriteCompletionAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
         public override void Write(byte[] buffer, int offset, int count) { }
         public override void Write(System.ReadOnlySpan<byte> buffer) { }
         public System.Threading.Tasks.ValueTask WriteAsync(System.Buffers.ReadOnlySequence<byte> buffers, bool endStream, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
index 7487a95..e409b96 100644 (file)
@@ -8,6 +8,8 @@ using System.Runtime.ExceptionServices;
 using System.Threading;
 using System.Threading.Channels;
 using System.Threading.Tasks;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
 
 namespace System.Net.Quic.Implementations.Mock
 {
@@ -244,6 +246,9 @@ namespace System.Net.Quic.Implementations.Mock
             }
 
             MockStream.StreamState streamState = new MockStream.StreamState(streamId, bidirectional);
+            // TODO Streams are never removed from a connection. Consider cleaning up in the future.
+            state._streams[streamState._streamId] = streamState;
+
             Channel<MockStream.StreamState> streamChannel = _isClient ? state._clientInitiatedStreamChannel : state._serverInitiatedStreamChannel;
             streamChannel.Writer.TryWrite(streamState);
 
@@ -320,6 +325,12 @@ namespace System.Net.Quic.Implementations.Mock
                     state._serverErrorCode = errorCode;
                     DrainAcceptQueue(errorCode, -1);
                 }
+
+                foreach (KeyValuePair<long, MockStream.StreamState> kvp in state._streams)
+                {
+                    kvp.Value._outboundWritesCompletedTcs.TrySetException(new QuicConnectionAbortedException(errorCode));
+                    kvp.Value._inboundWritesCompletedTcs.TrySetException(new QuicConnectionAbortedException(errorCode));
+                }
             }
 
             Dispose();
@@ -474,8 +485,9 @@ namespace System.Net.Quic.Implementations.Mock
         internal sealed class ConnectionState
         {
             public readonly SslApplicationProtocol _applicationProtocol;
-            public Channel<MockStream.StreamState> _clientInitiatedStreamChannel;
-            public Channel<MockStream.StreamState> _serverInitiatedStreamChannel;
+            public readonly Channel<MockStream.StreamState> _clientInitiatedStreamChannel;
+            public readonly Channel<MockStream.StreamState> _serverInitiatedStreamChannel;
+            public readonly ConcurrentDictionary<long, MockStream.StreamState> _streams;
 
             public PeerStreamLimit? _clientStreamLimit;
             public PeerStreamLimit? _serverStreamLimit;
@@ -490,6 +502,7 @@ namespace System.Net.Quic.Implementations.Mock
                 _clientInitiatedStreamChannel = Channel.CreateUnbounded<MockStream.StreamState>();
                 _serverInitiatedStreamChannel = Channel.CreateUnbounded<MockStream.StreamState>();
                 _clientErrorCode = _serverErrorCode = -1;
+                _streams = new ConcurrentDictionary<long, MockStream.StreamState>();
             }
         }
     }
index 588da85..fbace75 100644 (file)
@@ -164,6 +164,7 @@ namespace System.Net.Quic.Implementations.Mock
             if (endStream)
             {
                 streamBuffer.EndWrite();
+                WritesCompletedTcs.TrySetResult();
             }
         }
 
@@ -206,10 +207,12 @@ namespace System.Net.Quic.Implementations.Mock
             if (_isInitiator)
             {
                 _streamState._outboundWriteErrorCode = errorCode;
+                _streamState._inboundWritesCompletedTcs.TrySetException(new QuicStreamAbortedException(errorCode));
             }
             else
             {
                 _streamState._inboundWriteErrorCode = errorCode;
+                _streamState._outboundWritesCompletedTcs.TrySetException(new QuicOperationAbortedException());
             }
 
             ReadStreamBuffer?.AbortRead();
@@ -220,10 +223,12 @@ namespace System.Net.Quic.Implementations.Mock
             if (_isInitiator)
             {
                 _streamState._outboundReadErrorCode = errorCode;
+                _streamState._outboundWritesCompletedTcs.TrySetException(new QuicStreamAbortedException(errorCode));
             }
             else
             {
                 _streamState._inboundReadErrorCode = errorCode;
+                _streamState._inboundWritesCompletedTcs.TrySetException(new QuicOperationAbortedException());
             }
 
             WriteStreamBuffer?.EndWrite();
@@ -251,6 +256,8 @@ namespace System.Net.Quic.Implementations.Mock
             {
                 _connection.LocalStreamLimit!.Bidirectional.Decrement();
             }
+
+            WritesCompletedTcs.TrySetResult();
         }
 
         private void CheckDisposed()
@@ -283,6 +290,17 @@ namespace System.Net.Quic.Implementations.Mock
             return default;
         }
 
+        internal override ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default)
+        {
+            CheckDisposed();
+
+            return new ValueTask(WritesCompletedTcs.Task);
+        }
+
+        private TaskCompletionSource WritesCompletedTcs => _isInitiator
+            ? _streamState._outboundWritesCompletedTcs
+            : _streamState._inboundWritesCompletedTcs;
+
         internal sealed class StreamState
         {
             public readonly long _streamId;
@@ -292,6 +310,8 @@ namespace System.Net.Quic.Implementations.Mock
             public long _inboundReadErrorCode;
             public long _outboundWriteErrorCode;
             public long _inboundWriteErrorCode;
+            public TaskCompletionSource _outboundWritesCompletedTcs;
+            public TaskCompletionSource _inboundWritesCompletedTcs;
 
             private const int InitialBufferSize =
 #if DEBUG
@@ -310,6 +330,8 @@ namespace System.Net.Quic.Implementations.Mock
                 _streamId = streamId;
                 _outboundStreamBuffer = new StreamBuffer(initialBufferSize: InitialBufferSize, maxBufferSize: MaxBufferSize);
                 _inboundStreamBuffer = (bidirectional ? new StreamBuffer() : null);
+                _outboundWritesCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+                _inboundWritesCompletedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
             }
         }
     }
index 83ccf08..54f0ed3 100644 (file)
@@ -69,6 +69,11 @@ namespace System.Net.Quic.Implementations.MsQuic
             // Resettable completions to be used for multiple calls to send.
             public readonly ResettableCompletionSource<uint> SendResettableCompletionSource = new ResettableCompletionSource<uint>();
 
+            public ShutdownWriteState ShutdownWriteState;
+
+            // Set once writes have been shutdown.
+            public readonly TaskCompletionSource ShutdownWriteCompletionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
             public ShutdownState ShutdownState;
             // The value makes sure that we release the handles only once.
             public int ShutdownDone;
@@ -577,12 +582,26 @@ namespace System.Net.Quic.Implementations.MsQuic
                 return;
             }
 
+            bool shouldComplete = false;
+
             lock (_state)
             {
                 if (_state.SendState < SendState.Aborted)
                 {
                     _state.SendState = SendState.Aborted;
                 }
+
+                if (_state.ShutdownWriteState == ShutdownWriteState.None)
+                {
+                    _state.ShutdownWriteState = ShutdownWriteState.Canceled;
+                    shouldComplete = true;
+                }
+            }
+
+            if (shouldComplete)
+            {
+                _state.ShutdownWriteCompletionSource.SetException(
+                    ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException("Write was aborted.")));
             }
 
             StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND, errorCode);
@@ -629,6 +648,23 @@ namespace System.Net.Quic.Implementations.MsQuic
             await _state.ShutdownCompletionSource.Task.ConfigureAwait(false);
         }
 
+        internal override ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default)
+        {
+            // TODO: What should happen if this is called for a unidirectional stream and there are no writes?
+
+            ThrowIfDisposed();
+
+            lock (_state)
+            {
+                if (_state.ShutdownWriteState == ShutdownWriteState.ConnectionClosed)
+                {
+                    throw GetConnectionAbortedException(_state);
+                }
+            }
+
+            return new ValueTask(_state.ShutdownWriteCompletionSource.Task.WaitAsync(cancellationToken));
+        }
+
         internal override void Shutdown()
         {
             ThrowIfDisposed();
@@ -861,6 +897,11 @@ namespace System.Net.Quic.Implementations.MsQuic
                     // Peer has stopped receiving data, don't send anymore.
                     case QUIC_STREAM_EVENT_TYPE.PEER_RECEIVE_ABORTED:
                         return HandleEventPeerRecvAborted(state, ref evt);
+                    // Occurs when shutdown is completed for the send side.
+                    // This only happens for shutdown on sending, not receiving
+                    // Receive shutdown can only be abortive.
+                    case QUIC_STREAM_EVENT_TYPE.SEND_SHUTDOWN_COMPLETE:
+                        return HandleEventSendShutdownComplete(state, ref evt);
                     // Shutdown for both sending and receiving is completed.
                     case QUIC_STREAM_EVENT_TYPE.SHUTDOWN_COMPLETE:
                         return HandleEventShutdownComplete(state, ref evt);
@@ -993,23 +1034,37 @@ namespace System.Net.Quic.Implementations.MsQuic
 
         private static uint HandleEventPeerRecvAborted(State state, ref StreamEvent evt)
         {
-            bool shouldComplete = false;
+            bool shouldSendComplete = false;
+            bool shouldShutdownWriteComplete = false;
             lock (state)
             {
                 if (state.SendState == SendState.None || state.SendState == SendState.Pending)
                 {
-                    shouldComplete = true;
+                    shouldSendComplete = true;
+                }
+
+                if (state.ShutdownWriteState == ShutdownWriteState.None)
+                {
+                    state.ShutdownWriteState = ShutdownWriteState.Canceled;
+                    shouldShutdownWriteComplete = true;
                 }
+
                 state.SendState = SendState.Aborted;
                 state.SendErrorCode = (long)evt.Data.PeerReceiveAborted.ErrorCode;
             }
 
-            if (shouldComplete)
+            if (shouldSendComplete)
             {
                 state.SendResettableCompletionSource.CompleteException(
                     ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode)));
             }
 
+            if (shouldShutdownWriteComplete)
+            {
+                state.ShutdownWriteCompletionSource.SetException(
+                    ExceptionDispatchInfo.SetCurrentStackTrace(new QuicStreamAbortedException(state.SendErrorCode)));
+            }
+
             return MsQuicStatusCodes.Success;
         }
 
@@ -1021,6 +1076,38 @@ namespace System.Net.Quic.Implementations.MsQuic
             return MsQuicStatusCodes.Success;
         }
 
+        private static uint HandleEventSendShutdownComplete(State state, ref StreamEvent evt)
+        {
+            // Graceful will be false in three situations:
+            // 1. The peer aborted reads and the PEER_RECEIVE_ABORTED event was raised.
+            //    ShutdownWriteCompletionSource is already complete with an error.
+            // 2. We aborted writes.
+            //    ShutdownWriteCompletionSource is already complete with an error.
+            // 3. The connection was closed.
+            //    SHUTDOWN_COMPLETE event will be raised immediately after this event. It will handle completing with an error.
+            //
+            // Only use this event with sends gracefully completed.
+            if (evt.Data.SendShutdownComplete.Graceful != 0)
+            {
+                bool shouldComplete = false;
+                lock (state)
+                {
+                    if (state.ShutdownWriteState == ShutdownWriteState.None)
+                    {
+                        state.ShutdownWriteState = ShutdownWriteState.Finished;
+                        shouldComplete = true;
+                    }
+                }
+
+                if (shouldComplete)
+                {
+                    state.ShutdownWriteCompletionSource.SetResult();
+                }
+            }
+
+            return MsQuicStatusCodes.Success;
+        }
+
         private static uint HandleEventShutdownComplete(State state, ref StreamEvent evt)
         {
             StreamEventDataShutdownComplete shutdownCompleteEvent = evt.Data.ShutdownComplete;
@@ -1031,6 +1118,7 @@ namespace System.Net.Quic.Implementations.MsQuic
             }
 
             bool shouldReadComplete = false;
+            bool shouldShutdownWriteComplete = false;
             bool shouldShutdownComplete = false;
 
             lock (state)
@@ -1040,6 +1128,15 @@ namespace System.Net.Quic.Implementations.MsQuic
 
                 shouldReadComplete = CleanupReadStateAndCheckPending(state, ReadState.ReadsCompleted);
 
+                if (state.ShutdownWriteState == ShutdownWriteState.None)
+                {
+                    // TODO: We can get to this point if the stream is unidirectional and there are no writes.
+                    // Consider what is the best behavior here with write shutdown and the read side of
+                    // unidirecitonal streams in the future.
+                    state.ShutdownWriteState = ShutdownWriteState.Finished;
+                    shouldShutdownWriteComplete = true;
+                }
+
                 if (state.ShutdownState == ShutdownState.None)
                 {
                     state.ShutdownState = ShutdownState.Finished;
@@ -1052,6 +1149,11 @@ namespace System.Net.Quic.Implementations.MsQuic
                 state.ReceiveResettableCompletionSource.Complete(0);
             }
 
+            if (shouldShutdownWriteComplete)
+            {
+                state.ShutdownWriteCompletionSource.SetResult();
+            }
+
             if (shouldShutdownComplete)
             {
                 state.ShutdownCompletionSource.SetResult();
@@ -1361,6 +1463,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
             bool shouldCompleteRead = false;
             bool shouldCompleteSend = false;
+            bool shouldCompleteShutdownWrite = false;
             bool shouldCompleteShutdown = false;
 
             lock (state)
@@ -1373,6 +1476,12 @@ namespace System.Net.Quic.Implementations.MsQuic
                 }
                 state.SendState = SendState.ConnectionClosed;
 
+                if (state.ShutdownWriteState == ShutdownWriteState.None)
+                {
+                    shouldCompleteShutdownWrite = true;
+                }
+                state.ShutdownWriteState = ShutdownWriteState.ConnectionClosed;
+
                 if (state.ShutdownState == ShutdownState.None)
                 {
                     shouldCompleteShutdown = true;
@@ -1392,6 +1501,12 @@ namespace System.Net.Quic.Implementations.MsQuic
                     ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state)));
             }
 
+            if (shouldCompleteShutdownWrite)
+            {
+                state.ShutdownWriteCompletionSource.SetException(
+                    ExceptionDispatchInfo.SetCurrentStackTrace(GetConnectionAbortedException(state)));
+            }
+
             if (shouldCompleteShutdown)
             {
                 state.ShutdownCompletionSource.SetException(
@@ -1493,6 +1608,14 @@ namespace System.Net.Quic.Implementations.MsQuic
             Closed
         }
 
+        private enum ShutdownWriteState
+        {
+            None = 0,
+            Canceled,
+            Finished,
+            ConnectionClosed
+        }
+
         private enum ShutdownState
         {
             None = 0,
index 66c9a8b..215ce13 100644 (file)
@@ -47,6 +47,8 @@ namespace System.Net.Quic.Implementations
 
         internal abstract ValueTask ShutdownCompleted(CancellationToken cancellationToken = default);
 
+        internal abstract ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default);
+
         internal abstract void Shutdown();
 
         internal abstract void Flush();
index 8a6dbe4..912d32c 100644 (file)
@@ -117,6 +117,8 @@ namespace System.Net.Quic
 
         public ValueTask ShutdownCompleted(CancellationToken cancellationToken = default) => _provider.ShutdownCompleted(cancellationToken);
 
+        public ValueTask WaitForWriteCompletionAsync(CancellationToken cancellationToken = default) => _provider.WaitForWriteCompletionAsync(cancellationToken);
+
         public void Shutdown() => _provider.Shutdown();
 
         protected override void Dispose(bool disposing)
index 098bd4d..ed37126 100644 (file)
@@ -769,6 +769,230 @@ namespace System.Net.Quic.Tests
                 }
             );
         }
+
+        [Fact]
+        public async Task WaitForWriteCompletionAsync_ClientReadAborted_Throws()
+        {
+            const int ExpectedErrorCode = 0xfffffff;
+
+            TaskCompletionSource<long> waitForAbortTcs = new TaskCompletionSource<long>(TaskCreationOptions.RunContinuationsAsynchronously);
+            SemaphoreSlim sem = new SemaphoreSlim(0);
+
+            await RunBidirectionalClientServer(
+                async clientStream =>
+                {
+                    await clientStream.WriteAsync(new byte[1], endStream: true);
+
+                    // Wait for server to read data
+                    await sem.WaitAsync();
+
+                    clientStream.AbortRead(ExpectedErrorCode);
+                },
+                async serverStream =>
+                {
+                    var writeCompletionTask = ReleaseOnWriteCompletionAsync();
+
+                    int received = await serverStream.ReadAsync(new byte[1]);
+                    Assert.Equal(1, received);
+                    received = await serverStream.ReadAsync(new byte[1]);
+                    Assert.Equal(0, received);
+
+                    Assert.False(writeCompletionTask.IsCompleted, "Server is still writing.");
+
+                    // Tell client that data has been read and it can abort its reads.
+                    sem.Release();
+
+                    long sendAbortErrorCode = await waitForAbortTcs.Task;
+                    Assert.Equal(ExpectedErrorCode, sendAbortErrorCode);
+
+                    await writeCompletionTask;
+
+                    async ValueTask ReleaseOnWriteCompletionAsync()
+                    {
+                        try
+                        {
+                            await serverStream.WaitForWriteCompletionAsync();
+                            waitForAbortTcs.SetException(new Exception("WaitForWriteCompletionAsync didn't throw stream aborted."));
+                        }
+                        catch (QuicStreamAbortedException ex)
+                        {
+                            waitForAbortTcs.SetResult(ex.ErrorCode);
+                        }
+                        catch (Exception ex)
+                        {
+                            waitForAbortTcs.SetException(ex);
+                        }
+                    };
+                });
+        }
+
+        [Fact]
+        public async Task WaitForWriteCompletionAsync_ServerWriteAborted_Throws()
+        {
+            const int ExpectedErrorCode = 0xfffffff;
+
+            TaskCompletionSource waitForAbortTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+            await RunBidirectionalClientServer(
+                async clientStream =>
+                {
+                    await clientStream.WriteAsync(new byte[1], endStream: true);
+                },
+                async serverStream =>
+                {
+                    var writeCompletionTask = ReleaseOnWriteCompletionAsync();
+
+                    int received = await serverStream.ReadAsync(new byte[1]);
+                    Assert.Equal(1, received);
+                    received = await serverStream.ReadAsync(new byte[1]);
+                    Assert.Equal(0, received);
+
+                    Assert.False(writeCompletionTask.IsCompleted, "Server is still writing.");
+
+                    serverStream.AbortWrite(ExpectedErrorCode);
+
+                    await waitForAbortTcs.Task;
+                    await writeCompletionTask;
+
+                    async ValueTask ReleaseOnWriteCompletionAsync()
+                    {
+                        try
+                        {
+                            await serverStream.WaitForWriteCompletionAsync();
+                            waitForAbortTcs.SetException(new Exception("WaitForWriteCompletionAsync didn't throw stream aborted."));
+                        }
+                        catch (QuicOperationAbortedException)
+                        {
+                            waitForAbortTcs.SetResult();
+                        }
+                        catch (Exception ex)
+                        {
+                            waitForAbortTcs.SetException(ex);
+                        }
+                    };
+                });
+        }
+
+        [Fact]
+        public async Task WaitForWriteCompletionAsync_ServerShutdown_Success()
+        {
+            await RunBidirectionalClientServer(
+                async clientStream =>
+                {
+                    await clientStream.WriteAsync(new byte[1], endStream: true);
+
+                    int readCount = await clientStream.ReadAsync(new byte[1]);
+                    Assert.Equal(1, readCount);
+
+                    readCount = await clientStream.ReadAsync(new byte[1]);
+                    Assert.Equal(0, readCount);
+                },
+                async serverStream =>
+                {
+                    var writeCompletionTask = serverStream.WaitForWriteCompletionAsync();
+
+                    int received = await serverStream.ReadAsync(new byte[1]);
+                    Assert.Equal(1, received);
+                    received = await serverStream.ReadAsync(new byte[1]);
+                    Assert.Equal(0, received);
+
+                    await serverStream.WriteAsync(new byte[1]);
+
+                    Assert.False(writeCompletionTask.IsCompleted, "Server is still writing.");
+
+                    serverStream.Shutdown();
+
+                    await writeCompletionTask;
+                });
+        }
+
+        [Fact]
+        public async Task WaitForWriteCompletionAsync_GracefulShutdown_Success()
+        {
+            await RunBidirectionalClientServer(
+                async clientStream =>
+                {
+                    await clientStream.WriteAsync(new byte[1], endStream: true);
+
+                    int readCount = await clientStream.ReadAsync(new byte[1]);
+                    Assert.Equal(1, readCount);
+
+                    readCount = await clientStream.ReadAsync(new byte[1]);
+                    Assert.Equal(0, readCount);
+                },
+                async serverStream =>
+                {
+                    var writeCompletionTask = serverStream.WaitForWriteCompletionAsync();
+
+                    int received = await serverStream.ReadAsync(new byte[1]);
+                    Assert.Equal(1, received);
+                    received = await serverStream.ReadAsync(new byte[1]);
+                    Assert.Equal(0, received);
+
+                    Assert.False(writeCompletionTask.IsCompleted, "Server is still writing.");
+
+                    await serverStream.WriteAsync(new byte[1], endStream: true);
+
+                    await writeCompletionTask;
+                });
+        }
+
+        [Fact]
+        public async Task WaitForWriteCompletionAsync_ConnectionClosed_Throws()
+        {
+            const int ExpectedErrorCode = 0xfffffff;
+
+            using SemaphoreSlim sem = new SemaphoreSlim(0);
+            TaskCompletionSource<long> waitForAbortTcs = new TaskCompletionSource<long>(TaskCreationOptions.RunContinuationsAsynchronously);
+
+            await RunClientServer(
+                serverFunction: async connection =>
+                {
+                    await using QuicStream stream = await connection.AcceptStreamAsync();
+
+                    var writeCompletionTask = ReleaseOnWriteCompletionAsync();
+
+                    int received = await stream.ReadAsync(new byte[1]);
+                    Assert.Equal(1, received);
+                    received = await stream.ReadAsync(new byte[1]);
+                    Assert.Equal(0, received);
+
+                    // Signal that the server has read data
+                    sem.Release();
+
+                    long closeErrorCode = await waitForAbortTcs.Task;
+                    Assert.Equal(ExpectedErrorCode, closeErrorCode);
+
+                    await writeCompletionTask;
+
+                    async ValueTask ReleaseOnWriteCompletionAsync()
+                    {
+                        try
+                        {
+                            await stream.WaitForWriteCompletionAsync();
+                            waitForAbortTcs.SetException(new Exception("WaitForWriteCompletionAsync didn't throw connection aborted."));
+                        }
+                        catch (QuicConnectionAbortedException ex)
+                        {
+                            waitForAbortTcs.SetResult(ex.ErrorCode);
+                        }
+                    };
+                },
+                clientFunction: async connection =>
+                {
+                    await using QuicStream stream = connection.OpenBidirectionalStream();
+
+                    await stream.WriteAsync(new byte[1], endStream: true);
+
+                    await stream.WaitForWriteCompletionAsync();
+
+                    // Wait for the server to read data before closing the connection
+                    await sem.WaitAsync();
+
+                    await connection.CloseAsync(ExpectedErrorCode);
+                }
+            );
+        }
     }
 
     public sealed class QuicStreamTests_MockProvider : QuicStreamTests<MockProviderFactory>