[release/8.0] Fix support of FromKeyedServicesAttribute in ActivatorUtilities.CreateF...
authorgithub-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Sat, 16 Sep 2023 01:26:40 +0000 (18:26 -0700)
committerGitHub <noreply@github.com>
Sat, 16 Sep 2023 01:26:40 +0000 (18:26 -0700)
* Fix support of FromKeyedServicesAttribute in ActivatorUtilities.CreateFactory

* Addressing comment and adding a test

---------

Co-authored-by: Benjamin Petit <bpetit@microsoft.com>
src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs
src/libraries/Microsoft.Extensions.DependencyInjection.Specification.Tests/src/KeyedDependencyInjectionSpecificationTests.cs
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs

index 6a42373..3ba3072 100644 (file)
@@ -36,7 +36,7 @@ namespace Microsoft.Extensions.DependencyInjection
 #endif
 
         private static readonly MethodInfo GetServiceInfo =
-            GetMethodInfo<Func<IServiceProvider, Type, Type, bool, object?>>((sp, t, r, c) => GetService(sp, t, r, c));
+            GetMethodInfo<Func<IServiceProvider, Type, Type, bool, object?, object?>>((sp, t, r, c, k) => GetService(sp, t, r, c, k));
 
         /// <summary>
         /// Instantiate a type with constructor arguments provided directly and/or from an <see cref="IServiceProvider"/>.
@@ -324,9 +324,9 @@ namespace Microsoft.Extensions.DependencyInjection
             return mc.Method;
         }
 
-        private static object? GetService(IServiceProvider sp, Type type, Type requiredBy, bool hasDefaultValue)
+        private static object? GetService(IServiceProvider sp, Type type, Type requiredBy, bool hasDefaultValue, object? key)
         {
-            object? service = sp.GetService(type);
+            object? service = key == null ? sp.GetService(type) : GetKeyedService(sp, type, key);
             if (service is null && !hasDefaultValue)
             {
                 ThrowHelperUnableToResolveService(type, requiredBy);
@@ -361,10 +361,12 @@ namespace Microsoft.Extensions.DependencyInjection
                 }
                 else
                 {
+                    var keyAttribute = (FromKeyedServicesAttribute?) Attribute.GetCustomAttribute(constructorParameter, typeof(FromKeyedServicesAttribute), inherit: false);
                     var parameterTypeExpression = new Expression[] { serviceProvider,
                         Expression.Constant(parameterType, typeof(Type)),
                         Expression.Constant(constructor.DeclaringType, typeof(Type)),
-                        Expression.Constant(hasDefaultValue) };
+                        Expression.Constant(hasDefaultValue),
+                        Expression.Constant(keyAttribute?.Key) };
                     constructorArguments[i] = Expression.Call(GetServiceInfo, parameterTypeExpression);
                 }
 
@@ -435,10 +437,10 @@ namespace Microsoft.Extensions.DependencyInjection
             if (matchedArgCount == 0)
             {
                 // All injected; use a fast path.
-                Type[] types = GetParameterTypes();
+                FactoryParameterContext[] parameters = GetFactoryParameterContext();
                 return useFixedValues ?
-                    (serviceProvider, arguments) => ReflectionFactoryServiceOnlyFixed(invoker, types, declaringType, serviceProvider) :
-                    (serviceProvider, arguments) => ReflectionFactoryServiceOnlySpan(invoker, types, declaringType, serviceProvider);
+                    (serviceProvider, arguments) => ReflectionFactoryServiceOnlyFixed(invoker, parameters, declaringType, serviceProvider) :
+                    (serviceProvider, arguments) => ReflectionFactoryServiceOnlySpan(invoker, parameters, declaringType, serviceProvider);
             }
 
             if (matchedArgCount == constructorParameters.Length)
