Reduce Execution Context Save+Restore (#15629)
[platform/upstream/coreclr.git] / src / mscorlib / shared / System / Threading / ExecutionContext.cs
index 67857e9..9f27c6d 100644 (file)
@@ -12,7 +12,6 @@
 ===========================================================*/
 
 using System.Diagnostics;
-using System.Diagnostics.Contracts;
 using System.Runtime.ExceptionServices;
 using System.Runtime.Serialization;
 
@@ -22,41 +21,18 @@ namespace System.Threading
 {
     public delegate void ContextCallback(Object state);
 
-    internal struct ExecutionContextSwitcher
-    {
-        internal ExecutionContext m_ec;
-        internal SynchronizationContext m_sc;
-
-        internal void Undo(Thread currentThread)
-        {
-            Debug.Assert(currentThread == Thread.CurrentThread);
-
-            // The common case is that these have not changed, so avoid the cost of a write if not needed.
-            if (currentThread.SynchronizationContext != m_sc)
-            {
-                currentThread.SynchronizationContext = m_sc;
-            }
-
-            if (currentThread.ExecutionContext != m_ec)
-            {
-                ExecutionContext.Restore(currentThread, m_ec);
-            }
-        }
-    }
-
-    [Serializable]
     public sealed class ExecutionContext : IDisposable, ISerializable
     {
-        internal static readonly ExecutionContext Default = new ExecutionContext();
+        internal static readonly ExecutionContext Default = new ExecutionContext(isDefault: true);
 
         private readonly IAsyncLocalValueMap m_localValues;
         private readonly IAsyncLocal[] m_localChangeNotifications;
         private readonly bool m_isFlowSuppressed;
+        private readonly bool m_isDefault;
 
-        private ExecutionContext()
+        private ExecutionContext(bool isDefault)
         {
-            m_localValues = AsyncLocalValueMap.Empty;
-            m_localChangeNotifications = Array.Empty<IAsyncLocal>();
+            m_isDefault = isDefault;
         }
 
         private ExecutionContext(
@@ -71,15 +47,7 @@ namespace System.Threading
 
         public void GetObjectData(SerializationInfo info, StreamingContext context)
         {
-            if (info == null)
-            {
-                throw new ArgumentNullException(nameof(info));
-            }
-            Contract.EndContractBlock();
-        }
-
-        private ExecutionContext(SerializationInfo info, StreamingContext context)
-        {
+            throw new PlatformNotSupportedException();
         }
 
         public static ExecutionContext Capture()
@@ -96,12 +64,14 @@ namespace System.Threading
             Debug.Assert(isFlowSuppressed != m_isFlowSuppressed);
 
             if (!isFlowSuppressed &&
-                m_localValues == Default.m_localValues &&
-                m_localChangeNotifications == Default.m_localChangeNotifications)
+                (m_localValues == null ||
+                 m_localValues.GetType() == typeof(AsyncLocalValueMap.EmptyAsyncLocalValueMap))
+               )
             {
                 return null; // implies the default context
             }
-            return new ExecutionContext(m_localValues, m_localChangeNotifications, isFlowSuppressed);
+            // Flow suppressing a Default context will have null values, set them to Empty
+            return new ExecutionContext(m_localValues ?? AsyncLocalValueMap.Empty, m_localChangeNotifications ?? Array.Empty<IAsyncLocal>(), isFlowSuppressed);
         }
 
         public static AsyncFlowControl SuppressFlow()
@@ -112,7 +82,6 @@ namespace System.Threading
             {
                 throw new InvalidOperationException(SR.InvalidOperation_CannotSupressFlowMultipleTimes);
             }
-            Contract.EndContractBlock();
 
             executionContext = executionContext.ShallowClone(isFlowSuppressed: true);
             var asyncFlowControl = new AsyncFlowControl();
@@ -129,7 +98,6 @@ namespace System.Threading
             {
                 throw new InvalidOperationException(SR.InvalidOperation_CannotRestoreUnsupressedFlow);
             }
-            Contract.EndContractBlock();
 
             currentThread.ExecutionContext = executionContext.ShallowClone(isFlowSuppressed: false);
         }
@@ -140,134 +108,248 @@ namespace System.Threading
             return executionContext != null && executionContext.m_isFlowSuppressed;
         }
 
