Add marshallers for GeneratedComInterface-based types (#86177)
authorJeremy Koritzinsky <jekoritz@microsoft.com>
Sat, 20 May 2023 01:59:29 +0000 (18:59 -0700)
committerGitHub <noreply@github.com>
Sat, 20 May 2023 01:59:29 +0000 (18:59 -0700)
21 files changed:
docs/design/libraries/LibraryImportGenerator/Compatibility.md
src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/Marshalling/NativeMarshallingAttribute.cs
src/libraries/System.Runtime.InteropServices/gen/Common/DefaultMarshallingInfoParser.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ComInterfaceMarshallingInfoProvider.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/MarshalAsAttributeParser.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/StringMarshallingInfoProvider.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs
src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs
src/libraries/System.Runtime.InteropServices/src/System.Runtime.InteropServices.csproj
src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/ComInterfaceMarshaller.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/StrategyBasedComWrappers.cs
src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/UniqueComInterfaceMarshaller.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComClassTests.cs
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IGetAndSetIntTests.cs
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/CodeSnippets.cs
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/Compiles.cs
src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/CodeSnippets.cs
src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/Compiles.cs
src/libraries/System.Runtime.InteropServices/tests/LibraryImportGenerator.UnitTests/ConvertToLibraryImportAnalyzerTests.cs
src/libraries/System.Runtime/ref/System.Runtime.cs
src/libraries/apicompat/ApiCompatBaseline.NetCoreAppLatestStable.xml

index cd2e31f..63d7053 100644 (file)
@@ -8,6 +8,12 @@ Documentation on compatibility guidance and the current state. The version headi
 
 Due to trimming issues with NativeAOT's implementation of `Activator.CreateInstance`, we have decided to change our recommendation of providing a public parameterless constructor for `ref`, `out`, and return scenarios to a requirement. We already required a parameterless constructor of some visibility, so changing to a requirement matches our design principles of taking breaking changes to make interop more understandable and enforce more of our best practices instead of going out of our way to provide backward compatibility at increasing costs.
 
+### `UnmanagedType.Interface`
+
+Support for `MarshalAs(UnmanagedType.Interface)` is added to the interop source generators. `UnmanagedType.Interface` will marshal a parameter/return value of a type `T` to a COM interface pointer the `ComInterfaceMarshaller<T>` type. It will not support marshalling through the built-in COM interop subsystem.
+
+The `ComInterfaceMarshaller<T>` type has the following general behavior: An unmanaged pointer is marshalled to a managed object through `GetOrCreateObjectForComInstance` on a shared `StrategyBasedComWrappers` instance. A managed object is marshalled to an unmanaged pointer through that same shared instance with the `GetOrCreateComInterfaceForObject` method and then calling `QueryInterface` on the returned `IUnknown*` to get the pointer for the unmanaged interface with the IID from the managed type as defined by our default interface details strategy (or the IID of `IUnknown` if the managed type has no IID).
+
 ## Version 2 (.NET 7 Release)
 
 The focus of version 2 is to support all repos that make up the .NET Product, including ASP.NET Core and Windows Forms, as well as all packages in dotnet/runtime.
index 96bdeeb..8e39ad6 100644 (file)
@@ -12,7 +12,7 @@ namespace System.Runtime.InteropServices.Marshalling
     /// </remarks>
     /// <seealso cref="LibraryImportAttribute" />
     /// <seealso cref="CustomMarshallerAttribute" />
-    [AttributeUsage(AttributeTargets.Struct | AttributeTargets.Class | AttributeTargets.Enum | AttributeTargets.Delegate)]
+    [AttributeUsage(AttributeTargets.Struct | AttributeTargets.Class | AttributeTargets.Enum | AttributeTargets.Interface | AttributeTargets.Delegate)]
     public sealed class NativeMarshallingAttribute : Attribute
     {
         /// <summary>
index 2150eac..7149fc4 100644 (file)
@@ -38,14 +38,15 @@ namespace Microsoft.Interop
                 diagnostics,
                 new MethodSignatureElementInfoProvider(env.Compilation, diagnostics, method, useSiteAttributeParsers),
                 useSiteAttributeParsers,
-            ImmutableArray.Create<IMarshallingInfoAttributeParser>(
+                ImmutableArray.Create<IMarshallingInfoAttributeParser>(
                     new MarshalAsAttributeParser(env.Compilation, diagnostics, defaultInfo),
                     new MarshalUsingAttributeParser(env.Compilation, diagnostics),
-                    new NativeMarshallingAttributeParser(env.Compilation, diagnostics)),
+                    new NativeMarshallingAttributeParser(env.Compilation, diagnostics),
+                    new ComInterfaceMarshallingInfoProvider(env.Compilation)),
                 ImmutableArray.Create<ITypeBasedMarshallingInfoProvider>(
                     new SafeHandleMarshallingInfoProvider(env.Compilation, method.ContainingType),
                     new ArrayMarshallingInfoProvider(env.Compilation),
-            new CharMarshallingInfoProvider(defaultInfo),
+                    new CharMarshallingInfoProvider(defaultInfo),
                     new StringMarshallingInfoProvider(env.Compilation, diagnostics, unparsedAttributeData, defaultInfo),
                     new BooleanMarshallingInfoProvider(),
                     new BlittableTypeMarshallingInfoProvider(env.Compilation)));
diff --git a/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ComInterfaceMarshallingInfoProvider.cs b/src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/ComInterfaceMarshallingInfoProvider.cs
new file mode 100644 (file)
index 0000000..2386acc
--- /dev/null
@@ -0,0 +1,54 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+using System.Linq;
+using Microsoft.CodeAnalysis;
+
+namespace Microsoft.Interop
+{
+    /// <summary>
+    /// This class supports generating marshalling info for types with the <c>System.Runtime.InteropServices.Marshalling.GeneratedComInterfaceAttribute</c> attribute.
+    /// </summary>
+    public class ComInterfaceMarshallingInfoProvider : IMarshallingInfoAttributeParser
+    {
+        private readonly Compilation _compilation;
+
+        public ComInterfaceMarshallingInfoProvider(Compilation compilation)
+        {
+            _compilation = compilation;
+        }
+
+        public bool CanParseAttributeType(INamedTypeSymbol attributeType) => attributeType.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute;
+
+        public MarshallingInfo? ParseAttribute(AttributeData attributeData, ITypeSymbol type, int indirectionDepth, UseSiteAttributeProvider useSiteAttributes, GetMarshallingInfoCallback marshallingInfoCallback)
+        {
+            return CreateComInterfaceMarshallingInfo(_compilation, type);
+        }
+
+        public static MarshallingInfo CreateComInterfaceMarshallingInfo(
+            Compilation compilation,
+            ITypeSymbol interfaceType)
+        {
+            INamedTypeSymbol? comInterfaceMarshaller = compilation.GetTypeByMetadataName(TypeNames.System_Runtime_InteropServices_Marshalling_ComInterfaceMarshaller_Metadata);
+            if (comInterfaceMarshaller is null)
+                return new MissingSupportMarshallingInfo();
+
+            comInterfaceMarshaller = comInterfaceMarshaller.Construct(interfaceType);
+
+            if (ManualTypeMarshallingHelper.HasEntryPointMarshallerAttribute(comInterfaceMarshaller))
+            {
+                if (ManualTypeMarshallingHelper.TryGetValueMarshallersFromEntryType(comInterfaceMarshaller, interfaceType, compilation, out CustomTypeMarshallers? marshallers))
+                {
+                    return new NativeMarshallingAttributeInfo(
+                        EntryPointType: ManagedTypeInfo.CreateTypeInfoForTypeSymbol(comInterfaceMarshaller),
+                        Marshallers: marshallers.Value);
+                }
+            }
+
+            return new MissingSupportMarshallingInfo();
+        }
+    }
+}
index 4294f11..aa9b253 100644 (file)
@@ -4,7 +4,6 @@
 using System;
 using System.Collections.Generic;
 using System.Collections.Immutable;
-using System.Diagnostics;
 using System.Runtime.InteropServices;
 using Microsoft.CodeAnalysis;
 
@@ -98,6 +97,18 @@ namespace Microsoft.Interop
                 }
             }
 
