add AcceptAsync cancellation overloads (#53340)
authorGeoff Kizer <geoffrek@microsoft.com>
Sat, 29 May 2021 19:10:10 +0000 (12:10 -0700)
committerGitHub <noreply@github.com>
Sat, 29 May 2021 19:10:10 +0000 (12:10 -0700)
* add AcceptAsync cancellation overloads

* pass cancellationToken to AcceptAsync in Unix NamedPipe impl

* add TcpListener overloads too

* enable pipe cancellation test on Unix

Co-authored-by: Geoffrey Kizer <geoffrek@windows.microsoft.com>
13 files changed:
src/libraries/System.IO.Pipes/src/System/IO/Pipes/NamedPipeServerStream.Unix.cs
src/libraries/System.IO.Pipes/tests/PipeStreamConformanceTests.cs
src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Unix.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.Windows.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/TCPListener.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/Accept.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/SendFile.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketTestHelper.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/TcpListenerTest.cs

index 4628fe4..8a309dc 100644 (file)
@@ -78,7 +78,7 @@ namespace System.IO.Pipes
                 WaitForConnectionAsyncCore();
 
             async Task WaitForConnectionAsyncCore() =>
-               HandleAcceptedSocket(await _instance!.ListeningSocket.AcceptAsync().ConfigureAwait(false));
+               HandleAcceptedSocket(await _instance!.ListeningSocket.AcceptAsync(cancellationToken).ConfigureAwait(false));
         }
 
         private void HandleAcceptedSocket(Socket acceptedSocket)
index b4962b6..805bc29 100644 (file)
@@ -201,14 +201,10 @@ namespace System.IO.Pipes.Tests
 
             var ctx = new CancellationTokenSource();
 
-            if (OperatingSystem.IsWindows()) // cancellation token after the operation has been initiated
-            {
-                Task serverWaitTimeout = server.WaitForConnectionAsync(ctx.Token);
-                ctx.Cancel();
-                await Assert.ThrowsAnyAsync<OperationCanceledException>(() => serverWaitTimeout);
-            }
-
+            Task serverWaitTimeout = server.WaitForConnectionAsync(ctx.Token);
             ctx.Cancel();
+            await Assert.ThrowsAnyAsync<OperationCanceledException>(() => serverWaitTimeout);
+
             Assert.True(server.WaitForConnectionAsync(ctx.Token).IsCanceled);
         }
 
index b128e5c..0ea7f6b 100644 (file)
@@ -297,7 +297,9 @@ namespace System.Net.Sockets
         public bool UseOnlyOverlappedIO { get { throw null; } set { } }
         public System.Net.Sockets.Socket Accept() { throw null; }
         public System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptAsync() { throw null; }
+        public System.Threading.Tasks.ValueTask<System.Net.Sockets.Socket> AcceptAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
         public System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptAsync(System.Net.Sockets.Socket? acceptSocket) { throw null; }
+        public System.Threading.Tasks.ValueTask<System.Net.Sockets.Socket> AcceptAsync(System.Net.Sockets.Socket? acceptSocket, System.Threading.CancellationToken cancellationToken) { throw null; }
         public bool AcceptAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
         public System.IAsyncResult BeginAccept(System.AsyncCallback? callback, object? state) { throw null; }
         public System.IAsyncResult BeginAccept(int receiveSize, System.AsyncCallback? callback, object? state) { throw null; }
@@ -691,8 +693,10 @@ namespace System.Net.Sockets
         public System.Net.Sockets.Socket Server { get { throw null; } }
         public System.Net.Sockets.Socket AcceptSocket() { throw null; }
         public System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptSocketAsync() { throw null; }
+        public System.Threading.Tasks.ValueTask<System.Net.Sockets.Socket> AcceptSocketAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
         public System.Net.Sockets.TcpClient AcceptTcpClient() { throw null; }
         public System.Threading.Tasks.Task<System.Net.Sockets.TcpClient> AcceptTcpClientAsync() { throw null; }
