[QUIC] Call SendResettableCompletionSource.CompleteException in AbortWrite (#67341)
authorRadek Zikmund <32671551+rzikm@users.noreply.github.com>
Tue, 5 Apr 2022 13:54:35 +0000 (15:54 +0200)
committerGitHub <noreply@github.com>
Tue, 5 Apr 2022 13:54:35 +0000 (15:54 +0200)
* Call SendResettableCompletionSource.CompleteException in AbortWrite

* Add test

* fixup! Add test

* Use loop to make the test more robust

src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs
src/libraries/System.Net.Quic/tests/FunctionalTests/QuicStreamTests.cs
src/libraries/System.Net.Quic/tests/FunctionalTests/QuicTestBase.cs

index c33da0d..73ea871 100644 (file)
@@ -606,9 +606,15 @@ namespace System.Net.Quic.Implementations.MsQuic
             }
 
             bool shouldComplete = false;
+            bool shouldCompleteSends = false;
 
             lock (_state)
             {
+                if (_state.SendState == SendState.None || _state.SendState == SendState.Pending)
+                {
+                    shouldCompleteSends = true;
+                }
+
                 if (_state.SendState < SendState.Aborted)
                 {
                     _state.SendState = SendState.Aborted;
@@ -627,6 +633,12 @@ namespace System.Net.Quic.Implementations.MsQuic
                     ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException("Write was aborted.")));
             }
 
+            if (shouldCompleteSends)
+            {
+                _state.SendResettableCompletionSource.CompleteException(
+                    ExceptionDispatchInfo.SetCurrentStackTrace(new QuicOperationAbortedException("Write was aborted.")));
+            }
+
             StartShutdown(QUIC_STREAM_SHUTDOWN_FLAGS.ABORT_SEND, errorCode);
         }
 
index cd9eca2..f1c2e19 100644 (file)
@@ -2,7 +2,6 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
-using System.Buffers;
 using System.Collections.Generic;
 using System.Linq;
 using System.Text;
@@ -819,6 +818,38 @@ namespace System.Net.Quic.Tests
                 });
         }
 
+
+        [Fact]
+        public async Task WriteAsync_LocalAbort_Throws()
+        {
+            if (IsMockProvider)
+            {
+                // Mock provider does not support aborting pending writes via AbortWrite
+                return;
+            }
+
+            const int ExpectedErrorCode = 0xfffffff;
+
+            TaskCompletionSource waitForAbortTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+
+            await RunBidirectionalClientServer(
+                clientStream =>
+                {
+                    return Task.CompletedTask;
+                },
+                async serverStream =>
+                {
+                    // It may happen, that the WriteAsync call finishes early (before the AbortWrite 
+                    // below), and we hit a check on the next iteration of the WriteForever.
+                    // But in most cases it will still exercise aborting the outstanding write task.
+
+                    var writeTask = WriteForever(serverStream, 1024 * 1024);
+                    serverStream.AbortWrite(ExpectedErrorCode);
+
+                    await Assert.ThrowsAsync<QuicOperationAbortedException>(() => writeTask.WaitAsync(TimeSpan.FromSeconds(3)));
+                });
+        }
+
         [Fact]
         public async Task WaitForWriteCompletionAsync_ServerWriteAborted_Throws()
         {
index 39e6be4..a9f36cc 100644 (file)
@@ -1,6 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Buffers;
 using System.Collections.Generic;
 using System.Diagnostics;
 using System.Net.Quic.Implementations;
@@ -307,12 +308,19 @@ namespace System.Net.Quic.Tests
             return bytesRead;
         }
 
-        internal static async Task<int> WriteForever(QuicStream stream)
+        internal static async Task<int> WriteForever(QuicStream stream, int size = 1)
         {
-            Memory<byte> buffer = new byte[] { 123 };
-            while (true)
+            byte[] buffer = ArrayPool<byte>.Shared.Rent(size);
+            try
+            {
+                while (true)
+                {
+                    await stream.WriteAsync(buffer);
+                }
+            }
+            finally
             {
-                await stream.WriteAsync(buffer);
+                ArrayPool<byte>.Shared.Return(buffer);
             }
         }
     }