@@ -456,16 +458,6 @@ namespace Microsoft.Extensions.DependencyInjection
                     (serviceProvider, arguments) => ReflectionFactoryCanonicalFixed(invoker, parameters, declaringType, serviceProvider, arguments) :
                     (serviceProvider, arguments) => ReflectionFactoryCanonicalSpan(invoker, parameters, declaringType, serviceProvider, arguments);
             }
-
-            Type[] GetParameterTypes()
-            {
-                Type[] types = new Type[constructorParameters.Length];
-                for (int i = 0; i < constructorParameters.Length; i++)
-                {
-                    types[i] = constructorParameters[i].ParameterType;
-                }
-                return types;
-            }
 #else
             ParameterInfo[] constructorParameters = constructor.GetParameters();
             if (constructorParameters.Length == 0)
@@ -484,8 +476,15 @@ namespace Microsoft.Extensions.DependencyInjection
                 for (int i = 0; i < constructorParameters.Length; i++)
                 {
                     ParameterInfo constructorParameter = constructorParameters[i];
+                    FromKeyedServicesAttribute? attr = (FromKeyedServicesAttribute?)
+                        Attribute.GetCustomAttribute(constructorParameter, typeof(FromKeyedServicesAttribute), inherit: false);
                     bool hasDefaultValue = ParameterDefaultValue.TryGetDefaultValue(constructorParameter, out object? defaultValue);
-                    parameters[i] = new FactoryParameterContext(constructorParameter.ParameterType, hasDefaultValue, defaultValue, parameterMap[i] ?? -1);
+                    parameters[i] = new FactoryParameterContext(
+                        constructorParameter.ParameterType,
+                        hasDefaultValue,
+                        defaultValue,
+                        parameterMap[i] ?? -1,
+                        attr?.Key);
                 }
 
                 return parameters;
@@ -495,18 +494,20 @@ namespace Microsoft.Extensions.DependencyInjection
 
         private readonly struct FactoryParameterContext
         {
-            public FactoryParameterContext(Type parameterType, bool hasDefaultValue, object? defaultValue, int argumentIndex)
+            public FactoryParameterContext(Type parameterType, bool hasDefaultValue, object? defaultValue, int argumentIndex, object? serviceKey)
             {
                 ParameterType = parameterType;
                 HasDefaultValue = hasDefaultValue;
                 DefaultValue = defaultValue;
                 ArgumentIndex = argumentIndex;
+                ServiceKey = serviceKey;
             }
 
             public Type ParameterType { get; }
             public bool HasDefaultValue { get; }
             public object? DefaultValue { get; }
             public int ArgumentIndex { get; }
+            public object? ServiceKey { get; }
         }
 
         private static void FindApplicableConstructor(
@@ -825,39 +826,39 @@ namespace Microsoft.Extensions.DependencyInjection
 #if NET8_0_OR_GREATER // Use the faster ConstructorInvoker which also has alloc-free APIs when <= 4 parameters.
         private static object ReflectionFactoryServiceOnlyFixed(
             ConstructorInvoker invoker,
-            Type[] parameterTypes,
+            FactoryParameterContext[] parameters,
             Type declaringType,
             IServiceProvider serviceProvider)
         {
-            Debug.Assert(parameterTypes.Length >= 1 && parameterTypes.Length <= FixedArgumentThreshold);
+            Debug.Assert(parameters.Length >= 1 && parameters.Length <= FixedArgumentThreshold);
             Debug.Assert(FixedArgumentThreshold == 4);
 
             if (serviceProvider is null)
                 ThrowHelperArgumentNullExceptionServiceProvider();
 
-            switch (parameterTypes.Length)
+            switch (parameters.Length)
             {
                 case 1:
                     return invoker.Invoke(
-                        GetService(serviceProvider, parameterTypes[0], declaringType, false));
+                        GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey));
 
                 case 2:
                     return invoker.Invoke(
-                        GetService(serviceProvider, parameterTypes[0], declaringType, false),
-                        GetService(serviceProvider, parameterTypes[1], declaringType, false));
+                        GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey),
+                        GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey));
 
                 case 3:
                     return invoker.Invoke(
-                        GetService(serviceProvider, parameterTypes[0], declaringType, false),
-                        GetService(serviceProvider, parameterTypes[1], declaringType, false),
-                        GetService(serviceProvider, parameterTypes[2], declaringType, false));
+                        GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey),
+                        GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey),
+                        GetService(serviceProvider, parameters[2].ParameterType, declaringType, false, parameters[2].ServiceKey));
 
                 case 4:
                     return invoker.Invoke(
-                        GetService(serviceProvider, parameterTypes[0], declaringType, false),
-                        GetService(serviceProvider, parameterTypes[1], declaringType, false),
-                        GetService(serviceProvider, parameterTypes[2], declaringType, false),
-                        GetService(serviceProvider, parameterTypes[3], declaringType, false));
+                        GetService(serviceProvider, parameters[0].ParameterType, declaringType, false, parameters[0].ServiceKey),
+                        GetService(serviceProvider, parameters[1].ParameterType, declaringType, false, parameters[1].ServiceKey),
+                        GetService(serviceProvider, parameters[2].ParameterType, declaringType, false, parameters[2].ServiceKey),
+                        GetService(serviceProvider, parameters[3].ParameterType, declaringType, false, parameters[3].ServiceKey));
             }
 
             return null!;
