Provide support for exposing .NET classes to COM through source generation (#83755)
authorJeremy Koritzinsky <jekoritz@microsoft.com>
Tue, 28 Mar 2023 18:01:18 +0000 (11:01 -0700)
committerGitHub <noreply@github.com>
Tue, 28 Mar 2023 18:01:18 +0000 (11:01 -0700)
25 files changed:
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComInterfaceGenerator.cs
src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/IncrementalValuesProviderExtensions.cs
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/TypeNames.cs
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ComExposedClassAttribute.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ComObject.cs
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/DefaultCaching.cs
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/DefaultIUnknownInterfaceDetailsStrategy.cs
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComClassAttribute.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComInterfaceAttribute.cs
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComWrappersBase.cs [deleted file]
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IComExposedClass.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IComExposedDetails.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownCacheStrategy.cs
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownDerivedDetails.cs [moved from src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IUnknownDerivedDetails.cs with 82% similarity]
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownInterfaceDetailsStrategy.cs
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IIUnknownInterfaceType.cs
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IUnknownDerivedAttribute.cs
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/StrategyBasedComWrappers.cs
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComClassTests.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IComInterface1.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwTests.cs
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/GeneratedComInterfaceAnalyzerTests.cs
src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaces.cs

diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/ComClassGenerator.cs
new file mode 100644 (file)
index 0000000..5085cbc
--- /dev/null
@@ -0,0 +1,231 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.IO;
+using System.Linq;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+using Microsoft.CodeAnalysis.CSharp;
+
+namespace Microsoft.Interop
+{
+    [Generator]
+    public class ComClassGenerator : IIncrementalGenerator
+    {
+        private sealed record ComClassInfo(string ClassName, ContainingSyntaxContext ContainingSyntaxContext, ContainingSyntax ClassSyntax, SequenceEqualImmutableArray<string> ImplementedInterfacesNames);
+        public void Initialize(IncrementalGeneratorInitializationContext context)
+        {
+            // Get all types with the [GeneratedComClassAttribute] attribute.
+            var attributedClasses = context.SyntaxProvider
+                .ForAttributeWithMetadataName(
+                    TypeNames.GeneratedComClassAttribute,
+                    static (node, ct) => node is ClassDeclarationSyntax,
+                    static (context, ct) =>
+                    {
+                        var type = (INamedTypeSymbol)context.TargetSymbol;
+                        var syntax = (ClassDeclarationSyntax)context.TargetNode;
+                        ImmutableArray<string>.Builder names = ImmutableArray.CreateBuilder<string>();
+                        foreach (INamedTypeSymbol iface in type.AllInterfaces)
+                        {
+                            if (iface.GetAttributes().Any(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute))
+                            {
+                                names.Add(iface.ToDisplayString());
+                            }
+                        }
+                        return new ComClassInfo(
+                            type.ToDisplayString(),
+                            new ContainingSyntaxContext(syntax),
+                            new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
+                            new(names.ToImmutable()));
+                    });
+
+            var className = attributedClasses.Select(static (info, ct) => info.ClassName);
+
+            var classInfoType = attributedClasses
+                .Select(static (info, ct) => new { info.ClassName, info.ImplementedInterfacesNames })
+                .Select(static (info, ct) => GenerateClassInfoType(info.ImplementedInterfacesNames.Array).NormalizeWhitespace());
+
+            var attribute = attributedClasses
+                .Select(static (info, ct) => new { info.ContainingSyntaxContext, info.ClassSyntax })
+                .Select(static (info, ct) => GenerateClassInfoAttributeOnUserType(info.ContainingSyntaxContext, info.ClassSyntax).NormalizeWhitespace());
+
+            context.RegisterSourceOutput(className.Zip(classInfoType).Zip(attribute), static (context, classInfo) =>
+            {
+                var ((className, classInfoType), attribute) = classInfo;
+                StringWriter writer = new();
+                writer.WriteLine(classInfoType.ToFullString());
+                writer.WriteLine();
+                writer.WriteLine(attribute);
+                context.AddSource(className, writer.ToString());
+            });
+        }
+
+        private const string ClassInfoTypeName = "ComClassInformation";
+
+        private static readonly AttributeSyntax s_comExposedClassAttributeTemplate =
+            Attribute(
+                GenericName(TypeNames.ComExposedClassAttribute)
+                    .AddTypeArgumentListArguments(
+                        IdentifierName(ClassInfoTypeName)));
+        private static MemberDeclarationSyntax GenerateClassInfoAttributeOnUserType(ContainingSyntaxContext containingSyntaxContext, ContainingSyntax classSyntax) =>
+            containingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(
+                TypeDeclaration(classSyntax.TypeKind, classSyntax.Identifier)
+                    .WithModifiers(classSyntax.Modifiers)
+                    .WithTypeParameterList(classSyntax.TypeParameters)
+                    .AddAttributeLists(AttributeList(SingletonSeparatedList(s_comExposedClassAttributeTemplate))));
+        private static ClassDeclarationSyntax GenerateClassInfoType(ImmutableArray<string> implementedInterfaces)
+        {
+            const string vtablesField = "s_vtables";
+            const string vtablesLocal = "vtables";
+            const string detailsTempLocal = "details";
+            const string countIdentifier = "count";
+            var typeDeclaration = ClassDeclaration(ClassInfoTypeName)
+                .AddModifiers(
+                    Token(SyntaxKind.FileKeyword),
+                    Token(SyntaxKind.SealedKeyword),
+                    Token(SyntaxKind.UnsafeKeyword))
+                .AddBaseListTypes(SimpleBaseType(ParseTypeName(TypeNames.IComExposedClass)))
+                .AddMembers(
+                    FieldDeclaration(
+                        VariableDeclaration(
+                            PointerType(
+                                ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
+                            SingletonSeparatedList(VariableDeclarator(vtablesField))))
+                    .AddModifiers(
+                        Token(SyntaxKind.PrivateKeyword),
+                        Token(SyntaxKind.StaticKeyword),
+                        Token(SyntaxKind.VolatileKeyword)));
+            List<StatementSyntax> vtableInitializationBlock = new()
+            {
+                // ComInterfaceEntry* vtables = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(<ClassInfoTypeName>), sizeof(ComInterfaceEntry) * <numInterfaces>);
+                LocalDeclarationStatement(
+                    VariableDeclaration(
+                            PointerType(
+                                ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
+                            SingletonSeparatedList(
+                                VariableDeclarator(vtablesLocal)
+                                    .WithInitializer(EqualsValueClause(
+                                        CastExpression(
+                                            PointerType(
+                                                ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
+                                        InvocationExpression(
+                                            MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+                                                ParseTypeName(TypeNames.System_Runtime_CompilerServices_RuntimeHelpers),
+                                                IdentifierName("AllocateTypeAssociatedMemory")))
+                                        .AddArgumentListArguments(
+                                            Argument(TypeOfExpression(IdentifierName(ClassInfoTypeName))),
+                                            Argument(
+                                                BinaryExpression(
+                                                    SyntaxKind.MultiplyExpression,
+                                                    SizeOfExpression(ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
+                                                    LiteralExpression(
+                                                        SyntaxKind.NumericLiteralExpression,
+                                                        Literal(implementedInterfaces.Length))))))))))),
+                // IIUnknownDerivedDetails details;
+                LocalDeclarationStatement(
+                    VariableDeclaration(
+                        ParseTypeName(TypeNames.IIUnknownDerivedDetails),
+                        SingletonSeparatedList(
+                            VariableDeclarator(detailsTempLocal))))
+            };
+            for (int i = 0; i < implementedInterfaces.Length; i++)
+            {
+                string ifaceName = implementedInterfaces[i];
+
+                // details = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<ifaceName>).TypeHandle);
+                vtableInitializationBlock.Add(
+                    ExpressionStatement(
+                        AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
+                            IdentifierName(detailsTempLocal),
+                            InvocationExpression(
+                                MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+                                    MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+                                        ParseTypeName(TypeNames.StrategyBasedComWrappers),
+                                        IdentifierName("DefaultIUnknownInterfaceDetailsStrategy")),
+                                    IdentifierName("GetIUnknownDerivedDetails")),
+                                ArgumentList(
+                                    SingletonSeparatedList(
+                                        Argument(
+                                            MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+                                                TypeOfExpression(ParseName(ifaceName)),
+                                                IdentifierName("TypeHandle")))))))));
+                // vtable[i] = new() { IID = details.Iid, Vtable = details.ManagedVirtualMethodTable };
+                vtableInitializationBlock.Add(
+                    ExpressionStatement(
+                        AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
+                            ElementAccessExpression(
+                                IdentifierName(vtablesLocal),
+                                BracketedArgumentList(
+                                    SingletonSeparatedList(
+                                        Argument(
+                                            LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(i)))))),
+                            ImplicitObjectCreationExpression(
+                                ArgumentList(),
+                                InitializerExpression(SyntaxKind.ObjectInitializerExpression,
+                                    SeparatedList(
+                                        new ExpressionSyntax[]
+                                        {
+                                            AssignmentExpression(
+                                                SyntaxKind.SimpleAssignmentExpression,
+                                                IdentifierName("IID"),
+                                                MemberAccessExpression(
+                                                    SyntaxKind.SimpleMemberAccessExpression,
+                                                    IdentifierName(detailsTempLocal),
+                                                    IdentifierName("Iid"))),
+                                            AssignmentExpression(
+                                                SyntaxKind.SimpleAssignmentExpression,
+                                                IdentifierName("Vtable"),
+                                                CastExpression(
+                                                    IdentifierName("nint"),
+                                                    MemberAccessExpression(
+                                                        SyntaxKind.SimpleMemberAccessExpression,
+                                                        IdentifierName(detailsTempLocal),
+                                                        IdentifierName("ManagedVirtualMethodTable"))))
+                                        }))))));
+            }
+
+            // s_vtable = vtable;
+            vtableInitializationBlock.Add(
+                ExpressionStatement(
+                    AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
+                        IdentifierName(vtablesField),
+                        IdentifierName(vtablesLocal))));
+
+            BlockSyntax getComInterfaceEntriesMethodBody = Block(
+                // count = <count>;
+                ExpressionStatement(
+                    AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
+                        IdentifierName(countIdentifier),
+                        LiteralExpression(SyntaxKind.NumericLiteralExpression,
+                            Literal(implementedInterfaces.Length)))),
+                // if (s_vtable == null)
+                //   { initializer block }
+                IfStatement(
+                    BinaryExpression(SyntaxKind.EqualsExpression,
+                        IdentifierName(vtablesField),
+                        LiteralExpression(SyntaxKind.NullLiteralExpression)),
+                    Block(vtableInitializationBlock)),
+                // return s_vtable;
+                ReturnStatement(IdentifierName(vtablesField)));
+
+            typeDeclaration = typeDeclaration.AddMembers(
+                // public static unsafe ComWrappers.ComInterfaceDispatch* GetComInterfaceEntries(out int count)
+                // { body }
+                MethodDeclaration(
+                    PointerType(
+                        ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
+                    "GetComInterfaceEntries")
+                    .AddParameterListParameters(
+                        Parameter(Identifier(countIdentifier))
+                            .WithType(PredefinedType(Token(SyntaxKind.IntKeyword)))
+                            .AddModifiers(Token(SyntaxKind.OutKeyword)))
+                    .WithBody(getComInterfaceEntriesMethodBody)
+                    .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)));
+
+            return typeDeclaration;
+        }
+    }
+}
index f773acb..1770403 100644 (file)
@@ -45,7 +45,7 @@ namespace Microsoft.Interop
 
         public void Initialize(IncrementalGeneratorInitializationContext context)
         {
-            // Get all methods with the [GeneratedComInterface] attribute.
+            // Get all types with the [GeneratedComInterface] attribute.
             var attributedInterfaces = context.SyntaxProvider
                 .ForAttributeWithMetadataName(
                     TypeNames.GeneratedComInterfaceAttribute,
@@ -62,7 +62,7 @@ namespace Microsoft.Interop
                 return new { data.Syntax, data.Symbol, Diagnostic = diagnostic };
             });
 
