[ComInterfaceGenerator] Add implementation and test for CallerAllocateBufferType...
authorJackson Schuster <36744439+jtschuster@users.noreply.github.com>
Thu, 10 Aug 2023 16:13:51 +0000 (11:13 -0500)
committerGitHub <noreply@github.com>
Thu, 10 Aug 2023 16:13:51 +0000 (09:13 -0700)
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs
src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessCallerAllocateBufferMarshalling.cs

index 89ec569..95ab982 100644 (file)
@@ -214,6 +214,24 @@ namespace ComInterfaceGenerator.Tests
         }
 
         [Fact]
+        public void StatelessCallerAllocatedBufferMarshalling()
+        {
+            var obj = CreateWrapper<StatelessCallerAllocatedBufferMarshalling, IStatelessCallerAllocatedBufferMarshalling>();
+            var data = new StatelessCallerAllocatedBufferType() { I = 42 };
+
+            obj.Method(data);
+            Assert.Equal(42, data.I);
+            obj.MethodIn(in data);
+            Assert.Equal(42, data.I);
+            obj.MethodRef(ref data);
+            Assert.Equal(200, data.I);
+            obj.MethodOut(out data);
+            Assert.Equal(20, data.I);
+            Assert.Equal(201, obj.Return().I);
+            Assert.Equal(202, obj.ReturnPreserveSig().I);
+        }
+
+        [Fact]
         public void ICollectionMarshallingFails()
         {
             Type hrExceptionType = SystemFindsComCalleeException() ? typeof(MarshallingFailureException) : typeof(Exception);
index f5c3d4b..1b04704 100644 (file)
@@ -2,6 +2,8 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
+using System.Collections.Generic;
+using System.Runtime.CompilerServices;
 using System.Runtime.InteropServices;
 using System.Runtime.InteropServices.Marshalling;
 
@@ -9,26 +11,26 @@ namespace SharedTypes.ComInterfaces
 {
     [GeneratedComInterface]
     [Guid("4732FA5D-C105-4A23-87A7-58DCEDD4A9B3")]
-    internal partial interface IStatelessCallerAllocateBufferMarshalling
+    internal partial interface IStatelessCallerAllocatedBufferMarshalling
     {
-        void Method([MarshalUsing(CountElementName = nameof(size))] StatelessCallerAllocatedBufferType param, int size);
-        void MethodIn([MarshalUsing(CountElementName = nameof(size))] in StatelessCallerAllocatedBufferType param, int size);
-        void MethodOut([MarshalUsing(CountElementName = nameof(size))] out StatelessCallerAllocatedBufferType param, int size);
-        void MethodRef([MarshalUsing(CountElementName = nameof(size))] ref StatelessCallerAllocatedBufferType param, int size);
+        void Method(StatelessCallerAllocatedBufferType param);
+        void MethodIn(in StatelessCallerAllocatedBufferType param);
+        void MethodOut(out StatelessCallerAllocatedBufferType param);
+        void MethodRef(ref StatelessCallerAllocatedBufferType param);
         StatelessCallerAllocatedBufferType Return();
         [PreserveSig]
         StatelessCallerAllocatedBufferType ReturnPreserveSig();
     }
 
     [GeneratedComClass]
-    internal partial class StatelessCallerAllocatedBufferMarshalling : IStatelessCallerAllocateBufferMarshalling
+    internal partial class StatelessCallerAllocatedBufferMarshalling : IStatelessCallerAllocatedBufferMarshalling
     {
-        public void Method([MarshalUsing(CountElementName = "size")] StatelessCallerAllocatedBufferType param, int size) { }
-        public void MethodIn([MarshalUsing(CountElementName = "size")] in StatelessCallerAllocatedBufferType param, int size) { }
-        public void MethodOut([MarshalUsing(CountElementName = "size")] out StatelessCallerAllocatedBufferType param, int size) { param = new StatelessCallerAllocatedBufferType { I = 42 }; }
-        public void MethodRef([MarshalUsing(CountElementName = "size")] ref StatelessCallerAllocatedBufferType param, int size) { param = new StatelessCallerAllocatedBufferType { I = 200 }; }
-        public StatelessCallerAllocatedBufferType Return() => throw new NotImplementedException();
-        public StatelessCallerAllocatedBufferType ReturnPreserveSig() => throw new NotImplementedException();
+        public void Method(StatelessCallerAllocatedBufferType param) { }
+        public void MethodIn(in StatelessCallerAllocatedBufferType param) { }
+        public void MethodOut(out StatelessCallerAllocatedBufferType param) { param = new StatelessCallerAllocatedBufferType { I = 20 }; }
+        public void MethodRef(ref StatelessCallerAllocatedBufferType param) { param = new StatelessCallerAllocatedBufferType { I = 200 }; }
+        public StatelessCallerAllocatedBufferType Return() => new StatelessCallerAllocatedBufferType() { I = 201 };
+        public StatelessCallerAllocatedBufferType ReturnPreserveSig() => new StatelessCallerAllocatedBufferType() { I = 202 };
     }
 
     [NativeMarshalling(typeof(StatelessCallerAllocatedBufferTypeMarshaller))]
@@ -38,16 +40,38 @@ namespace SharedTypes.ComInterfaces
     }
 
     [CustomMarshaller(typeof(StatelessCallerAllocatedBufferType), MarshalMode.Default, typeof(StatelessCallerAllocatedBufferTypeMarshaller))]
-    internal static class StatelessCallerAllocatedBufferTypeMarshaller
+    internal static unsafe class StatelessCallerAllocatedBufferTypeMarshaller
     {
+        static HashSet<nint> _ptrs = new();
         public static int FreeCount { get; private set; }
-        public static int BufferSize => 64;
-        public static nint ConvertToUnmanaged(StatelessCallerAllocatedBufferType managed, Span<byte> buffer) => managed.I;
+        public static int BufferSize => sizeof(int);
+        public static int* ConvertToUnmanaged(StatelessCallerAllocatedBufferType managed, Span<int> buffer)
+        {
+            buffer[0] = managed.I;
+            return (int*)Unsafe.AsPointer(ref buffer[0]);
+        }
 
-        public static StatelessCallerAllocatedBufferType ConvertToManaged(nint unmanaged) => new StatelessCallerAllocatedBufferType { I = (int)unmanaged };
+        public static StatelessCallerAllocatedBufferType ConvertToManaged(int* unmanaged)
+        {
+            return new StatelessCallerAllocatedBufferType() { I = *unmanaged };
+        }
 
-        public static void Free(nint unmanaged) => FreeCount++;
+        public static void Free(int* unmanaged)
+        {
+            FreeCount++;
+            if (_ptrs.Contains((nint)unmanaged))
+            {
+                Marshal.FreeHGlobal((nint)unmanaged);
+                _ptrs.Remove((nint)unmanaged);
+            }
+        }
 
-        public static nint ConvertToUnmanaged(StatelessCallerAllocatedBufferType managed) => managed.I;
+        public static int* ConvertToUnmanaged(StatelessCallerAllocatedBufferType managed)
+        {
+            nint ptr = Marshal.AllocHGlobal(sizeof(int));
+            _ptrs.Add(ptr);
+            *(int*)ptr = managed.I;
+            return (int*)ptr;
+        }
     }
 }