+        public System.Threading.Tasks.ValueTask<System.Net.Sockets.TcpClient> AcceptTcpClientAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
         [System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")]
         public void AllowNatTraversal(bool allowed) { }
         public System.IAsyncResult BeginAcceptSocket(System.AsyncCallback? callback, object? state) { throw null; }
index fb2c911..46fecd7 100644 (file)
@@ -15,12 +15,9 @@ namespace System.Net.Sockets
 {
     public partial class Socket
     {
-        /// <summary>Cached instance for accept operations.</summary>
-        private TaskSocketAsyncEventArgs<Socket>? _acceptEventArgs;
-
         /// <summary>Cached instance for receive operations that return <see cref="ValueTask{Int32}"/>. Also used for ConnectAsync operations.</summary>
         private AwaitableSocketAsyncEventArgs? _singleBufferReceiveEventArgs;
-        /// <summary>Cached instance for send operations that return <see cref="ValueTask{Int32}"/>.</summary>
+        /// <summary>Cached instance for send operations that return <see cref="ValueTask{Int32}"/>. Also used for AcceptAsync operations.</summary>
         private AwaitableSocketAsyncEventArgs? _singleBufferSendEventArgs;
 
         /// <summary>Cached instance for receive operations that return <see cref="Task{Int32}"/>.</summary>
@@ -32,54 +29,44 @@ namespace System.Net.Sockets
         /// Accepts an incoming connection.
         /// </summary>
         /// <returns>An asynchronous task that completes with the accepted Socket.</returns>
-        public Task<Socket> AcceptAsync() => AcceptAsync((Socket?)null);
+        public Task<Socket> AcceptAsync() => AcceptAsync((Socket?)null, CancellationToken.None).AsTask();
 
         /// <summary>
         /// Accepts an incoming connection.
         /// </summary>
-        /// <param name="acceptSocket">The socket to use for accepting the connection.</param>
+        /// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
         /// <returns>An asynchronous task that completes with the accepted Socket.</returns>
-        public Task<Socket> AcceptAsync(Socket? acceptSocket)
-        {
-            // Get any cached SocketAsyncEventArg we may have.
-            TaskSocketAsyncEventArgs<Socket>? saea = Interlocked.Exchange(ref _acceptEventArgs, null);
-            if (saea is null)
-            {
-                saea = new TaskSocketAsyncEventArgs<Socket>();
-                saea.Completed += (s, e) => CompleteAccept((Socket)s!, (TaskSocketAsyncEventArgs<Socket>)e);
-            }
+        public ValueTask<Socket> AcceptAsync(CancellationToken cancellationToken) => AcceptAsync((Socket?)null, cancellationToken);
 
-            // Configure the SAEA.
-            saea.AcceptSocket = acceptSocket;
+        /// <summary>
+        /// Accepts an incoming connection.
+        /// </summary>
+        /// <param name="acceptSocket">The socket to use for accepting the connection.</param>
+        /// <returns>An asynchronous task that completes with the accepted Socket.</returns>
+        public Task<Socket> AcceptAsync(Socket? acceptSocket) => AcceptAsync(acceptSocket, CancellationToken.None).AsTask();
 
-            // Initiate the accept operation.
-            Task<Socket> t;
-            if (AcceptAsync(saea))
+        /// <summary>
+        /// Accepts an incoming connection.
+        /// </summary>
+        /// <param name="acceptSocket">The socket to use for accepting the connection.</param>
+        /// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
+        /// <returns>An asynchronous task that completes with the accepted Socket.</returns>
+        public ValueTask<Socket> AcceptAsync(Socket? acceptSocket, CancellationToken cancellationToken)
+        {
+            if (cancellationToken.IsCancellationRequested)
             {
-                // The operation is completing asynchronously (it may have already completed).
-                // Get the task for the operation, with appropriate synchronization to coordinate
-                // with the async callback that'll be completing the task.
-                bool responsibleForReturningToPool;
-                t = saea.GetCompletionResponsibility(out responsibleForReturningToPool).Task;
-                if (responsibleForReturningToPool)
-                {
-                    // We're responsible for returning it only if the callback has already been invoked
-                    // and gotten what it needs from the SAEA; otherwise, the callback will return it.
-                    ReturnSocketAsyncEventArgs(saea);
-                }
+                return ValueTask.FromCanceled<Socket>(cancellationToken);
             }
-            else
-            {
-                // The operation completed synchronously.  Get a task for it.
-                t = saea.SocketError == SocketError.Success ?
-                    Task.FromResult(saea.AcceptSocket!) :
-                    Task.FromException<Socket>(GetException(saea.SocketError));
 
-                // There won't be a callback, and we're done with the SAEA, so return it to the pool.
-                ReturnSocketAsyncEventArgs(saea);
-            }
+            AwaitableSocketAsyncEventArgs saea =
+                Interlocked.Exchange(ref _singleBufferSendEventArgs, null) ??
+                new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: false);
 
-            return t;
+            Debug.Assert(saea.BufferList == null);
+            saea.SetBuffer(null, 0, 0);
+            saea.AcceptSocket = acceptSocket;
+            saea.WrapExceptionsForNetworkStream = false;
+            return saea.AcceptAsync(this, cancellationToken);
         }
 
         /// <summary>
@@ -739,34 +726,6 @@ namespace System.Net.Sockets
         }
 
         /// <summary>Completes the SocketAsyncEventArg's Task with the result of the send or receive, and returns it to the specified pool.</summary>
-        private static void CompleteAccept(Socket s, TaskSocketAsyncEventArgs<Socket> saea)
-        {
-            // Pull the relevant state off of the SAEA
-            SocketError error = saea.SocketError;
-            Socket? acceptSocket = saea.AcceptSocket;
-
-            // Synchronize with the initiating thread. If the synchronous caller already got what
-            // it needs from the SAEA, then we can return it to the pool now. Otherwise, it'll be
-            // responsible for returning it once it's gotten what it needs from it.
-            bool responsibleForReturningToPool;
-            AsyncTaskMethodBuilder<Socket> builder = saea.GetCompletionResponsibility(out responsibleForReturningToPool);
-            if (responsibleForReturningToPool)
-            {
-                s.ReturnSocketAsyncEventArgs(saea);
-            }
-
-            // Complete the builder/task with the results.
-            if (error == SocketError.Success)
-            {
-                builder.SetResult(acceptSocket!);
-            }
-            else
-            {
-                builder.SetException(GetException(error));
-            }
-        }
-
-        /// <summary>Completes the SocketAsyncEventArg's Task with the result of the send or receive, and returns it to the specified pool.</summary>
         private static void CompleteSendReceive(Socket s, TaskSocketAsyncEventArgs<int> saea, bool isReceive)
         {
             // Pull the relevant state off of the SAEA
@@ -824,29 +783,9 @@ namespace System.Net.Sockets
             }
         }
 
-        /// <summary>Returns a <see cref="TaskSocketAsyncEventArgs{TResult}"/> instance for reuse.</summary>
-        /// <param name="saea">The instance to return.</param>
-        private void ReturnSocketAsyncEventArgs(TaskSocketAsyncEventArgs<Socket> saea)
-        {
-            // Reset state on the SAEA before returning it.  But do not reset buffer state.  That'll be done
-            // if necessary by the consumer, but we want to keep the buffers due to likely subsequent reuse
-            // and the costs associated with changing them.
-            saea.AcceptSocket = null;
-            saea._accessed = false;
-            saea._builder = default;
-
-            // Write this instance back as a cached instance, only if there isn't currently one cached.
-            if (Interlocked.CompareExchange(ref _acceptEventArgs, saea, null) != null)
-            {
-                // Couldn't return it, so dispose it.
-                saea.Dispose();
-            }
-        }
-
         /// <summary>Dispose of any cached <see cref="TaskSocketAsyncEventArgs{TResult}"/> instances.</summary>
         private void DisposeCachedTaskSocketAsyncEventArgs()
         {
-            Interlocked.Exchange(ref _acceptEventArgs, null)?.Dispose();
             Interlocked.Exchange(ref _multiBufferReceiveEventArgs, null)?.Dispose();
             Interlocked.Exchange(ref _multiBufferSendEventArgs, null)?.Dispose();
             Interlocked.Exchange(ref _singleBufferReceiveEventArgs, null)?.Dispose();
@@ -907,7 +846,7 @@ namespace System.Net.Sockets
         }
 
         /// <summary>A SocketAsyncEventArgs that can be awaited to get the result of an operation.</summary>
-        internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource<int>, IValueTaskSource<SocketReceiveFromResult>, IValueTaskSource<SocketReceiveMessageFromResult>
+        internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource<int>, IValueTaskSource<Socket>, IValueTaskSource<SocketReceiveFromResult>, IValueTaskSource<SocketReceiveMessageFromResult>
         {
             private static readonly Action<object?> s_completedSentinel = new Action<object?>(state => throw new InvalidOperationException(SR.Format(SR.net_sockets_valuetaskmisuse, nameof(s_completedSentinel))));
             /// <summary>The owning socket.</summary>
@@ -987,6 +926,28 @@ namespace System.Net.Sockets
                 }
             }
 
