Make WindowsServiceLifetime gracefully stop (#83892)
authorEric StJohn <ericstj@microsoft.com>
Thu, 6 Apr 2023 04:42:17 +0000 (21:42 -0700)
committerGitHub <noreply@github.com>
Thu, 6 Apr 2023 04:42:17 +0000 (21:42 -0700)
* Make WindowsServiceLifetime gracefully stop

WindowsServiceLifetime was not waiting for ServiceBase to stop the service.  As a result
we would sometimes end the process before notifying service control manager that the service
had stopped -- resulting in an error in the eventlog and sometimes a service restart.

We also were permitting multiple calls to Stop to occur - through SCM callbacks, and through
public API.  We must not call SetServiceStatus again once the service is marked as stopped.

* Alternate approach to ensuring we only ever set STATE_STOPPED once.

* Avoid calling ServiceBase.Stop on stopped service

I fixed double-calling STATE_STOPPED in ServiceBase, but this fix will
not be present on .NETFramework.  Workaround that by avoiding calling
ServiceBase.Stop when the service has already been stopped by SCM.

* Add tests for WindowsServiceLifetime

These tests leverage RemoteExecutor to avoid creating a separate service
assembly.

* Respond to feedback and add more tests.

This better integrates with the RemoteExecutor component as well,
by hooking up the service process and fetching its handle.

This gives us the correct logging and exitcode handling from
RemoteExecutor.

* Honor Cancellation in StopAsync

* Fix bindingRedirects in RemoteExecutor

* Use Async lambdas for service testing

* Fix issue on Win7 where duplicate service descriptions are disallowed

* Respond to feedback

* Fix comment and add timeout

eng/testing/xunit/xunit.targets
src/libraries/Common/src/Interop/Windows/Advapi32/Interop.QueryServiceStatusEx.cs [new file with mode: 0644]
src/libraries/Common/src/Interop/Windows/Interop.Errors.cs
src/libraries/Microsoft.Extensions.Hosting.WindowsServices/src/WindowsServiceLifetime.cs
src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/Microsoft.Extensions.Hosting.WindowsServices.Tests.csproj
src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/UseWindowsServiceTests.cs
src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/WindowsServiceLifetimeTests.cs [new file with mode: 0644]
src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/WindowsServiceTester.cs [new file with mode: 0644]
src/libraries/System.ServiceProcess.ServiceController/src/System/ServiceProcess/ServiceBase.cs

index 6b048e6..e72ebd4 100644 (file)
@@ -6,6 +6,11 @@
                       Condition="'$(TargetFrameworkIdentifier)' == '.NETCoreApp'" />
   </ItemGroup>
 
+  <PropertyGroup Condition="'$(TargetFrameworkIdentifier)' == '.NETFramework'">
+    <AutoGenerateBindingRedirects Condition="'$(AutoGenerateBindingRedirects)' == ''">true</AutoGenerateBindingRedirects>
+    <GenerateBindingRedirectsOutputType Condition="'$(GenerateBindingRedirectsOutputType)' == ''">true</GenerateBindingRedirectsOutputType>
+  </PropertyGroup>
+
   <!-- Run target (F5) support. -->
   <PropertyGroup>
     <RunWorkingDirectory Condition="'$(RunWorkingDirectory)' == ''">$(OutDir)</RunWorkingDirectory>
diff --git a/src/libraries/Common/src/Interop/Windows/Advapi32/Interop.QueryServiceStatusEx.cs b/src/libraries/Common/src/Interop/Windows/Advapi32/Interop.QueryServiceStatusEx.cs
new file mode 100644 (file)
index 0000000..8c38dec
--- /dev/null
@@ -0,0 +1,34 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using Microsoft.Win32.SafeHandles;
+using System;
+using System.Runtime.InteropServices;
+
+internal static partial class Interop
+{
+    internal static partial class Advapi32
+    {
+        [StructLayout(LayoutKind.Sequential)]
+        internal struct SERVICE_STATUS_PROCESS
+        {
+            public int dwServiceType;
+            public int dwCurrentState;
+            public int dwControlsAccepted;
+            public int dwWin32ExitCode;
+            public int dwServiceSpecificExitCode;
+            public int dwCheckPoint;
+            public int dwWaitHint;
+            public int dwProcessId;
+            public int dwServiceFlags;
+        }
+
+        private const int SC_STATUS_PROCESS_INFO = 0;
+
+        [LibraryImport(Libraries.Advapi32, SetLastError = true)]
+        [return: MarshalAs(UnmanagedType.Bool)]
+        private static unsafe partial bool QueryServiceStatusEx(SafeServiceHandle serviceHandle, int InfoLevel, SERVICE_STATUS_PROCESS* pStatus, int cbBufSize, out int pcbBytesNeeded);
+
+        internal static unsafe bool QueryServiceStatusEx(SafeServiceHandle serviceHandle, SERVICE_STATUS_PROCESS* pStatus) => QueryServiceStatusEx(serviceHandle, SC_STATUS_PROCESS_INFO, pStatus, sizeof(SERVICE_STATUS_PROCESS), out _);
+    }
+}
index cde3ae0..c810603 100644 (file)
@@ -64,6 +64,8 @@ internal static partial class Interop
         internal const int ERROR_IO_PENDING = 0x3E5;
         internal const int ERROR_NO_TOKEN = 0x3f0;
         internal const int ERROR_SERVICE_DOES_NOT_EXIST = 0x424;
+        internal const int ERROR_EXCEPTION_IN_SERVICE = 0x428;
+        internal const int ERROR_PROCESS_ABORTED = 0x42B;
         internal const int ERROR_NO_UNICODE_TRANSLATION = 0x459;
         internal const int ERROR_DLL_INIT_FAILED = 0x45A;
         internal const int ERROR_COUNTER_TIMEOUT = 0x461;
index d39ddd8..7edd3ea 100644 (file)
@@ -18,8 +18,10 @@ namespace Microsoft.Extensions.Hosting.WindowsServices
     public class WindowsServiceLifetime : ServiceBase, IHostLifetime
     {
         private readonly TaskCompletionSource<object?> _delayStart = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
+        private readonly TaskCompletionSource<object?> _serviceDispatcherStopped = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
         private readonly ManualResetEventSlim _delayStop = new ManualResetEventSlim();
         private readonly HostOptions _hostOptions;
+        private bool _serviceStopRequested;
 
         /// <summary>
         /// Initializes a new <see cref="WindowsServiceLifetime"/> instance.
@@ -87,19 +89,30 @@ namespace Microsoft.Extensions.Hosting.WindowsServices
             {
                 Run(this); // This blocks until the service is stopped.
                 _delayStart.TrySetException(new InvalidOperationException("Stopped without starting"));
+                _serviceDispatcherStopped.TrySetResult(null);
             }
             catch (Exception ex)
             {
                 _delayStart.TrySetException(ex);
+                _serviceDispatcherStopped.TrySetException(ex);
             }
         }
 
-        public Task StopAsync(CancellationToken cancellationToken)
+        /// <summary>
+        /// Called from <see cref="IHost.StopAsync"/> to stop the service if not already stopped, and wait for the service dispatcher to exit.
+        /// Once this method returns the service is stopped and the process can be terminated at any time.
+        /// </summary>
+        public async Task StopAsync(CancellationToken cancellationToken)
         {
-            // Avoid deadlock where host waits for StopAsync before firing ApplicationStopped,
-            // and Stop waits for ApplicationStopped.
-            Task.Run(Stop, CancellationToken.None);
-            return Task.CompletedTask;
+            cancellationToken.ThrowIfCancellationRequested();
+
+            if (!_serviceStopRequested)
+            {
+                await Task.Run(Stop, cancellationToken).ConfigureAwait(false);
+            }
+
+            // When the underlying service is stopped this will cause the ServiceBase.Run method to complete and return, which completes _serviceDispatcherStopped.
+            await _serviceDispatcherStopped.Task.ConfigureAwait(false);
         }
 
         // Called by base.Run when the service is ready to start.
@@ -111,11 +124,13 @@ namespace Microsoft.Extensions.Hosting.WindowsServices
         }
 
         /// <summary>
