Fix service accessor scope validation for the emit-based version
authorPavel Ivanov <ivanovpavelalex45@gmail.com>
Thu, 20 Jul 2023 13:57:55 +0000 (18:57 +0500)
committerGitHub <noreply@github.com>
Thu, 20 Jul 2023 13:57:55 +0000 (08:57 -0500)
src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs

index 613305c..6705f7d 100644 (file)
@@ -21,14 +21,14 @@ namespace Microsoft.Extensions.DependencyInjection
     {
         private readonly CallSiteValidator? _callSiteValidator;
 
-        private readonly Func<ServiceIdentifier, Func<ServiceProviderEngineScope, object?>> _createServiceAccessor;
+        private readonly Func<ServiceIdentifier, ServiceAccessor> _createServiceAccessor;
 
         // Internal for testing
         internal ServiceProviderEngine _engine;
 
         private bool _disposed;
 
-        private readonly ConcurrentDictionary<ServiceIdentifier, Func<ServiceProviderEngineScope, object?>> _realizedServices;
+        private readonly ConcurrentDictionary<ServiceIdentifier, ServiceAccessor> _serviceAccessors;
 
         internal CallSiteFactory CallSiteFactory { get; }
 
@@ -50,7 +50,7 @@ namespace Microsoft.Extensions.DependencyInjection
             Root = new ServiceProviderEngineScope(this, isRootScope: true);
             _engine = GetEngine();
             _createServiceAccessor = CreateServiceAccessor;
-            _realizedServices = new ConcurrentDictionary<ServiceIdentifier, Func<ServiceProviderEngineScope, object?>>();
+            _serviceAccessors = new ConcurrentDictionary<ServiceIdentifier, ServiceAccessor>();
 
             CallSiteFactory = new CallSiteFactory(serviceDescriptors);
             // The list of built in services that aren't part of the list of service descriptors
@@ -137,9 +137,12 @@ namespace Microsoft.Extensions.DependencyInjection
             _callSiteValidator?.ValidateCallSite(callSite);
         }
 
-        private void OnResolve(ServiceCallSite callSite, IServiceScope scope)
+        private void OnResolve(ServiceCallSite? callSite, IServiceScope scope)
         {
-            _callSiteValidator?.ValidateResolution(callSite, scope, Root);
+            if (callSite != null)
+            {
+                _callSiteValidator?.ValidateResolution(callSite, scope, Root);
+            }
         }
 
         internal object? GetService(ServiceIdentifier serviceIdentifier, ServiceProviderEngineScope serviceProviderEngineScope)
@@ -148,9 +151,10 @@ namespace Microsoft.Extensions.DependencyInjection
             {
                 ThrowHelper.ThrowObjectDisposedException();
             }
-
-            Func<ServiceProviderEngineScope, object?> realizedService = _realizedServices.GetOrAdd(serviceIdentifier, _createServiceAccessor);
-            var result = realizedService.Invoke(serviceProviderEngineScope);
+            ServiceAccessor serviceAccessor = _serviceAccessors.GetOrAdd(serviceIdentifier, _createServiceAccessor);
+            OnResolve(serviceAccessor.CallSite, serviceProviderEngineScope);
+            DependencyInjectionEventSource.Log.ServiceResolved(this, serviceIdentifier.ServiceType);
+            object? result = serviceAccessor.RealizedService?.Invoke(serviceProviderEngineScope);
             System.Diagnostics.Debug.Assert(result is null || CallSiteFactory.IsService(serviceIdentifier));
             return result;
         }
@@ -176,7 +180,7 @@ namespace Microsoft.Extensions.DependencyInjection
             }
         }
 
-        private Func<ServiceProviderEngineScope, object?> CreateServiceAccessor(ServiceIdentifier serviceIdentifier)
+        private ServiceAccessor CreateServiceAccessor(ServiceIdentifier serviceIdentifier)
         {
             ServiceCallSite? callSite = CallSiteFactory.GetCallSite(serviceIdentifier, new CallSiteChain());
             if (callSite != null)
@@ -188,28 +192,22 @@ namespace Microsoft.Extensions.DependencyInjection
                 if (callSite.Cache.Location == CallSiteResultCacheLocation.Root)
                 {
                     object? value = CallSiteRuntimeResolver.Instance.Resolve(callSite, Root);
-                    return scope =>
-                    {
-                        DependencyInjectionEventSource.Log.ServiceResolved(this, serviceIdentifier.ServiceType);
-                        return value;
-                    };
+                    return new ServiceAccessor { CallSite = callSite, RealizedService = scope => value };
                 }
 
                 Func<ServiceProviderEngineScope, object?> realizedService = _engine.RealizeService(callSite);
-                return scope =>
-                {
-                    OnResolve(callSite, scope);
-                    DependencyInjectionEventSource.Log.ServiceResolved(this, serviceIdentifier.ServiceType);
-                    return realizedService(scope);
-                };
+                return new ServiceAccessor { CallSite = callSite, RealizedService = realizedService };
             }
-
-            return _ => null;
+            return new ServiceAccessor { CallSite = callSite, RealizedService = _ => null };
         }
 
         internal void ReplaceServiceAccessor(ServiceCallSite callSite, Func<ServiceProviderEngineScope, object?> accessor)
         {
-            _realizedServices[new ServiceIdentifier(callSite.Key, callSite.ServiceType)] = accessor;
+            _serviceAccessors[new ServiceIdentifier(callSite.Key, callSite.ServiceType)] = new ServiceAccessor
+            {
+                CallSite = callSite,
+                RealizedService = accessor
+            };
         }
 
         internal IServiceScope CreateScope()
@@ -262,5 +260,11 @@ namespace Microsoft.Extensions.DependencyInjection
             public bool Disposed => _serviceProvider.Root.Disposed;
             public bool IsScope => !_serviceProvider.Root.IsRootScope;
         }
+
+        private sealed class ServiceAccessor
+        {
+            public ServiceCallSite? CallSite { get; set; }
+            public Func<ServiceProviderEngineScope, object?>? RealizedService { get; set; }
+        }
     }
 }
index c6f1834..8780312 100644 (file)
@@ -4,6 +4,7 @@
 using System;
 using System.Collections.Generic;
 using System.Linq;
+using System.Threading.Tasks;
 using Microsoft.Extensions.DependencyInjection.Specification.Fakes;
 using Xunit;
 
@@ -86,6 +87,30 @@ namespace Microsoft.Extensions.DependencyInjection.Tests
         }
 
         [Fact]
+        public async void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot_IL_Replacement()
+        {
+            // Arrange
+            var serviceCollection = new ServiceCollection();
+            serviceCollection.AddScoped<IBar, Bar>();
+            var serviceProvider = serviceCollection.BuildServiceProvider(validateScopes: true);
+
+            // Act + Assert
+            using (var scope = serviceProvider.CreateScope())
+            {
+                // Switch to an emit-based version which is triggered in the background after 2 calls to GetService.
+                scope.ServiceProvider.GetRequiredService(typeof(IBar));
+                scope.ServiceProvider.GetRequiredService(typeof(IBar));
+
+                // Give the background thread time to generate the emit version.
+                await Task.Delay(100);
+
+                // Ensure the emit-based version has the correct scope checks.
+                var exception = Assert.Throws<InvalidOperationException>(serviceProvider.GetRequiredService<IBar>);
+                Assert.Equal($"Cannot resolve scoped service '{typeof(IBar)}' from root provider.", exception.Message);
+            }
+        }
+
+        [Fact]
         public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRootViaTransient()
         {
             // Arrange