+            /// <summary>Initiates an accept operation on the associated socket.</summary>
+            /// <returns>This instance.</returns>
+            public ValueTask<Socket> AcceptAsync(Socket socket, CancellationToken cancellationToken)
+            {
+                Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use");
+
+                if (socket.AcceptAsync(this, cancellationToken))
+                {
+                    _cancellationToken = cancellationToken;
+                    return new ValueTask<Socket>(this, _token);
+                }
+
+                Socket acceptSocket = AcceptSocket!;
+                SocketError error = SocketError;
+
+                Release();
+
+                return error == SocketError.Success ?
+                    new ValueTask<Socket>(acceptSocket) :
+                    ValueTask.FromException<Socket>(CreateException(error));
+            }
+
             /// <summary>Initiates a receive operation on the associated socket.</summary>
             /// <returns>This instance.</returns>
             public ValueTask<int> ReceiveAsync(Socket socket, CancellationToken cancellationToken)
@@ -1288,7 +1249,7 @@ namespace System.Net.Sockets
             /// Unlike TaskAwaiter's GetResult, this does not block until the operation completes: it must only
             /// be used once the operation has completed.  This is handled implicitly by await.
             /// </remarks>
-            public int GetResult(short token)
+            int IValueTaskSource<int>.GetResult(short token)
             {
                 if (token != _token)
                 {
@@ -1326,6 +1287,26 @@ namespace System.Net.Sockets
                 }
             }
 