@@ -865,17 +866,17 @@ namespace Microsoft.Extensions.DependencyInjection
 
         private static object ReflectionFactoryServiceOnlySpan(
             ConstructorInvoker invoker,
-            Type[] parameterTypes,
+            FactoryParameterContext[] parameters,
             Type declaringType,
             IServiceProvider serviceProvider)
         {
             if (serviceProvider is null)
                 ThrowHelperArgumentNullExceptionServiceProvider();
 
-            object?[] arguments = new object?[parameterTypes.Length];
-            for (int i = 0; i < parameterTypes.Length; i++)
+            object?[] arguments = new object?[parameters.Length];
+            for (int i = 0; i < parameters.Length; i++)
             {
-                arguments[i] = GetService(serviceProvider, parameterTypes[i], declaringType, false);
+                arguments[i] = GetService(serviceProvider, parameters[i].ParameterType, declaringType, false, parameters[i].ServiceKey);
             }
 
             return invoker.Invoke(arguments.AsSpan());
@@ -907,7 +908,8 @@ namespace Microsoft.Extensions.DependencyInjection
                                 serviceProvider,
                                 parameter1.ParameterType,
                                 declaringType,
-                                parameter1.HasDefaultValue)) ?? parameter1.DefaultValue);
+                                parameter1.HasDefaultValue,
+                                parameter1.ServiceKey)) ?? parameter1.DefaultValue);
                 case 2:
                     {
                         ref FactoryParameterContext parameter2 = ref parameters[1];
@@ -920,7 +922,8 @@ namespace Microsoft.Extensions.DependencyInjection
                                     serviceProvider,
                                     parameter1.ParameterType,
                                     declaringType,
-                                    parameter1.HasDefaultValue)) ?? parameter1.DefaultValue,
+                                    parameter1.HasDefaultValue,
+                                    parameter1.ServiceKey)) ?? parameter1.DefaultValue,
                              ((parameter2.ArgumentIndex != -1)
                                 // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
                                 ? arguments![parameter2.ArgumentIndex]
@@ -928,7 +931,8 @@ namespace Microsoft.Extensions.DependencyInjection
                                     serviceProvider,
                                     parameter2.ParameterType,
                                     declaringType,
-                                    parameter2.HasDefaultValue)) ?? parameter2.DefaultValue);
+                                    parameter2.HasDefaultValue,
+                                    parameter2.ServiceKey)) ?? parameter2.DefaultValue);
                     }
                 case 3:
                     {
@@ -943,7 +947,8 @@ namespace Microsoft.Extensions.DependencyInjection
                                     serviceProvider,
                                     parameter1.ParameterType,
                                     declaringType,
-                                    parameter1.HasDefaultValue)) ?? parameter1.DefaultValue,
+                                    parameter1.HasDefaultValue,
+                                    parameter1.ServiceKey)) ?? parameter1.DefaultValue,
                              ((parameter2.ArgumentIndex != -1)
                                 // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
                                 ? arguments![parameter2.ArgumentIndex]
@@ -951,7 +956,8 @@ namespace Microsoft.Extensions.DependencyInjection
                                     serviceProvider,
                                     parameter2.ParameterType,
                                     declaringType,
-                                    parameter2.HasDefaultValue)) ?? parameter2.DefaultValue,
+                                    parameter2.HasDefaultValue,
+                                    parameter2.ServiceKey)) ?? parameter2.DefaultValue,
                              ((parameter3.ArgumentIndex != -1)
                                 // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
                                 ? arguments![parameter3.ArgumentIndex]
@@ -959,7 +965,8 @@ namespace Microsoft.Extensions.DependencyInjection
                                     serviceProvider,
                                     parameter3.ParameterType,
                                     declaringType,
-                                    parameter3.HasDefaultValue)) ?? parameter3.DefaultValue);
+                                    parameter3.HasDefaultValue,
+                                    parameter3.ServiceKey)) ?? parameter3.DefaultValue);
                     }
                 case 4:
                     {
@@ -975,7 +982,8 @@ namespace Microsoft.Extensions.DependencyInjection
                                     serviceProvider,
                                     parameter1.ParameterType,
                                     declaringType,
-                                    parameter1.HasDefaultValue)) ?? parameter1.DefaultValue,
+                                    parameter1.HasDefaultValue,
+                                    parameter1.ServiceKey)) ?? parameter1.DefaultValue,
                              ((parameter2.ArgumentIndex != -1)
                                 // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
                                 ? arguments![parameter2.ArgumentIndex]
@@ -983,7 +991,8 @@ namespace Microsoft.Extensions.DependencyInjection
                                     serviceProvider,
                                     parameter2.ParameterType,
                                     declaringType,
-                                    parameter2.HasDefaultValue)) ?? parameter2.DefaultValue,
+                                    parameter2.HasDefaultValue,
+                                    parameter2.ServiceKey)) ?? parameter2.DefaultValue,
                              ((parameter3.ArgumentIndex != -1)
                                 // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
                                 ? arguments![parameter3.ArgumentIndex]
@@ -991,7 +1000,8 @@ namespace Microsoft.Extensions.DependencyInjection
                                     serviceProvider,
                                     parameter3.ParameterType,
                                     declaringType,
-                                    parameter3.HasDefaultValue)) ?? parameter3.DefaultValue,
+                                    parameter3.HasDefaultValue,
+                                    parameter3.ServiceKey)) ?? parameter3.DefaultValue,
                              ((parameter4.ArgumentIndex != -1)
                                 // Throws a NullReferenceException if arguments is null. Consistent with expression-based factory.
                                 ? arguments![parameter4.ArgumentIndex]
@@ -999,7 +1009,8 @@ namespace Microsoft.Extensions.DependencyInjection
                                     serviceProvider,
                                     parameter4.ParameterType,
                                     declaringType,
-                                    parameter4.HasDefaultValue)) ?? parameter4.DefaultValue);
+                                    parameter4.HasDefaultValue,
+                                    parameter4.ServiceKey)) ?? parameter4.DefaultValue);
                     }
 
             }