-        /// Raises the Stop event to stop the <see cref="WindowsServiceLifetime"/>.
+        /// Executes when a Stop command is sent to the service by the Service Control Manager (SCM).
+        /// Triggers <see cref="IHostApplicationLifetime.ApplicationStopping"/> and waits for <see cref="IHostApplicationLifetime.ApplicationStopped"/>.
+        /// Shortly after this method returns, the Service will be marked as stopped in SCM and the process may exit at any point.
         /// </summary>
-        /// <remarks>This might be called multiple times by service Stop, ApplicationStopping, and StopAsync. That's okay because StopApplication uses a CancellationTokenSource and prevents any recursion.</remarks>
         protected override void OnStop()
         {
+            _serviceStopRequested = true;
             ApplicationLifetime.StopApplication();
             // Wait for the host to shutdown before marking service as stopped.
             _delayStop.Wait(_hostOptions.ShutdownTimeout);
@@ -123,10 +138,13 @@ namespace Microsoft.Extensions.Hosting.WindowsServices
         }
 
         /// <summary>
-        /// Raises the Shutdown event.
+        /// Executes when a Shutdown command is sent to the service by the Service Control Manager (SCM).
+        /// Triggers <see cref="IHostApplicationLifetime.ApplicationStopping"/> and waits for <see cref="IHostApplicationLifetime.ApplicationStopped"/>.
+        /// Shortly after this method returns, the Service will be marked as stopped in SCM and the process may exit at any point.
         /// </summary>
         protected override void OnShutdown()
         {
+            _serviceStopRequested = true;
             ApplicationLifetime.StopApplication();
             // Wait for the host to shutdown before marking service as stopped.
             _delayStop.Wait(_hostOptions.ShutdownTimeout);
index 93be9b8..ee433d9 100644 (file)
@@ -4,12 +4,45 @@
     <!-- Use "$(NetCoreAppCurrent)-windows" to avoid PlatformNotSupportedExceptions from ServiceController. -->
     <TargetFrameworks>$(NetCoreAppCurrent)-windows;$(NetFrameworkMinimum)</TargetFrameworks> 
     <EnableDefaultItems>true</EnableDefaultItems>
+    <EnableLibraryImportGenerator>true</EnableLibraryImportGenerator>
+    <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
+    <IncludeRemoteExecutor>true</IncludeRemoteExecutor>
   </PropertyGroup>
 
   <ItemGroup>
     <ProjectReference Include="..\src\Microsoft.Extensions.Hosting.WindowsServices.csproj" />
   </ItemGroup>
 
+  <ItemGroup>
+    <Compile Include="$(LibrariesProjectRoot)System.ServiceProcess.ServiceController\src\Microsoft\Win32\SafeHandles\SafeServiceHandle.cs"
+             Link="Microsoft\Win32\SafeHandles\SafeServiceHandle.cs" />
+    <Compile Include="$(CommonPath)DisableRuntimeMarshalling.cs"
+             Link="Common\DisableRuntimeMarshalling.cs"
+             Condition="'$(TargetFrameworkIdentifier)' == '.NETCoreApp'" />
+    <Compile Include="$(CommonPath)Interop\Windows\Interop.Errors.cs"
+             Link="Common\Interop\Windows\Interop.Errors.cs" />
+    <Compile Include="$(CommonPath)Interop\Windows\Interop.Libraries.cs"
+             Link="Common\Interop\Windows\Interop.Libraries.cs" />
+    <Compile Include="$(CommonPath)Interop\Windows\Advapi32\Interop.ServiceProcessOptions.cs"
+             Link="Common\Interop\Windows\Interop.ServiceProcessOptions.cs" />
+    <Compile Include="$(CommonPath)Interop\Windows\Advapi32\Interop.CloseServiceHandle.cs"
+             Link="Common\Interop\Windows\Interop.CloseServiceHandle.cs" />
+    <Compile Include="$(CommonPath)Interop\Windows\Advapi32\Interop.CreateService.cs"
+             Link="Common\Interop\Windows\Interop.CreateService.cs" />
+    <Compile Include="$(CommonPath)Interop\Windows\Advapi32\Interop.DeleteService.cs"
+             Link="Common\Interop\Windows\Interop.DeleteService.cs" />
+    <Compile Include="$(CommonPath)Interop\Windows\Advapi32\Interop.OpenService.cs"
+             Link="Common\Interop\Windows\Interop.OpenService.cs" />
+    <Compile Include="$(CommonPath)Interop\Windows\Advapi32\Interop.OpenSCManager.cs"
+             Link="Common\Interop\Windows\Interop.OpenSCManager.cs" />
+    <Compile Include="$(CommonPath)Interop\Windows\Advapi32\Interop.QueryServiceStatus.cs"
+             Link="Common\Interop\Windows\Interop.QueryServiceStatus.cs" />
+    <Compile Include="$(CommonPath)Interop\Windows\Advapi32\Interop.QueryServiceStatusEx.cs"
+             Link="Common\Interop\Windows\Interop.QueryServiceStatusEx.cs" />
+    <Compile Include="$(CommonPath)Interop\Windows\Advapi32\Interop.SERVICE_STATUS.cs"
+             Link="Common\Interop\Windows\Interop.SERVICE_STATUS.cs" />
+  </ItemGroup>
+
   <ItemGroup Condition="'$(TargetFrameworkIdentifier)' == '.NETFramework'">
     <Reference Include="System.ServiceProcess" />
   </ItemGroup>
index 1fb2ade..c18d503 100644 (file)
@@ -2,7 +2,6 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
-using System.IO;
 using System.Reflection;
 using System.ServiceProcess;
 using Microsoft.Extensions.DependencyInjection;
@@ -30,6 +29,26 @@ namespace Microsoft.Extensions.Hosting
             Assert.IsType<ConsoleLifetime>(lifetime);
         }
 
