--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.IO;
+using System.Linq;
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
+using Microsoft.CodeAnalysis.CSharp;
+
+namespace Microsoft.Interop
+{
+ [Generator]
+ public class ComClassGenerator : IIncrementalGenerator
+ {
+ private sealed record ComClassInfo(string ClassName, ContainingSyntaxContext ContainingSyntaxContext, ContainingSyntax ClassSyntax, SequenceEqualImmutableArray<string> ImplementedInterfacesNames);
+ public void Initialize(IncrementalGeneratorInitializationContext context)
+ {
+ // Get all types with the [GeneratedComClassAttribute] attribute.
+ var attributedClasses = context.SyntaxProvider
+ .ForAttributeWithMetadataName(
+ TypeNames.GeneratedComClassAttribute,
+ static (node, ct) => node is ClassDeclarationSyntax,
+ static (context, ct) =>
+ {
+ var type = (INamedTypeSymbol)context.TargetSymbol;
+ var syntax = (ClassDeclarationSyntax)context.TargetNode;
+ ImmutableArray<string>.Builder names = ImmutableArray.CreateBuilder<string>();
+ foreach (INamedTypeSymbol iface in type.AllInterfaces)
+ {
+ if (iface.GetAttributes().Any(attr => attr.AttributeClass?.ToDisplayString() == TypeNames.GeneratedComInterfaceAttribute))
+ {
+ names.Add(iface.ToDisplayString());
+ }
+ }
+ return new ComClassInfo(
+ type.ToDisplayString(),
+ new ContainingSyntaxContext(syntax),
+ new ContainingSyntax(syntax.Modifiers, syntax.Kind(), syntax.Identifier, syntax.TypeParameterList),
+ new(names.ToImmutable()));
+ });
+
+ var className = attributedClasses.Select(static (info, ct) => info.ClassName);
+
+ var classInfoType = attributedClasses
+ .Select(static (info, ct) => new { info.ClassName, info.ImplementedInterfacesNames })
+ .Select(static (info, ct) => GenerateClassInfoType(info.ImplementedInterfacesNames.Array).NormalizeWhitespace());
+
+ var attribute = attributedClasses
+ .Select(static (info, ct) => new { info.ContainingSyntaxContext, info.ClassSyntax })
+ .Select(static (info, ct) => GenerateClassInfoAttributeOnUserType(info.ContainingSyntaxContext, info.ClassSyntax).NormalizeWhitespace());
+
+ context.RegisterSourceOutput(className.Zip(classInfoType).Zip(attribute), static (context, classInfo) =>
+ {
+ var ((className, classInfoType), attribute) = classInfo;
+ StringWriter writer = new();
+ writer.WriteLine(classInfoType.ToFullString());
+ writer.WriteLine();
+ writer.WriteLine(attribute);
+ context.AddSource(className, writer.ToString());
+ });
+ }
+
+ private const string ClassInfoTypeName = "ComClassInformation";
+
+ private static readonly AttributeSyntax s_comExposedClassAttributeTemplate =
+ Attribute(
+ GenericName(TypeNames.ComExposedClassAttribute)
+ .AddTypeArgumentListArguments(
+ IdentifierName(ClassInfoTypeName)));
+ private static MemberDeclarationSyntax GenerateClassInfoAttributeOnUserType(ContainingSyntaxContext containingSyntaxContext, ContainingSyntax classSyntax) =>
+ containingSyntaxContext.WrapMemberInContainingSyntaxWithUnsafeModifier(
+ TypeDeclaration(classSyntax.TypeKind, classSyntax.Identifier)
+ .WithModifiers(classSyntax.Modifiers)
+ .WithTypeParameterList(classSyntax.TypeParameters)
+ .AddAttributeLists(AttributeList(SingletonSeparatedList(s_comExposedClassAttributeTemplate))));
+ private static ClassDeclarationSyntax GenerateClassInfoType(ImmutableArray<string> implementedInterfaces)
+ {
+ const string vtablesField = "s_vtables";
+ const string vtablesLocal = "vtables";
+ const string detailsTempLocal = "details";
+ const string countIdentifier = "count";
+ var typeDeclaration = ClassDeclaration(ClassInfoTypeName)
+ .AddModifiers(
+ Token(SyntaxKind.FileKeyword),
+ Token(SyntaxKind.SealedKeyword),
+ Token(SyntaxKind.UnsafeKeyword))
+ .AddBaseListTypes(SimpleBaseType(ParseTypeName(TypeNames.IComExposedClass)))
+ .AddMembers(
+ FieldDeclaration(
+ VariableDeclaration(
+ PointerType(
+ ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
+ SingletonSeparatedList(VariableDeclarator(vtablesField))))
+ .AddModifiers(
+ Token(SyntaxKind.PrivateKeyword),
+ Token(SyntaxKind.StaticKeyword),
+ Token(SyntaxKind.VolatileKeyword)));
+ List<StatementSyntax> vtableInitializationBlock = new()
+ {
+ // ComInterfaceEntry* vtables = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(<ClassInfoTypeName>), sizeof(ComInterfaceEntry) * <numInterfaces>);
+ LocalDeclarationStatement(
+ VariableDeclaration(
+ PointerType(
+ ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
+ SingletonSeparatedList(
+ VariableDeclarator(vtablesLocal)
+ .WithInitializer(EqualsValueClause(
+ CastExpression(
+ PointerType(
+ ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
+ InvocationExpression(
+ MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+ ParseTypeName(TypeNames.System_Runtime_CompilerServices_RuntimeHelpers),
+ IdentifierName("AllocateTypeAssociatedMemory")))
+ .AddArgumentListArguments(
+ Argument(TypeOfExpression(IdentifierName(ClassInfoTypeName))),
+ Argument(
+ BinaryExpression(
+ SyntaxKind.MultiplyExpression,
+ SizeOfExpression(ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
+ LiteralExpression(
+ SyntaxKind.NumericLiteralExpression,
+ Literal(implementedInterfaces.Length))))))))))),
+ // IIUnknownDerivedDetails details;
+ LocalDeclarationStatement(
+ VariableDeclaration(
+ ParseTypeName(TypeNames.IIUnknownDerivedDetails),
+ SingletonSeparatedList(
+ VariableDeclarator(detailsTempLocal))))
+ };
+ for (int i = 0; i < implementedInterfaces.Length; i++)
+ {
+ string ifaceName = implementedInterfaces[i];
+
+ // details = StrategyBasedComWrappers.DefaultIUnknownInterfaceDetailsStrategy.GetIUnknownDerivedDetails(typeof(<ifaceName>).TypeHandle);
+ vtableInitializationBlock.Add(
+ ExpressionStatement(
+ AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
+ IdentifierName(detailsTempLocal),
+ InvocationExpression(
+ MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+ MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+ ParseTypeName(TypeNames.StrategyBasedComWrappers),
+ IdentifierName("DefaultIUnknownInterfaceDetailsStrategy")),
+ IdentifierName("GetIUnknownDerivedDetails")),
+ ArgumentList(
+ SingletonSeparatedList(
+ Argument(
+ MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression,
+ TypeOfExpression(ParseName(ifaceName)),
+ IdentifierName("TypeHandle")))))))));
+ // vtable[i] = new() { IID = details.Iid, Vtable = details.ManagedVirtualMethodTable };
+ vtableInitializationBlock.Add(
+ ExpressionStatement(
+ AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
+ ElementAccessExpression(
+ IdentifierName(vtablesLocal),
+ BracketedArgumentList(
+ SingletonSeparatedList(
+ Argument(
+ LiteralExpression(SyntaxKind.NumericLiteralExpression, Literal(i)))))),
+ ImplicitObjectCreationExpression(
+ ArgumentList(),
+ InitializerExpression(SyntaxKind.ObjectInitializerExpression,
+ SeparatedList(
+ new ExpressionSyntax[]
+ {
+ AssignmentExpression(
+ SyntaxKind.SimpleAssignmentExpression,
+ IdentifierName("IID"),
+ MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName(detailsTempLocal),
+ IdentifierName("Iid"))),
+ AssignmentExpression(
+ SyntaxKind.SimpleAssignmentExpression,
+ IdentifierName("Vtable"),
+ CastExpression(
+ IdentifierName("nint"),
+ MemberAccessExpression(
+ SyntaxKind.SimpleMemberAccessExpression,
+ IdentifierName(detailsTempLocal),
+ IdentifierName("ManagedVirtualMethodTable"))))
+ }))))));
+ }
+
+ // s_vtable = vtable;
+ vtableInitializationBlock.Add(
+ ExpressionStatement(
+ AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
+ IdentifierName(vtablesField),
+ IdentifierName(vtablesLocal))));
+
+ BlockSyntax getComInterfaceEntriesMethodBody = Block(
+ // count = <count>;
+ ExpressionStatement(
+ AssignmentExpression(SyntaxKind.SimpleAssignmentExpression,
+ IdentifierName(countIdentifier),
+ LiteralExpression(SyntaxKind.NumericLiteralExpression,
+ Literal(implementedInterfaces.Length)))),
+ // if (s_vtable == null)
+ // { initializer block }
+ IfStatement(
+ BinaryExpression(SyntaxKind.EqualsExpression,
+ IdentifierName(vtablesField),
+ LiteralExpression(SyntaxKind.NullLiteralExpression)),
+ Block(vtableInitializationBlock)),
+ // return s_vtable;
+ ReturnStatement(IdentifierName(vtablesField)));
+
+ typeDeclaration = typeDeclaration.AddMembers(
+ // public static unsafe ComWrappers.ComInterfaceDispatch* GetComInterfaceEntries(out int count)
+ // { body }
+ MethodDeclaration(
+ PointerType(
+ ParseTypeName(TypeNames.System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry)),
+ "GetComInterfaceEntries")
+ .AddParameterListParameters(
+ Parameter(Identifier(countIdentifier))
+ .WithType(PredefinedType(Token(SyntaxKind.IntKeyword)))
+ .AddModifiers(Token(SyntaxKind.OutKeyword)))
+ .WithBody(getComInterfaceEntriesMethodBody)
+ .AddModifiers(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)));
+
+ return typeDeclaration;
+ }
+ }
+}
public void Initialize(IncrementalGeneratorInitializationContext context)
{
- // Get all methods with the [GeneratedComInterface] attribute.
+ // Get all types with the [GeneratedComInterface] attribute.
var attributedInterfaces = context.SyntaxProvider
.ForAttributeWithMetadataName(
TypeNames.GeneratedComInterfaceAttribute,
return new { data.Syntax, data.Symbol, Diagnostic = diagnostic };
});
- // Split the methods we want to generate and the ones we don't into two separate groups.
+ // Split the types we want to generate and the ones we don't into two separate groups.
var interfacesToGenerate = interfacesWithDiagnostics.Where(static data => data.Diagnostic is null);
var invalidTypeDiagnostics = interfacesWithDiagnostics.Where(static data => data.Diagnostic is not null);
.WithExpressionBody(
ArrowExpressionClause(
ConditionalExpression(
- BinaryExpression(SyntaxKind.EqualsExpression,
+ BinaryExpression(SyntaxKind.NotEqualsExpression,
IdentifierName(vtableFieldName),
LiteralExpression(SyntaxKind.NullLiteralExpression)),
IdentifierName(vtableFieldName),
});
}
+ /// <summary>
+ /// Format the syntax nodes in the given provider such that we will not re-normalize if the input nodes have not changed.
+ /// </summary>
+ /// <typeparam name="TNode">A syntax node kind.</typeparam>
+ /// <param name="provider">The input nodes</param>
+ /// <returns>A provider of the formatted syntax nodes.</returns>
+ /// <remarks>
+ /// Normalizing whitespace is very expensive, so if a generator will have cases where the input information into the step
+ /// that creates <paramref name="provider"/> may change but the results of <paramref name="provider"/> will say the same,
+ /// using this method to format the code in a separate step will reduce the amount of work the generator repeats when the
+ /// output code will not change.
+ /// </remarks>
public static IncrementalValuesProvider<TNode> SelectNormalized<TNode>(this IncrementalValuesProvider<TNode> provider)
where TNode : SyntaxNode
{
public const string System_Runtime_InteropServices_ComWrappers_ComInterfaceDispatch = "System.Runtime.InteropServices.ComWrappers.ComInterfaceDispatch";
+ public const string System_Runtime_InteropServices_ComWrappers_ComInterfaceEntry = "System.Runtime.InteropServices.ComWrappers.ComInterfaceEntry";
+
+ public const string StrategyBasedComWrappers = "System.Runtime.InteropServices.Marshalling.StrategyBasedComWrappers";
+
public const string IIUnknownInterfaceType = "System.Runtime.InteropServices.Marshalling.IIUnknownInterfaceType";
public const string IUnknownDerivedAttribute = "System.Runtime.InteropServices.Marshalling.IUnknownDerivedAttribute";
+ public const string IIUnknownDerivedDetails = "System.Runtime.InteropServices.Marshalling.IIUnknownDerivedDetails";
public const string ComWrappersUnwrapper = "System.Runtime.InteropServices.Marshalling.ComWrappersUnwrapper";
public const string UnmanagedObjectUnwrapperAttribute = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapperAttribute`1";
public const string IUnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.IUnmanagedObjectUnwrapper";
public const string UnmanagedObjectUnwrapper = "System.Runtime.InteropServices.Marshalling.UnmanagedObjectUnwrapper";
+
+ public const string GeneratedComClassAttribute = "System.Runtime.InteropServices.Marshalling.GeneratedComClassAttribute";
+ public const string ComExposedClassAttribute = "System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute";
+ public const string IComExposedClass = "System.Runtime.InteropServices.Marshalling.IComExposedClass";
}
}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace System.Runtime.InteropServices.Marshalling
+{
+ /// <summary>
+ /// An attribute to mark this class as a type whose instances should be exposed to COM.
+ /// </summary>
+ /// <typeparam name="T">The type that provides information about how to expose the attributed type to COM.</typeparam>
+ [AttributeUsage(AttributeTargets.Class, Inherited = false)]
+ public sealed class ComExposedClassAttribute<T> : Attribute, IComExposedDetails
+ where T : IComExposedClass
+ {
+ /// <inheritdoc />
+ public unsafe ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count) => T.GetComInterfaceEntries(out count);
+ }
+}
qiHResult = 0;
if (!CacheStrategy.TryGetTableInfo(handle, out result))
{
- IUnknownDerivedDetails? details = InterfaceDetailsStrategy.GetIUnknownDerivedDetails(handle);
+ IIUnknownDerivedDetails? details = InterfaceDetailsStrategy.GetIUnknownDerivedDetails(handle);
if (details is null)
{
return false;
// [TODO] Implement some smart/thread-safe caching
private readonly Dictionary<RuntimeTypeHandle, IIUnknownCacheStrategy.TableInfo> _cache = new();
- IIUnknownCacheStrategy.TableInfo IIUnknownCacheStrategy.ConstructTableInfo(RuntimeTypeHandle handle, IUnknownDerivedDetails details, void* ptr)
+ IIUnknownCacheStrategy.TableInfo IIUnknownCacheStrategy.ConstructTableInfo(RuntimeTypeHandle handle, IIUnknownDerivedDetails details, void* ptr)
{
var obj = (void***)ptr;
return new IIUnknownCacheStrategy.TableInfo()
{
public static readonly IIUnknownInterfaceDetailsStrategy Instance = new DefaultIUnknownInterfaceDetailsStrategy();
- public IUnknownDerivedDetails? GetIUnknownDerivedDetails(RuntimeTypeHandle type)
+ public IComExposedDetails? GetComExposedTypeDetails(RuntimeTypeHandle type)
{
- return IUnknownDerivedDetails.GetFromAttribute(type);
+ return IComExposedDetails.GetFromAttribute(type);
+ }
+
+ public IIUnknownDerivedDetails? GetIUnknownDerivedDetails(RuntimeTypeHandle type)
+ {
+ return IIUnknownDerivedDetails.GetFromAttribute(type);
}
}
}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+
+namespace System.Runtime.InteropServices.Marshalling
+{
+ [AttributeUsage(AttributeTargets.Class)]
+ public sealed class GeneratedComClassAttribute : Attribute
+ {
+ }
+}
namespace System.Runtime.InteropServices.Marshalling
{
- public interface IComObjectWrapper<T> { }
-
[AttributeUsage(AttributeTargets.Interface)]
public class GeneratedComInterfaceAttribute : Attribute
{
+++ /dev/null
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-
-// This type is for the COM source generator and implements part of the COM-specific interactions.
-// This API need to be exposed to implement the COM source generator in one form or another.
-
-using System.Collections;
-
-namespace System.Runtime.InteropServices.Marshalling
-{
- public abstract class GeneratedComWrappersBase : ComWrappers
- {
- protected virtual IIUnknownInterfaceDetailsStrategy CreateInterfaceDetailsStrategy() => DefaultIUnknownInterfaceDetailsStrategy.Instance;
-
- protected virtual IIUnknownStrategy CreateIUnknownStrategy() => FreeThreadedStrategy.Instance;
-
- protected virtual IIUnknownCacheStrategy CreateCacheStrategy() => new DefaultCaching();
-
- protected override sealed unsafe object CreateObject(nint externalComObject, CreateObjectFlags flags)
- {
- if (flags.HasFlag(CreateObjectFlags.TrackerObject)
- || flags.HasFlag(CreateObjectFlags.Aggregation))
- {
- throw new NotSupportedException();
- }
-
- var rcw = new ComObject(CreateInterfaceDetailsStrategy(), CreateIUnknownStrategy(), CreateCacheStrategy(), (void*)externalComObject);
- if (flags.HasFlag(CreateObjectFlags.UniqueInstance))
- {
- // Set value on MyComObject to enable the FinalRelease option.
- // This could also be achieved through an internal factory
- // function on ComObject type.
- }
- return rcw;
- }
-
- protected override sealed void ReleaseObjects(IEnumerable objects)
- {
- throw new NotImplementedException();
- }
-
- public ComObject GetOrCreateUniqueObjectForComInstance(nint comInstance, CreateObjectFlags flags)
- {
- if (flags.HasFlag(CreateObjectFlags.Unwrap))
- {
- throw new ArgumentException("Cannot create a unique object if unwrapping a ComWrappers-based COM object is requested.", nameof(flags));
- }
- return (ComObject)GetOrCreateObjectForComInstance(comInstance, flags | CreateObjectFlags.UniqueInstance);
- }
- }
-}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace System.Runtime.InteropServices.Marshalling
+{
+ /// <summary>
+ /// Type level information for managed class types exposed to COM.
+ /// </summary>
+ public unsafe interface IComExposedClass
+ {
+ /// <summary>
+ /// Get the COM interface information to provide to a <see cref="ComWrappers"/> instance to expose this type to COM.
+ /// </summary>
+ /// <param name="count">The number of COM interfaces this type implements.</param>
+ /// <returns>The interface entry information for the interfaces the type implements.</returns>
+ public static abstract ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count);
+ }
+}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Reflection;
+
+namespace System.Runtime.InteropServices.Marshalling
+{
+ /// <summary>
+ /// Details about a managed class type exposed to COM.
+ /// </summary>
+ public unsafe interface IComExposedDetails
+ {
+ /// <summary>
+ /// Get the COM interface information to provide to a <see cref="ComWrappers"/> instance to expose this type to COM.
+ /// </summary>
+ /// <param name="count">The number of COM interfaces this type implements.</param>
+ /// <returns>The interface entry information for the interfaces the type implements.</returns>
+ ComWrappers.ComInterfaceEntry* GetComInterfaceEntries(out int count);
+
+ internal static IComExposedDetails? GetFromAttribute(RuntimeTypeHandle handle)
+ {
+ var type = Type.GetTypeFromHandle(handle);
+ if (type is null)
+ {
+ return null;
+ }
+ return (IComExposedDetails?)type.GetCustomAttribute(typeof(ComExposedClassAttribute<>));
+ }
+ }
+}
/// <param name="ptr">Pointer to the instance to query</param>
/// <param name="info">A <see cref="TableInfo"/> instance</param>
/// <returns>True if success, otherwise false.</returns>
- TableInfo ConstructTableInfo(RuntimeTypeHandle handle, IUnknownDerivedDetails interfaceDetails, void* ptr);
+ TableInfo ConstructTableInfo(RuntimeTypeHandle handle, IIUnknownDerivedDetails interfaceDetails, void* ptr);
/// <summary>
/// Get associated <see cref="TableInfo"/>.
/// <summary>
/// Details for the IUnknown derived interface.
/// </summary>
- public interface IUnknownDerivedDetails
+ public interface IIUnknownDerivedDetails
{
/// <summary>
/// Interface ID.
/// </summary>
public unsafe void** ManagedVirtualMethodTable { get; }
- internal static IUnknownDerivedDetails? GetFromAttribute(RuntimeTypeHandle handle)
+ internal static IIUnknownDerivedDetails? GetFromAttribute(RuntimeTypeHandle handle)
{
var type = Type.GetTypeFromHandle(handle);
if (type is null)
{
return null;
}
- return (IUnknownDerivedDetails?)type.GetCustomAttribute(typeof(IUnknownDerivedAttribute<,>));
+ return (IIUnknownDerivedDetails?)type.GetCustomAttribute(typeof(IUnknownDerivedAttribute<,>));
}
}
}
/// </summary>
/// <param name="type">RuntimeTypeHandle instance</param>
/// <returns>Details if type is known.</returns>
- IUnknownDerivedDetails? GetIUnknownDerivedDetails(RuntimeTypeHandle type);
+ IIUnknownDerivedDetails? GetIUnknownDerivedDetails(RuntimeTypeHandle type);
+
+ /// <summary>
+ /// Given a <see cref="RuntimeTypeHandle"/> get the details about the type that are exposed to COM.
+ /// </summary>
+ /// <param name="type">RuntimeTypeHandle instance</param>
+ /// <returns>Details if type is known.</returns>
+ IComExposedDetails? GetComExposedTypeDetails(RuntimeTypeHandle type);
}
}
namespace System.Runtime.InteropServices.Marshalling
{
+ /// <summary>
+ /// Type level information for an IUnknown-derived interface.
+ /// </summary>
public unsafe interface IIUnknownInterfaceType
{
+ /// <summary>
+ /// The Interface ID (IID) for the interface.
+ /// </summary>
public abstract static Guid Iid { get; }
+
+ /// <summary>
+ /// A pointer to the virtual method table to enable unmanaged callers to call a managed implementation of the interface.
+ /// </summary>
public abstract static void** ManagedVirtualMethodTable { get; }
}
}
namespace System.Runtime.InteropServices.Marshalling
{
- [AttributeUsage(AttributeTargets.Interface)]
- public class IUnknownDerivedAttribute<T, TImpl> : Attribute, IUnknownDerivedDetails
+ /// <summary>
+ /// An attribute to mark this interface as a managed representation of an IUnknown-derived interface.
+ /// </summary>
+ /// <typeparam name="T">The type that provides type-level information about the interface.</typeparam>
+ /// <typeparam name="TImpl">The type to use for calling from managed callers to unmanaged implementations of the interface.</typeparam>
+ [AttributeUsage(AttributeTargets.Interface, Inherited = false)]
+ public class IUnknownDerivedAttribute<T, TImpl> : Attribute, IIUnknownDerivedDetails
where T : IIUnknownInterfaceType
{
public IUnknownDerivedAttribute()
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
-// Types that are only needed for the VTable source generator or to provide abstract concepts that the COM generator would use under the hood.
-// These are types that we can exclude from the API proposals and either inline into the generated code, provide as file-scoped types, or not provide publicly (indicated by comments on each type).
-
using System.Collections;
+using System.Reflection;
namespace System.Runtime.InteropServices.Marshalling
{
- public abstract class StrategyBasedComWrappers : InteropServices.ComWrappers
+ public class StrategyBasedComWrappers : ComWrappers
{
public static IIUnknownInterfaceDetailsStrategy DefaultIUnknownInterfaceDetailsStrategy { get; } = Marshalling.DefaultIUnknownInterfaceDetailsStrategy.Instance;
protected virtual IIUnknownCacheStrategy CreateCacheStrategy() => CreateDefaultCacheStrategy();
+ protected override sealed unsafe ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
+ {
+ if (obj.GetType().GetCustomAttribute(typeof(ComExposedClassAttribute<>)) is IComExposedDetails details)
+ {
+ return details.GetComInterfaceEntries(out count);
+ }
+ count = 0;
+ return null;
+ }
+
protected override sealed unsafe object CreateObject(nint externalComObject, CreateObjectFlags flags)
{
if (flags.HasFlag(CreateObjectFlags.TrackerObject)
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+using Xunit;
+
+namespace ComInterfaceGenerator.Tests
+{
+ unsafe partial class NativeExportsNE
+ {
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = "set_com_object_data")]
+ public static partial void SetComObjectData(void* obj, int data);
+
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_com_object_data")]
+ public static partial int GetComObjectData(void* obj);
+ }
+
+ [GeneratedComClass]
+ partial class ManagedObjectExposedToCom : IComInterface1
+ {
+ public int Data { get; set; }
+ int IComInterface1.GetData() => Data;
+ void IComInterface1.SetData(int n) => Data = n;
+ }
+
+ [GeneratedComClass]
+ partial class DerivedComObject : ManagedObjectExposedToCom
+ {
+ }
+
+ public unsafe class GeneratedComClassTests
+ {
+ [Fact]
+ public void ComInstanceProvidesInterfaceForDirectlyImplementedComInterface()
+ {
+ ManagedObjectExposedToCom obj = new();
+ StrategyBasedComWrappers wrappers = new();
+ nint ptr = wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None);
+ Assert.NotEqual(0, ptr);
+ var iid = typeof(IComInterface1).GUID;
+ Assert.Equal(0, Marshal.QueryInterface(ptr, ref iid, out nint iComInterface));
+ Assert.NotEqual(0, iComInterface);
+ Marshal.Release(iComInterface);
+ Marshal.Release(ptr);
+ }
+
+ [Fact]
+ public void ComInstanceProvidesInterfaceForIndirectlyImplementedComInterface()
+ {
+ DerivedComObject obj = new();
+ StrategyBasedComWrappers wrappers = new();
+ nint ptr = wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None);
+ Assert.NotEqual(0, ptr);
+ var iid = typeof(IComInterface1).GUID;
+ Assert.Equal(0, Marshal.QueryInterface(ptr, ref iid, out nint iComInterface));
+ Assert.NotEqual(0, iComInterface);
+ Marshal.Release(iComInterface);
+ Marshal.Release(ptr);
+ }
+
+ [Fact]
+ public void CallsToComInterfaceWriteChangesToManagedObject()
+ {
+ ManagedObjectExposedToCom obj = new();
+ StrategyBasedComWrappers wrappers = new();
+ void* ptr = (void*)wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None);
+ Assert.NotEqual(0, (nint)ptr);
+ obj.Data = 3;
+ Assert.Equal(3, obj.Data);
+ NativeExportsNE.SetComObjectData(ptr, 42);
+ Assert.Equal(42, obj.Data);
+ Marshal.Release((nint)ptr);
+ }
+
+ [Fact]
+ public void CallsToComInterfaceReadChangesFromManagedObject()
+ {
+ ManagedObjectExposedToCom obj = new();
+ StrategyBasedComWrappers wrappers = new();
+ void* ptr = (void*)wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None);
+ Assert.NotEqual(0, (nint)ptr);
+ obj.Data = 3;
+ Assert.Equal(3, obj.Data);
+ obj.Data = 12;
+ Assert.Equal(obj.Data, NativeExportsNE.GetComObjectData(ptr));
+ Marshal.Release((nint)ptr);
+ }
+ }
+}
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+
+namespace ComInterfaceGenerator.Tests
+{
+ [GeneratedComInterface]
+ [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
+ [Guid("2c3f9903-b586-46b1-881b-adfce9af47b1")]
+ public partial interface IComInterface1
+ {
+ int GetData();
+ void SetData(int n);
+ }
+}
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
-using System;
using System.Collections;
using System.Diagnostics;
using System.Linq;
namespace ComInterfaceGenerator.Tests;
-[GeneratedComInterface]
-[InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
-[Guid("2c3f9903-b586-46b1-881b-adfce9af47b1")]
-public partial interface IComInterface1
+internal unsafe partial class NativeExportsNE
{
- int GetData();
- void SetData(int n);
-}
-
-internal sealed unsafe partial class MyGeneratedComWrappers : StrategyBasedComWrappers
-{
- protected sealed override unsafe ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) => throw new UnreachableException("Not creating CCWs yet");
-}
-
-public static unsafe partial class Native
-{
- [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "get_com_object")]
+ [LibraryImport(NativeExportsNE_Binary, EntryPoint = "get_com_object")]
public static partial void* NewNativeObject();
}
[Fact]
public unsafe void CallRcwFromGeneratedComInterface()
{
- var ptr = Native.NewNativeObject(); // new_native_object
- var cw = new MyGeneratedComWrappers();
+ var ptr = NativeExportsNE.NewNativeObject(); // new_native_object
+ var cw = new StrategyBasedComWrappers();
var obj = cw.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None);
var intObj = (IComInterface1)obj;
--- /dev/null
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Linq;
+using System.Threading.Tasks;
+using Microsoft.CodeAnalysis;
+using Microsoft.Interop.UnitTests;
+using Xunit;
+
+namespace ComInterfaceGenerator.Unit.Tests
+{
+ public class ComClassGeneratorOutputShape
+ {
+ [Fact]
+ public async Task SingleComClass()
+ {
+ string source = """
+ using System.Runtime.InteropServices;
+ using System.Runtime.InteropServices.Marshalling;
+
+ [GeneratedComInterface]
+ partial interface INativeAPI
+ {
+ }
+
+ [GeneratedComClass]
+ partial class C : INativeAPI {}
+ """;
+ Compilation comp = await TestUtils.CreateCompilation(source);
+ TestUtils.AssertPreSourceGeneratorCompilation(comp);
+
+ var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.ComClassGenerator());
+ TestUtils.AssertPostSourceGeneratorCompilation(newComp);
+ // We'll create one syntax tree for the new interface.
+ Assert.Equal(comp.SyntaxTrees.Count() + 1, newComp.SyntaxTrees.Count());
+
+ VerifyShape(newComp, "C");
+ }
+
+ [Fact]
+ public async Task MultipleComClasses()
+ {
+ string source = $$"""
+ using System.Runtime.InteropServices;
+ using System.Runtime.InteropServices.Marshalling;
+
+ [GeneratedComInterface]
+ partial interface I
+ {
+ }
+ [GeneratedComInterface]
+ partial interface J
+ {
+ }
+
+ [GeneratedComClass]
+ partial class C : I, J
+ {
+ }
+
+ [GeneratedComClass]
+ partial class D : I, J
+ {
+ }
+
+ [GeneratedComClass]
+ partial class E : C
+ {
+ }
+ """;
+ Compilation comp = await TestUtils.CreateCompilation(source);
+ TestUtils.AssertPreSourceGeneratorCompilation(comp);
+
+ var newComp = TestUtils.RunGenerators(comp, out _, new Microsoft.Interop.ComClassGenerator());
+ TestUtils.AssertPostSourceGeneratorCompilation(newComp);
+ // We'll create one syntax tree per user-defined interface.
+ Assert.Equal(comp.SyntaxTrees.Count() + 3, newComp.SyntaxTrees.Count());
+
+ VerifyShape(newComp, "C");
+ VerifyShape(newComp, "D");
+ VerifyShape(newComp, "E");
+ }
+ private static void VerifyShape(Compilation comp, string userDefinedClassMetadataName)
+ {
+ INamedTypeSymbol? userDefinedClass = comp.Assembly.GetTypeByMetadataName(userDefinedClassMetadataName);
+ Assert.NotNull(userDefinedClass);
+
+ INamedTypeSymbol? comExposedClassAttribute = comp.GetTypeByMetadataName("System.Runtime.InteropServices.Marshalling.ComExposedClassAttribute`1");
+
+ Assert.NotNull(comExposedClassAttribute);
+
+ AttributeData iUnknownDerivedAttribute = Assert.Single(
+ userDefinedClass.GetAttributes(),
+ attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.OriginalDefinition, comExposedClassAttribute));
+
+ Assert.Collection(Assert.IsAssignableFrom<INamedTypeSymbol>(iUnknownDerivedAttribute.AttributeClass).TypeArguments,
+ infoType =>
+ {
+ Assert.True(Assert.IsAssignableFrom<INamedTypeSymbol>(infoType).IsFileLocal);
+ });
+ }
+ }
+}
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
}
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
}
{
void Bar() {}
}
-
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
}
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
}
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
}
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
}
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
}
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
}
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
}
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
}
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
}
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(
_usings + snippet,
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(
_usings + snippet,
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(
_usings + snippet,
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(
_usings + snippet,
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(
_usings + snippet,
{
void Bar() {}
}
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(
_usings + snippet,
[GeneratedComInterface]
partial interface IFoo { }
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
}
[GeneratedComInterface]
partial interface IFoo { }
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(_usings + snippet);
}
[GeneratedComInterface]
partial interface IFoo { }
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(
_usings + snippet,
[GeneratedComInterface]
partial interface IFoo { }
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(
_usings + snippet,
[GeneratedComInterface]
partial interface IFoo { }
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(
_usings + snippet,
[GeneratedComInterface]
partial interface IFoo { }
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(
_usings + snippet,
[GeneratedComInterface]
partial interface IFoo { }
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(
_usings + snippet,
[GeneratedComInterface]
partial interface IFoo { }
-
- public unsafe partial class MyComWrappers : GeneratedComWrappersBase
- {
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) {count = 0; return null;}
- }
-
""";
await VerifyCS.VerifyAnalyzerAsync(
_usings + snippet,
namespace NativeExports;
-public static unsafe class ComInterfaceGeneratorExports
+public static unsafe class ComInterfaces
{
interface IComInterface1
{
[UnmanagedCallersOnly(EntryPoint = "get_com_object")]
public static void* CreateComObject()
{
- MyComWrapper cw = new();
var myObject = new MyObject();
- nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);
+ nint ptr = ComWrappersInstance.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);
return (void*)ptr;
}
+ [UnmanagedCallersOnly(EntryPoint = "set_com_object_data")]
+ public static void SetComObjectData(void* ptr, int value)
+ {
+ IComInterface1 obj = (IComInterface1)ComWrappersInstance.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None);
+ obj.SetData(value);
+ }
+
+ [UnmanagedCallersOnly(EntryPoint = "get_com_object_data")]
+ public static int GetComObjectData(void* ptr)
+ {
+ IComInterface1 obj = (IComInterface1)ComWrappersInstance.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None);
+ return obj.GetData();
+ }
+
+ private static readonly ComWrappers ComWrappersInstance = new MyComWrapper();
+
class MyComWrapper : System.Runtime.InteropServices.ComWrappers
{
- static void* _s_comInterface1VTable = null;
- static void* s_comInterface1VTable
+ static volatile void* s_comInterface1VTable = null;
+ static void* IComInterface1VTable
{
get
{
- if (MyComWrapper._s_comInterface1VTable != null)
- return _s_comInterface1VTable;
- void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComInterfaceGeneratorExports), sizeof(void*) * 5);
+ if (s_comInterface1VTable != null)
+ return s_comInterface1VTable;
+ void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComInterfaces), sizeof(void*) * 5);
GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease);
vtable[0] = (void*)fpQueryInterface;
vtable[1] = (void*)fpAddReference;
vtable[2] = (void*)fpRelease;
vtable[3] = (delegate* unmanaged<void*, int*, int>)&MyObject.ABI.GetData;
vtable[4] = (delegate* unmanaged<void*, int, int>)&MyObject.ABI.SetData;
- _s_comInterface1VTable = vtable;
- return _s_comInterface1VTable;
+ s_comInterface1VTable = vtable;
+ return s_comInterface1VTable;
}
}
- protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
+
+ static volatile ComInterfaceEntry* s_myObjectComInterfaceEntries = null;
+ static ComInterfaceEntry* MyObjectComInterfaceEntries
{
- if (obj is MyObject)
+ get
{
+ if (s_myObjectComInterfaceEntries != null)
+ return s_myObjectComInterfaceEntries;
+
ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(MyObject), sizeof(ComInterfaceEntry));
comInterfaceEntry->IID = IComInterface1.IID;
- comInterfaceEntry->Vtable = (nint)s_comInterface1VTable;
+ comInterfaceEntry->Vtable = (nint)IComInterface1VTable;
+ s_myObjectComInterfaceEntries = comInterfaceEntry;
+ return s_myObjectComInterfaceEntries;
+ }
+ }
+ protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
+ {
+ if (obj is MyObject)
+ {
count = 1;
- return comInterfaceEntry;
+ return MyObjectComInterfaceEntries;
}
count = 0;
return null;
{
return null;
}
- return new IComInterface1Impl(ptr);
+ return new IComInterface1Impl(IComInterfaceImpl);
}
protected override void ReleaseObjects(IEnumerable objects) { }
}
// Wrapper for calling CCWs from the ComInterfaceGenerator
- class IComInterface1Impl : IComInterface1
+ sealed class IComInterface1Impl : IComInterface1
{
nint _ptr;
_ptr = @this;
}
+ ~IComInterface1Impl()
+ {
+ int refCount = Marshal.Release(_ptr);
+ }
+
int GetData(nint inst)
{
int value;