@@ -1028,7 +1039,8 @@ namespace Microsoft.Extensions.DependencyInjection
                         serviceProvider,
                         parameter.ParameterType,
                         declaringType,
-                        parameter.HasDefaultValue)) ?? parameter.DefaultValue;
+                        parameter.HasDefaultValue,
+                        parameter.ServiceKey)) ?? parameter.DefaultValue;
             }
 
             return invoker.Invoke(constructorArguments.AsSpan());
@@ -1078,7 +1090,8 @@ namespace Microsoft.Extensions.DependencyInjection
                         serviceProvider,
                         parameter.ParameterType,
                         declaringType,
-                        parameter.HasDefaultValue)) ?? parameter.DefaultValue;
+                        parameter.HasDefaultValue,
+                        parameter.ServiceKey)) ?? parameter.DefaultValue;
             }
 
             return constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, constructorArguments, culture: null);
@@ -1099,5 +1112,17 @@ namespace Microsoft.Extensions.DependencyInjection
             }
         }
 #endif
+
+        private static object? GetKeyedService(IServiceProvider provider, Type type, object? serviceKey)
+        {
+            ThrowHelper.ThrowIfNull(provider);
+
+            if (provider is IKeyedServiceProvider keyedServiceProvider)
+            {
+                return keyedServiceProvider.GetKeyedService(type, serviceKey);
+            }
+
+            throw new InvalidOperationException(SR.KeyedServicesNotSupported);
+        }
     }
 }