+        [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))]
+        public void CanCreateService()
+        {
+            using var serviceTester = WindowsServiceTester.Create(() =>
+            {
+                using IHost host = new HostBuilder()
+                    .UseWindowsService()
+                    .Build();
+                host.Run();
+            });
+
+            serviceTester.Start();
+            serviceTester.WaitForStatus(ServiceControllerStatus.Running);
+            serviceTester.Stop();
+            serviceTester.WaitForStatus(ServiceControllerStatus.Stopped);
+
+            var status = serviceTester.QueryServiceStatus();
+            Assert.Equal(0, status.win32ExitCode);
+        }
+
         [Fact]
         public void ServiceCollectionExtensionMethodDefaultsToOffOutsideOfService()
         {
@@ -66,7 +85,7 @@ namespace Microsoft.Extensions.Hosting
             var builder = new HostApplicationBuilder(new HostApplicationBuilderSettings
             {
                 ApplicationName = appName,
-            }); 
+            });
 
             // Emulate calling builder.Services.AddWindowsService() from inside a Windows service.
             AddWindowsServiceLifetime(builder.Services);
@@ -82,7 +101,7 @@ namespace Microsoft.Extensions.Hosting
         [Fact]
         public void ServiceCollectionExtensionMethodCanBeCalledOnDefaultConfiguration()
         {
-            var builder = new HostApplicationBuilder(); 
+            var builder = new HostApplicationBuilder();
 
             // Emulate calling builder.Services.AddWindowsService() from inside a Windows service.
             AddWindowsServiceLifetime(builder.Services);
diff --git a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/WindowsServiceLifetimeTests.cs b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/WindowsServiceLifetimeTests.cs
new file mode 100644 (file)
index 0000000..06679b3
--- /dev/null
@@ -0,0 +1,338 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Diagnostics;
+using System.IO;
+using System.ServiceProcess;
+using System.Threading;
+using System.Threading.Tasks;
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Hosting.Internal;
+using Microsoft.Extensions.Hosting.WindowsServices;
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+using Microsoft.Extensions.Options;
+using Xunit;
+
+namespace Microsoft.Extensions.Hosting
+{
+    public class WindowsServiceLifetimeTests
+    {
+        [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))]
+        public void ServiceStops()
+        {
+            using var serviceTester = WindowsServiceTester.Create(async () =>
+            {
+                var applicationLifetime = new ApplicationLifetime(NullLogger<ApplicationLifetime>.Instance);
+                using var lifetime = new WindowsServiceLifetime(
+                    new HostingEnvironment(), 
+                    applicationLifetime,
+                    NullLoggerFactory.Instance,
+                    new OptionsWrapper<HostOptions>(new HostOptions()));
+
+                await lifetime.WaitForStartAsync(CancellationToken.None);
+                
+                // would normally occur here, but WindowsServiceLifetime does not depend on it.
+                // applicationLifetime.NotifyStarted();
+                
+                // will be signaled by WindowsServiceLifetime when SCM stops the service.
+                applicationLifetime.ApplicationStopping.WaitHandle.WaitOne();
+
+                // required by WindowsServiceLifetime to identify that app has stopped.
+                applicationLifetime.NotifyStopped();
+
+                await lifetime.StopAsync(CancellationToken.None);
+            });
+
+            serviceTester.Start();
+            serviceTester.WaitForStatus(ServiceControllerStatus.Running);
+
+            var statusEx = serviceTester.QueryServiceStatusEx();
+            var serviceProcess = Process.GetProcessById(statusEx.dwProcessId);
+
+            serviceTester.Stop();
+            serviceTester.WaitForStatus(ServiceControllerStatus.Stopped);
+            
+            serviceProcess.WaitForExit();
+
+            var status = serviceTester.QueryServiceStatus();
+            Assert.Equal(0, status.win32ExitCode);
+        }
+
+        [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))]
+        [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework, ".NET Framework is missing the fix from https://github.com/dotnet/corefx/commit/3e68d791066ad0fdc6e0b81828afbd9df00dd7f8")]
+        public void ExceptionOnStartIsPropagated()
+        {
+            using var serviceTester = WindowsServiceTester.Create(async () =>
+            {
+                using (var lifetime = ThrowingWindowsServiceLifetime.Create(throwOnStart: new Exception("Should be thrown")))
+                {
+                    Assert.Equal(lifetime.ThrowOnStart,
+                            await Assert.ThrowsAsync<Exception>(async () => 
+                                await lifetime.WaitForStartAsync(CancellationToken.None)));
+                }
+            });
+
+            serviceTester.Start();
+
+            serviceTester.WaitForStatus(ServiceControllerStatus.Stopped);
+            var status = serviceTester.QueryServiceStatus();
+            Assert.Equal(Interop.Errors.ERROR_EXCEPTION_IN_SERVICE, status.win32ExitCode);
+        }
+
+        [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))]
+        public void ExceptionOnStopIsPropagated()
+        {
+            using var serviceTester = WindowsServiceTester.Create(async () =>
+            {
+                using (var lifetime = ThrowingWindowsServiceLifetime.Create(throwOnStop: new Exception("Should be thrown")))
+                {
+                    await lifetime.WaitForStartAsync(CancellationToken.None);
+                    lifetime.ApplicationLifetime.NotifyStopped();
+                    Assert.Equal(lifetime.ThrowOnStop,
+                            await Assert.ThrowsAsync<Exception>( async () => 
+                                await lifetime.StopAsync(CancellationToken.None)));
+                }
+            });
+
+            serviceTester.Start();
+
+            serviceTester.WaitForStatus(ServiceControllerStatus.Stopped);
+            var status = serviceTester.QueryServiceStatus();
+            Assert.Equal(Interop.Errors.ERROR_PROCESS_ABORTED, status.win32ExitCode);
+        }
+
+        [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))]
+        public void CancelStopAsync()
+        {
+            using var serviceTester = WindowsServiceTester.Create(async () =>
+            {
+                var applicationLifetime = new ApplicationLifetime(NullLogger<ApplicationLifetime>.Instance);
+                using var lifetime = new WindowsServiceLifetime(
+                    new HostingEnvironment(), 
+                    applicationLifetime,
+                    NullLoggerFactory.Instance,
+                    new OptionsWrapper<HostOptions>(new HostOptions()));
+                await lifetime.WaitForStartAsync(CancellationToken.None);
+                
+                await Assert.ThrowsAsync<OperationCanceledException>(async () => await lifetime.StopAsync(new CancellationToken(true)));
+            });
+
+            serviceTester.Start();
+
+            serviceTester.WaitForStatus(ServiceControllerStatus.Stopped);
+            var status = serviceTester.QueryServiceStatus();
+            Assert.Equal(Interop.Errors.ERROR_PROCESS_ABORTED, status.win32ExitCode);
+        }
+
+        [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))]
+        public void ServiceCanStopItself()
+        {
+            using (var serviceTester = WindowsServiceTester.Create(async () =>
+            {
+                FileLogger.InitializeForTestCase(nameof(ServiceCanStopItself));
+                using IHost host = new HostBuilder()
+                    .ConfigureServices(services =>
+                    {
+                        services.AddHostedService<LoggingBackgroundService>();
+                        services.AddSingleton<IHostLifetime, LoggingWindowsServiceLifetime>();
+                    })
+                    .Build();
+
+                var applicationLifetime = host.Services.GetRequiredService<IHostApplicationLifetime>();
+                applicationLifetime.ApplicationStarted.Register(() => FileLogger.Log($"lifetime started"));
+                applicationLifetime.ApplicationStopping.Register(() => FileLogger.Log($"lifetime stopping"));
+                applicationLifetime.ApplicationStopped.Register(() => FileLogger.Log($"lifetime stopped"));
+
+                FileLogger.Log("host.Start()");
+                host.Start();
+
+                FileLogger.Log("host.Stop()");
+                await host.StopAsync();
+                FileLogger.Log("host.Stop() complete");
+            }))
+            {
+                FileLogger.DeleteLog(nameof(ServiceCanStopItself));
+
+                // service should start cleanly
+                serviceTester.Start();
+                
+                // service will proceed to stopped without any error
+                serviceTester.WaitForStatus(ServiceControllerStatus.Stopped);
+            
+                var status = serviceTester.QueryServiceStatus();
+                Assert.Equal(0, status.win32ExitCode);
+
+            }
+
+            var logText = FileLogger.ReadLog(nameof(ServiceCanStopItself));
+            Assert.Equal("""
+                host.Start()
+                WindowsServiceLifetime.OnStart
+                BackgroundService.StartAsync
+                lifetime started
+                host.Stop()
+                lifetime stopping
+                BackgroundService.StopAsync
+                lifetime stopped
+                WindowsServiceLifetime.OnStop
+                host.Stop() complete
+
+                """, logText);
+        }
+
+        [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))]
+        public void ServiceSequenceIsCorrect()
+        {
+            using (var serviceTester = WindowsServiceTester.Create(() =>
+            {
+                FileLogger.InitializeForTestCase(nameof(ServiceSequenceIsCorrect));
+                using IHost host = new HostBuilder()
+                    .ConfigureServices(services =>
+                    {
+                        services.AddHostedService<LoggingBackgroundService>();
+                        services.AddSingleton<IHostLifetime, LoggingWindowsServiceLifetime>();
+                    })
+                    .Build();
+
+                var applicationLifetime = host.Services.GetRequiredService<IHostApplicationLifetime>();
+                applicationLifetime.ApplicationStarted.Register(() => FileLogger.Log($"lifetime started"));
+                applicationLifetime.ApplicationStopping.Register(() => FileLogger.Log($"lifetime stopping"));
+                applicationLifetime.ApplicationStopped.Register(() => FileLogger.Log($"lifetime stopped"));
+
+                FileLogger.Log("host.Run()");
+                host.Run();
+                FileLogger.Log("host.Run() complete");
+            }))
+            {
+
+                FileLogger.DeleteLog(nameof(ServiceSequenceIsCorrect));
+
+                serviceTester.Start();
+                serviceTester.WaitForStatus(ServiceControllerStatus.Running);
+
+                var statusEx = serviceTester.QueryServiceStatusEx();
+                var serviceProcess = Process.GetProcessById(statusEx.dwProcessId);
+
+                // Give a chance for all asynchronous "started" events to be raised, these happen after the service status changes to started 
+                Thread.Sleep(1000);
+
+                serviceTester.Stop();
+                serviceTester.WaitForStatus(ServiceControllerStatus.Stopped);
+                
+                var status = serviceTester.QueryServiceStatus();
+                Assert.Equal(0, status.win32ExitCode);
+
+            }
+
+            var logText = FileLogger.ReadLog(nameof(ServiceSequenceIsCorrect));
+            Assert.Equal("""
+                host.Run()
+                WindowsServiceLifetime.OnStart
+                BackgroundService.StartAsync
+                lifetime started
+                WindowsServiceLifetime.OnStop
+                lifetime stopping
+                BackgroundService.StopAsync
+                lifetime stopped
+                host.Run() complete
+
+                """, logText);
+
+        }
+
+        public class LoggingWindowsServiceLifetime : WindowsServiceLifetime
+        {
+            public LoggingWindowsServiceLifetime(IHostEnvironment environment, IHostApplicationLifetime applicationLifetime, ILoggerFactory loggerFactory, IOptions<HostOptions> optionsAccessor) :
+                base(environment, applicationLifetime, loggerFactory, optionsAccessor)
+            { }
+
+            protected override void OnStart(string[] args)
+            {
+                FileLogger.Log("WindowsServiceLifetime.OnStart");
+                base.OnStart(args);
+            }
+
+            protected override void OnStop()
+            {
+                FileLogger.Log("WindowsServiceLifetime.OnStop");
+                base.OnStop();
+            }
+        }
+
+        public class ThrowingWindowsServiceLifetime : WindowsServiceLifetime
+        {
+            public static ThrowingWindowsServiceLifetime Create(Exception throwOnStart = null, Exception throwOnStop = null) => 
+                    new ThrowingWindowsServiceLifetime(
+                        new HostingEnvironment(), 
+                        new ApplicationLifetime(NullLogger<ApplicationLifetime>.Instance),
+                        NullLoggerFactory.Instance,
+                        new OptionsWrapper<HostOptions>(new HostOptions()))
+                    {
+                        ThrowOnStart = throwOnStart,
+                        ThrowOnStop = throwOnStop
+                    };
+
+            public ThrowingWindowsServiceLifetime(IHostEnvironment environment, ApplicationLifetime applicationLifetime, ILoggerFactory loggerFactory, IOptions<HostOptions> optionsAccessor) :
+                base(environment, applicationLifetime, loggerFactory, optionsAccessor)
+            { 
+                ApplicationLifetime = applicationLifetime;
+            }
+
+            public ApplicationLifetime ApplicationLifetime { get; }
+
+            public Exception ThrowOnStart { get; set; }
+            protected override void OnStart(string[] args)
+            {
+                if (ThrowOnStart != null)
+                {
+                    throw ThrowOnStart;
+                }
+                base.OnStart(args);
+            }
+
+            public Exception ThrowOnStop { get; set; }
+            protected override void OnStop()
+            {
+                if (ThrowOnStop != null)
+                {
+                    throw ThrowOnStop;
+                }
+                base.OnStop();
+            }
+        }
+
+        public class LoggingBackgroundService : BackgroundService
+        {
+#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
+            protected override async Task ExecuteAsync(CancellationToken stoppingToken) => FileLogger.Log("BackgroundService.ExecuteAsync");
+            public override async Task StartAsync(CancellationToken stoppingToken) => FileLogger.Log("BackgroundService.StartAsync");
+            public override async Task StopAsync(CancellationToken stoppingToken) => FileLogger.Log("BackgroundService.StopAsync");
+#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
+        }
+
+        static class FileLogger
+        {
+            static string _fileName;
+
+            public static void InitializeForTestCase(string testCaseName)
+            {
+                Assert.Null(_fileName);
+                _fileName = GetLogForTestCase(testCaseName);
+            }
+
+            private static string GetLogForTestCase(string testCaseName) => Path.Combine(AppContext.BaseDirectory, $"{testCaseName}.log");
+            public static void DeleteLog(string testCaseName) => File.Delete(GetLogForTestCase(testCaseName));
+            public static string ReadLog(string testCaseName) => File.ReadAllText(GetLogForTestCase(testCaseName));
+            public static void Log(string message)
+            {
+                Assert.NotNull(_fileName);
+                lock (_fileName)
+                {
+                    File.AppendAllText(_fileName, message + Environment.NewLine);
+                }
+            }
+        }
+    }
+}
diff --git a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/WindowsServiceTester.cs b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/WindowsServiceTester.cs
new file mode 100644 (file)
index 0000000..895b4a8
--- /dev/null
@@ -0,0 +1,158 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.ComponentModel;
+using System.Diagnostics;
+using System.Runtime.CompilerServices;
+using System.ServiceProcess;
+using System.Threading.Tasks;
+using Microsoft.DotNet.RemoteExecutor;
+using Microsoft.Win32.SafeHandles;
+using Xunit;
+
+namespace Microsoft.Extensions.Hosting
+{
+    public class WindowsServiceTester : ServiceController
+    {
+        private WindowsServiceTester(SafeServiceHandle serviceHandle, RemoteInvokeHandle remoteInvokeHandle, string serviceName) : base(serviceName)
+        {
+            _serviceHandle = serviceHandle;
+            _remoteInvokeHandle = remoteInvokeHandle;
+        }
+
+        private SafeServiceHandle _serviceHandle;
+        private RemoteInvokeHandle _remoteInvokeHandle;
+
+        public new void Start()
+        {
+            Start(Array.Empty<string>());
+        }
+
+        public new void Start(string[] args)
+        {
+            base.Start(args);
+
+            // get the process
+            _remoteInvokeHandle.Process.Dispose();
+            _remoteInvokeHandle.Process = null;
+
+            var statusEx = QueryServiceStatusEx();
+            try
+            {
+                _remoteInvokeHandle.Process = Process.GetProcessById(statusEx.dwProcessId);
+                // fetch the process handle so that we can get the exit code later.
+                var _ = _remoteInvokeHandle.Process.SafeHandle;
+            }
+            catch (ArgumentException)
+            { }
+        }
+
+        public TimeSpan WaitForStatusTimeout { get; set; } = TimeSpan.FromSeconds(30);
+
+        public new void WaitForStatus(ServiceControllerStatus desiredStatus) =>
+            WaitForStatus(desiredStatus, WaitForStatusTimeout);
+
+        public new void WaitForStatus(ServiceControllerStatus desiredStatus, TimeSpan timeout)
+        {
+            base.WaitForStatus(desiredStatus, timeout);
+
+            Assert.Equal(Status, desiredStatus);
+        }
+
+        // the following overloads are necessary to ensure the compiler will produce the correct signature from a lambda.
+        public static WindowsServiceTester Create(Func<Task> serviceMain, [CallerMemberName] string serviceName = null) => Create(RemoteExecutor.Invoke(serviceMain, remoteInvokeOptions), serviceName);
+        
+        public static WindowsServiceTester Create(Func<Task<int>> serviceMain, [CallerMemberName] string serviceName = null) => Create(RemoteExecutor.Invoke(serviceMain, remoteInvokeOptions), serviceName);
+
+        public static WindowsServiceTester Create(Func<int> serviceMain, [CallerMemberName] string serviceName = null) => Create(RemoteExecutor.Invoke(serviceMain, remoteInvokeOptions), serviceName);
+        
+        public static WindowsServiceTester Create(Action serviceMain, [CallerMemberName] string serviceName = null) => Create(RemoteExecutor.Invoke(serviceMain, remoteInvokeOptions), serviceName);
+
+        private static RemoteInvokeOptions remoteInvokeOptions = new RemoteInvokeOptions() { Start = false };
+
+        private static WindowsServiceTester Create(RemoteInvokeHandle remoteInvokeHandle, string serviceName)
+        {
+            // create remote executor commandline arguments
+            var startInfo = remoteInvokeHandle.Process.StartInfo;
+            string commandLine = startInfo.FileName + " " + startInfo.Arguments;
+
+            // install the service
+            using (var serviceManagerHandle = new SafeServiceHandle(Interop.Advapi32.OpenSCManager(null, null, Interop.Advapi32.ServiceControllerOptions.SC_MANAGER_ALL)))
+            {
+                if (serviceManagerHandle.IsInvalid)
+                {
+                    throw new InvalidOperationException();
+                }
+
+                // delete existing service if it exists
+                using (var existingServiceHandle = new SafeServiceHandle(Interop.Advapi32.OpenService(serviceManagerHandle, serviceName, Interop.Advapi32.ServiceAccessOptions.ACCESS_TYPE_ALL)))
+                {
+                    if (!existingServiceHandle.IsInvalid)
+                    {
+                        Interop.Advapi32.DeleteService(existingServiceHandle);
+                    }
+                }
+
+                var serviceHandle = new SafeServiceHandle(
+                    Interop.Advapi32.CreateService(serviceManagerHandle,
+                    serviceName,
+                    $"{nameof(WindowsServiceTester)} {serviceName} test service",
+                    Interop.Advapi32.ServiceAccessOptions.ACCESS_TYPE_ALL,
+                    Interop.Advapi32.ServiceTypeOptions.SERVICE_WIN32_OWN_PROCESS,
+                    (int)ServiceStartMode.Manual,
+                    Interop.Advapi32.ServiceStartErrorModes.ERROR_CONTROL_NORMAL,
+                    commandLine,
+                    loadOrderGroup: null,
+                    pTagId: IntPtr.Zero,
+                    dependencies: null,
+                    servicesStartName: null,
+                    password: null));
+
+                if (serviceHandle.IsInvalid)
+                {
+                    throw new Win32Exception();
+                }
+
+                return new WindowsServiceTester(serviceHandle, remoteInvokeHandle, serviceName);
+            }
+        }
+
+        internal unsafe Interop.Advapi32.SERVICE_STATUS QueryServiceStatus()
+        {
+            Interop.Advapi32.SERVICE_STATUS status = default;
+            bool success = Interop.Advapi32.QueryServiceStatus(_serviceHandle, &status);
+            if (!success)
+            {
+                throw new Win32Exception();
+            }
+            return status;
+        }
+
+        internal unsafe Interop.Advapi32.SERVICE_STATUS_PROCESS QueryServiceStatusEx()
+        {
+            Interop.Advapi32.SERVICE_STATUS_PROCESS status = default;
+            bool success = Interop.Advapi32.QueryServiceStatusEx(_serviceHandle, &status);
+            if (!success)
+            {
+                throw new Win32Exception();
+            }
+            return status;
+        }
+
+        protected override void Dispose(bool disposing)
+        {
+            if (_remoteInvokeHandle != null)
+            {
+                _remoteInvokeHandle.Dispose();               
+            }
+
+            if (!_serviceHandle.IsInvalid)
+            {
+                // delete the temporary test service
+                Interop.Advapi32.DeleteService(_serviceHandle);
+                _serviceHandle.Close();
+            }
+        }
+    }
+}
index 870eff6..c745593 100644 (file)
@@ -31,6 +31,7 @@ namespace System.ServiceProcess
         private bool _commandPropsFrozen;  // set to true once we've use the Can... properties.
         private bool _disposed;
         private bool _initialized;
