Merge pull request #14178 from stephentoub/async_avoid_delegate
[platform/upstream/coreclr.git] / src / mscorlib / shared / System / Threading / ExecutionContext.cs
1 // Licensed to the .NET Foundation under one or more agreements.
2 // The .NET Foundation licenses this file to you under the MIT license.
3 // See the LICENSE file in the project root for more information.
4
5 /*============================================================
6 **
7 **
8 **
9 ** Purpose: Capture execution  context for a thread
10 **
11 **
12 ===========================================================*/
13
14 using System.Diagnostics;
15 using System.Runtime.ExceptionServices;
16 using System.Runtime.Serialization;
17
18 using Thread = Internal.Runtime.Augments.RuntimeThread;
19
20 namespace System.Threading
21 {
22     public delegate void ContextCallback(Object state);
23
24     internal struct ExecutionContextSwitcher
25     {
26         internal ExecutionContext m_ec;
27         internal SynchronizationContext m_sc;
28
29         internal void Undo(Thread currentThread)
30         {
31             Debug.Assert(currentThread == Thread.CurrentThread);
32
33             // The common case is that these have not changed, so avoid the cost of a write if not needed.
34             if (currentThread.SynchronizationContext != m_sc)
35             {
36                 currentThread.SynchronizationContext = m_sc;
37             }
38
39             if (currentThread.ExecutionContext != m_ec)
40             {
41                 ExecutionContext.Restore(currentThread, m_ec);
42             }
43         }
44     }
45
46     public sealed class ExecutionContext : IDisposable, ISerializable
47     {
48         internal static readonly ExecutionContext Default = new ExecutionContext();
49
50         private readonly IAsyncLocalValueMap m_localValues;
51         private readonly IAsyncLocal[] m_localChangeNotifications;
52         private readonly bool m_isFlowSuppressed;
53
54         private ExecutionContext()
55         {
56             m_localValues = AsyncLocalValueMap.Empty;
57             m_localChangeNotifications = Array.Empty<IAsyncLocal>();
58         }
59
60         private ExecutionContext(
61             IAsyncLocalValueMap localValues,
62             IAsyncLocal[] localChangeNotifications,
63             bool isFlowSuppressed)
64         {
65             m_localValues = localValues;
66             m_localChangeNotifications = localChangeNotifications;
67             m_isFlowSuppressed = isFlowSuppressed;
68         }
69
70         public void GetObjectData(SerializationInfo info, StreamingContext context)
71         {
72             throw new PlatformNotSupportedException();
73         }
74
75         public static ExecutionContext Capture()
76         {
77             ExecutionContext executionContext = Thread.CurrentThread.ExecutionContext;
78             return
79                 executionContext == null ? Default :
80                 executionContext.m_isFlowSuppressed ? null :
81                 executionContext;
82         }
83
84         private ExecutionContext ShallowClone(bool isFlowSuppressed)
85         {
86             Debug.Assert(isFlowSuppressed != m_isFlowSuppressed);
87
88             if (!isFlowSuppressed &&
89                 m_localValues == Default.m_localValues &&
90                 m_localChangeNotifications == Default.m_localChangeNotifications)
91             {
92                 return null; // implies the default context
93             }
94             return new ExecutionContext(m_localValues, m_localChangeNotifications, isFlowSuppressed);
95         }
96
97         public static AsyncFlowControl SuppressFlow()
98         {
99             Thread currentThread = Thread.CurrentThread;
100             ExecutionContext executionContext = currentThread.ExecutionContext ?? Default;
101             if (executionContext.m_isFlowSuppressed)
102             {
103                 throw new InvalidOperationException(SR.InvalidOperation_CannotSupressFlowMultipleTimes);
104             }
105
106             executionContext = executionContext.ShallowClone(isFlowSuppressed: true);
107             var asyncFlowControl = new AsyncFlowControl();
108             currentThread.ExecutionContext = executionContext;
109             asyncFlowControl.Initialize(currentThread);
110             return asyncFlowControl;
111         }
112
113         public static void RestoreFlow()
114         {
115             Thread currentThread = Thread.CurrentThread;
116             ExecutionContext executionContext = currentThread.ExecutionContext;
117             if (executionContext == null || !executionContext.m_isFlowSuppressed)
118             {
119                 throw new InvalidOperationException(SR.InvalidOperation_CannotRestoreUnsupressedFlow);
120             }
121
122             currentThread.ExecutionContext = executionContext.ShallowClone(isFlowSuppressed: false);
123         }
124
125         public static bool IsFlowSuppressed()
126         {
127             ExecutionContext executionContext = Thread.CurrentThread.ExecutionContext;
128             return executionContext != null && executionContext.m_isFlowSuppressed;
129         }
130
131         public static void Run(ExecutionContext executionContext, ContextCallback callback, Object state)
132         {
133             if (executionContext == null)
134                 throw new InvalidOperationException(SR.InvalidOperation_NullContext);
135
136             Thread currentThread = Thread.CurrentThread;
137             ExecutionContextSwitcher ecsw = default(ExecutionContextSwitcher);
138             try
139             {
140                 EstablishCopyOnWriteScope(currentThread, ref ecsw);
141                 ExecutionContext.Restore(currentThread, executionContext);
142                 callback(state);
143             }
144             catch
145             {
146                 // Note: we have a "catch" rather than a "finally" because we want
147                 // to stop the first pass of EH here.  That way we can restore the previous
148                 // context before any of our callers' EH filters run.  That means we need to
149                 // end the scope separately in the non-exceptional case below.
150                 ecsw.Undo(currentThread);
151                 throw;
152             }
153             ecsw.Undo(currentThread);
154         }
155
156         internal static void Restore(Thread currentThread, ExecutionContext executionContext)
157         {
158             Debug.Assert(currentThread == Thread.CurrentThread);
159
160             ExecutionContext previous = currentThread.ExecutionContext ?? Default;
161             currentThread.ExecutionContext = executionContext;
162
163             // New EC could be null if that's what ECS.Undo saved off.
164             // For the purposes of dealing with context change, treat this as the default EC
165             executionContext = executionContext ?? Default;
166
167             if (previous != executionContext)
168             {
169                 OnContextChanged(previous, executionContext);
170             }
171         }
172
173         internal static void EstablishCopyOnWriteScope(Thread currentThread, ref ExecutionContextSwitcher ecsw)
174         {
175             Debug.Assert(currentThread == Thread.CurrentThread);
176
177             ecsw.m_ec = currentThread.ExecutionContext;
178             ecsw.m_sc = currentThread.SynchronizationContext;
179         }
180
181         private static void OnContextChanged(ExecutionContext previous, ExecutionContext current)
182         {
183             Debug.Assert(previous != null);
184             Debug.Assert(current != null);
185             Debug.Assert(previous != current);
186
187             foreach (IAsyncLocal local in previous.m_localChangeNotifications)
188             {
189                 object previousValue;
190                 object currentValue;
191                 previous.m_localValues.TryGetValue(local, out previousValue);
192                 current.m_localValues.TryGetValue(local, out currentValue);
193
194                 if (previousValue != currentValue)
195                     local.OnValueChanged(previousValue, currentValue, true);
196             }
197
198             if (current.m_localChangeNotifications != previous.m_localChangeNotifications)
199             {
200                 try
201                 {
202                     foreach (IAsyncLocal local in current.m_localChangeNotifications)
203                     {
204                         // If the local has a value in the previous context, we already fired the event for that local
205                         // in the code above.
206                         object previousValue;
207                         if (!previous.m_localValues.TryGetValue(local, out previousValue))
208                         {
209                             object currentValue;
210                             current.m_localValues.TryGetValue(local, out currentValue);
211
212                             if (previousValue != currentValue)
213                                 local.OnValueChanged(previousValue, currentValue, true);
214                         }
215                     }
216                 }
217                 catch (Exception ex)
218                 {
219                     Environment.FailFast(
220                         SR.ExecutionContext_ExceptionInAsyncLocalNotification,
221                         ex);
222                 }
223             }
224         }
225
226         internal static object GetLocalValue(IAsyncLocal local)
227         {
228             ExecutionContext current = Thread.CurrentThread.ExecutionContext;
229             if (current == null)
230                 return null;
231
232             object value;
233             current.m_localValues.TryGetValue(local, out value);
234             return value;
235         }
236
237         internal static void SetLocalValue(IAsyncLocal local, object newValue, bool needChangeNotifications)
238         {
239             ExecutionContext current = Thread.CurrentThread.ExecutionContext ?? ExecutionContext.Default;
240
241             object previousValue;
242             bool hadPreviousValue = current.m_localValues.TryGetValue(local, out previousValue);
243
244             if (previousValue == newValue)
245                 return;
246
247             IAsyncLocalValueMap newValues = current.m_localValues.Set(local, newValue);
248
249             //
250             // Either copy the change notification array, or create a new one, depending on whether we need to add a new item.
251             //
252             IAsyncLocal[] newChangeNotifications = current.m_localChangeNotifications;
253             if (needChangeNotifications)
254             {
255                 if (hadPreviousValue)
256                 {
257                     Debug.Assert(Array.IndexOf(newChangeNotifications, local) >= 0);
258                 }
259                 else
260                 {
261                     int newNotificationIndex = newChangeNotifications.Length;
262                     Array.Resize(ref newChangeNotifications, newNotificationIndex + 1);
263                     newChangeNotifications[newNotificationIndex] = local;
264                 }
265             }
266
267             Thread.CurrentThread.ExecutionContext =
268                 new ExecutionContext(newValues, newChangeNotifications, current.m_isFlowSuppressed);
269
270             if (needChangeNotifications)
271             {
272                 local.OnValueChanged(previousValue, newValue, false);
273             }
274         }
275
276         public ExecutionContext CreateCopy()
277         {
278             return this; // since CoreCLR's ExecutionContext is immutable, we don't need to create copies.
279         }
280
281         public void Dispose()
282         {
283             // For CLR compat only
284         }
285     }
286
287     public struct AsyncFlowControl : IDisposable
288     {
289         private Thread _thread;
290
291         internal void Initialize(Thread currentThread)
292         {
293             Debug.Assert(currentThread == Thread.CurrentThread);
294             _thread = currentThread;
295         }
296
297         public void Undo()
298         {
299             if (_thread == null)
300             {
301                 throw new InvalidOperationException(SR.InvalidOperation_CannotUseAFCMultiple);
302             }
303             if (Thread.CurrentThread != _thread)
304             {
305                 throw new InvalidOperationException(SR.InvalidOperation_CannotUseAFCOtherThread);
306             }
307
308             // An async flow control cannot be undone when a different execution context is applied. The desktop framework
309             // mutates the execution context when its state changes, and only changes the instance when an execution context
310             // is applied (for instance, through ExecutionContext.Run). The framework prevents a suppressed-flow execution
311             // context from being applied by returning null from ExecutionContext.Capture, so the only type of execution
312             // context that can be applied is one whose flow is not suppressed. After suppressing flow and changing an async
313             // local's value, the desktop framework verifies that a different execution context has not been applied by
314             // checking the execution context instance against the one saved from when flow was suppressed. In .NET Core,
315             // since the execution context instance will change after changing the async local's value, it verifies that a
316             // different execution context has not been applied, by instead ensuring that the current execution context's
317             // flow is suppressed.
318             if (!ExecutionContext.IsFlowSuppressed())
319             {
320                 throw new InvalidOperationException(SR.InvalidOperation_AsyncFlowCtrlCtxMismatch);
321             }
322
323             _thread = null;
324             ExecutionContext.RestoreFlow();
325         }
326
327         public void Dispose()
328         {
329             Undo();
330         }
331
332         public override bool Equals(object obj)
333         {
334             return obj is AsyncFlowControl && Equals((AsyncFlowControl)obj);
335         }
336
337         public bool Equals(AsyncFlowControl obj)
338         {
339             return _thread == obj._thread;
340         }
341
342         public override int GetHashCode()
343         {
344             return _thread?.GetHashCode() ?? 0;
345         }
346
347         public static bool operator ==(AsyncFlowControl a, AsyncFlowControl b)
348         {
349             return a.Equals(b);
350         }
351
352         public static bool operator !=(AsyncFlowControl a, AsyncFlowControl b)
353         {
354             return !(a == b);
355         }
356     }
357 }