Don't create a COM weak reference if the object is an aggregated COMWrappers RCW...
authorJeremy Koritzinsky <jekoritz@microsoft.com>
Sat, 6 Nov 2021 19:02:48 +0000 (12:02 -0700)
committerGitHub <noreply@github.com>
Sat, 6 Nov 2021 19:02:48 +0000 (12:02 -0700)
* Don't create a COM weak reference if the object is an aggregated COMWrappers RCW.

* Add test for weak reference + aggregation with native weak reference impl.

* Apply suggestions from code review

Co-authored-by: Aaron Robinson <arobins@microsoft.com>
Co-authored-by: Aaron Robinson <arobins@microsoft.com>
src/coreclr/vm/interoplibinterface.h
src/coreclr/vm/interoplibinterface_comwrappers.cpp
src/coreclr/vm/weakreferencenative.cpp
src/tests/Interop/COM/ComWrappers/WeakReference/WeakReferenceNative.cpp
src/tests/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs

index ad5c3c0..dd35bde 100644 (file)
@@ -34,7 +34,7 @@ public: // COM activation
     static void MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe);
 
 public: // Unwrapping support
-    static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId);
+    static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId, _Out_ bool* isAggregated);
     static bool HasManagedObjectComWrapper(_In_ OBJECTREF object, _Out_ bool* isActive);
 
 public: // GC interaction
index 5e73d97..8c3b988 100644 (file)
@@ -53,6 +53,9 @@ namespace
             // The EOC is "detached" and no longer used to map between identity and a managed object.
             // This will only be set if the EOC was inserted into the cache.
             Flags_Detached = 8,
+
+            // This EOC is an aggregated instance
+            Flags_Aggregated = 16
         };
         DWORD Flags;
 
@@ -900,7 +903,11 @@ namespace
                                 : ExternalObjectContext::Flags_None) |
                             (uniqueInstance
                                 ? ExternalObjectContext::Flags_None
-                                : ExternalObjectContext::Flags_InCache);
+                                : ExternalObjectContext::Flags_InCache) |
+                            ((flags & CreateObjectFlags::CreateObjectFlags_Aggregated) != 0
+                                ? ExternalObjectContext::Flags_Aggregated
+                                : ExternalObjectContext::Flags_None);
+
                 ExternalObjectContext::Construct(
                     resultHolder.GetContext(),
                     identity,
@@ -1774,7 +1781,7 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(
         objRef);
 }
 
-IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId)
+IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId, _Out_ bool* isAggregated)
 {
     CONTRACTL
     {
@@ -1807,6 +1814,7 @@ IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTE
     {
         ExternalObjectContext* context = reinterpret_cast<ExternalObjectContext*>(contextMaybe);
         *wrapperId = context->WrapperId;
+        *isAggregated = context->IsSet(ExternalObjectContext::Flags_Aggregated);
 
         IUnknown* identity = reinterpret_cast<IUnknown*>(context->Identity);
         GCX_PREEMP();
index ab2b6f9..c61467b 100644 (file)
@@ -108,6 +108,7 @@ private:
 //
 // In order to qualify to be used with a HNDTYPE_WEAK_NATIVE_COM, the incoming object must:
 //  * be an RCW
+//  * not be an aggregated RCW
 //  * respond to a QI for IWeakReferenceSource
 //  * succeed when asked for an IWeakReference*
 //
@@ -149,7 +150,14 @@ NativeComWeakHandleInfo* GetComWeakReferenceInfo(OBJECTREF* pObject)
 #endif
     {
 #ifdef FEATURE_COMWRAPPERS
-        pWeakReferenceSource = reinterpret_cast<IWeakReferenceSource*>(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource, &wrapperId));
+        bool isAggregated = false;
+        pWeakReferenceSource = reinterpret_cast<IWeakReferenceSource*>(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource, &wrapperId, &isAggregated));
+        if (isAggregated)
+        {
+            // If the RCW is an aggregated RCW, then the managed object cannot be recreated from the IUnknown as the outer IUnknown wraps the managed object.
+            // In this case, don't create a weak reference backed by a COM weak reference.
+            pWeakReferenceSource = nullptr;
+        }
 #endif
     }
 
@@ -448,7 +456,7 @@ FCIMPL3(void, WeakReferenceNative::Create, WeakReferenceObject * pThisUNSAFE, Ob
     _ASSERTE(gc.pThis->GetMethodTable()->CanCastToClass(pWeakReferenceMT));
 
     // Create the handle.
-#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS) 
+#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
     NativeComWeakHandleInfo *comWeakHandleInfo = nullptr;
     if (gc.pTarget != NULL)
     {
@@ -690,7 +698,7 @@ FCIMPL1(Object *, WeakReferenceNative::GetTarget, WeakReferenceObject * pThisUNS
 
     OBJECTREF pTarget = GetWeakReferenceTarget(pThis);
 
-#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS) 
+#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
     // If we found an object, or we're not a native COM weak reference, then we're done.  Othewrise
     // we can try to create a new RCW to the underlying native COM object if it's still alive.
     if (pTarget != NULL || !IsNativeComWeakReferenceHandle(pThis->m_Handle))
@@ -718,7 +726,7 @@ FCIMPL1(Object *, WeakReferenceOfTNative::GetTarget, WeakReferenceObject * pThis
     OBJECTREF pTarget = GetWeakReferenceTarget(pThis);
 
 
-#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS) 
+#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
     // If we found an object, or we're not a native COM weak reference, then we're done.  Othewrise
     // we can try to create a new RCW to the underlying native COM object if it's still alive.
     if (pTarget != NULL || !IsNativeComWeakReferenceHandle(pThis->m_Handle))
index 89a0b99..1b52136 100644 (file)
@@ -167,8 +167,122 @@ namespace
             return UnknownImpl::DoRelease();
         }
     };
+
+    struct WeakReferenceSource : public IWeakReferenceSource, public IInspectable
+    {
+    private:
+        IUnknown* _outerUnknown;
+        ComSmartPtr<WeakReference> _weakReference;
+    public:
+        WeakReferenceSource(IUnknown* outerUnknown)
+            :_outerUnknown(outerUnknown),
+            _weakReference(new WeakReference(this, 1))
+        {
+        }
+
+        STDMETHOD(GetWeakReference)(IWeakReference** ppWeakReference)
+        {
+            _weakReference->AddRef();
+            *ppWeakReference = _weakReference;
+            return S_OK;
+        }
+
+        STDMETHOD(QueryInterface)(
+            /* [in] */ REFIID riid,
+            /* [iid_is][out] */ void ** ppvObject)
+        {
+            if (riid == __uuidof(IWeakReferenceSource))
+            {
+                *ppvObject = static_cast<IWeakReferenceSource*>(this);
+                _weakReference->AddStrongRef();
+                return S_OK;
+            }
+            return _outerUnknown->QueryInterface(riid, ppvObject);
+        }
+        STDMETHOD_(ULONG, AddRef)(void)
+        {
+            return _weakReference->AddStrongRef();
+        }
+        STDMETHOD_(ULONG, Release)(void)
+        {
+            return _weakReference->ReleaseStrongRef();
+        }
+
+        STDMETHOD(GetRuntimeClassName)(HSTRING* pRuntimeClassName)
+        {
+            return E_NOTIMPL;
+        }
+
+        STDMETHOD(GetIids)(
+            ULONG *iidCount,
+            IID   **iids)
+        {
+            return E_NOTIMPL;
+        }
+
+        STDMETHOD(GetTrustLevel)(TrustLevel *trustLevel)
+        {
+            *trustLevel = FullTrust;
+            return S_OK;
+        }
+    };
+
+    struct AggregatedWeakReferenceSource : IInspectable
+    {
+    private:
+        IUnknown* _outerUnknown;
+        ComSmartPtr<WeakReferenceSource> _weakReference;
+    public:
+        AggregatedWeakReferenceSource(IUnknown* outerUnknown)
+            :_outerUnknown(outerUnknown),
+            _weakReference(new WeakReferenceSource(outerUnknown))
+        {
+        }
+
+        STDMETHOD(GetRuntimeClassName)(HSTRING* pRuntimeClassName)
+        {
+            return E_NOTIMPL;
+        }
+
+        STDMETHOD(GetIids)(
+            ULONG *iidCount,
+            IID   **iids)
+        {
+            return E_NOTIMPL;
+        }
+
+        STDMETHOD(GetTrustLevel)(TrustLevel *trustLevel)
+        {
+            *trustLevel = FullTrust;
+            return S_OK;
+        }
+
+        STDMETHOD(QueryInterface)(
+            /* [in] */ REFIID riid,
+            /* [iid_is][out] */ void ** ppvObject)
+        {
+            if (riid == __uuidof(IWeakReferenceSource))
+            {
+                return _weakReference->QueryInterface(riid, ppvObject);
+            }
+            return _outerUnknown->QueryInterface(riid, ppvObject);
+        }
+        STDMETHOD_(ULONG, AddRef)(void)
+        {
+            return _outerUnknown->AddRef();
+        }
+        STDMETHOD_(ULONG, Release)(void)
+        {
+            return _outerUnknown->Release();
+        }
+    };
 }
 extern "C" DLL_EXPORT WeakReferencableObject* STDMETHODCALLTYPE CreateWeakReferencableObject()
 {
     return new WeakReferencableObject();
 }
+
+extern "C" DLL_EXPORT AggregatedWeakReferenceSource* STDMETHODCALLTYPE CreateAggregatedWeakReferenceObject(IUnknown* pOuter)
+{
+    return new AggregatedWeakReferenceSource(pOuter);
+}
index 38e40d1..1d2ebe8 100644 (file)
@@ -14,6 +14,9 @@ namespace ComWrappersTests
     {
         [DllImport(nameof(WeakReferenceNative))]
         public static extern IntPtr CreateWeakReferencableObject();
+
+        [DllImport(nameof(WeakReferenceNative))]
+        public static extern IntPtr CreateAggregatedWeakReferenceObject(IntPtr outer);
     }
 
     public struct VtblPtr
@@ -28,71 +31,96 @@ namespace ComWrappersTests
         Marshalling,
     }
 
-    public class WeakReferenceableWrapper
+    public unsafe class WeakReferenceableWrapper
     {
         private struct Vtbl
         {
-            public IntPtr QueryInterface;
-            public _AddRef AddRef;
-            public _Release Release;
+            public delegate* unmanaged<IntPtr, Guid*, IntPtr*, int> QueryInterface;
+            public delegate* unmanaged<IntPtr, int> AddRef;
+            public delegate* unmanaged<IntPtr, int> Release;
         }
 
-        private delegate int _AddRef(IntPtr This);
-        private delegate int _Release(IntPtr This);
-
         private readonly IntPtr instance;
         private readonly Vtbl vtable;
+        private readonly bool releaseInFinalizer;
 
         public WrapperRegistration Registration { get; }
 
-        public WeakReferenceableWrapper(IntPtr instance, WrapperRegistration reg)
+        public WeakReferenceableWrapper(IntPtr instance, WrapperRegistration reg, bool releaseInFinalizer = true)
         {
             var inst = Marshal.PtrToStructure<VtblPtr>(instance);
             this.vtable = Marshal.PtrToStructure<Vtbl>(inst.Vtbl);
             this.instance = instance;
+            this.releaseInFinalizer = releaseInFinalizer;
             Registration = reg;
         }
 
+        public int QueryInterface(Guid iid, out IntPtr ptr)
+        {
+            fixed(IntPtr* ppv = &ptr)
+            {
+                return this.vtable.QueryInterface(this.instance, &iid, ppv);
+            }
+        }
+
         ~WeakReferenceableWrapper()
         {
-            if (this.instance != IntPtr.Zero)
+            if (this.instance != IntPtr.Zero && this.releaseInFinalizer)
             {
                 this.vtable.Release(this.instance);
             }
         }
     }
 
-    class Program
+    class DerivedObject : ICustomQueryInterface
     {
-        class TestComWrappers : ComWrappers
+        private WeakReferenceableWrapper inner;
+        public DerivedObject(TestComWrappers comWrappersInstance)
         {
-            public WrapperRegistration Registration { get; }
+            IntPtr innerInstance = WeakReferenceNative.CreateAggregatedWeakReferenceObject(
+                comWrappersInstance.GetOrCreateComInterfaceForObject(this, CreateComInterfaceFlags.None));
+            inner = new WeakReferenceableWrapper(innerInstance, comWrappersInstance.Registration, releaseInFinalizer: false);
+            comWrappersInstance.GetOrRegisterObjectForComInstance(innerInstance, CreateObjectFlags.Aggregation, this);
+        }
 
-            public TestComWrappers(WrapperRegistration reg = WrapperRegistration.Local)
-            {
-                Registration = reg;
-            }
+        public CustomQueryInterfaceResult GetInterface(ref Guid iid, out IntPtr ppv)
+        {
+            return inner.QueryInterface(iid, out ppv) == 0 ? CustomQueryInterfaceResult.Handled : CustomQueryInterfaceResult.Failed;
+        }
+    }
 
-            protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
-            {
-                count = 0;
-                return null;
-            }
+    class TestComWrappers : ComWrappers
+    {
+        public WrapperRegistration Registration { get; }
 
-            protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flag)
-            {
-                Marshal.AddRef(externalComObject);
-                return new WeakReferenceableWrapper(externalComObject, Registration);
-            }
+        public TestComWrappers(WrapperRegistration reg = WrapperRegistration.Local)
+        {
+            Registration = reg;
+        }
 
-            protected override void ReleaseObjects(IEnumerable objects)
-            {
-            }
+        protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
+        {
+            count = 0;
+            return null;
+        }
+
+        protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flag)
+        {
+            Marshal.AddRef(externalComObject);
+            return new WeakReferenceableWrapper(externalComObject, Registration);
+        }
 
