implement cancellation support for SendFileAsync and DisconnectAsync (#53062)
authorGeoff Kizer <geoffrek@microsoft.com>
Thu, 27 May 2021 05:58:37 +0000 (22:58 -0700)
committerGitHub <noreply@github.com>
Thu, 27 May 2021 05:58:37 +0000 (22:58 -0700)
implement cancellation support for SendFileAsync and DisconnectAsync, and rework some internal async logic to support this and reduce code duplication

src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/SendFile.cs

index ae67f70..fb2c911 100644 (file)
@@ -1098,8 +1098,7 @@ namespace System.Net.Sockets
             {
                 Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use");
 
-                // TODO: Support cancellation by passing cancellationToken down through SendPacketsAsync, etc.
-                if (socket.SendPacketsAsync(this))
+                if (socket.SendPacketsAsync(this, cancellationToken))
                 {
                     _cancellationToken = cancellationToken;
                     return new ValueTask(this, _token);
@@ -1379,7 +1378,10 @@ namespace System.Net.Sockets
 
             private void ThrowException(SocketError error, CancellationToken cancellationToken)
             {
-                if (error == SocketError.OperationAborted)
+                // Most operations will report OperationAborted when canceled.
+                // On Windows, SendFileAsync will report ConnectionAborted.
+                // There's a race here anyway, so there's no harm in also checking for ConnectionAborted in all cases.
+                if (error == SocketError.OperationAborted || error == SocketError.ConnectionAborted)
                 {
                     cancellationToken.ThrowIfCancellationRequested();
                 }
index a4d24c2..9138e01 100644 (file)
@@ -3060,7 +3060,9 @@ namespace System.Net.Sockets
             return socketError == SocketError.IOPending;
         }
 
-        public bool SendPacketsAsync(SocketAsyncEventArgs e)
+        public bool SendPacketsAsync(SocketAsyncEventArgs e) => SendPacketsAsync(e, default(CancellationToken));
+
+        private bool SendPacketsAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
         {
             ThrowIfDisposed();
 
@@ -3082,7 +3084,7 @@ namespace System.Net.Sockets
             SocketError socketError;
             try
             {
-                socketError = e.DoOperationSendPackets(this, _handle);
+                socketError = e.DoOperationSendPackets(this, _handle, cancellationToken);
             }
             catch (Exception)
             {
index 907170d..4e1383c 100644 (file)
@@ -2100,7 +2100,7 @@ namespace System.Net.Sockets
             return operation.ErrorCode;
         }
 
-        public SocketError SendFileAsync(SafeFileHandle fileHandle, long offset, long count, out long bytesSent, Action<long, SocketError> callback)
+        public SocketError SendFileAsync(SafeFileHandle fileHandle, long offset, long count, out long bytesSent, Action<long, SocketError> callback, CancellationToken cancellationToken = default)
         {
             SetHandleNonBlocking();
 
@@ -2122,7 +2122,7 @@ namespace System.Net.Sockets
                 BytesTransferred = bytesSent
             };
 
-            if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber))
+            if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken))
             {
                 bytesSent = operation.BytesTransferred;
                 return operation.ErrorCode;
index 8812024..a8be414 100644 (file)
@@ -244,7 +244,7 @@ namespace System.Net.Sockets
             return errorCode;
         }
 
