Add Array marshalling tests for ComInterfaceGenerator (#84509)
authorJackson Schuster <36744439+jtschuster@users.noreply.github.com>
Thu, 13 Apr 2023 20:59:55 +0000 (15:59 -0500)
committerGitHub <noreply@github.com>
Thu, 13 Apr 2023 20:59:55 +0000 (13:59 -0700)
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/ComInterfaceGenerator.Tests.csproj
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/GeneratedComClassTests.cs
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IGetAndSetIntTests.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IGetIntArrayTests.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/ArrayMarshalling.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/GetAndSetInt.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/NativeExports.csproj
src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IGetAndSetInt.cs [new file with mode: 0644]
src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IGetIntArray.cs [new file with mode: 0644]

index b894796..74a8534 100644 (file)
@@ -1,4 +1,4 @@
-<Project Sdk="Microsoft.NET.Sdk">
+<Project Sdk="Microsoft.NET.Sdk">
   <PropertyGroup>
     <TargetFramework>$(NetCoreAppCurrent)</TargetFramework>
     <IsPackable>false</IsPackable>
@@ -13,8 +13,9 @@
 
   <ItemGroup>
     <Compile Include="$(CommonPath)DisableRuntimeMarshalling.cs" Link="Common\DisableRuntimeMarshalling.cs" />
+    <Compile Include="..\TestAssets\SharedTypes\ComInterfaces\*.cs" Link="ComInterfaces\%(FileName).cs" />
   </ItemGroup>
-  
+
   <ItemGroup>
     <ProjectReference Include="..\..\gen\ComInterfaceGenerator\ComInterfaceGenerator.csproj" ReferenceOutputAssembly="false" OutputItemType="Analyzer" />
     <ProjectReference Include="..\Ancillary.Interop\Ancillary.Interop.csproj" />
index 5481817..737a9f9 100644 (file)
@@ -4,6 +4,7 @@
 using System;
 using System.Runtime.InteropServices;
 using System.Runtime.InteropServices.Marshalling;
+using SharedTypes.ComInterfaces;
 using Xunit;
 
 namespace ComInterfaceGenerator.Tests
@@ -18,11 +19,11 @@ namespace ComInterfaceGenerator.Tests
     }
 
     [GeneratedComClass]
-    partial class ManagedObjectExposedToCom : IComInterface1
+    partial class ManagedObjectExposedToCom : IGetAndSetInt
     {
         public int Data { get; set; }
-        int IComInterface1.GetData() => Data;
-        void IComInterface1.SetData(int n) => Data = n;
+        int IGetAndSetInt.GetInt() => Data;
+        void IGetAndSetInt.SetInt(int n) => Data = n;
     }
 
     [GeneratedComClass]
@@ -39,7 +40,7 @@ namespace ComInterfaceGenerator.Tests
             StrategyBasedComWrappers wrappers = new();
             nint ptr = wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None);
             Assert.NotEqual(0, ptr);
-            var iid = typeof(IComInterface1).GUID;
+            var iid = typeof(IGetAndSetInt).GUID;
             Assert.Equal(0, Marshal.QueryInterface(ptr, ref iid, out nint iComInterface));
             Assert.NotEqual(0, iComInterface);
             Marshal.Release(iComInterface);
@@ -53,7 +54,7 @@ namespace ComInterfaceGenerator.Tests
             StrategyBasedComWrappers wrappers = new();
             nint ptr = wrappers.GetOrCreateComInterfaceForObject(obj, CreateComInterfaceFlags.None);
             Assert.NotEqual(0, ptr);
