Avoid Attribute.GetCustomAttributes() returning null for open generic type (#65237)
authormadelson <1269046+madelson@users.noreply.github.com>
Fri, 4 Mar 2022 04:19:52 +0000 (23:19 -0500)
committerGitHub <noreply@github.com>
Fri, 4 Mar 2022 04:19:52 +0000 (20:19 -0800)
* Avoid Attribute.GetCustomAttributes() returning null for open generic type.

Fix #64335

src/coreclr/System.Private.CoreLib/src/System/Attribute.CoreCLR.cs
src/coreclr/System.Private.CoreLib/src/System/Reflection/RuntimeCustomAttributeData.cs
src/coreclr/System.Private.CoreLib/src/System/Reflection/RuntimeParameterInfo.cs
src/coreclr/nativeaot/System.Private.CoreLib/src/System/Reflection/Attribute.CoreRT.cs
src/libraries/System.Reflection/tests/ParameterInfoTests.cs
src/libraries/System.Runtime/tests/System/Attributes.cs
src/tests/reflection/GenericAttribute/GenericAttributeMetadata.cs
src/tests/reflection/GenericAttribute/GenericAttributeTests.cs

index f33849d..c84e918 100644 (file)
@@ -434,10 +434,8 @@ namespace System
                 SR.Format(SR.Format_AttributeUsage, type));
         }
 
-        private static Attribute[] CreateAttributeArrayHelper(Type elementType, int elementCount)
-        {
-            return (Attribute[])Array.CreateInstance(elementType, elementCount);
-        }
+        private static Attribute[] CreateAttributeArrayHelper(Type elementType, int elementCount) =>
+            elementType.ContainsGenericParameters ? new Attribute[elementCount] : (Attribute[])Array.CreateInstance(elementType, elementCount);
         #endregion
 
         #endregion
@@ -459,7 +457,7 @@ namespace System
             {
                 MemberTypes.Property => InternalGetCustomAttributes((PropertyInfo)element, attributeType, inherit),
                 MemberTypes.Event => InternalGetCustomAttributes((EventInfo)element, attributeType, inherit),
-                _ => (element.GetCustomAttributes(attributeType, inherit) as Attribute[])!,
+                _ => (Attribute[])element.GetCustomAttributes(attributeType, inherit)
             };
         }
 
@@ -474,7 +472,7 @@ namespace System
             {
                 MemberTypes.Property => InternalGetCustomAttributes((PropertyInfo)element, typeof(Attribute), inherit),
                 MemberTypes.Event => InternalGetCustomAttributes((EventInfo)element, typeof(Attribute), inherit),
-                _ => (element.GetCustomAttributes(typeof(Attribute), inherit) as Attribute[])!,
+                _ => (Attribute[])element.GetCustomAttributes(typeof(Attribute), inherit)
             };
         }
 
@@ -536,12 +534,11 @@ namespace System
             if (element.Member == null)
                 throw new ArgumentException(SR.Argument_InvalidParameterInfo, nameof(element));
 
-
             MemberInfo member = element.Member;
             if (member.MemberType == MemberTypes.Method && inherit)
                 return InternalParamGetCustomAttributes(element, attributeType, inherit);
 
-            return (element.GetCustomAttributes(attributeType, inherit) as Attribute[])!;
+            return (Attribute[])element.GetCustomAttributes(attributeType, inherit);
         }
 
         public static Attribute[] GetCustomAttributes(ParameterInfo element!!, bool inherit)
@@ -549,12 +546,11 @@ namespace System
             if (element.Member == null)
                 throw new ArgumentException(SR.Argument_InvalidParameterInfo, nameof(element));
 
-
             MemberInfo member = element.Member;
             if (member.MemberType == MemberTypes.Method && inherit)
                 return InternalParamGetCustomAttributes(element, null, inherit);
 
-            return (element.GetCustomAttributes(typeof(Attribute), inherit) as Attribute[])!;
+            return (Attribute[])element.GetCustomAttributes(typeof(Attribute), inherit);
         }
 
         public static bool IsDefined(ParameterInfo element, Type attributeType)
index bb642f9..d28edee 100644 (file)
@@ -899,7 +899,7 @@ namespace System.Reflection
             Debug.Assert(caType is not null);
 
             if (type.GetElementType() is not null)
-                return (caType.IsValueType) ? Array.Empty<object>() : CreateAttributeArrayHelper(caType, 0);
+                return CreateAttributeArrayHelper(caType, 0);
 
             if (type.IsGenericType && !type.IsGenericTypeDefinition)
                 type = (type.GetGenericTypeDefinition() as RuntimeType)!;
