Implement IKeyedServiceProvider interface on ServiceProviderEngineScope (#89509)
authorNatalia Kondratyeva <knatalia@microsoft.com>
Fri, 28 Jul 2023 16:53:51 +0000 (17:53 +0100)
committerGitHub <noreply@github.com>
Fri, 28 Jul 2023 16:53:51 +0000 (18:53 +0200)
* Implement IKeyedServiceProvider interface

* Add more tests

src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs
src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/ServiceProviderEngineScope.cs
src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceProvider.cs
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderEngineScopeTests.cs

index 3aa3282..32112c3 100644 (file)
@@ -324,6 +324,114 @@ namespace Microsoft.Extensions.DependencyInjection.Specification
             Assert.NotSame(first, second);
         }
 
+        [Fact]
+        public void ResolveKeyedSingletonFromInjectedServiceProvider()
+        {
+            var serviceCollection = new ServiceCollection();
+            serviceCollection.AddKeyedSingleton<IService, Service>("key");
+            serviceCollection.AddSingleton<ServiceProviderAccessor>();
+
+            var provider = CreateServiceProvider(serviceCollection);
+            var accessor = provider.GetRequiredService<ServiceProviderAccessor>();
+
+            Assert.Null(accessor.ServiceProvider.GetService<IService>());
+
+            var service1 = accessor.ServiceProvider.GetKeyedService<IService>("key");
+            var service2 = accessor.ServiceProvider.GetKeyedService<IService>("key");
+
+            Assert.Same(service1, service2);
+        }
+
+        [Fact]
+        public void ResolveKeyedTransientFromInjectedServiceProvider()
+        {
+            var serviceCollection = new ServiceCollection();
+            serviceCollection.AddKeyedTransient<IService, Service>("key");
+            serviceCollection.AddSingleton<ServiceProviderAccessor>();
+
+            var provider = CreateServiceProvider(serviceCollection);
+            var accessor = provider.GetRequiredService<ServiceProviderAccessor>();
+
+            Assert.Null(accessor.ServiceProvider.GetService<IService>());
+
+            var service1 = accessor.ServiceProvider.GetKeyedService<IService>("key");
+            var service2 = accessor.ServiceProvider.GetKeyedService<IService>("key");
+
+            Assert.NotSame(service1, service2);
+        }
+
+        [Fact]
+        public void ResolveKeyedSingletonFromScopeServiceProvider()
+        {
+            var serviceCollection = new ServiceCollection();
+            serviceCollection.AddKeyedSingleton<IService, Service>("key");
+
+            var provider = CreateServiceProvider(serviceCollection);
+            var scopeA = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();
+            var scopeB = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();
+
+            Assert.Null(scopeA.ServiceProvider.GetService<IService>());
+            Assert.Null(scopeB.ServiceProvider.GetService<IService>());
+
+            var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
+            var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
+
+            var serviceB1 = scopeB.ServiceProvider.GetKeyedService<IService>("key");
+            var serviceB2 = scopeB.ServiceProvider.GetKeyedService<IService>("key");
+
+            Assert.Same(serviceA1, serviceA2);
+            Assert.Same(serviceB1, serviceB2);
+            Assert.Same(serviceA1, serviceB1);
+        }
+
+        [Fact]
+        public void ResolveKeyedScopedFromScopeServiceProvider()
+        {
+            var serviceCollection = new ServiceCollection();
+            serviceCollection.AddKeyedScoped<IService, Service>("key");
+
+            var provider = CreateServiceProvider(serviceCollection);
+            var scopeA = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();
+            var scopeB = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();
+
+            Assert.Null(scopeA.ServiceProvider.GetService<IService>());
+            Assert.Null(scopeB.ServiceProvider.GetService<IService>());
+
+            var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
+            var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
+
+            var serviceB1 = scopeB.ServiceProvider.GetKeyedService<IService>("key");
+            var serviceB2 = scopeB.ServiceProvider.GetKeyedService<IService>("key");
+
+            Assert.Same(serviceA1, serviceA2);
+            Assert.Same(serviceB1, serviceB2);
+            Assert.NotSame(serviceA1, serviceB1);
+        }
+
+        [Fact]
+        public void ResolveKeyedTransientFromScopeServiceProvider()
+        {
+            var serviceCollection = new ServiceCollection();
+            serviceCollection.AddKeyedTransient<IService, Service>("key");
+
+            var provider = CreateServiceProvider(serviceCollection);
+            var scopeA = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();
+            var scopeB = provider.GetRequiredService<IServiceScopeFactory>().CreateScope();
+
+            Assert.Null(scopeA.ServiceProvider.GetService<IService>());
+            Assert.Null(scopeB.ServiceProvider.GetService<IService>());
+
+            var serviceA1 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
+            var serviceA2 = scopeA.ServiceProvider.GetKeyedService<IService>("key");
+
+            var serviceB1 = scopeB.ServiceProvider.GetKeyedService<IService>("key");
+            var serviceB2 = scopeB.ServiceProvider.GetKeyedService<IService>("key");
+
+            Assert.NotSame(serviceA1, serviceA2);
+            Assert.NotSame(serviceB1, serviceB2);
+            Assert.NotSame(serviceA1, serviceB1);
+        }
+
         internal interface IService { }
 
         internal class Service : IService
