ExecutionContext.Restore (#40322)
authorBen Adams <thundercat@illyriad.co.uk>
Wed, 12 Aug 2020 15:05:18 +0000 (16:05 +0100)
committerGitHub <noreply@github.com>
Wed, 12 Aug 2020 15:05:18 +0000 (08:05 -0700)
Add ExecutionContext.Restore

src/libraries/System.Private.CoreLib/src/System/Threading/ExecutionContext.cs
src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Sources/ManualResetValueTaskSourceCore.cs
src/libraries/System.Threading/ref/System.Threading.cs
src/libraries/System.Threading/tests/ExecutionContextTests.cs

index 40b6ca1..9f5b9cd 100644 (file)
@@ -67,6 +67,21 @@ namespace System.Threading
             return executionContext;
         }
 
+        // Allows capturing asynclocals for a FlowSuppressed ExecutionContext rather than returning null.
+        internal static ExecutionContext? CaptureForRestore()
+        {
+            // This is a short cut for:
+            //
+            // ExecutionContext.RestoreFlow()
+            // var ec = ExecutionContext.Capture()
+            // ExecutionContext.SuppressFlow();
+            // ...
+            // ExecutionContext.Restore(ec)
+            // ExecutionContext.SuppressFlow();
+
+            return Thread.CurrentThread._executionContext;
+        }
+
         private ExecutionContext? ShallowClone(bool isFlowSuppressed)
         {
             Debug.Assert(isFlowSuppressed != m_isFlowSuppressed);
@@ -199,7 +214,25 @@ namespace System.Threading
             edi?.Throw();
         }
 
-        internal static void Restore(ExecutionContext? executionContext)
+        /// <summary>
+        /// Restores a captured execution context to on the current thread.
+        /// </summary>
+        /// <remarks>
+        /// To revert to the current execution context; capture it before Restore, and Restore it again.
+        /// It will not automatically be reverted unlike <seealso cref="ExecutionContext.Run"/>.
+        /// </remarks>
+        /// <param name="executionContext">The ExecutionContext to set.</param>
+        public static void Restore(ExecutionContext executionContext)
+        {
+            if (executionContext == null)
+            {
+                ThrowNullContext();
+            }
+
+            RestoreInternal(executionContext);
+        }
+
+        internal static void RestoreInternal(ExecutionContext? executionContext)
         {
             Thread currentThread = Thread.CurrentThread;
 
index b9caa53..647e4d7 100644 (file)
@@ -244,7 +244,7 @@ namespace System.Threading.Tasks.Sources
             Debug.Assert(_continuation != null);
             Debug.Assert(_executionContext != null);
 
-            ExecutionContext? currentContext = ExecutionContext.Capture();
+            ExecutionContext? currentContext = ExecutionContext.CaptureForRestore();
             // Restore the captured ExecutionContext before executing anything.
             ExecutionContext.Restore(_executionContext);
 
@@ -259,7 +259,7 @@ namespace System.Threading.Tasks.Sources
                     finally
                     {
                         // Restore the current ExecutionContext.
-                        ExecutionContext.Restore(currentContext);
+                        ExecutionContext.RestoreInternal(currentContext);
                     }
                 }
                 else
@@ -284,7 +284,7 @@ namespace System.Threading.Tasks.Sources
                         // Set sync context back to what it was prior to coming in
                         SynchronizationContext.SetSynchronizationContext(syncContext);
                         // Restore the current ExecutionContext.
-                        ExecutionContext.Restore(currentContext);
+                        ExecutionContext.RestoreInternal(currentContext);
                     }
 
                     // Now rethrow the exception; if there is one.
@@ -301,7 +301,7 @@ namespace System.Threading.Tasks.Sources
             finally
             {
                 // Restore the current ExecutionContext.
-                ExecutionContext.Restore(currentContext);
+                ExecutionContext.RestoreInternal(currentContext);
             }
         }
 
index 888abda..d2f60ac 100644 (file)
@@ -128,6 +128,7 @@ namespace System.Threading
         public void Dispose() { }
         public void GetObjectData(System.Runtime.Serialization.SerializationInfo info, System.Runtime.Serialization.StreamingContext context) { }
         public static bool IsFlowSuppressed() { throw null; }
+        public static void Restore(System.Threading.ExecutionContext executionContext) { }
         public static void RestoreFlow() { }
         public static void Run(System.Threading.ExecutionContext executionContext, System.Threading.ContextCallback callback, object? state) { }
         public static System.Threading.AsyncFlowControl SuppressFlow() { throw null; }
index 6db03a6..96261fb 100644 (file)
@@ -31,6 +31,29 @@ namespace System.Threading.Tests
         }
 
         [Fact]
+        public static void RestoreTest()
+        {
+            ExecutionContext defaultEC = ExecutionContext.Capture();
+            var asyncLocal = new AsyncLocal<int>();
+            Assert.Equal(0, asyncLocal.Value);
+
+            asyncLocal.Value = 1;
+            ExecutionContext oneEC = ExecutionContext.Capture();
+            Assert.Equal(1, asyncLocal.Value);
+
+            ExecutionContext.Restore(defaultEC);
+            Assert.Equal(0, asyncLocal.Value);
+
+            ExecutionContext.Restore(oneEC);
+            Assert.Equal(1, asyncLocal.Value);
+
+            ExecutionContext.Restore(defaultEC);
+            Assert.Equal(0, asyncLocal.Value);
+
+            Assert.Throws<InvalidOperationException>(() => ExecutionContext.Restore(null!));
+        }
+
+        [Fact]
         public static void DisposeTest()
         {
             ExecutionContext executionContext = ExecutionContext.Capture();