Add reflection path for ActivatorUtilities.CreateFactory (#81262)
authorJames Newton-King <james@newtonking.com>
Mon, 30 Jan 2023 08:57:02 +0000 (16:57 +0800)
committerGitHub <noreply@github.com>
Mon, 30 Jan 2023 08:57:02 +0000 (16:57 +0800)
Co-authored-by: Stephen Toub <stoub@microsoft.com>
Co-authored-by: Jan Kotas <jkotas@microsoft.com>
Fixes https://github.com/dotnet/runtime/issues/81258

src/libraries/Microsoft.Extensions.DependencyInjection.Abstractions/src/ActivatorUtilities.cs
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ActivatorUtilitiesTests.cs

index 45b1cdc..5a42e12 100644 (file)
@@ -6,6 +6,7 @@ using System.Diagnostics;
 using System.Diagnostics.CodeAnalysis;
 using System.Linq.Expressions;
 using System.Reflection;
+using System.Runtime.CompilerServices;
 using System.Runtime.ExceptionServices;
 using Microsoft.Extensions.Internal;
 
@@ -127,6 +128,15 @@ namespace Microsoft.Extensions.DependencyInjection
             [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type instanceType,
             Type[] argumentTypes)
         {
+#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP
+            if (!RuntimeFeature.IsDynamicCodeSupported)
+            {
+                // Create a reflection-based factory when dynamic code isn't supported, e.g. app is published with NativeAOT.
+                // Reflection-based factory is faster than interpreted expressions and doesn't pull in System.Linq.Expressions dependency.
+                return CreateFactoryReflection(instanceType, argumentTypes);
+            }
+#endif
+
             CreateFactoryInternal(instanceType, argumentTypes, out ParameterExpression provider, out ParameterExpression argumentArray, out Expression factoryExpressionBody);
 
             var factoryLambda = Expression.Lambda<Func<IServiceProvider, object?[]?, object>>(
@@ -152,6 +162,16 @@ namespace Microsoft.Extensions.DependencyInjection
             CreateFactory<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] T>(
                 Type[] argumentTypes)
         {
+#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP
+            if (!RuntimeFeature.IsDynamicCodeSupported)
+            {
+                // Create a reflection-based factory when dynamic code isn't supported, e.g. app is published with NativeAOT.
+                // Reflection-based factory is faster than interpreted expressions and doesn't pull in System.Linq.Expressions dependency.
+                var factory = CreateFactoryReflection(typeof(T), argumentTypes);
+                return (serviceProvider, arguments) => (T)factory(serviceProvider, arguments);
+            }
+#endif
+
             CreateFactoryInternal(typeof(T), argumentTypes, out ParameterExpression provider, out ParameterExpression argumentArray, out Expression factoryExpressionBody);
 
             var factoryLambda = Expression.Lambda<Func<IServiceProvider, object?[]?, T>>(
@@ -264,6 +284,67 @@ namespace Microsoft.Extensions.DependencyInjection
             return Expression.New(constructor, constructorArguments);
         }
 
+#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP
+        private static ObjectFactory CreateFactoryReflection(
+            [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type instanceType,
+            Type?[] argumentTypes)
+        {
+            FindApplicableConstructor(instanceType, argumentTypes, out ConstructorInfo constructor, out int?[] parameterMap);
+
+            ParameterInfo[] constructorParameters = constructor.GetParameters();
+            if (constructorParameters.Length == 0)
+            {
+                return (IServiceProvider serviceProvider, object?[]? arguments) =>
+                    constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, parameters: null, culture: null);
+            }
+
+            FactoryParameterContext[] parameters = new FactoryParameterContext[constructorParameters.Length];
+            for (int i = 0; i < constructorParameters.Length; i++)
+            {
+                ParameterInfo constructorParameter = constructorParameters[i];
+                bool hasDefaultValue = ParameterDefaultValue.TryGetDefaultValue(constructorParameter, out object? defaultValue);
+
+                parameters[i] = new FactoryParameterContext(constructorParameter.ParameterType, hasDefaultValue, defaultValue, parameterMap[i] ?? -1);
+            }
+            Type declaringType = constructor.DeclaringType!;
+
+            return (IServiceProvider serviceProvider, object?[]? arguments) =>
+            {
+                object?[] constructorArguments = new object?[parameters.Length];
+                for (int i = 0; i < parameters.Length; i++)
+                {
+                    ref FactoryParameterContext parameter = ref parameters[i];
+                    constructorArguments[i] = ((parameter.ArgumentIndex != -1)
+                        // Throws an NullReferenceException if arguments is null. Consistent with expression-based factory.
+                        ? arguments![parameter.ArgumentIndex]
+                        : GetService(
+                            serviceProvider,
+                            parameter.ParameterType,
+                            declaringType,
+                            parameter.HasDefaultValue)) ?? parameter.DefaultValue;
+                }
+
+                return constructor.Invoke(BindingFlags.DoNotWrapExceptions, binder: null, constructorArguments, culture: null);
+            };
+        }
+
+        private readonly struct FactoryParameterContext
+        {
+            public FactoryParameterContext(Type parameterType, bool hasDefaultValue, object? defaultValue, int argumentIndex)
+            {
+                ParameterType = parameterType;
+                HasDefaultValue = hasDefaultValue;
+                DefaultValue = defaultValue;
+                ArgumentIndex = argumentIndex;
+            }
+
+            public Type ParameterType { get; }
+            public bool HasDefaultValue { get; }
+            public object? DefaultValue { get; }
+            public int ArgumentIndex { get; }
+        }
+#endif
+
         private static void FindApplicableConstructor(
             [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type instanceType,
             Type?[] argumentTypes,
index 47af6f2..f867f01 100644 (file)
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
+using Microsoft.DotNet.RemoteExecutor;
 using Xunit;
 using static Microsoft.Extensions.DependencyInjection.Tests.AsyncServiceScopeTests;
 
@@ -191,6 +192,159 @@ namespace Microsoft.Extensions.DependencyInjection.Tests
             Assert.IsType<ObjectFactory<ClassWithABCS>>(factory2);
             Assert.IsType<ClassWithABCS>(item2);
         }
+
+        [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        [InlineData(true)]
+#if NETCOREAPP
+        [InlineData(false)]
+#endif
+        public void CreateFactory_RemoteExecutor_CreatesFactoryMethod(bool isDynamicCodeSupported)
+        {
+            var options = new RemoteInvokeOptions();
+            if (!isDynamicCodeSupported)
+            {
+                options.RuntimeConfigurationOptions.Add("System.Runtime.CompilerServices.RuntimeFeature.IsDynamicCodeSupported", "false");
+            }
+
+            using var remoteHandle = RemoteExecutor.Invoke(static () =>
+            {
+                var factory1 = ActivatorUtilities.CreateFactory(typeof(ClassWithABCS), new Type[] { typeof(B) });
+                var factory2 = ActivatorUtilities.CreateFactory<ClassWithABCS>(new Type[] { typeof(B) });
+
+                var services = new ServiceCollection();
+                services.AddSingleton(new A());
+                services.AddSingleton(new C());
+                services.AddSingleton(new S());
+                using var provider = services.BuildServiceProvider();
+                object item1 = factory1(provider, new[] { new B() });
+                var item2 = factory2(provider, new[] { new B() });
+
+                Assert.IsType<ObjectFactory>(factory1);
+                Assert.IsType<ClassWithABCS>(item1);
+
+                Assert.IsType<ObjectFactory<ClassWithABCS>>(factory2);
+                Assert.IsType<ClassWithABCS>(item2);
+            }, options);
+        }
+
+        [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        [InlineData(true)]
+#if NETCOREAPP
+        [InlineData(false)]
+#endif
+        public void CreateFactory_RemoteExecutor_NullArguments_Throws(bool isDynamicCodeSupported)
+        {
+            var options = new RemoteInvokeOptions();
+            if (!isDynamicCodeSupported)
+            {
+                options.RuntimeConfigurationOptions.Add("System.Runtime.CompilerServices.RuntimeFeature.IsDynamicCodeSupported", "false");
+            }
+
+            using var remoteHandle = RemoteExecutor.Invoke(static () =>
+            {
+                var factory1 = ActivatorUtilities.CreateFactory(typeof(ClassWithA), new Type[] { typeof(A) });
+
+                var services = new ServiceCollection();
+                using var provider = services.BuildServiceProvider();
+                Assert.Throws<NullReferenceException>(() => factory1(provider, null));
+            }, options);
+        }
+
+        [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        [InlineData(true)]
+#if NETCOREAPP
+        [InlineData(false)]
+#endif
+        public void CreateFactory_RemoteExecutor_NoArguments_UseNullDefaultValue(bool isDynamicCodeSupported)
+        {
+            var options = new RemoteInvokeOptions();
+            if (!isDynamicCodeSupported)
+            {
+                options.RuntimeConfigurationOptions.Add("System.Runtime.CompilerServices.RuntimeFeature.IsDynamicCodeSupported", "false");
+            }
+
+            using var remoteHandle = RemoteExecutor.Invoke(static () =>
+            {
+                var factory1 = ActivatorUtilities.CreateFactory(typeof(ClassWithADefaultValue), new Type[0]);
+
+                var services = new ServiceCollection();
+                using var provider = services.BuildServiceProvider();
+                var item = (ClassWithADefaultValue)factory1(provider, null);
+                Assert.Null(item.A);
+            }, options);
+        }
+
+        [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        [InlineData(true)]
+#if NETCOREAPP
+        [InlineData(false)]
+#endif
+        public void CreateFactory_RemoteExecutor_NoArguments_ThrowRequiredValue(bool isDynamicCodeSupported)
+        {
+            var options = new RemoteInvokeOptions();
+            if (!isDynamicCodeSupported)
+            {
+                options.RuntimeConfigurationOptions.Add("System.Runtime.CompilerServices.RuntimeFeature.IsDynamicCodeSupported", "false");
+            }
+
+            using var remoteHandle = RemoteExecutor.Invoke(static () =>
+            {
+                var factory1 = ActivatorUtilities.CreateFactory(typeof(ClassWithA), new Type[0]);
+
+                var services = new ServiceCollection();
+                using var provider = services.BuildServiceProvider();
+                var ex = Assert.Throws<InvalidOperationException>(() => factory1(provider, null));
+                Assert.Equal($"Unable to resolve service for type '{typeof(A).FullName}' while attempting to activate '{typeof(ClassWithA).FullName}'.", ex.Message);
+            }, options);
+        }
+
+        [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        [InlineData(true)]
+#if NETCOREAPP
+        [InlineData(false)]
+#endif
+        public void CreateFactory_RemoteExecutor_NullArgument_UseDefaultValue(bool isDynamicCodeSupported)
+        {
+            var options = new RemoteInvokeOptions();
+            if (!isDynamicCodeSupported)
+            {
+                options.RuntimeConfigurationOptions.Add("System.Runtime.CompilerServices.RuntimeFeature.IsDynamicCodeSupported", "false");
+            }
+
+            using var remoteHandle = RemoteExecutor.Invoke(static () =>
+            {
+                var factory1 = ActivatorUtilities.CreateFactory(typeof(ClassWithStringDefaultValue), new[] { typeof(string) });
+
+                var services = new ServiceCollection();
+                using var provider = services.BuildServiceProvider();
+                var item = (ClassWithStringDefaultValue)factory1(provider, new object[] { null });
+                Assert.Equal("DEFAULT", item.Text);
+            }, options);
+        }
+
+        [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
+        [InlineData(true)]
+#if NETCOREAPP
+        [InlineData(false)]
+#endif
+        public void CreateFactory_RemoteExecutor_NoParameters_Success(bool isDynamicCodeSupported)
+        {
+            var options = new RemoteInvokeOptions();
+            if (!isDynamicCodeSupported)
+            {
+                options.RuntimeConfigurationOptions.Add("System.Runtime.CompilerServices.RuntimeFeature.IsDynamicCodeSupported", "false");
+            }
+
+            using var remoteHandle = RemoteExecutor.Invoke(static () =>
+            {
+                var factory1 = ActivatorUtilities.CreateFactory(typeof(A), new Type[0]);
+
+                var services = new ServiceCollection();
+                using var provider = services.BuildServiceProvider();
+                var item = (A)factory1(provider, null);
+                Assert.NotNull(item);
+            }, options);
+        }
     }
 
     internal class A { }
@@ -265,6 +419,15 @@ namespace Microsoft.Extensions.DependencyInjection.Tests
         }
     }
 
+    internal class ClassWithADefaultValue
+    {
+        public A A { get; }
+        public ClassWithADefaultValue(A a = null)
+        {
+            A = a;
+        }
+    }
+
     internal class ABCS
     {
         public A A { get; }
@@ -354,4 +517,13 @@ namespace Microsoft.Extensions.DependencyInjection.Tests
         public ClassWithABC_DefaultConstructorLast(A a) : base(a) { }
         public ClassWithABC_DefaultConstructorLast() : base() { }
     }
+
+    internal class ClassWithStringDefaultValue
+    {
+        public string Text { get; set; }
+        public ClassWithStringDefaultValue(string text = "DEFAULT")
+        {
+            Text = text;
+        }
+    }
 }