From 38fed9edbdd8f35134aa3e1883d7ed6e7eccc3be Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Thu, 10 Aug 2023 11:13:51 -0500 Subject: [PATCH] [ComInterfaceGenerator] Add implementation and test for CallerAllocateBufferType parameters (#90263) --- .../RcwAroundCcwTests.cs | 18 +++++++ .../IStatelessCallerAllocateBufferMarshalling.cs | 60 +++++++++++++++------- 2 files changed, 60 insertions(+), 18 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs index 89ec569..95ab982 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs @@ -214,6 +214,24 @@ namespace ComInterfaceGenerator.Tests } [Fact] + public void StatelessCallerAllocatedBufferMarshalling() + { + var obj = CreateWrapper(); + var data = new StatelessCallerAllocatedBufferType() { I = 42 }; + + obj.Method(data); + Assert.Equal(42, data.I); + obj.MethodIn(in data); + Assert.Equal(42, data.I); + obj.MethodRef(ref data); + Assert.Equal(200, data.I); + obj.MethodOut(out data); + Assert.Equal(20, data.I); + Assert.Equal(201, obj.Return().I); + Assert.Equal(202, obj.ReturnPreserveSig().I); + } + + [Fact] public void ICollectionMarshallingFails() { Type hrExceptionType = SystemFindsComCalleeException() ? typeof(MarshallingFailureException) : typeof(Exception); diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessCallerAllocateBufferMarshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessCallerAllocateBufferMarshalling.cs index f5c3d4b..1b04704 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessCallerAllocateBufferMarshalling.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessCallerAllocateBufferMarshalling.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; @@ -9,26 +11,26 @@ namespace SharedTypes.ComInterfaces { [GeneratedComInterface] [Guid("4732FA5D-C105-4A23-87A7-58DCEDD4A9B3")] - internal partial interface IStatelessCallerAllocateBufferMarshalling + internal partial interface IStatelessCallerAllocatedBufferMarshalling { - void Method([MarshalUsing(CountElementName = nameof(size))] StatelessCallerAllocatedBufferType param, int size); - void MethodIn([MarshalUsing(CountElementName = nameof(size))] in StatelessCallerAllocatedBufferType param, int size); - void MethodOut([MarshalUsing(CountElementName = nameof(size))] out StatelessCallerAllocatedBufferType param, int size); - void MethodRef([MarshalUsing(CountElementName = nameof(size))] ref StatelessCallerAllocatedBufferType param, int size); + void Method(StatelessCallerAllocatedBufferType param); + void MethodIn(in StatelessCallerAllocatedBufferType param); + void MethodOut(out StatelessCallerAllocatedBufferType param); + void MethodRef(ref StatelessCallerAllocatedBufferType param); StatelessCallerAllocatedBufferType Return(); [PreserveSig] StatelessCallerAllocatedBufferType ReturnPreserveSig(); } [GeneratedComClass] - internal partial class StatelessCallerAllocatedBufferMarshalling : IStatelessCallerAllocateBufferMarshalling + internal partial class StatelessCallerAllocatedBufferMarshalling : IStatelessCallerAllocatedBufferMarshalling { - public void Method([MarshalUsing(CountElementName = "size")] StatelessCallerAllocatedBufferType param, int size) { } - public void MethodIn([MarshalUsing(CountElementName = "size")] in StatelessCallerAllocatedBufferType param, int size) { } - public void MethodOut([MarshalUsing(CountElementName = "size")] out StatelessCallerAllocatedBufferType param, int size) { param = new StatelessCallerAllocatedBufferType { I = 42 }; } - public void MethodRef([MarshalUsing(CountElementName = "size")] ref StatelessCallerAllocatedBufferType param, int size) { param = new StatelessCallerAllocatedBufferType { I = 200 }; } - public StatelessCallerAllocatedBufferType Return() => throw new NotImplementedException(); - public StatelessCallerAllocatedBufferType ReturnPreserveSig() => throw new NotImplementedException(); + public void Method(StatelessCallerAllocatedBufferType param) { } + public void MethodIn(in StatelessCallerAllocatedBufferType param) { } + public void MethodOut(out StatelessCallerAllocatedBufferType param) { param = new StatelessCallerAllocatedBufferType { I = 20 }; } + public void MethodRef(ref StatelessCallerAllocatedBufferType param) { param = new StatelessCallerAllocatedBufferType { I = 200 }; } + public StatelessCallerAllocatedBufferType Return() => new StatelessCallerAllocatedBufferType() { I = 201 }; + public StatelessCallerAllocatedBufferType ReturnPreserveSig() => new StatelessCallerAllocatedBufferType() { I = 202 }; } [NativeMarshalling(typeof(StatelessCallerAllocatedBufferTypeMarshaller))] @@ -38,16 +40,38 @@ namespace SharedTypes.ComInterfaces } [CustomMarshaller(typeof(StatelessCallerAllocatedBufferType), MarshalMode.Default, typeof(StatelessCallerAllocatedBufferTypeMarshaller))] - internal static class StatelessCallerAllocatedBufferTypeMarshaller + internal static unsafe class StatelessCallerAllocatedBufferTypeMarshaller { + static HashSet _ptrs = new(); public static int FreeCount { get; private set; } - public static int BufferSize => 64; - public static nint ConvertToUnmanaged(StatelessCallerAllocatedBufferType managed, Span buffer) => managed.I; + public static int BufferSize => sizeof(int); + public static int* ConvertToUnmanaged(StatelessCallerAllocatedBufferType managed, Span buffer) + { + buffer[0] = managed.I; + return (int*)Unsafe.AsPointer(ref buffer[0]); + } - public static StatelessCallerAllocatedBufferType ConvertToManaged(nint unmanaged) => new StatelessCallerAllocatedBufferType { I = (int)unmanaged }; + public static StatelessCallerAllocatedBufferType ConvertToManaged(int* unmanaged) + { + return new StatelessCallerAllocatedBufferType() { I = *unmanaged }; + } - public static void Free(nint unmanaged) => FreeCount++; + public static void Free(int* unmanaged) + { + FreeCount++; + if (_ptrs.Contains((nint)unmanaged)) + { + Marshal.FreeHGlobal((nint)unmanaged); + _ptrs.Remove((nint)unmanaged); + } + } - public static nint ConvertToUnmanaged(StatelessCallerAllocatedBufferType managed) => managed.I; + public static int* ConvertToUnmanaged(StatelessCallerAllocatedBufferType managed) + { + nint ptr = Marshal.AllocHGlobal(sizeof(int)); + _ptrs.Add(ptr); + *(int*)ptr = managed.I; + return (int*)ptr; + } } } -- 2.7.4