-            var iid = typeof(IComInterface1).GUID;
+            var iid = typeof(IGetAndSetInt).GUID;
             Assert.Equal(0, Marshal.QueryInterface(ptr, ref iid, out nint iComInterface));
             Assert.NotEqual(0, iComInterface);
             Marshal.Release(iComInterface);
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IGetAndSetIntTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IGetAndSetIntTests.cs
new file mode 100644 (file)
index 0000000..322c459
--- /dev/null
@@ -0,0 +1,36 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections;
+using System.Diagnostics;
+using System.Linq;
+using System.Reflection;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+using SharedTypes.ComInterfaces;
+using Xunit;
+using Xunit.Sdk;
+
+namespace ComInterfaceGenerator.Tests
+{
+    public unsafe partial class IGetAndSetIntTests
+    {
+
+        [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_get_and_set_int")]
+        public static partial void* NewNativeObject();
+
+        [Fact]
+        public unsafe void CallRcwFromGeneratedComInterface()
+        {
+            var ptr = NewNativeObject(); // new_native_object
+            var cw = new StrategyBasedComWrappers();
+            var obj = cw.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None);
+
+            var intObj = (IGetAndSetInt)obj;
+            Assert.Equal(0, intObj.GetInt());
+            intObj.SetInt(2);
+            Assert.Equal(2, intObj.GetInt());
+        }
+    }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IGetIntArrayTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/IGetIntArrayTests.cs
new file mode 100644 (file)
index 0000000..1769ac7
--- /dev/null
@@ -0,0 +1,33 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Collections;
+using System.Diagnostics;
+using System.Linq;
+using System.Reflection;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+using SharedTypes.ComInterfaces;
+using Xunit;
+using Xunit.Sdk;
+
+namespace ComInterfaceGenerator.Tests
+{
+    public unsafe partial class IGetIntArrayTests
+    {
+        [LibraryImport(NativeExportsNE.NativeExportsNE_Binary, EntryPoint = "new_get_and_set_int_array")]
+        public static partial void* NewNativeObject();
+
+        [Fact]
+        public unsafe void CallRcwFromGeneratedComInterface()
+        {
+            var ptr = NewNativeObject(); // new_native_object
+            var cw = new StrategyBasedComWrappers();
+            var obj = cw.GetOrCreateObjectForComInstance((nint)ptr, CreateObjectFlags.None);
+
+            var intObj = (IGetIntArray)obj;
+            Assert.Equal<int>(new int[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }, intObj.GetInts());
+        }
+    }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/ArrayMarshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/ArrayMarshalling.cs
new file mode 100644 (file)
index 0000000..c8f2292
--- /dev/null
@@ -0,0 +1,98 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+using System.Text;
+using System.Threading.Tasks;
+using SharedTypes.ComInterfaces;
+using static System.Runtime.InteropServices.ComWrappers;
+
+namespace NativeExports.ComInterfaceGenerator
+{
+
+    public static unsafe class ArrayMarshalling
+    {
+
+        [UnmanagedCallersOnly(EntryPoint = "new_get_and_set_int_array")]
+        public static void* CreateComObject()
+        {
+            MyComWrapper cw = new();
+            var myObject = new ImplementingObject();
+            nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);
+
+            return (void*)ptr;
+        }
+
+        class MyComWrapper : ComWrappers
+        {
+            static void* _s_comInterface1VTable = null;
+            static void* GetIntArrayVTable
+            {
+                get
+                {
+                    if (MyComWrapper._s_comInterface1VTable != null)
+                        return _s_comInterface1VTable;
+                    void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ImplementingObject), sizeof(void*) * 4);
+                    GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease);
+                    vtable[0] = (void*)fpQueryInterface;
+                    vtable[1] = (void*)fpAddReference;
+                    vtable[2] = (void*)fpRelease;
+                    vtable[3] = (delegate* unmanaged<void*, int**, int>)&ImplementingObject.ABI.GetInts;
+                    _s_comInterface1VTable = vtable;
+                    return _s_comInterface1VTable;
+                }
+            }
+            protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
+            {
+                if (obj is ImplementingObject)
+                {
+                    ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ImplementingObject), sizeof(ComInterfaceEntry));
+                    comInterfaceEntry->IID = new Guid(IGetIntArray._guid);
+                    comInterfaceEntry->Vtable = (nint)GetIntArrayVTable;
+                    count = 1;
+                    return comInterfaceEntry;
+                }
+                count = 0;
+                return null;
+            }
+
+            protected override object? CreateObject(nint externalComObject, CreateObjectFlags flags) => throw new NotImplementedException();
+            protected override void ReleaseObjects(IEnumerable objects) => throw new NotImplementedException();
+        }
+        class ImplementingObject : IGetIntArray
+        {
+            int[] _data = new int[10] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
+
+            public int[] GetInts() => _data;
+
+            public static class ABI
+            {
+                [UnmanagedCallersOnly]
+                public static int GetInts(void* @this, int** values)
+                {
+
+                    try
+                    {
+                        int[] arr = ComInterfaceDispatch.GetInstance<IGetIntArray>((ComInterfaceDispatch*)@this).GetInts();
+                        *values = (int*)Marshal.AllocCoTaskMem(sizeof(int) * arr.Length);
+                        for (int i = 0; i < arr.Length; i++)
+                        {
+                            (*values)[i] = arr[i];
+                        }
+                        return 0;
+                    }
+                    catch (Exception e)
+                    {
+                        return e.HResult;
+                    }
+                }
+            }
+        }
+    }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/GetAndSetInt.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/NativeExports/ComInterfaceGenerator/GetAndSetInt.cs
new file mode 100644 (file)
index 0000000..66faa8c
--- /dev/null
@@ -0,0 +1,118 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+using System.Runtime.InteropServices.ObjectiveC;
+using System.Text;
+using System.Threading.Tasks;
+using SharedTypes.ComInterfaces;
+using static System.Runtime.InteropServices.ComWrappers;
+
+namespace NativeExports.ComInterfaceGenerator
+{
+    public static unsafe class GetAndSetInt
+    {
+        // Call from another assembly to get a ptr to make an RCW
+        [UnmanagedCallersOnly(EntryPoint = "new_get_and_set_int")]
+        public static void* CreateComObject()
+        {
+            MyComWrapper cw = new();
+            var myObject = new ImplementingObject();
+            nint ptr = cw.GetOrCreateComInterfaceForObject(myObject, CreateComInterfaceFlags.None);
+
+            return (void*)ptr;
+        }
+
+        class MyComWrapper : ComWrappers
+        {
+            static void* _s_comInterface1VTable = null;
+            static void* s_comInterface1VTable
+            {
+                get
+                {
+                    if (MyComWrapper._s_comInterface1VTable != null)
+                        return _s_comInterface1VTable;
+                    void** vtable = (void**)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(GetAndSetInt), sizeof(void*) * 5);
+                    GetIUnknownImpl(out var fpQueryInterface, out var fpAddReference, out var fpRelease);
+                    vtable[0] = (void*)fpQueryInterface;
+                    vtable[1] = (void*)fpAddReference;
+                    vtable[2] = (void*)fpRelease;
+                    vtable[3] = (delegate* unmanaged<void*, int*, int>)&ImplementingObject.ABI.GetInt;
+                    vtable[4] = (delegate* unmanaged<void*, int, int>)&ImplementingObject.ABI.SetInt;
+                    _s_comInterface1VTable = vtable;
+                    return _s_comInterface1VTable;
+                }
+            }
+            protected override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
+            {
+                if (obj is ImplementingObject)
+                {
+                    ComInterfaceEntry* comInterfaceEntry = (ComInterfaceEntry*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(ImplementingObject), sizeof(ComInterfaceEntry));
+                    comInterfaceEntry->IID = new Guid(IGetAndSetInt._guid);
+                    comInterfaceEntry->Vtable = (nint)s_comInterface1VTable;
+                    count = 1;
+                    return comInterfaceEntry;
+                }
+                count = 0;
+                return null;
+            }
+
+            protected override object? CreateObject(nint externalComObject, CreateObjectFlags flags) => throw new NotImplementedException();
+            protected override void ReleaseObjects(IEnumerable objects) => throw new NotImplementedException();
+        }
+
+        class ImplementingObject : IGetAndSetInt
+        {
+            int _data = 0;
+
+            int IGetAndSetInt.GetInt()
+            {
+                return _data;
+            }
+            void IGetAndSetInt.SetInt(int x)
+            {
+                _data = x;
+            }
+
+            // Provides function pointers in the COM format to use in COM VTables
+            public static class ABI
+            {
+
+                [UnmanagedCallersOnly]
+                public static int GetInt(void* @this, int* value)
+                {
+                    try
+                    {
+                        *value = ComInterfaceDispatch.GetInstance<IGetAndSetInt>((ComInterfaceDispatch*)@this).GetInt();
+                        return 0;
+                    }
+                    catch (Exception e)
+                    {
+                        return e.HResult;
+                    }
+                }
+
+                [UnmanagedCallersOnly]
+                public static int SetInt(void* @this, int newValue)
+                {
+                    try
+                    {
+                        ComInterfaceDispatch.GetInstance<IGetAndSetInt>((ComInterfaceDispatch*)@this).SetInt(newValue);
+                        return 0;
+                    }
+                    catch (Exception e)
+                    {
+                        return e.HResult;
+                    }
+                }
+            }
+        }
+    }
+}
index e577711..592a6f0 100644 (file)
       '$(TargetOS)' == 'tvossimulator'">true</_TargetsAppleOS>
   </PropertyGroup>
 
