[QUIC] Stream write cancellation (#53304)
authorNatalia Kondratyeva <knatalia@microsoft.com>
Fri, 4 Jun 2021 14:32:22 +0000 (16:32 +0200)
committerGitHub <noreply@github.com>
Fri, 4 Jun 2021 14:32:22 +0000 (16:32 +0200)
Add tests to check write cancellation behavior, fix pre-cancelled writes and fix mock stream.
Add throwing on msquic returning write canceled status.

Fixes #32077

src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs
src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs

index 14ecead9a7f88e1480d471b10f94f7c2dffef123..68964fc30688d6ee170e8a0c354d9d9c3be335e2 100644 (file)
@@ -16,6 +16,7 @@ namespace System.Net.Quic.Implementations.Mock
         private readonly bool _isInitiator;
 
         private readonly StreamState _streamState;
+        private bool _writesCanceled;
 
         internal MockStream(StreamState streamState, bool isInitiator)
         {
@@ -84,6 +85,10 @@ namespace System.Net.Quic.Implementations.Mock
         internal override void Write(ReadOnlySpan<byte> buffer)
         {
             CheckDisposed();
+            if (Volatile.Read(ref _writesCanceled))
+            {
+                throw new OperationCanceledException();
+            }
 
             StreamBuffer? streamBuffer = WriteStreamBuffer;
             if (streamBuffer is null)
@@ -102,6 +107,11 @@ namespace System.Net.Quic.Implementations.Mock
         internal override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, bool endStream, CancellationToken cancellationToken = default)
         {
             CheckDisposed();
+            if (Volatile.Read(ref _writesCanceled))
+            {
+                cancellationToken.ThrowIfCancellationRequested();
+                throw new OperationCanceledException();
+            }
 
             StreamBuffer? streamBuffer = WriteStreamBuffer;
             if (streamBuffer is null)
@@ -109,6 +119,12 @@ namespace System.Net.Quic.Implementations.Mock
                 throw new NotSupportedException();
             }
 
+            using var registration = cancellationToken.UnsafeRegister(static s =>
+            {
+                var stream = (MockStream)s!;
+                Volatile.Write(ref stream._writesCanceled, true);
+            }, this);
+
             await streamBuffer.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
 
             if (endStream)
index 6102235f7e1fe42e3fa6aa942c1639593d4a428a..bb1468cd5255867d587eac451c961309e4724694 100644 (file)
@@ -216,18 +216,14 @@ namespace System.Net.Quic.Implementations.MsQuic
                 throw new InvalidOperationException(SR.net_quic_writing_notallowed);
             }
 
-            lock (_state)
+            // Make sure start has completed
+            if (!_started)
             {
-                if (_state.SendState == SendState.Aborted)
-                {
-                    throw new OperationCanceledException(SR.net_quic_sending_aborted);
-                }
-                else if (_state.SendState == SendState.ConnectionClosed)
-                {
-                    throw GetConnectionAbortedException(_state);
-                }
+                await _state.SendResettableCompletionSource.GetTypelessValueTask().ConfigureAwait(false);
+                _started = true;
             }
 
+            // if token was already cancelled, this would execute syncronously
             CancellationTokenRegistration registration = cancellationToken.UnsafeRegister(static (s, token) =>
             {
                 var state = (State)s!;
@@ -248,11 +244,17 @@ namespace System.Net.Quic.Implementations.MsQuic
                 }
             }, _state);
 
-            // Make sure start has completed
-            if (!_started)
+            lock (_state)
             {
-                await _state.SendResettableCompletionSource.GetTypelessValueTask().ConfigureAwait(false);
-                _started = true;
+                if (_state.SendState == SendState.Aborted)
+                {
+                    cancellationToken.ThrowIfCancellationRequested();
+                    throw new OperationCanceledException(SR.net_quic_sending_aborted);
+                }
+                else if (_state.SendState == SendState.ConnectionClosed)
+                {
+                    throw GetConnectionAbortedException(_state);
+                }
             }
 
             return registration;
@@ -262,7 +264,7 @@ namespace System.Net.Quic.Implementations.MsQuic
         {
             lock (_state)
             {
-                if (_state.SendState == SendState.Finished || _state.SendState == SendState.Aborted)
+                if (_state.SendState == SendState.Finished)
                 {
                     _state.SendState = SendState.None;
                 }
@@ -827,6 +829,9 @@ namespace System.Net.Quic.Implementations.MsQuic
 
         private static uint HandleEventSendComplete(State state, ref StreamEvent evt)
         {
+            StreamEventDataSendComplete sendCompleteEvent = evt.Data.SendComplete;
+            bool canceled = sendCompleteEvent.Canceled != 0;
+
             bool complete = false;
 
             lock (state)
@@ -836,13 +841,26 @@ namespace System.Net.Quic.Implementations.MsQuic
                     state.SendState = SendState.Finished;
                     complete = true;
                 }
+
+                if (canceled)
+                {
+                    state.SendState = SendState.Aborted;
+                }
             }
 
             if (complete)
             {
                 CleanupSendState(state);
-                // TODO throw if a write was canceled.
-                state.SendResettableCompletionSource.Complete(MsQuicStatusCodes.Success);
+
+                if (!canceled)
+                {
+                    state.SendResettableCompletionSource.Complete(MsQuicStatusCodes.Success);
+                }
+                else
+                {
+                    state.SendResettableCompletionSource.CompleteException(
+                        ExceptionDispatchInfo.SetCurrentStackTrace(new OperationCanceledException("Write was canceled")));
+                }
             }
 
             return MsQuicStatusCodes.Success;
index 72243c3bdb723d4fd024c4ca291416042abacd06..4eee9b459d9fbb3bff386151351bcad0bc08c082 100644 (file)
@@ -6,6 +6,7 @@ using System.Buffers;
 using System.Collections.Generic;
 using System.Linq;
 using System.Text;
+using System.Threading;
 using System.Threading.Tasks;
 using Xunit;
 
@@ -434,6 +435,138 @@ namespace System.Net.Quic.Tests
                 Assert.Equal(ExpectedErrorCode, ex.ErrorCode);
             }).WaitAsync(TimeSpan.FromSeconds(15));
         }