@@ -919,8 +919,6 @@ namespace System.Reflection
 
             RuntimeType.ListBuilder<object> result = default;
             bool mustBeInheritable = false;
-            bool useObjectArray = (caType.IsValueType || caType.ContainsGenericParameters);
-            RuntimeType arrayType = useObjectArray ? (RuntimeType)typeof(object) : caType;
 
             for (int i = 0; i < pcas.Count; i++)
                 result.Add(pcas[i]);
@@ -932,7 +930,7 @@ namespace System.Reflection
                 type = (type.BaseType as RuntimeType)!;
             }
 
-            object[] typedResult = CreateAttributeArrayHelper(arrayType, result.Count);
+            object[] typedResult = CreateAttributeArrayHelper(caType, result.Count);
             for (int i = 0; i < result.Count; i++)
             {
                 typedResult[i] = result[i];
@@ -963,8 +961,6 @@ namespace System.Reflection
 
             RuntimeType.ListBuilder<object> result = default;
             bool mustBeInheritable = false;
-            bool useObjectArray = (caType.IsValueType || caType.ContainsGenericParameters);
-            RuntimeType arrayType = useObjectArray ? (RuntimeType)typeof(object) : caType;
 
             for (int i = 0; i < pcas.Count; i++)
                 result.Add(pcas[i]);
@@ -976,7 +972,7 @@ namespace System.Reflection
                 method = method.GetParentDefinition()!;
             }
 
-            object[] typedResult = CreateAttributeArrayHelper(arrayType, result.Count);
+            object[] typedResult = CreateAttributeArrayHelper(caType, result.Count);
             for (int i = 0; i < result.Count; i++)
             {
                 typedResult[i] = result[i];
@@ -1123,16 +1119,13 @@ namespace System.Reflection
         }
 
         private static object[] GetCustomAttributes(
-            RuntimeModule decoratedModule, int decoratedMetadataToken, int pcaCount, RuntimeType? attributeFilterType)
+            RuntimeModule decoratedModule, int decoratedMetadataToken, int pcaCount, RuntimeType attributeFilterType)
         {
             RuntimeType.ListBuilder<object> attributes = default;
 
             AddCustomAttributes(ref attributes, decoratedModule, decoratedMetadataToken, attributeFilterType, false, default);
 
-            bool useObjectArray = attributeFilterType is null || attributeFilterType.IsValueType || attributeFilterType.ContainsGenericParameters;
-            RuntimeType arrayType = useObjectArray ? (RuntimeType)typeof(object) : attributeFilterType!;
-
-            object[] result = CreateAttributeArrayHelper(arrayType, attributes.Count + pcaCount);
+            object[] result = CreateAttributeArrayHelper(attributeFilterType, attributes.Count + pcaCount);
             for (int i = 0; i < attributes.Count; i++)
             {
                 result[i] = attributes[i];
@@ -1439,6 +1432,42 @@ namespace System.Reflection
 
             return attributeUsageAttribute ?? AttributeUsageAttribute.Default;
         }
+
+        internal static object[] CreateAttributeArrayHelper(RuntimeType caType, int elementCount)
+        {
+            bool useAttributeArray = false;
+            bool useObjectArray = false;
+
+            if (caType == typeof(Attribute))
+            {
+                useAttributeArray = true;
+            }
+            else if (caType.IsValueType)
+            {
+                useObjectArray = true;
+            }
+            else if (caType.ContainsGenericParameters)
+            {
+                if (caType.IsSubclassOf(typeof(Attribute)))
+                {
+                    useAttributeArray = true;
+                }
+                else
+                {
+                    useObjectArray = true;
+                }
+            }
+
+            if (useAttributeArray)
+            {
+                return elementCount == 0 ? Array.Empty<Attribute>() : new Attribute[elementCount];
+            }
+            if (useObjectArray)
+            {
+                return elementCount == 0 ? Array.Empty<object>() : new object[elementCount];
+            }
+            return elementCount == 0 ? caType.GetEmptyArray() : (object[])Array.CreateInstance(caType, elementCount);
+        }
         #endregion
 
         #region Private Static FCalls
@@ -1476,17 +1505,6 @@ namespace System.Reflection
                 module, &pBlobStart, (byte*)blobEnd, out name, out isProperty, out type, out value);
             blobStart = (IntPtr)pBlobStart;
         }
