Rename VirtualMethodTableManagedImplementation and add manual CCW to test generated...
authorJackson Schuster <36744439+jtschuster@users.noreply.github.com>
Fri, 17 Mar 2023 16:43:24 +0000 (09:43 -0700)
committerGitHub <noreply@github.com>
Fri, 17 Mar 2023 16:43:24 +0000 (09:43 -0700)
src/libraries/System.Private.CoreLib/src/System.Private.CoreLib.Shared.projitems
src/libraries/System.Runtime.InteropServices/gen/Microsoft.Interop.SourceGeneration/GeneratedStatements.cs
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IUnknownDerivedAttribute.cs
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/IUnknownDerivedDetails.cs
src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/StrategyBasedComWrappers.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwTests.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaces.cs [new file with mode: 0644]

index 35dbe01..77ba3fe 100644 (file)
     <Compile Include="$(MSBuildThisFileDirectory)System\Numerics\IUnaryPlusOperators.cs" />
     <Compile Include="$(MSBuildThisFileDirectory)System\Numerics\IUnsignedNumber.cs" />
   </ItemGroup>
-</Project>
+</Project>
\ No newline at end of file
index 3bc1ecc..e8e3de0 100644 (file)
@@ -2,11 +2,8 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
-using System.Collections.Generic;
 using System.Collections.Immutable;
-using System.Diagnostics;
 using System.Linq;
-using System.Text;
 using Microsoft.CodeAnalysis;
 using Microsoft.CodeAnalysis.CSharp;
 using Microsoft.CodeAnalysis.CSharp.Syntax;
index ded6f03..8624e00 100644 (file)
@@ -21,6 +21,6 @@ namespace System.Runtime.InteropServices.Marshalling
         public Type Implementation => typeof(TImpl);
 
         /// <inheritdoc />
-        public unsafe void* VirtualMethodTableManagedImplementation => T.ManagedVirtualMethodTable;
+        public unsafe void** ManagedVirtualMethodTable => T.ManagedVirtualMethodTable;
     }
 }
index ecd5a54..0c6a9e2 100644 (file)
@@ -26,7 +26,7 @@ namespace System.Runtime.InteropServices.Marshalling
         /// <summary>
         /// A pointer to the virtual method table to enable unmanaged callers to call a managed implementation of the interface.
         /// </summary>