+            // We'll support the UnmanagedType.Interface option, but we'll explicitly
+            // exclude ComImport types as they will not work as expected
+            // unless they are migrated to [GeneratedComInterface].
+            if (unmanagedType == UnmanagedType.Interface)
+            {
+                if (type is INamedTypeSymbol { IsComImport: true })
+                {
+                    return new MarshalAsInfo(unmanagedType, _defaultInfo.CharEncoding);
+                }
+                return ComInterfaceMarshallingInfoProvider.CreateComInterfaceMarshallingInfo(_compilation, type);
+            }
+
             if (isArrayType)
             {
                 if (type is not IArrayTypeSymbol { ElementType: ITypeSymbol elementType })
index e8c9aee..8eeb1b2 100644 (file)
@@ -8,7 +8,6 @@ using Microsoft.CodeAnalysis;
 
 namespace Microsoft.Interop
 {
-
     /// <summary>
     /// This class supports generating marshalling info for the <see cref="string"/> type.
     /// This includes support for the <c>System.Runtime.InteropServices.StringMarshalling</c> enum.
index 612a881..f91c947 100644 (file)
@@ -34,6 +34,8 @@ namespace Microsoft.Interop
 
         public const string UnmanagedCallersOnlyAttribute = "System.Runtime.InteropServices.UnmanagedCallersOnlyAttribute";
 
+        public const string System_Runtime_InteropServices_ComImportAttribute = "System.Runtime.InteropServices.ComImportAttribute";
+
         public const string VirtualMethodIndexAttribute = "System.Runtime.InteropServices.Marshalling.VirtualMethodIndexAttribute";
 
         public const string IUnmanagedVirtualMethodTableProvider = "System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider";
@@ -142,5 +144,7 @@ namespace Microsoft.Interop
         public const string UnreachableException = "System.Diagnostics.UnreachableException";
 
         public const string System_Runtime_InteropServices_Marshalling_SafeHandleMarshaller_Metadata = "System.Runtime.InteropServices.Marshalling.SafeHandleMarshaller`1";
+
+        public const string System_Runtime_InteropServices_Marshalling_ComInterfaceMarshaller_Metadata = "System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller`1";
     }
 }
index c218273..87d4310 100644 (file)
@@ -332,6 +332,17 @@ namespace System.Runtime.InteropServices.Marshalling
         public ComExposedClassAttribute() { }
         public unsafe System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count) { throw null; }
     }