+            Socket IValueTaskSource<Socket>.GetResult(short token)
+            {
+                if (token != _token)
+                {
+                    ThrowIncorrectTokenException();
+                }
+
+                SocketError error = SocketError;
+                Socket acceptSocket = AcceptSocket!;
+                CancellationToken cancellationToken = _cancellationToken;
+
+                Release();
+
+                if (error != SocketError.Success)
+                {
+                    ThrowException(error, cancellationToken);
+                }
+                return acceptSocket;
+            }
+
             SocketReceiveFromResult IValueTaskSource<SocketReceiveFromResult>.GetResult(short token)
             {
                 if (token != _token)
index 9138e01..2bf4645 100644 (file)
@@ -2656,7 +2656,9 @@ namespace System.Net.Sockets
         // Async methods
         //
 
-        public bool AcceptAsync(SocketAsyncEventArgs e)
+        public bool AcceptAsync(SocketAsyncEventArgs e) => AcceptAsync(e, CancellationToken.None);
+
+        private bool AcceptAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
         {
             ThrowIfDisposed();
 
@@ -2689,7 +2691,7 @@ namespace System.Net.Sockets
             SocketError socketError;
             try
             {
-                socketError = e.DoOperationAccept(this, _handle, acceptHandle);
+                socketError = e.DoOperationAccept(this, _handle, acceptHandle, cancellationToken);
             }
             catch (Exception ex)
             {
index 4e1383c..a9bf3de 100644 (file)
@@ -1433,7 +1433,7 @@ namespace System.Net.Sockets
             return operation.ErrorCode;
         }
 
-        public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, Action<IntPtr, byte[], int, SocketError> callback)
+        public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, Action<IntPtr, byte[], int, SocketError> callback, CancellationToken cancellationToken)
         {
             Debug.Assert(socketAddress != null, "Expected non-null socketAddress");
             Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}");
@@ -1456,7 +1456,7 @@ namespace System.Net.Sockets
             operation.SocketAddress = socketAddress;
             operation.SocketAddressLen = socketAddressLen;
 