-        public unsafe void* VirtualMethodTableManagedImplementation { get; }
+        public unsafe void** ManagedVirtualMethodTable { get; }
 
         internal static IUnknownDerivedDetails? GetFromAttribute(RuntimeTypeHandle handle)
         {
diff --git a/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/StrategyBasedComWrappers.cs b/src/libraries/System.Runtime.InteropServices/tests/Ancillary.Interop/StrategyBasedComWrappers.cs
new file mode 100644 (file)
index 0000000..e6e3442
--- /dev/null
@@ -0,0 +1,48 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+// Types that are only needed for the VTable source generator or to provide abstract concepts that the COM generator would use under the hood.
+// These are types that we can exclude from the API proposals and either inline into the generated code, provide as file-scoped types, or not provide publicly (indicated by comments on each type).
+
+using System.Collections;
+
+namespace System.Runtime.InteropServices.Marshalling
+{
+    public abstract class StrategyBasedComWrappers : InteropServices.ComWrappers
+    {
+        public static IIUnknownInterfaceDetailsStrategy DefaultIUnknownInterfaceDetailsStrategy { get; } = Marshalling.DefaultIUnknownInterfaceDetailsStrategy.Instance;
+
+        public static IIUnknownStrategy DefaultIUnknownStrategy { get; } = FreeThreadedStrategy.Instance;
+
+        protected static IIUnknownCacheStrategy CreateDefaultCacheStrategy() => new DefaultCaching();
+
+        protected virtual IIUnknownInterfaceDetailsStrategy GetOrCreateInterfaceDetailsStrategy() => DefaultIUnknownInterfaceDetailsStrategy;
+
+        protected virtual IIUnknownStrategy GetOrCreateIUnknownStrategy() => DefaultIUnknownStrategy;
+
+        protected virtual IIUnknownCacheStrategy CreateCacheStrategy() => CreateDefaultCacheStrategy();
+
+        protected override sealed unsafe object CreateObject(nint externalComObject, CreateObjectFlags flags)
+        {
+            if (flags.HasFlag(CreateObjectFlags.TrackerObject)
+                || flags.HasFlag(CreateObjectFlags.Aggregation))
+            {
+                throw new NotSupportedException();
+            }
+
+            var rcw = new ComObject(GetOrCreateInterfaceDetailsStrategy(), GetOrCreateIUnknownStrategy(), CreateCacheStrategy(), (void*)externalComObject);
+            if (flags.HasFlag(CreateObjectFlags.UniqueInstance))
+            {
+                // Set value on MyComObject to enable the FinalRelease option.
+                // This could also be achieved through an internal factory
+                // function on ComObject type.
+            }
+            return rcw;
+        }
+
+        protected override sealed void ReleaseObjects(IEnumerable objects)
+        {
+            throw new NotImplementedException();
+        }
+    }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwTests.cs
new file mode 100644 (file)
index 0000000..0841e6e
--- /dev/null
@@ -0,0 +1,52 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections;
+using System.Diagnostics;
+using System.Linq;
+using System.Reflection;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+using Xunit;
+using Xunit.Sdk;
+
+namespace ComInterfaceGenerator.Tests;
+
+[GeneratedComInterface]
+[InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
+[Guid("2c3f9903-b586-46b1-881b-adfce9af47b1")]
+public partial interface IComInterface1
+{
+    int GetData();
+    void SetData(int n);
+}
+
+internal sealed unsafe partial class MyGeneratedComWrappers : StrategyBasedComWrappers
+{
+    protected sealed override unsafe ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) => throw new UnreachableException("Not creating CCWs yet");
+}
+
+public static unsafe partial class Native
+{
+    [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "get_com_object")]
+    public static partial void* NewNativeObject();
+}
+
+
+public class RcwTests
+{
+    [Fact]
+    public unsafe void CallRcwFromGeneratedComInterface()
+    {
+        var ptr = Native.NewNativeObject(); // new_native_object
+        var cw = new MyGeneratedComWrappers();
+        var obj = cw.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None);
+
+        var intObj = (IComInterface1)obj;
+        Assert.Equal(0, intObj.GetData());
+        intObj.SetData(2);
+        Assert.Equal(2, intObj.GetData());
+    }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaces.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaces.cs
new file mode 100644 (file)
index 0000000..52db26b
--- /dev/null
@@ -0,0 +1,167 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.ObjectiveC;
+using System.Text;
+using System.Threading.Tasks;
+using static System.Runtime.InteropServices.ComWrappers;
+
+namespace NativeExports;
+
+public static unsafe class ComInterfaceGeneratorExports
+{
+    interface IComInterface1
+    {
+        public int GetData();
+
+        public void SetData(int x);
+
+        public static Guid IID = new Guid("2c3f9903-b586-46b1-881b-adfce9af47b1");
+    }
+
+    // Call from another assembly to get a ptr to make an RCW
+    [UnmanagedCallersOnly(EntryPoint = "get_com_object")]
+    public static void* CreateComObject()
+    {
+        MyComWrapper cw = new();
+        var myObject = new MyObject();
+        nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);
+
+        return (void*)ptr;
+    }
+
+    class MyComWrapper : System.Runtime.InteropServices.ComWrappers
+    {
+        static void* _s_comInterface1VTable = null;
+        static void* s_comInterface1VTable
+        {
+            get
+            {
+                if (MyComWrapper._s_comInterface1VTable != null)
+                    return _s_comInterface1VTable;
+                void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ComInterfaceGeneratorExports), sizeof(void*) * 5);
+                GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease);
+                vtable[0] = (void*)fpQueryInterface;
+                vtable[1] = (void*)fpAddReference;
+                vtable[2] = (void*)fpRelease;
+                vtable[3] = (delegate* unmanaged<void*, int*, int>)&MyObject.ABI.GetData;
+                vtable[4] = (delegate* unmanaged<void*, int, int>)&MyObject.ABI.SetData;
+                _s_comInterface1VTable = vtable;
+                return _s_comInterface1VTable;
+            }
+        }
+        protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
+        {
+            if (obj is MyObject)
+            {
+                ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(MyObject), sizeof(ComInterfaceEntry));
+                comInterfaceEntry->IID = IComInterface1.IID;
+                comInterfaceEntry->Vtable = (nint)s_comInterface1VTable;
+                count = 1;
+                return comInterfaceEntry;
+            }
+            count = 0;
+            return null;
+        }
+        protected override object CreateObject(nint ptr, CreateObjectFlags flags)
+        {
+            int hr = Marshal.QueryInterface(ptr, ref IComInterface1.IID, out IntPtr IComInterfaceImpl);
+            if (hr != 0)
+            {
+                return null;
+            }
+            return new IComInterface1Impl(ptr);
+        }
+
+        protected override void ReleaseObjects(IEnumerable objects) { }
+    }
+
+    // Wrapper for calling CCWs from the ComInterfaceGenerator
+    class IComInterface1Impl : IComInterface1
+    {
+        nint _ptr;
+
+        public IComInterface1Impl(nint @this)
+        {
+            _ptr = @this;
+        }
+
+        int GetData(nint inst)
+        {
+            int value;
+            int hr = ((delegate* unmanaged<nint, int*, int>)(*(*(void***)inst + 3)))(inst, &value);
+            if (hr != 0)
+            {
+                Marshal.GetExceptionForHR(hr);
+            }
+            return value;
+        }
+
+        void SetData(nint inst, int newValue)
+        {
+            int hr = ((delegate* unmanaged<nint, int, int>)(*(*(void***)inst + 4)))(inst, newValue);
+            if (hr != 0)
+            {
+                Marshal.GetExceptionForHR(hr);
+            }
+        }
+
+        int IComInterface1.GetData() => GetData(_ptr);
+
+        void IComInterface1.SetData(int newValue) => SetData(_ptr, newValue);
+    }
+
+    class MyObject : IComInterface1
+    {
+        int _data = 0;
+
+        int IComInterface1.GetData()
+        {
+            return _data;
+        }
+        void IComInterface1.SetData(int x)
+        {
+            _data = x;
+        }
+
+        // Provides function pointers in the COM format to use in COM VTables
+        public static class ABI
+        {
+
+            [UnmanagedCallersOnly]
+            public static int GetData(void* @this, int* value)
+            {
+                try
+                {
+                    *value = ComInterfaceDispatch.GetInstance<IComInterface1>((ComInterfaceDispatch*)@this).GetData();
+                    return 0;
+                }
+                catch (Exception e)
+                {
+                    return e.HResult;
+                }
+            }
+
+            [UnmanagedCallersOnly]
+            public static int SetData(void* @this, int newValue)
+            {
+                try
+                {
+                    ComInterfaceDispatch.GetInstance<IComInterface1>((ComInterfaceDispatch*)@this).SetData(newValue);
+                    return 0;
+                }
+                catch (Exception e)
+                {
+                    return e.HResult;
+                }
+            }
+        }
+    }
+}