+    [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("android")]
+    [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
+    [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("ios")]
+    [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("tvos")]
+    [System.CLSCompliantAttribute(false)]
+    [System.Runtime.InteropServices.Marshalling.CustomMarshallerAttribute(typeof(System.Runtime.InteropServices.Marshalling.CustomMarshallerAttribute.GenericPlaceholder), System.Runtime.InteropServices.Marshalling.MarshalMode.Default, typeof(System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller<>))]
+    public static unsafe class ComInterfaceMarshaller<T>
+    {
+        public static void* ConvertToUnmanaged(T? managed) { throw null; }
+        public static T? ConvertToManaged(void* unmanaged) { throw null; }
+    }
     public sealed partial class ComObject : System.Runtime.InteropServices.IDynamicInterfaceCastable, System.Runtime.InteropServices.Marshalling.IUnmanagedVirtualMethodTableProvider
     {
         internal ComObject() { }
@@ -453,6 +464,17 @@ namespace System.Runtime.InteropServices.Marshalling
         protected virtual System.Runtime.InteropServices.Marshalling.IIUnknownStrategy GetOrCreateIUnknownStrategy() { throw null; }
         protected sealed override void ReleaseObjects(System.Collections.IEnumerable objects) { }
     }
+    [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("android")]
+    [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
+    [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("ios")]
+    [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("tvos")]
+    [System.CLSCompliantAttribute(false)]
+    [System.Runtime.InteropServices.Marshalling.CustomMarshallerAttribute(typeof(System.Runtime.InteropServices.Marshalling.CustomMarshallerAttribute.GenericPlaceholder), System.Runtime.InteropServices.Marshalling.MarshalMode.Default, typeof(System.Runtime.InteropServices.Marshalling.UniqueComInterfaceMarshaller<>))]
+    public static unsafe class UniqueComInterfaceMarshaller<T>
+    {
+        public static void* ConvertToUnmanaged(T? managed) { throw null; }
+        public static T? ConvertToManaged(void* unmanaged) { throw null; }
+    }
     [System.CLSCompliantAttribute(false)]
     public readonly partial struct VirtualMethodTableInfo
     {
@@ -690,10 +712,10 @@ namespace System.Runtime.InteropServices
         public ComSourceInterfacesAttribute(System.Type sourceInterface1, System.Type sourceInterface2, System.Type sourceInterface3, System.Type sourceInterface4) { }
         public string Value { get { throw null; } }
     }
-    [System.Runtime.Versioning.UnsupportedOSPlatform("android")]
-    [System.Runtime.Versioning.UnsupportedOSPlatform("browser")]
-    [System.Runtime.Versioning.UnsupportedOSPlatform("ios")]
-    [System.Runtime.Versioning.UnsupportedOSPlatform("tvos")]
+    [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("android")]
+    [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("browser")]
+    [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("ios")]
+    [System.Runtime.Versioning.UnsupportedOSPlatformAttribute("tvos")]
     [System.CLSCompliantAttribute(false)]
     public abstract class ComWrappers
     {
index 7d91ce1..c92594f 100644 (file)
@@ -32,6 +32,8 @@
     <Compile Include="System\Runtime\InteropServices\ImportedFromTypeLibAttribute.cs" />
     <Compile Include="System\Runtime\InteropServices\ManagedToNativeComInteropStubAttribute.cs" />
     <Compile Include="System\Runtime\InteropServices\Marshalling\ComExposedClassAttribute.cs" />
+    <Compile Include="System\Runtime\InteropServices\Marshalling\UniqueComInterfaceMarshaller.cs" />
+    <Compile Include="System\Runtime\InteropServices\Marshalling\ComInterfaceMarshaller.cs" />
     <Compile Include="System\Runtime\InteropServices\Marshalling\ComObject.cs" />
     <Compile Include="System\Runtime\InteropServices\Marshalling\DefaultCaching.cs" />
     <Compile Include="System\Runtime\InteropServices\Marshalling\DefaultIUnknownInterfaceDetailsStrategy.cs" />
diff --git a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/ComInterfaceMarshaller.cs b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/ComInterfaceMarshaller.cs
new file mode 100644 (file)
index 0000000..382540a
--- /dev/null
@@ -0,0 +1,64 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Runtime.Versioning;
+
+namespace System.Runtime.InteropServices.Marshalling
+{
+    /// <summary>
+    /// COM interface marshaller using a StrategyBasedComWrappers instance
+    /// </summary>
+    /// <remarks>
+    /// This marshaller will always pass the <see cref="CreateObjectFlags.Unwrap"/> flag
+    /// to <see cref="ComWrappers.GetOrCreateObjectForComInstance(IntPtr, CreateObjectFlags)"/>.
+    /// </remarks>
+    [UnsupportedOSPlatform("android")]
+    [UnsupportedOSPlatform("browser")]
+    [UnsupportedOSPlatform("ios")]
+    [UnsupportedOSPlatform("tvos")]
+    [CLSCompliant(false)]
+    [CustomMarshaller(typeof(CustomMarshallerAttribute.GenericPlaceholder), MarshalMode.Default, typeof(ComInterfaceMarshaller<>))]
+    public static unsafe class ComInterfaceMarshaller<T>
+    {
+        private static readonly Guid? TargetInterfaceIID = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(T).TypeHandle)?.Iid;
+
+        public static void* ConvertToUnmanaged(T? managed)
+        {
+            if (managed == null)
+            {
+                return null;
+            }
+            if (!ComWrappers.TryGetComInstance(managed, out nint unknown))
+            {
+                unknown = StrategyBasedComWrappers.DefaultMarshallingInstance.GetOrCreateComInterfaceForObject(managed, CreateComInterfaceFlags.None);
+            }
+            return CastIUnknownToInterfaceType(unknown);
+        }
+
+        public static T? ConvertToManaged(void* unmanaged)
+        {
+            if (unmanaged == null)
+            {
+                return default;
+            }
+            return (T)StrategyBasedComWrappers.DefaultMarshallingInstance.GetOrCreateObjectForComInstance((nint)unmanaged, CreateObjectFlags.Unwrap);
+        }
+
+        internal static void* CastIUnknownToInterfaceType(nint unknown)
+        {
+            if (TargetInterfaceIID is null)
+            {
+                // If the managed type isn't a GeneratedComInterface-attributed type, we'll marshal to an IUnknown*.
+                return (void*)unknown;
+            }
+            Guid iid = TargetInterfaceIID.Value;
+            if (Marshal.QueryInterface(unknown, ref iid, out nint interfacePointer) != 0)
+            {
+                Marshal.Release(unknown);
+                throw new InvalidCastException($"Unable to cast the provided managed object to a COM interface with ID '{iid:B}'");
+            }
+            Marshal.Release(unknown);
+            return (void*)interfacePointer;
+        }
+    }
+}
index d8074e1..e5ba351 100644 (file)
@@ -9,6 +9,8 @@ namespace System.Runtime.InteropServices.Marshalling
     [CLSCompliant(false)]
     public class StrategyBasedComWrappers : ComWrappers
     {
+        internal static StrategyBasedComWrappers DefaultMarshallingInstance { get; } = new();
+
         public static IIUnknownInterfaceDetailsStrategy DefaultIUnknownInterfaceDetailsStrategy { get; } = Marshalling.DefaultIUnknownInterfaceDetailsStrategy.Instance;
 
         public static IIUnknownStrategy DefaultIUnknownStrategy { get; } = FreeThreadedStrategy.Instance;
diff --git a/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/UniqueComInterfaceMarshaller.cs b/src/libraries/System.Runtime.InteropServices/src/System/Runtime/InteropServices/Marshalling/UniqueComInterfaceMarshaller.cs
new file mode 100644 (file)
index 0000000..1fdc5ea
--- /dev/null
@@ -0,0 +1,46 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Runtime.Versioning;
+
+namespace System.Runtime.InteropServices.Marshalling
+{
+    /// <summary>
+    /// COM interface marshaller using a StrategyBasedComWrappers instance
+    /// that will only create unique native object wrappers (RCW).
+    /// </summary>
+    /// <remarks>
+    /// This marshaller will always pass the <see cref="CreateObjectFlags.Unwrap"/> and <see cref="CreateObjectFlags.UniqueInstance"/> flags
+    /// to <see cref="ComWrappers.GetOrCreateObjectForComInstance(IntPtr, CreateObjectFlags)"/>.
+    /// </remarks>
+    [UnsupportedOSPlatform("android")]
+    [UnsupportedOSPlatform("browser")]
+    [UnsupportedOSPlatform("ios")]
+    [UnsupportedOSPlatform("tvos")]
+    [CLSCompliant(false)]
+    [CustomMarshaller(typeof(CustomMarshallerAttribute.GenericPlaceholder), MarshalMode.Default, typeof(UniqueComInterfaceMarshaller<>))]
+    public static unsafe class UniqueComInterfaceMarshaller<T>
+    {
+        public static void* ConvertToUnmanaged(T? managed)
+        {
+            if (managed == null)
+            {
+                return null;
+            }
+            if (!ComWrappers.TryGetComInstance(managed, out nint unknown))
+            {
+                unknown = StrategyBasedComWrappers.DefaultMarshallingInstance.GetOrCreateComInterfaceForObject(managed, CreateComInterfaceFlags.None);
+            }
+            return ComInterfaceMarshaller<T>.CastIUnknownToInterfaceType(unknown);
+        }
+
+        public static T? ConvertToManaged(void* unmanaged)
+        {
+            if (unmanaged == null)
+            {
+                return default;
+            }
+            return (T)StrategyBasedComWrappers.DefaultMarshallingInstance.GetOrCreateObjectForComInstance((nint)unmanaged, CreateObjectFlags.Unwrap | CreateObjectFlags.UniqueInstance);
+        }
+    }
+}
index 737a9f9..b8dbf34 100644 (file)
@@ -16,6 +16,12 @@ namespace ComInterfaceGenerator.Tests
 
         [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_com_object_data")]
         public static partial int GetComObjectData(void* obj);
+
+        [LibraryImport(NativeExportsNE_Binary, EntryPoint = "set_com_object_data")]
+        public static partial void SetComObjectData(IGetAndSetInt obj, int data);
+
+        [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_com_object_data")]
+        public static partial int GetComObjectData(IGetAndSetInt obj);
     }
 
     [GeneratedComClass]
@@ -62,31 +68,23 @@ namespace ComInterfaceGenerator.Tests
         }
 
         [Fact]
-        public void CallsToComInterfaceWriteChangesToManagedObject()
+        public void CallsToComInterfaceWithMarshallerWriteChangesToManagedObject()
         {
             ManagedObjectExposedToCom obj = new();
-            StrategyBasedComWrappers wrappers = new();
-            void* ptr = (void*)wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None);
-            Assert.NotEqual(0, (nint)ptr);
             obj.Data = 3;
             Assert.Equal(3, obj.Data);
-            NativeExportsNE.SetComObjectData(ptr, 42);
+            NativeExportsNE.SetComObjectData(obj, 42);
             Assert.Equal(42, obj.Data);
-            Marshal.Release((nint)ptr);
         }
 
         [Fact]
-        public void CallsToComInterfaceReadChangesFromManagedObject()
+        public void CallsToComInterfaceWithMarshallerReadChangesFromManagedObject()
         {
             ManagedObjectExposedToCom obj = new();
-            StrategyBasedComWrappers wrappers = new();
-            void* ptr = (void*)wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None);
-            Assert.NotEqual(0, (nint)ptr);
             obj.Data = 3;
             Assert.Equal(3, obj.Data);
             obj.Data = 12;
-            Assert.Equal(obj.Data, NativeExportsNE.GetComObjectData(ptr));
-            Marshal.Release((nint)ptr);
+            Assert.Equal(obj.Data, NativeExportsNE.GetComObjectData(obj));
         }
     }
 }
