Change CreateObjectFlags.Unwrap to be ComWrappers-instance-specific (#86195)
authorJeremy Koritzinsky <jekoritz@microsoft.com>
Mon, 15 May 2023 18:43:05 +0000 (11:43 -0700)
committerGitHub <noreply@github.com>
Mon, 15 May 2023 18:43:05 +0000 (11:43 -0700)
src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs
src/coreclr/vm/interoplibinterface_comwrappers.cpp
src/tests/Interop/COM/ComWrappers/API/Program.cs
src/tests/Interop/COM/ComWrappers/Common.cs

index c348418..0bb2fb3 100644 (file)
@@ -706,11 +706,35 @@ namespace System.Runtime.InteropServices
 
             if (flags.HasFlag(CreateObjectFlags.Unwrap))
             {
-                var comInterfaceDispatch = TryGetComInterfaceDispatch(externalComObject);
+                ComInterfaceDispatch* comInterfaceDispatch = TryGetComInterfaceDispatch(externalComObject);
                 if (comInterfaceDispatch != null)
                 {
-                    retValue = ComInterfaceDispatch.GetInstance<object>(comInterfaceDispatch);
-                    return true;
+                    // If we found a managed object wrapper in this ComWrappers instance
+                    // and it's has the same identity pointer as the one we're creating a NativeObjectWrapper for,
+                    // unwrap it. We don't AddRef the wrapper as we don't take a reference to it.
+                    //
+                    // A managed object can have multiple managed object wrappers, with a max of one per context.
+                    // Let's say we have a managed object A and ComWrappers instances C1 and C2. Let B1 and B2 be the
+                    // managed object wrappers for A created with C1 and C2 respectively.
+                    // If we are asked to create an EOC for B1 with the unwrap flag on the C2 ComWrappers instance,
+                    // we will create a new wrapper. In this scenario, we'll only unwrap B2.
+                    object unwrapped = ComInterfaceDispatch.GetInstance<object>(comInterfaceDispatch);
+                    if (_ccwTable.TryGetValue(unwrapped, out ManagedObjectWrapperHolder? unwrappedWrapperInThisContext))
+                    {
+                        // The unwrapped object has a CCW in this context. Get the IUnknown for the externalComObject
+                        // so we can see if it's the CCW for the unwrapped object in this context.
+                        Guid iid = IID_IUnknown;
+                        int hr = Marshal.QueryInterface(externalComObject, ref iid, out IntPtr externalIUnknown);
+                        Debug.Assert(hr == 0); // An external COM object that came from a ComWrappers instance
+                                               // will always be well-formed.
+                        if (unwrappedWrapperInThisContext.ComIp == externalIUnknown)
+                        {
+                            Marshal.Release(externalIUnknown);
+                            retValue = unwrapped;
+                            return true;
+                        }
+                        Marshal.Release(externalIUnknown);
+                    }
                 }
             }
 
index 4108fdc..93fa6a4 100644 (file)
@@ -866,12 +866,42 @@ namespace
         }
         else if (handle != NULL)
         {
-            // We have an object handle from the COM instance which is a CCW. Use that object.
-            // This allows for the round-trip from object -> COM instance -> object.
+            // We have an object handle from the COM instance which is a CCW.
             ::OBJECTHANDLE objectHandle = static_cast<::OBJECTHANDLE>(handle);
-            gc.objRefMaybe = ObjectFromHandle(objectHandle);
+
+            // Now we need to check if this object is a CCW from the same ComWrappers instance
+            // as the one creating the EOC. If it is not, we need to create a new EOC for it.
+            // Otherwise, use it. This allows for the round-trip from object -> COM instance -> object.
+            OBJECTREF objRef = NULL;
+            GCPROTECT_BEGIN(objRef);
+            objRef = ObjectFromHandle(objectHandle);
+
+            SyncBlock* syncBlock = objRef->GetSyncBlock();
+            InteropSyncBlockInfo* interopInfo = syncBlock->GetInteropInfo();
+
+            // If we found a managed object wrapper in this ComWrappers instance
+            // and it's the same identity pointer as the one we're creating an EOC for,
+            // unwrap it. We don't AddRef the wrapper as we don't take a reference to it.
+            //
+            // A managed object can have multiple managed object wrappers, with a max of one per context.
+            // Let's say we have a managed object A and ComWrappers instances C1 and C2. Let B1 and B2 be the
+            // managed object wrappers for A created with C1 and C2 respectively.
+            // If we are asked to create an EOC for B1 with the unwrap flag on the C2 ComWrappers instance,
+            // we will create a new wrapper. In this scenario, we'll only unwrap B2.
+            void* wrapperRawMaybe = NULL;
+            if (interopInfo->TryGetManagedObjectComWrapper(wrapperId, &wrapperRawMaybe)
+                && wrapperRawMaybe == identity)
+            {
+                gc.objRefMaybe = objRef;
+            }
+            else
+            {
+                STRESS_LOG2(LF_INTEROP, LL_INFO1000, "Not unwrapping handle (0x%p) because the object's MOW in this ComWrappers instance (if any) (0x%p) is not the provided identity\n", handle, wrapperRawMaybe);
+            }
+            GCPROTECT_END();
         }