+        private object _stopLock = new object();
         private EventLog? _eventLog;
 
         /// <summary>
@@ -501,27 +502,34 @@ namespace System.ServiceProcess
         // This is a problem when multiple services are hosted in a single process.
         private unsafe void DeferredStop()
         {
-            fixed (SERVICE_STATUS* pStatus = &_status)
+            lock(_stopLock)
             {
-                int previousState = _status.currentState;
-
-                _status.checkPoint = 0;
-                _status.waitHint = 0;
-                _status.currentState = ServiceControlStatus.STATE_STOP_PENDING;
-                SetServiceStatus(_statusHandle, pStatus);
-                try
+                // never call SetServiceStatus again after STATE_STOPPED is set.
+                if (_status.currentState != ServiceControlStatus.STATE_STOPPED)
                 {
-                    OnStop();
-                    WriteLogEntry(SR.StopSuccessful);
-                    _status.currentState = ServiceControlStatus.STATE_STOPPED;
-                    SetServiceStatus(_statusHandle, pStatus);
-                }
-                catch (Exception e)
-                {
-                    _status.currentState = previousState;
-                    SetServiceStatus(_statusHandle, pStatus);
-                    WriteLogEntry(SR.Format(SR.StopFailed, e), EventLogEntryType.Error);
-                    throw;
+                    fixed (SERVICE_STATUS* pStatus = &_status)
+                    {
+                        int previousState = _status.currentState;
+
+                        _status.checkPoint = 0;
+                        _status.waitHint = 0;
+                        _status.currentState = ServiceControlStatus.STATE_STOP_PENDING;
+                        SetServiceStatus(_statusHandle, pStatus);
+                        try
+                        {
+                            OnStop();
+                            WriteLogEntry(SR.StopSuccessful);
+                            _status.currentState = ServiceControlStatus.STATE_STOPPED;
+                            SetServiceStatus(_statusHandle, pStatus);
+                        }
+                        catch (Exception e)
+                        {
+                            _status.currentState = previousState;
+                            SetServiceStatus(_statusHandle, pStatus);
+                            WriteLogEntry(SR.Format(SR.StopFailed, e), EventLogEntryType.Error);
+                            throw;
+                        }
+                    }
                 }
             }
         }