index 322c459..ac1c3b5 100644 (file)
@@ -20,6 +20,13 @@ namespace ComInterfaceGenerator.Tests
         [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_get_and_set_int")]
         public static partial void* NewNativeObject();
 
+        [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_get_and_set_int")]
+        internal static partial IGetAndSetInt NewNativeObjectWithMarshaller();
+
+        [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_get_and_set_int")]
+        [return:MarshalUsing(typeof(UniqueComInterfaceMarshaller<IGetAndSetInt>))]
+        internal static partial IGetAndSetInt NewNativeObjectWithUniqueMarshaller();
+
         [Fact]
         public unsafe void CallRcwFromGeneratedComInterface()
         {
@@ -32,5 +39,27 @@ namespace ComInterfaceGenerator.Tests
             intObj.SetInt(2);
             Assert.Equal(2, intObj.GetInt());
         }
+
+        [Fact]
+        public unsafe void CallRcwFromGeneratedComInterfaceConstructedByMarshaller()
+        {
+            var intObj = NewNativeObjectWithMarshaller(); // new_native_object
+
+            Assert.Equal(0, intObj.GetInt());
+            intObj.SetInt(2);
+            Assert.Equal(2, intObj.GetInt());
+        }
+
+        [Fact]
+        public unsafe void CallRcwFromGeneratedComInterfaceConstructedByUniqueMarshaller()
+        {
+            var intObj = NewNativeObjectWithUniqueMarshaller(); // new_native_object
+            var intObj2 = NewNativeObjectWithUniqueMarshaller(); // new_native_object
+            Assert.NotSame(intObj, intObj2);
+
+            Assert.Equal(0, intObj.GetInt());
+            intObj.SetInt(2);
+            Assert.Equal(2, intObj.GetInt());
+        }
     }
 }
index 9a1ab29..38f9f9f 100644 (file)
@@ -53,7 +53,7 @@ namespace ComInterfaceGenerator.Unit.Tests
         public string SpecifiedMethodIndexNoExplicitParameters => $$"""
             using System.Runtime.InteropServices;
             using System.Runtime.InteropServices.Marshalling;
-            
+
             {{UnmanagedObjectUnwrapper(typeof(UnmanagedObjectUnwrapper.TestUnwrapper))}}
             {{GeneratedComInterface}}
             partial interface INativeAPI
@@ -65,17 +65,17 @@ namespace ComInterfaceGenerator.Unit.Tests
             """;
 
         public string SpecifiedMethodIndexNoExplicitParametersNoImplicitThis => $$"""
-            
+
             using System.Runtime.InteropServices;
             using System.Runtime.InteropServices.Marshalling;
-            
+
             {{UnmanagedObjectUnwrapper(typeof(UnmanagedObjectUnwrapper.TestUnwrapper))}}
             {{GeneratedComInterface}}
             partial interface INativeAPI
             {
                 {{VirtualMethodIndex(0, ImplicitThisParameter: false)}}
                 void Method();
-            
+
             }
             {{_attributeProvider.AdditionalUserRequiredInterfaces("INativeAPI")}}
             """;
@@ -84,29 +84,29 @@ namespace ComInterfaceGenerator.Unit.Tests
             using System.Runtime.CompilerServices;
             using System.Runtime.InteropServices;
             using System.Runtime.InteropServices.Marshalling;
-            
+
             {{UnmanagedObjectUnwrapper(typeof(UnmanagedObjectUnwrapper.TestUnwrapper))}}
             {{GeneratedComInterface}}
             partial interface INativeAPI
             {
-            
+
                 {{UnmanagedCallConv(CallConvs: new[] { typeof(CallConvCdecl) })}}
                 {{VirtualMethodIndex(0)}}
                 void Method();
                 {{UnmanagedCallConv(CallConvs: new[] { typeof(CallConvCdecl), typeof(CallConvMemberFunction) })}}
                 {{VirtualMethodIndex(1)}}
                 void Method1();
-            
+
                 [SuppressGCTransition]
                 {{UnmanagedCallConv(CallConvs: new[] { typeof(CallConvCdecl), typeof(CallConvMemberFunction) })}}
                 {{VirtualMethodIndex(2)}}
                 void Method2();
-            
+
                 [SuppressGCTransition]
                 {{UnmanagedCallConv()}}
                 {{VirtualMethodIndex(3)}}
                 void Method3();
-            
+
                 [SuppressGCTransition]
                 {{VirtualMethodIndex(4)}}
                 void Method4();
@@ -118,9 +118,9 @@ namespace ComInterfaceGenerator.Unit.Tests
             using System.Runtime.InteropServices;
             using System.Runtime.InteropServices.Marshalling;
             {{preDeclaration}}
-            
+
             [assembly:DisableRuntimeMarshalling]
-            
+
             {{UnmanagedObjectUnwrapper(typeof(UnmanagedObjectUnwrapper.TestUnwrapper))}}
             {{GeneratedComInterface}}
             partial interface INativeAPI
@@ -136,9 +136,9 @@ namespace ComInterfaceGenerator.Unit.Tests
             using System.Runtime.InteropServices;
             using System.Runtime.InteropServices.Marshalling;
             {{preDeclaration}}
-            
+
             [assembly:DisableRuntimeMarshalling]
-            
+
             {{UnmanagedObjectUnwrapper(typeof(UnmanagedObjectUnwrapper.TestUnwrapper))}}
             {{GeneratedComInterface}}
             partial interface INativeAPI
@@ -146,7 +146,7 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {{VirtualMethodIndex(0, Direction: MarshalDirection.ManagedToUnmanaged)}}
                 {{typeName}} Method({{typeName}} value, in {{typeName}} inValue, ref {{typeName}} refValue, out {{typeName}} outValue);
             }
-            
+
             {{_attributeProvider.AdditionalUserRequiredInterfaces("INativeAPI")}}
             """;
         public string BasicParametersAndModifiers<T>() => BasicParametersAndModifiers(typeof(T).FullName!);
@@ -155,9 +155,9 @@ namespace ComInterfaceGenerator.Unit.Tests
             using System.Runtime.InteropServices;
             using System.Runtime.InteropServices.Marshalling;
             {{preDeclaration}}
-            
+
             [assembly:DisableRuntimeMarshalling]
-            
+
             {{UnmanagedObjectUnwrapper(typeof(UnmanagedObjectUnwrapper.TestUnwrapper))}}
             {{GeneratedComInterface}}
             partial interface INativeAPI
@@ -165,7 +165,7 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {{VirtualMethodIndex(0)}}
                 {{typeName}} Method({{typeName}} value, in {{typeName}} inValue, out {{typeName}} outValue);
             }
-            
+
             {{_attributeProvider.AdditionalUserRequiredInterfaces("INativeAPI")}}
             """;
 
@@ -173,7 +173,7 @@ namespace ComInterfaceGenerator.Unit.Tests
             using System.Runtime.CompilerServices;
             using System.Runtime.InteropServices;
             using System.Runtime.InteropServices.Marshalling;
-            
+
             {{UnmanagedObjectUnwrapper(typeof(UnmanagedObjectUnwrapper.TestUnwrapper))}}
             {{GeneratedComInterface}}
             partial interface INativeAPI
@@ -191,7 +191,7 @@ namespace ComInterfaceGenerator.Unit.Tests
             using System.Runtime.InteropServices;
             using System.Runtime.InteropServices.Marshalling;
             [assembly:DisableRuntimeMarshalling]
-            
+
             {{UnmanagedObjectUnwrapper(typeof(UnmanagedObjectUnwrapper.TestUnwrapper))}}
             {{GeneratedComInterface}}
             partial interface INativeAPI
@@ -208,7 +208,7 @@ namespace ComInterfaceGenerator.Unit.Tests
                     [MarshalUsing(CountElementName = "pOutSize")] out {{collectionType}} pOut,
                     out int pOutSize);
             }
