Make Process.WaitForExitAsync wait for redirected output reads (#42585)
authorEirik Tsarpalis <eirik.tsarpalis@gmail.com>
Mon, 28 Sep 2020 16:22:44 +0000 (17:22 +0100)
committerGitHub <noreply@github.com>
Mon, 28 Sep 2020 16:22:44 +0000 (17:22 +0100)
* Make Process.WaitForExitAsync wait for output reads

Addresses an issue where Process.WaitForExitAsync
doesn't wait for background redirected output reads,
a behaviour which diverges from the sync method
equivalent. Fixes #42556.

* address feedback

* pass cancellation token to AsyncStreamReader waiter

* address feedback

* use tcs.TrySetResult

* Update src/libraries/Common/src/System/Threading/Tasks/TaskTimeoutExtensions.cs

Co-authored-by: Stephen Toub <stoub@microsoft.com>
Co-authored-by: Stephen Toub <stoub@microsoft.com>
src/libraries/Common/src/System/Threading/Tasks/TaskTimeoutExtensions.cs [new file with mode: 0644]
src/libraries/System.Diagnostics.Process/src/System.Diagnostics.Process.csproj
src/libraries/System.Diagnostics.Process/src/System/Diagnostics/AsyncStreamReader.cs
src/libraries/System.Diagnostics.Process/src/System/Diagnostics/Process.Unix.cs
src/libraries/System.Diagnostics.Process/src/System/Diagnostics/Process.Windows.cs
src/libraries/System.Diagnostics.Process/src/System/Diagnostics/Process.cs
src/libraries/System.Diagnostics.Process/tests/ProcessWaitingTests.cs
src/libraries/System.Diagnostics.Process/tests/RemotelyInvokable.cs

diff --git a/src/libraries/Common/src/System/Threading/Tasks/TaskTimeoutExtensions.cs b/src/libraries/Common/src/System/Threading/Tasks/TaskTimeoutExtensions.cs
new file mode 100644 (file)
index 0000000..d485dec
--- /dev/null
@@ -0,0 +1,46 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+#nullable enable
+
+namespace System.Threading.Tasks
+{
+    /// <summary>
+    /// Task timeout helper based on https://devblogs.microsoft.com/pfxteam/crafting-a-task-timeoutafter-method/
+    /// </summary>
+    internal static class TaskTimeoutExtensions
+    {
+        public static Task WithCancellation(this Task task, CancellationToken cancellationToken)
+        {
+            if (task is null)
+            {
+                throw new ArgumentNullException(nameof(task));
+            }
+
+            if (task.IsCompleted || !cancellationToken.CanBeCanceled)
+            {
+                return task;
+            }
+
+            if (cancellationToken.IsCancellationRequested)
+            {
+                return Task.FromCanceled(cancellationToken);
+            }
+
+            return WithCancellationCore(task, cancellationToken);
+
+            static async Task WithCancellationCore(Task task, CancellationToken cancellationToken)
+            {
+                var tcs = new TaskCompletionSource();
+                using CancellationTokenRegistration _ = cancellationToken.UnsafeRegister(static s => ((TaskCompletionSource)s!).TrySetResult(), tcs);
+
+                if (task != await Task.WhenAny(task, tcs.Task).ConfigureAwait(false))
+                {
+                    throw new TaskCanceledException(Task.FromCanceled(cancellationToken));
+                }
+
+                task.GetAwaiter().GetResult(); // already completed; propagate any exception
+            }
+        }
+    }
+}
index 947d0f2..a2132a7 100644 (file)
@@ -1,4 +1,4 @@
-<Project Sdk="Microsoft.NET.Sdk">
+<Project Sdk="Microsoft.NET.Sdk">
   <PropertyGroup>
     <DefineConstants>$(DefineConstants);FEATURE_REGISTRY</DefineConstants>
     <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
@@ -39,6 +39,8 @@
              Link="Common\Interop\Windows\Interop.Errors.cs" />
     <Compile Include="$(CommonPath)System\Threading\Tasks\TaskCompletionSourceWithCancellation.cs"
              Link="Common\System\Threading\Tasks\TaskCompletionSourceWithCancellation.cs" />
+    <Compile Include="$(CommonPath)System\Threading\Tasks\TaskTimeoutExtensions.cs"
+         Link="Common\System\Threading\Tasks\TaskTimeoutExtensions.cs" />
   </ItemGroup>
   <ItemGroup Condition=" '$(TargetsWindows)' == 'true'">
     <Compile Include="$(CommonPath)Interop\Windows\Kernel32\Interop.EnumProcessModules.cs"
index 63f6f39..83640ff 100644 (file)
@@ -253,15 +253,24 @@ namespace System.Diagnostics
 
         // Wait until we hit EOF. This is called from Process.WaitForExit
         // We will lose some information if we don't do this.
-        internal void WaitUtilEOF()
+        internal void WaitUntilEOF()
         {
-            if (_readToBufferTask != null)
+            if (_readToBufferTask is Task task)
             {
-                _readToBufferTask.GetAwaiter().GetResult();
-                _readToBufferTask = null;
+                task.GetAwaiter().GetResult();
             }
         }
 
+        internal Task WaitUntilEOFAsync(CancellationToken cancellationToken)
+        {
+            if (_readToBufferTask is Task task)
+            {
+                return task.WithCancellation(cancellationToken);
+            }
+
+            return Task.CompletedTask;
+        }
+
         public void Dispose()
         {
             _cts.Cancel();
index 5096a2a..1fb5535 100644 (file)
@@ -207,11 +207,11 @@ namespace System.Diagnostics
             {
                 if (_output != null)
                 {
-                    _output.WaitUtilEOF();
+                    _output.WaitUntilEOF();
                 }
                 if (_error != null)
                 {
-                    _error.WaitUtilEOF();
+                    _error.WaitUntilEOF();
                 }
             }
 
index b0ab947..44323f9 100644 (file)
@@ -180,10 +180,10 @@ namespace System.Diagnostics
             {
                 // If we have a hard timeout, we cannot wait for the streams
                 if (_output != null && milliseconds == Timeout.Infinite)
-                    _output.WaitUtilEOF();
+                    _output.WaitUntilEOF();
 
                 if (_error != null && milliseconds == Timeout.Infinite)
-                    _error.WaitUtilEOF();
+                    _error.WaitUntilEOF();
 
                 handle?.Dispose();
             }
index 5422827..99184a2 100644 (file)
@@ -1452,6 +1452,7 @@ namespace System.Diagnostics
                 // exception up to the user
                 if (HasExited)
                 {
+                    await WaitUntilOutputEOF().ConfigureAwait(false);
                     return;
                 }
 
@@ -1460,7 +1461,7 @@ namespace System.Diagnostics
 
             var tcs = new TaskCompletionSourceWithCancellation<bool>();
 
-            EventHandler handler = (s, e) => tcs.TrySetResult(true);
+            EventHandler handler = (_, _) => tcs.TrySetResult(true);
             Exited += handler;
 
             try
@@ -1468,16 +1469,33 @@ namespace System.Diagnostics
                 if (HasExited)
                 {
                     // CASE 1.2 & CASE 3.2: Handle race where the process exits before registering the handler
-                    return;
+                }
+                else
+                {
+                    // CASE 1.1 & CASE 3.1: Process exits or is canceled here
+                    await tcs.WaitWithCancellationAsync(cancellationToken).ConfigureAwait(false);
                 }
 
-                // CASE 1.1 & CASE 3.1: Process exits or is canceled here
-                await tcs.WaitWithCancellationAsync(cancellationToken).ConfigureAwait(false);
+                // Wait until output streams have been drained
+                await WaitUntilOutputEOF().ConfigureAwait(false);
             }
             finally
             {
                 Exited -= handler;
             }
+
+            async ValueTask WaitUntilOutputEOF()
+            {
+                if (_output != null)
+                {
+                    await _output.WaitUntilEOFAsync(cancellationToken).ConfigureAwait(false);
+                }
+
+                if (_error != null)
+                {
+                    await _error.WaitUntilEOFAsync(cancellationToken).ConfigureAwait(false);
+                }
+            }
         }
 
         /// <devdoc>