-
-        private static object[] CreateAttributeArrayHelper(RuntimeType elementType, int elementCount)
-        {
-            // If we have 0 elements, don't allocate a new array
-            if (elementCount == 0)
-            {
-                return elementType.GetEmptyArray();
-            }
-
-            return (object[])Array.CreateInstance(elementType, elementCount);
-        }
         #endregion
     }
 
index 33777c7..f889133 100644 (file)
@@ -509,12 +509,12 @@ namespace System.Reflection
 
         public override object[] GetCustomAttributes(Type attributeType!!, bool inherit)
         {
-            if (MdToken.IsNullToken(m_tkParamDef))
-                return Array.Empty<object>();
-
             if (attributeType.UnderlyingSystemType is not RuntimeType attributeRuntimeType)
                 throw new ArgumentException(SR.Arg_MustBeType, nameof(attributeType));
 
+            if (MdToken.IsNullToken(m_tkParamDef))
+                return CustomAttribute.CreateAttributeArrayHelper(attributeRuntimeType, 0);
+
             return CustomAttribute.GetCustomAttributes(this, attributeRuntimeType);
         }
 
index d7f9bb6..f760461 100644 (file)
@@ -141,19 +141,9 @@ namespace System
                 attributes.Add(instantiatedAttribute);
             }
             int count = attributes.Count;
-            Attribute[] result;
-            try
-            {
-                result = (Attribute[])Array.CreateInstance(actualElementType, count);
-            }
-            catch (NotSupportedException) when (actualElementType.ContainsGenericParameters)
-            {
-                // This is here for desktop compatibility (using try-catch as control flow to avoid slowing down the mainline case.)
-                // GetCustomAttributes() normally returns an array of the exact attribute type requested except when
-                // the requested type is an open type. Its ICustomAttributeProvider counterpart would return an Object[] array but that's
-                // not possible with this api's return type so it returns null instead.
-                return null;
-            }
+            Attribute[] result = actualElementType.ContainsGenericParameters
+                ? new Attribute[count]
+                : (Attribute[])Array.CreateInstance(actualElementType, count);
             attributes.CopyTo(result, 0);
             return result;
         }
index f628554..26d2d65 100644 (file)
@@ -246,6 +246,14 @@ namespace System.Reflection.Tests
         }
 
         [Fact]