-        internal SocketError DoOperationSendPackets(Socket socket, SafeSocketHandle handle)
+        internal SocketError DoOperationSendPackets(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
         {
             Debug.Assert(_sendPacketsElements != null);
             SendPacketsElement[] elements = (SendPacketsElement[])_sendPacketsElements.Clone();
@@ -288,7 +288,7 @@ namespace System.Net.Sockets
                 throw;
             }
 
-            SocketPal.SendPacketsAsync(socket, SendPacketsFlags, elements, files, (bytesTransferred, error) =>
+            SocketPal.SendPacketsAsync(socket, SendPacketsFlags, elements, files, cancellationToken, (bytesTransferred, error) =>
             {
                 if (error == SocketError.Success)
                 {
index d79d340..73cb36f 100644 (file)
@@ -13,44 +13,48 @@ namespace System.Net.Sockets
 {
     public partial class SocketAsyncEventArgs : EventArgs, IDisposable
     {
-        // Single buffer
-        private MemoryHandle _singleBufferHandle;
-
-        /// <summary>The state of the <see cref="_singleBufferHandle"/> to track whether and when it requires disposal to unpin memory.</summary>
+        /// <summary>Tracks whether and when whether to perform cleanup of async processing state, specifically <see cref="_singleBufferHandle"/> and <see cref="_registrationToCancelPendingIO"/>.</summary>
         /// <remarks>
-        /// Pinning via a GCHandle (the mechanism used by Memory) has measurable overhead, and for operations like
-        /// send and receive that we want to optimize and that frequently complete synchronously, we
-        /// want to avoid such GCHandle interactions whenever possible.  To achieve that, we used `fixed`
-        /// to pin the relevant state while starting the async operation, and then only if the operation
-        /// is seen to be pending do we use Pin to create the GCHandle; this is done while the `fixed` is
-        /// still in scope, to ensure that throughout the whole operation the buffer remains pinned, while
-        /// using the much cheaper `fixed` only for the fast path.  <see cref="_singleBufferHandle"/> starts
-        /// life as None, transitions to InProcess prior to initiating the async operation, and then transitions
-        /// either back to None if the buffer never needed to be pinned, or to Set once it has been pinned. This
-        /// ensures that asynchronous completion racing with the code that is still setting up the operation
-        /// can properly clean up after pinned memory, even if it needs to wait momentarily to do so.
+        /// We sometimes want to or need to defer the initialization of some aspects of async processing until
+        /// the operation pends, that is, the OS async call returns and indicates that it will complete asynchronously.
+        ///
+        /// We do this in two cases: to optimize buffer pinning in certain cases, and to handle cancellation.
+        ///
+        /// We optimize buffer pinning for operations that frequently complete synchronously, like Send and Receive,
+        /// when they use a single buffer. (Optimizing for multiple buffers is more complicated, so we don't currently do it.)
+        /// In these cases, we used the much-cheaper `fixed` keyword to pin the buffer while the OS async call is in progress.
+        /// If the operation pends, we will then pin using MemoryHandle.Pin() before exiting the `fixed` scope,
+        /// thus ensuring that the buffer remains pinned throughout the whole operation.
         ///
-        /// Currently, only the operations that use <see cref="_singleBufferHandle"/> and <see cref="_singleBufferHandleState"/>
-        /// are cancelable, and as such <see cref="_singleBufferHandleState"/> is also used to guard the cleanup
-        /// of <see cref="_registrationToCancelPendingIO"/>.
+        /// For cancellation, we always need to defer cancellation registration until after the OS call has pended.
+        ///
+        /// To coordinate with the completion callback from the IOCP, we set <see cref="_asyncProcessingState"/> to InProcess
+        /// before performing the OS async call, and then transition it to Set once the operation has pended
+        /// and the relevant async setup (pinning, cancellation registration) has occurred.
+        /// This ensures that the completion callback will only clean up the async state (unpin, unregister cancellation)
+        /// once it has been fully constructed by the original calling thread, even if it needs to wait momentarily to do so.
+        ///
+        /// For some operations, like Connect and multi-buffer Send/Receive, we do not do any deferred async processing.
+        /// Instead, we perform any necessary setup and set <see cref="_asyncProcessingState"/> to Set before making the call.
+        /// The cleanup logic will be invoked as always, but in this case will never need to wait on the InProcess state.
+        /// If an operation does not require any cleanup of this state, it can simply leave <see cref="_asyncProcessingState"/> as None.
         /// </remarks>
-        private volatile SingleBufferHandleState _singleBufferHandleState;
+        private volatile AsyncProcessingState _asyncProcessingState;
 
-        /// <summary>Defines possible states for <see cref="_singleBufferHandleState"/> in order to faciliate correct cleanup of any pinned state.</summary>
-        private enum SingleBufferHandleState : byte
+        /// <summary>Defines possible states for <see cref="_asyncProcessingState"/> in order to faciliate correct cleanup of asynchronous processing state.</summary>
+        private enum AsyncProcessingState : byte
         {
-            /// <summary>No operation using <see cref="_singleBufferHandle"/> is in flight, and no cleanup of <see cref="_singleBufferHandle"/> is required.</summary>
+            /// <summary>No cleanup is required, either because no operation is in flight or the current operation does not require cleanup.</summary>
             None,
-            /// <summary>
-            /// An operation potentially using <see cref="_singleBufferHandle"/> is in flight, but the field hasn't yet been initialized.
-            /// It's possible <see cref="_singleBufferHandle"/> will transition to <see cref="Set"/>, and thus code needs to wait for the
-            /// value to no longer be <see cref="InProcess"/> before <see cref="_singleBufferHandle"/> can be disposed.
-            /// </summary>
+            /// <summary>An operation is in flight but async processing state has not yet been initialized.</summary>
             InProcess,
-            /// <summary>The <see cref="_singleBufferHandle"/> field has been initialized and requires disposal.  It is safe to dispose of when the operation no longer needs it.</summary>
+            /// <summary>An operation is in flight and async processing state is fully initialized and ready to be cleaned up.</summary>
             Set
         }
 
+        // Single buffer pin handle
+        private MemoryHandle _singleBufferHandle;
+
         // BufferList property variables.
         // Note that these arrays are allocated and then grown as necessary, but never shrunk.
         // Thus the actual in-use length is defined by _bufferListInternal.Count, not the length of these arrays.
@@ -123,10 +127,10 @@ namespace System.Net.Sockets
 
         private unsafe void RegisterToCancelPendingIO(NativeOverlapped* overlapped, CancellationToken cancellationToken)
         {
-            Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.InProcess, "An operation must be declared in-flight in order to register to cancel it.");
+            Debug.Assert(_asyncProcessingState == AsyncProcessingState.InProcess, "An operation must be declared in-flight in order to register to cancel it.");
             Debug.Assert(_pendingOverlappedForCancellation == null);
             _pendingOverlappedForCancellation = overlapped;
-            _registrationToCancelPendingIO = cancellationToken.UnsafeRegister(s =>
+            _registrationToCancelPendingIO = cancellationToken.UnsafeRegister(static s =>
             {
                 // Try to cancel the I/O.  We ignore the return value (other than for logging), as cancellation
                 // is opportunistic and we don't want to fail the operation because we couldn't cancel it.
@@ -161,120 +165,94 @@ namespace System.Net.Sockets
             _strongThisRef.Value = this;
         }
 
-        /// <summary>Handles the result of an IOCP operation.</summary>
-        /// <param name="success">true if the operation completed synchronously and successfully; otherwise, false.</param>
-        /// <param name="bytesTransferred">The number of bytes transferred, if the operation completed synchronously and successfully.</param>
-        /// <param name="overlapped">The overlapped to be freed if the operation completed synchronously.</param>
-        /// <returns>The result status of the operation.</returns>
-        private unsafe SocketError ProcessIOCPResult(bool success, int bytesTransferred, NativeOverlapped* overlapped)
+        /// <summary>Gets the result of an IOCP operation and determines how it should be handled (synchronously or asynchronously).</summary>
+        /// <param name="success">true if the IOCP operation indicated synchronous success; otherwise, false.</param>
+        /// <param name="overlapped">The overlapped that was used for this operation. Will be freed if the operation result will be handled synchronously.</param>
+        /// <returns>The SocketError for the operation. This will be SocketError.IOPending if the operation will be handled asynchronously.</returns>
+        private unsafe SocketError GetIOCPResult(bool success, NativeOverlapped* overlapped)
         {
-            // Note: We need to dispose of the overlapped iff the operation completed synchronously,
-            // and if we do, we must do so before we mark the operation as completed.
+            // Note: We need to dispose of the overlapped iff the operation result will be handled synchronously.
 
             if (success)
             {
                 // Synchronous success.
                 if (_currentSocket!.SafeHandle.SkipCompletionPortOnSuccess)
                 {
-                    // The socket handle is configured to skip completion on success,
-                    // so we can set the results right now.
+                    // The socket handle is configured to skip completion on success, so we can handle the result synchronously.
                     FreeNativeOverlapped(overlapped);
-                    FinishOperationSyncSuccess(bytesTransferred, SocketFlags.None);
-
-                    if (SocketsTelemetry.Log.IsEnabled() && !_disableTelemetry) AfterConnectAcceptTelemetry();
-
                     return SocketError.Success;
                 }
 
                 // Completed synchronously, but the handle wasn't marked as skip completion port on success,
-                // so we still need to fall through and behave as if the IO was pending.
+                // so we still need to behave as if the IO was pending and wait for the completion to come through on the IOCP.
+                return SocketError.IOPending;
             }
             else
             {
                 // Get the socket error (which may be IOPending)
                 SocketError socketError = SocketPal.GetLastSocketError();
+                Debug.Assert(socketError != SocketError.Success);
                 if (socketError != SocketError.IOPending)
                 {
                     // Completed synchronously with a failure.
+                    // No IOCP completion will occur.
                     FreeNativeOverlapped(overlapped);
-                    FinishOperationSyncFailure(socketError, bytesTransferred, SocketFlags.None);
-
-                    if (SocketsTelemetry.Log.IsEnabled() && !_disableTelemetry) AfterConnectAcceptTelemetry();
-
                     return socketError;
                 }
 
-                // Fall through to IOPending handling for asynchronous completion.
+                // The completion will arrive on the IOCP when the operation is done.
+                return SocketError.IOPending;
             }
-
-            // Socket handle is going to post a completion to the completion port (may have done so already).
-            // Return pending and we will continue in the completion port callback.
-            return SocketError.IOPending;
         }
 
         /// <summary>Handles the result of an IOCP operation.</summary>
-        /// <param name="socketError">The result status of the operation, as returned from the API call.</param>
+        /// <param name="success">true if the IOCP operation indicated synchronous success; otherwise, false.</param>
         /// <param name="bytesTransferred">The number of bytes transferred, if the operation completed synchronously and successfully.</param>
-        /// <param name="overlapped">The overlapped to be freed if the operation completed synchronously.</param>
-        /// <param name="cancellationToken">The cancellation token to use to cancel the operation.</param>
+        /// <param name="overlapped">The overlapped that was used for this operation. Will be freed if the operation result will be handled synchronously.</param>
         /// <returns>The result status of the operation.</returns>
-        private unsafe SocketError ProcessIOCPResultWithSingleBufferHandle(SocketError socketError, int bytesTransferred, NativeOverlapped* overlapped, CancellationToken cancellationToken = default)
+        private unsafe SocketError ProcessIOCPResult(bool success, int bytesTransferred, NativeOverlapped* overlapped)
         {
-            // Note: We need to dispose of the overlapped iff the operation completed synchronously,
-            // and if we do, we must do so before we mark the operation as completed.
-
-            if (socketError == SocketError.Success)
-            {
-                // Synchronous success.
-                if (_currentSocket!.SafeHandle.SkipCompletionPortOnSuccess)
-                {
-                    // The socket handle is configured to skip completion on success,
-                    // so we can set the results right now.
-                    _singleBufferHandleState = SingleBufferHandleState.None;
-                    FreeNativeOverlapped(overlapped);
-                    FinishOperationSyncSuccess(bytesTransferred, SocketFlags.None);
+            Debug.Assert(_asyncProcessingState != AsyncProcessingState.InProcess);
 
-                    if (SocketsTelemetry.Log.IsEnabled() && !_disableTelemetry) AfterConnectAcceptTelemetry();
+            SocketError socketError = GetIOCPResult(success, overlapped);
 
-                    return SocketError.Success;
-                }
-
-                // Completed synchronously, but the handle wasn't marked as skip completion port on success,
-                // so we still need to fall through and behave as if the IO was pending.
-            }
-            else
+            if (socketError != SocketError.IOPending)
             {
-                // Get the socket error (which may be IOPending)
-                socketError = SocketPal.GetLastSocketError();
-                if (socketError != SocketError.IOPending)
-                {
-                    // Completed synchronously with a failure.
-                    _singleBufferHandleState = SingleBufferHandleState.None;
-                    FreeNativeOverlapped(overlapped);
-                    FinishOperationSyncFailure(socketError, bytesTransferred, SocketFlags.None);
+                FinishOperationSync(socketError, bytesTransferred, SocketFlags.None);
+            }
 
-                    if (SocketsTelemetry.Log.IsEnabled() && !_disableTelemetry) AfterConnectAcceptTelemetry();
+            return socketError;
+        }
 
-                    return socketError;
-                }
+        /// <summary>Handles the result of an IOCP operation for which we have deferred async processing logic (buffer pinning or cancellation).</summary>
+        /// <param name="success">true if the IOCP operation indicated synchronous success; otherwise, false.</param>
+        /// <param name="bytesTransferred">The number of bytes transferred, if the operation completed synchronously and successfully.</param>
+        /// <param name="overlapped">The overlapped that was used for this operation. Will be freed if the operation result will be handled synchronously.</param>
+        /// <param name="bufferToPin">The buffer to pin. May be Memory.Empty if no buffer should be pinned.
+        ///     Note this buffer (if not empty) should already be pinned locally using `fixed` prior to the OS async call and until after this method returns.</param>
+        /// <param name="cancellationToken">The cancellation token to use to cancel the operation.</param>
+        /// <returns>The result status of the operation.</returns>
+        private unsafe SocketError ProcessIOCPResultWithDeferredAsyncHandling(bool success, int bytesTransferred, NativeOverlapped* overlapped, Memory<byte> bufferToPin, CancellationToken cancellationToken = default)
+        {
+            Debug.Assert(_asyncProcessingState == AsyncProcessingState.InProcess);
 
-                // Fall through to IOPending handling for asynchronous completion.
-            }
+            SocketError socketError = GetIOCPResult(success, overlapped);
 
-            // Socket handle is going to post a completion to the completion port (may have done so already).
-            // Return pending and we will continue in the completion port callback.
-            if (_singleBufferHandleState == SingleBufferHandleState.InProcess)
+            if (socketError == SocketError.IOPending)
             {
-                // Register for cancellation.  This must happen before we change state to Set, as once it's Set,
-                // the operation completing asynchronously could invoke cleanup, which includes disposing of the
-                // cancellation registration, and thus the registration needs to be stored prior to setting Set.
                 RegisterToCancelPendingIO(overlapped, cancellationToken);
 
-                _singleBufferHandle = _buffer.Pin();
-                _singleBufferHandleState = SingleBufferHandleState.Set;
+                _singleBufferHandle = bufferToPin.Pin();
+
+                _asyncProcessingState = AsyncProcessingState.Set;
+            }
+            else
+            {
+                _asyncProcessingState = AsyncProcessingState.None;
+                FinishOperationSync(socketError, bytesTransferred, SocketFlags.None);
             }
 
-            return SocketError.IOPending;
+            return socketError;
         }
 
         internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle acceptHandle)
@@ -282,13 +260,13 @@ namespace System.Net.Sockets
             bool userBuffer = _count != 0;
             Debug.Assert(!userBuffer || (!_buffer.Equals(default) && _count >= _acceptAddressBufferCount));
             Memory<byte> buffer = userBuffer ? _buffer : _acceptBuffer;
-            Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None);
+            Debug.Assert(_asyncProcessingState == AsyncProcessingState.None);
 
             NativeOverlapped* overlapped = AllocateNativeOverlapped();
             try
             {
                 _singleBufferHandle = buffer.Pin();
-                _singleBufferHandleState = SingleBufferHandleState.Set;
+                _asyncProcessingState = AsyncProcessingState.Set;
 
                 bool success = socket.AcceptEx(
                     handle,
@@ -304,7 +282,7 @@ namespace System.Net.Sockets
             }
             catch
             {
-                _singleBufferHandleState = SingleBufferHandleState.None;
+                _asyncProcessingState = AsyncProcessingState.None;
                 FreeNativeOverlapped(overlapped);
                 _singleBufferHandle.Dispose();
                 throw;
@@ -329,9 +307,9 @@ namespace System.Net.Sockets
             NativeOverlapped* overlapped = AllocateNativeOverlapped();
             try
             {
-                Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None);
+                Debug.Assert(_asyncProcessingState == AsyncProcessingState.None);
                 _singleBufferHandle = _buffer.Pin();
-                _singleBufferHandleState = SingleBufferHandleState.Set;
+                _asyncProcessingState = AsyncProcessingState.Set;
 
                 bool success = socket.ConnectEx(
                     handle,
@@ -346,7 +324,7 @@ namespace System.Net.Sockets
             }
             catch
             {
-                _singleBufferHandleState = SingleBufferHandleState.None;
+                _asyncProcessingState = AsyncProcessingState.None;
                 FreeNativeOverlapped(overlapped);
                 _singleBufferHandle.Dispose();
                 throw;
@@ -355,22 +333,23 @@ namespace System.Net.Sockets
 
         internal unsafe SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
         {
-            // Note: CancellationToken is ignored for now.
-            // See https://github.com/dotnet/runtime/issues/51452
-
             NativeOverlapped* overlapped = AllocateNativeOverlapped();
             try
             {
+                Debug.Assert(_asyncProcessingState == AsyncProcessingState.None, $"Expected None, got {_asyncProcessingState}");
+                _asyncProcessingState = AsyncProcessingState.InProcess;
+
                 bool success = socket.DisconnectEx(
                     handle,
                     overlapped,
                     (int)(DisconnectReuseSocket ? TransmitFileOptions.ReuseSocket : 0),
                     0);
 
-                return ProcessIOCPResult(success, 0, overlapped);
+                return ProcessIOCPResultWithDeferredAsyncHandling(success, 0, overlapped, Memory<byte>.Empty, cancellationToken);
             }
             catch
             {
+                _asyncProcessingState = AsyncProcessingState.None;
                 FreeNativeOverlapped(overlapped);
                 throw;
             }
@@ -387,8 +366,8 @@ namespace System.Net.Sockets
                 NativeOverlapped* overlapped = AllocateNativeOverlapped();
                 try
                 {
-                    Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None, $"Expected None, got {_singleBufferHandleState}");
-                    _singleBufferHandleState = SingleBufferHandleState.InProcess;
+                    Debug.Assert(_asyncProcessingState == AsyncProcessingState.None, $"Expected None, got {_asyncProcessingState}");
+                    _asyncProcessingState = AsyncProcessingState.InProcess;
                     var wsaBuffer = new WSABuffer { Length = _count, Pointer = (IntPtr)(bufferPtr + _offset) };
 
                     SocketFlags flags = _socketFlags;
@@ -401,11 +380,11 @@ namespace System.Net.Sockets
                         overlapped,
                         IntPtr.Zero);
 
-                    return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken);
+                    return ProcessIOCPResultWithDeferredAsyncHandling(socketError == SocketError.Success, bytesTransferred, overlapped, _buffer, cancellationToken);
                 }
                 catch
                 {
-                    _singleBufferHandleState = SingleBufferHandleState.None;
+                    _asyncProcessingState = AsyncProcessingState.None;
                     FreeNativeOverlapped(overlapped);
                     throw;
                 }
@@ -457,8 +436,8 @@ namespace System.Net.Sockets
                 NativeOverlapped* overlapped = AllocateNativeOverlapped();
                 try
                 {
-                    Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None);
-                    _singleBufferHandleState = SingleBufferHandleState.InProcess;
+                    Debug.Assert(_asyncProcessingState == AsyncProcessingState.None);
+                    _asyncProcessingState = AsyncProcessingState.InProcess;
                     var wsaBuffer = new WSABuffer { Length = _count, Pointer = (IntPtr)(bufferPtr + _offset) };
 
                     SocketFlags flags = _socketFlags;
@@ -473,11 +452,11 @@ namespace System.Net.Sockets
                         overlapped,
                         IntPtr.Zero);
 
-                    return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken);
+                    return ProcessIOCPResultWithDeferredAsyncHandling(socketError == SocketError.Success, bytesTransferred, overlapped, _buffer, cancellationToken);
                 }
                 catch
                 {
-                    _singleBufferHandleState = SingleBufferHandleState.None;
+                    _asyncProcessingState = AsyncProcessingState.None;
                     FreeNativeOverlapped(overlapped);
                     throw;
                 }
@@ -552,8 +531,8 @@ namespace System.Net.Sockets
 
                 fixed (byte* bufferPtr = &MemoryMarshal.GetReference(_buffer.Span))
                 {
-                    Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None);
-                    _singleBufferHandleState = SingleBufferHandleState.InProcess;
+                    Debug.Assert(_asyncProcessingState == AsyncProcessingState.None);
+                    _asyncProcessingState = AsyncProcessingState.InProcess;
 
                     _wsaRecvMsgWSABufferArrayPinned[0].Pointer = (IntPtr)bufferPtr + _offset;
                     _wsaRecvMsgWSABufferArrayPinned[0].Length = _count;
@@ -608,12 +587,12 @@ namespace System.Net.Sockets
                         IntPtr.Zero);
 
                     return _bufferList == null ?
-                        ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken) :
+                        ProcessIOCPResultWithDeferredAsyncHandling(socketError == SocketError.Success, bytesTransferred, overlapped, _buffer, cancellationToken) :
                         ProcessIOCPResult(socketError == SocketError.Success, bytesTransferred, overlapped);
                 }
                 catch
                 {
-                    _singleBufferHandleState = SingleBufferHandleState.None;
+                    _asyncProcessingState = AsyncProcessingState.None;
                     FreeNativeOverlapped(overlapped);
                     throw;
                 }
@@ -631,8 +610,8 @@ namespace System.Net.Sockets
                 NativeOverlapped* overlapped = AllocateNativeOverlapped();
                 try
                 {
-                    Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None);
-                    _singleBufferHandleState = SingleBufferHandleState.InProcess;
+                    Debug.Assert(_asyncProcessingState == AsyncProcessingState.None);
+                    _asyncProcessingState = AsyncProcessingState.InProcess;
                     var wsaBuffer = new WSABuffer { Length = _count, Pointer = (IntPtr)(bufferPtr + _offset) };
 
                     SocketError socketError = Interop.Winsock.WSASend(
@@ -644,11 +623,11 @@ namespace System.Net.Sockets
                         overlapped,
                         IntPtr.Zero);
 
-                    return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken);
+                    return ProcessIOCPResultWithDeferredAsyncHandling(socketError == SocketError.Success, bytesTransferred, overlapped, _buffer, cancellationToken);
                 }
                 catch
                 {
-                    _singleBufferHandleState = SingleBufferHandleState.None;
+                    _asyncProcessingState = AsyncProcessingState.None;
                     FreeNativeOverlapped(overlapped);
                     throw;
                 }