index 13935f8..bb0edda 100644 (file)
@@ -1,6 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Collections.Generic;
 using System.IO;
 using System.Linq;
 using System.Threading;
@@ -493,6 +494,54 @@ namespace System.Diagnostics.Tests
         }
 
         [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        public void WaitForExit_AfterProcessExit_ShouldConsumeOutputDataReceived()
+        {
+            const string message = "test";
+            using Process p = CreateProcessPortable(RemotelyInvokable.Echo, message);
+
+            int linesReceived = 0;
+            p.OutputDataReceived += (_, e) => { if (e.Data is not null) linesReceived++; };
+            p.StartInfo.RedirectStandardOutput = true;
+
+            Assert.True(p.Start());
+
+            // Give time for the process (cmd) to terminate
+            while (!p.HasExited)
+            {
+                Thread.Sleep(20);
+            }
+
+            p.BeginOutputReadLine();
+            p.WaitForExit();
+
+            Assert.Equal(1, linesReceived);
+        }
+
+        [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        public async Task WaitForExitAsync_AfterProcessExit_ShouldConsumeOutputDataReceived()
+        {
+            const string message = "test";
+            using Process p = CreateProcessPortable(RemotelyInvokable.Echo, message);
+
+            int linesReceived = 0;
+            p.OutputDataReceived += (_, e) => { if (e.Data is not null) linesReceived++; };
+            p.StartInfo.RedirectStandardOutput = true;
+
+            Assert.True(p.Start());
+
+            // Give time for the process (cmd) to terminate
+            while (!p.HasExited)
+            {
+                Thread.Sleep(20);
+            }
+
+            p.BeginOutputReadLine();
+            await p.WaitForExitAsync();
+
+            Assert.Equal(1, linesReceived);
+        }
+
+        [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
         public void WaitChain()
         {
             Process root = CreateProcess(() =>
index ede15af..39b1536 100644 (file)
@@ -66,6 +66,12 @@ namespace System.Diagnostics.Tests
             return line == "Success" ? SuccessExitCode : SuccessExitCode + 1;
         }
 
+        public static int Echo(string value)
+        {
+            Console.WriteLine(value);
+            return SuccessExitCode;
+        }
+
         public static int ReadLineWriteIfNull()
         {
             string line = Console.ReadLine();