-            
+
             {{_attributeProvider.AdditionalUserRequiredInterfaces("INativeAPI")}}
             """;
 
@@ -217,7 +217,7 @@ namespace ComInterfaceGenerator.Unit.Tests
             using System.Runtime.InteropServices;
             using System.Runtime.InteropServices.Marshalling;
             {{preDeclaration}}
-            
+
             {{UnmanagedObjectUnwrapper(typeof(UnmanagedObjectUnwrapper.TestUnwrapper))}}
             {{GeneratedComInterface}}
             partial interface INativeAPI
@@ -225,7 +225,7 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {{VirtualMethodIndex(0, ExceptionMarshalling: ExceptionMarshalling.Com)}}
                 {{typeName}} Method();
             }
-            
+
             {{_attributeProvider.AdditionalUserRequiredInterfaces("INativeAPI")}}
             """;
 
@@ -234,7 +234,7 @@ namespace ComInterfaceGenerator.Unit.Tests
             using System.Runtime.InteropServices;
             using System.Runtime.InteropServices.Marshalling;
             {{preDeclaration}}
-            
+
             {{UnmanagedObjectUnwrapper(typeof(UnmanagedObjectUnwrapper.TestUnwrapper))}}
             {{GeneratedComInterface}}
             partial interface INativeAPI
