From 38ae4d3729f4fe1069b362309cc1b9bd5241e997 Mon Sep 17 00:00:00 2001 From: Jackson Schuster <36744439+jtschuster@users.noreply.github.com> Date: Tue, 8 Aug 2023 17:37:37 -0500 Subject: [PATCH] PinnedMarshal after marshalling in unmanaged to managed (#90117) PinnedMarshal has the "FromManaged" call in stateful marshallers, so that needs to happen before marshal in unmanaged to managed stubs. --- .../UnmanagedToManagedStubGenerator.cs | 2 +- .../RcwAroundCcwTests.cs | 18 +++++ .../UnmanagedToManagedCustomMarshallingTests.cs | 5 +- .../ComInterfaces/IStatefulMarshalling.cs | 79 ++++++++++++++++++---- .../IStatelessCollectionStatelessElement.cs | 4 +- .../ComInterfaces/IStatelessMarshalling.cs | 4 +- 6 files changed, 90 insertions(+), 22 deletions(-) diff --git a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs index cc76d85..f1d0a9b 100644 --- a/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs +++ b/src/libraries/System.Runtime.InteropServices/gen/ComInterfaceGenerator/UnmanagedToManagedStubGenerator.cs @@ -72,8 +72,8 @@ namespace Microsoft.Interop tryStatements.Add(statements.InvokeStatement); tryStatements.AddRange(statements.NotifyForSuccessfulInvoke); - tryStatements.AddRange(statements.PinnedMarshal); tryStatements.AddRange(statements.Marshal); + tryStatements.AddRange(statements.PinnedMarshal); List allStatements = setupStatements; List finallyStatements = new(); diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs index 98390d9..5050200 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/RcwAroundCcwTests.cs @@ -195,6 +195,24 @@ namespace ComInterfaceGenerator.Tests } [Fact] + public void StatefulMarshalling() + { + var obj = CreateWrapper(); + var data = new StatefulType() { i = 42 }; + + obj.Method(data); + Assert.Equal(42, data.i); + obj.MethodIn(in data); + Assert.Equal(42, data.i); + var oldData = data; + obj.MethodRef(ref data); + Assert.True(oldData == data); // We want reference equality here + obj.MethodOut(out data); + Assert.Equal(1, data.i); + Assert.Equal(1, obj.Return().i); + Assert.Equal(1, obj.ReturnPreserveSig().i); + } + public void ICollectionMarshallingFails() { Type hrExceptionType = SystemFindsComCalleeException() ? typeof(MarshallingFailureException) : typeof(Exception); diff --git a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs index 54e96f3..c88031d 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/ComInterfaceGenerator.Tests/UnmanagedToManagedCustomMarshallingTests.cs @@ -417,6 +417,7 @@ namespace ComInterfaceGenerator.Tests private TUnmanaged* _originalUnmanaged; private TUnmanaged* _unmanaged; private TManaged[] _managed; + private TUnmanaged* Unmanaged => _unmanaged == null ? _unmanaged = (TUnmanaged*)Marshal.AllocCoTaskMem(sizeof(TUnmanaged) * _managed.Length) : _unmanaged; public void FromUnmanaged(TUnmanaged* unmanaged) { @@ -451,7 +452,7 @@ namespace ComInterfaceGenerator.Tests public TUnmanaged* ToUnmanaged() { - return _unmanaged = (TUnmanaged*)Marshal.AllocCoTaskMem(sizeof(TUnmanaged) * _managed.Length); + return Unmanaged; } public ReadOnlySpan GetManagedValuesSource() @@ -461,7 +462,7 @@ namespace ComInterfaceGenerator.Tests public Span GetUnmanagedValuesDestination() { - return new(_unmanaged, _managed.Length); + return new(Unmanaged, _managed.Length); } } } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatefulMarshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatefulMarshalling.cs index a1df4de..69f7e09 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatefulMarshalling.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatefulMarshalling.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using System.Runtime.InteropServices; using System.Runtime.InteropServices.Marshalling; @@ -19,9 +20,26 @@ namespace SharedTypes.ComInterfaces StatefulType ReturnPreserveSig(); } + [GeneratedComClass] + internal partial class StatefulMarshalling : IStatefulMarshalling + { + public void Method(StatefulType param) => param.i++; + public void MethodIn(in StatefulType param) => param.i++; + public void MethodOut(out StatefulType param) => param = new StatefulType() { i = 1 }; + public void MethodRef(ref StatefulType param) { } + public StatefulType Return() => new StatefulType() { i = 1 }; + public StatefulType ReturnPreserveSig() => new StatefulType() { i = 1 }; + } + [NativeMarshalling(typeof(StatefulTypeMarshaller))] internal class StatefulType { + public int i; + } + + internal struct StatefulNative + { + public int i; } [CustomMarshaller(typeof(StatefulType), MarshalMode.ManagedToUnmanagedIn, typeof(ManagedToUnmanaged))] @@ -34,29 +52,46 @@ namespace SharedTypes.ComInterfaces { internal struct Bidirectional { + public static int FreeCount { get; private set; } + StatefulType? _managed; + bool _hasManaged; + StatefulNative _unmanaged; + bool _hasUnmanaged; + public void FromManaged(StatefulType managed) { - throw new System.NotImplementedException(); + _hasManaged = true; + _managed = managed; } - public nint ToUnmanaged() + public StatefulNative ToUnmanaged() { - throw new System.NotImplementedException(); + if (!_hasManaged) throw new InvalidOperationException(); + return new StatefulNative() { i = _managed.i }; } - public void FromUnmanaged(nint unmanaged) + public void FromUnmanaged(StatefulNative unmanaged) { - throw new System.NotImplementedException(); + _hasUnmanaged = true; + _unmanaged = unmanaged; } public StatefulType ToManaged() { - throw new System.NotImplementedException(); + if (!_hasUnmanaged) + { + throw new InvalidOperationException(); + } + if (_hasManaged && _managed.i == _unmanaged.i) + { + return _managed; + } + return new StatefulType() { i = _unmanaged.i }; } public void Free() { - throw new System.NotImplementedException(); + FreeCount++; } public void OnInvoked() { } @@ -64,19 +99,24 @@ namespace SharedTypes.ComInterfaces internal struct ManagedToUnmanaged { + public static int FreeCount { get; private set; } + StatefulType? _managed; + bool _hasManaged; public void FromManaged(StatefulType managed) { - throw new System.NotImplementedException(); + _hasManaged = true; + _managed = managed; } - public nint ToUnmanaged() + public StatefulNative ToUnmanaged() { - throw new System.NotImplementedException(); + if (!_hasManaged) throw new InvalidOperationException(); + return new StatefulNative() { i = _managed.i }; } public void Free() { - throw new System.NotImplementedException(); + FreeCount++; } public void OnInvoked() { } @@ -84,19 +124,28 @@ namespace SharedTypes.ComInterfaces internal struct UnmanagedToManaged { - public void FromUnmanaged(nint unmanaged) + public static int FreeCount { get; private set; } + StatefulNative _unmanaged; + bool _hasUnmanaged; + + public void FromUnmanaged(StatefulNative unmanaged) { - throw new System.NotImplementedException(); + _hasUnmanaged = true; + _unmanaged = unmanaged; } public StatefulType ToManaged() { - throw new System.NotImplementedException(); + if (!_hasUnmanaged) + { + throw new InvalidOperationException(); + } + return new StatefulType() { i = _unmanaged.i }; } public void Free() { - throw new System.NotImplementedException(); + FreeCount++; } public void OnInvoked() { } diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessCollectionStatelessElement.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessCollectionStatelessElement.cs index 0b03eee..b295a9a 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessCollectionStatelessElement.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessCollectionStatelessElement.cs @@ -47,12 +47,12 @@ namespace SharedTypes.ComInterfaces [ContiguousCollectionMarshaller] [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ManagedToUnmanagedIn, typeof(StatelessCollectionMarshaller<,>.ManagedToUnmanaged))] [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.UnmanagedToManagedOut, typeof(StatelessCollectionMarshaller<,>.ManagedToUnmanaged))] - [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ElementIn, typeof(StatelessCollectionMarshaller<,>.Bidirectional))] [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ManagedToUnmanagedOut, typeof(StatelessCollectionMarshaller<,>.UnmanagedToManaged))] [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.UnmanagedToManagedIn, typeof(StatelessCollectionMarshaller<,>.UnmanagedToManaged))] - [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ElementOut, typeof(StatelessCollectionMarshaller<,>.Bidirectional))] [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.UnmanagedToManagedRef, typeof(StatelessCollectionMarshaller<,>.Bidirectional))] [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ManagedToUnmanagedRef, typeof(StatelessCollectionMarshaller<,>.Bidirectional))] + [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ElementIn, typeof(StatelessCollectionMarshaller<,>.Bidirectional))] + [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ElementOut, typeof(StatelessCollectionMarshaller<,>.Bidirectional))] [CustomMarshaller(typeof(StatelessCollection<>), MarshalMode.ElementRef, typeof(StatelessCollectionMarshaller<,>.Bidirectional))] internal static unsafe class StatelessCollectionMarshaller where TUnmanagedElement : unmanaged { diff --git a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessMarshalling.cs b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessMarshalling.cs index 66ccc0c..e063e77 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessMarshalling.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/TestAssets/SharedTypes/ComInterfaces/IStatelessMarshalling.cs @@ -39,13 +39,13 @@ namespace SharedTypes.ComInterfaces [CustomMarshaller(typeof(StatelessType), MarshalMode.ManagedToUnmanagedIn, typeof(ManagedToUnmanaged))] [CustomMarshaller(typeof(StatelessType), MarshalMode.UnmanagedToManagedOut, typeof(ManagedToUnmanaged))] - [CustomMarshaller(typeof(StatelessType), MarshalMode.ElementIn, typeof(Bidirectional))] [CustomMarshaller(typeof(StatelessType), MarshalMode.ManagedToUnmanagedOut, typeof(UnmanagedToManaged))] [CustomMarshaller(typeof(StatelessType), MarshalMode.UnmanagedToManagedIn, typeof(UnmanagedToManaged))] [CustomMarshaller(typeof(StatelessType), MarshalMode.ElementOut, typeof(Bidirectional))] + [CustomMarshaller(typeof(StatelessType), MarshalMode.ElementIn, typeof(Bidirectional))] + [CustomMarshaller(typeof(StatelessType), MarshalMode.ElementRef, typeof(Bidirectional))] [CustomMarshaller(typeof(StatelessType), MarshalMode.UnmanagedToManagedRef, typeof(Bidirectional))] [CustomMarshaller(typeof(StatelessType), MarshalMode.ManagedToUnmanagedRef, typeof(Bidirectional))] - [CustomMarshaller(typeof(StatelessType), MarshalMode.ElementRef, typeof(Bidirectional))] internal static class StatelessTypeMarshaller { public static int AllFreeCount => Bidirectional.FreeCount + UnmanagedToManaged.FreeCount + ManagedToUnmanaged.FreeCount; -- 2.7.4