index 32112c3..a0dc73d 100644 (file)
@@ -476,5 +476,41 @@ namespace Microsoft.Extensions.DependencyInjection.Specification
 
             public IServiceProvider ServiceProvider { get; }
         }
+
+            [Fact]
+            public void SimpleServiceKeyedResolution()
+            {
+                // Arrange
+                var services = new ServiceCollection();
+                services.AddKeyedTransient<ISimpleService, SimpleService>("simple");
+                services.AddKeyedTransient<ISimpleService, AnotherSimpleService>("another");
+                services.AddTransient<SimpleParentWithDynamicKeyedService>();
+                var provider = CreateServiceProvider(services);
+                var sut = provider.GetService<SimpleParentWithDynamicKeyedService>();
+
+                // Act
+                var result = sut!.GetService("simple");
+
+                // Assert
+                Assert.True(result.GetType() == typeof(SimpleService));
+            }
+
+        public class SimpleParentWithDynamicKeyedService
+        {
+            private readonly IServiceProvider _serviceProvider;
+
+            public SimpleParentWithDynamicKeyedService(IServiceProvider serviceProvider)
+            {
+                _serviceProvider = serviceProvider;
+            }
+
+            public ISimpleService GetService(string name) => _serviceProvider.GetKeyedService<ISimpleService>(name)!;
+        }
+
+        public interface ISimpleService { }
+
+        public class SimpleService : ISimpleService { }
+
+        public class AnotherSimpleService : ISimpleService { }
     }
 }
index dda3caf..f6e7c2f 100644 (file)
@@ -245,6 +245,100 @@ namespace Microsoft.Extensions.DependencyInjection.Tests
 #if NETCOREAPP
         [InlineData(false)]
 #endif
