From b5c91e4c29359f160edcf7caf16530e48d9a4fb0 Mon Sep 17 00:00:00 2001 From: Geoff Kizer Date: Sat, 29 May 2021 12:10:10 -0700 Subject: [PATCH] add AcceptAsync cancellation overloads (#53340) * add AcceptAsync cancellation overloads * pass cancellationToken to AcceptAsync in Unix NamedPipe impl * add TcpListener overloads too * enable pipe cancellation test on Unix Co-authored-by: Geoffrey Kizer --- .../System/IO/Pipes/NamedPipeServerStream.Unix.cs | 2 +- .../tests/PipeStreamConformanceTests.cs | 10 +- .../System.Net.Sockets/ref/System.Net.Sockets.cs | 4 + .../src/System/Net/Sockets/Socket.Tasks.cs | 163 +++++++++------------ .../src/System/Net/Sockets/Socket.cs | 6 +- .../System/Net/Sockets/SocketAsyncContext.Unix.cs | 4 +- .../Net/Sockets/SocketAsyncEventArgs.Unix.cs | 4 +- .../Net/Sockets/SocketAsyncEventArgs.Windows.cs | 79 +++++----- .../src/System/Net/Sockets/TCPListener.cs | 16 +- .../tests/FunctionalTests/Accept.cs | 44 ++++++ .../tests/FunctionalTests/SendFile.cs | 9 +- .../tests/FunctionalTests/SocketTestHelper.cs | 4 +- .../tests/FunctionalTests/TcpListenerTest.cs | 5 +- 13 files changed, 198 insertions(+), 152 deletions(-) diff --git a/src/libraries/System.IO.Pipes/src/System/IO/Pipes/NamedPipeServerStream.Unix.cs b/src/libraries/System.IO.Pipes/src/System/IO/Pipes/NamedPipeServerStream.Unix.cs index 4628fe4..8a309dc 100644 --- a/src/libraries/System.IO.Pipes/src/System/IO/Pipes/NamedPipeServerStream.Unix.cs +++ b/src/libraries/System.IO.Pipes/src/System/IO/Pipes/NamedPipeServerStream.Unix.cs @@ -78,7 +78,7 @@ namespace System.IO.Pipes WaitForConnectionAsyncCore(); async Task WaitForConnectionAsyncCore() => - HandleAcceptedSocket(await _instance!.ListeningSocket.AcceptAsync().ConfigureAwait(false)); + HandleAcceptedSocket(await _instance!.ListeningSocket.AcceptAsync(cancellationToken).ConfigureAwait(false)); } private void HandleAcceptedSocket(Socket acceptedSocket) diff --git a/src/libraries/System.IO.Pipes/tests/PipeStreamConformanceTests.cs b/src/libraries/System.IO.Pipes/tests/PipeStreamConformanceTests.cs index b4962b6..805bc29 100644 --- a/src/libraries/System.IO.Pipes/tests/PipeStreamConformanceTests.cs +++ b/src/libraries/System.IO.Pipes/tests/PipeStreamConformanceTests.cs @@ -201,14 +201,10 @@ namespace System.IO.Pipes.Tests var ctx = new CancellationTokenSource(); - if (OperatingSystem.IsWindows()) // cancellation token after the operation has been initiated - { - Task serverWaitTimeout = server.WaitForConnectionAsync(ctx.Token); - ctx.Cancel(); - await Assert.ThrowsAnyAsync(() => serverWaitTimeout); - } - + Task serverWaitTimeout = server.WaitForConnectionAsync(ctx.Token); ctx.Cancel(); + await Assert.ThrowsAnyAsync(() => serverWaitTimeout); + Assert.True(server.WaitForConnectionAsync(ctx.Token).IsCanceled); } diff --git a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs index b128e5c..0ea7f6b 100644 --- a/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs +++ b/src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs @@ -297,7 +297,9 @@ namespace System.Net.Sockets public bool UseOnlyOverlappedIO { get { throw null; } set { } } public System.Net.Sockets.Socket Accept() { throw null; } public System.Threading.Tasks.Task AcceptAsync() { throw null; } + public System.Threading.Tasks.ValueTask AcceptAsync(System.Threading.CancellationToken cancellationToken) { throw null; } public System.Threading.Tasks.Task AcceptAsync(System.Net.Sockets.Socket? acceptSocket) { throw null; } + public System.Threading.Tasks.ValueTask AcceptAsync(System.Net.Sockets.Socket? acceptSocket, System.Threading.CancellationToken cancellationToken) { throw null; } public bool AcceptAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; } public System.IAsyncResult BeginAccept(System.AsyncCallback? callback, object? state) { throw null; } public System.IAsyncResult BeginAccept(int receiveSize, System.AsyncCallback? callback, object? state) { throw null; } @@ -691,8 +693,10 @@ namespace System.Net.Sockets public System.Net.Sockets.Socket Server { get { throw null; } } public System.Net.Sockets.Socket AcceptSocket() { throw null; } public System.Threading.Tasks.Task AcceptSocketAsync() { throw null; } + public System.Threading.Tasks.ValueTask AcceptSocketAsync(System.Threading.CancellationToken cancellationToken) { throw null; } public System.Net.Sockets.TcpClient AcceptTcpClient() { throw null; } public System.Threading.Tasks.Task AcceptTcpClientAsync() { throw null; } + public System.Threading.Tasks.ValueTask AcceptTcpClientAsync(System.Threading.CancellationToken cancellationToken) { throw null; } [System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")] public void AllowNatTraversal(bool allowed) { } public System.IAsyncResult BeginAcceptSocket(System.AsyncCallback? callback, object? state) { throw null; } diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs index fb2c911..46fecd7 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs @@ -15,12 +15,9 @@ namespace System.Net.Sockets { public partial class Socket { - /// Cached instance for accept operations. - private TaskSocketAsyncEventArgs? _acceptEventArgs; - /// Cached instance for receive operations that return . Also used for ConnectAsync operations. private AwaitableSocketAsyncEventArgs? _singleBufferReceiveEventArgs; - /// Cached instance for send operations that return . + /// Cached instance for send operations that return . Also used for AcceptAsync operations. private AwaitableSocketAsyncEventArgs? _singleBufferSendEventArgs; /// Cached instance for receive operations that return . @@ -32,54 +29,44 @@ namespace System.Net.Sockets /// Accepts an incoming connection. /// /// An asynchronous task that completes with the accepted Socket. - public Task AcceptAsync() => AcceptAsync((Socket?)null); + public Task AcceptAsync() => AcceptAsync((Socket?)null, CancellationToken.None).AsTask(); /// /// Accepts an incoming connection. /// - /// The socket to use for accepting the connection. + /// A cancellation token that can be used to cancel the asynchronous operation. /// An asynchronous task that completes with the accepted Socket. - public Task AcceptAsync(Socket? acceptSocket) - { - // Get any cached SocketAsyncEventArg we may have. - TaskSocketAsyncEventArgs? saea = Interlocked.Exchange(ref _acceptEventArgs, null); - if (saea is null) - { - saea = new TaskSocketAsyncEventArgs(); - saea.Completed += (s, e) => CompleteAccept((Socket)s!, (TaskSocketAsyncEventArgs)e); - } + public ValueTask AcceptAsync(CancellationToken cancellationToken) => AcceptAsync((Socket?)null, cancellationToken); - // Configure the SAEA. - saea.AcceptSocket = acceptSocket; + /// + /// Accepts an incoming connection. + /// + /// The socket to use for accepting the connection. + /// An asynchronous task that completes with the accepted Socket. + public Task AcceptAsync(Socket? acceptSocket) => AcceptAsync(acceptSocket, CancellationToken.None).AsTask(); - // Initiate the accept operation. - Task t; - if (AcceptAsync(saea)) + /// + /// Accepts an incoming connection. + /// + /// The socket to use for accepting the connection. + /// A cancellation token that can be used to cancel the asynchronous operation. + /// An asynchronous task that completes with the accepted Socket. + public ValueTask AcceptAsync(Socket? acceptSocket, CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) { - // The operation is completing asynchronously (it may have already completed). - // Get the task for the operation, with appropriate synchronization to coordinate - // with the async callback that'll be completing the task. - bool responsibleForReturningToPool; - t = saea.GetCompletionResponsibility(out responsibleForReturningToPool).Task; - if (responsibleForReturningToPool) - { - // We're responsible for returning it only if the callback has already been invoked - // and gotten what it needs from the SAEA; otherwise, the callback will return it. - ReturnSocketAsyncEventArgs(saea); - } + return ValueTask.FromCanceled(cancellationToken); } - else - { - // The operation completed synchronously. Get a task for it. - t = saea.SocketError == SocketError.Success ? - Task.FromResult(saea.AcceptSocket!) : - Task.FromException(GetException(saea.SocketError)); - // There won't be a callback, and we're done with the SAEA, so return it to the pool. - ReturnSocketAsyncEventArgs(saea); - } + AwaitableSocketAsyncEventArgs saea = + Interlocked.Exchange(ref _singleBufferSendEventArgs, null) ?? + new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: false); - return t; + Debug.Assert(saea.BufferList == null); + saea.SetBuffer(null, 0, 0); + saea.AcceptSocket = acceptSocket; + saea.WrapExceptionsForNetworkStream = false; + return saea.AcceptAsync(this, cancellationToken); } /// @@ -739,34 +726,6 @@ namespace System.Net.Sockets } /// Completes the SocketAsyncEventArg's Task with the result of the send or receive, and returns it to the specified pool. - private static void CompleteAccept(Socket s, TaskSocketAsyncEventArgs saea) - { - // Pull the relevant state off of the SAEA - SocketError error = saea.SocketError; - Socket? acceptSocket = saea.AcceptSocket; - - // Synchronize with the initiating thread. If the synchronous caller already got what - // it needs from the SAEA, then we can return it to the pool now. Otherwise, it'll be - // responsible for returning it once it's gotten what it needs from it. - bool responsibleForReturningToPool; - AsyncTaskMethodBuilder builder = saea.GetCompletionResponsibility(out responsibleForReturningToPool); - if (responsibleForReturningToPool) - { - s.ReturnSocketAsyncEventArgs(saea); - } - - // Complete the builder/task with the results. - if (error == SocketError.Success) - { - builder.SetResult(acceptSocket!); - } - else - { - builder.SetException(GetException(error)); - } - } - - /// Completes the SocketAsyncEventArg's Task with the result of the send or receive, and returns it to the specified pool. private static void CompleteSendReceive(Socket s, TaskSocketAsyncEventArgs saea, bool isReceive) { // Pull the relevant state off of the SAEA @@ -824,29 +783,9 @@ namespace System.Net.Sockets } } - /// Returns a instance for reuse. - /// The instance to return. - private void ReturnSocketAsyncEventArgs(TaskSocketAsyncEventArgs saea) - { - // Reset state on the SAEA before returning it. But do not reset buffer state. That'll be done - // if necessary by the consumer, but we want to keep the buffers due to likely subsequent reuse - // and the costs associated with changing them. - saea.AcceptSocket = null; - saea._accessed = false; - saea._builder = default; - - // Write this instance back as a cached instance, only if there isn't currently one cached. - if (Interlocked.CompareExchange(ref _acceptEventArgs, saea, null) != null) - { - // Couldn't return it, so dispose it. - saea.Dispose(); - } - } - /// Dispose of any cached instances. private void DisposeCachedTaskSocketAsyncEventArgs() { - Interlocked.Exchange(ref _acceptEventArgs, null)?.Dispose(); Interlocked.Exchange(ref _multiBufferReceiveEventArgs, null)?.Dispose(); Interlocked.Exchange(ref _multiBufferSendEventArgs, null)?.Dispose(); Interlocked.Exchange(ref _singleBufferReceiveEventArgs, null)?.Dispose(); @@ -907,7 +846,7 @@ namespace System.Net.Sockets } /// A SocketAsyncEventArgs that can be awaited to get the result of an operation. - internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource, IValueTaskSource, IValueTaskSource + internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource, IValueTaskSource, IValueTaskSource, IValueTaskSource { private static readonly Action s_completedSentinel = new Action(state => throw new InvalidOperationException(SR.Format(SR.net_sockets_valuetaskmisuse, nameof(s_completedSentinel)))); /// The owning socket. @@ -987,6 +926,28 @@ namespace System.Net.Sockets } } + /// Initiates an accept operation on the associated socket. + /// This instance. + public ValueTask AcceptAsync(Socket socket, CancellationToken cancellationToken) + { + Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use"); + + if (socket.AcceptAsync(this, cancellationToken)) + { + _cancellationToken = cancellationToken; + return new ValueTask(this, _token); + } + + Socket acceptSocket = AcceptSocket!; + SocketError error = SocketError; + + Release(); + + return error == SocketError.Success ? + new ValueTask(acceptSocket) : + ValueTask.FromException(CreateException(error)); + } + /// Initiates a receive operation on the associated socket. /// This instance. public ValueTask ReceiveAsync(Socket socket, CancellationToken cancellationToken) @@ -1288,7 +1249,7 @@ namespace System.Net.Sockets /// Unlike TaskAwaiter's GetResult, this does not block until the operation completes: it must only /// be used once the operation has completed. This is handled implicitly by await. /// - public int GetResult(short token) + int IValueTaskSource.GetResult(short token) { if (token != _token) { @@ -1326,6 +1287,26 @@ namespace System.Net.Sockets } } + Socket IValueTaskSource.GetResult(short token) + { + if (token != _token) + { + ThrowIncorrectTokenException(); + } + + SocketError error = SocketError; + Socket acceptSocket = AcceptSocket!; + CancellationToken cancellationToken = _cancellationToken; + + Release(); + + if (error != SocketError.Success) + { + ThrowException(error, cancellationToken); + } + return acceptSocket; + } + SocketReceiveFromResult IValueTaskSource.GetResult(short token) { if (token != _token) diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs index 9138e01..2bf4645 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs @@ -2656,7 +2656,9 @@ namespace System.Net.Sockets // Async methods // - public bool AcceptAsync(SocketAsyncEventArgs e) + public bool AcceptAsync(SocketAsyncEventArgs e) => AcceptAsync(e, CancellationToken.None); + + private bool AcceptAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken) { ThrowIfDisposed(); @@ -2689,7 +2691,7 @@ namespace System.Net.Sockets SocketError socketError; try { - socketError = e.DoOperationAccept(this, _handle, acceptHandle); + socketError = e.DoOperationAccept(this, _handle, acceptHandle, cancellationToken); } catch (Exception ex) { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs index 4e1383c..a9bf3de 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs @@ -1433,7 +1433,7 @@ namespace System.Net.Sockets return operation.ErrorCode; } - public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, Action callback) + public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, Action callback, CancellationToken cancellationToken) { Debug.Assert(socketAddress != null, "Expected non-null socketAddress"); Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}"); @@ -1456,7 +1456,7 @@ namespace System.Net.Sockets operation.SocketAddress = socketAddress; operation.SocketAddressLen = socketAddressLen; - if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber)) + if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken)) { socketAddressLen = operation.SocketAddressLen; acceptedFd = operation.AcceptedFileDescriptor; diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs index a8be414..28d3016 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs @@ -51,7 +51,7 @@ namespace System.Net.Sockets _acceptAddressBufferCount = socketAddressSize; } - internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle? acceptHandle) + internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle? acceptHandle, CancellationToken cancellationToken) { if (!_buffer.Equals(default)) { @@ -64,7 +64,7 @@ namespace System.Net.Sockets IntPtr acceptedFd; int socketAddressLen = _acceptAddressBufferCount / 2; - SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, ref socketAddressLen, out acceptedFd, AcceptCompletionCallback); + SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, ref socketAddressLen, out acceptedFd, AcceptCompletionCallback, cancellationToken); if (socketError != SocketError.IOPending) { diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs index 73cb36f..d8d59e9 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs @@ -255,37 +255,38 @@ namespace System.Net.Sockets return socketError; } - internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle acceptHandle) + internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle acceptHandle, CancellationToken cancellationToken) { bool userBuffer = _count != 0; Debug.Assert(!userBuffer || (!_buffer.Equals(default) && _count >= _acceptAddressBufferCount)); Memory buffer = userBuffer ? _buffer : _acceptBuffer; - Debug.Assert(_asyncProcessingState == AsyncProcessingState.None); - NativeOverlapped* overlapped = AllocateNativeOverlapped(); - try + fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer.Span)) { - _singleBufferHandle = buffer.Pin(); - _asyncProcessingState = AsyncProcessingState.Set; + NativeOverlapped* overlapped = AllocateNativeOverlapped(); + try + { + Debug.Assert(_asyncProcessingState == AsyncProcessingState.None, $"Expected None, got {_asyncProcessingState}"); + _asyncProcessingState = AsyncProcessingState.InProcess; - bool success = socket.AcceptEx( - handle, - acceptHandle, - userBuffer ? (IntPtr)((byte*)_singleBufferHandle.Pointer + _offset) : (IntPtr)_singleBufferHandle.Pointer, - userBuffer ? _count - _acceptAddressBufferCount : 0, - _acceptAddressBufferCount / 2, - _acceptAddressBufferCount / 2, - out int bytesTransferred, - overlapped); + bool success = socket.AcceptEx( + handle, + acceptHandle, + (IntPtr)(userBuffer ? (bufferPtr + _offset) : bufferPtr), + userBuffer ? _count - _acceptAddressBufferCount : 0, + _acceptAddressBufferCount / 2, + _acceptAddressBufferCount / 2, + out int bytesTransferred, + overlapped); - return ProcessIOCPResult(success, bytesTransferred, overlapped); - } - catch - { - _asyncProcessingState = AsyncProcessingState.None; - FreeNativeOverlapped(overlapped); - _singleBufferHandle.Dispose(); - throw; + return ProcessIOCPResultWithDeferredAsyncHandling(success, bytesTransferred, overlapped, buffer, cancellationToken); + } + catch + { + _asyncProcessingState = AsyncProcessingState.None; + FreeNativeOverlapped(overlapped); + throw; + } } } @@ -1088,20 +1089,26 @@ namespace System.Net.Sockets safeHandle.DangerousAddRef(ref refAdded); IntPtr handle = safeHandle.DangerousGetHandle(); - Debug.Assert(_asyncProcessingState == AsyncProcessingState.Set); - bool userBuffer = _count >= _acceptAddressBufferCount; - - _currentSocket.GetAcceptExSockaddrs( - userBuffer ? (IntPtr)((byte*)_singleBufferHandle.Pointer + _offset) : (IntPtr)_singleBufferHandle.Pointer, - _count != 0 ? _count - _acceptAddressBufferCount : 0, - _acceptAddressBufferCount / 2, - _acceptAddressBufferCount / 2, - out localAddr, - out localAddrLength, - out remoteAddr, - out remoteSocketAddress.InternalSize + // This matches the logic in DoOperationAccept + bool userBuffer = _count != 0; + Debug.Assert(!userBuffer || (!_buffer.Equals(default) && _count >= _acceptAddressBufferCount)); + Memory buffer = userBuffer ? _buffer : _acceptBuffer; + + fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer.Span)) + { + _currentSocket.GetAcceptExSockaddrs( + (IntPtr)(userBuffer ? (bufferPtr + _offset) : bufferPtr), + userBuffer ? _count - _acceptAddressBufferCount : 0, + _acceptAddressBufferCount / 2, + _acceptAddressBufferCount / 2, + out localAddr, + out localAddrLength, + out remoteAddr, + out remoteSocketAddress.InternalSize ); - Marshal.Copy(remoteAddr, remoteSocketAddress.Buffer, 0, remoteSocketAddress.Size); + + Marshal.Copy(remoteAddr, remoteSocketAddress.Buffer, 0, remoteSocketAddress.Size); + } socketError = Interop.Winsock.setsockopt( _acceptSocket!.SafeHandle, diff --git a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/TCPListener.cs b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/TCPListener.cs index e622257..61220e8 100644 --- a/src/libraries/System.Net.Sockets/src/System/Net/Sockets/TCPListener.cs +++ b/src/libraries/System.Net.Sockets/src/System/Net/Sockets/TCPListener.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Threading; using System.Threading.Tasks; using System.Runtime.Versioning; using System.Diagnostics; @@ -220,25 +221,28 @@ namespace System.Net.Sockets public TcpClient EndAcceptTcpClient(IAsyncResult asyncResult) => EndAcceptCore(asyncResult); - public Task AcceptSocketAsync() + public Task AcceptSocketAsync() => AcceptSocketAsync(CancellationToken.None).AsTask(); + + public ValueTask AcceptSocketAsync(CancellationToken cancellationToken) { if (!_active) { throw new InvalidOperationException(SR.net_stopped); } - return _serverSocket!.AcceptAsync(); + return _serverSocket!.AcceptAsync(cancellationToken); } - public Task AcceptTcpClientAsync() + public Task AcceptTcpClientAsync() => AcceptTcpClientAsync(CancellationToken.None).AsTask(); + + public ValueTask AcceptTcpClientAsync(CancellationToken cancellationToken) { - return WaitAndWrap(AcceptSocketAsync()); + return WaitAndWrap(AcceptSocketAsync(cancellationToken)); - static async Task WaitAndWrap(Task task) => + static async ValueTask WaitAndWrap(ValueTask task) => new TcpClient(await task.ConfigureAwait(false)); } - // This creates a TcpListener that listens on both IPv4 and IPv6 on the given port. public static TcpListener Create(int port) { diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Accept.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Accept.cs index 1427708..73c9310 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/Accept.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/Accept.cs @@ -397,6 +397,50 @@ namespace System.Net.Sockets.Tests public AcceptTask(ITestOutputHelper output) : base(output) {} } + public sealed class AcceptCancellableTask : Accept + { + public AcceptCancellableTask(ITestOutputHelper output) : base(output) { } + + [Fact] + public async Task AcceptAsync_Precanceled_Throws() + { + using (Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + int port = listen.BindToAnonymousPort(IPAddress.Loopback); + listen.Listen(1); + + var cts = new CancellationTokenSource(); + cts.Cancel(); + + var acceptTask = listen.AcceptAsync(cts.Token); + Assert.True(acceptTask.IsCompleted); + + var oce = await Assert.ThrowsAnyAsync(async () => await acceptTask); + Assert.Equal(cts.Token, oce.CancellationToken); + } + } + + [Fact] + public async Task AcceptAsync_CanceledDuringOperation_Throws() + { + using (Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + int port = listen.BindToAnonymousPort(IPAddress.Loopback); + listen.Listen(1); + + var cts = new CancellationTokenSource(); + + var acceptTask = listen.AcceptAsync(cts.Token); + Assert.False(acceptTask.IsCompleted); + + cts.Cancel(); + + var oce = await Assert.ThrowsAnyAsync(async () => await acceptTask); + Assert.Equal(cts.Token, oce.CancellationToken); + } + } + } + public sealed class AcceptEap : Accept { public AcceptEap(ITestOutputHelper output) : base(output) {} diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendFile.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendFile.cs index 31f8f6c..cda5d75 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendFile.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SendFile.cs @@ -421,9 +421,14 @@ namespace System.Net.Sockets.Tests public sealed class SendFile_Task : SendFile { public SendFile_Task(ITestOutputHelper output) : base(output) { } + } + + public sealed class SendFile_CancellableTask : SendFile + { + public SendFile_CancellableTask(ITestOutputHelper output) : base(output) { } [Fact] - public async Task Precanceled_Throws() + public async Task SendFileAsync_Precanceled_Throws() { using (var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) @@ -445,7 +450,7 @@ namespace System.Net.Sockets.Tests [Theory] [InlineData(false)] [InlineData(true)] - public async Task SendAsync_CanceledDuringOperation_Throws(bool ipv6) + public async Task SendFileAsync_CanceledDuringOperation_Throws(bool ipv6) { const int CancelAfter = 200; // ms const int NumOfSends = 100; diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs index 7c40f83..e75013e 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs @@ -263,11 +263,11 @@ namespace System.Net.Sockets.Tests public override bool ValidatesArrayArguments => false; public override Task AcceptAsync(Socket s) => - s.AcceptAsync(); + s.AcceptAsync(_cts.Token).AsTask(); public override Task<(Socket socket, byte[] buffer)> AcceptAsync(Socket s, int receiveSize) => throw new NotSupportedException(); public override Task AcceptAsync(Socket s, Socket acceptSocket) => - s.AcceptAsync(acceptSocket); + s.AcceptAsync(acceptSocket, _cts.Token).AsTask(); public override Task ConnectAsync(Socket s, EndPoint endPoint) => s.ConnectAsync(endPoint, _cts.Token).AsTask(); public override Task MultiConnectAsync(Socket s, IPAddress[] addresses, int port) => diff --git a/src/libraries/System.Net.Sockets/tests/FunctionalTests/TcpListenerTest.cs b/src/libraries/System.Net.Sockets/tests/FunctionalTests/TcpListenerTest.cs index c7311dd..b23ad8c 100644 --- a/src/libraries/System.Net.Sockets/tests/FunctionalTests/TcpListenerTest.cs +++ b/src/libraries/System.Net.Sockets/tests/FunctionalTests/TcpListenerTest.cs @@ -127,7 +127,8 @@ namespace System.Net.Sockets.Tests [Theory] [InlineData(0)] // Sync [InlineData(1)] // Async - [InlineData(2)] // APM + [InlineData(2)] // Async with Cancellation + [InlineData(3)] // APM [ActiveIssue("https://github.com/dotnet/runtime/issues/51392", TestPlatforms.iOS | TestPlatforms.tvOS | TestPlatforms.MacCatalyst)] public async Task Accept_AcceptsPendingSocketOrClient(int mode) { @@ -141,6 +142,7 @@ namespace System.Net.Sockets.Tests { 0 => listener.AcceptSocket(), 1 => await listener.AcceptSocketAsync(), + 2 => await listener.AcceptSocketAsync(CancellationToken.None), _ => await Task.Factory.FromAsync(listener.BeginAcceptSocket, listener.EndAcceptSocket, null), }) { @@ -156,6 +158,7 @@ namespace System.Net.Sockets.Tests { 0 => listener.AcceptTcpClient(), 1 => await listener.AcceptTcpClientAsync(), + 2 => await listener.AcceptTcpClientAsync(CancellationToken.None), _ => await Task.Factory.FromAsync(listener.BeginAcceptTcpClient, listener.EndAcceptTcpClient, null), }) { -- 2.7.4