+
+        [ActiveIssue("https://github.com/dotnet/runtime/issues/53530")]
+        [Fact]
+        public async Task StreamAbortedWithoutWriting_ReadThrows()
+        {
+            long expectedErrorCode = 1234;
+
+            await RunClientServer(
+                clientFunction: async connection =>
+                {
+                    await using QuicStream stream = connection.OpenUnidirectionalStream();
+                    stream.AbortWrite(expectedErrorCode);
+
+                    await stream.ShutdownCompleted();
+                },
+                serverFunction: async connection =>
+                {
+                    await using QuicStream stream = await connection.AcceptStreamAsync();
+
+                    byte[] buffer = new byte[1];
+
+                    QuicStreamAbortedException ex = await Assert.ThrowsAsync<QuicStreamAbortedException>(() => ReadAll(stream, buffer));
+                    Assert.Equal(expectedErrorCode, ex.ErrorCode);
+
+                    await stream.ShutdownCompleted();
+                }
+            );
+        }
+
+        [Fact]
+        public async Task WritePreCanceled_Throws()
+        {
+            long expectedErrorCode = 1234;
+
+            await RunClientServer(
+                clientFunction: async connection =>
+                {
+                    await using QuicStream stream = connection.OpenUnidirectionalStream();
+
+                    CancellationTokenSource cts = new CancellationTokenSource();
+                    cts.Cancel();
+
+                    await Assert.ThrowsAsync<OperationCanceledException>(() => stream.WriteAsync(new byte[1], cts.Token).AsTask());
+
+                    // next write would also throw
+                    await Assert.ThrowsAsync<OperationCanceledException>(() => stream.WriteAsync(new byte[1]).AsTask());
+
+                    // manual write abort is still required
+                    stream.AbortWrite(expectedErrorCode);
+
+                    await stream.ShutdownCompleted();
+                },
+                serverFunction: async connection =>
+                {
+                    await using QuicStream stream = await connection.AcceptStreamAsync();
+
+                    byte[] buffer = new byte[1024 * 1024];
+
+                    // TODO: it should always throw QuicStreamAbortedException, but sometimes it does not https://github.com/dotnet/runtime/issues/53530
+                    //QuicStreamAbortedException ex = await Assert.ThrowsAsync<QuicStreamAbortedException>(() => ReadAll(stream, buffer));
+                    try
+                    {
+                        await ReadAll(stream, buffer);
+                    }
+                    catch (QuicStreamAbortedException) { }
+
+                    await stream.ShutdownCompleted();
+                }
+            );
+        }
+
+        [Fact]
+        public async Task WriteCanceled_NextWriteThrows()
+        {
+            long expectedErrorCode = 1234;
+
+            await RunClientServer(
+                clientFunction: async connection =>
+                {
+                    await using QuicStream stream = connection.OpenUnidirectionalStream();
+
+                    CancellationTokenSource cts = new CancellationTokenSource(500);
+
+                    async Task WriteUntilCanceled()
+                    {
+                        var buffer = new byte[64 * 1024];
+                        while (true)
+                        {
+                            await stream.WriteAsync(buffer, cancellationToken: cts.Token);
+                        }
+                    }
+
+                    // a write would eventually be canceled
+                    await Assert.ThrowsAsync<OperationCanceledException>(() => WriteUntilCanceled().WaitAsync(TimeSpan.FromSeconds(3)));
+
+                    // next write would also throw
+                    await Assert.ThrowsAsync<OperationCanceledException>(() => stream.WriteAsync(new byte[1]).AsTask());
+
+                    // manual write abort is still required
+                    stream.AbortWrite(expectedErrorCode);
+
+                    await stream.ShutdownCompleted();
+                },
+                serverFunction: async connection =>
+                {
+                    await using QuicStream stream = await connection.AcceptStreamAsync();
+
+                    async Task ReadUntilAborted()  
+                    {
+                        var buffer = new byte[1024];
+                        while (true)
+                        {
+                            int res = await stream.ReadAsync(buffer);
+                            if (res == 0)
+                            {
+                                break;
+                            }
+                        }
+                    }
+
+                    // TODO: it should always throw QuicStreamAbortedException, but sometimes it does not https://github.com/dotnet/runtime/issues/53530
+                    //QuicStreamAbortedException ex = await Assert.ThrowsAsync<QuicStreamAbortedException>(() => ReadUntilAborted());
+                    try
+                    {
+                        await ReadUntilAborted().WaitAsync(TimeSpan.FromSeconds(3));
+                    }
+                    catch (QuicStreamAbortedException) { }
+
+                    await stream.ShutdownCompleted();
+                }
+            );
+        }
     }
 
     public sealed class QuicStreamTests_MockProvider : QuicStreamTests<MockProviderFactory> { }