+        public void CreateFactory_CreatesFactoryMethod_KeyedParams(bool useDynamicCode)
+        {
+            var options = new RemoteInvokeOptions();
+            if (!useDynamicCode)
+            {
+                DisableDynamicCode(options);
+            }
+
+            using var remoteHandle = RemoteExecutor.Invoke(static () =>
+            {
+                var factory = ActivatorUtilities.CreateFactory<ClassWithAKeyedBKeyedC>(Type.EmptyTypes);
+
+                var services = new ServiceCollection();
+                services.AddSingleton(new A());
+                services.AddKeyedSingleton("b", new B());
+                services.AddKeyedSingleton("c", new C());
+                using var provider = services.BuildServiceProvider();
+                ClassWithAKeyedBKeyedC item = factory(provider, null);
+
+                Assert.IsType<ObjectFactory<ClassWithAKeyedBKeyedC>>(factory);
+                Assert.NotNull(item.A);
+                Assert.NotNull(item.B);
+                Assert.NotNull(item.C);
+            }, options);
+        }
+
+        [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        [InlineData(true)]
+#if NETCOREAPP
+        [InlineData(false)]
+#endif
+        public void CreateFactory_CreatesFactoryMethod_KeyedParams_5Types(bool useDynamicCode)
+        {
+            var options = new RemoteInvokeOptions();
+            if (!useDynamicCode)
+            {
+                DisableDynamicCode(options);
+            }
+
+            using var remoteHandle = RemoteExecutor.Invoke(static () =>
+            {
+                var factory = ActivatorUtilities.CreateFactory<ClassWithAKeyedBKeyedCSZ>(Type.EmptyTypes);
+
+                var services = new ServiceCollection();
+                services.AddSingleton(new A());
+                services.AddKeyedSingleton("b", new B());
+                services.AddKeyedSingleton("c", new C());
+                services.AddSingleton(new S());
+                services.AddSingleton(new Z());
+                using var provider = services.BuildServiceProvider();
+                ClassWithAKeyedBKeyedCSZ item = factory(provider, null);
+
+                Assert.IsType<ObjectFactory<ClassWithAKeyedBKeyedCSZ>>(factory);
+                Assert.NotNull(item.A);
+                Assert.NotNull(item.B);
+                Assert.NotNull(item.C);
+            }, options);
+        }
+
+        [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        [InlineData(true)]
+#if NETCOREAPP
+        [InlineData(false)]
+#endif
+        public void CreateFactory_CreatesFactoryMethod_KeyedParams_1Injected(bool useDynamicCode)
+        {
+            var options = new RemoteInvokeOptions();
+            if (!useDynamicCode)
+            {
+                DisableDynamicCode(options);
+            }
+
+            using var remoteHandle = RemoteExecutor.Invoke(static () =>
+            {
+                var factory = ActivatorUtilities.CreateFactory<ClassWithAKeyedBKeyedC>(new Type[] { typeof(A) });
+
+                var services = new ServiceCollection();
+                services.AddKeyedSingleton("b", new B());
+                services.AddKeyedSingleton("c", new C());
+                using var provider = services.BuildServiceProvider();
+                ClassWithAKeyedBKeyedC item = factory(provider, new object?[] { new A() });
+
+                Assert.IsType<ObjectFactory<ClassWithAKeyedBKeyedC>>(factory);
+                Assert.NotNull(item.A);
+                Assert.NotNull(item.B);
+                Assert.NotNull(item.C);
+            }, options);
+        }
+
+        [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        [InlineData(true)]
+#if NETCOREAPP
+        [InlineData(false)]
+#endif
         public void CreateFactory_RemoteExecutor_CreatesFactoryMethod(bool useDynamicCode)
         {
             var options = new RemoteInvokeOptions();
@@ -527,6 +621,13 @@ namespace Microsoft.Extensions.DependencyInjection.Tests
     internal class S { }
     internal class Z { }
 
+    internal class ClassWithAKeyedBKeyedC : ClassWithABC
+    {
+        public ClassWithAKeyedBKeyedC(A a, [FromKeyedServices("b")] B b, [FromKeyedServices("c")] C c)
+            : base(a, b, c)
+        { }
+    }
+
     internal class ClassWithABCS : ClassWithABC
     {
         public S S { get; }
@@ -540,6 +641,13 @@ namespace Microsoft.Extensions.DependencyInjection.Tests
         public ClassWithABCSZ(A a, B b, C c, S s, Z z) : base(a, b, c, s) { Z = z; }
     }
 
+    internal class ClassWithAKeyedBKeyedCSZ : ClassWithABCSZ
+    {
+        public ClassWithAKeyedBKeyedCSZ(A a, [FromKeyedServices("b")] B b, [FromKeyedServices("c")] C c, S s, Z z)
+            : base(a, b, c, s, z)
+        { }
+    }
+
     internal class ClassWithABC_FirstConstructorWithAttribute : ClassWithABC
     {
         [ActivatorUtilitiesConstructor]