Fix finding attribute data for syntax for assembly/module symbols (#72535)
authorJeremy Koritzinsky <jekoritz@microsoft.com>
Wed, 20 Jul 2022 21:05:24 +0000 (14:05 -0700)
committerGitHub <noreply@github.com>
Wed, 20 Jul 2022 21:05:24 +0000 (14:05 -0700)
src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/CustomMarshallerAttributeAnalyzer.cs
src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/NativeMarshallingAttributeAnalyzer.cs
src/libraries/System.Runtime.InteropServices/gen/LibraryImportGenerator/Analyzers/SyntaxExtensions.cs
src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/NativeMarshallingAttributeAnalyzerTests.cs

index 9e9f732..b3a4f8e 100644 (file)
@@ -262,8 +262,8 @@ namespace Microsoft.Interop.Analyzers
                 AttributeSyntax syntax = (AttributeSyntax)context.Node;
                 ISymbol attributedSymbol = context.ContainingSymbol!;
 
-                AttributeData attr = GetAttributeData(syntax, attributedSymbol);
-                if (attr.AttributeClass?.ToDisplayString() == TypeNames.CustomMarshallerAttribute
+                AttributeData? attr = syntax.FindAttributeData(attributedSymbol);
+                if (attr?.AttributeClass?.ToDisplayString() == TypeNames.CustomMarshallerAttribute
                     && attr.AttributeConstructor is not null)
                 {
                     DiagnosticReporter managedTypeReporter = DiagnosticReporter.CreateForLocation(syntax.FindArgumentWithNameOrArity("managedType", 0).FindTypeExpressionOrNullLocation(), context.ReportDiagnostic);
@@ -313,20 +313,6 @@ namespace Microsoft.Interop.Analyzers
             {
                 // TODO: Implement for the V2 shapes
             }
-
-            private static AttributeData GetAttributeData(AttributeSyntax syntax, ISymbol symbol)
-            {
-                if (syntax.FirstAncestorOrSelf<AttributeListSyntax>().Target?.Identifier.IsKind(SyntaxKind.ReturnKeyword) == true)
-                {
-                    return ((IMethodSymbol)symbol).GetReturnTypeAttributes().First(attributeSyntaxLocationMatches);
-                }
-                return symbol.GetAttributes().First(attributeSyntaxLocationMatches);
-
-                bool attributeSyntaxLocationMatches(AttributeData attrData)
-                {
-                    return attrData.ApplicationSyntaxReference!.SyntaxTree == syntax.SyntaxTree && attrData.ApplicationSyntaxReference.Span == syntax.Span;
-                }
-            }
         }
     }
 }
index b17e43a..bba752c 100644 (file)
@@ -86,8 +86,8 @@ namespace Microsoft.Interop.Analyzers
                 AttributeSyntax syntax = (AttributeSyntax)context.Node;
                 ISymbol attributedSymbol = context.ContainingSymbol!;
 
-                AttributeData attr = GetAttributeData(syntax, attributedSymbol);
-                if (attr.AttributeClass?.ToDisplayString() == TypeNames.NativeMarshallingAttribute
+                AttributeData? attr = syntax.FindAttributeData(attributedSymbol);
+                if (attr?.AttributeClass?.ToDisplayString() == TypeNames.NativeMarshallingAttribute
                     && attr.AttributeConstructor is not null)
                 {
                     INamedTypeSymbol? entryType = (INamedTypeSymbol?)attr.ConstructorArguments[0].Value;
@@ -163,20 +163,6 @@ namespace Microsoft.Interop.Analyzers
                 }
             }
 
-            private static AttributeData GetAttributeData(AttributeSyntax syntax, ISymbol symbol)
-            {
-                if (syntax.FirstAncestorOrSelf<AttributeListSyntax>().Target?.Identifier.IsKind(SyntaxKind.ReturnKeyword) == true)
-                {
-                    return ((IMethodSymbol)symbol).GetReturnTypeAttributes().First(attributeSyntaxLocationMatches);
-                }
-                return symbol.GetAttributes().First(attributeSyntaxLocationMatches);
-
-                bool attributeSyntaxLocationMatches(AttributeData attrData)
-                {
-                    return attrData.ApplicationSyntaxReference!.SyntaxTree == syntax.SyntaxTree && attrData.ApplicationSyntaxReference.Span == syntax.Span;
-                }
-            }
-
             private static ITypeSymbol GetSymbolType(ISymbol symbol)
             {
                 return symbol switch
index 09de03c..00f3d93 100644 (file)
@@ -25,6 +25,31 @@ namespace Microsoft.Interop.Analyzers
             return walker.TypeExpressionLocation;
         }
 
+        public static AttributeData? FindAttributeData(this AttributeSyntax syntax, ISymbol targetSymbol)
+        {
+            AttributeTargetSpecifierSyntax attributeTarget = syntax.FirstAncestorOrSelf<AttributeListSyntax>().Target;
+            if (attributeTarget is not null)
+            {
+                switch (attributeTarget.Identifier.Kind())
+                {
+                    case SyntaxKind.ReturnKeyword:
+                        return ((IMethodSymbol)targetSymbol).GetReturnTypeAttributes().First(attributeSyntaxLocationMatches);
+                    case SyntaxKind.AssemblyKeyword:
+                        return targetSymbol.ContainingAssembly.GetAttributes().First(attributeSyntaxLocationMatches);
+                    case SyntaxKind.ModuleKeyword:
+                        return targetSymbol.ContainingModule.GetAttributes().First(attributeSyntaxLocationMatches);
+                    default:
+                        return null;
+                }
+            }
+            return targetSymbol.GetAttributes().First(attributeSyntaxLocationMatches);
+
+            bool attributeSyntaxLocationMatches(AttributeData attrData)
+            {
+                return attrData.ApplicationSyntaxReference!.SyntaxTree == syntax.SyntaxTree && attrData.ApplicationSyntaxReference.Span == syntax.Span;
+            }
+        }
+
         private sealed class FindTypeLocationWalker : CSharpSyntaxWalker
         {
             public Location? TypeExpressionLocation { get; private set; }
index dd421e7..49587ad 100644 (file)
@@ -256,5 +256,19 @@ namespace LibraryImportGenerator.UnitTests
             await VerifyCS.VerifyAnalyzerAsync(source,
                 VerifyCS.Diagnostic(GenericEntryPointMarshallerTypeMustBeClosedOrMatchArityRule).WithLocation(0).WithArguments("MarshallerType<U, V, W>", "ManagedType<T>"));
         }
+
+        [Fact]
+        public async Task UnrelatedAssemblyOrModuleTargetDiagnostic_DoesNotCauseException()
+        {
+            string source = """
+                using System.Reflection;
+                using System.Runtime.CompilerServices;
+
+                [assembly:AssemblyMetadata("MyKey", "MyValue")]
+                [module:SkipLocalsInit]
+                """;
+
+            await VerifyCS.VerifyAnalyzerAsync(source);
+        }
     }
 }