PinnedMarshal after marshalling in unmanaged to managed (#90117)
authorJackson Schuster <36744439+jtschuster@users.noreply.github.com>
Tue, 8 Aug 2023 22:37:37 +0000 (17:37 -0500)
committerGitHub <noreply@github.com>
Tue, 8 Aug 2023 22:37:37 +0000 (15:37 -0700)
PinnedMarshal has the "FromManaged" call in stateful marshallers, so that needs to happen before marshal in unmanaged to managed stubs.

src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs
src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs
src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatefulMarshalling.cs
src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessCollectionStatelessElement.cs
src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessMarshalling.cs

index cc76d85..f1d0a9b 100644 (file)
@@ -72,8 +72,8 @@ namespace Microsoft.Interop
             tryStatements.Add(statements.InvokeStatement);
 
             tryStatements.AddRange(statements.NotifyForSuccessfulInvoke);
-            tryStatements.AddRange(statements.PinnedMarshal);
             tryStatements.AddRange(statements.Marshal);
+            tryStatements.AddRange(statements.PinnedMarshal);
 
             List<StatementSyntax> allStatements = setupStatements;
             List<StatementSyntax> finallyStatements = new();
index 98390d9..5050200 100644 (file)
@@ -195,6 +195,24 @@ namespace ComInterfaceGenerator.Tests
         }
 
         [Fact]
+        public void StatefulMarshalling()
+        {
+            var obj = CreateWrapper<StatefulMarshalling, IStatefulMarshalling>();
+            var data = new StatefulType() { i = 42 };
+
+            obj.Method(data);
+            Assert.Equal(42, data.i);
+            obj.MethodIn(in data);
+            Assert.Equal(42, data.i);
+            var oldData = data;
+            obj.MethodRef(ref data);
+            Assert.True(oldData == data); // We want reference equality here
+            obj.MethodOut(out data);
+            Assert.Equal(1, data.i);
+            Assert.Equal(1, obj.Return().i);
+            Assert.Equal(1, obj.ReturnPreserveSig().i);
+        }
+
         public void ICollectionMarshallingFails()
         {
             Type hrExceptionType = SystemFindsComCalleeException() ? typeof(MarshallingFailureException) : typeof(Exception);
index 54e96f3..c88031d 100644 (file)
@@ -417,6 +417,7 @@ namespace ComInterfaceGenerator.Tests
                 private TUnmanaged* _originalUnmanaged;
                 private TUnmanaged* _unmanaged;
                 private TManaged[] _managed;
+                private TUnmanaged* Unmanaged => _unmanaged == null ? _unmanaged = (TUnmanaged*)Marshal.AllocCoTaskMem(sizeof(TUnmanaged) * _managed.Length) : _unmanaged;
 
                 public void FromUnmanaged(TUnmanaged* unmanaged)
                 {
@@ -451,7 +452,7 @@ namespace ComInterfaceGenerator.Tests
 
                 public TUnmanaged* ToUnmanaged()
                 {
-                    return _unmanaged = (TUnmanaged*)Marshal.AllocCoTaskMem(sizeof(TUnmanaged) * _managed.Length); 
+                    return Unmanaged;
                 }
 
                 public ReadOnlySpan<TManaged> GetManagedValuesSource()
@@ -461,7 +462,7 @@ namespace ComInterfaceGenerator.Tests
 
                 public Span<TUnmanaged> GetUnmanagedValuesDestination()
                 {
-                    return new(_unmanaged, _managed.Length);
+                    return new(Unmanaged, _managed.Length);
                 }
             }
         }
index a1df4de..69f7e09 100644 (file)
@@ -1,6 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System;
 using System.Runtime.InteropServices;
 using System.Runtime.InteropServices.Marshalling;
 
@@ -19,9 +20,26 @@ namespace SharedTypes.ComInterfaces
         StatefulType ReturnPreserveSig();
     }
 
+    [GeneratedComClass]
+    internal partial class StatefulMarshalling : IStatefulMarshalling
+    {
+        public void Method(StatefulType param) => param.i++;
+        public void MethodIn(in StatefulType param) => param.i++;
+        public void MethodOut(out StatefulType param) => param = new StatefulType() { i = 1 };
+        public void MethodRef(ref StatefulType param) { }
+        public StatefulType Return() => new StatefulType() { i = 1 };
+        public StatefulType ReturnPreserveSig() => new StatefulType() { i = 1 };
+    }
+
     [NativeMarshalling(typeof(StatefulTypeMarshaller))]
     internal class StatefulType
     {
+        public int i;
+    }
+
+    internal struct StatefulNative
+    {
+        public int i;
     }
 
     [CustomMarshaller(typeof(StatefulType), MarshalMode.ManagedToUnmanagedIn, typeof(ManagedToUnmanaged))]
