[QUIC] QuicStream add ReadsCompleted (#57492)
authorNatalia Kondratyeva <knatalia@microsoft.com>
Tue, 17 Aug 2021 12:22:06 +0000 (14:22 +0200)
committerGitHub <noreply@github.com>
Tue, 17 Aug 2021 12:22:06 +0000 (14:22 +0200)
Add ReadsCompleted API that exposes ReadState=ReadsCompleted, set ReadState to ReadsCompleted if FIN flag arrives in RECEIVE event, fix ReadState changing to final stauses, expand ReadState transition description

Fixes #55707

src/libraries/System.Net.Quic/ref/System.Net.Quic.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/Mock/MockStream.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/MsQuic/MsQuicStream.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/Implementations/QuicStreamProvider.cs
src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs
src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs

index 762d0df..790ad3f 100644 (file)
@@ -89,6 +89,7 @@ namespace System.Net.Quic
         public override bool CanTimeout { get { throw null; } }
         public override long Length { get { throw null; } }
         public override long Position { get { throw null; } set { } }
+        public bool ReadsCompleted { get { throw null; } }
         public long StreamId { get { throw null; } }
         public void AbortRead(long errorCode) { }
         public void AbortWrite(long errorCode) { }
index 1b58009..2c3d50a 100644 (file)
@@ -58,6 +58,8 @@ namespace System.Net.Quic.Implementations.Mock
 
         internal override bool CanRead => !_disposed && ReadStreamBuffer is not null;
 
+        internal override bool ReadsCompleted => ReadStreamBuffer?.IsComplete ?? false;
+
         internal override int Read(Span<byte> buffer)
         {
             CheckDisposed();
index 1aa31bd..bbaf9c4 100644 (file)
@@ -2,7 +2,6 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Buffers;
-using System.Collections.Generic;
 using System.Diagnostics;
 using System.IO;
 using System.Net.Quic.Implementations.MsQuic.Internal;
@@ -50,6 +49,7 @@ namespace System.Net.Quic.Implementations.MsQuic
             public QuicBuffer[] ReceiveQuicBuffers = Array.Empty<QuicBuffer>();
             public int ReceiveQuicBuffersCount;
             public int ReceiveQuicBuffersTotalBytes;
+            public bool ReceiveIsFinal;
 
             // set when ReadState.PendingRead:
             public Memory<byte> ReceiveUserBuffer;
@@ -193,6 +193,8 @@ namespace System.Net.Quic.Implementations.MsQuic
 
         internal override bool CanWrite => _disposed == 0 && _canWrite;
 
+        internal override bool ReadsCompleted => _state.ReadState == ReadState.ReadsCompleted;
+
         internal override bool CanTimeout => true;
 
         private int _readTimeout = Timeout.Infinite;
@@ -415,13 +417,13 @@ namespace System.Net.Quic.Implementations.MsQuic
                 initialReadState = _state.ReadState;
                 abortError = _state.ReadErrorCode;
 
-                // Failure scenario: pre-canceled token. Transition: any -> Aborted
+                // Failure scenario: pre-canceled token. Transition: Any non-final -> Aborted
                 // PendingRead state indicates there is another concurrent read operation in flight
                 // which is forbidden, so it is handled separately
                 if (initialReadState != ReadState.PendingRead && cancellationToken.IsCancellationRequested)
                 {
                     initialReadState = ReadState.Aborted;
-                    _state.ReadState = ReadState.Aborted;
+                    CleanupReadStateAndCheckPending(_state, ReadState.Aborted);
                     preCanceled = true;
                 }
 
@@ -442,16 +444,14 @@ namespace System.Net.Quic.Implementations.MsQuic
 
                     if (cancellationToken.CanBeCanceled)
                     {
+                        // Failure scenario: cancellation. Transition: Any non-final -> Aborted
                         _state.ReceiveCancellationRegistration = cancellationToken.UnsafeRegister(static (obj, token) =>
                         {
                             var state = (State)obj!;
                             bool completePendingRead;
                             lock (state)
                             {
-                                completePendingRead = state.ReadState == ReadState.PendingRead;
-                                state.Stream = null;
-                                state.ReceiveUserBuffer = null;
-                                state.ReadState = ReadState.Aborted;
+                                completePendingRead = CleanupReadStateAndCheckPending(state, ReadState.Aborted);
                             }
 
                             if (completePendingRead)
@@ -468,7 +468,8 @@ namespace System.Net.Quic.Implementations.MsQuic
                     return _state.ReceiveResettableCompletionSource.GetValueTask();
                 }
 
-                // Success scenario: data already available, completing synchronously. Transition IndividualReadComplete->None
+                // Success scenario: data already available, completing synchronously.
+                // Transition IndividualReadComplete->None, or IndividualReadComplete->ReadsCompleted, if it was the last message and we fully consumed it
                 if (initialReadState == ReadState.IndividualReadComplete)
                 {
                     _state.ReadState = ReadState.None;
@@ -481,6 +482,11 @@ namespace System.Net.Quic.Implementations.MsQuic
                         // Need to re-enable receives because MsQuic will pause them when we don't consume the entire buffer.
                         EnableReceive();
                     }
+                    else if (_state.ReceiveIsFinal)
+                    {
+                        // This was a final message and we've consumed everything. We can complete the state without waiting for PEER_SEND_SHUTDOWN
+                        _state.ReadState = ReadState.ReadsCompleted;
+                    }
 
                     return new ValueTask<int>(taken);
                 }
@@ -512,7 +518,10 @@ namespace System.Net.Quic.Implementations.MsQuic
         /// <returns>The number of bytes copied.</returns>
         private static unsafe int CopyMsQuicBuffersToUserBuffer(ReadOnlySpan<QuicBuffer> sourceBuffers, Span<byte> destinationBuffer)
         {
-            Debug.Assert(sourceBuffers.Length != 0);
+            if (sourceBuffers.Length == 0)
+            {
+                return 0;
+            }
 
             int originalDestinationLength = destinationBuffer.Length;
             QuicBuffer nativeBuffer;
@@ -543,16 +552,7 @@ namespace System.Net.Quic.Implementations.MsQuic
             bool shouldComplete = false;
             lock (_state)
             {
-                if (_state.ReadState == ReadState.PendingRead)
-                {
-                    shouldComplete = true;
-                    _state.Stream = null;
-                    _state.ReceiveUserBuffer = null;
-                }
-                if (_state.ReadState < ReadState.ReadsCompleted)
-                {
-                    _state.ReadState = ReadState.Aborted;
-                }
+                shouldComplete = CleanupReadStateAndCheckPending(_state, ReadState.Aborted);
             }
 
             if (shouldComplete)
@@ -754,9 +754,7 @@ namespace System.Net.Quic.Implementations.MsQuic
                 if (_state.ReadState < ReadState.ReadsCompleted || _state.ReadState == ReadState.Aborted)
                 {
                     abortRead = true;
-                    completeRead = _state.ReadState == ReadState.PendingRead;
-                    _state.Stream = null;
-                    _state.ReadState = ReadState.Aborted;
+                    completeRead = CleanupReadStateAndCheckPending(_state, ReadState.Aborted);
                 }
 
                 if (_state.ShutdownState == ShutdownState.None)
@@ -881,11 +879,9 @@ namespace System.Net.Quic.Implementations.MsQuic
         {
             ref StreamEventDataReceive receiveEvent = ref evt.Data.Receive;
 
-            if (receiveEvent.BufferCount == 0)
+            if (NetEventSource.Log.IsEnabled())
             {
-                // This is a 0-length receive that happens once reads are finished (via abort or otherwise).
-                // State changes for this are handled in PEER_SEND_SHUTDOWN / PEER_SEND_ABORT / SHUTDOWN_COMPLETE event handlers.
-                return MsQuicStatusCodes.Success;
+                NetEventSource.Info(state, $"{state.TraceId} Stream received {receiveEvent.TotalBufferLength} bytes{(receiveEvent.Flags.HasFlag(QUIC_RECEIVE_FLAGS.FIN) ? " with FIN flag" : "")}");
             }
 
             int readLength;
@@ -922,8 +918,27 @@ namespace System.Net.Quic.Implementations.MsQuic
 
                         state.ReceiveQuicBuffersCount = (int)receiveEvent.BufferCount;
                         state.ReceiveQuicBuffersTotalBytes = checked((int)receiveEvent.TotalBufferLength);
-                        state.ReadState = ReadState.IndividualReadComplete;
-                        return MsQuicStatusCodes.Pending;
+                        state.ReceiveIsFinal = receiveEvent.Flags.HasFlag(QUIC_RECEIVE_FLAGS.FIN);
+
+                        // 0-length receive can happens once reads are finished (gracefully or otherwise).
+                        if (state.ReceiveQuicBuffersTotalBytes == 0)
+                        {
+                            if (state.ReceiveIsFinal)
+                            {
+                                // We can complete the state without waiting for PEER_SEND_SHUTDOWN
+                                state.ReadState = ReadState.ReadsCompleted;
+                            }
+
+                            // if it was not a graceful shutdown, we defer aborting to PEER_SEND_ABORT event handler
+                            return MsQuicStatusCodes.Success;
+                        }
+                        else
+                        {
+                            // Normal RECEIVE - data will be buffered until user calls ReadAsync() and no new event will be issued until EnableReceive()
+                            state.ReadState = ReadState.IndividualReadComplete;
+                            return MsQuicStatusCodes.Pending;
+                        }
+
                     case ReadState.PendingRead:
                         // There is a pending ReadAsync().
 
@@ -933,8 +948,17 @@ namespace System.Net.Quic.Implementations.MsQuic
                         state.ReadState = ReadState.None;
 
                         readLength = CopyMsQuicBuffersToUserBuffer(new ReadOnlySpan<QuicBuffer>(receiveEvent.Buffers, (int)receiveEvent.BufferCount), state.ReceiveUserBuffer.Span);
+
+                        // This was a final message and we've consumed everything. We can complete the state without waiting for PEER_SEND_SHUTDOWN
+                        if (receiveEvent.Flags.HasFlag(QUIC_RECEIVE_FLAGS.FIN) && (uint)readLength == receiveEvent.TotalBufferLength)
+                        {
+                            state.ReadState = ReadState.ReadsCompleted;
+                        }
+                        // Else, if this was a final message, but we haven't consumed it fully, FIN flag will arrive again in the next RECEIVE event
+
                         state.ReceiveUserBuffer = null;
                         break;
+
                     default:
                         Debug.Assert(state.ReadState is ReadState.Aborted or ReadState.ConnectionClosed, $"Unexpected {nameof(ReadState)} '{state.ReadState}' in {nameof(HandleEventRecv)}.");
 
@@ -1008,16 +1032,7 @@ namespace System.Net.Quic.Implementations.MsQuic
                 // This event won't occur within the middle of a receive.
                 if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"{state.TraceId} Stream completing resettable event source.");
 
-                if (state.ReadState == ReadState.PendingRead)
-                {
-                    shouldReadComplete = true;
-                    state.Stream = null;
-                    state.ReceiveUserBuffer = null;
-                }
-                if (state.ReadState < ReadState.ReadsCompleted)
-                {
-                    state.ReadState = ReadState.ReadsCompleted;
-                }
+                shouldReadComplete = CleanupReadStateAndCheckPending(state, ReadState.ReadsCompleted);
 
                 if (state.ShutdownState == ShutdownState.None)
                 {
@@ -1051,13 +1066,7 @@ namespace System.Net.Quic.Implementations.MsQuic
             bool shouldComplete = false;
             lock (state)
             {
-                if (state.ReadState == ReadState.PendingRead)
-                {
-                    shouldComplete = true;
-                    state.Stream = null;
-                    state.ReceiveUserBuffer = null;
-                }
-                state.ReadState = ReadState.Aborted;
+                shouldComplete = CleanupReadStateAndCheckPending(state, ReadState.Aborted);
                 state.ReadErrorCode = (long)evt.Data.PeerSendAborted.ErrorCode;
             }
 
@@ -1079,16 +1088,7 @@ namespace System.Net.Quic.Implementations.MsQuic
                 // This event won't occur within the middle of a receive.
                 if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"{state.TraceId} Stream completing resettable event source.");
 
-                if (state.ReadState == ReadState.PendingRead)
-                {
-                    shouldComplete = true;
-                    state.Stream = null;
-                    state.ReceiveUserBuffer = null;
-                }
-                if (state.ReadState < ReadState.ReadsCompleted)
-                {
-                    state.ReadState = ReadState.ReadsCompleted;
-                }
+                shouldComplete = CleanupReadStateAndCheckPending(state, ReadState.ReadsCompleted);
             }
 
             if (shouldComplete)
@@ -1378,11 +1378,7 @@ namespace System.Net.Quic.Implementations.MsQuic
 
             lock (state)
             {
-                shouldCompleteRead = state.ReadState == ReadState.PendingRead;
-                if (state.ReadState < ReadState.ReadsCompleted)
-                {
-                    state.ReadState = ReadState.ConnectionClosed;
-                }
+                shouldCompleteRead = CleanupReadStateAndCheckPending(state, ReadState.ConnectionClosed);
 
                 if (state.SendState == SendState.None || state.SendState == SendState.Pending)
                 {
@@ -1428,15 +1424,47 @@ namespace System.Net.Quic.Implementations.MsQuic
         private static Exception GetConnectionAbortedException(State state) =>
             ThrowHelper.GetConnectionAbortedException(state.ConnectionState.AbortErrorCode);
 
+        private static bool CleanupReadStateAndCheckPending(State state, ReadState finalState)
+        {
+            Debug.Assert(finalState >= ReadState.ReadsCompleted, $"Expected final read state, got {finalState}");
+            Debug.Assert(Monitor.IsEntered(state));
+
+            bool shouldComplete = false;
+            if (state.ReadState == ReadState.PendingRead)
+            {
+                shouldComplete = true;
+                state.Stream = null;
+                state.ReceiveUserBuffer = null;
+                state.ReceiveCancellationRegistration.Unregister();
+            }
+            if (state.ReadState < ReadState.ReadsCompleted)
+            {
+                state.ReadState = finalState;
+            }
+            return shouldComplete;
+        }
+
         // Read state transitions:
         //
-        // None  --(data arrives in event RECV)->  IndividualReadComplete  --(user calls ReadAsync() & completes syncronously)->  None
-        // None  --(user calls ReadAsync() & waits)->  PendingRead  --(data arrives in event RECV & completes user's ReadAsync())->  None
+        // None  --(data arrives in event RECV)->  IndividualReadComplete
+        // None  --(data arrives in event RECV with FIN flag)->  IndividualReadComplete(+FIN)
+        // None  --(0-byte data arrives in event RECV with FIN flag)->  ReadsCompleted
+        // None  --(user calls ReadAsync() & waits)->  PendingRead
+        //
+        // IndividualReadComplete  --(user calls ReadAsync())->  None
+        // IndividualReadComplete(+FIN)  --(user calls ReadAsync() & consumes only partial data)->  None
+        // IndividualReadComplete(+FIN)  --(user calls ReadAsync() & consumes full data)->  ReadsCompleted
+        //
+        // PendingRead  --(data arrives in event RECV & completes user's ReadAsync())->  None
+        // PendingRead  --(data arrives in event RECV with FIN flag & completes user's ReadAsync() with only partial data)->  None
+        // PendingRead  --(data arrives in event RECV with FIN flag & completes user's ReadAsync() with full data)->  ReadsCompleted
+        //
         // Any non-final state  --(event PEER_SEND_SHUTDOWN or SHUTDOWN_COMPLETED with ConnectionClosed=false)->  ReadsCompleted
         // Any non-final state  --(event PEER_SEND_ABORT)->  Aborted
         // Any non-final state  --(user calls AbortRead())->  Aborted
-        // Any state  --(CancellationToken's cancellation for ReadAsync())->  Aborted (TODO: should it be only for non-final as others?)
+        // Any non-final state  --(CancellationToken's cancellation for ReadAsync())->  Aborted
         // Any non-final state  --(event SHUTDOWN_COMPLETED with ConnectionClosed=true)->  ConnectionClosed
+        //
         // Closed - no transitions, set for Unidirectional write-only streams
         private enum ReadState
         {
index f011e56..66c9a8b 100644 (file)
@@ -15,6 +15,8 @@ namespace System.Net.Quic.Implementations
 
         internal abstract bool CanRead { get; }
 
+        internal abstract bool ReadsCompleted { get; }
+
         internal abstract int ReadTimeout { get; set; }
 
         internal abstract int Read(Span<byte> buffer);
index 55ba995..8a6dbe4 100644 (file)
@@ -71,6 +71,8 @@ namespace System.Net.Quic
 
         public override bool CanRead => _provider.CanRead;
 
+        public bool ReadsCompleted => _provider.ReadsCompleted;
+
         public override int Read(Span<byte> buffer) => _provider.Read(buffer);
 
         public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default) => _provider.ReadAsync(buffer, cancellationToken);
index 59833d8..9e3cc64 100644 (file)
@@ -4,6 +4,7 @@
 using System.Buffers;
 using System.Collections.Generic;
 using System.Diagnostics;
+using System.Diagnostics.Tracing;
 using System.Linq;
 using System.Net.Security;
 using System.Net.Sockets;
@@ -22,7 +23,7 @@ namespace System.Net.Quic.Tests
     [Collection("NoParallelTests")]
     public class MsQuicTests : QuicTestBase<MsQuicProviderFactory>
     {
-        private static ReadOnlyMemory<byte> s_data = Encoding.UTF8.GetBytes("Hello world!");
+        private static byte[] s_data = Encoding.UTF8.GetBytes("Hello world!");
 
         public MsQuicTests(ITestOutputHelper output) : base(output) { }
 
@@ -840,5 +841,69 @@ namespace System.Net.Quic.Tests
                 }
             }
         }
+
+        [Fact]
+        public async Task BasicTest_WithReadsCompletedCheck()
+        {
+            await RunClientServer(
+                iterations: 100,
+                serverFunction: async connection =>
+                {
+                    using QuicStream stream = await connection.AcceptStreamAsync();
+                    Assert.False(stream.ReadsCompleted);
+
+                    byte[] buffer = new byte[s_data.Length];
+                    int bytesRead = await ReadAll(stream, buffer);
+
+                    Assert.True(stream.ReadsCompleted);
+                    Assert.Equal(s_data.Length, bytesRead);
+                    Assert.Equal(s_data, buffer);
+
+                    await stream.WriteAsync(s_data, endStream: true);
+                    await stream.ShutdownCompleted();
+                },
+                clientFunction: async connection =>
+                {
+                    using QuicStream stream = connection.OpenBidirectionalStream();
+                    Assert.False(stream.ReadsCompleted);
+
+                    await stream.WriteAsync(s_data, endStream: true);
+
+                    byte[] buffer = new byte[s_data.Length];
+                    int bytesRead = await ReadAll(stream, buffer);
+
+                    Assert.True(stream.ReadsCompleted);
+                    Assert.Equal(s_data.Length, bytesRead);
+                    Assert.Equal(s_data, buffer);
+
+                    await stream.ShutdownCompleted();
+                }
+            );
+        }
+
+        [Fact]
+        public async Task Read_ReadsCompleted_ReportedBeforeReturning0()
+        {
+            await RunBidirectionalClientServer(
+                async clientStream =>
+                {
+                    await clientStream.WriteAsync(new byte[1], endStream: true);
+                },
+                async serverStream =>
+                {
+                    Assert.False(serverStream.ReadsCompleted);
+
+                    var received = await serverStream.ReadAsync(new byte[1]);
+                    Assert.Equal(1, received);
+                    Assert.True(serverStream.ReadsCompleted);
+
+                    var task = serverStream.ReadAsync(new byte[1]);
+                    Assert.True(task.IsCompleted);
+
+                    received = await task;
+                    Assert.Equal(0, received);
+                    Assert.True(serverStream.ReadsCompleted);
+                });
+        }
     }
 }