--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+#nullable enable
+
+namespace System.Threading.Tasks
+{
+ /// <summary>
+ /// Task timeout helper based on https://devblogs.microsoft.com/pfxteam/crafting-a-task-timeoutafter-method/
+ /// </summary>
+ internal static class TaskTimeoutExtensions
+ {
+ public static Task WithCancellation(this Task task, CancellationToken cancellationToken)
+ {
+ if (task is null)
+ {
+ throw new ArgumentNullException(nameof(task));
+ }
+
+ if (task.IsCompleted || !cancellationToken.CanBeCanceled)
+ {
+ return task;
+ }
+
+ if (cancellationToken.IsCancellationRequested)
+ {
+ return Task.FromCanceled(cancellationToken);
+ }
+
+ return WithCancellationCore(task, cancellationToken);
+
+ static async Task WithCancellationCore(Task task, CancellationToken cancellationToken)
+ {
+ var tcs = new TaskCompletionSource();
+ using CancellationTokenRegistration _ = cancellationToken.UnsafeRegister(static s => ((TaskCompletionSource)s!).TrySetResult(), tcs);
+
+ if (task != await Task.WhenAny(task, tcs.Task).ConfigureAwait(false))
+ {
+ throw new TaskCanceledException(Task.FromCanceled(cancellationToken));
+ }
+
+ task.GetAwaiter().GetResult(); // already completed; propagate any exception
+ }
+ }
+ }
+}
-<Project Sdk="Microsoft.NET.Sdk">
+<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<DefineConstants>$(DefineConstants);FEATURE_REGISTRY</DefineConstants>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
Link="Common\Interop\Windows\Interop.Errors.cs" />
<Compile Include="$(CommonPath)System\Threading\Tasks\TaskCompletionSourceWithCancellation.cs"
Link="Common\System\Threading\Tasks\TaskCompletionSourceWithCancellation.cs" />
+ <Compile Include="$(CommonPath)System\Threading\Tasks\TaskTimeoutExtensions.cs"
+ Link="Common\System\Threading\Tasks\TaskTimeoutExtensions.cs" />
</ItemGroup>
<ItemGroup Condition=" '$(TargetsWindows)' == 'true'">
<Compile Include="$(CommonPath)Interop\Windows\Kernel32\Interop.EnumProcessModules.cs"
// Wait until we hit EOF. This is called from Process.WaitForExit
// We will lose some information if we don't do this.
- internal void WaitUtilEOF()
+ internal void WaitUntilEOF()
{
- if (_readToBufferTask != null)
+ if (_readToBufferTask is Task task)
{
- _readToBufferTask.GetAwaiter().GetResult();
- _readToBufferTask = null;
+ task.GetAwaiter().GetResult();
}
}
+ internal Task WaitUntilEOFAsync(CancellationToken cancellationToken)
+ {
+ if (_readToBufferTask is Task task)
+ {
+ return task.WithCancellation(cancellationToken);
+ }
+
+ return Task.CompletedTask;
+ }
+
public void Dispose()
{
_cts.Cancel();
{
if (_output != null)
{
- _output.WaitUtilEOF();
+ _output.WaitUntilEOF();
}
if (_error != null)
{
- _error.WaitUtilEOF();
+ _error.WaitUntilEOF();
}
}
{
// If we have a hard timeout, we cannot wait for the streams
if (_output != null && milliseconds == Timeout.Infinite)
- _output.WaitUtilEOF();
+ _output.WaitUntilEOF();
if (_error != null && milliseconds == Timeout.Infinite)
- _error.WaitUtilEOF();
+ _error.WaitUntilEOF();
handle?.Dispose();
}
// exception up to the user
if (HasExited)
{
+ await WaitUntilOutputEOF().ConfigureAwait(false);
return;
}
var tcs = new TaskCompletionSourceWithCancellation<bool>();
- EventHandler handler = (s, e) => tcs.TrySetResult(true);
+ EventHandler handler = (_, _) => tcs.TrySetResult(true);
Exited += handler;
try
if (HasExited)
{
// CASE 1.2 & CASE 3.2: Handle race where the process exits before registering the handler
- return;
+ }
+ else
+ {
+ // CASE 1.1 & CASE 3.1: Process exits or is canceled here
+ await tcs.WaitWithCancellationAsync(cancellationToken).ConfigureAwait(false);
}
- // CASE 1.1 & CASE 3.1: Process exits or is canceled here
- await tcs.WaitWithCancellationAsync(cancellationToken).ConfigureAwait(false);
+ // Wait until output streams have been drained
+ await WaitUntilOutputEOF().ConfigureAwait(false);
}
finally
{
Exited -= handler;
}
+
+ async ValueTask WaitUntilOutputEOF()
+ {
+ if (_output != null)
+ {
+ await _output.WaitUntilEOFAsync(cancellationToken).ConfigureAwait(false);
+ }
+
+ if (_error != null)
+ {
+ await _error.WaitUntilEOFAsync(cancellationToken).ConfigureAwait(false);
+ }
+ }
}
/// <devdoc>
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
+using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
}
[ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+ public void WaitForExit_AfterProcessExit_ShouldConsumeOutputDataReceived()
+ {
+ const string message = "test";
+ using Process p = CreateProcessPortable(RemotelyInvokable.Echo, message);
+
+ int linesReceived = 0;
+ p.OutputDataReceived += (_, e) => { if (e.Data is not null) linesReceived++; };
+ p.StartInfo.RedirectStandardOutput = true;
+
+ Assert.True(p.Start());
+
+ // Give time for the process (cmd) to terminate
+ while (!p.HasExited)
+ {
+ Thread.Sleep(20);
+ }
+
+ p.BeginOutputReadLine();
+ p.WaitForExit();
+
+ Assert.Equal(1, linesReceived);
+ }
+
+ [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+ public async Task WaitForExitAsync_AfterProcessExit_ShouldConsumeOutputDataReceived()
+ {
+ const string message = "test";
+ using Process p = CreateProcessPortable(RemotelyInvokable.Echo, message);
+
+ int linesReceived = 0;
+ p.OutputDataReceived += (_, e) => { if (e.Data is not null) linesReceived++; };
+ p.StartInfo.RedirectStandardOutput = true;
+
+ Assert.True(p.Start());
+
+ // Give time for the process (cmd) to terminate
+ while (!p.HasExited)
+ {
+ Thread.Sleep(20);
+ }
+
+ p.BeginOutputReadLine();
+ await p.WaitForExitAsync();
+
+ Assert.Equal(1, linesReceived);
+ }
+
+ [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
public void WaitChain()
{
Process root = CreateProcess(() =>
return line == "Success" ? SuccessExitCode : SuccessExitCode + 1;
}
+ public static int Echo(string value)
+ {
+ Console.WriteLine(value);
+ return SuccessExitCode;
+ }
+
public static int ReadLineWriteIfNull()
{
string line = Console.ReadLine();