@@ -678,7 +657,7 @@ namespace System.Net.Sockets
             }
         }
 
-        internal unsafe SocketError DoOperationSendPackets(Socket socket, SafeSocketHandle handle)
+        internal unsafe SocketError DoOperationSendPackets(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken)
         {
             // Cache copy to avoid problems with concurrent manipulation during the async operation.
             Debug.Assert(_sendPacketsElements != null);
@@ -757,6 +736,9 @@ namespace System.Net.Sockets
             NativeOverlapped* overlapped = AllocateNativeOverlapped();
             try
             {
+                Debug.Assert(_asyncProcessingState == AsyncProcessingState.None);
+                _asyncProcessingState = AsyncProcessingState.InProcess;
+
                 bool result = socket.TransmitPackets(
                     handle,
                     Marshal.UnsafeAddrOfPinnedArrayElement(sendPacketsDescriptorPinned, 0),
@@ -765,10 +747,11 @@ namespace System.Net.Sockets
                     overlapped,
                     _sendPacketsFlags);
 
-                return ProcessIOCPResult(result, 0, overlapped);
+                return ProcessIOCPResultWithDeferredAsyncHandling(result, 0, overlapped, Memory<byte>.Empty, cancellationToken);
             }
             catch
             {
+                _asyncProcessingState = AsyncProcessingState.None;
                 FreeNativeOverlapped(overlapped);
                 throw;
             }
@@ -796,8 +779,8 @@ namespace System.Net.Sockets
                 NativeOverlapped* overlapped = AllocateNativeOverlapped();
                 try
                 {
-                    Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None);
-                    _singleBufferHandleState = SingleBufferHandleState.InProcess;
+                    Debug.Assert(_asyncProcessingState == AsyncProcessingState.None);
+                    _asyncProcessingState = AsyncProcessingState.InProcess;
                     var wsaBuffer = new WSABuffer { Length = _count, Pointer = (IntPtr)(bufferPtr + _offset) };
 
                     SocketError socketError = Interop.Winsock.WSASendTo(
@@ -811,11 +794,11 @@ namespace System.Net.Sockets
                         overlapped,
                         IntPtr.Zero);
 
-                    return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken);
+                    return ProcessIOCPResultWithDeferredAsyncHandling(socketError == SocketError.Success, bytesTransferred, overlapped, _buffer, cancellationToken);
                 }
                 catch
                 {
-                    _singleBufferHandleState = SingleBufferHandleState.None;
+                    _asyncProcessingState = AsyncProcessingState.None;
                     FreeNativeOverlapped(overlapped);
                     throw;
                 }
