From bb38de7bc6e24090cc673b20d7b910be7e0efd26 Mon Sep 17 00:00:00 2001 From: Geoff Kizer Date: Wed, 26 May 2021 22:58:37 -0700 Subject: [PATCH] implement cancellation support for SendFileAsync and DisconnectAsync (#53062) implement cancellation support for SendFileAsync and DisconnectAsync, and rework some internal async logic to support this and reduce code duplication --- .../src/System/Net/Sockets/Socket.Tasks.cs | 8 +- .../src/System/Net/Sockets/Socket.cs | 6 +- .../System/Net/Sockets/SocketAsyncContext.Unix.cs | 4 +- .../Net/Sockets/SocketAsyncEventArgs.Unix.cs | 4 +- .../Net/Sockets/SocketAsyncEventArgs.Windows.cs | 270 ++++++++++----------- .../src/System/Net/Sockets/SocketPal.Unix.cs | 10 +- .../tests/FunctionalTests/SendFile.cs | 55 +++++ 7 files changed, 199 insertions(+), 158 deletions(-) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index ae67f70..fb2c911 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -1098,8 +1098,7 @@ namespace System.Net.Sockets { Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); - // TODO: Support cancellation by passing cancellationToken down through SendPacketsAsync, etc. - if (socket.SendPacketsAsync(this)) + if (socket.SendPacketsAsync(this, cancellationToken)) { _cancellationToken = cancellationToken; return new ValueTask(this, _token); @@ -1379,7 +1378,10 @@ namespace System.Net.Sockets private void ThrowException(SocketError error, CancellationToken cancellationToken) { - if (error == SocketError.OperationAborted) + // Most operations will report OperationAborted when canceled. + // On Windows, SendFileAsync will report ConnectionAborted. + // There's a race here anyway, so there's no harm in also checking for ConnectionAborted in all cases. + if (error == SocketError.OperationAborted || error == SocketError.ConnectionAborted) { cancellationToken.ThrowIfCancellationRequested(); } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index a4d24c2..9138e01 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -3060,7 +3060,9 @@ namespace System.Net.Sockets return socketError == SocketError.IOPending; } - public bool SendPacketsAsync(SocketAsyncEventArgs e) + public bool SendPacketsAsync(SocketAsyncEventArgs e) => SendPacketsAsync(e, default(CancellationToken)); + + private bool SendPacketsAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken) { ThrowIfDisposed(); @@ -3082,7 +3084,7 @@ namespace System.Net.Sockets SocketError socketError; try { - socketError = e.DoOperationSendPackets(this, _handle); + socketError = e.DoOperationSendPackets(this, _handle, cancellationToken); } catch (Exception) { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs index 907170d..4e1383c 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs @@ -2100,7 +2100,7 @@ namespace System.Net.Sockets return operation.ErrorCode; } - public SocketError SendFileAsync(SafeFileHandle fileHandle, long offset, long count, out long bytesSent, Action callback) + public SocketError SendFileAsync(SafeFileHandle fileHandle, long offset, long count, out long bytesSent, Action callback, CancellationToken cancellationToken = default) { SetHandleNonBlocking(); @@ -2122,7 +2122,7 @@ namespace System.Net.Sockets BytesTransferred = bytesSent }; - if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber)) + if (!_sendQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken)) { bytesSent = operation.BytesTransferred; return operation.ErrorCode; diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs index 8812024..a8be414 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs @@ -244,7 +244,7 @@ namespace System.Net.Sockets return errorCode; } - internal SocketError DoOperationSendPackets(Socket socket, SafeSocketHandle handle) + internal SocketError DoOperationSendPackets(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken) { Debug.Assert(_sendPacketsElements != null); SendPacketsElement[] elements = (SendPacketsElement[])_sendPacketsElements.Clone(); @@ -288,7 +288,7 @@ namespace System.Net.Sockets throw; } - SocketPal.SendPacketsAsync(socket, SendPacketsFlags, elements, files, (bytesTransferred, error) => + SocketPal.SendPacketsAsync(socket, SendPacketsFlags, elements, files, cancellationToken, (bytesTransferred, error) => { if (error == SocketError.Success) { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs index d79d340..73cb36f 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs @@ -13,44 +13,48 @@ namespace System.Net.Sockets { public partial class SocketAsyncEventArgs : EventArgs, IDisposable { - // Single buffer - private MemoryHandle _singleBufferHandle; - - /// The state of the to track whether and when it requires disposal to unpin memory. + /// Tracks whether and when whether to perform cleanup of async processing state, specifically and . /// - /// Pinning via a GCHandle (the mechanism used by Memory) has measurable overhead, and for operations like - /// send and receive that we want to optimize and that frequently complete synchronously, we - /// want to avoid such GCHandle interactions whenever possible. To achieve that, we used `fixed` - /// to pin the relevant state while starting the async operation, and then only if the operation - /// is seen to be pending do we use Pin to create the GCHandle; this is done while the `fixed` is - /// still in scope, to ensure that throughout the whole operation the buffer remains pinned, while - /// using the much cheaper `fixed` only for the fast path. starts - /// life as None, transitions to InProcess prior to initiating the async operation, and then transitions - /// either back to None if the buffer never needed to be pinned, or to Set once it has been pinned. This - /// ensures that asynchronous completion racing with the code that is still setting up the operation - /// can properly clean up after pinned memory, even if it needs to wait momentarily to do so. + /// We sometimes want to or need to defer the initialization of some aspects of async processing until + /// the operation pends, that is, the OS async call returns and indicates that it will complete asynchronously. + /// + /// We do this in two cases: to optimize buffer pinning in certain cases, and to handle cancellation. + /// + /// We optimize buffer pinning for operations that frequently complete synchronously, like Send and Receive, + /// when they use a single buffer. (Optimizing for multiple buffers is more complicated, so we don't currently do it.) + /// In these cases, we used the much-cheaper `fixed` keyword to pin the buffer while the OS async call is in progress. + /// If the operation pends, we will then pin using MemoryHandle.Pin() before exiting the `fixed` scope, + /// thus ensuring that the buffer remains pinned throughout the whole operation. /// - /// Currently, only the operations that use and - /// are cancelable, and as such is also used to guard the cleanup - /// of . + /// For cancellation, we always need to defer cancellation registration until after the OS call has pended. + /// + /// To coordinate with the completion callback from the IOCP, we set to InProcess + /// before performing the OS async call, and then transition it to Set once the operation has pended + /// and the relevant async setup (pinning, cancellation registration) has occurred. + /// This ensures that the completion callback will only clean up the async state (unpin, unregister cancellation) + /// once it has been fully constructed by the original calling thread, even if it needs to wait momentarily to do so. + /// + /// For some operations, like Connect and multi-buffer Send/Receive, we do not do any deferred async processing. + /// Instead, we perform any necessary setup and set to Set before making the call. + /// The cleanup logic will be invoked as always, but in this case will never need to wait on the InProcess state. + /// If an operation does not require any cleanup of this state, it can simply leave as None. /// - private volatile SingleBufferHandleState _singleBufferHandleState; + private volatile AsyncProcessingState _asyncProcessingState; - /// Defines possible states for in order to faciliate correct cleanup of any pinned state. - private enum SingleBufferHandleState : byte + /// Defines possible states for in order to faciliate correct cleanup of asynchronous processing state. + private enum AsyncProcessingState : byte { - /// No operation using is in flight, and no cleanup of is required. + /// No cleanup is required, either because no operation is in flight or the current operation does not require cleanup. None, - /// - /// An operation potentially using is in flight, but the field hasn't yet been initialized. - /// It's possible will transition to , and thus code needs to wait for the - /// value to no longer be before can be disposed. - /// + /// An operation is in flight but async processing state has not yet been initialized. InProcess, - /// The field has been initialized and requires disposal. It is safe to dispose of when the operation no longer needs it. + /// An operation is in flight and async processing state is fully initialized and ready to be cleaned up. Set } + // Single buffer pin handle + private MemoryHandle _singleBufferHandle; + // BufferList property variables. // Note that these arrays are allocated and then grown as necessary, but never shrunk. // Thus the actual in-use length is defined by _bufferListInternal.Count, not the length of these arrays. @@ -123,10 +127,10 @@ namespace System.Net.Sockets private unsafe void RegisterToCancelPendingIO(NativeOverlapped* overlapped, CancellationToken cancellationToken) { - Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.InProcess, "An operation must be declared in-flight in order to register to cancel it."); + Debug.Assert(_asyncProcessingState == AsyncProcessingState.InProcess, "An operation must be declared in-flight in order to register to cancel it."); Debug.Assert(_pendingOverlappedForCancellation == null); _pendingOverlappedForCancellation = overlapped; - _registrationToCancelPendingIO = cancellationToken.UnsafeRegister(s => + _registrationToCancelPendingIO = cancellationToken.UnsafeRegister(static s => { // Try to cancel the I/O. We ignore the return value (other than for logging), as cancellation // is opportunistic and we don't want to fail the operation because we couldn't cancel it. @@ -161,120 +165,94 @@ namespace System.Net.Sockets _strongThisRef.Value = this; } - /// Handles the result of an IOCP operation. - /// true if the operation completed synchronously and successfully; otherwise, false. - /// The number of bytes transferred, if the operation completed synchronously and successfully. - /// The overlapped to be freed if the operation completed synchronously. - /// The result status of the operation. - private unsafe SocketError ProcessIOCPResult(bool success, int bytesTransferred, NativeOverlapped* overlapped) + /// Gets the result of an IOCP operation and determines how it should be handled (synchronously or asynchronously). + /// true if the IOCP operation indicated synchronous success; otherwise, false. + /// The overlapped that was used for this operation. Will be freed if the operation result will be handled synchronously. + /// The SocketError for the operation. This will be SocketError.IOPending if the operation will be handled asynchronously. + private unsafe SocketError GetIOCPResult(bool success, NativeOverlapped* overlapped) { - // Note: We need to dispose of the overlapped iff the operation completed synchronously, - // and if we do, we must do so before we mark the operation as completed. + // Note: We need to dispose of the overlapped iff the operation result will be handled synchronously. if (success) { // Synchronous success. if (_currentSocket!.SafeHandle.SkipCompletionPortOnSuccess) { - // The socket handle is configured to skip completion on success, - // so we can set the results right now. + // The socket handle is configured to skip completion on success, so we can handle the result synchronously. FreeNativeOverlapped(overlapped); - FinishOperationSyncSuccess(bytesTransferred, SocketFlags.None); - - if (SocketsTelemetry.Log.IsEnabled() && !_disableTelemetry) AfterConnectAcceptTelemetry(); - return SocketError.Success; } // Completed synchronously, but the handle wasn't marked as skip completion port on success, - // so we still need to fall through and behave as if the IO was pending. + // so we still need to behave as if the IO was pending and wait for the completion to come through on the IOCP. + return SocketError.IOPending; } else { // Get the socket error (which may be IOPending) SocketError socketError = SocketPal.GetLastSocketError(); + Debug.Assert(socketError != SocketError.Success); if (socketError != SocketError.IOPending) { // Completed synchronously with a failure. + // No IOCP completion will occur. FreeNativeOverlapped(overlapped); - FinishOperationSyncFailure(socketError, bytesTransferred, SocketFlags.None); - - if (SocketsTelemetry.Log.IsEnabled() && !_disableTelemetry) AfterConnectAcceptTelemetry(); - return socketError; } - // Fall through to IOPending handling for asynchronous completion. + // The completion will arrive on the IOCP when the operation is done. + return SocketError.IOPending; } - - // Socket handle is going to post a completion to the completion port (may have done so already). - // Return pending and we will continue in the completion port callback. - return SocketError.IOPending; } /// Handles the result of an IOCP operation. - /// The result status of the operation, as returned from the API call. + /// true if the IOCP operation indicated synchronous success; otherwise, false. /// The number of bytes transferred, if the operation completed synchronously and successfully. - /// The overlapped to be freed if the operation completed synchronously. - /// The cancellation token to use to cancel the operation. + /// The overlapped that was used for this operation. Will be freed if the operation result will be handled synchronously. /// The result status of the operation. - private unsafe SocketError ProcessIOCPResultWithSingleBufferHandle(SocketError socketError, int bytesTransferred, NativeOverlapped* overlapped, CancellationToken cancellationToken = default) + private unsafe SocketError ProcessIOCPResult(bool success, int bytesTransferred, NativeOverlapped* overlapped) { - // Note: We need to dispose of the overlapped iff the operation completed synchronously, - // and if we do, we must do so before we mark the operation as completed. - - if (socketError == SocketError.Success) - { - // Synchronous success. - if (_currentSocket!.SafeHandle.SkipCompletionPortOnSuccess) - { - // The socket handle is configured to skip completion on success, - // so we can set the results right now. - _singleBufferHandleState = SingleBufferHandleState.None; - FreeNativeOverlapped(overlapped); - FinishOperationSyncSuccess(bytesTransferred, SocketFlags.None); + Debug.Assert(_asyncProcessingState != AsyncProcessingState.InProcess); - if (SocketsTelemetry.Log.IsEnabled() && !_disableTelemetry) AfterConnectAcceptTelemetry(); + SocketError socketError = GetIOCPResult(success, overlapped); - return SocketError.Success; - } - - // Completed synchronously, but the handle wasn't marked as skip completion port on success, - // so we still need to fall through and behave as if the IO was pending. - } - else + if (socketError != SocketError.IOPending) { - // Get the socket error (which may be IOPending) - socketError = SocketPal.GetLastSocketError(); - if (socketError != SocketError.IOPending) - { - // Completed synchronously with a failure. - _singleBufferHandleState = SingleBufferHandleState.None; - FreeNativeOverlapped(overlapped); - FinishOperationSyncFailure(socketError, bytesTransferred, SocketFlags.None); + FinishOperationSync(socketError, bytesTransferred, SocketFlags.None); + } - if (SocketsTelemetry.Log.IsEnabled() && !_disableTelemetry) AfterConnectAcceptTelemetry(); + return socketError; + } - return socketError; - } + /// Handles the result of an IOCP operation for which we have deferred async processing logic (buffer pinning or cancellation). + /// true if the IOCP operation indicated synchronous success; otherwise, false. + /// The number of bytes transferred, if the operation completed synchronously and successfully. + /// The overlapped that was used for this operation. Will be freed if the operation result will be handled synchronously. + /// The buffer to pin. May be Memory.Empty if no buffer should be pinned. + /// Note this buffer (if not empty) should already be pinned locally using `fixed` prior to the OS async call and until after this method returns. + /// The cancellation token to use to cancel the operation. + /// The result status of the operation. + private unsafe SocketError ProcessIOCPResultWithDeferredAsyncHandling(bool success, int bytesTransferred, NativeOverlapped* overlapped, Memory bufferToPin, CancellationToken cancellationToken = default) + { + Debug.Assert(_asyncProcessingState == AsyncProcessingState.InProcess); - // Fall through to IOPending handling for asynchronous completion. - } + SocketError socketError = GetIOCPResult(success, overlapped); - // Socket handle is going to post a completion to the completion port (may have done so already). - // Return pending and we will continue in the completion port callback. - if (_singleBufferHandleState == SingleBufferHandleState.InProcess) + if (socketError == SocketError.IOPending) { - // Register for cancellation. This must happen before we change state to Set, as once it's Set, - // the operation completing asynchronously could invoke cleanup, which includes disposing of the - // cancellation registration, and thus the registration needs to be stored prior to setting Set. RegisterToCancelPendingIO(overlapped, cancellationToken); - _singleBufferHandle = _buffer.Pin(); - _singleBufferHandleState = SingleBufferHandleState.Set; + _singleBufferHandle = bufferToPin.Pin(); + + _asyncProcessingState = AsyncProcessingState.Set; + } + else + { + _asyncProcessingState = AsyncProcessingState.None; + FinishOperationSync(socketError, bytesTransferred, SocketFlags.None); } - return SocketError.IOPending; + return socketError; } internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle acceptHandle) @@ -282,13 +260,13 @@ namespace System.Net.Sockets bool userBuffer = _count != 0; Debug.Assert(!userBuffer || (!_buffer.Equals(default) && _count >= _acceptAddressBufferCount)); Memory buffer = userBuffer ? _buffer : _acceptBuffer; - Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None); + Debug.Assert(_asyncProcessingState == AsyncProcessingState.None); NativeOverlapped* overlapped = AllocateNativeOverlapped(); try { _singleBufferHandle = buffer.Pin(); - _singleBufferHandleState = SingleBufferHandleState.Set; + _asyncProcessingState = AsyncProcessingState.Set; bool success = socket.AcceptEx( handle, @@ -304,7 +282,7 @@ namespace System.Net.Sockets } catch { - _singleBufferHandleState = SingleBufferHandleState.None; + _asyncProcessingState = AsyncProcessingState.None; FreeNativeOverlapped(overlapped); _singleBufferHandle.Dispose(); throw; @@ -329,9 +307,9 @@ namespace System.Net.Sockets NativeOverlapped* overlapped = AllocateNativeOverlapped(); try { - Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None); + Debug.Assert(_asyncProcessingState == AsyncProcessingState.None); _singleBufferHandle = _buffer.Pin(); - _singleBufferHandleState = SingleBufferHandleState.Set; + _asyncProcessingState = AsyncProcessingState.Set; bool success = socket.ConnectEx( handle, @@ -346,7 +324,7 @@ namespace System.Net.Sockets } catch { - _singleBufferHandleState = SingleBufferHandleState.None; + _asyncProcessingState = AsyncProcessingState.None; FreeNativeOverlapped(overlapped); _singleBufferHandle.Dispose(); throw; @@ -355,22 +333,23 @@ namespace System.Net.Sockets internal unsafe SocketError DoOperationDisconnect(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken) { - // Note: CancellationToken is ignored for now. - // See https://github.com/dotnet/runtime/issues/51452 - NativeOverlapped* overlapped = AllocateNativeOverlapped(); try { + Debug.Assert(_asyncProcessingState == AsyncProcessingState.None, $"Expected None, got {_asyncProcessingState}"); + _asyncProcessingState = AsyncProcessingState.InProcess; + bool success = socket.DisconnectEx( handle, overlapped, (int)(DisconnectReuseSocket ? TransmitFileOptions.ReuseSocket : 0), 0); - return ProcessIOCPResult(success, 0, overlapped); + return ProcessIOCPResultWithDeferredAsyncHandling(success, 0, overlapped, Memory.Empty, cancellationToken); } catch { + _asyncProcessingState = AsyncProcessingState.None; FreeNativeOverlapped(overlapped); throw; } @@ -387,8 +366,8 @@ namespace System.Net.Sockets NativeOverlapped* overlapped = AllocateNativeOverlapped(); try { - Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None, $"Expected None, got {_singleBufferHandleState}"); - _singleBufferHandleState = SingleBufferHandleState.InProcess; + Debug.Assert(_asyncProcessingState == AsyncProcessingState.None, $"Expected None, got {_asyncProcessingState}"); + _asyncProcessingState = AsyncProcessingState.InProcess; var wsaBuffer = new WSABuffer { Length = _count, Pointer = (IntPtr)(bufferPtr + _offset) }; SocketFlags flags = _socketFlags; @@ -401,11 +380,11 @@ namespace System.Net.Sockets overlapped, IntPtr.Zero); - return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken); + return ProcessIOCPResultWithDeferredAsyncHandling(socketError == SocketError.Success, bytesTransferred, overlapped, _buffer, cancellationToken); } catch { - _singleBufferHandleState = SingleBufferHandleState.None; + _asyncProcessingState = AsyncProcessingState.None; FreeNativeOverlapped(overlapped); throw; } @@ -457,8 +436,8 @@ namespace System.Net.Sockets NativeOverlapped* overlapped = AllocateNativeOverlapped(); try { - Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None); - _singleBufferHandleState = SingleBufferHandleState.InProcess; + Debug.Assert(_asyncProcessingState == AsyncProcessingState.None); + _asyncProcessingState = AsyncProcessingState.InProcess; var wsaBuffer = new WSABuffer { Length = _count, Pointer = (IntPtr)(bufferPtr + _offset) }; SocketFlags flags = _socketFlags; @@ -473,11 +452,11 @@ namespace System.Net.Sockets overlapped, IntPtr.Zero); - return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken); + return ProcessIOCPResultWithDeferredAsyncHandling(socketError == SocketError.Success, bytesTransferred, overlapped, _buffer, cancellationToken); } catch { - _singleBufferHandleState = SingleBufferHandleState.None; + _asyncProcessingState = AsyncProcessingState.None; FreeNativeOverlapped(overlapped); throw; } @@ -552,8 +531,8 @@ namespace System.Net.Sockets fixed (byte* bufferPtr = &MemoryMarshal.GetReference(_buffer.Span)) { - Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None); - _singleBufferHandleState = SingleBufferHandleState.InProcess; + Debug.Assert(_asyncProcessingState == AsyncProcessingState.None); + _asyncProcessingState = AsyncProcessingState.InProcess; _wsaRecvMsgWSABufferArrayPinned[0].Pointer = (IntPtr)bufferPtr + _offset; _wsaRecvMsgWSABufferArrayPinned[0].Length = _count; @@ -608,12 +587,12 @@ namespace System.Net.Sockets IntPtr.Zero); return _bufferList == null ? - ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken) : + ProcessIOCPResultWithDeferredAsyncHandling(socketError == SocketError.Success, bytesTransferred, overlapped, _buffer, cancellationToken) : ProcessIOCPResult(socketError == SocketError.Success, bytesTransferred, overlapped); } catch { - _singleBufferHandleState = SingleBufferHandleState.None; + _asyncProcessingState = AsyncProcessingState.None; FreeNativeOverlapped(overlapped); throw; } @@ -631,8 +610,8 @@ namespace System.Net.Sockets NativeOverlapped* overlapped = AllocateNativeOverlapped(); try { - Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None); - _singleBufferHandleState = SingleBufferHandleState.InProcess; + Debug.Assert(_asyncProcessingState == AsyncProcessingState.None); + _asyncProcessingState = AsyncProcessingState.InProcess; var wsaBuffer = new WSABuffer { Length = _count, Pointer = (IntPtr)(bufferPtr + _offset) }; SocketError socketError = Interop.Winsock.WSASend( @@ -644,11 +623,11 @@ namespace System.Net.Sockets overlapped, IntPtr.Zero); - return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken); + return ProcessIOCPResultWithDeferredAsyncHandling(socketError == SocketError.Success, bytesTransferred, overlapped, _buffer, cancellationToken); } catch { - _singleBufferHandleState = SingleBufferHandleState.None; + _asyncProcessingState = AsyncProcessingState.None; FreeNativeOverlapped(overlapped); throw; } @@ -678,7 +657,7 @@ namespace System.Net.Sockets } } - internal unsafe SocketError DoOperationSendPackets(Socket socket, SafeSocketHandle handle) + internal unsafe SocketError DoOperationSendPackets(Socket socket, SafeSocketHandle handle, CancellationToken cancellationToken) { // Cache copy to avoid problems with concurrent manipulation during the async operation. Debug.Assert(_sendPacketsElements != null); @@ -757,6 +736,9 @@ namespace System.Net.Sockets NativeOverlapped* overlapped = AllocateNativeOverlapped(); try { + Debug.Assert(_asyncProcessingState == AsyncProcessingState.None); + _asyncProcessingState = AsyncProcessingState.InProcess; + bool result = socket.TransmitPackets( handle, Marshal.UnsafeAddrOfPinnedArrayElement(sendPacketsDescriptorPinned, 0), @@ -765,10 +747,11 @@ namespace System.Net.Sockets overlapped, _sendPacketsFlags); - return ProcessIOCPResult(result, 0, overlapped); + return ProcessIOCPResultWithDeferredAsyncHandling(result, 0, overlapped, Memory.Empty, cancellationToken); } catch { + _asyncProcessingState = AsyncProcessingState.None; FreeNativeOverlapped(overlapped); throw; } @@ -796,8 +779,8 @@ namespace System.Net.Sockets NativeOverlapped* overlapped = AllocateNativeOverlapped(); try { - Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.None); - _singleBufferHandleState = SingleBufferHandleState.InProcess; + Debug.Assert(_asyncProcessingState == AsyncProcessingState.None); + _asyncProcessingState = AsyncProcessingState.InProcess; var wsaBuffer = new WSABuffer { Length = _count, Pointer = (IntPtr)(bufferPtr + _offset) }; SocketError socketError = Interop.Winsock.WSASendTo( @@ -811,11 +794,11 @@ namespace System.Net.Sockets overlapped, IntPtr.Zero); - return ProcessIOCPResultWithSingleBufferHandle(socketError, bytesTransferred, overlapped, cancellationToken); + return ProcessIOCPResultWithDeferredAsyncHandling(socketError == SocketError.Success, bytesTransferred, overlapped, _buffer, cancellationToken); } catch { - _singleBufferHandleState = SingleBufferHandleState.None; + _asyncProcessingState = AsyncProcessingState.None; FreeNativeOverlapped(overlapped); throw; } @@ -957,9 +940,9 @@ namespace System.Net.Sockets { _pinState = PinState.None; - if (_singleBufferHandleState != SingleBufferHandleState.None) + if (_asyncProcessingState != AsyncProcessingState.None) { - _singleBufferHandleState = SingleBufferHandleState.None; + _asyncProcessingState = AsyncProcessingState.None; _singleBufferHandle.Dispose(); } @@ -1105,7 +1088,7 @@ namespace System.Net.Sockets safeHandle.DangerousAddRef(ref refAdded); IntPtr handle = safeHandle.DangerousGetHandle(); - Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set); + Debug.Assert(_asyncProcessingState == AsyncProcessingState.Set); bool userBuffer = _count >= _acceptAddressBufferCount; _currentSocket.GetAcceptExSockaddrs( @@ -1181,7 +1164,7 @@ namespace System.Net.Sockets { _strongThisRef.Value = null; // null out this reference from the overlapped so this isn't kept alive artificially - if (_singleBufferHandleState != SingleBufferHandleState.None) + if (_asyncProcessingState != AsyncProcessingState.None) { // If the state isn't None, then either it's Set, in which case there's state to cleanup, // or it's InProcess, which can happen if the async operation was scheduled and actually @@ -1201,30 +1184,27 @@ namespace System.Net.Sockets // initiating it. Wait until that initiation code has completed before // we try to undo the state it configures. SpinWait sw = default; - while (_singleBufferHandleState == SingleBufferHandleState.InProcess) + while (_asyncProcessingState == AsyncProcessingState.InProcess) { sw.SpinOnce(); } - Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set); + Debug.Assert(_asyncProcessingState == AsyncProcessingState.Set); // Remove any cancellation registration. First dispose the registration // to ensure that cancellation will either never fine or will have completed // firing before we continue. Only then can we safely null out the overlapped. _registrationToCancelPendingIO.Dispose(); + _registrationToCancelPendingIO = default; unsafe { _pendingOverlappedForCancellation = null; } // Release any GC handles. - Debug.Assert(_singleBufferHandleState == SingleBufferHandleState.Set); + _singleBufferHandle.Dispose(); - if (_singleBufferHandleState == SingleBufferHandleState.Set) - { - _singleBufferHandleState = SingleBufferHandleState.None; - _singleBufferHandle.Dispose(); - } + _asyncProcessingState = AsyncProcessingState.None; } } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs index 0e64fc1..7006ff6 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs @@ -7,6 +7,7 @@ using System.Diagnostics; using System.IO; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using System.Threading; using System.Threading.Tasks; using Microsoft.Win32.SafeHandles; @@ -1906,10 +1907,10 @@ namespace System.Net.Sockets return GetSocketErrorForErrorCode(err); } - private static SocketError SendFileAsync(SafeSocketHandle handle, FileStream fileStream, long offset, long count, Action callback) + private static SocketError SendFileAsync(SafeSocketHandle handle, FileStream fileStream, long offset, long count, CancellationToken cancellationToken, Action callback) { long bytesSent; - SocketError socketError = handle.AsyncContext.SendFileAsync(fileStream.SafeFileHandle, offset, count, out bytesSent, callback); + SocketError socketError = handle.AsyncContext.SendFileAsync(fileStream.SafeFileHandle, offset, count, out bytesSent, callback, cancellationToken); if (socketError == SocketError.Success) { callback(bytesSent, SocketError.Success); @@ -1918,7 +1919,7 @@ namespace System.Net.Sockets } public static async void SendPacketsAsync( - Socket socket, TransmitFileOptions options, SendPacketsElement[] elements, FileStream[] files, Action callback) + Socket socket, TransmitFileOptions options, SendPacketsElement[] elements, FileStream[] files, CancellationToken cancellationToken, Action callback) { SocketError error = SocketError.Success; long bytesTransferred = 0; @@ -1932,7 +1933,7 @@ namespace System.Net.Sockets { if (e.MemoryBuffer != null) { - bytesTransferred += await socket.SendAsync(e.MemoryBuffer.Value, SocketFlags.None).ConfigureAwait(false); + bytesTransferred += await socket.SendAsync(e.MemoryBuffer.Value, SocketFlags.None, cancellationToken).ConfigureAwait(false); } else { @@ -1945,6 +1946,7 @@ namespace System.Net.Sockets var tcs = new TaskCompletionSource(); error = SendFileAsync(socket.InternalSafeHandle, fs, e.OffsetLong, e.Count > 0 ? e.Count : fs.Length - e.OffsetLong, + cancellationToken, (transferred, se) => { bytesTransferred += transferred; diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendFile.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendFile.cs index 319403d..31f8f6c 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendFile.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendFile.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Xunit; @@ -420,6 +421,60 @@ namespace System.Net.Sockets.Tests public sealed class SendFile_Task : SendFile { public SendFile_Task(ITestOutputHelper output) : base(output) { } + + [Fact] + public async Task Precanceled_Throws() + { + using (var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + listener.BindToAnonymousPort(IPAddress.Loopback); + listener.Listen(1); + + await client.ConnectAsync(listener.LocalEndPoint); + using (Socket server = await listener.AcceptAsync()) + { + var cts = new CancellationTokenSource(); + cts.Cancel(); + + await Assert.ThrowsAnyAsync(async () => await server.SendFileAsync(null, ReadOnlyMemory.Empty, ReadOnlyMemory.Empty, TransmitFileOptions.UseDefaultWorkerThread, cts.Token)); + } + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task SendAsync_CanceledDuringOperation_Throws(bool ipv6) + { + const int CancelAfter = 200; // ms + const int NumOfSends = 100; + const int SendBufferSize = 1024; + + (Socket client, Socket server) = SocketTestExtensions.CreateConnectedSocketPair(ipv6); + byte[] buffer = new byte[1024 * 64]; + using (client) + using (server) + { + client.SendBufferSize = SendBufferSize; + CancellationTokenSource cts = new CancellationTokenSource(); + + List tasks = new List(); + + // After flooding the socket with a high number of SendFile tasks, + // we assume some of them won't complete before the "CancelAfter" period expires. + for (int i = 0; i < NumOfSends; i++) + { + var task = server.SendFileAsync(null, buffer, ReadOnlyMemory.Empty, TransmitFileOptions.UseDefaultWorkerThread, cts.Token).AsTask(); + tasks.Add(task); + } + + cts.CancelAfter(CancelAfter); + + // We shall see at least one cancellation amongst all the scheduled sends: + await Assert.ThrowsAnyAsync(() => Task.WhenAll(tasks)); + } + } } public sealed class SendFile_Apm : SendFile -- 2.7.4