@@ -34,29 +52,46 @@ namespace SharedTypes.ComInterfaces
     {
         internal struct Bidirectional
         {
+            public static int FreeCount { get; private set; }
+            StatefulType? _managed;
+            bool _hasManaged;
+            StatefulNative _unmanaged;
+            bool _hasUnmanaged;
+
             public void FromManaged(StatefulType managed)
             {
-                throw new System.NotImplementedException();
+                _hasManaged = true;
+                _managed = managed;
             }
 
-            public nint ToUnmanaged()
+            public StatefulNative ToUnmanaged()
             {
-                throw new System.NotImplementedException();
+                if (!_hasManaged) throw new InvalidOperationException();
+                return new StatefulNative() { i = _managed.i };
             }
 
-            public void FromUnmanaged(nint unmanaged)
+            public void FromUnmanaged(StatefulNative unmanaged)
             {
-                throw new System.NotImplementedException();
+                _hasUnmanaged = true;
+                _unmanaged = unmanaged;
             }
 
             public StatefulType ToManaged()
             {
-                throw new System.NotImplementedException();
+                if (!_hasUnmanaged)
+                {
+                    throw new InvalidOperationException();
+                }
+                if (_hasManaged && _managed.i == _unmanaged.i)
+                {
+                    return _managed;
+                }
+                return new StatefulType() { i = _unmanaged.i };
             }
 
             public void Free()
             {
-                throw new System.NotImplementedException();
+                FreeCount++;
             }
 
             public void OnInvoked() { }
@@ -64,19 +99,24 @@ namespace SharedTypes.ComInterfaces
 
         internal struct ManagedToUnmanaged
         {
+            public static int FreeCount { get; private set; }
+            StatefulType? _managed;
+            bool _hasManaged;
             public void FromManaged(StatefulType managed)
             {
-                throw new System.NotImplementedException();
+                _hasManaged = true;
+                _managed = managed;
             }
 
-            public nint ToUnmanaged()
+            public StatefulNative ToUnmanaged()
             {
-                throw new System.NotImplementedException();
+                if (!_hasManaged) throw new InvalidOperationException();
+                return new StatefulNative() { i = _managed.i };
             }
 
             public void Free()
             {
-                throw new System.NotImplementedException();
+                FreeCount++;
             }
 
             public void OnInvoked() { }
@@ -84,19 +124,28 @@ namespace SharedTypes.ComInterfaces
 
         internal struct UnmanagedToManaged
         {
-            public void FromUnmanaged(nint unmanaged)
+            public static int FreeCount { get; private set; }
+            StatefulNative _unmanaged;
+            bool _hasUnmanaged;
+
+            public void FromUnmanaged(StatefulNative unmanaged)
             {
-                throw new System.NotImplementedException();
+                _hasUnmanaged = true;
+                _unmanaged = unmanaged;
             }
 
             public StatefulType ToManaged()
             {
-                throw new System.NotImplementedException();
+                if (!_hasUnmanaged)
+                {
+                    throw new InvalidOperationException();
+                }
+                return new StatefulType() { i = _unmanaged.i };
             }
 
             public void Free()
             {
-                throw new System.NotImplementedException();
+                FreeCount++;
             }
 
             public void OnInvoked() { }
index 0b03eee..b295a9a 100644 (file)
@@ -47,12 +47,12 @@ namespace SharedTypes.ComInterfaces
     [ContiguousCollectionMarshaller]
     [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ManagedToUnmanagedIn, typeof(StatelessCollectionMarshaller<,>.ManagedToUnmanaged))]
     [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.UnmanagedToManagedOut, typeof(StatelessCollectionMarshaller<,>.ManagedToUnmanaged))]
-    [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ElementIn, typeof(StatelessCollectionMarshaller<,>.Bidirectional))]
     [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ManagedToUnmanagedOut, typeof(StatelessCollectionMarshaller<,>.UnmanagedToManaged))]
     [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.UnmanagedToManagedIn, typeof(StatelessCollectionMarshaller<,>.UnmanagedToManaged))]
-    [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ElementOut, typeof(StatelessCollectionMarshaller<,>.Bidirectional))]
     [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.UnmanagedToManagedRef, typeof(StatelessCollectionMarshaller<,>.Bidirectional))]
     [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ManagedToUnmanagedRef, typeof(StatelessCollectionMarshaller<,>.Bidirectional))]
+    [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ElementIn, typeof(StatelessCollectionMarshaller<,>.Bidirectional))]
+    [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ElementOut, typeof(StatelessCollectionMarshaller<,>.Bidirectional))]
     [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ElementRef, typeof(StatelessCollectionMarshaller<,>.Bidirectional))]
     internal static unsafe class StatelessCollectionMarshaller<T, TUnmanagedElement> where TUnmanagedElement : unmanaged
     {
index 66ccc0c..e063e77 100644 (file)
@@ -39,13 +39,13 @@ namespace SharedTypes.ComInterfaces
 
     [CustomMarshaller(typeof(StatelessType), MarshalMode.ManagedToUnmanagedIn, typeof(ManagedToUnmanaged))]
     [CustomMarshaller(typeof(StatelessType), MarshalMode.UnmanagedToManagedOut, typeof(ManagedToUnmanaged))]
-    [CustomMarshaller(typeof(StatelessType), MarshalMode.ElementIn, typeof(Bidirectional))]
     [CustomMarshaller(typeof(StatelessType), MarshalMode.ManagedToUnmanagedOut, typeof(UnmanagedToManaged))]
     [CustomMarshaller(typeof(StatelessType), MarshalMode.UnmanagedToManagedIn, typeof(UnmanagedToManaged))]
     [CustomMarshaller(typeof(StatelessType), MarshalMode.ElementOut, typeof(Bidirectional))]
+    [CustomMarshaller(typeof(StatelessType), MarshalMode.ElementIn, typeof(Bidirectional))]
+    [CustomMarshaller(typeof(StatelessType), MarshalMode.ElementRef, typeof(Bidirectional))]
     [CustomMarshaller(typeof(StatelessType), MarshalMode.UnmanagedToManagedRef, typeof(Bidirectional))]
     [CustomMarshaller(typeof(StatelessType), MarshalMode.ManagedToUnmanagedRef, typeof(Bidirectional))]
-    [CustomMarshaller(typeof(StatelessType), MarshalMode.ElementRef, typeof(Bidirectional))]
     internal static class StatelessTypeMarshaller
     {
         public static int AllFreeCount => Bidirectional.FreeCount + UnmanagedToManaged.FreeCount + ManagedToUnmanaged.FreeCount;