-        else
+
+        if (gc.objRefMaybe == NULL)
         {
             // Create context instance for the possibly new external object.
             ExternalWrapperResultHolder resultHolder;
index 5982a1b..56c2748 100644 (file)
@@ -84,12 +84,23 @@ namespace ComWrappersTests
 
             protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flag)
             {
-                var iid = typeof(ITrackerObject).GUID;
+                var iTrackerObjectIid = typeof(ITrackerObject).GUID;
                 IntPtr iTrackerComObject;
-                int hr = Marshal.QueryInterface(externalComObject, ref iid, out iTrackerComObject);
-                Assert.Equal(0, hr);
+                int hr = Marshal.QueryInterface(externalComObject, ref iTrackerObjectIid, out iTrackerComObject);
+                if (hr == 0)
+                {
+                    return new ITrackerObjectWrapper(iTrackerComObject);
+                }
+                var iTestIid = typeof(ITest).GUID;
+                IntPtr iTest;
+                hr = Marshal.QueryInterface(externalComObject, ref iTestIid, out iTest);
+                if (hr == 0)
+                {
+                    return new ITestObjectWrapper(iTest);
+                }
 
-                return new ITrackerObjectWrapper(iTrackerComObject);
+                Assert.Fail("The COM object should support ITrackerObject or ITest for all tests in this test suite.");
+                return null;
             }
 
             public const int ReleaseObjectsCallAck = unchecked((int)-1);
@@ -175,13 +186,82 @@ namespace ComWrappersTests
             Assert.NotEqual(IntPtr.Zero, comWrapper);
 
             var testObjUnwrapped = wrappers.GetOrCreateObjectForComInstance(comWrapper, CreateObjectFlags.Unwrap);
-            Assert.Equal(testObj, testObjUnwrapped);
+            Assert.Same(testObj, testObjUnwrapped);
 
             // Release the wrapper
             int count = Marshal.Release(comWrapper);
             Assert.Equal(0, count);
         }
 