@@ -358,5 +466,15 @@ namespace Microsoft.Extensions.DependencyInjection.Specification
 
             public ServiceWithIntKey([ServiceKey] int id) => _id = id;
         }
+
+        internal class ServiceProviderAccessor
+        {
+            public ServiceProviderAccessor(IServiceProvider serviceProvider)
+            {
+                ServiceProvider = serviceProvider;
+            }
+
+            public IServiceProvider ServiceProvider { get; }
+        }
     }
 }
index 5ef37ba..7582ed8 100644 (file)
@@ -12,7 +12,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
 {
     [DebuggerDisplay("{DebuggerToString(),nq}")]
     [DebuggerTypeProxy(typeof(ServiceProviderEngineScopeDebugView))]
-    internal sealed class ServiceProviderEngineScope : IServiceScope, IServiceProvider, IAsyncDisposable, IServiceScopeFactory
+    internal sealed class ServiceProviderEngineScope : IServiceScope, IServiceProvider, IKeyedServiceProvider, IAsyncDisposable, IServiceScopeFactory
     {
         // For testing and debugging only
         internal IList<object> Disposables => _disposables ?? (IList<object>)Array.Empty<object>();
@@ -50,6 +50,26 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
             return RootProvider.GetService(ServiceIdentifier.FromServiceType(serviceType), this);
         }
 
+        public object? GetKeyedService(Type serviceType, object? serviceKey)
+        {
+            if (_disposed)
+            {
+                ThrowHelper.ThrowObjectDisposedException();
+            }
+
+            return RootProvider.GetKeyedService(serviceType, serviceKey, this);
+        }
+
+        public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
+        {
+            if (_disposed)
+            {
+                ThrowHelper.ThrowObjectDisposedException();
+            }
+
+            return RootProvider.GetRequiredKeyedService(serviceType, serviceKey, this);
+        }
+
         public IServiceProvider ServiceProvider => this;
 
         public IServiceScope CreateScope() => RootProvider.CreateScope();
index 6705f7d..6fe894d 100644 (file)
@@ -98,11 +98,17 @@ namespace Microsoft.Extensions.DependencyInjection
         public object? GetService(Type serviceType) => GetService(ServiceIdentifier.FromServiceType(serviceType), Root);
 
         public object? GetKeyedService(Type serviceType, object? serviceKey)
-            => GetService(new ServiceIdentifier(serviceKey, serviceType), Root);
+            => GetKeyedService(serviceType, serviceKey, Root);
+
+        internal object? GetKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope)
+            => GetService(new ServiceIdentifier(serviceKey, serviceType), serviceProviderEngineScope);
 
         public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
+            => GetRequiredKeyedService(serviceType, serviceKey, Root);
+
+        internal object GetRequiredKeyedService(Type serviceType, object? serviceKey, ServiceProviderEngineScope serviceProviderEngineScope)
         {
-            object? service = GetKeyedService(serviceType, serviceKey);
+            object? service = GetKeyedService(serviceType, serviceKey, serviceProviderEngineScope);
             if (service == null)
             {
                 throw new InvalidOperationException(SR.Format(SR.NoServiceRegistered, serviceType));
index f3de857..e25174c 100644 (file)
@@ -2,8 +2,10 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
+using System.Collections.Generic;
 using Microsoft.Extensions.DependencyInjection.Specification.Fakes;
 using Xunit;
+using Xunit.Abstractions;
 
 namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
 {
@@ -29,5 +31,15 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
 
             Assert.Throws<ObjectDisposedException>(() => sp.GetRequiredService<IServiceProvider>());
         }
+
+        [Fact]
+        public void ServiceProviderEngineScope_ImplementsAllServiceProviderInterfaces()
+        {
+            var engineScopeInterfaces = typeof(ServiceProviderEngineScope).GetInterfaces();
+            foreach (var serviceProviderInterface in typeof(ServiceProvider).GetInterfaces())
+            {
+                Assert.Contains(serviceProviderInterface, engineScopeInterfaces);
+            }
+        }
     }
 }