Expose SocketAsyncEventArgs constructor for suppressing the execution context (#706)
authorMarco Rossignoli <marco.rossignoli@gmail.com>
Mon, 30 Dec 2019 11:51:09 +0000 (12:51 +0100)
committerCory Nelson <phrosty@gmail.com>
Mon, 30 Dec 2019 11:51:09 +0000 (03:51 -0800)
Make execution context suppressing ctor public. Resolves #937.

src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.Tasks.cs
src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncEventArgs.cs
src/libraries/System.Net.Sockets/tests/FunctionalTests/SocketAsyncEventArgsTest.cs

index ec7fc1c..2128018 100644 (file)
@@ -365,6 +365,7 @@ namespace System.Net.Sockets
     public partial class SocketAsyncEventArgs : System.EventArgs, System.IDisposable
     {
         public SocketAsyncEventArgs() { }
+        public SocketAsyncEventArgs(bool unsafeSuppressExecutionContextFlow) { }
         public System.Net.Sockets.Socket AcceptSocket { get { throw null; } set { } }
         public byte[] Buffer { get { throw null; } }
         public System.Collections.Generic.IList<System.ArraySegment<byte>> BufferList { get { throw null; } set { } }
index 304f10a..d049368 100644 (file)
@@ -782,7 +782,7 @@ namespace System.Net.Sockets
             internal bool _accessed = false;
 
             internal TaskSocketAsyncEventArgs() :
-                base(flowExecutionContext: false) // avoid flowing context at lower layers as we only expose Task, which handles it
+                base(unsafeSuppressExecutionContextFlow: true) // avoid flowing context at lower layers as we only expose Task, which handles it
             {
             }
 
@@ -836,7 +836,7 @@ namespace System.Net.Sockets
 
             /// <summary>Initializes the event args.</summary>
             public AwaitableSocketAsyncEventArgs() :
-                base(flowExecutionContext: false) // avoid flowing context at lower layers as we only expose ValueTask, which handles it
+                base(unsafeSuppressExecutionContextFlow: true) // avoid flowing context at lower layers as we only expose ValueTask, which handles it
             {
             }
 
index 2bfc7a1..60bde3e 100644 (file)
@@ -83,18 +83,18 @@ namespace System.Net.Sockets
 
         private MultipleConnectAsync _multipleConnect;
 
-        public SocketAsyncEventArgs() : this(flowExecutionContext: true)
+        public SocketAsyncEventArgs() : this(unsafeSuppressExecutionContextFlow: false)
         {
         }
 
         /// <summary>Initialize the SocketAsyncEventArgs</summary>
-        /// <param name="flowExecutionContext">
-        /// Whether to capture and flow ExecutionContext. ExecutionContext flow should only
+        /// <param name="unsafeSuppressExecutionContextFlow">
+        /// Whether to disable the capturing and flow of ExecutionContext. ExecutionContext flow should only
         /// be disabled if it's going to be handled by higher layers.
         /// </param>
-        internal SocketAsyncEventArgs(bool flowExecutionContext)
+        public SocketAsyncEventArgs(bool unsafeSuppressExecutionContextFlow)
         {
-            _flowExecutionContext = flowExecutionContext;
+            _flowExecutionContext = !unsafeSuppressExecutionContextFlow;
             InitializeInternals();
         }
 
index 6377e26..1999d67 100644 (file)
@@ -149,6 +149,58 @@ namespace System.Net.Sockets.Tests
         }
 
         [Fact]
+        public async Task ExecutionContext_SocketAsyncEventArgs_Ctor_Default_FlowIsNotSuppressed()
+        {
+            await ExecutionContext_SocketAsyncEventArgs_Ctors(() => new SocketAsyncEventArgs(), false);
+        }
+
+        [Theory]
+        [InlineData(true)]
+        [InlineData(false)]
+        public async Task ExecutionContext_SocketAsyncEventArgs_Ctor_UnsafeSuppressExecutionContextFlow(bool suppressed)
+        {
+            await ExecutionContext_SocketAsyncEventArgs_Ctors(() => new SocketAsyncEventArgs(suppressed), suppressed);
+        }
+
+        private async Task ExecutionContext_SocketAsyncEventArgs_Ctors(Func<SocketAsyncEventArgs> saeaFactory, bool suppressed)
+        {
+            using (var listen = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+            using (var client = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
+            {
+                listen.Bind(new IPEndPoint(IPAddress.Loopback, 0));
+                listen.Listen(1);
+
+                Task<Socket> acceptTask = listen.AcceptAsync();
+                await Task.WhenAll(
+                    acceptTask,
+                    client.ConnectAsync(new IPEndPoint(IPAddress.Loopback, ((IPEndPoint)listen.LocalEndPoint).Port)));
+
+                using (Socket server = await acceptTask)
+                using (SocketAsyncEventArgs receiveSaea = saeaFactory())
+                {
+                    var local = new AsyncLocal<int>
+                    {
+                        Value = 42
+                    };
+                    int threadId = Environment.CurrentManagedThreadId;
+
+                    var mres = new ManualResetEventSlim();
+                    receiveSaea.SetBuffer(new byte[1], 0, 1);
+                    receiveSaea.Completed += delegate
+                    {
+                        Assert.NotEqual(threadId, Environment.CurrentManagedThreadId);
+                        Assert.Equal(suppressed ? 0 : 42, local.Value);
+                        mres.Set();
+                    };
+
+                    Assert.True(client.ReceiveAsync(receiveSaea));
+                    server.Send(new byte[1]);
+                    mres.Wait();
+                }
+            }
+        }
+
+        [Fact]
         public void SetBuffer_InvalidArgs_Throws()
         {
             using (var saea = new SocketAsyncEventArgs())
@@ -394,7 +446,7 @@ namespace System.Net.Sockets.Tests
 
                     var args = new SocketAsyncEventArgs();
                     args.SetBuffer(new byte[1024], 0, 1024);
-                    args.Completed += (_,__) => tcs.SetResult(true);
+                    args.Completed += (_, __) => tcs.SetResult(true);
 
                     for (int i = 1; i <= 10; i++)
                     {
@@ -443,14 +495,14 @@ namespace System.Net.Sockets.Tests
                     sendBufferList.Add(new ArraySegment<byte>(sendBuffer, 0, 1));
                     var sendArgs = new SocketAsyncEventArgs();
                     sendArgs.BufferList = sendBufferList;
-                    sendArgs.Completed += (_,__) => tcs.SetResult(true);
+                    sendArgs.Completed += (_, __) => tcs.SetResult(true);
 
                     var recvBuffer = new byte[64];
                     var recvBufferList = new List<ArraySegment<byte>>();
                     recvBufferList.Add(new ArraySegment<byte>(recvBuffer, 0, 1));
                     var recvArgs = new SocketAsyncEventArgs();
                     recvArgs.BufferList = recvBufferList;
-                    recvArgs.Completed += (_,__) => tcs.SetResult(true);
+                    recvArgs.Completed += (_, __) => tcs.SetResult(true);
 
                     for (int i = 1; i <= 10; i++)
                     {