+  <ItemGroup>
+    <Compile Include="..\..\TestAssets\SharedTypes\ComInterfaces\*.cs" Link="ComInterfaceGenerator\ComInterfaces\%(FileName).cs" />
+  </ItemGroup>
+
   <!-- Until we use the live app host, use a prebuilt from the 7.0 SDK.
        Issue: https://github.com/dotnet/runtime/issues/58109. -->
   <ItemGroup Condition="'$(UseLocalAppHostPack)' != 'true'">
@@ -36,6 +40,7 @@
 
   <ItemGroup>
     <ProjectReference Include="..\SharedTypes\SharedTypes.csproj" />
+    <ProjectReference Include="..\..\Ancillary.Interop\Ancillary.Interop.csproj" />
   </ItemGroup>
 
   <Target Name="GetUnixBuildArgumentsForDNNE" Condition="'$(OS)' == 'Unix'">
diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IGetAndSetInt.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IGetAndSetInt.cs
new file mode 100644 (file)
index 0000000..2d76ef4
--- /dev/null
@@ -0,0 +1,24 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace SharedTypes.ComInterfaces
+{
+    [GeneratedComInterface]
+    [Guid(_guid)]
+    partial interface IGetAndSetInt
+    {
+        int GetInt();
+
+        public void SetInt(int x);
+
+        public const string _guid = "2c3f9903-b586-46b1-881b-adfce9af47b1";
+    }
+}
diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IGetIntArray.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IGetIntArray.cs
new file mode 100644 (file)
index 0000000..6b99c1d
--- /dev/null
@@ -0,0 +1,23 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Runtime.InteropServices.Marshalling;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace SharedTypes.ComInterfaces
+{
+    [GeneratedComInterface]
+    [Guid(_guid)]
+    partial interface IGetIntArray
+    {
+        [return: MarshalUsing(ConstantElementCount = 10)]
+        int[] GetInts();
+
+        public const string _guid = "7D802A0A-630A-4C8E-A21F-771CC9031FB9";
+    }
+}