From f976f748723a874988bd63d1f790152baccf21dc Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 9 Apr 2018 19:52:25 -0400 Subject: [PATCH] Avoid Unsafe.As usage in ValueTask that can break type safety (#17471) Unsafe.As yields a performance improvement over using a normal cast, and it's fine when ValueTask is used correctly, but if the ValueTask instance were to be stored into a field and multiple threads incorrectly raced to access it, a torn read/write could result in violating type safety due to ObjectIsTask reading the wrong value for the associated object. This commit changes the implementation to only use the single object field to determine which paths to take, rather than factoring in a second field that may not be in sync. --- .../ConfiguredValueTaskAwaitable.cs | 90 +++-- .../Runtime/CompilerServices/ValueTaskAwaiter.cs | 70 ++-- .../shared/System/Threading/Tasks/ValueTask.cs | 412 ++++++++++++--------- 3 files changed, 338 insertions(+), 234 deletions(-) diff --git a/src/mscorlib/shared/System/Runtime/CompilerServices/ConfiguredValueTaskAwaitable.cs b/src/mscorlib/shared/System/Runtime/CompilerServices/ConfiguredValueTaskAwaitable.cs index 11e6215..8f7b0c8 100644 --- a/src/mscorlib/shared/System/Runtime/CompilerServices/ConfiguredValueTaskAwaitable.cs +++ b/src/mscorlib/shared/System/Runtime/CompilerServices/ConfiguredValueTaskAwaitable.cs @@ -59,55 +59,64 @@ namespace System.Runtime.CompilerServices /// Schedules the continuation action for the . public void OnCompleted(Action continuation) { - if (_value.ObjectIsTask) + object obj = _value._obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj is Task t) { - _value.UnsafeGetTask().ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().OnCompleted(continuation); + t.ConfigureAwait(_value._continueOnCapturedContext).GetAwaiter().OnCompleted(continuation); } - else if (_value._obj != null) + else if (obj != null) { - _value.UnsafeGetValueTaskSource().OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, + Unsafe.As(obj).OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, ValueTaskSourceOnCompletedFlags.FlowExecutionContext | - (_value.ContinueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None)); + (_value._continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None)); } else { - ValueTask.CompletedTask.ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().OnCompleted(continuation); + ValueTask.CompletedTask.ConfigureAwait(_value._continueOnCapturedContext).GetAwaiter().OnCompleted(continuation); } } /// Schedules the continuation action for the . public void UnsafeOnCompleted(Action continuation) { - if (_value.ObjectIsTask) + object obj = _value._obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj is Task t) { - _value.UnsafeGetTask().ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(continuation); + t.ConfigureAwait(_value._continueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(continuation); } - else if (_value._obj != null) + else if (obj != null) { - _value.UnsafeGetValueTaskSource().OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, - _value.ContinueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None); + Unsafe.As(obj).OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, + _value._continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None); } else { - ValueTask.CompletedTask.ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(continuation); + ValueTask.CompletedTask.ConfigureAwait(_value._continueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(continuation); } } #if CORECLR void IStateMachineBoxAwareAwaiter.AwaitUnsafeOnCompleted(IAsyncStateMachineBox box) { - if (_value.ObjectIsTask) + object obj = _value._obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj is Task t) { - TaskAwaiter.UnsafeOnCompletedInternal(_value.UnsafeGetTask(), box, _value.ContinueOnCapturedContext); + TaskAwaiter.UnsafeOnCompletedInternal(t, box, _value._continueOnCapturedContext); } - else if (_value._obj != null) + else if (obj != null) { - _value.UnsafeGetValueTaskSource().OnCompleted(ValueTaskAwaiter.s_invokeAsyncStateMachineBox, box, _value._token, - _value.ContinueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None); + Unsafe.As(obj).OnCompleted(ValueTaskAwaiter.s_invokeAsyncStateMachineBox, box, _value._token, + _value._continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None); } else { - TaskAwaiter.UnsafeOnCompletedInternal(Task.CompletedTask, box, _value.ContinueOnCapturedContext); + TaskAwaiter.UnsafeOnCompletedInternal(Task.CompletedTask, box, _value._continueOnCapturedContext); } } #endif @@ -161,55 +170,64 @@ namespace System.Runtime.CompilerServices /// Schedules the continuation action for the . public void OnCompleted(Action continuation) { - if (_value.ObjectIsTask) + object obj = _value._obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj is Task t) { - _value.UnsafeGetTask().ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().OnCompleted(continuation); + t.ConfigureAwait(_value._continueOnCapturedContext).GetAwaiter().OnCompleted(continuation); } - else if (_value._obj != null) + else if (obj != null) { - _value.UnsafeGetValueTaskSource().OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, + Unsafe.As>(obj).OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, ValueTaskSourceOnCompletedFlags.FlowExecutionContext | - (_value.ContinueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None)); + (_value._continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None)); } else { - ValueTask.CompletedTask.ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().OnCompleted(continuation); + ValueTask.CompletedTask.ConfigureAwait(_value._continueOnCapturedContext).GetAwaiter().OnCompleted(continuation); } } /// Schedules the continuation action for the . public void UnsafeOnCompleted(Action continuation) { - if (_value.ObjectIsTask) + object obj = _value._obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj is Task t) { - _value.UnsafeGetTask().ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(continuation); + t.ConfigureAwait(_value._continueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(continuation); } - else if (_value._obj != null) + else if (obj != null) { - _value.UnsafeGetValueTaskSource().OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, - _value.ContinueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None); + Unsafe.As>(obj).OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, + _value._continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None); } else { - ValueTask.CompletedTask.ConfigureAwait(_value.ContinueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(continuation); + ValueTask.CompletedTask.ConfigureAwait(_value._continueOnCapturedContext).GetAwaiter().UnsafeOnCompleted(continuation); } } #if CORECLR void IStateMachineBoxAwareAwaiter.AwaitUnsafeOnCompleted(IAsyncStateMachineBox box) { - if (_value.ObjectIsTask) + object obj = _value._obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj is Task t) { - TaskAwaiter.UnsafeOnCompletedInternal(_value.UnsafeGetTask(), box, _value.ContinueOnCapturedContext); + TaskAwaiter.UnsafeOnCompletedInternal(t, box, _value._continueOnCapturedContext); } - else if (_value._obj != null) + else if (obj != null) { - _value.UnsafeGetValueTaskSource().OnCompleted(ValueTaskAwaiter.s_invokeAsyncStateMachineBox, box, _value._token, - _value.ContinueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None); + Unsafe.As>(obj).OnCompleted(ValueTaskAwaiter.s_invokeAsyncStateMachineBox, box, _value._token, + _value._continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None); } else { - TaskAwaiter.UnsafeOnCompletedInternal(Task.CompletedTask, box, _value.ContinueOnCapturedContext); + TaskAwaiter.UnsafeOnCompletedInternal(Task.CompletedTask, box, _value._continueOnCapturedContext); } } #endif diff --git a/src/mscorlib/shared/System/Runtime/CompilerServices/ValueTaskAwaiter.cs b/src/mscorlib/shared/System/Runtime/CompilerServices/ValueTaskAwaiter.cs index 31955b6..02b5910 100644 --- a/src/mscorlib/shared/System/Runtime/CompilerServices/ValueTaskAwaiter.cs +++ b/src/mscorlib/shared/System/Runtime/CompilerServices/ValueTaskAwaiter.cs @@ -6,6 +6,10 @@ using System.Diagnostics; using System.Threading.Tasks; using System.Threading.Tasks.Sources; +#if !netstandard +using Internal.Runtime.CompilerServices; +#endif + namespace System.Runtime.CompilerServices { /// Provides an awaiter for a . @@ -48,13 +52,16 @@ namespace System.Runtime.CompilerServices /// Schedules the continuation action for this ValueTask. public void OnCompleted(Action continuation) { - if (_value.ObjectIsTask) + object obj = _value._obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj is Task t) { - _value.UnsafeGetTask().GetAwaiter().OnCompleted(continuation); + t.GetAwaiter().OnCompleted(continuation); } - else if (_value._obj != null) + else if (obj != null) { - _value.UnsafeGetValueTaskSource().OnCompleted(s_invokeActionDelegate, continuation, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext | ValueTaskSourceOnCompletedFlags.FlowExecutionContext); + Unsafe.As(obj).OnCompleted(s_invokeActionDelegate, continuation, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext | ValueTaskSourceOnCompletedFlags.FlowExecutionContext); } else { @@ -65,13 +72,16 @@ namespace System.Runtime.CompilerServices /// Schedules the continuation action for this ValueTask. public void UnsafeOnCompleted(Action continuation) { - if (_value.ObjectIsTask) + object obj = _value._obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj is Task t) { - _value.UnsafeGetTask().GetAwaiter().UnsafeOnCompleted(continuation); + t.GetAwaiter().UnsafeOnCompleted(continuation); } - else if (_value._obj != null) + else if (obj != null) { - _value.UnsafeGetValueTaskSource().OnCompleted(s_invokeActionDelegate, continuation, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext); + Unsafe.As(obj).OnCompleted(s_invokeActionDelegate, continuation, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext); } else { @@ -82,13 +92,16 @@ namespace System.Runtime.CompilerServices #if CORECLR void IStateMachineBoxAwareAwaiter.AwaitUnsafeOnCompleted(IAsyncStateMachineBox box) { - if (_value.ObjectIsTask) + object obj = _value._obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj is Task t) { - TaskAwaiter.UnsafeOnCompletedInternal(_value.UnsafeGetTask(), box, continueOnCapturedContext: true); + TaskAwaiter.UnsafeOnCompletedInternal(t, box, continueOnCapturedContext: true); } - else if (_value._obj != null) + else if (obj != null) { - _value.UnsafeGetValueTaskSource().OnCompleted(s_invokeAsyncStateMachineBox, box, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext); + Unsafe.As(obj).OnCompleted(s_invokeAsyncStateMachineBox, box, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext); } else { @@ -139,13 +152,16 @@ namespace System.Runtime.CompilerServices /// Schedules the continuation action for this ValueTask. public void OnCompleted(Action continuation) { - if (_value.ObjectIsTask) + object obj = _value._obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj is Task t) { - _value.UnsafeGetTask().GetAwaiter().OnCompleted(continuation); + t.GetAwaiter().OnCompleted(continuation); } - else if (_value._obj != null) + else if (obj != null) { - _value.UnsafeGetValueTaskSource().OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext | ValueTaskSourceOnCompletedFlags.FlowExecutionContext); + Unsafe.As>(obj).OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext | ValueTaskSourceOnCompletedFlags.FlowExecutionContext); } else { @@ -156,13 +172,16 @@ namespace System.Runtime.CompilerServices /// Schedules the continuation action for this ValueTask. public void UnsafeOnCompleted(Action continuation) { - if (_value.ObjectIsTask) + object obj = _value._obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj is Task t) { - _value.UnsafeGetTask().GetAwaiter().UnsafeOnCompleted(continuation); + t.GetAwaiter().UnsafeOnCompleted(continuation); } - else if (_value._obj != null) + else if (obj != null) { - _value.UnsafeGetValueTaskSource().OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext); + Unsafe.As>(obj).OnCompleted(ValueTaskAwaiter.s_invokeActionDelegate, continuation, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext); } else { @@ -173,13 +192,16 @@ namespace System.Runtime.CompilerServices #if CORECLR void IStateMachineBoxAwareAwaiter.AwaitUnsafeOnCompleted(IAsyncStateMachineBox box) { - if (_value.ObjectIsTask) + object obj = _value._obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj is Task t) { - TaskAwaiter.UnsafeOnCompletedInternal(_value.UnsafeGetTask(), box, continueOnCapturedContext: true); + TaskAwaiter.UnsafeOnCompletedInternal(t, box, continueOnCapturedContext: true); } - else if (_value._obj != null) + else if (obj != null) { - _value.UnsafeGetValueTaskSource().OnCompleted(ValueTaskAwaiter.s_invokeAsyncStateMachineBox, box, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext); + Unsafe.As>(obj).OnCompleted(ValueTaskAwaiter.s_invokeAsyncStateMachineBox, box, _value._token, ValueTaskSourceOnCompletedFlags.UseSchedulingContext); } else { diff --git a/src/mscorlib/shared/System/Threading/Tasks/ValueTask.cs b/src/mscorlib/shared/System/Threading/Tasks/ValueTask.cs index 56d5f54..53f746b 100644 --- a/src/mscorlib/shared/System/Threading/Tasks/ValueTask.cs +++ b/src/mscorlib/shared/System/Threading/Tasks/ValueTask.cs @@ -14,6 +14,22 @@ using Internal.Runtime.CompilerServices; namespace System.Threading.Tasks { + // TYPE SAFETY WARNING: + // This code uses Unsafe.As to cast _obj. This is done in order to minimize the costs associated with + // casting _obj to a variety of different types that can be stored in a ValueTask, e.g. Task + // vs IValueTaskSource. Previous attempts at this were faulty due to using a separate field + // to store information about the type of the object in _obj; this is faulty because if the ValueTask + // is stored into a field, concurrent read/writes can result in tearing the _obj from the type information + // stored in a separate field. This means we can rely only on the _obj field to determine how to handle + // it. As such, the pattern employed is to copy _obj into a local obj, and then check it for null and + // type test against Task/Task. Since the ValueTask can only be constructed with null, Task, + // or IValueTaskSource, we can then be confident in knowing that if it doesn't match one of those values, + // it must be an IValueTaskSource, and we can use Unsafe.As. This could be defeated by other unsafe means, + // like private reflection or using Unsafe.As manually, but at that point you're already doing things + // that can violate type safety; we only care about getting correct behaviors when using "safe" code. + // There are still other race conditions in user's code that can result in errors, but such errors don't + // cause ValueTask to violate type safety. + /// Provides an awaitable result of an asynchronous operation. /// /// s are meant to be directly awaited. To do more complicated operations with them, a @@ -42,10 +58,11 @@ namespace System.Threading.Tasks /// null if representing a successful synchronous completion, otherwise a or a . internal readonly object _obj; - /// Flags providing additional details about the ValueTask's contents and behavior. - internal readonly ValueTaskFlags _flags; /// Opaque value passed through to the . internal readonly short _token; + /// true to continue on the capture context; otherwise, true. + /// Stored in the rather than in the configured awaiter to utilize otherwise padding space. + internal readonly bool _continueOnCapturedContext; // An instance created with the default ctor (a zero init'd struct) represents a synchronously, successfully completed operation. @@ -61,7 +78,7 @@ namespace System.Threading.Tasks _obj = task; - _flags = ValueTaskFlags.ObjectIsTask; + _continueOnCapturedContext = true; _token = 0; } @@ -79,51 +96,15 @@ namespace System.Threading.Tasks _obj = source; _token = token; - _flags = 0; + _continueOnCapturedContext = true; } - /// Non-verified initialization of the struct to the specified values. - /// The object. - /// The token. - /// The flags. [MethodImpl(MethodImplOptions.AggressiveInlining)] - private ValueTask(object obj, short token, ValueTaskFlags flags) + private ValueTask(object obj, short token, bool continueOnCapturedContext) { _obj = obj; _token = token; - _flags = flags; - } - - /// Gets whether the contination should be scheduled to the current context. - internal bool ContinueOnCapturedContext - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => (_flags & ValueTaskFlags.AvoidCapturedContext) == 0; - } - - /// Gets whether the object in the field is a . - internal bool ObjectIsTask - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => (_flags & ValueTaskFlags.ObjectIsTask) != 0; - } - - /// Returns the stored in . This uses . - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal Task UnsafeGetTask() - { - Debug.Assert(ObjectIsTask); - Debug.Assert(_obj is Task); - return Unsafe.As(_obj); - } - - /// Returns the stored in . This uses . - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal IValueTaskSource UnsafeGetValueTaskSource() - { - Debug.Assert(!ObjectIsTask); - Debug.Assert(_obj is IValueTaskSource); - return Unsafe.As(_obj); + _continueOnCapturedContext = continueOnCapturedContext; } /// Returns the hash code for this instance. @@ -152,18 +133,26 @@ namespace System.Threading.Tasks /// It will either return the wrapped task object if one exists, or it'll /// manufacture a new task object to represent the result. /// - public Task AsTask() => - _obj == null ? ValueTask.CompletedTask : - ObjectIsTask ? UnsafeGetTask() : - GetTaskForValueTaskSource(); + public Task AsTask() + { + object obj = _obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + return + obj == null ? CompletedTask : + obj as Task ?? + GetTaskForValueTaskSource(Unsafe.As(obj)); + } /// Gets a that may be used at any point in the future. public ValueTask Preserve() => _obj == null ? this : new ValueTask(AsTask()); /// Creates a to represent the . - private Task GetTaskForValueTaskSource() + /// + /// The is passed in rather than reading and casting + /// so that the caller can pass in an object it's already validated. + /// + private Task GetTaskForValueTaskSource(IValueTaskSource t) { - IValueTaskSource t = UnsafeGetValueTaskSource(); ValueTaskSourceStatus status = t.GetStatus(_token); if (status != ValueTaskSourceStatus.Pending) { @@ -172,7 +161,7 @@ namespace System.Threading.Tasks // Propagate any exceptions that may have occurred, then return // an already successfully completed task. t.GetResult(_token); - return ValueTask.CompletedTask; + return CompletedTask; // If status is Faulted or Canceled, GetResult should throw. But // we can't guarantee every implementation will do the "right thing". @@ -206,7 +195,7 @@ namespace System.Threading.Tasks } } - var m = new ValueTaskSourceTask(t, _token); + var m = new ValueTaskSourceAsTask(t, _token); return #if netstandard m.Task; @@ -216,7 +205,7 @@ namespace System.Threading.Tasks } /// Type used to create a to represent a . - private sealed class ValueTaskSourceTask : + private sealed class ValueTaskSourceAsTask : #if netstandard TaskCompletionSource #else @@ -225,7 +214,7 @@ namespace System.Threading.Tasks { private static readonly Action s_completionAction = state => { - if (!(state is ValueTaskSourceTask vtst) || + if (!(state is ValueTaskSourceAsTask vtst) || !(vtst._source is IValueTaskSource source)) { // This could only happen if the IValueTaskSource passed the wrong state @@ -271,7 +260,7 @@ namespace System.Threading.Tasks /// The token to pass through to operations on private readonly short _token; - public ValueTaskSourceTask(IValueTaskSource source, short token) + public ValueTaskSourceAsTask(IValueTaskSource source, short token) { _token = token; _source = source; @@ -283,30 +272,73 @@ namespace System.Threading.Tasks public bool IsCompleted { [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => _obj == null || (ObjectIsTask ? UnsafeGetTask().IsCompleted : UnsafeGetValueTaskSource().GetStatus(_token) != ValueTaskSourceStatus.Pending); + get + { + object obj = _obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj == null) + { + return true; + } + + if (obj is Task t) + { + return t.IsCompleted; + } + + return Unsafe.As(obj).GetStatus(_token) != ValueTaskSourceStatus.Pending; + } } /// Gets whether the represents a successfully completed operation. public bool IsCompletedSuccessfully { [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => - _obj == null || - (ObjectIsTask ? + get + { + object obj = _obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj == null) + { + return true; + } + + if (obj is Task t) + { + return #if netstandard - UnsafeGetTask().Status == TaskStatus.RanToCompletion : + t.Status == TaskStatus.RanToCompletion; #else - UnsafeGetTask().IsCompletedSuccessfully : + t.IsCompletedSuccessfully; #endif - UnsafeGetValueTaskSource().GetStatus(_token) == ValueTaskSourceStatus.Succeeded); + } + + return Unsafe.As(obj).GetStatus(_token) == ValueTaskSourceStatus.Succeeded; + } } /// Gets whether the represents a failed operation. public bool IsFaulted { - get => - _obj != null && - (ObjectIsTask ? UnsafeGetTask().IsFaulted : UnsafeGetValueTaskSource().GetStatus(_token) == ValueTaskSourceStatus.Faulted); + get + { + object obj = _obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj == null) + { + return false; + } + + if (obj is Task t) + { + return t.IsFaulted; + } + + return Unsafe.As(obj).GetStatus(_token) == ValueTaskSourceStatus.Faulted; + } } /// Gets whether the represents a canceled operation. @@ -317,9 +349,23 @@ namespace System.Threading.Tasks /// public bool IsCanceled { - get => - _obj != null && - (ObjectIsTask ? UnsafeGetTask().IsCanceled : UnsafeGetValueTaskSource().GetStatus(_token) == ValueTaskSourceStatus.Canceled); + get + { + object obj = _obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj == null) + { + return false; + } + + if (obj is Task t) + { + return t.IsCanceled; + } + + return Unsafe.As(obj).GetStatus(_token) == ValueTaskSourceStatus.Canceled; + } } /// Throws the exception that caused the to fail. If it completed successfully, nothing is thrown. @@ -327,19 +373,22 @@ namespace System.Threading.Tasks [StackTraceHidden] internal void ThrowIfCompletedUnsuccessfully() { - if (_obj != null) + object obj = _obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj != null) { - if (ObjectIsTask) + if (obj is Task t) { #if netstandard - UnsafeGetTask().GetAwaiter().GetResult(); + t.GetAwaiter().GetResult(); #else - TaskAwaiter.ValidateEnd(UnsafeGetTask()); + TaskAwaiter.ValidateEnd(t); #endif } else { - UnsafeGetValueTaskSource().GetResult(_token); + Unsafe.As(obj).GetResult(_token); } } } @@ -352,12 +401,8 @@ namespace System.Threading.Tasks /// true to attempt to marshal the continuation back to the captured context; otherwise, false. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ConfiguredValueTaskAwaitable ConfigureAwait(bool continueOnCapturedContext) - { - // TODO: Simplify once https://github.com/dotnet/coreclr/pull/16138 is fixed. - bool avoidCapture = !continueOnCapturedContext; - return new ConfiguredValueTaskAwaitable(new ValueTask(_obj, _token, _flags | Unsafe.As(ref avoidCapture))); - } + public ConfiguredValueTaskAwaitable ConfigureAwait(bool continueOnCapturedContext) => + new ConfiguredValueTaskAwaitable(new ValueTask(_obj, _token, continueOnCapturedContext)); } /// Provides a value type that can represent a synchronously available value or a task object. @@ -378,10 +423,11 @@ namespace System.Threading.Tasks internal readonly object _obj; /// The result to be used if the operation completed successfully synchronously. internal readonly TResult _result; - /// Flags providing additional details about the ValueTask's contents and behavior. - internal readonly ValueTaskFlags _flags; /// Opaque value passed through to the . internal readonly short _token; + /// true to continue on the captured context; otherwise, false. + /// Stored in the rather than in the configured awaiter to utilize otherwise padding space. + internal readonly bool _continueOnCapturedContext; // An instance created with the default ctor (a zero init'd struct) represents a synchronously, successfully completed operation // with a result of default(TResult). @@ -394,7 +440,7 @@ namespace System.Threading.Tasks _result = result; _obj = null; - _flags = 0; + _continueOnCapturedContext = true; _token = 0; } @@ -411,7 +457,7 @@ namespace System.Threading.Tasks _obj = task; _result = default; - _flags = ValueTaskFlags.ObjectIsTask; + _continueOnCapturedContext = true; _token = 0; } @@ -430,54 +476,23 @@ namespace System.Threading.Tasks _token = token; _result = default; - _flags = 0; + _continueOnCapturedContext = true; } /// Non-verified initialization of the struct to the specified values. /// The object. /// The result. /// The token. - /// The flags. + /// true to continue on captured context; otherwise, false. [MethodImpl(MethodImplOptions.AggressiveInlining)] - private ValueTask(object obj, TResult result, short token, ValueTaskFlags flags) + private ValueTask(object obj, TResult result, short token, bool continueOnCapturedContext) { _obj = obj; _result = result; _token = token; - _flags = flags; + _continueOnCapturedContext = continueOnCapturedContext; } - /// Gets whether the contination should be scheduled to the current context. - internal bool ContinueOnCapturedContext - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => (_flags & ValueTaskFlags.AvoidCapturedContext) == 0; - } - - /// Gets whether the object in the field is a . - internal bool ObjectIsTask - { - [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => (_flags & ValueTaskFlags.ObjectIsTask) != 0; - } - - /// Returns the stored in . This uses . - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal Task UnsafeGetTask() - { - Debug.Assert(ObjectIsTask); - Debug.Assert(_obj is Task); - return Unsafe.As>(_obj); - } - - /// Returns the stored in . This uses . - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal IValueTaskSource UnsafeGetValueTaskSource() - { - Debug.Assert(!ObjectIsTask); - Debug.Assert(_obj is IValueTaskSource); - return Unsafe.As>(_obj); - } /// Returns the hash code for this instance. public override int GetHashCode() => @@ -511,23 +526,39 @@ namespace System.Threading.Tasks /// It will either return the wrapped task object if one exists, or it'll /// manufacture a new task object to represent the result. /// - public Task AsTask() => - _obj == null ? + public Task AsTask() + { + object obj = _obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj == null) + { + return #if netstandard - Task.FromResult(_result) : + Task.FromResult(_result); #else - AsyncTaskMethodBuilder.GetTaskForResult(_result) : + AsyncTaskMethodBuilder.GetTaskForResult(_result); #endif - ObjectIsTask ? UnsafeGetTask() : - GetTaskForValueTaskSource(); + } + + if (obj is Task t) + { + return t; + } + + return GetTaskForValueTaskSource(Unsafe.As>(obj)); + } /// Gets a that may be used at any point in the future. public ValueTask Preserve() => _obj == null ? this : new ValueTask(AsTask()); /// Creates a to represent the . - private Task GetTaskForValueTaskSource() + /// + /// The is passed in rather than reading and casting + /// so that the caller can pass in an object it's already validated. + /// + private Task GetTaskForValueTaskSource(IValueTaskSource t) { - IValueTaskSource t = UnsafeGetValueTaskSource(); ValueTaskSourceStatus status = t.GetStatus(_token); if (status != ValueTaskSourceStatus.Pending) { @@ -588,7 +619,7 @@ namespace System.Threading.Tasks } } - var m = new ValueTaskSourceTask(t, _token); + var m = new ValueTaskSourceAsTask(t, _token); return #if netstandard m.Task; @@ -598,7 +629,7 @@ namespace System.Threading.Tasks } /// Type used to create a to represent a . - private sealed class ValueTaskSourceTask : + private sealed class ValueTaskSourceAsTask : #if netstandard TaskCompletionSource #else @@ -607,7 +638,7 @@ namespace System.Threading.Tasks { private static readonly Action s_completionAction = state => { - if (!(state is ValueTaskSourceTask vtst) || + if (!(state is ValueTaskSourceAsTask vtst) || !(vtst._source is IValueTaskSource source)) { // This could only happen if the IValueTaskSource passed the wrong state @@ -652,7 +683,7 @@ namespace System.Threading.Tasks /// The token to pass through to operations on private readonly short _token; - public ValueTaskSourceTask(IValueTaskSource source, short token) + public ValueTaskSourceAsTask(IValueTaskSource source, short token) { _source = source; _token = token; @@ -664,30 +695,73 @@ namespace System.Threading.Tasks public bool IsCompleted { [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => _obj == null || (ObjectIsTask ? UnsafeGetTask().IsCompleted : UnsafeGetValueTaskSource().GetStatus(_token) != ValueTaskSourceStatus.Pending); + get + { + object obj = _obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj == null) + { + return true; + } + + if (obj is Task t) + { + return t.IsCompleted; + } + + return Unsafe.As>(obj).GetStatus(_token) != ValueTaskSourceStatus.Pending; + } } /// Gets whether the represents a successfully completed operation. public bool IsCompletedSuccessfully { [MethodImpl(MethodImplOptions.AggressiveInlining)] - get => - _obj == null || - (ObjectIsTask ? + get + { + object obj = _obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj == null) + { + return true; + } + + if (obj is Task t) + { + return #if netstandard - UnsafeGetTask().Status == TaskStatus.RanToCompletion : + t.Status == TaskStatus.RanToCompletion; #else - UnsafeGetTask().IsCompletedSuccessfully : + t.IsCompletedSuccessfully; #endif - UnsafeGetValueTaskSource().GetStatus(_token) == ValueTaskSourceStatus.Succeeded); + } + + return Unsafe.As>(obj).GetStatus(_token) == ValueTaskSourceStatus.Succeeded; + } } /// Gets whether the represents a failed operation. public bool IsFaulted { - get => - _obj != null && - (ObjectIsTask ? UnsafeGetTask().IsFaulted : UnsafeGetValueTaskSource().GetStatus(_token) == ValueTaskSourceStatus.Faulted); + get + { + object obj = _obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj == null) + { + return false; + } + + if (obj is Task t) + { + return t.IsFaulted; + } + + return Unsafe.As>(obj).GetStatus(_token) == ValueTaskSourceStatus.Faulted; + } } /// Gets whether the represents a canceled operation. @@ -698,9 +772,23 @@ namespace System.Threading.Tasks /// public bool IsCanceled { - get => - _obj != null && - (ObjectIsTask ? UnsafeGetTask().IsCanceled : UnsafeGetValueTaskSource().GetStatus(_token) == ValueTaskSourceStatus.Canceled); + get + { + object obj = _obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj == null) + { + return false; + } + + if (obj is Task t) + { + return t.IsCanceled; + } + + return Unsafe.As>(obj).GetStatus(_token) == ValueTaskSourceStatus.Canceled; + } } /// Gets the result. @@ -709,23 +797,25 @@ namespace System.Threading.Tasks [MethodImpl(MethodImplOptions.AggressiveInlining)] get { - if (_obj == null) + object obj = _obj; + Debug.Assert(obj == null || obj is Task || obj is IValueTaskSource); + + if (obj == null) { return _result; } - if (ObjectIsTask) + if (obj is Task t) { #if netstandard - return UnsafeGetTask().GetAwaiter().GetResult(); + return t.GetAwaiter().GetResult(); #else - Task t = UnsafeGetTask(); TaskAwaiter.ValidateEnd(t); return t.ResultOnSuccess; #endif } - return UnsafeGetValueTaskSource().GetResult(_token); + return Unsafe.As>(obj).GetResult(_token); } } @@ -738,12 +828,8 @@ namespace System.Threading.Tasks /// true to attempt to marshal the continuation back to the captured context; otherwise, false. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - public ConfiguredValueTaskAwaitable ConfigureAwait(bool continueOnCapturedContext) - { - // TODO: Simplify once https://github.com/dotnet/coreclr/pull/16138 is fixed. - bool avoidCapture = !continueOnCapturedContext; - return new ConfiguredValueTaskAwaitable(new ValueTask(_obj, _result, _token, _flags | Unsafe.As(ref avoidCapture))); - } + public ConfiguredValueTaskAwaitable ConfigureAwait(bool continueOnCapturedContext) => + new ConfiguredValueTaskAwaitable(new ValueTask(_obj, _result, _token, continueOnCapturedContext)); /// Gets a string-representation of this . public override string ToString() @@ -760,26 +846,4 @@ namespace System.Threading.Tasks return string.Empty; } } - - /// Internal flags used in the implementation of and . - [Flags] - internal enum ValueTaskFlags : byte - { - /// - /// Indicates that context (e.g. SynchronizationContext) should not be captured when adding - /// a continuation. - /// - /// - /// The value here must be 0x1, to match the value of a true Boolean reinterpreted as a byte. - /// This only has meaning when awaiting a ValueTask, with ConfigureAwait creating a new - /// ValueTask setting or not setting this flag appropriately. - /// - AvoidCapturedContext = 0x1, - - /// - /// Indicates that the ValueTask's object field stores a Task. This is used to avoid - /// a type check on whatever is stored in the object field. - /// - ObjectIsTask = 0x2 - } } -- 2.7.4