Add PreserveSig support to ComInterfaceGenerator (#85941)
authorJeremy Koritzinsky <jekoritz@microsoft.com>
Wed, 10 May 2023 18:09:44 +0000 (11:09 -0700)
committerGitHub <noreply@github.com>
Wed, 10 May 2023 18:09:44 +0000 (11:09 -0700)
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/TargetSignatureTests.cs [moved from src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CallingConventionForwarding.cs with 56% similarity]
src/libraries/System.Runtime.InteropServices/tests/Common/Verifiers/CSharpSourceGeneratorVerifier.cs

index 77b40b5..0690bc0 100644 (file)
@@ -7,6 +7,7 @@ using System.Collections.Immutable;
 using System.Diagnostics;
 using System.IO;
 using System.Linq;
+using System.Reflection;
 using System.Threading;
 using Microsoft.CodeAnalysis;
 using Microsoft.CodeAnalysis.CSharp;
@@ -448,46 +449,49 @@ namespace Microsoft.Interop
             // Create the stub.
             var signatureContext = SignatureContext.Create(symbol, DefaultMarshallingInfoParser.Create(environment, generatorDiagnostics, symbol, new InteropAttributeCompilationData(), generatedComAttribute), environment, typeof(VtableIndexStubGenerator).Assembly);
 
-            // Search for the element information for the managed return value.
-            // We need to transform it such that any return type is converted to an out parameter at the end of the parameter list.
-            ImmutableArray<TypePositionInfo> returnSwappedSignatureElements = signatureContext.ElementTypeInformation;
-            for (int i = 0; i < returnSwappedSignatureElements.Length; ++i)
+            if (!symbol.MethodImplementationFlags.HasFlag(MethodImplAttributes.PreserveSig))
             {
-                if (returnSwappedSignatureElements[i].IsManagedReturnPosition)
+                // Search for the element information for the managed return value.
+                // We need to transform it such that any return type is converted to an out parameter at the end of the parameter list.
+                ImmutableArray<TypePositionInfo> returnSwappedSignatureElements = signatureContext.ElementTypeInformation;
+                for (int i = 0; i < returnSwappedSignatureElements.Length; ++i)
                 {
-                    if (returnSwappedSignatureElements[i].ManagedType == SpecialTypeInfo.Void)
+                    if (returnSwappedSignatureElements[i].IsManagedReturnPosition)
                     {
-                        // Return type is void, just remove the element from the signature list.
-                        // We don't introduce an out parameter.
-                        returnSwappedSignatureElements = returnSwappedSignatureElements.RemoveAt(i);
-                    }
-                    else
-                    {
-                        // Convert the current element into an out parameter on the native signature
-                        // while keeping it at the return position in the managed signature.
-                        var managedSignatureAsNativeOut = returnSwappedSignatureElements[i] with
+                        if (returnSwappedSignatureElements[i].ManagedType == SpecialTypeInfo.Void)
                         {
-                            RefKind = RefKind.Out,
-                            RefKindSyntax = SyntaxKind.OutKeyword,
-                            ManagedIndex = TypePositionInfo.ReturnIndex,
-                            NativeIndex = symbol.Parameters.Length
-                        };
-                        returnSwappedSignatureElements = returnSwappedSignatureElements.SetItem(i, managedSignatureAsNativeOut);
+                            // Return type is void, just remove the element from the signature list.
+                            // We don't introduce an out parameter.
+                            returnSwappedSignatureElements = returnSwappedSignatureElements.RemoveAt(i);
+                        }
+                        else
+                        {
+                            // Convert the current element into an out parameter on the native signature
+                            // while keeping it at the return position in the managed signature.
+                            var managedSignatureAsNativeOut = returnSwappedSignatureElements[i] with
+                            {
+                                RefKind = RefKind.Out,
+                                RefKindSyntax = SyntaxKind.OutKeyword,
+                                ManagedIndex = TypePositionInfo.ReturnIndex,
+                                NativeIndex = symbol.Parameters.Length
+                            };
+                            returnSwappedSignatureElements = returnSwappedSignatureElements.SetItem(i, managedSignatureAsNativeOut);
+                        }
+                        break;
                     }
-                    break;
                 }
-            }
 
