diff --git a/eng/testing/xunit/xunit.targets b/eng/testing/xunit/xunit.targets index 8d4aa64986e4d..5cf2846d40d18 100644 --- a/eng/testing/xunit/xunit.targets +++ b/eng/testing/xunit/xunit.targets @@ -4,6 +4,11 @@ + + true + true + + $(OutDir) diff --git a/src/libraries/Common/src/Interop/Windows/Advapi32/Interop.QueryServiceStatusEx.cs b/src/libraries/Common/src/Interop/Windows/Advapi32/Interop.QueryServiceStatusEx.cs new file mode 100644 index 0000000000000..8c38dec4df8eb --- /dev/null +++ b/src/libraries/Common/src/Interop/Windows/Advapi32/Interop.QueryServiceStatusEx.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Win32.SafeHandles; +using System; +using System.Runtime.InteropServices; + +internal static partial class Interop +{ + internal static partial class Advapi32 + { + [StructLayout(LayoutKind.Sequential)] + internal struct SERVICE_STATUS_PROCESS + { + public int dwServiceType; + public int dwCurrentState; + public int dwControlsAccepted; + public int dwWin32ExitCode; + public int dwServiceSpecificExitCode; + public int dwCheckPoint; + public int dwWaitHint; + public int dwProcessId; + public int dwServiceFlags; + } + + private const int SC_STATUS_PROCESS_INFO = 0; + + [LibraryImport(Libraries.Advapi32, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + private static unsafe partial bool QueryServiceStatusEx(SafeServiceHandle serviceHandle, int InfoLevel, SERVICE_STATUS_PROCESS* pStatus, int cbBufSize, out int pcbBytesNeeded); + + internal static unsafe bool QueryServiceStatusEx(SafeServiceHandle serviceHandle, SERVICE_STATUS_PROCESS* pStatus) => QueryServiceStatusEx(serviceHandle, SC_STATUS_PROCESS_INFO, pStatus, sizeof(SERVICE_STATUS_PROCESS), out _); + } +} diff --git a/src/libraries/Common/src/Interop/Windows/Interop.Errors.cs b/src/libraries/Common/src/Interop/Windows/Interop.Errors.cs index d5f6d1637507f..424b57f7a8d7b 100644 --- a/src/libraries/Common/src/Interop/Windows/Interop.Errors.cs +++ b/src/libraries/Common/src/Interop/Windows/Interop.Errors.cs @@ -63,6 +63,8 @@ internal static partial class Errors internal const int ERROR_IO_PENDING = 0x3E5; internal const int ERROR_NO_TOKEN = 0x3f0; internal const int ERROR_SERVICE_DOES_NOT_EXIST = 0x424; + internal const int ERROR_EXCEPTION_IN_SERVICE = 0x428; + internal const int ERROR_PROCESS_ABORTED = 0x42B; internal const int ERROR_NO_UNICODE_TRANSLATION = 0x459; internal const int ERROR_DLL_INIT_FAILED = 0x45A; internal const int ERROR_COUNTER_TIMEOUT = 0x461; diff --git a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/src/WindowsServiceLifetime.cs b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/src/WindowsServiceLifetime.cs index f374e3f0f7132..755098d61841a 100644 --- a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/src/WindowsServiceLifetime.cs +++ b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/src/WindowsServiceLifetime.cs @@ -12,9 +12,11 @@ namespace Microsoft.Extensions.Hosting.WindowsServices { public class WindowsServiceLifetime : ServiceBase, IHostLifetime { - private readonly TaskCompletionSource _delayStart = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly TaskCompletionSource _delayStart = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly TaskCompletionSource _serviceDispatcherStopped = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); private readonly ManualResetEventSlim _delayStop = new ManualResetEventSlim(); private readonly HostOptions _hostOptions; + private bool _serviceStopRequested; public WindowsServiceLifetime(IHostEnvironment environment, IHostApplicationLifetime applicationLifetime, ILoggerFactory loggerFactory, IOptions optionsAccessor) : this(environment, applicationLifetime, loggerFactory, optionsAccessor, Options.Options.Create(new WindowsServiceLifetimeOptions())) @@ -73,19 +75,30 @@ private void Run() { Run(this); // This blocks until the service is stopped. _delayStart.TrySetException(new InvalidOperationException("Stopped without starting")); + _serviceDispatcherStopped.TrySetResult(null); } catch (Exception ex) { _delayStart.TrySetException(ex); + _serviceDispatcherStopped.TrySetException(ex); } } - public Task StopAsync(CancellationToken cancellationToken) + /// + /// Called from to stop the service if not already stopped, and wait for the service dispatcher to exit. + /// Once this method returns the service is stopped and the process can be terminated at any time. + /// + public async Task StopAsync(CancellationToken cancellationToken) { - // Avoid deadlock where host waits for StopAsync before firing ApplicationStopped, - // and Stop waits for ApplicationStopped. - Task.Run(Stop, CancellationToken.None); - return Task.CompletedTask; + cancellationToken.ThrowIfCancellationRequested(); + + if (!_serviceStopRequested) + { + await Task.Run(Stop, cancellationToken).ConfigureAwait(false); + } + + // When the underlying service is stopped this will cause the ServiceBase.Run method to complete and return, which completes _serviceDispatcherStopped. + await _serviceDispatcherStopped.Task.ConfigureAwait(false); } // Called by base.Run when the service is ready to start. @@ -95,18 +108,28 @@ protected override void OnStart(string[] args) base.OnStart(args); } - // Called by base.Stop. This may be called multiple times by service Stop, ApplicationStopping, and StopAsync. - // That's OK because StopApplication uses a CancellationTokenSource and prevents any recursion. + /// + /// Executes when a Stop command is sent to the service by the Service Control Manager (SCM). + /// Triggers and waits for . + /// Shortly after this method returns, the Service will be marked as stopped in SCM and the process may exit at any point. + /// protected override void OnStop() { + _serviceStopRequested = true; ApplicationLifetime.StopApplication(); // Wait for the host to shutdown before marking service as stopped. _delayStop.Wait(_hostOptions.ShutdownTimeout); base.OnStop(); } + /// + /// Executes when a Shutdown command is sent to the service by the Service Control Manager (SCM). + /// Triggers and waits for . + /// Shortly after this method returns, the Service will be marked as stopped in SCM and the process may exit at any point. + /// protected override void OnShutdown() { + _serviceStopRequested = true; ApplicationLifetime.StopApplication(); // Wait for the host to shutdown before marking service as stopped. _delayStop.Wait(_hostOptions.ShutdownTimeout); diff --git a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/test/Microsoft.Extensions.Hosting.WindowsServices.Tests.csproj b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/test/Microsoft.Extensions.Hosting.WindowsServices.Tests.csproj index 67b1cb31ee521..ef7aa6d7801bd 100644 --- a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/test/Microsoft.Extensions.Hosting.WindowsServices.Tests.csproj +++ b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/test/Microsoft.Extensions.Hosting.WindowsServices.Tests.csproj @@ -1,12 +1,47 @@ - + - $(NetCoreAppCurrent);net461 + + $(NetCoreAppCurrent)-windows;$(NetFrameworkMinimum) true + true + true + true + + + + + + + + + + + + + + + + + + + diff --git a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/test/UseWindowsServiceTests.cs b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/test/UseWindowsServiceTests.cs index 7e24685466d91..c894f58954f32 100644 --- a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/test/UseWindowsServiceTests.cs +++ b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/test/UseWindowsServiceTests.cs @@ -1,14 +1,23 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; +using System.Reflection; +using System.ServiceProcess; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting.Internal; +using Microsoft.Extensions.Hosting.WindowsServices; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.EventLog; +using Microsoft.Extensions.Options; using Xunit; namespace Microsoft.Extensions.Hosting { public class UseWindowsServiceTests { + private static MethodInfo? _addWindowsServiceLifetimeMethod = null; + [Fact] public void DefaultsToOffOutsideOfService() { @@ -22,5 +31,25 @@ public void DefaultsToOffOutsideOfService() Assert.IsType(lifetime); } } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))] + public void CanCreateService() + { + using var serviceTester = WindowsServiceTester.Create(() => + { + using IHost host = new HostBuilder() + .UseWindowsService() + .Build(); + host.Run(); + }); + + serviceTester.Start(); + serviceTester.WaitForStatus(ServiceControllerStatus.Running); + serviceTester.Stop(); + serviceTester.WaitForStatus(ServiceControllerStatus.Stopped); + + var status = serviceTester.QueryServiceStatus(); + Assert.Equal(0, status.win32ExitCode); + } } } diff --git a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/test/WindowsServiceLifetimeTests.cs b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/test/WindowsServiceLifetimeTests.cs new file mode 100644 index 0000000000000..06679b3c48459 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/test/WindowsServiceLifetimeTests.cs @@ -0,0 +1,338 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics; +using System.IO; +using System.ServiceProcess; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting.Internal; +using Microsoft.Extensions.Hosting.WindowsServices; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Xunit; + +namespace Microsoft.Extensions.Hosting +{ + public class WindowsServiceLifetimeTests + { + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))] + public void ServiceStops() + { + using var serviceTester = WindowsServiceTester.Create(async () => + { + var applicationLifetime = new ApplicationLifetime(NullLogger.Instance); + using var lifetime = new WindowsServiceLifetime( + new HostingEnvironment(), + applicationLifetime, + NullLoggerFactory.Instance, + new OptionsWrapper(new HostOptions())); + + await lifetime.WaitForStartAsync(CancellationToken.None); + + // would normally occur here, but WindowsServiceLifetime does not depend on it. + // applicationLifetime.NotifyStarted(); + + // will be signaled by WindowsServiceLifetime when SCM stops the service. + applicationLifetime.ApplicationStopping.WaitHandle.WaitOne(); + + // required by WindowsServiceLifetime to identify that app has stopped. + applicationLifetime.NotifyStopped(); + + await lifetime.StopAsync(CancellationToken.None); + }); + + serviceTester.Start(); + serviceTester.WaitForStatus(ServiceControllerStatus.Running); + + var statusEx = serviceTester.QueryServiceStatusEx(); + var serviceProcess = Process.GetProcessById(statusEx.dwProcessId); + + serviceTester.Stop(); + serviceTester.WaitForStatus(ServiceControllerStatus.Stopped); + + serviceProcess.WaitForExit(); + + var status = serviceTester.QueryServiceStatus(); + Assert.Equal(0, status.win32ExitCode); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))] + [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework, ".NET Framework is missing the fix from https://github.com/dotnet/corefx/commit/3e68d791066ad0fdc6e0b81828afbd9df00dd7f8")] + public void ExceptionOnStartIsPropagated() + { + using var serviceTester = WindowsServiceTester.Create(async () => + { + using (var lifetime = ThrowingWindowsServiceLifetime.Create(throwOnStart: new Exception("Should be thrown"))) + { + Assert.Equal(lifetime.ThrowOnStart, + await Assert.ThrowsAsync(async () => + await lifetime.WaitForStartAsync(CancellationToken.None))); + } + }); + + serviceTester.Start(); + + serviceTester.WaitForStatus(ServiceControllerStatus.Stopped); + var status = serviceTester.QueryServiceStatus(); + Assert.Equal(Interop.Errors.ERROR_EXCEPTION_IN_SERVICE, status.win32ExitCode); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))] + public void ExceptionOnStopIsPropagated() + { + using var serviceTester = WindowsServiceTester.Create(async () => + { + using (var lifetime = ThrowingWindowsServiceLifetime.Create(throwOnStop: new Exception("Should be thrown"))) + { + await lifetime.WaitForStartAsync(CancellationToken.None); + lifetime.ApplicationLifetime.NotifyStopped(); + Assert.Equal(lifetime.ThrowOnStop, + await Assert.ThrowsAsync( async () => + await lifetime.StopAsync(CancellationToken.None))); + } + }); + + serviceTester.Start(); + + serviceTester.WaitForStatus(ServiceControllerStatus.Stopped); + var status = serviceTester.QueryServiceStatus(); + Assert.Equal(Interop.Errors.ERROR_PROCESS_ABORTED, status.win32ExitCode); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))] + public void CancelStopAsync() + { + using var serviceTester = WindowsServiceTester.Create(async () => + { + var applicationLifetime = new ApplicationLifetime(NullLogger.Instance); + using var lifetime = new WindowsServiceLifetime( + new HostingEnvironment(), + applicationLifetime, + NullLoggerFactory.Instance, + new OptionsWrapper(new HostOptions())); + await lifetime.WaitForStartAsync(CancellationToken.None); + + await Assert.ThrowsAsync(async () => await lifetime.StopAsync(new CancellationToken(true))); + }); + + serviceTester.Start(); + + serviceTester.WaitForStatus(ServiceControllerStatus.Stopped); + var status = serviceTester.QueryServiceStatus(); + Assert.Equal(Interop.Errors.ERROR_PROCESS_ABORTED, status.win32ExitCode); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))] + public void ServiceCanStopItself() + { + using (var serviceTester = WindowsServiceTester.Create(async () => + { + FileLogger.InitializeForTestCase(nameof(ServiceCanStopItself)); + using IHost host = new HostBuilder() + .ConfigureServices(services => + { + services.AddHostedService(); + services.AddSingleton(); + }) + .Build(); + + var applicationLifetime = host.Services.GetRequiredService(); + applicationLifetime.ApplicationStarted.Register(() => FileLogger.Log($"lifetime started")); + applicationLifetime.ApplicationStopping.Register(() => FileLogger.Log($"lifetime stopping")); + applicationLifetime.ApplicationStopped.Register(() => FileLogger.Log($"lifetime stopped")); + + FileLogger.Log("host.Start()"); + host.Start(); + + FileLogger.Log("host.Stop()"); + await host.StopAsync(); + FileLogger.Log("host.Stop() complete"); + })) + { + FileLogger.DeleteLog(nameof(ServiceCanStopItself)); + + // service should start cleanly + serviceTester.Start(); + + // service will proceed to stopped without any error + serviceTester.WaitForStatus(ServiceControllerStatus.Stopped); + + var status = serviceTester.QueryServiceStatus(); + Assert.Equal(0, status.win32ExitCode); + + } + + var logText = FileLogger.ReadLog(nameof(ServiceCanStopItself)); + Assert.Equal(""" + host.Start() + WindowsServiceLifetime.OnStart + BackgroundService.StartAsync + lifetime started + host.Stop() + lifetime stopping + BackgroundService.StopAsync + lifetime stopped + WindowsServiceLifetime.OnStop + host.Stop() complete + + """, logText); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))] + public void ServiceSequenceIsCorrect() + { + using (var serviceTester = WindowsServiceTester.Create(() => + { + FileLogger.InitializeForTestCase(nameof(ServiceSequenceIsCorrect)); + using IHost host = new HostBuilder() + .ConfigureServices(services => + { + services.AddHostedService(); + services.AddSingleton(); + }) + .Build(); + + var applicationLifetime = host.Services.GetRequiredService(); + applicationLifetime.ApplicationStarted.Register(() => FileLogger.Log($"lifetime started")); + applicationLifetime.ApplicationStopping.Register(() => FileLogger.Log($"lifetime stopping")); + applicationLifetime.ApplicationStopped.Register(() => FileLogger.Log($"lifetime stopped")); + + FileLogger.Log("host.Run()"); + host.Run(); + FileLogger.Log("host.Run() complete"); + })) + { + + FileLogger.DeleteLog(nameof(ServiceSequenceIsCorrect)); + + serviceTester.Start(); + serviceTester.WaitForStatus(ServiceControllerStatus.Running); + + var statusEx = serviceTester.QueryServiceStatusEx(); + var serviceProcess = Process.GetProcessById(statusEx.dwProcessId); + + // Give a chance for all asynchronous "started" events to be raised, these happen after the service status changes to started + Thread.Sleep(1000); + + serviceTester.Stop(); + serviceTester.WaitForStatus(ServiceControllerStatus.Stopped); + + var status = serviceTester.QueryServiceStatus(); + Assert.Equal(0, status.win32ExitCode); + + } + + var logText = FileLogger.ReadLog(nameof(ServiceSequenceIsCorrect)); + Assert.Equal(""" + host.Run() + WindowsServiceLifetime.OnStart + BackgroundService.StartAsync + lifetime started + WindowsServiceLifetime.OnStop + lifetime stopping + BackgroundService.StopAsync + lifetime stopped + host.Run() complete + + """, logText); + + } + + public class LoggingWindowsServiceLifetime : WindowsServiceLifetime + { + public LoggingWindowsServiceLifetime(IHostEnvironment environment, IHostApplicationLifetime applicationLifetime, ILoggerFactory loggerFactory, IOptions optionsAccessor) : + base(environment, applicationLifetime, loggerFactory, optionsAccessor) + { } + + protected override void OnStart(string[] args) + { + FileLogger.Log("WindowsServiceLifetime.OnStart"); + base.OnStart(args); + } + + protected override void OnStop() + { + FileLogger.Log("WindowsServiceLifetime.OnStop"); + base.OnStop(); + } + } + + public class ThrowingWindowsServiceLifetime : WindowsServiceLifetime + { + public static ThrowingWindowsServiceLifetime Create(Exception throwOnStart = null, Exception throwOnStop = null) => + new ThrowingWindowsServiceLifetime( + new HostingEnvironment(), + new ApplicationLifetime(NullLogger.Instance), + NullLoggerFactory.Instance, + new OptionsWrapper(new HostOptions())) + { + ThrowOnStart = throwOnStart, + ThrowOnStop = throwOnStop + }; + + public ThrowingWindowsServiceLifetime(IHostEnvironment environment, ApplicationLifetime applicationLifetime, ILoggerFactory loggerFactory, IOptions optionsAccessor) : + base(environment, applicationLifetime, loggerFactory, optionsAccessor) + { + ApplicationLifetime = applicationLifetime; + } + + public ApplicationLifetime ApplicationLifetime { get; } + + public Exception ThrowOnStart { get; set; } + protected override void OnStart(string[] args) + { + if (ThrowOnStart != null) + { + throw ThrowOnStart; + } + base.OnStart(args); + } + + public Exception ThrowOnStop { get; set; } + protected override void OnStop() + { + if (ThrowOnStop != null) + { + throw ThrowOnStop; + } + base.OnStop(); + } + } + + public class LoggingBackgroundService : BackgroundService + { +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + protected override async Task ExecuteAsync(CancellationToken stoppingToken) => FileLogger.Log("BackgroundService.ExecuteAsync"); + public override async Task StartAsync(CancellationToken stoppingToken) => FileLogger.Log("BackgroundService.StartAsync"); + public override async Task StopAsync(CancellationToken stoppingToken) => FileLogger.Log("BackgroundService.StopAsync"); +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + } + + static class FileLogger + { + static string _fileName; + + public static void InitializeForTestCase(string testCaseName) + { + Assert.Null(_fileName); + _fileName = GetLogForTestCase(testCaseName); + } + + private static string GetLogForTestCase(string testCaseName) => Path.Combine(AppContext.BaseDirectory, $"{testCaseName}.log"); + public static void DeleteLog(string testCaseName) => File.Delete(GetLogForTestCase(testCaseName)); + public static string ReadLog(string testCaseName) => File.ReadAllText(GetLogForTestCase(testCaseName)); + public static void Log(string message) + { + Assert.NotNull(_fileName); + lock (_fileName) + { + File.AppendAllText(_fileName, message + Environment.NewLine); + } + } + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/test/WindowsServiceTester.cs b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/test/WindowsServiceTester.cs new file mode 100644 index 0000000000000..895b4a87108eb --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/test/WindowsServiceTester.cs @@ -0,0 +1,158 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ComponentModel; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.ServiceProcess; +using System.Threading.Tasks; +using Microsoft.DotNet.RemoteExecutor; +using Microsoft.Win32.SafeHandles; +using Xunit; + +namespace Microsoft.Extensions.Hosting +{ + public class WindowsServiceTester : ServiceController + { + private WindowsServiceTester(SafeServiceHandle serviceHandle, RemoteInvokeHandle remoteInvokeHandle, string serviceName) : base(serviceName) + { + _serviceHandle = serviceHandle; + _remoteInvokeHandle = remoteInvokeHandle; + } + + private SafeServiceHandle _serviceHandle; + private RemoteInvokeHandle _remoteInvokeHandle; + + public new void Start() + { + Start(Array.Empty()); + } + + public new void Start(string[] args) + { + base.Start(args); + + // get the process + _remoteInvokeHandle.Process.Dispose(); + _remoteInvokeHandle.Process = null; + + var statusEx = QueryServiceStatusEx(); + try + { + _remoteInvokeHandle.Process = Process.GetProcessById(statusEx.dwProcessId); + // fetch the process handle so that we can get the exit code later. + var _ = _remoteInvokeHandle.Process.SafeHandle; + } + catch (ArgumentException) + { } + } + + public TimeSpan WaitForStatusTimeout { get; set; } = TimeSpan.FromSeconds(30); + + public new void WaitForStatus(ServiceControllerStatus desiredStatus) => + WaitForStatus(desiredStatus, WaitForStatusTimeout); + + public new void WaitForStatus(ServiceControllerStatus desiredStatus, TimeSpan timeout) + { + base.WaitForStatus(desiredStatus, timeout); + + Assert.Equal(Status, desiredStatus); + } + + // the following overloads are necessary to ensure the compiler will produce the correct signature from a lambda. + public static WindowsServiceTester Create(Func serviceMain, [CallerMemberName] string serviceName = null) => Create(RemoteExecutor.Invoke(serviceMain, remoteInvokeOptions), serviceName); + + public static WindowsServiceTester Create(Func> serviceMain, [CallerMemberName] string serviceName = null) => Create(RemoteExecutor.Invoke(serviceMain, remoteInvokeOptions), serviceName); + + public static WindowsServiceTester Create(Func serviceMain, [CallerMemberName] string serviceName = null) => Create(RemoteExecutor.Invoke(serviceMain, remoteInvokeOptions), serviceName); + + public static WindowsServiceTester Create(Action serviceMain, [CallerMemberName] string serviceName = null) => Create(RemoteExecutor.Invoke(serviceMain, remoteInvokeOptions), serviceName); + + private static RemoteInvokeOptions remoteInvokeOptions = new RemoteInvokeOptions() { Start = false }; + + private static WindowsServiceTester Create(RemoteInvokeHandle remoteInvokeHandle, string serviceName) + { + // create remote executor commandline arguments + var startInfo = remoteInvokeHandle.Process.StartInfo; + string commandLine = startInfo.FileName + " " + startInfo.Arguments; + + // install the service + using (var serviceManagerHandle = new SafeServiceHandle(Interop.Advapi32.OpenSCManager(null, null, Interop.Advapi32.ServiceControllerOptions.SC_MANAGER_ALL))) + { + if (serviceManagerHandle.IsInvalid) + { + throw new InvalidOperationException(); + } + + // delete existing service if it exists + using (var existingServiceHandle = new SafeServiceHandle(Interop.Advapi32.OpenService(serviceManagerHandle, serviceName, Interop.Advapi32.ServiceAccessOptions.ACCESS_TYPE_ALL))) + { + if (!existingServiceHandle.IsInvalid) + { + Interop.Advapi32.DeleteService(existingServiceHandle); + } + } + + var serviceHandle = new SafeServiceHandle( + Interop.Advapi32.CreateService(serviceManagerHandle, + serviceName, + $"{nameof(WindowsServiceTester)} {serviceName} test service", + Interop.Advapi32.ServiceAccessOptions.ACCESS_TYPE_ALL, + Interop.Advapi32.ServiceTypeOptions.SERVICE_WIN32_OWN_PROCESS, + (int)ServiceStartMode.Manual, + Interop.Advapi32.ServiceStartErrorModes.ERROR_CONTROL_NORMAL, + commandLine, + loadOrderGroup: null, + pTagId: IntPtr.Zero, + dependencies: null, + servicesStartName: null, + password: null)); + + if (serviceHandle.IsInvalid) + { + throw new Win32Exception(); + } + + return new WindowsServiceTester(serviceHandle, remoteInvokeHandle, serviceName); + } + } + + internal unsafe Interop.Advapi32.SERVICE_STATUS QueryServiceStatus() + { + Interop.Advapi32.SERVICE_STATUS status = default; + bool success = Interop.Advapi32.QueryServiceStatus(_serviceHandle, &status); + if (!success) + { + throw new Win32Exception(); + } + return status; + } + + internal unsafe Interop.Advapi32.SERVICE_STATUS_PROCESS QueryServiceStatusEx() + { + Interop.Advapi32.SERVICE_STATUS_PROCESS status = default; + bool success = Interop.Advapi32.QueryServiceStatusEx(_serviceHandle, &status); + if (!success) + { + throw new Win32Exception(); + } + return status; + } + + protected override void Dispose(bool disposing) + { + if (_remoteInvokeHandle != null) + { + _remoteInvokeHandle.Dispose(); + } + + if (!_serviceHandle.IsInvalid) + { + // delete the temporary test service + Interop.Advapi32.DeleteService(_serviceHandle); + _serviceHandle.Close(); + } + } + } +} diff --git a/src/libraries/System.ServiceProcess.ServiceController/src/System/ServiceProcess/ServiceBase.cs b/src/libraries/System.ServiceProcess.ServiceController/src/System/ServiceProcess/ServiceBase.cs index c8dcaa55b8936..d0995486dcdeb 100644 --- a/src/libraries/System.ServiceProcess.ServiceController/src/System/ServiceProcess/ServiceBase.cs +++ b/src/libraries/System.ServiceProcess.ServiceController/src/System/ServiceProcess/ServiceBase.cs @@ -31,6 +31,7 @@ public class ServiceBase : Component private bool _commandPropsFrozen; // set to true once we've use the Can... properties. private bool _disposed; private bool _initialized; + private object _stopLock = new object(); private EventLog? _eventLog; /// @@ -505,27 +506,34 @@ private void DeferredSessionChange(int eventType, int sessionId) // This is a problem when multiple services are hosted in a single process. private unsafe void DeferredStop() { - fixed (SERVICE_STATUS* pStatus = &_status) + lock(_stopLock) { - int previousState = _status.currentState; - - _status.checkPoint = 0; - _status.waitHint = 0; - _status.currentState = ServiceControlStatus.STATE_STOP_PENDING; - SetServiceStatus(_statusHandle, pStatus); - try + // never call SetServiceStatus again after STATE_STOPPED is set. + if (_status.currentState != ServiceControlStatus.STATE_STOPPED) { - OnStop(); - WriteLogEntry(SR.StopSuccessful); - _status.currentState = ServiceControlStatus.STATE_STOPPED; - SetServiceStatus(_statusHandle, pStatus); - } - catch (Exception e) - { - _status.currentState = previousState; - SetServiceStatus(_statusHandle, pStatus); - WriteLogEntry(SR.Format(SR.StopFailed, e), EventLogEntryType.Error); - throw; + fixed (SERVICE_STATUS* pStatus = &_status) + { + int previousState = _status.currentState; + + _status.checkPoint = 0; + _status.waitHint = 0; + _status.currentState = ServiceControlStatus.STATE_STOP_PENDING; + SetServiceStatus(_statusHandle, pStatus); + try + { + OnStop(); + WriteLogEntry(SR.StopSuccessful); + _status.currentState = ServiceControlStatus.STATE_STOPPED; + SetServiceStatus(_statusHandle, pStatus); + } + catch (Exception e) + { + _status.currentState = previousState; + SetServiceStatus(_statusHandle, pStatus); + WriteLogEntry(SR.Format(SR.StopFailed, e), EventLogEntryType.Error); + throw; + } + } } } } @@ -537,14 +545,17 @@ private unsafe void DeferredShutdown() OnShutdown(); WriteLogEntry(SR.ShutdownOK); - if (_status.currentState == ServiceControlStatus.STATE_PAUSED || _status.currentState == ServiceControlStatus.STATE_RUNNING) + lock(_stopLock) { - fixed (SERVICE_STATUS* pStatus = &_status) + if (_status.currentState == ServiceControlStatus.STATE_PAUSED || _status.currentState == ServiceControlStatus.STATE_RUNNING) { - _status.checkPoint = 0; - _status.waitHint = 0; - _status.currentState = ServiceControlStatus.STATE_STOPPED; - SetServiceStatus(_statusHandle, pStatus); + fixed (SERVICE_STATUS* pStatus = &_status) + { + _status.checkPoint = 0; + _status.waitHint = 0; + _status.currentState = ServiceControlStatus.STATE_STOPPED; + SetServiceStatus(_statusHandle, pStatus); + } } } } @@ -672,7 +683,7 @@ private void Initialize(bool multipleServices) { if (!_initialized) { - //Cannot register the service with NT service manatger if the object has been disposed, since finalization has been suppressed. + //Cannot register the service with NT service manager if the object has been disposed, since finalization has been suppressed. if (_disposed) throw new ObjectDisposedException(GetType().Name); @@ -935,8 +946,14 @@ public unsafe void ServiceMainCallback(int argCount, IntPtr argPointer) if (!statusOK) { WriteLogEntry(SR.Format(SR.StartFailed, new Win32Exception().Message), EventLogEntryType.Error); - _status.currentState = ServiceControlStatus.STATE_STOPPED; - SetServiceStatus(_statusHandle, pStatus); + lock (_stopLock) + { + if (_status.currentState != ServiceControlStatus.STATE_STOPPED) + { + _status.currentState = ServiceControlStatus.STATE_STOPPED; + SetServiceStatus(_statusHandle, pStatus); + } + } } } }