+        [MethodImpl(MethodImplOptions.NoInlining)]
+        [Fact]
+        public void ValidateComInterfaceUnwrapWrapperSpecific()
+        {
+            Console.WriteLine($"Running {nameof(ValidateComInterfaceUnwrapWrapperSpecific)}...");
+
+            var testObj = new Test();
+
+            var wrappers = new TestComWrappers();
+
+            // Allocate a wrapper for the object
+            IntPtr comWrapper = wrappers.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.None);
+            Assert.NotEqual(IntPtr.Zero, comWrapper);
+
+            // Make sure that unwrapping the wrapper in the same ComWrappers context gets back the same object
+            var testObjUnwrapped = GetUnwrappedObjectHandleForComInstance(wrappers, comWrapper);
+            AssertSameInstanceAndFreeHandle(testObj, testObjUnwrapped);
+
+            // Make sure that unwrapping the wrapper in a different ComWrappers context gets back a different object
+            var wrappers2 = new TestComWrappers();
+            var testObjWrapper2 = GetUnwrappedObjectHandleForComInstance(wrappers2, comWrapper);
+            AssertNotSameInstanceAndFreeHandle(testObj, testObjWrapper2);
+
+            // Make sure that unwrapping a wrapper from a different ComWrappers context in a context that has created a CCW
+            // for the object only unwraps the wrapper from that context, not from any context.
+            var wrappers3 = new TestComWrappers();
+            IntPtr comWrapper3 = wrappers3.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.None);
+
+            Assert.NotEqual(IntPtr.Zero, comWrapper3);
+            Assert.NotEqual(comWrapper, comWrapper3);
+
+            var testObjWrapper3 = GetUnwrappedObjectHandleForComInstance(wrappers3, comWrapper);
+            AssertNotSameInstanceAndFreeHandle(testObj, testObjWrapper3);
+            AssertSameInstanceAndFreeHandle(testObj, GetUnwrappedObjectHandleForComInstance(wrappers3, comWrapper3));
+
+            // Force a GC to release the new managed object wrappers we made
+            ForceGC();
+
+            // Release the COM wrappers
+            int count = Marshal.Release(comWrapper);
+            count = Marshal.Release(comWrapper3);
+            Assert.Equal(0, count);
+
+            // Make sure that all possible references to the CCW over the RCW are never on the same frame
+            // as the rest of the test (to ensure that the GC does collect it).
+            [MethodImpl(MethodImplOptions.NoInlining)]
+            static GCHandle GetUnwrappedObjectHandleForComInstance(ComWrappers wrapper, nint comWrapper)
+            {
+                var obj = wrapper.GetOrCreateObjectForComInstance(comWrapper, CreateObjectFlags.Unwrap);
+                return GCHandle.Alloc(obj);
+            }
+
+            [MethodImpl(MethodImplOptions.NoInlining)]
+            static void AssertSameInstanceAndFreeHandle(object obj, GCHandle handle)
+            {
+                Assert.True(handle.IsAllocated);
+                Assert.Same(obj, handle.Target);
+                handle.Free();
+            }
+
+            [MethodImpl(MethodImplOptions.NoInlining)]
+            static void AssertNotSameInstanceAndFreeHandle(object obj, GCHandle handle)
+            {
+                Assert.True(handle.IsAllocated);
+                Assert.NotSame(obj, handle.Target);
+                handle.Free();
+            }
+        }
+
         [Fact]
         public void ValidateComObjectExtendsManagedLifetime()
         {
index 0567185..41fdf34 100644 (file)
@@ -85,6 +85,30 @@ namespace ComWrappersTests.Common
         }
     }
 
+    public class ITestObjectWrapper : ITest
+    {
+        private readonly ITestVtbl._SetValue _setValue;
+        private readonly IntPtr _ptr;
+
+        public ITestObjectWrapper(IntPtr ptr)
+        {
+            _ptr = ptr;
+            VtblPtr inst = Marshal.PtrToStructure<VtblPtr>(ptr);
+            ITestVtbl _vtbl = Marshal.PtrToStructure<ITestVtbl>(inst.Vtbl);
+            _setValue = Marshal.GetDelegateForFunctionPointer<ITestVtbl._SetValue>(_vtbl.SetValue);
+        }
+
+        ~ITestObjectWrapper()
+        {
+            if (_ptr != IntPtr.Zero)
+            {
+                Marshal.Release(_ptr);
+            }
+        }
+
+        public void SetValue(int i) => _setValue(_ptr, i);
+    }
+
     //
     // Native interface definition with managed wrapper for tracker object
     //