-            signatureContext = signatureContext with
-            {
-                // Add the HRESULT return value in the native signature.
-                // This element does not have any influence on the managed signature, so don't assign a managed index.
-                ElementTypeInformation = returnSwappedSignatureElements.Add(
-                    new TypePositionInfo(SpecialTypeInfo.Int32, new ManagedHResultExceptionMarshallingInfo())
-                    {
-                        NativeIndex = TypePositionInfo.ReturnIndex
-                    })
-            };
+                signatureContext = signatureContext with
+                {
+                    // Add the HRESULT return value in the native signature.
+                    // This element does not have any influence on the managed signature, so don't assign a managed index.
+                    ElementTypeInformation = returnSwappedSignatureElements.Add(
+                        new TypePositionInfo(SpecialTypeInfo.Int32, new ManagedHResultExceptionMarshallingInfo())
+                        {
+                            NativeIndex = TypePositionInfo.ReturnIndex
+                        })
+                };
+            }
 
             var containingSyntaxContext = new ContainingSyntaxContext(syntax);
 
@@ -2,19 +2,21 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
+using System.Collections.Generic;
 using System.Linq;
 using System.Reflection.Metadata;
 using System.Threading.Tasks;
 using Microsoft.CodeAnalysis;
 using Microsoft.CodeAnalysis.Operations;
 using Microsoft.CodeAnalysis.Testing;
+using Microsoft.Interop;
 using Xunit;
 