+        internal bool HasChangeNotifications => m_localChangeNotifications != null;
+
+        internal bool IsDefault => m_isDefault;
+
         public static void Run(ExecutionContext executionContext, ContextCallback callback, Object state)
         {
+            // Note: ExecutionContext.Run is an extremely hot function and used by every await, ThreadPool execution, etc.
             if (executionContext == null)
-                throw new InvalidOperationException(SR.InvalidOperation_NullContext);
+            {
+                ThrowNullContext();
+            }
 
-            Thread currentThread = Thread.CurrentThread;
-            ExecutionContextSwitcher ecsw = default(ExecutionContextSwitcher);
+            RunInternal(executionContext, callback, state);
+        }
+
+        internal static void RunInternal(ExecutionContext executionContext, ContextCallback callback, Object state)
+        {
+            // Note: ExecutionContext.RunInternal is an extremely hot function and used by every await, ThreadPool execution, etc.
+            // Note: Manual enregistering may be addressed by "Exception Handling Write Through Optimization"
+            //       https://github.com/dotnet/coreclr/blob/master/Documentation/design-docs/eh-writethru.md
+
+            // Enregister variables with 0 post-fix so they can be used in registers without EH forcing them to stack
+            // Capture references to Thread Contexts
+            Thread currentThread0 = Thread.CurrentThread;
+            Thread currentThread = currentThread0;
+            ExecutionContext previousExecutionCtx0 = currentThread0.ExecutionContext;
+
+            // Store current ExecutionContext and SynchronizationContext as "previousXxx".
+            // This allows us to restore them and undo any Context changes made in callback.Invoke
+            // so that they won't "leak" back into caller.
+            // These variables will cross EH so be forced to stack
+            ExecutionContext previousExecutionCtx = previousExecutionCtx0;
+            SynchronizationContext previousSyncCtx = currentThread0.SynchronizationContext;
+
+            if (executionContext != null && executionContext.m_isDefault)
+            {
+                // Default is a null ExecutionContext internally
+                executionContext = null;
+            }
+
+            if (previousExecutionCtx0 != executionContext)
+            {
+                // Restore changed ExecutionContext
+                currentThread0.ExecutionContext = executionContext;
+                if ((executionContext != null && executionContext.HasChangeNotifications) ||
+                    (previousExecutionCtx0 != null && previousExecutionCtx0.HasChangeNotifications))
+                {
+                    // There are change notifications; trigger any affected
+                    OnValuesChanged(previousExecutionCtx0, executionContext);
+                }
+            }
+
+            ExceptionDispatchInfo edi = null;
             try
             {
-                EstablishCopyOnWriteScope(currentThread, ref ecsw);
-                ExecutionContext.Restore(currentThread, executionContext);
-                callback(state);
+                callback.Invoke(state);
             }
-            catch
+            catch (Exception ex)
             {
                 // Note: we have a "catch" rather than a "finally" because we want
                 // to stop the first pass of EH here.  That way we can restore the previous
-                // context before any of our callers' EH filters run.  That means we need to
-                // end the scope separately in the non-exceptional case below.
-                ecsw.Undo(currentThread);
-                throw;
+                // context before any of our callers' EH filters run.
+                edi = ExceptionDispatchInfo.Capture(ex);
             }
-            ecsw.Undo(currentThread);
-        }
-
-        internal static void Restore(Thread currentThread, ExecutionContext executionContext)
-        {
-            Debug.Assert(currentThread == Thread.CurrentThread);
-
-            ExecutionContext previous = currentThread.ExecutionContext ?? Default;
-            currentThread.ExecutionContext = executionContext;
 
-            // New EC could be null if that's what ECS.Undo saved off.
-            // For the purposes of dealing with context change, treat this as the default EC
-            executionContext = executionContext ?? Default;
-
-            if (previous != executionContext)
+            // Re-enregistrer variables post EH with 1 post-fix so they can be used in registers rather than from stack
+            SynchronizationContext previousSyncCtx1 = previousSyncCtx;
+            Thread currentThread1 = currentThread;
+            // The common case is that these have not changed, so avoid the cost of a write barrier if not needed.
+            if (currentThread1.SynchronizationContext != previousSyncCtx1)
             {
-                OnContextChanged(previous, executionContext);
+                // Restore changed SynchronizationContext back to previous
+                currentThread1.SynchronizationContext = previousSyncCtx1;
             }
-        }
 
