Adding support for constrained open generics to DI (#34393)
authorJimmy Bogard <jimmy.bogard@gmail.com>
Thu, 16 Jul 2020 23:08:35 +0000 (18:08 -0500)
committerGitHub <noreply@github.com>
Thu, 16 Jul 2020 23:08:35 +0000 (16:08 -0700)
20 files changed:
src/libraries/Microsoft.Extensions.DependencyInjection/src/Resources/Strings.resx
src/libraries/Microsoft.Extensions.DependencyInjection/src/ServiceLookup/CallSiteFactory.cs
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.External.Tests/Autofac.cs
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.External.Tests/StashBox.cs
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/DependencyInjectionSpecificationTests.cs
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/AbstractClass.cs [new file with mode: 0644]
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassImplementingIComparable.cs [new file with mode: 0644]
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassImplementingIEnumerable.cs [new file with mode: 0644]
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassInheritingAbstractClass.cs [new file with mode: 0644]
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithAbstractClassConstraint.cs [new file with mode: 0644]
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithClassConstraint.cs [new file with mode: 0644]
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithInterfaceConstraint.cs [new file with mode: 0644]
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithNewConstraint.cs [new file with mode: 0644]
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithNoConstraint.cs [new file with mode: 0644]
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithSelfReferencingConstraint.cs [new file with mode: 0644]
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithStructConstraint.cs [new file with mode: 0644]
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ConstrainedFakeOpenGenericService.cs [new file with mode: 0644]
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/IFakeOpenGenericService.cs
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/Fakes/AbstractClass.cs [deleted file]
src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceLookup/CallSiteFactoryTest.cs

index 354fe5fccbfa5ae1b7a60cc4183c3e368c388941..134d98fdd448c84651ce080b6867128833a41929 100644 (file)
   <data name="AsyncDisposableServiceDispose" xml:space="preserve">
     <value>'{0}' type only implements IAsyncDisposable. Use DisposeAsync to dispose the container.</value>
   </data>
+  <data name="GenericConstraintViolation" xml:space="preserve">
+    <value>Generic constraints violated for type '{0}' while attempting to activate '{1}'.</value>
+  </data>
 </root>
\ No newline at end of file
index 2bf5ee16fc8e15d372cf8d07b3689fe8bd71fa2d..9adc47de00d87df357cd58f6089c7749d8f8769b 100644 (file)
@@ -118,7 +118,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
             if (serviceType.IsConstructedGenericType
                 && _descriptorLookup.TryGetValue(serviceType.GetGenericTypeDefinition(), out ServiceDescriptorCacheItem descriptor))
             {
-                return TryCreateOpenGeneric(descriptor.Last, serviceType, callSiteChain, DefaultSlot);
+                return TryCreateOpenGeneric(descriptor.Last, serviceType, callSiteChain, DefaultSlot, true);
             }
 
             return null;
@@ -164,7 +164,7 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
                         {
                             ServiceDescriptor descriptor = _descriptors[i];
                             ServiceCallSite callSite = TryCreateExact(descriptor, itemType, callSiteChain, slot) ??
-                                           TryCreateOpenGeneric(descriptor, itemType, callSiteChain, slot);
+                                           TryCreateOpenGeneric(descriptor, itemType, callSiteChain, slot, false);
 
                             if (callSite != null)
                             {
@@ -230,14 +230,28 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
             return null;
         }
 
-        private ServiceCallSite TryCreateOpenGeneric(ServiceDescriptor descriptor, Type serviceType, CallSiteChain callSiteChain, int slot)
+        private ServiceCallSite TryCreateOpenGeneric(ServiceDescriptor descriptor, Type serviceType, CallSiteChain callSiteChain, int slot, bool throwOnConstraintViolation)
         {
             if (serviceType.IsConstructedGenericType &&
                 serviceType.GetGenericTypeDefinition() == descriptor.ServiceType)
             {
                 Debug.Assert(descriptor.ImplementationType != null, "descriptor.ImplementationType != null");
                 var lifetime = new ResultCache(descriptor.Lifetime, serviceType, slot);
-                Type closedType = descriptor.ImplementationType.MakeGenericType(serviceType.GenericTypeArguments);
+                Type closedType;
+                try
+                {
+                    closedType = descriptor.ImplementationType.MakeGenericType(serviceType.GenericTypeArguments);
+                }
+                catch (ArgumentException ex)
+                {
+                    if (throwOnConstraintViolation)
+                    {
+                        throw new InvalidOperationException(SR.Format(SR.GenericConstraintViolation, serviceType, descriptor.ImplementationType), ex);
+                    }
+
+                    return null;
+                }
+
                 return CreateConstructorCallSite(lifetime, serviceType, closedType, callSiteChain);
             }
 
index 15295982782f122f2d511bd33d597a5f119575ee..ccb31e1105b192565270eccd0937a08696ecde97 100644 (file)
@@ -7,9 +7,14 @@ using Autofac.Extensions.DependencyInjection;
 
 namespace Microsoft.Extensions.DependencyInjection.Specification
 {
-    public class AutofacDependencyInjectionSpecificationTestsDependencyInjectionSpecificationTests
+    public class AutofacDependencyInjectionSpecificationTests : SkippableDependencyInjectionSpecificationTests
     {
-        protected override IServiceProvider CreateServiceProvider(IServiceCollection serviceCollection)
+        public override string[] SkippedTests => new[]
+        {
+            "PublicNoArgCtorConstrainedOpenGenericServicesCanBeResolved"
+        };
+
+        protected override IServiceProvider CreateServiceProviderImpl(IServiceCollection serviceCollection)
         {
             var builder = new ContainerBuilder();
             builder.Populate(serviceCollection);
index 28a27f028af5a402dc535be8157af0b0fc40b8e7..62986632a18abd057f1c3116ef8a4577226ac61c 100644 (file)
@@ -5,9 +5,16 @@ using System;
 
 namespace Microsoft.Extensions.DependencyInjection.Specification
 {
-    public class StashBoxDependencyInjectionSpecificationTestsDependencyInjectionSpecificationTests
+    public class StashBoxDependencyInjectionSpecificationTests : SkippableDependencyInjectionSpecificationTests
     {
-        protected override IServiceProvider CreateServiceProvider(IServiceCollection serviceCollection)
+        public override string[] SkippedTests => new[]
+        {
+            "PublicNoArgCtorConstrainedOpenGenericServicesCanBeResolved",
+            "SelfReferencingConstrainedOpenGenericServicesCanBeResolved",
+            "ClassConstrainedOpenGenericServicesCanBeResolved"
+        };
+
+        protected override IServiceProvider CreateServiceProviderImpl(IServiceCollection serviceCollection)
         {
             return serviceCollection.UseStashbox();
         }
index e50d9ca8b649fa9d66878b468840c3a2848ec7f4..9d82878dd3575bdbb48cd8a37eec1e78e98692f9 100644 (file)
@@ -591,6 +591,160 @@ namespace Microsoft.Extensions.DependencyInjection.Specification
             Assert.Same(singletonService, genericService.Value);
         }
 
+        [Fact]
+        public void ConstrainedOpenGenericServicesCanBeResolved()
+        {
+            // Arrange
+            var collection = new TestServiceCollection();
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(FakeOpenGenericService<>));
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ConstrainedFakeOpenGenericService<>));
+            var poco = new PocoClass();
+            collection.AddSingleton(poco);
+            collection.AddSingleton<IFakeSingletonService, FakeService>();
+            var provider = CreateServiceProvider(collection);
+            // Act
+            var allServices = provider.GetServices<IFakeOpenGenericService<PocoClass>>().ToList();
+            var constrainedServices = provider.GetServices<IFakeOpenGenericService<IFakeSingletonService>>().ToList();
+            var singletonService = provider.GetService<IFakeSingletonService>();
+            // Assert
+            Assert.Equal(2, allServices.Count);
+            Assert.Same(poco, allServices[0].Value);
+            Assert.Same(poco, allServices[1].Value);
+            Assert.Equal(1, constrainedServices.Count);
+            Assert.Same(singletonService, constrainedServices[0].Value);
+        }
+
+        [Fact]
+        public void ConstrainedOpenGenericServicesReturnsEmptyWithNoMatches()
+        {
+            // Arrange
+            var collection = new TestServiceCollection();
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ConstrainedFakeOpenGenericService<>));
+            collection.AddSingleton<IFakeSingletonService, FakeService>();
+            var provider = CreateServiceProvider(collection);
+            // Act
+            var constrainedServices = provider.GetServices<IFakeOpenGenericService<IFakeSingletonService>>().ToList();
+            // Assert
+            Assert.Equal(0, constrainedServices.Count);
+        }
+
+        [Fact]
+        public void InterfaceConstrainedOpenGenericServicesCanBeResolved()
+        {
+            // Arrange
+            var collection = new TestServiceCollection();
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(FakeOpenGenericService<>));
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithInterfaceConstraint<>));
+            var enumerableVal = new ClassImplementingIEnumerable();
+            collection.AddSingleton(enumerableVal);
+            collection.AddSingleton<IFakeSingletonService, FakeService>();
+            var provider = CreateServiceProvider(collection);
+            // Act
+            var allServices = provider.GetServices<IFakeOpenGenericService<ClassImplementingIEnumerable>>().ToList();
+            var constrainedServices = provider.GetServices<IFakeOpenGenericService<IFakeSingletonService>>().ToList();
+            var singletonService = provider.GetService<IFakeSingletonService>();
+            // Assert
+            Assert.Equal(2, allServices.Count);
+            Assert.Same(enumerableVal, allServices[0].Value);
+            Assert.Same(enumerableVal, allServices[1].Value);
+            Assert.Equal(1, constrainedServices.Count);
+            Assert.Same(singletonService, constrainedServices[0].Value);
+        }
+
+        [Fact]
+        public void PublicNoArgCtorConstrainedOpenGenericServicesCanBeResolved()
+        {
+            // Arrange
+            var collection = new TestServiceCollection();
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithNoConstraints<>));
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithNewConstraint<>));
+            var provider = CreateServiceProvider(collection);
+            // Act
+            var allServices = provider.GetServices<IFakeOpenGenericService<PocoClass>>().ToList();
+            var constrainedServices = provider.GetServices<IFakeOpenGenericService<ClassWithPrivateCtor>>().ToList();
+            // Assert
+            Assert.Equal(2, allServices.Count);
+            Assert.Equal(1, constrainedServices.Count);
+        }
+
+        [Fact]
+        public void ClassConstrainedOpenGenericServicesCanBeResolved()
+        {
+            // Arrange
+            var collection = new TestServiceCollection();
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithNoConstraints<>));
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithClassConstraint<>));
+            var provider = CreateServiceProvider(collection);
+            // Act
+            var allServices = provider.GetServices<IFakeOpenGenericService<PocoClass>>().ToList();
+            var constrainedServices = provider.GetServices<IFakeOpenGenericService<int>>().ToList();
+            // Assert
+            Assert.Equal(2, allServices.Count);
+            Assert.Equal(1, constrainedServices.Count);
+        }
+
+        [Fact]
+        public void StructConstrainedOpenGenericServicesCanBeResolved()
+        {
+            // Arrange
+            var collection = new TestServiceCollection();
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithNoConstraints<>));
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithStructConstraint<>));
+            var provider = CreateServiceProvider(collection);
+            // Act
+            var allServices = provider.GetServices<IFakeOpenGenericService<int>>().ToList();
+            var constrainedServices = provider.GetServices<IFakeOpenGenericService<PocoClass>>().ToList();
+            // Assert
+            Assert.Equal(2, allServices.Count);
+            Assert.Equal(1, constrainedServices.Count);
+        }
+
+        [Fact]
+        public void AbstractClassConstrainedOpenGenericServicesCanBeResolved()
+        {
+            // Arrange
+            var collection = new TestServiceCollection();
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(FakeOpenGenericService<>));
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithAbstractClassConstraint<>));
+            var poco = new PocoClass();
+            collection.AddSingleton(poco);
+            var classInheritingClassInheritingAbstractClass = new ClassInheritingClassInheritingAbstractClass();
+            collection.AddSingleton(classInheritingClassInheritingAbstractClass);
+            var provider = CreateServiceProvider(collection);
+            // Act
+            var allServices = provider.GetServices<IFakeOpenGenericService<ClassInheritingClassInheritingAbstractClass>>().ToList();
+            var constrainedServices = provider.GetServices<IFakeOpenGenericService<PocoClass>>().ToList();
+            // Assert
+            Assert.Equal(2, allServices.Count);
+            Assert.Same(classInheritingClassInheritingAbstractClass, allServices[0].Value);
+            Assert.Same(classInheritingClassInheritingAbstractClass, allServices[1].Value);
+            Assert.Equal(1, constrainedServices.Count);
+            Assert.Same(poco, constrainedServices[0].Value);
+        }
+
+        [Fact]
+        public void SelfReferencingConstrainedOpenGenericServicesCanBeResolved()
+        {
+            // Arrange
+            var collection = new TestServiceCollection();
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(FakeOpenGenericService<>));
+            collection.AddTransient(typeof(IFakeOpenGenericService<>), typeof(ClassWithSelfReferencingConstraint<>));
+            var poco = new PocoClass();
+            collection.AddSingleton(poco);
+            var selfComparable = new ClassImplementingIComparable();
+            collection.AddSingleton(selfComparable);
+            var provider = CreateServiceProvider(collection);
+            // Act
+            var allServices = provider.GetServices<IFakeOpenGenericService<ClassImplementingIComparable>>().ToList();
+            var constrainedServices = provider.GetServices<IFakeOpenGenericService<PocoClass>>().ToList();
+            // Assert
+            Assert.Equal(2, allServices.Count);
+            Assert.Same(selfComparable, allServices[0].Value);
+            Assert.Same(selfComparable, allServices[1].Value);
+            Assert.Equal(1, constrainedServices.Count);
+            Assert.Same(poco, constrainedServices[0].Value);
+        }
+
         [Fact]
         public void ClosedServicesPreferredOverOpenGenericServices()
         {
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/AbstractClass.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/AbstractClass.cs
new file mode 100644 (file)
index 0000000..412a042
--- /dev/null
@@ -0,0 +1,10 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes
+{
+    public abstract class AbstractClass
+    {
+        
+    }
+}
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassImplementingIComparable.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassImplementingIComparable.cs
new file mode 100644 (file)
index 0000000..02ffe2e
--- /dev/null
@@ -0,0 +1,13 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes
+{
+    public class ClassImplementingIComparable : IComparable<ClassImplementingIComparable>
+    {
+        public int CompareTo(ClassImplementingIComparable other) => 0;
+    }
+}
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassImplementingIEnumerable.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassImplementingIEnumerable.cs
new file mode 100644 (file)
index 0000000..20fe4e2
--- /dev/null
@@ -0,0 +1,14 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections;
+
+namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes
+{
+    public class ClassImplementingIEnumerable : IEnumerable
+    {
+        public IEnumerator GetEnumerator() => throw new NotImplementedException();
+    }
+}
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassInheritingAbstractClass.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassInheritingAbstractClass.cs
new file mode 100644 (file)
index 0000000..ec37f2b
--- /dev/null
@@ -0,0 +1,21 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes
+{
+    public class ClassInheritingAbstractClass : AbstractClass
+    {
+
+    }
+
+    public class ClassAlsoInheritingAbstractClass : AbstractClass
+    {
+
+    }
+
+    public class ClassInheritingClassInheritingAbstractClass : ClassInheritingAbstractClass
+    {
+
+    }
+}
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithAbstractClassConstraint.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithAbstractClassConstraint.cs
new file mode 100644 (file)
index 0000000..e551986
--- /dev/null
@@ -0,0 +1,14 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes
+{
+    public class ClassWithAbstractClassConstraint<T> : IFakeOpenGenericService<T>
+        where T : AbstractClass
+    {
+        public ClassWithAbstractClassConstraint(T value) => Value = value;
+
+        public T Value { get; }
+    }
+}
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithClassConstraint.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithClassConstraint.cs
new file mode 100644 (file)
index 0000000..b180203
--- /dev/null
@@ -0,0 +1,12 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes
+{
+    public class ClassWithClassConstraint<T> : IFakeOpenGenericService<T>
+        where T : class
+    {
+        public T Value { get; } = default;
+    }
+}
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithInterfaceConstraint.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithInterfaceConstraint.cs
new file mode 100644 (file)
index 0000000..efd2c9c
--- /dev/null
@@ -0,0 +1,16 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections;
+
+namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes
+{
+    public class ClassWithInterfaceConstraint<T> : IFakeOpenGenericService<T>
+        where T : IEnumerable
+    {
+        public ClassWithInterfaceConstraint(T value) => Value = value;
+
+        public T Value { get; }
+    }
+}
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithNewConstraint.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithNewConstraint.cs
new file mode 100644 (file)
index 0000000..143986c
--- /dev/null
@@ -0,0 +1,12 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes
+{
+    public class ClassWithNewConstraint<T> : IFakeOpenGenericService<T>
+        where T : new()
+    {
+        public T Value { get; } = new T();
+    }
+}
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithNoConstraint.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithNoConstraint.cs
new file mode 100644 (file)
index 0000000..848e755
--- /dev/null
@@ -0,0 +1,11 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes
+{
+    public class ClassWithNoConstraints<T> : IFakeOpenGenericService<T>
+    {
+        public T Value { get; } = default;
+    }
+}
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithSelfReferencingConstraint.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithSelfReferencingConstraint.cs
new file mode 100644 (file)
index 0000000..0e46429
--- /dev/null
@@ -0,0 +1,16 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes
+{
+    public class ClassWithSelfReferencingConstraint<T> : IFakeOpenGenericService<T>
+        where T : IComparable<T>
+    {
+        public ClassWithSelfReferencingConstraint(T value) => Value = value;
+
+        public T Value { get; }
+    }
+}
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithStructConstraint.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ClassWithStructConstraint.cs
new file mode 100644 (file)
index 0000000..06355eb
--- /dev/null
@@ -0,0 +1,12 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes
+{
+    public class ClassWithStructConstraint<T> : IFakeOpenGenericService<T>
+        where T : struct
+    {
+        public T Value { get; } = default;
+    }
+}
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ConstrainedFakeOpenGenericService.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Specification.Tests/Fakes/ConstrainedFakeOpenGenericService.cs
new file mode 100644 (file)
index 0000000..940e362
--- /dev/null
@@ -0,0 +1,16 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes
+{
+    public class ConstrainedFakeOpenGenericService<TVal> : IFakeOpenGenericService<TVal>
+        where TVal : PocoClass
+    {
+        public ConstrainedFakeOpenGenericService(TVal value)
+        {
+            Value = value;
+        }
+        public TVal Value { get; }
+    }
+}
index 71870357fd9d310577b514f12ae6a7815b962e2c..d556939012de31cefae8c4ed3be94bce26e2b21c 100644 (file)
@@ -3,7 +3,7 @@
 
 namespace Microsoft.Extensions.DependencyInjection.Specification.Fakes
 {
-    public interface IFakeOpenGenericService<TValue>
+    public interface IFakeOpenGenericService<out TValue>
     {
         TValue Value { get; }
     }
diff --git a/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/Fakes/AbstractClass.cs b/src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/Fakes/AbstractClass.cs
deleted file mode 100644 (file)
index b2d1cff..0000000
+++ /dev/null
@@ -1,12 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
-namespace Microsoft.Extensions.DependencyInjection.Tests.Fakes
-{
-    public abstract class AbstractClass
-    {
-        public AbstractClass()
-        {
-        }
-    }
-}
index bc0988f06d64d1e150911cad7906dbe240e93b57..b8d8e6a0fcd78459d0f1a14e7d42dafb8f897449 100644 (file)
@@ -111,6 +111,299 @@ namespace Microsoft.Extensions.DependencyInjection.ServiceLookup
             Assert.Empty(ctorCallSite.ParameterCallSites);
         }
 
+        [Fact]
+        public void CreateCallSite_Throws_IfClosedTypeDoesNotSatisfyStructGenericConstraint()
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<>);
+            var implementationType = typeof(ClassWithStructConstraint<>);
+            var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient);
+            var callSiteFactory = GetCallSiteFactory(descriptor);
+            // Act
+            var nonMatchingType = typeof(IFakeOpenGenericService<object>);
+            // Assert
+            var ex = Assert.Throws<InvalidOperationException>(() => callSiteFactory(nonMatchingType));
+            Assert.Equal($"Generic constraints violated for type '{nonMatchingType}' while attempting to activate '{implementationType}'.", ex.Message);
+        }
+
+        [Fact]
+        public void CreateCallSite_ReturnsService_IfClosedTypeSatisfiesStructGenericConstraint()
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<>);
+            var implementationType = typeof(ClassWithStructConstraint<>);
+            var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient);
+            var callSiteFactory = GetCallSiteFactory(descriptor);
+            // Act
+            var matchingType = typeof(IFakeOpenGenericService<int>);
+            var matchingCallSite = callSiteFactory(matchingType);
+            // Assert
+            Assert.NotNull(matchingCallSite);
+        }
+
+        [Fact]
+        public void CreateCallSite_Throws_IfClosedTypeDoesNotSatisfyClassGenericConstraint()
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<>);
+            var implementationType = typeof(ClassWithClassConstraint<>);
+            var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient);
+            var callSiteFactory = GetCallSiteFactory(descriptor);
+            // Act
+            var nonMatchingType = typeof(IFakeOpenGenericService<int>);
+            // Assert
+            var ex = Assert.Throws<InvalidOperationException>(() => callSiteFactory(nonMatchingType));
+            Assert.Equal($"Generic constraints violated for type '{nonMatchingType}' while attempting to activate '{implementationType}'.", ex.Message);
+        }
+
+        [Fact]
+        public void CreateCallSite_ReturnsService_IfClosedTypeSatisfiesClassGenericConstraint()
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<>);
+            var implementationType = typeof(ClassWithClassConstraint<>);
+            var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient);
+            var callSiteFactory = GetCallSiteFactory(descriptor);
+            // Act
+            var matchingType = typeof(IFakeOpenGenericService<object>);
+            var matchingCallSite = callSiteFactory(matchingType);
+            // Assert
+            Assert.NotNull(matchingCallSite);
+        }
+
+        [Fact]
+        public void CreateCallSite_Throws_IfClosedTypeDoesNotSatisfyNewGenericConstraint()
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<>);
+            var implementationType = typeof(ClassWithNewConstraint<>);
+            var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient);
+            var callSiteFactory = GetCallSiteFactory(descriptor);
+            // Act
+            var nonMatchingType = typeof(IFakeOpenGenericService<TypeWithNoPublicConstructors>);
+            // Assert
+            var ex = Assert.Throws<InvalidOperationException>(() => callSiteFactory(nonMatchingType));
+            Assert.Equal($"Generic constraints violated for type '{nonMatchingType}' while attempting to activate '{implementationType}'.", ex.Message);
+        }
+
+        [Fact]
+        public void CreateCallSite_ReturnsService_IfClosedTypeSatisfiesNewGenericConstraint()
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<>);
+            var implementationType = typeof(ClassWithNewConstraint<>);
+            var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient);
+            var callSiteFactory = GetCallSiteFactory(descriptor, new ServiceDescriptor(typeof(TypeWithParameterlessPublicConstructor), new TypeWithParameterlessPublicConstructor()));
+            // Act
+            var matchingType = typeof(IFakeOpenGenericService<TypeWithParameterlessPublicConstructor>);
+            var matchingCallSite = callSiteFactory(matchingType);
+            // Assert
+            Assert.NotNull(matchingCallSite);
+        }
+
+        [Fact]
+        public void CreateCallSite_Throws_IfClosedTypeDoesNotSatisfyInterfaceGenericConstraint()
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<>);
+            var implementationType = typeof(ClassWithInterfaceConstraint<>);
+            var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient);
+            var callSiteFactory = GetCallSiteFactory(descriptor);
+            // Act
+            var nonMatchingType = typeof(IFakeOpenGenericService<int>);
+            // Assert
+            var ex = Assert.Throws<InvalidOperationException>(() => callSiteFactory(nonMatchingType));
+            Assert.Equal($"Generic constraints violated for type '{nonMatchingType}' while attempting to activate '{implementationType}'.", ex.Message);
+        }
+
+        [Fact]
+        public void CreateCallSite_ReturnsService_IfClosedTypeSatisfiesInterfaceGenericConstraint()
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<>);
+            var implementationType = typeof(ClassWithInterfaceConstraint<>);
+            var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient);
+            var callSiteFactory = GetCallSiteFactory(descriptor, new ServiceDescriptor(typeof(string), ""));
+            // Act
+            var matchingType = typeof(IFakeOpenGenericService<string>);
+            var matchingCallSite = callSiteFactory(matchingType);
+            // Assert
+            Assert.NotNull(matchingCallSite);
+        }
+
+        [Fact]
+        public void CreateCallSite_Throws_IfClosedTypeDoesNotSatisfyAbstractClassGenericConstraint()
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<>);
+            var implementationType = typeof(ClassWithAbstractClassConstraint<>);
+            var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient);
+            var callSiteFactory = GetCallSiteFactory(descriptor);
+            // Act
+            var nonMatchingType = typeof(IFakeOpenGenericService<object>);
+            // Assert
+            var ex = Assert.Throws<InvalidOperationException>(() => callSiteFactory(nonMatchingType));
+            Assert.Equal($"Generic constraints violated for type '{nonMatchingType}' while attempting to activate '{implementationType}'.", ex.Message);
+        }
+
+        [Fact]
+        public void CreateCallSite_ReturnsService_IfClosedTypeSatisfiesAbstractClassGenericConstraint()
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<>);
+            var implementationType = typeof(ClassWithAbstractClassConstraint<>);
+            var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient);
+            var callSiteFactory = GetCallSiteFactory(descriptor, new ServiceDescriptor(typeof(ClassInheritingAbstractClass), new ClassInheritingAbstractClass()));
+            // Act
+            var matchingType = typeof(IFakeOpenGenericService<ClassInheritingAbstractClass>);
+            var matchingCallSite = callSiteFactory(matchingType);
+            // Assert
+            Assert.NotNull(matchingCallSite);
+        }
+
+        [Fact]
+        public void CreateCallSite_Throws_IfClosedTypeDoesNotSatisfySelfReferencingConstraint()
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<>);
+            var implementationType = typeof(ClassWithSelfReferencingConstraint<>);
+            var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient);
+            var callSiteFactory = GetCallSiteFactory(descriptor);
+            // Act
+            var nonMatchingType = typeof(IFakeOpenGenericService<object>);
+            // Assert
+            var ex = Assert.Throws<InvalidOperationException>(() => callSiteFactory(nonMatchingType));
+            Assert.Equal($"Generic constraints violated for type '{nonMatchingType}' while attempting to activate '{implementationType}'.", ex.Message);
+        }
+
+        [Fact]
+        public void CreateCallSite_Throws_IfComplexClosedTypeDoesNotSatisfySelfReferencingConstraint()
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<>);
+            var implementationType = typeof(ClassWithSelfReferencingConstraint<>);
+            var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient);
+            var callSiteFactory = GetCallSiteFactory(descriptor);
+            // Act
+            var nonMatchingType = typeof(IFakeOpenGenericService<int[]>);
+            // Assert
+            var ex = Assert.Throws<InvalidOperationException>(() => callSiteFactory(nonMatchingType));
+            Assert.Equal($"Generic constraints violated for type '{nonMatchingType}' while attempting to activate '{implementationType}'.", ex.Message);
+        }
+
+        [Fact]
+        public void CreateCallSite_ReturnsService_IfClosedTypeSatisfiesSelfReferencing()
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<>);
+            var implementationType = typeof(ClassWithSelfReferencingConstraint<>);
+            var descriptor = new ServiceDescriptor(serviceType, implementationType, ServiceLifetime.Transient);
+            var callSiteFactory = GetCallSiteFactory(descriptor, new ServiceDescriptor(typeof(string), ""));
+            // Act
+            var matchingType = typeof(IFakeOpenGenericService<string>);
+            var matchingCallSite = callSiteFactory(matchingType);
+            // Assert
+            Assert.NotNull(matchingCallSite);
+        }
+
+        [Fact]
+        public void CreateCallSite_ReturnsEmpty_IfClosedTypeSatisfiesBaseClassConstraintButRegisteredTypeNotExactMatch()
+        {
+            // Arrange
+            var classInheritingAbstractClassImplementationType = typeof(ClassWithAbstractClassConstraint<ClassInheritingAbstractClass>);
+            var classInheritingAbstractClassDescriptor = new ServiceDescriptor(typeof(IFakeOpenGenericService<ClassInheritingAbstractClass>), classInheritingAbstractClassImplementationType, ServiceLifetime.Transient);
+            var classAlsoInheritingAbstractClassImplementationType = typeof(ClassWithAbstractClassConstraint<ClassAlsoInheritingAbstractClass>);
+            var classAlsoInheritingAbstractClassDescriptor = new ServiceDescriptor(typeof(IFakeOpenGenericService<ClassAlsoInheritingAbstractClass>), classAlsoInheritingAbstractClassImplementationType, ServiceLifetime.Transient);
+            var classInheritingClassInheritingAbstractClassImplementationType = typeof(ClassWithAbstractClassConstraint<ClassInheritingClassInheritingAbstractClass>);
+            var classInheritingClassInheritingAbstractClassDescriptor = new ServiceDescriptor(typeof(IFakeOpenGenericService<ClassInheritingClassInheritingAbstractClass>), classInheritingClassInheritingAbstractClassImplementationType, ServiceLifetime.Transient);
+            var notMatchingServiceType = typeof(IFakeOpenGenericService<PocoClass>);
+            var notMatchingType = typeof(FakeService);
+            var notMatchingDescriptor = new ServiceDescriptor(notMatchingServiceType, notMatchingType, ServiceLifetime.Transient);
+
+            var callSiteFactory = GetCallSiteFactory(classInheritingAbstractClassDescriptor, classAlsoInheritingAbstractClassDescriptor, classInheritingClassInheritingAbstractClassDescriptor, notMatchingDescriptor);
+            // Act
+            var matchingType = typeof(IEnumerable<IFakeOpenGenericService<AbstractClass>>);
+            var matchingCallSite = callSiteFactory(matchingType);
+            // Assert
+            var enumerableCall = Assert.IsType<IEnumerableCallSite>(matchingCallSite);
+
+            Assert.Empty(enumerableCall.ServiceCallSites);
+        }
+
+        [Fact]
+        public void CreateCallSite_ReturnsMatchingTypes_IfClosedTypeSatisfiesBaseClassConstraintAndRegisteredType()
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<AbstractClass>);
+            var classInheritingAbstractClassImplementationType = typeof(ClassWithAbstractClassConstraint<ClassInheritingAbstractClass>);
+            var classInheritingAbstractClassDescriptor = new ServiceDescriptor(serviceType, classInheritingAbstractClassImplementationType, ServiceLifetime.Transient);
+            var classAlsoInheritingAbstractClassImplementationType = typeof(ClassWithAbstractClassConstraint<ClassAlsoInheritingAbstractClass>);
+            var classAlsoInheritingAbstractClassDescriptor = new ServiceDescriptor(serviceType, classAlsoInheritingAbstractClassImplementationType, ServiceLifetime.Transient);
+            var classInheritingClassInheritingAbstractClassImplementationType = typeof(ClassWithAbstractClassConstraint<ClassInheritingClassInheritingAbstractClass>);
+            var classInheritingClassInheritingAbstractClassDescriptor = new ServiceDescriptor(serviceType, classInheritingClassInheritingAbstractClassImplementationType, ServiceLifetime.Transient);
+            var notMatchingServiceType = typeof(IFakeOpenGenericService<PocoClass>);
+            var notMatchingType = typeof(FakeService);
+            var notMatchingDescriptor = new ServiceDescriptor(notMatchingServiceType, notMatchingType, ServiceLifetime.Transient);
+
+            var descriptors = new[]
+            {
+                classInheritingAbstractClassDescriptor,
+                new ServiceDescriptor(typeof(ClassInheritingAbstractClass), new ClassInheritingAbstractClass()),
+                classAlsoInheritingAbstractClassDescriptor,
+                new ServiceDescriptor(typeof(ClassAlsoInheritingAbstractClass), new ClassAlsoInheritingAbstractClass()),
+                classInheritingClassInheritingAbstractClassDescriptor,
+                new ServiceDescriptor(typeof(ClassInheritingClassInheritingAbstractClass), new ClassInheritingClassInheritingAbstractClass()),
+                notMatchingDescriptor
+            };
+            var callSiteFactory = GetCallSiteFactory(descriptors);
+            // Act
+            var matchingType = typeof(IEnumerable<>).MakeGenericType(serviceType);
+            var matchingCallSite = callSiteFactory(matchingType);
+            // Assert
+            var enumerableCall = Assert.IsType<IEnumerableCallSite>(matchingCallSite);
+
+            var matchingTypes = new[]
+            {
+                classInheritingAbstractClassImplementationType,
+                classAlsoInheritingAbstractClassImplementationType,
+                classInheritingClassInheritingAbstractClassImplementationType
+            };
+            Assert.Equal(matchingTypes.Length, enumerableCall.ServiceCallSites.Length);
+            Assert.Equal(matchingTypes, enumerableCall.ServiceCallSites.Select(scs => scs.ImplementationType).ToArray());
+        }
+
+        [Theory]
+        [InlineData(typeof(IFakeOpenGenericService<int>), default(int), new[] { typeof(FakeOpenGenericService<int>), typeof(ClassWithStructConstraint<int>), typeof(ClassWithNewConstraint<int>), typeof(ClassWithSelfReferencingConstraint<int>) })]
+        [InlineData(typeof(IFakeOpenGenericService<string>), "", new[] { typeof(FakeOpenGenericService<string>), typeof(ClassWithClassConstraint<string>), typeof(ClassWithInterfaceConstraint<string>), typeof(ClassWithSelfReferencingConstraint<string>) })]
+        [InlineData(typeof(IFakeOpenGenericService<int[]>), new[] { 1, 2, 3 }, new[] { typeof(FakeOpenGenericService<int[]>), typeof(ClassWithClassConstraint<int[]>), typeof(ClassWithInterfaceConstraint<int[]>) })]
+        public void CreateCallSite_ReturnsMatchingTypesThatMatchCorrectConstraints(Type closedServiceType, object value, Type[] matchingImplementationTypes)
+        {
+            // Arrange
+            var serviceType = typeof(IFakeOpenGenericService<>);
+            var noConstraintImplementationType = typeof(FakeOpenGenericService<>);
+            var noConstraintDescriptor = new ServiceDescriptor(serviceType, noConstraintImplementationType, ServiceLifetime.Transient);
+            var structImplementationType = typeof(ClassWithStructConstraint<>);
+            var structDescriptor = new ServiceDescriptor(serviceType, structImplementationType, ServiceLifetime.Transient);
+            var classImplementationType = typeof(ClassWithClassConstraint<>);
+            var classDescriptor = new ServiceDescriptor(serviceType, classImplementationType, ServiceLifetime.Transient);
+            var newImplementationType = typeof(ClassWithNewConstraint<>);
+            var newDescriptor = new ServiceDescriptor(serviceType, newImplementationType, ServiceLifetime.Transient);
+            var interfaceImplementationType = typeof(ClassWithInterfaceConstraint<>);
+            var interfaceDescriptor = new ServiceDescriptor(serviceType, interfaceImplementationType, ServiceLifetime.Transient);
+            var selfConstraintImplementationType = typeof(ClassWithSelfReferencingConstraint<>);
+            var selfConstraintDescriptor = new ServiceDescriptor(serviceType, selfConstraintImplementationType, ServiceLifetime.Transient);
+            var serviceValueType = closedServiceType.GenericTypeArguments[0];
+            var serviceValueDescriptor = new ServiceDescriptor(serviceValueType, value);
+            var callSiteFactory = GetCallSiteFactory(noConstraintDescriptor, structDescriptor, classDescriptor, newDescriptor, interfaceDescriptor, selfConstraintDescriptor, serviceValueDescriptor);
+            var collectionType = typeof(IEnumerable<>).MakeGenericType(closedServiceType);
+            // Act
+            var callSite = callSiteFactory(collectionType);
+            // Assert
+            var enumerableCall = Assert.IsType<IEnumerableCallSite>(callSite);
+            Assert.Equal(matchingImplementationTypes.Length, enumerableCall.ServiceCallSites.Length);
+            Assert.Equal(matchingImplementationTypes, enumerableCall.ServiceCallSites.Select(scs => scs.ImplementationType).ToArray());
+        }
+
         public static TheoryData CreateCallSite_PicksConstructorWithTheMostNumberOfResolvedParametersData =>
             new TheoryData<Type, Func<Type, ServiceCallSite>, Type[]>
             {