@@ -957,9 +940,9 @@ namespace System.Net.Sockets
         {
             _pinState = PinState.None;
 
-            if (_singleBufferHandleState != SingleBufferHandleState.None)
+            if (_asyncProcessingState != AsyncProcessingState.None)
             {
-                _singleBufferHandleState = SingleBufferHandleState.None;
+                _asyncProcessingState = AsyncProcessingState.None;
                 _singleBufferHandle.Dispose();
             }
 
@@ -1105,7 +1088,7 @@ namespace System.Net.Sockets
                 safeHandle.DangerousAddRef(ref refAdded);
                 IntPtr handle = safeHandle.DangerousGetHandle();
 
-                Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set);
+                Debug.Assert(_asyncProcessingState == AsyncProcessingState.Set);
                 bool userBuffer = _count >= _acceptAddressBufferCount;
 
                 _currentSocket.GetAcceptExSockaddrs(
@@ -1181,7 +1164,7 @@ namespace System.Net.Sockets
         {
             _strongThisRef.Value = null; // null out this reference from the overlapped so this isn't kept alive artificially
 
-            if (_singleBufferHandleState != SingleBufferHandleState.None)
+            if (_asyncProcessingState != AsyncProcessingState.None)
             {
                 // If the state isn't None, then either it's Set, in which case there's state to cleanup,
                 // or it's InProcess, which can happen if the async operation was scheduled and actually
@@ -1201,30 +1184,27 @@ namespace System.Net.Sockets
                 // initiating it.  Wait until that initiation code has completed before
                 // we try to undo the state it configures.
                 SpinWait sw = default;
-                while (_singleBufferHandleState == SingleBufferHandleState.InProcess)
+                while (_asyncProcessingState == AsyncProcessingState.InProcess)
                 {
                     sw.SpinOnce();
                 }
 
-                Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set);
+                Debug.Assert(_asyncProcessingState == AsyncProcessingState.Set);
 
                 // Remove any cancellation registration.  First dispose the registration
                 // to ensure that cancellation will either never fine or will have completed
                 // firing before we continue.  Only then can we safely null out the overlapped.
                 _registrationToCancelPendingIO.Dispose();
+                _registrationToCancelPendingIO = default;
                 unsafe
                 {
                     _pendingOverlappedForCancellation = null;
                 }
 
                 // Release any GC handles.
-                Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set);
+                _singleBufferHandle.Dispose();
 
-                if (_singleBufferHandleState == SingleBufferHandleState.Set)
-                {
-                    _singleBufferHandleState = SingleBufferHandleState.None;
-                    _singleBufferHandle.Dispose();
-                }
+                _asyncProcessingState = AsyncProcessingState.None;
             }
         }
 