@@ -533,14 +541,17 @@ namespace System.ServiceProcess
                 OnShutdown();
                 WriteLogEntry(SR.ShutdownOK);
 
-                if (_status.currentState == ServiceControlStatus.STATE_PAUSED || _status.currentState == ServiceControlStatus.STATE_RUNNING)
+                lock(_stopLock)
                 {
-                    fixed (SERVICE_STATUS* pStatus = &_status)
+                    if (_status.currentState == ServiceControlStatus.STATE_PAUSED || _status.currentState == ServiceControlStatus.STATE_RUNNING)
                     {
-                        _status.checkPoint = 0;
-                        _status.waitHint = 0;
-                        _status.currentState = ServiceControlStatus.STATE_STOPPED;
-                        SetServiceStatus(_statusHandle, pStatus);
+                        fixed (SERVICE_STATUS* pStatus = &_status)
+                        {
+                            _status.checkPoint = 0;
+                            _status.waitHint = 0;
+                            _status.currentState = ServiceControlStatus.STATE_STOPPED;
+                            SetServiceStatus(_statusHandle, pStatus);
+                        }
                     }
                 }
             }
@@ -654,7 +665,7 @@ namespace System.ServiceProcess
         {
             if (!_initialized)
             {
-                //Cannot register the service with NT service manatger if the object has been disposed, since finalization has been suppressed.
+                //Cannot register the service with NT service manager if the object has been disposed, since finalization has been suppressed.
                 if (_disposed)
                     throw new ObjectDisposedException(GetType().Name);
 
@@ -923,8 +934,14 @@ namespace System.ServiceProcess
                 {
                     string errorMessage = new Win32Exception().Message;
                     WriteLogEntry(SR.Format(SR.StartFailed, errorMessage), EventLogEntryType.Error);
-                    _status.currentState = ServiceControlStatus.STATE_STOPPED;
-                    SetServiceStatus(_statusHandle, pStatus);
+                    lock (_stopLock)
+                    {
+                        if (_status.currentState != ServiceControlStatus.STATE_STOPPED)
+                        {
+                            _status.currentState = ServiceControlStatus.STATE_STOPPED;
+                            SetServiceStatus(_statusHandle, pStatus);
+                        }
+                    }
                 }
             }
         }