using System.Diagnostics;
using System.IO;
using System.Linq;
+using System.Reflection;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
// Create the stub.
var signatureContext = SignatureContext.Create(symbol, DefaultMarshallingInfoParser.Create(environment, generatorDiagnostics, symbol, new InteropAttributeCompilationData(), generatedComAttribute), environment, typeof(VtableIndexStubGenerator).Assembly);
- // Search for the element information for the managed return value.
- // We need to transform it such that any return type is converted to an out parameter at the end of the parameter list.
- ImmutableArray<TypePositionInfo> returnSwappedSignatureElements = signatureContext.ElementTypeInformation;
- for (int i = 0; i < returnSwappedSignatureElements.Length; ++i)
+ if (!symbol.MethodImplementationFlags.HasFlag(MethodImplAttributes.PreserveSig))
{
- if (returnSwappedSignatureElements[i].IsManagedReturnPosition)
+ // Search for the element information for the managed return value.
+ // We need to transform it such that any return type is converted to an out parameter at the end of the parameter list.
+ ImmutableArray<TypePositionInfo> returnSwappedSignatureElements = signatureContext.ElementTypeInformation;
+ for (int i = 0; i < returnSwappedSignatureElements.Length; ++i)
{
- if (returnSwappedSignatureElements[i].ManagedType == SpecialTypeInfo.Void)
+ if (returnSwappedSignatureElements[i].IsManagedReturnPosition)
{
- // Return type is void, just remove the element from the signature list.
- // We don't introduce an out parameter.
- returnSwappedSignatureElements = returnSwappedSignatureElements.RemoveAt(i);
- }
- else
- {
- // Convert the current element into an out parameter on the native signature
- // while keeping it at the return position in the managed signature.
- var managedSignatureAsNativeOut = returnSwappedSignatureElements[i] with
+ if (returnSwappedSignatureElements[i].ManagedType == SpecialTypeInfo.Void)
{
- RefKind = RefKind.Out,
- RefKindSyntax = SyntaxKind.OutKeyword,
- ManagedIndex = TypePositionInfo.ReturnIndex,
- NativeIndex = symbol.Parameters.Length
- };
- returnSwappedSignatureElements = returnSwappedSignatureElements.SetItem(i, managedSignatureAsNativeOut);
+ // Return type is void, just remove the element from the signature list.
+ // We don't introduce an out parameter.
+ returnSwappedSignatureElements = returnSwappedSignatureElements.RemoveAt(i);
+ }
+ else
+ {
+ // Convert the current element into an out parameter on the native signature
+ // while keeping it at the return position in the managed signature.
+ var managedSignatureAsNativeOut = returnSwappedSignatureElements[i] with
+ {
+ RefKind = RefKind.Out,
+ RefKindSyntax = SyntaxKind.OutKeyword,
+ ManagedIndex = TypePositionInfo.ReturnIndex,
+ NativeIndex = symbol.Parameters.Length
+ };
+ returnSwappedSignatureElements = returnSwappedSignatureElements.SetItem(i, managedSignatureAsNativeOut);
+ }
+ break;
}
- break;
}
- }
- signatureContext = signatureContext with
- {
- // Add the HRESULT return value in the native signature.
- // This element does not have any influence on the managed signature, so don't assign a managed index.
- ElementTypeInformation = returnSwappedSignatureElements.Add(
- new TypePositionInfo(SpecialTypeInfo.Int32, new ManagedHResultExceptionMarshallingInfo())
- {
- NativeIndex = TypePositionInfo.ReturnIndex
- })
- };
+ signatureContext = signatureContext with
+ {
+ // Add the HRESULT return value in the native signature.
+ // This element does not have any influence on the managed signature, so don't assign a managed index.
+ ElementTypeInformation = returnSwappedSignatureElements.Add(
+ new TypePositionInfo(SpecialTypeInfo.Int32, new ManagedHResultExceptionMarshallingInfo())
+ {
+ NativeIndex = TypePositionInfo.ReturnIndex
+ })
+ };
+ }
var containingSyntaxContext = new ContainingSyntaxContext(syntax);
// The .NET Foundation licenses this file to you under the MIT license.
using System;
+using System.Collections.Generic;
using System.Linq;
using System.Reflection.Metadata;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Testing;
+using Microsoft.Interop;
using Xunit;
-using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.Interop.VtableIndexStubGenerator>;
+using VerifyCS = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.CodeAnalysis.Testing.EmptySourceGeneratorProvider>;
namespace ComInterfaceGenerator.Unit.Tests
{
- public class CallingConventionForwarding
+ public class TargetSignatureTests
{
[Fact]
public async Task NoSpecifiedCallConvForwardsDefault()
}
""";
- await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (compilation, signature) =>
+ await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (compilation, signature) =>
{
Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
Assert.Empty(signature.UnmanagedCallingConventionTypes);
}
""";
- await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
+ await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
{
Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
Assert.Equal(newComp.GetTypeByMetadataName("System.Runtime.CompilerServices.CallConvSuppressGCTransition"), Assert.Single(signature.UnmanagedCallingConventionTypes), SymbolEqualityComparer.Default);
}
""";
- await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (_, signature) =>
+ await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (_, signature) =>
{
Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
Assert.Empty(signature.UnmanagedCallingConventionTypes);
}
""";
- await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (_, signature) =>
+ await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (_, signature) =>
{
Assert.Equal(SignatureCallingConvention.CDecl, signature.CallingConvention);
Assert.Empty(signature.UnmanagedCallingConventionTypes);
}
""";
- await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
+ await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
{
Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
Assert.Equal(new[]
}
""";
- await VerifySourceGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
+ await VerifyVirtualMethodIndexGeneratorAsync(source, "INativeAPI", "Method", (newComp, signature) =>
{
Assert.Equal(SignatureCallingConvention.Unmanaged, signature.CallingConvention);
Assert.Equal(new[]
});
}
- private static async Task VerifySourceGeneratorAsync(string source, string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
+ [Fact]
+ public async Task ComInterfaceMethodFunctionPointerReturnsInt()
{
- CallingConventionForwardingTest test = new(interfaceName, methodName, signatureValidator)
+ string source = $$"""
+ using System.Runtime.CompilerServices;
+ using System.Runtime.InteropServices;
+ using System.Runtime.InteropServices.Marshalling;
+
+ [GeneratedComInterface]
+ [Guid("0A617667-4961-4F90-B74F-6DC368E98179")]
+ partial interface IComInterface
+ {
+ void Method();
+ }
+ """;
+
+ await VerifyComInterfaceGeneratorAsync(source, "IComInterface", "Method", (newComp, signature) =>
+ {
+ Assert.Equal(SpecialType.System_Int32, signature.ReturnType.SpecialType);
+ });
+ }
+
+ [Fact]
+ public async Task ComInterfaceMethodFunctionPointerReturnTypeChangedToOutParameter()
+ {
+ string source = $$"""
+ using System.Runtime.CompilerServices;
+ using System.Runtime.InteropServices;
+ using System.Runtime.InteropServices.Marshalling;
+
+ [GeneratedComInterface]
+ [Guid("0A617667-4961-4F90-B74F-6DC368E98179")]
+ partial interface IComInterface
+ {
+ long Method();
+ }
+ """;
+
+ await VerifyComInterfaceGeneratorAsync(source, "IComInterface", "Method", (newComp, signature) =>
+ {
+ Assert.Equal(SpecialType.System_Int32, signature.ReturnType.SpecialType);
+ Assert.Equal(2, signature.Parameters.Length);
+ Assert.Equal(newComp.CreatePointerTypeSymbol(newComp.GetSpecialType(SpecialType.System_Void)), signature.Parameters[0].Type, SymbolEqualityComparer.Default);
+ Assert.Equal(newComp.CreatePointerTypeSymbol(newComp.GetSpecialType(SpecialType.System_Int64)), signature.Parameters[^1].Type, SymbolEqualityComparer.Default);
+ });
+ }
+
+ [Fact]
+ public async Task ComInterfaceMethodPreserveSigFunctionPointerReturnTypePreserved()
+ {
+ string source = $$"""
+ using System.Runtime.CompilerServices;
+ using System.Runtime.InteropServices;
+ using System.Runtime.InteropServices.Marshalling;
+
+ [GeneratedComInterface]
+ [Guid("0A617667-4961-4F90-B74F-6DC368E98179")]
+ partial interface IComInterface
+ {
+ [PreserveSig]
+ long Method();
+ }
+ """;
+
+ await VerifyComInterfaceGeneratorAsync(source, "IComInterface", "Method", (newComp, signature) =>
+ {
+ Assert.Equal(SpecialType.System_Int64, signature.ReturnType.SpecialType);
+ Assert.Equal(newComp.CreatePointerTypeSymbol(newComp.GetSpecialType(SpecialType.System_Void)), Assert.Single(signature.Parameters).Type, SymbolEqualityComparer.Default);
+ });
+ }
+
+ private static async Task VerifyVirtualMethodIndexGeneratorAsync(string source, string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
+ {
+ VirtualMethodIndexTargetSignatureTest test = new(interfaceName, methodName, signatureValidator)
{
TestCode = source,
TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck
await test.RunAsync();
}
+ private static async Task VerifyComInterfaceGeneratorAsync(string source, string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
+ {
+ ComInterfaceTargetSignatureTest test = new(interfaceName, methodName, signatureValidator)
+ {
+ TestCode = source,
+ TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck
+ };
- class CallingConventionForwardingTest : VerifyCS.Test
+ await test.RunAsync();
+ }
+
+ private abstract class TargetSignatureTestBase : VerifyCS.Test
{
private readonly Action<Compilation, IMethodSymbol> _signatureValidator;
private readonly string _interfaceName;
private readonly string _methodName;
- public CallingConventionForwardingTest(string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
+ protected TargetSignatureTestBase(string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
: base(referenceAncillaryInterop: true)
{
_signatureValidator = signatureValidator;
{
_signatureValidator(compilation, FindFunctionPointerInvocationSignature(compilation));
}
+
+ protected abstract INamedTypeSymbol FindImplementationInterface(Compilation compilation, INamedTypeSymbol userDefinedInterface);
private IMethodSymbol FindFunctionPointerInvocationSignature(Compilation compilation)
{
INamedTypeSymbol? userDefinedInterface = compilation.Assembly.GetTypeByMetadataName(_interfaceName);
Assert.NotNull(userDefinedInterface);
- INamedTypeSymbol generatedInterfaceImplementation = Assert.Single(userDefinedInterface.GetTypeMembers("Native"));
+ INamedTypeSymbol generatedInterfaceImplementation = FindImplementationInterface(compilation, userDefinedInterface);
IMethodSymbol methodImplementation = Assert.Single(generatedInterfaceImplementation.GetMembers($"global::{_interfaceName}.{_methodName}").OfType<IMethodSymbol>());
return Assert.Single(body.Descendants().OfType<IFunctionPointerInvocationOperation>()).GetFunctionPointerSignature();
}
}
+
+ private sealed class VirtualMethodIndexTargetSignatureTest : TargetSignatureTestBase
+ {
+ public VirtualMethodIndexTargetSignatureTest(string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator)
+ : base(interfaceName, methodName, signatureValidator)
+ {
+ }
+
+ protected override IEnumerable<Type> GetSourceGenerators() => new[] { typeof(VtableIndexStubGenerator) };
+
+ protected override INamedTypeSymbol FindImplementationInterface(Compilation compilation, INamedTypeSymbol userDefinedInterface) => Assert.Single(userDefinedInterface.GetTypeMembers("Native"));
+ }
+
+ private sealed class ComInterfaceTargetSignatureTest : TargetSignatureTestBase
+ {
+ public ComInterfaceTargetSignatureTest(string interfaceName, string methodName, Action<Compilation, IMethodSymbol> signatureValidator) : base(interfaceName, methodName, signatureValidator)
+ {
+ }
+ protected override IEnumerable<Type> GetSourceGenerators() => new[] { typeof(Microsoft.Interop.ComInterfaceGenerator) };
+
+ protected override INamedTypeSymbol FindImplementationInterface(Compilation compilation, INamedTypeSymbol userDefinedInterface)
+ {
+ INamedTypeSymbol? iUnknownDerivedAttributeType = compilation.GetTypeByMetadataName("System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute`2");
+
+ Assert.NotNull(iUnknownDerivedAttributeType);
+
+ AttributeData iUnknownDerivedAttribute = Assert.Single(
+ userDefinedInterface.GetAttributes(),
+ attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.OriginalDefinition, iUnknownDerivedAttributeType));
+
+ return (INamedTypeSymbol)iUnknownDerivedAttribute.AttributeClass!.TypeArguments[1];
+ }
+ }
}
}