-using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.Interop.VtableIndexStubGenerator>;
+using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.CodeAnalysis.Testing.EmptySourceGeneratorProvider>;
 
 namespace ComInterfaceGenerator.Unit.Tests
 {
-    public class CallingConventionForwarding
+    public class TargetSignatureTests
     {
         [Fact]
         public async Task NoSpecifiedCallConvForwardsDefault()
@@ -32,7 +34,7 @@ namespace ComInterfaceGenerator.Unit.Tests
                 }
                 """;
 
-            await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (compilation, signature) =>
+            await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (compilation, signature) =>
             {
                 Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
                 Assert.Empty(signature.UnmanagedCallingConventionTypes);
@@ -56,7 +58,7 @@ namespace ComInterfaceGenerator.Unit.Tests
                 }
                 """;
 
-            await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
+            await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
             {
                 Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
                 Assert.Equal(newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvSuppressGCTransition"), Assert.Single(signature.UnmanagedCallingConventionTypes), SymbolEqualityComparer.Default);
@@ -80,7 +82,7 @@ namespace ComInterfaceGenerator.Unit.Tests
                 }
             """;
 
-            await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (_, signature) =>
+            await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (_, signature) =>
             {
                 Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
                 Assert.Empty(signature.UnmanagedCallingConventionTypes);
@@ -105,7 +107,7 @@ namespace ComInterfaceGenerator.Unit.Tests
                 }
                 """;
 
-            await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (_, signature) =>
+            await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (_, signature) =>
             {
                 Assert.Equal(SignatureCallingConvention.CDecl, signature.CallingConvention);
                 Assert.Empty(signature.UnmanagedCallingConventionTypes);
@@ -130,7 +132,7 @@ namespace ComInterfaceGenerator.Unit.Tests
                 }
                 """;
 
-            await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
+            await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
             {
                 Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
                 Assert.Equal(new[]
@@ -162,7 +164,7 @@ namespace ComInterfaceGenerator.Unit.Tests
                 }
                 """;
 
-            await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
+            await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
             {
                 Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
                 Assert.Equal(new[]
@@ -176,9 +178,80 @@ namespace ComInterfaceGenerator.Unit.Tests
             });
         }
 
-        private static async Task VerifySourceGeneratorAsync(string source, string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
+        [Fact]
+        public async Task ComInterfaceMethodFunctionPointerReturnsInt()
         {
-            CallingConventionForwardingTest test = new(interfaceName, methodName, signatureValidator)
+            string source = $$"""
+                using System.Runtime.CompilerServices;
+                using System.Runtime.InteropServices;
+                using System.Runtime.InteropServices.Marshalling;
+
+                [GeneratedComInterface]
+                [Guid("0A617667-4961-4F90-B74F-6DC368E98179")]
+                partial interface IComInterface
+                {
+                    void Method();
+                }
+                """;
+
+            await VerifyComInterfaceGeneratorAsync(source, "IComInterface", "Method", (newComp, signature) =>
+            {
+                Assert.Equal(SpecialType.System_Int32, signature.ReturnType.SpecialType);
+            });
+        }
+
+        [Fact]
+        public async Task ComInterfaceMethodFunctionPointerReturnTypeChangedToOutParameter()
+        {
+            string source = $$"""
+                using System.Runtime.CompilerServices;
+                using System.Runtime.InteropServices;
+                using System.Runtime.InteropServices.Marshalling;
+
+                [GeneratedComInterface]
+                [Guid("0A617667-4961-4F90-B74F-6DC368E98179")]
+                partial interface IComInterface
+                {
+                    long Method();
+                }
+                """;
+
+            await VerifyComInterfaceGeneratorAsync(source, "IComInterface", "Method", (newComp, signature) =>
+            {
+                Assert.Equal(SpecialType.System_Int32, signature.ReturnType.SpecialType);
+                Assert.Equal(2, signature.Parameters.Length);
+                Assert.Equal(newComp.CreatePointerTypeSymbol(newComp.GetSpecialType(SpecialType.System_Void)), signature.Parameters[0].Type, SymbolEqualityComparer.Default);
+                Assert.Equal(newComp.CreatePointerTypeSymbol(newComp.GetSpecialType(SpecialType.System_Int64)), signature.Parameters[^1].Type, SymbolEqualityComparer.Default);
+            });
+        }
+
+        [Fact]
+        public async Task ComInterfaceMethodPreserveSigFunctionPointerReturnTypePreserved()
+        {
+            string source = $$"""
+                using System.Runtime.CompilerServices;
+                using System.Runtime.InteropServices;
+                using System.Runtime.InteropServices.Marshalling;
+
+                [GeneratedComInterface]
+                [Guid("0A617667-4961-4F90-B74F-6DC368E98179")]
+                partial interface IComInterface
+                {
+                    [PreserveSig]
+                    long Method();
+                }
+                """;
+
+            await VerifyComInterfaceGeneratorAsync(source, "IComInterface", "Method", (newComp, signature) =>
+            {
+                Assert.Equal(SpecialType.System_Int64, signature.ReturnType.SpecialType);
+                Assert.Equal(newComp.CreatePointerTypeSymbol(newComp.GetSpecialType(SpecialType.System_Void)), Assert.Single(signature.Parameters).Type, SymbolEqualityComparer.Default);
+            });
+        }
+
+        private static async Task VerifyVirtualMethodIndexGeneratorAsync(string source, string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
+        {
+            VirtualMethodIndexTargetSignatureTest test = new(interfaceName, methodName, signatureValidator)
             {
                 TestCode = source,
                 TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck
@@ -186,14 +259,24 @@ namespace ComInterfaceGenerator.Unit.Tests
 
             await test.RunAsync();
         }
+        private static async Task VerifyComInterfaceGeneratorAsync(string source, string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
+        {
+            ComInterfaceTargetSignatureTest test = new(interfaceName, methodName, signatureValidator)
+            {
+                TestCode = source,
+                TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck
+            };
 
-        class CallingConventionForwardingTest : VerifyCS.Test
+            await test.RunAsync();
+        }
+
+        private abstract class TargetSignatureTestBase : VerifyCS.Test
         {
             private readonly Action<Compilation, IMethodSymbol> _signatureValidator;
             private readonly string _interfaceName;
             private readonly string _methodName;
 
-            public CallingConventionForwardingTest(string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
+            protected TargetSignatureTestBase(string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
                 : base(referenceAncillaryInterop: true)
             {
                 _signatureValidator = signatureValidator;
@@ -205,12 +288,14 @@ namespace ComInterfaceGenerator.Unit.Tests
             {
                 _signatureValidator(compilation, FindFunctionPointerInvocationSignature(compilation));
             }
+
+            protected abstract INamedTypeSymbol FindImplementationInterface(Compilation compilation, INamedTypeSymbol userDefinedInterface);
             private IMethodSymbol FindFunctionPointerInvocationSignature(Compilation compilation)
             {
                 INamedTypeSymbol? userDefinedInterface = compilation.Assembly.GetTypeByMetadataName(_interfaceName);
                 Assert.NotNull(userDefinedInterface);
 
-                INamedTypeSymbol generatedInterfaceImplementation = Assert.Single(userDefinedInterface.GetTypeMembers("Native"));
+                INamedTypeSymbol generatedInterfaceImplementation = FindImplementationInterface(compilation, userDefinedInterface);
 
                 IMethodSymbol methodImplementation = Assert.Single(generatedInterfaceImplementation.GetMembers($"global::{_interfaceName}.{_methodName}").OfType<IMethodSymbol>());
 
@@ -223,5 +308,38 @@ namespace ComInterfaceGenerator.Unit.Tests
                 return Assert.Single(body.Descendants().OfType<IFunctionPointerInvocationOperation>()).GetFunctionPointerSignature();
             }
         }
+
+        private sealed class VirtualMethodIndexTargetSignatureTest : TargetSignatureTestBase
+        {
+            public VirtualMethodIndexTargetSignatureTest(string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
+                : base(interfaceName, methodName, signatureValidator)
+            {
+            }
+
+            protected override IEnumerable<Type> GetSourceGenerators() => new[] { typeof(VtableIndexStubGenerator) };
+
+            protected override INamedTypeSymbol FindImplementationInterface(Compilation compilation, INamedTypeSymbol userDefinedInterface) => Assert.Single(userDefinedInterface.GetTypeMembers("Native"));
+        }
+
+        private sealed class ComInterfaceTargetSignatureTest : TargetSignatureTestBase
+        {
+            public ComInterfaceTargetSignatureTest(string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator) : base(interfaceName, methodName, signatureValidator)
+            {
+            }
+            protected override IEnumerable<Type> GetSourceGenerators() => new[] { typeof(Microsoft.Interop.ComInterfaceGenerator) };
+
+            protected override INamedTypeSymbol FindImplementationInterface(Compilation compilation, INamedTypeSymbol userDefinedInterface)
+            {
+                INamedTypeSymbol? iUnknownDerivedAttributeType = compilation.GetTypeByMetadataName("System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute`2");
+
+                Assert.NotNull(iUnknownDerivedAttributeType);
+
+                AttributeData iUnknownDerivedAttribute = Assert.Single(
+                    userDefinedInterface.GetAttributes(),
+                    attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.OriginalDefinition, iUnknownDerivedAttributeType));
+
+                return (INamedTypeSymbol)iUnknownDerivedAttribute.AttributeClass!.TypeArguments[1];
+            }
+        }
     }
 }
index 167f000..1af721d 100644 (file)
@@ -19,7 +19,7 @@ using Microsoft.CodeAnalysis.Testing.Verifiers;
 namespace Microsoft.Interop.UnitTests.Verifiers
 {
     public static class CSharpSourceGeneratorVerifier<TSourceGenerator>
-        where TSourceGenerator : IIncrementalGenerator, new()
+        where TSourceGenerator : new()
     {
         public static DiagnosticResult Diagnostic(string diagnosticId)
             => new DiagnosticResult(diagnosticId, DiagnosticSeverity.Error);