-            // Split the methods we want to generate and the ones we don't into two separate groups.
+            // Split the types we want to generate and the ones we don't into two separate groups.
             var interfacesToGenerate = interfacesWithDiagnostics.Where(static data => data.Diagnostic is null);
             var invalidTypeDiagnostics = interfacesWithDiagnostics.Where(static data => data.Diagnostic is not null);
 
@@ -726,7 +726,7 @@ namespace Microsoft.Interop
                         .WithExpressionBody(
                             ArrowExpressionClause(
                                 ConditionalExpression(
-                                    BinaryExpression(SyntaxKind.EqualsExpression,
+                                    BinaryExpression(SyntaxKind.NotEqualsExpression,
                                         IdentifierName(vtableFieldName),
                                         LiteralExpression(SyntaxKind.NullLiteralExpression)),
                                     IdentifierName(vtableFieldName),
index 870368e..fb0dd80 100644 (file)
@@ -31,6 +31,18 @@ namespace Microsoft.Interop
                 });
         }
 
+        /// <summary>
+        /// Format the syntax nodes in the given provider such that we will not re-normalize if the input nodes have not changed.
+        /// </summary>
+        /// <typeparam name="TNode">A syntax node kind.</typeparam>
+        /// <param name="provider">The input nodes</param>
+        /// <returns>A provider of the formatted syntax nodes.</returns>
+        /// <remarks>
+        /// Normalizing whitespace is very expensive, so if a generator will have cases where the input information into the step
+        /// that creates <paramref name="provider"/> may change but the results of <paramref name="provider"/> will say the same,
+        /// using this method to format the code in a separate step will reduce the amount of work the generator repeats when the
+        /// output code will not change.
+        /// </remarks>
         public static IncrementalValuesProvider<TNode> SelectNormalized<TNode>(this IncrementalValuesProvider<TNode> provider)
             where TNode : SyntaxNode
         {
index c74ca0b..fa271c8 100644 (file)
@@ -117,13 +117,22 @@ namespace Microsoft.Interop
 
         public const string System_Runtime_InteropServices_ComWrappers_ComInterfaceDispatch = "System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch";
 
+        public const string System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry = "System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry";
+
+        public const string StrategyBasedComWrappers = "System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers";
+
         public const string IIUnknownInterfaceType = "System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType";
         public const string IUnknownDerivedAttribute = "System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute";
+        public const string IIUnknownDerivedDetails = "System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails";
 
         public const string ComWrappersUnwrapper = "System.Runtime.InteropServices.Marshalling.ComWrappersUnwrapper";
         public const string UnmanagedObjectUnwrapperAttribute = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapperAttribute`1";
 
         public const string IUnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.IUnmanagedObjectUnwrapper";
         public const string UnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapper";
+
+        public const string GeneratedComClassAttribute = "System.Runtime.InteropServices.Marshalling.GeneratedComClassAttribute";
+        public const string ComExposedClassAttribute = "System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute";
+        public const string IComExposedClass = "System.Runtime.InteropServices.Marshalling.IComExposedClass";
     }
 }
diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ComExposedClassAttribute.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/ComExposedClassAttribute.cs
new file mode 100644 (file)
index 0000000..5fb0d07
--- /dev/null
@@ -0,0 +1,23 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace System.Runtime.InteropServices.Marshalling
+{
+    /// <summary>
+    /// An attribute to mark this class as a type whose instances should be exposed to COM.
+    /// </summary>
+    /// <typeparam name="T">The type that provides information about how to expose the attributed type to COM.</typeparam>
+    [AttributeUsage(AttributeTargets.Class, Inherited = false)]
+    public sealed class ComExposedClassAttribute<T> : Attribute, IComExposedDetails
+        where T : IComExposedClass
+    {
+        /// <inheritdoc />
+        public unsafe ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count) => T.GetComInterfaceEntries(out count);
+    }
+}
index c66d4a7..2fecd06 100644 (file)
@@ -90,7 +90,7 @@ namespace System.Runtime.InteropServices.Marshalling
             qiHResult = 0;
             if (!CacheStrategy.TryGetTableInfo(handle, out result))
             {
-                IUnknownDerivedDetails? details = InterfaceDetailsStrategy.GetIUnknownDerivedDetails(handle);
+                IIUnknownDerivedDetails? details = InterfaceDetailsStrategy.GetIUnknownDerivedDetails(handle);
                 if (details is null)
                 {
                     return false;
index 07152e0..5b86d12 100644 (file)
@@ -11,7 +11,7 @@ namespace System.Runtime.InteropServices.Marshalling
         // [TODO] Implement some smart/thread-safe caching
         private readonly Dictionary<RuntimeTypeHandle, IIUnknownCacheStrategy.TableInfo> _cache = new();
 
-        IIUnknownCacheStrategy.TableInfo IIUnknownCacheStrategy.ConstructTableInfo(RuntimeTypeHandle handle, IUnknownDerivedDetails details, void* ptr)
+        IIUnknownCacheStrategy.TableInfo IIUnknownCacheStrategy.ConstructTableInfo(RuntimeTypeHandle handle, IIUnknownDerivedDetails details, void* ptr)
         {
             var obj = (void***)ptr;
             return new IIUnknownCacheStrategy.TableInfo()
index b1c3ff0..33f8bc1 100644 (file)
@@ -7,9 +7,14 @@ namespace System.Runtime.InteropServices.Marshalling
     {
         public static readonly IIUnknownInterfaceDetailsStrategy Instance = new DefaultIUnknownInterfaceDetailsStrategy();
 
-        public IUnknownDerivedDetails? GetIUnknownDerivedDetails(RuntimeTypeHandle type)
+        public IComExposedDetails? GetComExposedTypeDetails(RuntimeTypeHandle type)
         {
-            return IUnknownDerivedDetails.GetFromAttribute(type);
+            return IComExposedDetails.GetFromAttribute(type);
+        }
+
+        public IIUnknownDerivedDetails? GetIUnknownDerivedDetails(RuntimeTypeHandle type)
+        {
+            return IIUnknownDerivedDetails.GetFromAttribute(type);
         }
     }
 }
diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComClassAttribute.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComClassAttribute.cs
new file mode 100644 (file)
index 0000000..13fe839
--- /dev/null
@@ -0,0 +1,12 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+
+namespace System.Runtime.InteropServices.Marshalling
+{
+    [AttributeUsage(AttributeTargets.Class)]
+    public sealed class GeneratedComClassAttribute : Attribute
+    {
+    }
+}
index 09c81e4..7d81e17 100644 (file)
@@ -3,8 +3,6 @@
 
 namespace System.Runtime.InteropServices.Marshalling
 {
-    public interface IComObjectWrapper<T> { }
-
     [AttributeUsage(AttributeTargets.Interface)]
     public class GeneratedComInterfaceAttribute : Attribute
     {
diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComWrappersBase.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/GeneratedComWrappersBase.cs
deleted file mode 100644 (file)
index 7f3f420..0000000
+++ /dev/null
@@ -1,51 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
-// This type is for the COM source generator and implements part of the COM-specific interactions.
-// This API need to be exposed to implement the COM source generator in one form or another.
-
-using System.Collections;
-
-namespace System.Runtime.InteropServices.Marshalling
-{
-    public abstract class GeneratedComWrappersBase : ComWrappers
-    {
-        protected virtual IIUnknownInterfaceDetailsStrategy CreateInterfaceDetailsStrategy() => DefaultIUnknownInterfaceDetailsStrategy.Instance;
-
-        protected virtual IIUnknownStrategy CreateIUnknownStrategy() => FreeThreadedStrategy.Instance;
-
-        protected virtual IIUnknownCacheStrategy CreateCacheStrategy() => new DefaultCaching();
-
-        protected override sealed unsafe object CreateObject(nint externalComObject, CreateObjectFlags flags)
-        {
-            if (flags.HasFlag(CreateObjectFlags.TrackerObject)
-                || flags.HasFlag(CreateObjectFlags.Aggregation))
-            {
-                throw new NotSupportedException();
-            }
-
-            var rcw = new ComObject(CreateInterfaceDetailsStrategy(), CreateIUnknownStrategy(), CreateCacheStrategy(), (void*)externalComObject);
-            if (flags.HasFlag(CreateObjectFlags.UniqueInstance))
-            {
-                // Set value on MyComObject to enable the FinalRelease option.
-                // This could also be achieved through an internal factory
-                // function on ComObject type.
-            }
-            return rcw;
-        }
-
-        protected override sealed void ReleaseObjects(IEnumerable objects)
-        {
-            throw new NotImplementedException();
-        }
-
-        public ComObject GetOrCreateUniqueObjectForComInstance(nint comInstance, CreateObjectFlags flags)
-        {
-            if (flags.HasFlag(CreateObjectFlags.Unwrap))
-            {
-                throw new ArgumentException("Cannot create a unique object if unwrapping a ComWrappers-based COM object is requested.", nameof(flags));
-            }
-            return (ComObject)GetOrCreateObjectForComInstance(comInstance, flags | CreateObjectFlags.UniqueInstance);
-        }
-    }
-}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IComExposedClass.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IComExposedClass.cs
new file mode 100644 (file)
index 0000000..070f4d7
--- /dev/null
@@ -0,0 +1,18 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace System.Runtime.InteropServices.Marshalling
+{
+    /// <summary>
+    /// Type level information for managed class types exposed to COM.
+    /// </summary>
+    public unsafe interface IComExposedClass
+    {
+        /// <summary>
+        /// Get the COM interface information to provide to a <see cref="ComWrappers"/> instance to expose this type to COM.
+        /// </summary>
+        /// <param name="count">The number of COM interfaces this type implements.</param>
+        /// <returns>The interface entry information for the interfaces the type implements.</returns>
+        public static abstract ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count);
+    }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IComExposedDetails.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IComExposedDetails.cs
new file mode 100644 (file)
index 0000000..4c7c419
--- /dev/null
@@ -0,0 +1,31 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Reflection;
+
+namespace System.Runtime.InteropServices.Marshalling
+{
+    /// <summary>
+    /// Details about a managed class type exposed to COM.
+    /// </summary>
+    public unsafe interface IComExposedDetails
+    {
+        /// <summary>
+        /// Get the COM interface information to provide to a <see cref="ComWrappers"/> instance to expose this type to COM.
+        /// </summary>
+        /// <param name="count">The number of COM interfaces this type implements.</param>
+        /// <returns>The interface entry information for the interfaces the type implements.</returns>
+        ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count);
+
+        internal static IComExposedDetails? GetFromAttribute(RuntimeTypeHandle handle)
+        {
+            var type = Type.GetTypeFromHandle(handle);
+            if (type is null)
+            {
+                return null;
+            }
+            return (IComExposedDetails?)type.GetCustomAttribute(typeof(ComExposedClassAttribute<>));
+        }
+    }
+}
index d043344..0784e6e 100644 (file)
@@ -25,7 +25,7 @@ namespace System.Runtime.InteropServices.Marshalling
         /// <param name="ptr">Pointer to the instance to query</param>
         /// <param name="info">A <see cref="TableInfo"/> instance</param>
         /// <returns>True if success, otherwise false.</returns>
-        TableInfo ConstructTableInfo(RuntimeTypeHandle handle, IUnknownDerivedDetails interfaceDetails, void* ptr);
+        TableInfo ConstructTableInfo(RuntimeTypeHandle handle, IIUnknownDerivedDetails interfaceDetails, void* ptr);
 
         /// <summary>
         /// Get associated <see cref="TableInfo"/>.
@@ -11,7 +11,7 @@ namespace System.Runtime.InteropServices.Marshalling
     /// <summary>
     /// Details for the IUnknown derived interface.
     /// </summary>
-    public interface IUnknownDerivedDetails
+    public interface IIUnknownDerivedDetails
     {
         /// <summary>
         /// Interface ID.
@@ -28,14 +28,14 @@ namespace System.Runtime.InteropServices.Marshalling
         /// </summary>
         public unsafe void** ManagedVirtualMethodTable { get; }
 
-        internal static IUnknownDerivedDetails? GetFromAttribute(RuntimeTypeHandle handle)
+        internal static IIUnknownDerivedDetails? GetFromAttribute(RuntimeTypeHandle handle)
         {
             var type = Type.GetTypeFromHandle(handle);
             if (type is null)
             {
                 return null;
             }
-            return (IUnknownDerivedDetails?)type.GetCustomAttribute(typeof(IUnknownDerivedAttribute<,>));
+            return (IIUnknownDerivedDetails?)type.GetCustomAttribute(typeof(IUnknownDerivedAttribute<,>));
         }
     }
 }
index cd9f849..4492117 100644 (file)
@@ -16,6 +16,13 @@ namespace System.Runtime.InteropServices.Marshalling
         /// </summary>
         /// <param name="type">RuntimeTypeHandle instance</param>
         /// <returns>Details if type is known.</returns>
-        IUnknownDerivedDetails? GetIUnknownDerivedDetails(RuntimeTypeHandle type);
+        IIUnknownDerivedDetails? GetIUnknownDerivedDetails(RuntimeTypeHandle type);
+
+        /// <summary>
+        /// Given a <see cref="RuntimeTypeHandle"/> get the details about the type that are exposed to COM.
+        /// </summary>
+        /// <param name="type">RuntimeTypeHandle instance</param>
+        /// <returns>Details if type is known.</returns>
+        IComExposedDetails? GetComExposedTypeDetails(RuntimeTypeHandle type);
     }
 }
index a600826..73b41fc 100644 (file)
@@ -6,9 +6,19 @@
 
 namespace System.Runtime.InteropServices.Marshalling
 {
+    /// <summary>
+    /// Type level information for an IUnknown-derived interface.
+    /// </summary>
     public unsafe interface IIUnknownInterfaceType
     {
+        /// <summary>
+        /// The Interface ID (IID) for the interface.
+        /// </summary>
         public abstract static Guid Iid { get; }
+
+        /// <summary>
+        /// A pointer to the virtual method table to enable unmanaged callers to call a managed implementation of the interface.
+        /// </summary>
         public abstract static void** ManagedVirtualMethodTable { get; }
     }
 }
index 8624e00..02f8d78 100644 (file)
@@ -6,8 +6,13 @@
 
 namespace System.Runtime.InteropServices.Marshalling
 {
-    [AttributeUsage(AttributeTargets.Interface)]
-    public class IUnknownDerivedAttribute<T, TImpl> : Attribute, IUnknownDerivedDetails
+    /// <summary>
+    /// An attribute to mark this interface as a managed representation of an IUnknown-derived interface.
+    /// </summary>
+    /// <typeparam name="T">The type that provides type-level information about the interface.</typeparam>
+    /// <typeparam name="TImpl">The type to use for calling from managed callers to unmanaged implementations of the interface.</typeparam>
+    [AttributeUsage(AttributeTargets.Interface, Inherited = false)]
+    public class IUnknownDerivedAttribute<T, TImpl> : Attribute, IIUnknownDerivedDetails
         where T : IIUnknownInterfaceType
     {
         public IUnknownDerivedAttribute()
index e6e3442..e06f50b 100644 (file)
@@ -1,14 +1,12 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
-// Types that are only needed for the VTable source generator or to provide abstract concepts that the COM generator would use under the hood.
-// These are types that we can exclude from the API proposals and either inline into the generated code, provide as file-scoped types, or not provide publicly (indicated by comments on each type).
-
 using System.Collections;
+using System.Reflection;
 
 namespace System.Runtime.InteropServices.Marshalling
 {
-    public abstract class StrategyBasedComWrappers : InteropServices.ComWrappers
+    public class StrategyBasedComWrappers : ComWrappers
     {
         public static IIUnknownInterfaceDetailsStrategy DefaultIUnknownInterfaceDetailsStrategy { get; } = Marshalling.DefaultIUnknownInterfaceDetailsStrategy.Instance;
 
@@ -22,6 +20,16 @@ namespace System.Runtime.InteropServices.Marshalling
 
         protected virtual IIUnknownCacheStrategy CreateCacheStrategy() => CreateDefaultCacheStrategy();
 
+        protected override sealed unsafe ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
+        {
+            if (obj.GetType().GetCustomAttribute(typeof(ComExposedClassAttribute<>)) is IComExposedDetails details)
+            {
+                return details.GetComInterfaceEntries(out count);
+            }
+            count = 0;
+            return null;
+        }
+
         protected override sealed unsafe object CreateObject(nint externalComObject, CreateObjectFlags flags)
         {
             if (flags.HasFlag(CreateObjectFlags.TrackerObject)
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComClassTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComClassTests.cs
new file mode 100644 (file)
index 0000000..5481817
--- /dev/null
@@ -0,0 +1,91 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+using Xunit;
+
+namespace ComInterfaceGenerator.Tests
+{
+    unsafe partial class NativeExportsNE
+    {
+        [LibraryImport(NativeExportsNE_Binary, EntryPoint = "set_com_object_data")]
+        public static partial void SetComObjectData(void* obj, int data);
+
+        [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_com_object_data")]
+        public static partial int GetComObjectData(void* obj);
+    }
+
+    [GeneratedComClass]
+    partial class ManagedObjectExposedToCom : IComInterface1
+    {
+        public int Data { get; set; }
+        int IComInterface1.GetData() => Data;
+        void IComInterface1.SetData(int n) => Data = n;
+    }
+
+    [GeneratedComClass]
+    partial class DerivedComObject : ManagedObjectExposedToCom
+    {
+    }
+
+    public unsafe class GeneratedComClassTests
+    {
+        [Fact]
+        public void ComInstanceProvidesInterfaceForDirectlyImplementedComInterface()
+        {
+            ManagedObjectExposedToCom obj = new();
+            StrategyBasedComWrappers wrappers = new();
+            nint ptr = wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None);
+            Assert.NotEqual(0, ptr);
+            var iid = typeof(IComInterface1).GUID;
+            Assert.Equal(0, Marshal.QueryInterface(ptr, ref iid, out nint iComInterface));
+            Assert.NotEqual(0, iComInterface);
+            Marshal.Release(iComInterface);
+            Marshal.Release(ptr);
+        }
+
+        [Fact]
+        public void ComInstanceProvidesInterfaceForIndirectlyImplementedComInterface()
+        {
+            DerivedComObject obj = new();
+            StrategyBasedComWrappers wrappers = new();
+            nint ptr = wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None);
+            Assert.NotEqual(0, ptr);
+            var iid = typeof(IComInterface1).GUID;
+            Assert.Equal(0, Marshal.QueryInterface(ptr, ref iid, out nint iComInterface));
+            Assert.NotEqual(0, iComInterface);
+            Marshal.Release(iComInterface);
+            Marshal.Release(ptr);
+        }
+
+        [Fact]
+        public void CallsToComInterfaceWriteChangesToManagedObject()
+        {
+            ManagedObjectExposedToCom obj = new();
+            StrategyBasedComWrappers wrappers = new();
+            void* ptr = (void*)wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None);
+            Assert.NotEqual(0, (nint)ptr);
+            obj.Data = 3;
+            Assert.Equal(3, obj.Data);
+            NativeExportsNE.SetComObjectData(ptr, 42);
+            Assert.Equal(42, obj.Data);
+            Marshal.Release((nint)ptr);
+        }
+
+        [Fact]
+        public void CallsToComInterfaceReadChangesFromManagedObject()
+        {
+            ManagedObjectExposedToCom obj = new();
+            StrategyBasedComWrappers wrappers = new();
+            void* ptr = (void*)wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None);
+            Assert.NotEqual(0, (nint)ptr);
+            obj.Data = 3;
+            Assert.Equal(3, obj.Data);
+            obj.Data = 12;
+            Assert.Equal(obj.Data, NativeExportsNE.GetComObjectData(ptr));
+            Marshal.Release((nint)ptr);
+        }
+    }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IComInterface1.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IComInterface1.cs
new file mode 100644 (file)
index 0000000..bb464c0
--- /dev/null
@@ -0,0 +1,18 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+
+namespace ComInterfaceGenerator.Tests
+{
+    [GeneratedComInterface]
+    [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
+    [Guid("2c3f9903-b586-46b1-881b-adfce9af47b1")]
+    public partial interface IComInterface1
+    {
+        int GetData();
+        void SetData(int n);
+    }
+}
index 0841e6e..4f41e80 100644 (file)
@@ -1,7 +1,6 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
-using System;
 using System.Collections;
 using System.Diagnostics;
 using System.Linq;
@@ -14,23 +13,9 @@ using Xunit.Sdk;
 
 namespace ComInterfaceGenerator.Tests;
 
-[GeneratedComInterface]
-[InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
-[Guid("2c3f9903-b586-46b1-881b-adfce9af47b1")]
-public partial interface IComInterface1
+internal unsafe partial class NativeExportsNE
 {
-    int GetData();
-    void SetData(int n);
-}
-
-internal sealed unsafe partial class MyGeneratedComWrappers : StrategyBasedComWrappers
-{
-    protected sealed override unsafe ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) => throw new UnreachableException("Not creating CCWs yet");
-}
-
-public static unsafe partial class Native
-{
-    [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "get_com_object")]
+    [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_com_object")]
     public static partial void* NewNativeObject();
 }
 
@@ -40,8 +25,8 @@ public class RcwTests
     [Fact]
     public unsafe void CallRcwFromGeneratedComInterface()
     {
-        var ptr = Native.NewNativeObject(); // new_native_object
-        var cw = new MyGeneratedComWrappers();
+        var ptr = NativeExportsNE.NewNativeObject(); // new_native_object
+        var cw = new StrategyBasedComWrappers();
         var obj = cw.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None);
 
         var intObj = (IComInterface1)obj;
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Unit.Tests/ComClassGeneratorOutputShape.cs
new file mode 100644 (file)
index 0000000..2bb8d5e
--- /dev/null
@@ -0,0 +1,104 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Linq;
+using System.Threading.Tasks;
+using Microsoft.CodeAnalysis;
+using Microsoft.Interop.UnitTests;
+using Xunit;
+
+namespace ComInterfaceGenerator.Unit.Tests
+{
+    public class ComClassGeneratorOutputShape
+    {
+        [Fact]
+        public async Task SingleComClass()
+        {
+            string source = """
+                using System.Runtime.InteropServices;
+                using System.Runtime.InteropServices.Marshalling;
+
+                [GeneratedComInterface]
+                partial interface INativeAPI
+                {
+                }
+
+                [GeneratedComClass]
+                partial class C : INativeAPI {}
+                """;
+            Compilation comp = await TestUtils.CreateCompilation(source);
+            TestUtils.AssertPreSourceGeneratorCompilation(comp);
+
+            var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.ComClassGenerator());
+            TestUtils.AssertPostSourceGeneratorCompilation(newComp);
+            // We'll create one syntax tree for the new interface.
+            Assert.Equal(comp.SyntaxTrees.Count() + 1, newComp.SyntaxTrees.Count());
+
+            VerifyShape(newComp, "C");
+        }
+
+        [Fact]
+        public async Task MultipleComClasses()
+        {
+            string source = $$"""
+                using System.Runtime.InteropServices;
+                using System.Runtime.InteropServices.Marshalling;
+                
+                [GeneratedComInterface]
+                partial interface I
+                {
+                }
+                [GeneratedComInterface]
+                partial interface J
+                {
+                }
+                
+                [GeneratedComClass]
+                partial class C : I, J
+                {
+                }
+
+                [GeneratedComClass]
+                partial class D : I, J
+                {
+                }
+
+                [GeneratedComClass]
+                partial class E : C
+                {
+                }
+                """;
+            Compilation comp = await TestUtils.CreateCompilation(source);
+            TestUtils.AssertPreSourceGeneratorCompilation(comp);
+
+            var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.ComClassGenerator());
+            TestUtils.AssertPostSourceGeneratorCompilation(newComp);
+            // We'll create one syntax tree per user-defined interface.
+            Assert.Equal(comp.SyntaxTrees.Count() + 3, newComp.SyntaxTrees.Count());
+
+            VerifyShape(newComp, "C");
+            VerifyShape(newComp, "D");
+            VerifyShape(newComp, "E");
+        }
+        private static void VerifyShape(Compilation comp, string userDefinedClassMetadataName)
+        {
+            INamedTypeSymbol? userDefinedClass = comp.Assembly.GetTypeByMetadataName(userDefinedClassMetadataName);
+            Assert.NotNull(userDefinedClass);
+
+            INamedTypeSymbol? comExposedClassAttribute = comp.GetTypeByMetadataName("System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute`1");
+
+            Assert.NotNull(comExposedClassAttribute);
+
+            AttributeData iUnknownDerivedAttribute = Assert.Single(
+                userDefinedClass.GetAttributes(),
+                attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.OriginalDefinition, comExposedClassAttribute));
+
+            Assert.Collection(Assert.IsAssignableFrom<INamedTypeSymbol>(iUnknownDerivedAttribute.AttributeClass).TypeArguments,
+                infoType =>
+                {
+                    Assert.True(Assert.IsAssignableFrom<INamedTypeSymbol>(infoType).IsFileLocal);
+                });
+        }
+    }
+}
index e4291e6..a77166a 100644 (file)
@@ -34,12 +34,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                     {
                         void Bar() {}
                     }