-            if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber))
+            if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken))
             {
                 socketAddressLen = operation.SocketAddressLen;
                 acceptedFd = operation.AcceptedFileDescriptor;
index a8be414..28d3016 100644 (file)
@@ -51,7 +51,7 @@ namespace System.Net.Sockets
             _acceptAddressBufferCount = socketAddressSize;
         }
 
-        internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle? acceptHandle)
+        internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle? acceptHandle, CancellationToken cancellationToken)
         {
             if (!_buffer.Equals(default))
             {
@@ -64,7 +64,7 @@ namespace System.Net.Sockets
 
             IntPtr acceptedFd;
             int socketAddressLen = _acceptAddressBufferCount / 2;
-            SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, ref socketAddressLen, out acceptedFd, AcceptCompletionCallback);
+            SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, ref socketAddressLen, out acceptedFd, AcceptCompletionCallback, cancellationToken);
 
             if (socketError != SocketError.IOPending)
             {
index 73cb36f..d8d59e9 100644 (file)
@@ -255,37 +255,38 @@ namespace System.Net.Sockets
             return socketError;
         }
 
-        internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle acceptHandle)
+        internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle acceptHandle, CancellationToken cancellationToken)
         {
             bool userBuffer = _count != 0;
             Debug.Assert(!userBuffer || (!_buffer.Equals(default) && _count >= _acceptAddressBufferCount));
             Memory<byte> buffer = userBuffer ? _buffer : _acceptBuffer;
-            Debug.Assert(_asyncProcessingState == AsyncProcessingState.None);
 
-            NativeOverlapped* overlapped = AllocateNativeOverlapped();
-            try
+            fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer.Span))
             {
-                _singleBufferHandle = buffer.Pin();
-                _asyncProcessingState = AsyncProcessingState.Set;
+                NativeOverlapped* overlapped = AllocateNativeOverlapped();
+                try
+                {
+                    Debug.Assert(_asyncProcessingState == AsyncProcessingState.None, $"Expected None, got {_asyncProcessingState}");
+                    _asyncProcessingState = AsyncProcessingState.InProcess;
 
-                bool success = socket.AcceptEx(
-                    handle,
-                    acceptHandle,
-                    userBuffer ? (IntPtr)((byte*)_singleBufferHandle.Pointer + _offset) : (IntPtr)_singleBufferHandle.Pointer,
-                    userBuffer ? _count - _acceptAddressBufferCount : 0,
-                    _acceptAddressBufferCount / 2,
-                    _acceptAddressBufferCount / 2,
-                    out int bytesTransferred,
-                    overlapped);
+                    bool success = socket.AcceptEx(
+                        handle,
+                        acceptHandle,
+                        (IntPtr)(userBuffer ? (bufferPtr + _offset) : bufferPtr),
+                        userBuffer ? _count - _acceptAddressBufferCount : 0,
+                        _acceptAddressBufferCount / 2,
+                        _acceptAddressBufferCount / 2,
+                        out int bytesTransferred,
+                        overlapped);
 
-                return ProcessIOCPResult(success, bytesTransferred, overlapped);
-            }
-            catch
-            {
-                _asyncProcessingState = AsyncProcessingState.None;
-                FreeNativeOverlapped(overlapped);
-                _singleBufferHandle.Dispose();
-                throw;
+                    return ProcessIOCPResultWithDeferredAsyncHandling(success, bytesTransferred, overlapped, buffer, cancellationToken);
+                }
+                catch
+                {
+                    _asyncProcessingState = AsyncProcessingState.None;
+                    FreeNativeOverlapped(overlapped);
+                    throw;
+                }
             }
         }
 
@@ -1088,20 +1089,26 @@ namespace System.Net.Sockets
                 safeHandle.DangerousAddRef(ref refAdded);
                 IntPtr handle = safeHandle.DangerousGetHandle();
 