@@ -249,7 +249,7 @@ namespace ComInterfaceGenerator.Unit.Tests
             using System.Runtime.CompilerServices;
             using System.Runtime.InteropServices;
             using System.Runtime.InteropServices.Marshalling;
-            
+
             {{GeneratedComInterface}}
             partial interface IComInterface
             {
@@ -265,7 +265,7 @@ namespace ComInterfaceGenerator.Unit.Tests
             using System.Runtime.CompilerServices;
             using System.Runtime.InteropServices;
             using System.Runtime.InteropServices.Marshalling;
-            
+
             {{GeneratedComInterface}}
             partial interface IComInterface
             {
@@ -282,6 +282,15 @@ namespace ComInterfaceGenerator.Unit.Tests
                 void Method2();
             }
             """;
+
+        public string ComInterfaceParameters => BasicParametersAndModifiers("IComInterface2") + $$"""
+            {{GeneratedComInterface}}
+            partial interface IComInterface2
+            {
+                void Method2();
+            }
+            """;
+
         public class ManagedToUnmanaged : IVirtualMethodIndexSignatureProvider
         {
             public MarshalDirection Direction => MarshalDirection.ManagedToUnmanaged;
index 89c1e84..c866713 100644 (file)
@@ -4,16 +4,13 @@
 using System;
 using System.Collections.Generic;
 using System.Diagnostics;
-using System.Linq;
 using System.Runtime.CompilerServices;
 using System.Threading.Tasks;
-using Microsoft.CodeAnalysis;
 using Microsoft.Interop.UnitTests;
 using Xunit;
 
 using VerifyVTableGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.Interop.VtableIndexStubGenerator>;
 using VerifyComInterfaceGenerator = Microsoft.Interop.UnitTests.Verifiers.CSharpSourceGeneratorVerifier<Microsoft.Interop.ComInterfaceGenerator>;
-using Microsoft.CodeAnalysis.Testing;
 
 namespace ComInterfaceGenerator.Unit.Tests
 {
@@ -327,7 +324,6 @@ namespace ComInterfaceGenerator.Unit.Tests
         [MemberData(nameof(UnmanagedToManagedCodeSnippetsToCompile), GeneratorKind.VTableIndexStubGenerator)]
         [MemberData(nameof(CustomCollectionsManagedToUnmanaged), GeneratorKind.VTableIndexStubGenerator)]
         [MemberData(nameof(CustomCollections), GeneratorKind.VTableIndexStubGenerator)]
-        [MemberData(nameof(CustomCollections), GeneratorKind.VTableIndexStubGenerator)]
         public async Task ValidateVTableIndexSnippets(string id, string source)
         {
             _ = id;
@@ -338,6 +334,7 @@ namespace ComInterfaceGenerator.Unit.Tests
         {
             CodeSnippets codeSnippets = new(new GeneratedComInterfaceAttributeProvider());
             yield return new object[] { ID(), codeSnippets.DerivedComInterfaceType };
+            yield return new object[] { ID(), codeSnippets.ComInterfaceParameters };
         }
 
         [Theory]
index 0eb9830..488be39 100644 (file)
@@ -426,7 +426,7 @@ namespace LibraryImportGenerator.UnitTests
                 static class Marshaller
                 {
                     public static nint ConvertToUnmanaged({{typeName}} s) => default;
-                
+
                     public static {{typeName}} ConvertToManaged(nint i) => default;
                 }
                 """;
@@ -754,13 +754,21 @@ namespace LibraryImportGenerator.UnitTests
             class MySafeHandle : SafeHandle
             {
                 {{(privateCtor ? "private" : "public")}} MySafeHandle() : base(System.IntPtr.Zero, true) { }
-            
+
                 public override bool IsInvalid => handle == System.IntPtr.Zero;
-            
+
                 protected override bool ReleaseHandle() => true;
             }
             """;
 
+        public static string GeneratedComInterface => BasicParametersAndModifiers("MyInterfaceType", "using System.Runtime.InteropServices.Marshalling;") + """
+            [GeneratedComInterface]
+            interface MyInterfaceType
+            {
+                void Method();
+            }
+            """;
+
         public static string PreprocessorIfAroundFullFunctionDefinition(string define) =>
             $$"""
             partial class Test
index 6a932ac..e231c36 100644 (file)
@@ -124,11 +124,20 @@ namespace LibraryImportGenerator.UnitTests
             yield return new[] { ID(), CodeSnippets.MarshalAsParametersAndModifiers<string>(UnmanagedType.LPUTF8Str) };
             yield return new[] { ID(), CodeSnippets.MarshalAsParametersAndModifiers<string>(UnmanagedType.LPStr) };
             yield return new[] { ID(), CodeSnippets.MarshalAsParametersAndModifiers<string>(UnmanagedType.BStr) };
+            yield return new[] { ID(), CodeSnippets.MarshalAsParametersAndModifiers<object>(UnmanagedType.Interface) };
+            // TODO: Do we want to limit support of UnmanagedType.Interface to a subset of types?
+            // TODO: Should we block delegate types as they use to have special COM interface marshalling that we have since
+            // blocked? Blocking it would help .NET Framework->.NET migration as there wouldn't be a silent behavior change.
+            yield return new[] { ID(), CodeSnippets.MarshalAsParametersAndModifiers<string>(UnmanagedType.Interface) };
+            yield return new[] { ID(), CodeSnippets.MarshalAsParametersAndModifiers<Action>(UnmanagedType.Interface) };
+
+            // MarshalAs with array element UnmanagedType
             yield return new[] { ID(), CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo<string>(UnmanagedType.LPWStr) };
             yield return new[] { ID(), CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo<string>(UnmanagedType.LPUTF8Str) };
             yield return new[] { ID(), CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo<string>(UnmanagedType.LPStr) };
             yield return new[] { ID(), CodeSnippets.MarshalAsArrayParameterWithNestedMarshalInfo<string>(UnmanagedType.BStr) };
 
+
             // [In, Out] attributes
             // By value non-blittable array
             yield return new[] { ID(), CodeSnippets.ByValueParameterWithModifier("S[]", "Out")
@@ -236,6 +245,9 @@ namespace LibraryImportGenerator.UnitTests
             yield return new[] { ID(), CodeSnippets.MaybeBlittableGenericTypeParametersAndModifiers<IntPtr>() };
             yield return new[] { ID(), CodeSnippets.MaybeBlittableGenericTypeParametersAndModifiers<UIntPtr>() };
             yield return new[] { ID(), CodeSnippets.GenericsStress };
+
+            // Type-level interop generator trigger attributes
+            yield return new[] { ID(), CodeSnippets.GeneratedComInterface };
         }
 
         public static IEnumerable<object[]> CustomCollections()
index d2bd400..3288160 100644 (file)
@@ -36,7 +36,7 @@ namespace LibraryImportGenerator.UnitTests
             new object[] { typeof(bool*) },
             new object[] { typeof(char*) },
             // See issue https://github.com/dotnet/runtime/issues/71891
-            // new object[] { typeof(delegate* <void>) }, 
+            // new object[] { typeof(delegate* <void>) },
             new object[] { typeof(IntPtr) },
             new object[] { typeof(ConsoleKey) }, // enum
         };
