From a2127cd09a0a0b59526dc16ca28b111c7ed463f4 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Mon, 15 May 2023 11:43:05 -0700 Subject: [PATCH] Change CreateObjectFlags.Unwrap to be ComWrappers-instance-specific (#86195) --- .../InteropServices/ComWrappers.NativeAot.cs | 30 +++++++- src/coreclr/vm/interoplibinterface_comwrappers.cpp | 38 ++++++++- src/tests/Interop/COM/ComWrappers/API/Program.cs | 90 ++++++++++++++++++++-- src/tests/Interop/COM/ComWrappers/Common.cs | 24 ++++++ 4 files changed, 170 insertions(+), 12 deletions(-) diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs index c348418..0bb2fb3 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs @@ -706,11 +706,35 @@ namespace System.Runtime.InteropServices if (flags.HasFlag(CreateObjectFlags.Unwrap)) { - var comInterfaceDispatch = TryGetComInterfaceDispatch(externalComObject); + ComInterfaceDispatch* comInterfaceDispatch = TryGetComInterfaceDispatch(externalComObject); if (comInterfaceDispatch != null) { - retValue = ComInterfaceDispatch.GetInstance(comInterfaceDispatch); - return true; + // If we found a managed object wrapper in this ComWrappers instance + // and it's has the same identity pointer as the one we're creating a NativeObjectWrapper for, + // unwrap it. We don't AddRef the wrapper as we don't take a reference to it. + // + // A managed object can have multiple managed object wrappers, with a max of one per context. + // Let's say we have a managed object A and ComWrappers instances C1 and C2. Let B1 and B2 be the + // managed object wrappers for A created with C1 and C2 respectively. + // If we are asked to create an EOC for B1 with the unwrap flag on the C2 ComWrappers instance, + // we will create a new wrapper. In this scenario, we'll only unwrap B2. + object unwrapped = ComInterfaceDispatch.GetInstance(comInterfaceDispatch); + if (_ccwTable.TryGetValue(unwrapped, out ManagedObjectWrapperHolder? unwrappedWrapperInThisContext)) + { + // The unwrapped object has a CCW in this context. Get the IUnknown for the externalComObject + // so we can see if it's the CCW for the unwrapped object in this context. + Guid iid = IID_IUnknown; + int hr = Marshal.QueryInterface(externalComObject, ref iid, out IntPtr externalIUnknown); + Debug.Assert(hr == 0); // An external COM object that came from a ComWrappers instance + // will always be well-formed. + if (unwrappedWrapperInThisContext.ComIp == externalIUnknown) + { + Marshal.Release(externalIUnknown); + retValue = unwrapped; + return true; + } + Marshal.Release(externalIUnknown); + } } } diff --git a/src/coreclr/vm/interoplibinterface_comwrappers.cpp b/src/coreclr/vm/interoplibinterface_comwrappers.cpp index 4108fdc..93fa6a4 100644 --- a/src/coreclr/vm/interoplibinterface_comwrappers.cpp +++ b/src/coreclr/vm/interoplibinterface_comwrappers.cpp @@ -866,12 +866,42 @@ namespace } else if (handle != NULL) { - // We have an object handle from the COM instance which is a CCW. Use that object. - // This allows for the round-trip from object -> COM instance -> object. + // We have an object handle from the COM instance which is a CCW. ::OBJECTHANDLE objectHandle = static_cast<::OBJECTHANDLE>(handle); - gc.objRefMaybe = ObjectFromHandle(objectHandle); + + // Now we need to check if this object is a CCW from the same ComWrappers instance + // as the one creating the EOC. If it is not, we need to create a new EOC for it. + // Otherwise, use it. This allows for the round-trip from object -> COM instance -> object. + OBJECTREF objRef = NULL; + GCPROTECT_BEGIN(objRef); + objRef = ObjectFromHandle(objectHandle); + + SyncBlock* syncBlock = objRef->GetSyncBlock(); + InteropSyncBlockInfo* interopInfo = syncBlock->GetInteropInfo(); + + // If we found a managed object wrapper in this ComWrappers instance + // and it's the same identity pointer as the one we're creating an EOC for, + // unwrap it. We don't AddRef the wrapper as we don't take a reference to it. + // + // A managed object can have multiple managed object wrappers, with a max of one per context. + // Let's say we have a managed object A and ComWrappers instances C1 and C2. Let B1 and B2 be the + // managed object wrappers for A created with C1 and C2 respectively. + // If we are asked to create an EOC for B1 with the unwrap flag on the C2 ComWrappers instance, + // we will create a new wrapper. In this scenario, we'll only unwrap B2. + void* wrapperRawMaybe = NULL; + if (interopInfo->TryGetManagedObjectComWrapper(wrapperId, &wrapperRawMaybe) + && wrapperRawMaybe == identity) + { + gc.objRefMaybe = objRef; + } + else + { + STRESS_LOG2(LF_INTEROP, LL_INFO1000, "Not unwrapping handle (0x%p) because the object's MOW in this ComWrappers instance (if any) (0x%p) is not the provided identity\n", handle, wrapperRawMaybe); + } + GCPROTECT_END(); } - else + + if (gc.objRefMaybe == NULL) { // Create context instance for the possibly new external object. ExternalWrapperResultHolder resultHolder; diff --git a/src/tests/Interop/COM/ComWrappers/API/Program.cs b/src/tests/Interop/COM/ComWrappers/API/Program.cs index 5982a1b..56c2748 100644 --- a/src/tests/Interop/COM/ComWrappers/API/Program.cs +++ b/src/tests/Interop/COM/ComWrappers/API/Program.cs @@ -84,12 +84,23 @@ namespace ComWrappersTests protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flag) { - var iid = typeof(ITrackerObject).GUID; + var iTrackerObjectIid = typeof(ITrackerObject).GUID; IntPtr iTrackerComObject; - int hr = Marshal.QueryInterface(externalComObject, ref iid, out iTrackerComObject); - Assert.Equal(0, hr); + int hr = Marshal.QueryInterface(externalComObject, ref iTrackerObjectIid, out iTrackerComObject); + if (hr == 0) + { + return new ITrackerObjectWrapper(iTrackerComObject); + } + var iTestIid = typeof(ITest).GUID; + IntPtr iTest; + hr = Marshal.QueryInterface(externalComObject, ref iTestIid, out iTest); + if (hr == 0) + { + return new ITestObjectWrapper(iTest); + } - return new ITrackerObjectWrapper(iTrackerComObject); + Assert.Fail("The COM object should support ITrackerObject or ITest for all tests in this test suite."); + return null; } public const int ReleaseObjectsCallAck = unchecked((int)-1); @@ -175,13 +186,82 @@ namespace ComWrappersTests Assert.NotEqual(IntPtr.Zero, comWrapper); var testObjUnwrapped = wrappers.GetOrCreateObjectForComInstance(comWrapper, CreateObjectFlags.Unwrap); - Assert.Equal(testObj, testObjUnwrapped); + Assert.Same(testObj, testObjUnwrapped); // Release the wrapper int count = Marshal.Release(comWrapper); Assert.Equal(0, count); } + [MethodImpl(MethodImplOptions.NoInlining)] + [Fact] + public void ValidateComInterfaceUnwrapWrapperSpecific() + { + Console.WriteLine($"Running {nameof(ValidateComInterfaceUnwrapWrapperSpecific)}..."); + + var testObj = new Test(); + + var wrappers = new TestComWrappers(); + + // Allocate a wrapper for the object + IntPtr comWrapper = wrappers.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.None); + Assert.NotEqual(IntPtr.Zero, comWrapper); + + // Make sure that unwrapping the wrapper in the same ComWrappers context gets back the same object + var testObjUnwrapped = GetUnwrappedObjectHandleForComInstance(wrappers, comWrapper); + AssertSameInstanceAndFreeHandle(testObj, testObjUnwrapped); + + // Make sure that unwrapping the wrapper in a different ComWrappers context gets back a different object + var wrappers2 = new TestComWrappers(); + var testObjWrapper2 = GetUnwrappedObjectHandleForComInstance(wrappers2, comWrapper); + AssertNotSameInstanceAndFreeHandle(testObj, testObjWrapper2); + + // Make sure that unwrapping a wrapper from a different ComWrappers context in a context that has created a CCW + // for the object only unwraps the wrapper from that context, not from any context. + var wrappers3 = new TestComWrappers(); + IntPtr comWrapper3 = wrappers3.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.None); + + Assert.NotEqual(IntPtr.Zero, comWrapper3); + Assert.NotEqual(comWrapper, comWrapper3); + + var testObjWrapper3 = GetUnwrappedObjectHandleForComInstance(wrappers3, comWrapper); + AssertNotSameInstanceAndFreeHandle(testObj, testObjWrapper3); + AssertSameInstanceAndFreeHandle(testObj, GetUnwrappedObjectHandleForComInstance(wrappers3, comWrapper3)); + + // Force a GC to release the new managed object wrappers we made + ForceGC(); + + // Release the COM wrappers + int count = Marshal.Release(comWrapper); + count = Marshal.Release(comWrapper3); + Assert.Equal(0, count); + + // Make sure that all possible references to the CCW over the RCW are never on the same frame + // as the rest of the test (to ensure that the GC does collect it). + [MethodImpl(MethodImplOptions.NoInlining)] + static GCHandle GetUnwrappedObjectHandleForComInstance(ComWrappers wrapper, nint comWrapper) + { + var obj = wrapper.GetOrCreateObjectForComInstance(comWrapper, CreateObjectFlags.Unwrap); + return GCHandle.Alloc(obj); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void AssertSameInstanceAndFreeHandle(object obj, GCHandle handle) + { + Assert.True(handle.IsAllocated); + Assert.Same(obj, handle.Target); + handle.Free(); + } + + [MethodImpl(MethodImplOptions.NoInlining)] + static void AssertNotSameInstanceAndFreeHandle(object obj, GCHandle handle) + { + Assert.True(handle.IsAllocated); + Assert.NotSame(obj, handle.Target); + handle.Free(); + } + } + [Fact] public void ValidateComObjectExtendsManagedLifetime() { diff --git a/src/tests/Interop/COM/ComWrappers/Common.cs b/src/tests/Interop/COM/ComWrappers/Common.cs index 0567185..41fdf34 100644 --- a/src/tests/Interop/COM/ComWrappers/Common.cs +++ b/src/tests/Interop/COM/ComWrappers/Common.cs @@ -85,6 +85,30 @@ namespace ComWrappersTests.Common } } + public class ITestObjectWrapper : ITest + { + private readonly ITestVtbl._SetValue _setValue; + private readonly IntPtr _ptr; + + public ITestObjectWrapper(IntPtr ptr) + { + _ptr = ptr; + VtblPtr inst = Marshal.PtrToStructure(ptr); + ITestVtbl _vtbl = Marshal.PtrToStructure(inst.Vtbl); + _setValue = Marshal.GetDelegateForFunctionPointer(_vtbl.SetValue); + } + + ~ITestObjectWrapper() + { + if (_ptr != IntPtr.Zero) + { + Marshal.Release(_ptr); + } + } + + public void SetValue(int i) => _setValue(_ptr, i); + } + // // Native interface definition with managed wrapper for tracker object // -- 2.7.4