index 0e64fc1..7006ff6 100644 (file)
@@ -7,6 +7,7 @@ using System.Diagnostics;
 using System.IO;
 using System.Runtime.CompilerServices;
 using System.Runtime.InteropServices;
+using System.Threading;
 using System.Threading.Tasks;
 using Microsoft.Win32.SafeHandles;
 
@@ -1906,10 +1907,10 @@ namespace System.Net.Sockets
             return GetSocketErrorForErrorCode(err);
         }
 
-        private static SocketError SendFileAsync(SafeSocketHandle handle, FileStream fileStream, long offset, long count, Action<long, SocketError> callback)
+        private static SocketError SendFileAsync(SafeSocketHandle handle, FileStream fileStream, long offset, long count, CancellationToken cancellationToken, Action<long, SocketError> callback)
         {
             long bytesSent;
-            SocketError socketError = handle.AsyncContext.SendFileAsync(fileStream.SafeFileHandle, offset, count, out bytesSent, callback);
+            SocketError socketError = handle.AsyncContext.SendFileAsync(fileStream.SafeFileHandle, offset, count, out bytesSent, callback, cancellationToken);
             if (socketError == SocketError.Success)
             {
                 callback(bytesSent, SocketError.Success);
@@ -1918,7 +1919,7 @@ namespace System.Net.Sockets
         }
 
         public static async void SendPacketsAsync(
-            Socket socket, TransmitFileOptions options, SendPacketsElement[] elements, FileStream[] files, Action<long, SocketError> callback)
+            Socket socket, TransmitFileOptions options, SendPacketsElement[] elements, FileStream[] files, CancellationToken cancellationToken, Action<long, SocketError> callback)
         {
             SocketError error = SocketError.Success;
             long bytesTransferred = 0;
@@ -1932,7 +1933,7 @@ namespace System.Net.Sockets
                     {
                         if (e.MemoryBuffer != null)
                         {
-                            bytesTransferred += await socket.SendAsync(e.MemoryBuffer.Value, SocketFlags.None).ConfigureAwait(false);
+                            bytesTransferred += await socket.SendAsync(e.MemoryBuffer.Value, SocketFlags.None, cancellationToken).ConfigureAwait(false);
                         }
                         else
                         {
@@ -1945,6 +1946,7 @@ namespace System.Net.Sockets
                             var tcs = new TaskCompletionSource<SocketError>();
                             error = SendFileAsync(socket.InternalSafeHandle, fs, e.OffsetLong,
                                 e.Count > 0 ? e.Count : fs.Length - e.OffsetLong,
+                                cancellationToken,
                                 (transferred, se) =>
                                 {
                                     bytesTransferred += transferred;
index 319403d..31f8f6c 100644 (file)
@@ -4,6 +4,7 @@
 using System.Collections.Generic;
 using System.IO;
 using System.Linq;
+using System.Threading;
 using System.Threading.Tasks;
 
 using Xunit;
@@ -420,6 +421,60 @@ namespace System.Net.Sockets.Tests
     public sealed class SendFile_Task : SendFile<SocketHelperTask>
     {
         public SendFile_Task(ITestOutputHelper output) : base(output) { }
+
+        [Fact]
+        public async Task Precanceled_Throws()
+        {
+            using (var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+            using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+            {
+                listener.BindToAnonymousPort(IPAddress.Loopback);
+                listener.Listen(1);
+
+                await client.ConnectAsync(listener.LocalEndPoint);
+                using (Socket server = await listener.AcceptAsync())
+                {
+                    var cts = new CancellationTokenSource();
+                    cts.Cancel();
+
+                    await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await server.SendFileAsync(null, ReadOnlyMemory<byte>.Empty, ReadOnlyMemory<byte>.Empty, TransmitFileOptions.UseDefaultWorkerThread, cts.Token));
+                }
+            }
+        }
+
+        [Theory]
+        [InlineData(false)]
+        [InlineData(true)]
+        public async Task SendAsync_CanceledDuringOperation_Throws(bool ipv6)
+        {
+            const int CancelAfter = 200; // ms
+            const int NumOfSends = 100;
+            const int SendBufferSize = 1024;
+
+            (Socket client, Socket server) = SocketTestExtensions.CreateConnectedSocketPair(ipv6);
+            byte[] buffer = new byte[1024 * 64];
+            using (client)
+            using (server)
+            {
+                client.SendBufferSize = SendBufferSize;
+                CancellationTokenSource cts = new CancellationTokenSource();
+
+                List<Task> tasks = new List<Task>();
+
+                // After flooding the socket with a high number of SendFile tasks,
+                // we assume some of them won't complete before the "CancelAfter" period expires.
+                for (int i = 0; i < NumOfSends; i++)
+                {
+                    var task = server.SendFileAsync(null, buffer, ReadOnlyMemory<byte>.Empty, TransmitFileOptions.UseDefaultWorkerThread, cts.Token).AsTask();
+                    tasks.Add(task);
+                }
+
+                cts.CancelAfter(CancelAfter);
+
+                // We shall see at least one cancellation amongst all the scheduled sends:
+                await Assert.ThrowsAnyAsync<OperationCanceledException>(() => Task.WhenAll(tasks));
+            }
+        }
     }
 
     public sealed class SendFile_Apm : SendFile<SocketHelperApm>