+        public static void GetCustomAttributesOnParameterWithNullMetadataTokenReturnsArrayOfCorrectType()
+        {
+            var parameterWithNullMetadataToken = typeof(int[]).GetProperty(nameof(Array.Length)).GetMethod.ReturnParameter;
+            Assert.Equal(typeof(Attribute[]), Attribute.GetCustomAttributes(parameterWithNullMetadataToken).GetType());
+            Assert.Equal(typeof(MyAttribute[]), Attribute.GetCustomAttributes(parameterWithNullMetadataToken, typeof(MyAttribute)).GetType());
+        }
+
+        [Fact]
         public void VerifyGetCustomAttributesData()
         {
             ParameterInfo p = GetParameterInfo(typeof(ParameterInfoMetadata), "MethodWithCustomAttribute", 0);
index e219d30..3dda3f5 100644 (file)
@@ -20,7 +20,7 @@ using System.Diagnostics;
 using System.Runtime.InteropServices;
 using Xunit;
 
-[module:Debuggable(true,false)]
+[module: Debuggable(true, false)]
 namespace System.Tests
 {
     public class AttributeIsDefinedTests
@@ -218,6 +218,71 @@ namespace System.Tests
         {
             Assert.True(typeof(ExampleWithAttribute).GetCustomAttributes(typeof(INameable), inherit: false)[0] is NameableAttribute);
         }
+
+        [Fact]
+        [ActiveIssue("https://github.com/dotnet/runtime/issues/56887)", TestRuntimes.Mono)]
+        public static void GetCustomAttributesWorksWithOpenAndClosedGenericTypesForType()
+        {
+            GenericAttributesTestHelper<string>(t => Attribute.GetCustomAttributes(typeof(HasGenericAttribute), t));
+        }
+
+        [Fact]
+        [ActiveIssue("https://github.com/dotnet/runtime/issues/56887)", TestRuntimes.Mono)]
+        public static void GetCustomAttributesWorksWithOpenAndClosedGenericTypesForField()
+        {
+            FieldInfo field = typeof(HasGenericAttribute).GetField(nameof(HasGenericAttribute.Field), BindingFlags.NonPublic | BindingFlags.Instance);
+            GenericAttributesTestHelper<TimeSpan>(t => Attribute.GetCustomAttributes(field, t));
+        }
+
+        [Fact]
+        [ActiveIssue("https://github.com/dotnet/runtime/issues/56887)", TestRuntimes.Mono)]
+        public static void GetCustomAttributesWorksWithOpenAndClosedGenericTypesForConstructor()
+        {
+            ConstructorInfo method = typeof(HasGenericAttribute).GetConstructor(Type.EmptyTypes);
+            GenericAttributesTestHelper<Guid>(t => Attribute.GetCustomAttributes(method, t));
+        }
+
+        [Fact]
+        [ActiveIssue("https://github.com/dotnet/runtime/issues/56887)", TestRuntimes.Mono)]
+        public static void GetCustomAttributesWorksWithOpenAndClosedGenericTypesForMethod()
+        {
+            MethodInfo method = typeof(HasGenericAttribute).GetMethod(nameof(HasGenericAttribute.Method));
+            GenericAttributesTestHelper<long>(t => Attribute.GetCustomAttributes(method, t));
+        }
+
+        [Fact]
+        [ActiveIssue("https://github.com/dotnet/runtime/issues/56887)", TestRuntimes.Mono)]
+        public static void GetCustomAttributesWorksWithOpenAndClosedGenericTypesForParameter()
+        {
+            ParameterInfo parameter = typeof(HasGenericAttribute).GetMethod(nameof(HasGenericAttribute.Method)).GetParameters()[0];
+            GenericAttributesTestHelper<ulong>(t => Attribute.GetCustomAttributes(parameter, t));
+        }
+
+        [Fact]
+        [ActiveIssue("https://github.com/dotnet/runtime/issues/56887)", TestRuntimes.Mono)]
+        public static void GetCustomAttributesWorksWithOpenAndClosedGenericTypesForProperty()
+        {
+            PropertyInfo property = typeof(HasGenericAttribute).GetProperty(nameof(HasGenericAttribute.Property));
+            GenericAttributesTestHelper<List<object>>(t => Attribute.GetCustomAttributes(property, t));
+        }
+
+        [Fact]
+        [ActiveIssue("https://github.com/dotnet/runtime/issues/56887)", TestRuntimes.Mono)]
+        public static void GetCustomAttributesWorksWithOpenAndClosedGenericTypesForEvent()
+        {
+            EventInfo @event = typeof(HasGenericAttribute).GetEvent(nameof(HasGenericAttribute.Event));
+            GenericAttributesTestHelper<DateTime?>(t => Attribute.GetCustomAttributes(@event, t));
+        }
+
+        private static void GenericAttributesTestHelper<TGenericParameter>(Func<Type, Attribute[]> getCustomAttributes)
+        {
+            Attribute[] openGenericAttributes = getCustomAttributes(typeof(GenericAttribute<>));
+            Assert.Empty(openGenericAttributes);
+
+            Attribute[] closedGenericAttributes = getCustomAttributes(typeof(GenericAttribute<TGenericParameter>));
+            Assert.Equal(1, closedGenericAttributes.Length);
+            Assert.Equal(typeof(GenericAttribute<TGenericParameter>[]), closedGenericAttributes.GetType());
+        }
     }
 
     public static class GetCustomAttribute
@@ -226,7 +291,7 @@ namespace System.Tests
         [Fact]
         public static void customAttributeCount()
         {
-            List<CustomAttributeData> customAttributes =  typeof(GetCustomAttribute).Module.CustomAttributes.ToList();
+            List<CustomAttributeData> customAttributes = typeof(GetCustomAttribute).Module.CustomAttributes.ToList();
             // [System.Security.UnverifiableCodeAttribute()]
             // [TestAttributes.FooAttribute()]
             // [TestAttributes.ComplicatedAttribute((Int32)1, Stuff = 2)]
@@ -660,7 +725,7 @@ namespace System.Tests
     }
     public class BaseClass
     {
-        public virtual void TestMethod([ArgumentUsage("for test")]string[] strArray, params string[] strList)
+        public virtual void TestMethod([ArgumentUsage("for test")] string[] strArray, params string[] strList)
         {
         }
     }
@@ -816,7 +881,7 @@ namespace System.Tests
         string Name { get; }
     }
 