-
-                    public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                    {
-                        protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                    }
-
                     """;
                 await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
             }
@@ -54,12 +48,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                     {
                         void Bar() {}
                     }
-
-                    public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                    {
-                        protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                    }
-
                     """;
                 await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
             }
@@ -75,13 +63,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
             }
@@ -97,12 +78,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
             }
@@ -118,12 +93,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
             }
@@ -139,12 +108,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
             }
@@ -160,12 +123,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
             }
@@ -181,12 +138,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
             }
@@ -206,12 +157,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
             }
@@ -232,12 +177,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
             }
@@ -254,12 +193,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
             }
@@ -276,12 +209,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(
                     _usings + snippet,
@@ -302,12 +229,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(
                     _usings + snippet,
@@ -328,12 +249,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(
                     _usings + snippet,
@@ -354,12 +269,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(
                     _usings + snippet,
@@ -380,12 +289,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(
                     _usings + snippet,
@@ -406,12 +309,6 @@ namespace ComInterfaceGenerator.Unit.Tests
                 {
                     void Bar() {}
                 }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(
                     _usings + snippet,
@@ -438,12 +335,6 @@ namespace ComInterfaceGenerator.Unit.Tests
 
                 [GeneratedComInterface]
                 partial interface IFoo { }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
             }
@@ -462,12 +353,6 @@ namespace ComInterfaceGenerator.Unit.Tests
 
                 [GeneratedComInterface]
                 partial interface IFoo { }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
             }
@@ -486,12 +371,6 @@ namespace ComInterfaceGenerator.Unit.Tests
 
                 [GeneratedComInterface]
                 partial interface IFoo { }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(
                     _usings + snippet,
@@ -514,12 +393,6 @@ namespace ComInterfaceGenerator.Unit.Tests
 
                 [GeneratedComInterface]
                 partial interface IFoo { }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(
                     _usings + snippet,
@@ -542,12 +415,6 @@ namespace ComInterfaceGenerator.Unit.Tests
 
                 [GeneratedComInterface]
                 partial interface IFoo { }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(
                     _usings + snippet,
@@ -570,12 +437,6 @@ namespace ComInterfaceGenerator.Unit.Tests
 
                 [GeneratedComInterface]
                 partial interface IFoo { }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(
                     _usings + snippet,
@@ -598,12 +459,6 @@ namespace ComInterfaceGenerator.Unit.Tests
 
                 [GeneratedComInterface]
                 partial interface IFoo { }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(
                     _usings + snippet,
@@ -626,12 +481,6 @@ namespace ComInterfaceGenerator.Unit.Tests
 
                 [GeneratedComInterface]
                 partial interface IFoo { }
-
-                public unsafe partial class MyComWrappers : GeneratedComWrappersBase
-                {
-                    protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
-                }
-
                 """;
                 await VerifyCS.VerifyAnalyzerAsync(
                     _usings + snippet,
index 52db26b..651f02b 100644 (file)
@@ -15,7 +15,7 @@ using static System.Runtime.InteropServices.ComWrappers;
 
 namespace NativeExports;
 
-public static unsafe class ComInterfaceGeneratorExports
+public static unsafe class ComInterfaces
 {
     interface IComInterface1
     {
@@ -30,42 +30,70 @@ public static unsafe class ComInterfaceGeneratorExports
     [UnmanagedCallersOnly(EntryPoint = "get_com_object")]
     public static void* CreateComObject()
     {
-        MyComWrapper cw = new();
         var myObject = new MyObject();
-        nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);
+        nint ptr = ComWrappersInstance.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);
 
         return (void*)ptr;
     }
 
+    [UnmanagedCallersOnly(EntryPoint = "set_com_object_data")]
+    public static void SetComObjectData(void* ptr, int value)
+    {
+        IComInterface1 obj = (IComInterface1)ComWrappersInstance.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None);
+        obj.SetData(value);
+    }
+
+    [UnmanagedCallersOnly(EntryPoint = "get_com_object_data")]
+    public static int GetComObjectData(void* ptr)
+    {
+        IComInterface1 obj = (IComInterface1)ComWrappersInstance.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None);
+        return obj.GetData();
+    }
+
+    private static readonly ComWrappers ComWrappersInstance = new MyComWrapper();
+
     class MyComWrapper : System.Runtime.InteropServices.ComWrappers
     {
-        static void* _s_comInterface1VTable = null;
-        static void* s_comInterface1VTable
+        static volatile void* s_comInterface1VTable = null;
+        static void* IComInterface1VTable
         {
             get
             {
-                if (MyComWrapper._s_comInterface1VTable != null)
-                    return _s_comInterface1VTable;
-                void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComInterfaceGeneratorExports), sizeof(void*) * 5);
+                if (s_comInterface1VTable != null)
+                    return s_comInterface1VTable;
+                void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComInterfaces), sizeof(void*) * 5);
                 GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease);
                 vtable[0] = (void*)fpQueryInterface;
                 vtable[1] = (void*)fpAddReference;
                 vtable[2] = (void*)fpRelease;
                 vtable[3] = (delegate* unmanaged<void*, int*, int>)&MyObject.ABI.GetData;
                 vtable[4] = (delegate* unmanaged<void*, int, int>)&MyObject.ABI.SetData;
-                _s_comInterface1VTable = vtable;
-                return _s_comInterface1VTable;
+                s_comInterface1VTable = vtable;
+                return s_comInterface1VTable;
             }
         }
-        protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
+
+        static volatile ComInterfaceEntry* s_myObjectComInterfaceEntries = null;
+        static ComInterfaceEntry* MyObjectComInterfaceEntries
         {
-            if (obj is MyObject)
+            get
             {
+                if (s_myObjectComInterfaceEntries != null)
+                    return s_myObjectComInterfaceEntries;
+
                 ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(MyObject), sizeof(ComInterfaceEntry));
                 comInterfaceEntry->IID = IComInterface1.IID;
-                comInterfaceEntry->Vtable = (nint)s_comInterface1VTable;
+                comInterfaceEntry->Vtable = (nint)IComInterface1VTable;
+                s_myObjectComInterfaceEntries = comInterfaceEntry;
+                return s_myObjectComInterfaceEntries;
+            }
+        }
+        protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
+        {
+            if (obj is MyObject)
+            {
                 count = 1;
-                return comInterfaceEntry;
+                return MyObjectComInterfaceEntries;
             }
             count = 0;
             return null;
@@ -77,14 +105,14 @@ public static unsafe class ComInterfaceGeneratorExports
             {
                 return null;
             }
-            return new IComInterface1Impl(ptr);
+            return new IComInterface1Impl(IComInterfaceImpl);
         }
 
         protected override void ReleaseObjects(IEnumerable objects) { }
     }
 
     // Wrapper for calling CCWs from the ComInterfaceGenerator
-    class IComInterface1Impl : IComInterface1
+    sealed class IComInterface1Impl : IComInterface1
     {
         nint _ptr;
 
@@ -93,6 +121,11 @@ public static unsafe class ComInterfaceGeneratorExports
             _ptr = @this;
         }
 
+        ~IComInterface1Impl()
+        {
+            int refCount = Marshal.Release(_ptr);
+        }
+
         int GetData(nint inst)
         {
             int value;