-                Debug.Assert(_asyncProcessingState == AsyncProcessingState.Set);
-                bool userBuffer = _count >= _acceptAddressBufferCount;
-
-                _currentSocket.GetAcceptExSockaddrs(
-                    userBuffer ? (IntPtr)((byte*)_singleBufferHandle.Pointer + _offset) : (IntPtr)_singleBufferHandle.Pointer,
-                    _count != 0 ? _count - _acceptAddressBufferCount : 0,
-                    _acceptAddressBufferCount / 2,
-                    _acceptAddressBufferCount / 2,
-                    out localAddr,
-                    out localAddrLength,
-                    out remoteAddr,
-                    out remoteSocketAddress.InternalSize
+                // This matches the logic in DoOperationAccept
+                bool userBuffer = _count != 0;
+                Debug.Assert(!userBuffer || (!_buffer.Equals(default) && _count >= _acceptAddressBufferCount));
+                Memory<byte> buffer = userBuffer ? _buffer : _acceptBuffer;
+
+                fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer.Span))
+                {
+                    _currentSocket.GetAcceptExSockaddrs(
+                        (IntPtr)(userBuffer ? (bufferPtr + _offset) : bufferPtr),
+                        userBuffer ? _count - _acceptAddressBufferCount : 0,
+                        _acceptAddressBufferCount / 2,
+                        _acceptAddressBufferCount / 2,
+                        out localAddr,
+                        out localAddrLength,
+                        out remoteAddr,
+                        out remoteSocketAddress.InternalSize
                     );
-                Marshal.Copy(remoteAddr, remoteSocketAddress.Buffer, 0, remoteSocketAddress.Size);
+
+                    Marshal.Copy(remoteAddr, remoteSocketAddress.Buffer, 0, remoteSocketAddress.Size);
+                }
 
                 socketError = Interop.Winsock.setsockopt(
                     _acceptSocket!.SafeHandle,
index e622257..61220e8 100644 (file)
@@ -1,6 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Threading;
 using System.Threading.Tasks;
 using System.Runtime.Versioning;
 using System.Diagnostics;
@@ -220,25 +221,28 @@ namespace System.Net.Sockets
         public TcpClient EndAcceptTcpClient(IAsyncResult asyncResult) =>
             EndAcceptCore<TcpClient>(asyncResult);
 
-        public Task<Socket> AcceptSocketAsync()
+        public Task<Socket> AcceptSocketAsync() => AcceptSocketAsync(CancellationToken.None).AsTask();
+
+        public ValueTask<Socket> AcceptSocketAsync(CancellationToken cancellationToken)
         {
             if (!_active)
             {
                 throw new InvalidOperationException(SR.net_stopped);
             }
 
-            return _serverSocket!.AcceptAsync();
+            return _serverSocket!.AcceptAsync(cancellationToken);
         }
 
-        public Task<TcpClient> AcceptTcpClientAsync()
+        public Task<TcpClient> AcceptTcpClientAsync() => AcceptTcpClientAsync(CancellationToken.None).AsTask();
+
+        public ValueTask<TcpClient> AcceptTcpClientAsync(CancellationToken cancellationToken)
         {
-            return WaitAndWrap(AcceptSocketAsync());
+            return WaitAndWrap(AcceptSocketAsync(cancellationToken));
 
-            static async Task<TcpClient> WaitAndWrap(Task<Socket> task) =>
+            static async ValueTask<TcpClient> WaitAndWrap(ValueTask<Socket> task) =>
                 new TcpClient(await task.ConfigureAwait(false));
         }
 
-
         // This creates a TcpListener that listens on both IPv4 and IPv6 on the given port.
         public static TcpListener Create(int port)
         {
index 1427708..73c9310 100644 (file)
@@ -397,6 +397,50 @@ namespace System.Net.Sockets.Tests
         public AcceptTask(ITestOutputHelper output) : base(output) {}
     }
 
+    public sealed class AcceptCancellableTask : Accept<SocketHelperCancellableTask>
+    {
+        public AcceptCancellableTask(ITestOutputHelper output) : base(output) { }
+
+        [Fact]
+        public async Task AcceptAsync_Precanceled_Throws()
+        {
+            using (Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+            {
+                int port = listen.BindToAnonymousPort(IPAddress.Loopback);
+                listen.Listen(1);
+
+                var cts = new CancellationTokenSource();
+                cts.Cancel();
+
+                var acceptTask = listen.AcceptAsync(cts.Token);
+                Assert.True(acceptTask.IsCompleted);
+
+                var oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await acceptTask);
+                Assert.Equal(cts.Token, oce.CancellationToken);
+            }
+        }
+
+        [Fact]
+        public async Task AcceptAsync_CanceledDuringOperation_Throws()
+        {
+            using (Socket listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+            {
+                int port = listen.BindToAnonymousPort(IPAddress.Loopback);
+                listen.Listen(1);
+
+                var cts = new CancellationTokenSource();
+
+                var acceptTask = listen.AcceptAsync(cts.Token);
+                Assert.False(acceptTask.IsCompleted);
+
+                cts.Cancel();
+
+                var oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await acceptTask);
+                Assert.Equal(cts.Token, oce.CancellationToken);
+            }
+        }
+    }
+
     public sealed class AcceptEap : Accept<SocketHelperEap>
     {
         public AcceptEap(ITestOutputHelper output) : base(output) {}
index 31f8f6c..cda5d75 100644 (file)
@@ -421,9 +421,14 @@ namespace System.Net.Sockets.Tests
     public sealed class SendFile_Task : SendFile<SocketHelperTask>
     {
         public SendFile_Task(ITestOutputHelper output) : base(output) { }
+    }
+
+    public sealed class SendFile_CancellableTask : SendFile<SocketHelperCancellableTask>
+    {
+        public SendFile_CancellableTask(ITestOutputHelper output) : base(output) { }
 
         [Fact]
-        public async Task Precanceled_Throws()
+        public async Task SendFileAsync_Precanceled_Throws()
         {
             using (var listener = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
             using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
@@ -445,7 +450,7 @@ namespace System.Net.Sockets.Tests
         [Theory]
         [InlineData(false)]
         [InlineData(true)]
-        public async Task SendAsync_CanceledDuringOperation_Throws(bool ipv6)
+        public async Task SendFileAsync_CanceledDuringOperation_Throws(bool ipv6)
         {
             const int CancelAfter = 200; // ms
             const int NumOfSends = 100;
index 7c40f83..e75013e 100644 (file)
@@ -263,11 +263,11 @@ namespace System.Net.Sockets.Tests
         public override bool ValidatesArrayArguments => false;
 
         public override Task<Socket> AcceptAsync(Socket s) =>
-            s.AcceptAsync();
+            s.AcceptAsync(_cts.Token).AsTask();
         public override Task<(Socket socket, byte[] buffer)> AcceptAsync(Socket s, int receiveSize)
             => throw new NotSupportedException();
         public override Task<Socket> AcceptAsync(Socket s, Socket acceptSocket) =>
-            s.AcceptAsync(acceptSocket);
+            s.AcceptAsync(acceptSocket, _cts.Token).AsTask();
         public override Task ConnectAsync(Socket s, EndPoint endPoint) =>
             s.ConnectAsync(endPoint, _cts.Token).AsTask();
         public override Task MultiConnectAsync(Socket s, IPAddress[] addresses, int port) =>
index c7311dd..b23ad8c 100644 (file)
@@ -127,7 +127,8 @@ namespace System.Net.Sockets.Tests
         [Theory]
         [InlineData(0)] // Sync
         [InlineData(1)] // Async
-        [InlineData(2)] // APM
+        [InlineData(2)] // Async with Cancellation
+        [InlineData(3)] // APM
         [ActiveIssue("https://github.com/dotnet/runtime/issues/51392", TestPlatforms.iOS | TestPlatforms.tvOS | TestPlatforms.MacCatalyst)]
         public async Task Accept_AcceptsPendingSocketOrClient(int mode)
         {
@@ -141,6 +142,7 @@ namespace System.Net.Sockets.Tests
                 {
                     0 => listener.AcceptSocket(),
                     1 => await listener.AcceptSocketAsync(),
+                    2 => await listener.AcceptSocketAsync(CancellationToken.None),
                     _ => await Task.Factory.FromAsync(listener.BeginAcceptSocket, listener.EndAcceptSocket, null),
                 })
                 {
@@ -156,6 +158,7 @@ namespace System.Net.Sockets.Tests
                 {
                     0 => listener.AcceptTcpClient(),
                     1 => await listener.AcceptTcpClientAsync(),
+                    2 => await listener.AcceptTcpClientAsync(CancellationToken.None),
                     _ => await Task.Factory.FromAsync(listener.BeginAcceptTcpClient, listener.EndAcceptTcpClient, null),
                 })
                 {