-    [AttributeUsage (AttributeTargets.All, AllowMultiple = true)]
+    [AttributeUsage(AttributeTargets.All, AllowMultiple = true)]
     public class NameableAttribute : Attribute, INameable
     {
         string INameable.Name => "Nameable";
@@ -824,4 +889,31 @@ namespace System.Tests
 
     [Nameable]
     public class ExampleWithAttribute { }
+
+    public class GenericAttribute<T> : Attribute
+    {
+    }
+
+    [GenericAttribute<string>]
+    public class HasGenericAttribute
+    {
+        [GenericAttribute<TimeSpan>]
+        internal bool Field;
+
+        [GenericAttribute<Guid>]
+        public HasGenericAttribute() { }
+
+        [GenericAttribute<long>]
+        public void Method([GenericAttribute<ulong>] int parameter)
+        {
+            this.Field = true;
+            this.Event += () => { };
+        }
+
+        [GenericAttribute<List<object>>]
+        public int Property { get; set; }
+
+        [GenericAttribute<DateTime?>]
+        public event Action Event;
+    }
 }
index d0ef635..3bea1e6 100644 (file)
@@ -17,6 +17,8 @@ using System.Runtime.CompilerServices;
 [assembly: MultiAttribute<bool>()]
 [assembly: MultiAttribute<bool>(true)]
 
+[module: SingleAttribute<long>()]
+
 [AttributeUsage(AttributeTargets.Assembly | AttributeTargets.Class | AttributeTargets.Property, AllowMultiple = false)]
 public class SingleAttribute<T> : Attribute
 {
index e7661d2..b1dcd6b 100644 (file)
@@ -18,9 +18,17 @@ class Program
         Assert(((ICustomAttributeProvider)assembly).IsDefined(typeof(SingleAttribute<int>), true));
         Assert(CustomAttributeExtensions.IsDefined(assembly, typeof(SingleAttribute<bool>)));
         Assert(((ICustomAttributeProvider)assembly).IsDefined(typeof(SingleAttribute<bool>), true));
-
+        Assert(!CustomAttributeExtensions.GetCustomAttributes(assembly, typeof(SingleAttribute<>)).GetEnumerator().MoveNext());
+        Assert(!CustomAttributeExtensions.GetCustomAttributes(assembly, typeof(SingleAttribute<>)).GetEnumerator().MoveNext());
 */
 
+        // Module module = programTypeInfo.Module;
+        // AssertAny(CustomAttributeExtensions.GetCustomAttributes(module), a => a is SingleAttribute<long>);
+        // Assert(CustomAttributeExtensions.GetCustomAttributes(module, typeof(SingleAttribute<long>)).GetEnumerator().MoveNext());
+        // Assert(CustomAttributeExtensions.GetCustomAttributes(module, typeof(SingleAttribute<long>)).GetEnumerator().MoveNext());
+        // Assert(!CustomAttributeExtensions.GetCustomAttributes(module, typeof(SingleAttribute<>)).GetEnumerator().MoveNext());
+        // Assert(!CustomAttributeExtensions.GetCustomAttributes(module, typeof(SingleAttribute<>)).GetEnumerator().MoveNext());       
+
         TypeInfo programTypeInfo = typeof(Class).GetTypeInfo();
         Assert(CustomAttributeExtensions.GetCustomAttribute<SingleAttribute<int>>(programTypeInfo) != null);
         Assert(((ICustomAttributeProvider)programTypeInfo).GetCustomAttributes(typeof(SingleAttribute<int>), true) != null);
@@ -152,8 +160,8 @@ class Program
         AssertAny(b10, a => (a as MultiAttribute<Type>)?.Value == typeof(Class));
         AssertAny(b10, a => (a as MultiAttribute<Type>)?.Value == typeof(Class.Derive));
 
-        Assert(CustomAttributeExtensions.GetCustomAttributes(programTypeInfo, typeof(MultiAttribute<>), false) == null);
-        Assert(CustomAttributeExtensions.GetCustomAttributes(programTypeInfo, typeof(MultiAttribute<>), true) == null);
+        Assert(!CustomAttributeExtensions.GetCustomAttributes(programTypeInfo, typeof(MultiAttribute<>), false).GetEnumerator().MoveNext());
+        Assert(!CustomAttributeExtensions.GetCustomAttributes(programTypeInfo, typeof(MultiAttribute<>), true).GetEnumerator().MoveNext());
         Assert(!((ICustomAttributeProvider)programTypeInfo).GetCustomAttributes(typeof(MultiAttribute<>), true).GetEnumerator().MoveNext());
 
         // Test coverage for CustomAttributeData api surface