-            public static readonly TestComWrappers TrackerSupportInstance = new TestComWrappers(WrapperRegistration.TrackerSupport);
-            public static readonly TestComWrappers MarshallingInstance = new TestComWrappers(WrapperRegistration.Marshalling);
+        protected override void ReleaseObjects(IEnumerable objects)
+        {
         }
 
+        public static readonly TestComWrappers TrackerSupportInstance = new TestComWrappers(WrapperRegistration.TrackerSupport);
+        public static readonly TestComWrappers MarshallingInstance = new TestComWrappers(WrapperRegistration.Marshalling);
+    }
+
+    class Program
+    {
+
         private static void ValidateWeakReferenceState(WeakReference<WeakReferenceableWrapper> wr, bool expectedIsAlive, TestComWrappers sourceWrappers = null)
         {
             WeakReferenceableWrapper target;
@@ -135,7 +163,7 @@ namespace ComWrappersTests
             // a global ComWrappers instance. If the RCW was created throug a local ComWrappers instance, the weak
             // reference should be dead and stay dead once the RCW is collected.
             bool supportsRehydration = cw.Registration != WrapperRegistration.Local;
-            
+
             Console.WriteLine($"    -- Validate RCW recreation");
             ValidateWeakReferenceState(weakRef, expectedIsAlive: supportsRehydration, cw);
 
@@ -221,6 +249,26 @@ namespace ComWrappersTests
             Assert.IsNull(weakRef.Target);
         }
 
+        static void ValidateAggregatedWeakReference()
+        {
+            Console.WriteLine("Validate weak reference with aggregation.");
+            var (handle, weakRef) = GetWeakReference();
+
+            GC.Collect();
+            GC.WaitForPendingFinalizers();
+
+            Assert.IsNull(handle.Target);
+            Assert.IsFalse(weakRef.TryGetTarget(out _));
+
+            static (GCHandle handle, WeakReference<DerivedObject>) GetWeakReference()
+            {
+                DerivedObject obj = new DerivedObject(TestComWrappers.TrackerSupportInstance);
+                // We use an explicit weak GC handle here to enable us to validate that we are using "weak" GCHandle
+                // semantics with the weak reference.
+                return (GCHandle.Alloc(obj, GCHandleType.Weak), new WeakReference<DerivedObject>(obj));
+            }
+        }
+
         static int Main(string[] doNotUse)
         {
             try
@@ -235,6 +283,7 @@ namespace ComWrappersTests
 
                 ComWrappers.RegisterForTrackerSupport(TestComWrappers.TrackerSupportInstance);
                 ValidateGlobalInstanceTrackerSupport();
+                ValidateAggregatedWeakReference();
 
                 ValidateLocalInstance();
             }