@@ -77,10 +77,10 @@ namespace LibraryImportGenerator.UnitTests
                 {
                     [DllImport("DoesNotExist")]
                     public static extern void {|#0:Method_In|}(in {{typeName}} p);
-                
+
                     [DllImport("DoesNotExist")]
                     public static extern void {|#1:Method_Out|}(out {{typeName}} p);
-                
+
                     [DllImport("DoesNotExist")]
                     public static extern void {|#2:Method_Ref|}(ref {{typeName}} p);
                 }
@@ -132,7 +132,6 @@ namespace LibraryImportGenerator.UnitTests
         }
 
         [Theory]
-        [InlineData(UnmanagedType.Interface)]
         [InlineData(UnmanagedType.IDispatch)]
         [InlineData(UnmanagedType.IInspectable)]
         [InlineData(UnmanagedType.IUnknown)]
@@ -145,7 +144,7 @@ namespace LibraryImportGenerator.UnitTests
                 {
                     [DllImport("DoesNotExist")]
                     public static extern void Method_Parameter([MarshalAs(UnmanagedType.{{unmanagedType}}, MarshalType = "DNE")]int p);
-                
+
                     [DllImport("DoesNotExist")]
                     [return: MarshalAs(UnmanagedType.{{unmanagedType}}, MarshalType = "DNE")]
                     public static extern int Method_Return();
@@ -155,6 +154,32 @@ namespace LibraryImportGenerator.UnitTests
         }
 
         [Fact]
+        public async Task UnmanagedTypeInterfaceWithComImportType_NoDiagnostic()
+        {
+            string source = $$"""
+                using System.Runtime.InteropServices;
+
+                [ComImport]
+                [Guid("8509bcd0-45bc-4b04-bb45-f3cac0b4cabd")]
+                interface IFoo
+                {
+                    void Bar();
+                }
+
+                unsafe partial class Test
+                {
+                    [DllImport("DoesNotExist")]
+                    public static extern void Method_Parameter([MarshalAs(UnmanagedType.Interface)]IFoo p);
+
+                    [DllImport("DoesNotExist")]
+                    [return: MarshalAs(UnmanagedType.Interface, MarshalType = "DNE")]
+                    public static extern IFoo Method_Return();
+                }
+                """;
+            await VerifyCS.VerifyAnalyzerAsync(source);
+        }
+
+        [Fact]
         public async Task LibraryImport_NoDiagnostic()
         {
             string source = """
@@ -193,7 +218,7 @@ namespace LibraryImportGenerator.UnitTests
             {
                 [DllImport("DoesNotExist")]
                 public static extern void {|#0:Method_Parameter|}({{typeName}} p);
-            
+
                 [DllImport("DoesNotExist")]
                 public static extern {{typeName}} {|#1:Method_Return|}();
             }
index 5362c12..984f1b9 100644 (file)
@@ -13461,7 +13461,7 @@ namespace System.Runtime.InteropServices.Marshalling
         ElementRef = 8,
         ElementOut = 9
     }
-    [System.AttributeUsageAttribute(System.AttributeTargets.Struct | System.AttributeTargets.Class | System.AttributeTargets.Enum | System.AttributeTargets.Delegate)]
+    [System.AttributeUsageAttribute(System.AttributeTargets.Struct | System.AttributeTargets.Class | System.AttributeTargets.Enum | System.AttributeTargets.Interface | System.AttributeTargets.Delegate)]
     public sealed partial class NativeMarshallingAttribute : System.Attribute
     {
         public NativeMarshallingAttribute(System.Type nativeType) { }
index f744ff4..1349586 100644 (file)
   </Suppression>
   <Suppression>
     <DiagnosticId>CP0015</DiagnosticId>
+    <Target>T:System.Runtime.InteropServices.Marshalling.NativeMarshallingAttribute:[T:System.AttributeUsageAttribute]</Target>
+    <Left>net7.0/System.Runtime.dll</Left>
+    <Right>net8.0/System.Runtime.dll</Right>
+  </Suppression>
+  <Suppression>
+    <DiagnosticId>CP0015</DiagnosticId>
     <Target>M:System.Reflection.Metadata.MetadataUpdateHandlerAttribute.#ctor(System.Type)$0:[T:System.Diagnostics.CodeAnalysis.DynamicallyAccessedMembersAttribute]</Target>
     <Left>net7.0/System.Runtime.Loader.dll</Left>
     <Right>net8.0/System.Runtime.Loader.dll</Right>