-        internal static void EstablishCopyOnWriteScope(Thread currentThread, ref ExecutionContextSwitcher ecsw)
-        {
-            Debug.Assert(currentThread == Thread.CurrentThread);
+            ExecutionContext previousExecutionCtx1 = previousExecutionCtx;
+            ExecutionContext currentExecutionCtx1 = currentThread1.ExecutionContext;
+            if (currentExecutionCtx1 != previousExecutionCtx1)
+            {
+                // Restore changed ExecutionContext back to previous
+                currentThread1.ExecutionContext = previousExecutionCtx1;
+                if ((currentExecutionCtx1 != null && currentExecutionCtx1.HasChangeNotifications) ||
+                    (previousExecutionCtx1 != null && previousExecutionCtx1.HasChangeNotifications))
+                {
+                    // There are change notifications; trigger any affected
+                    OnValuesChanged(currentExecutionCtx1, previousExecutionCtx1);
+                }
+            }
 
-            ecsw.m_ec = currentThread.ExecutionContext;
-            ecsw.m_sc = currentThread.SynchronizationContext;
+            // If exception was thrown by callback, rethrow it now original contexts are restored
+            edi?.Throw();
         }
 
-        private static void OnContextChanged(ExecutionContext previous, ExecutionContext current)
+        internal static void OnValuesChanged(ExecutionContext previousExecutionCtx, ExecutionContext nextExecutionCtx)
         {
-            Debug.Assert(previous != null);
-            Debug.Assert(current != null);
-            Debug.Assert(previous != current);
+            Debug.Assert(previousExecutionCtx != nextExecutionCtx);
 
-            foreach (IAsyncLocal local in previous.m_localChangeNotifications)
-            {
-                object previousValue;
-                object currentValue;
-                previous.m_localValues.TryGetValue(local, out previousValue);
-                current.m_localValues.TryGetValue(local, out currentValue);
+            // Collect Change Notifications 
+            IAsyncLocal[] previousChangeNotifications = previousExecutionCtx?.m_localChangeNotifications;
+            IAsyncLocal[] nextChangeNotifications = nextExecutionCtx?.m_localChangeNotifications;
 
-                if (previousValue != currentValue)
-                    local.OnValueChanged(previousValue, currentValue, true);
-            }
+            // At least one side must have notifications
+            Debug.Assert(previousChangeNotifications != null || nextChangeNotifications != null);
 
-            if (current.m_localChangeNotifications != previous.m_localChangeNotifications)
+            // Fire Change Notifications
+            try
             {
-                try
+                if (previousChangeNotifications != null && nextChangeNotifications != null)
                 {
-                    foreach (IAsyncLocal local in current.m_localChangeNotifications)
+                    // Notifications can't exist without values
+                    Debug.Assert(previousExecutionCtx.m_localValues != null);
+                    Debug.Assert(nextExecutionCtx.m_localValues != null);
+                    // Both contexts have change notifications, check previousExecutionCtx first
+                    foreach (IAsyncLocal local in previousChangeNotifications)
                     {
-                        // If the local has a value in the previous context, we already fired the event for that local
-                        // in the code above.
-                        object previousValue;
-                        if (!previous.m_localValues.TryGetValue(local, out previousValue))
+                        previousExecutionCtx.m_localValues.TryGetValue(local, out object previousValue);
+                        nextExecutionCtx.m_localValues.TryGetValue(local, out object currentValue);
+
+                        if (previousValue != currentValue)
                         {
-                            object currentValue;
-                            current.m_localValues.TryGetValue(local, out currentValue);
+                            local.OnValueChanged(previousValue, currentValue, contextChanged: true);
+                        }
+                    }
 
-                            if (previousValue != currentValue)
-                                local.OnValueChanged(previousValue, currentValue, true);
+                    if (nextChangeNotifications != previousChangeNotifications)
+                    {
+                        // Check for additional notifications in nextExecutionCtx
+                        foreach (IAsyncLocal local in nextChangeNotifications)
+                        {
+                            // If the local has a value in the previous context, we already fired the event 
+                            // for that local in the code above.
+                            if (!previousExecutionCtx.m_localValues.TryGetValue(local, out object previousValue))
+                            {
+                                nextExecutionCtx.m_localValues.TryGetValue(local, out object currentValue);
+                                if (previousValue != currentValue)
+                                {
+                                    local.OnValueChanged(previousValue, currentValue, contextChanged: true);
+                                }
+                            }
+                        }
+                    }
+                }
+                else if (previousChangeNotifications != null)
+                {
+                    // Notifications can't exist without values
+                    Debug.Assert(previousExecutionCtx.m_localValues != null);
+                    // No current values, so just check previous against null
+                    foreach (IAsyncLocal local in previousChangeNotifications)
+                    {
+                        previousExecutionCtx.m_localValues.TryGetValue(local, out object previousValue);
+                        if (previousValue != null)
+                        {
+                            local.OnValueChanged(previousValue, null, contextChanged: true);
                         }
                     }
                 }
-                catch (Exception ex)
+                else // Implied: nextChangeNotifications != null
                 {
-                    Environment.FailFast(
-                        SR.ExecutionContext_ExceptionInAsyncLocalNotification,
-                        ex);
+                    // Notifications can't exist without values
+                    Debug.Assert(nextExecutionCtx.m_localValues != null);
+                    // No previous values, so just check current against null
+                    foreach (IAsyncLocal local in nextChangeNotifications)
+                    {
+                        nextExecutionCtx.m_localValues.TryGetValue(local, out object currentValue);
+                        if (currentValue != null)
+                        {
+                            local.OnValueChanged(null, currentValue, contextChanged: true);
+                        }
+                    }
                 }
             }
+            catch (Exception ex)
+            {
+                Environment.FailFast(
+                    SR.ExecutionContext_ExceptionInAsyncLocalNotification,
+                    ex);
+            }
+        }
+
+        [StackTraceHidden]
+        private static void ThrowNullContext()
+        {
+            throw new InvalidOperationException(SR.InvalidOperation_NullContext);
         }
 
         internal static object GetLocalValue(IAsyncLocal local)
         {
             ExecutionContext current = Thread.CurrentThread.ExecutionContext;
             if (current == null)
+            {
                 return null;
+            }
 
-            object value;
-            current.m_localValues.TryGetValue(local, out value);
+            current.m_localValues.TryGetValue(local, out object value);
             return value;
         }
 
         internal static void SetLocalValue(IAsyncLocal local, object newValue, bool needChangeNotifications)
         {
-            ExecutionContext current = Thread.CurrentThread.ExecutionContext ?? ExecutionContext.Default;
+            ExecutionContext current = Thread.CurrentThread.ExecutionContext;
 
-            object previousValue;
-            bool hadPreviousValue = current.m_localValues.TryGetValue(local, out previousValue);
+            object previousValue = null;
+            bool hadPreviousValue = false;
+            if (current != null)
+            {
+                hadPreviousValue = current.m_localValues.TryGetValue(local, out previousValue);
+            }
 
             if (previousValue == newValue)
+            {
                 return;
+            }
 
-            IAsyncLocalValueMap newValues = current.m_localValues.Set(local, newValue);
+            IAsyncLocal[] newChangeNotifications = null;
+            IAsyncLocalValueMap newValues;
+            bool isFlowSuppressed = false;
+            if (current != null)
+            {
+                isFlowSuppressed = current.m_isFlowSuppressed;
+                newValues = current.m_localValues.Set(local, newValue);
+                newChangeNotifications = current.m_localChangeNotifications;
+            }
+            else
+            {
+                // First AsyncLocal
+                newValues = new AsyncLocalValueMap.OneElementAsyncLocalValueMap(local, newValue);
+            }
 
             //
             // Either copy the change notification array, or create a new one, depending on whether we need to add a new item.
             //
-            IAsyncLocal[] newChangeNotifications = current.m_localChangeNotifications;
             if (needChangeNotifications)
             {
                 if (hadPreviousValue)
                 {
+                    Debug.Assert(newChangeNotifications != null);
                     Debug.Assert(Array.IndexOf(newChangeNotifications, local) >= 0);
                 }
+                else if (newChangeNotifications == null)
+                {
+                    newChangeNotifications = new IAsyncLocal[1] { local };
+                }
                 else
                 {
                     int newNotificationIndex = newChangeNotifications.Length;
@@ -276,12 +358,14 @@ namespace System.Threading
                 }
             }
 
-            Thread.CurrentThread.ExecutionContext =
-                new ExecutionContext(newValues, newChangeNotifications, current.m_isFlowSuppressed);
+            Thread.CurrentThread.ExecutionContext = 
+                (!isFlowSuppressed && newValues.GetType() == typeof(AsyncLocalValueMap.EmptyAsyncLocalValueMap)) ?
+                null : // No values, return to Default context
+                new ExecutionContext(newValues, newChangeNotifications, isFlowSuppressed);
 
             if (needChangeNotifications)
             {
-                local.OnValueChanged(previousValue, newValue, false);
+                local.OnValueChanged(previousValue, newValue, contextChanged: false);
             }
         }
 
@@ -331,7 +415,6 @@ namespace System.Threading
             {
                 throw new InvalidOperationException(SR.InvalidOperation_AsyncFlowCtrlCtxMismatch);
             }
-            Contract.EndContractBlock();
 
             _thread = null;
             ExecutionContext.RestoreFlow();