From aefb0fc63b6fceef35039c3d5bf61b7d03fd99a0 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Sat, 6 Nov 2021 12:02:48 -0700 Subject: [PATCH] Don't create a COM weak reference if the object is an aggregated COMWrappers RCW. (#61267) * Don't create a COM weak reference if the object is an aggregated COMWrappers RCW. * Add test for weak reference + aggregation with native weak reference impl. * Apply suggestions from code review Co-authored-by: Aaron Robinson Co-authored-by: Aaron Robinson --- src/coreclr/vm/interoplibinterface.h | 2 +- src/coreclr/vm/interoplibinterface_comwrappers.cpp | 12 ++- src/coreclr/vm/weakreferencenative.cpp | 16 ++- .../WeakReference/WeakReferenceNative.cpp | 114 +++++++++++++++++++++ .../ComWrappers/WeakReference/WeakReferenceTest.cs | 113 ++++++++++++++------ 5 files changed, 218 insertions(+), 39 deletions(-) diff --git a/src/coreclr/vm/interoplibinterface.h b/src/coreclr/vm/interoplibinterface.h index ad5c3c0..dd35bde 100644 --- a/src/coreclr/vm/interoplibinterface.h +++ b/src/coreclr/vm/interoplibinterface.h @@ -34,7 +34,7 @@ public: // COM activation static void MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe); public: // Unwrapping support - static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId); + static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId, _Out_ bool* isAggregated); static bool HasManagedObjectComWrapper(_In_ OBJECTREF object, _Out_ bool* isActive); public: // GC interaction diff --git a/src/coreclr/vm/interoplibinterface_comwrappers.cpp b/src/coreclr/vm/interoplibinterface_comwrappers.cpp index 5e73d97..8c3b988 100644 --- a/src/coreclr/vm/interoplibinterface_comwrappers.cpp +++ b/src/coreclr/vm/interoplibinterface_comwrappers.cpp @@ -53,6 +53,9 @@ namespace // The EOC is "detached" and no longer used to map between identity and a managed object. // This will only be set if the EOC was inserted into the cache. Flags_Detached = 8, + + // This EOC is an aggregated instance + Flags_Aggregated = 16 }; DWORD Flags; @@ -900,7 +903,11 @@ namespace : ExternalObjectContext::Flags_None) | (uniqueInstance ? ExternalObjectContext::Flags_None - : ExternalObjectContext::Flags_InCache); + : ExternalObjectContext::Flags_InCache) | + ((flags & CreateObjectFlags::CreateObjectFlags_Aggregated) != 0 + ? ExternalObjectContext::Flags_Aggregated + : ExternalObjectContext::Flags_None); + ExternalObjectContext::Construct( resultHolder.GetContext(), identity, @@ -1774,7 +1781,7 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance( objRef); } -IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId) +IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId, _Out_ bool* isAggregated) { CONTRACTL { @@ -1807,6 +1814,7 @@ IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTE { ExternalObjectContext* context = reinterpret_cast(contextMaybe); *wrapperId = context->WrapperId; + *isAggregated = context->IsSet(ExternalObjectContext::Flags_Aggregated); IUnknown* identity = reinterpret_cast(context->Identity); GCX_PREEMP(); diff --git a/src/coreclr/vm/weakreferencenative.cpp b/src/coreclr/vm/weakreferencenative.cpp index ab2b6f9..c61467b 100644 --- a/src/coreclr/vm/weakreferencenative.cpp +++ b/src/coreclr/vm/weakreferencenative.cpp @@ -108,6 +108,7 @@ private: // // In order to qualify to be used with a HNDTYPE_WEAK_NATIVE_COM, the incoming object must: // * be an RCW +// * not be an aggregated RCW // * respond to a QI for IWeakReferenceSource // * succeed when asked for an IWeakReference* // @@ -149,7 +150,14 @@ NativeComWeakHandleInfo* GetComWeakReferenceInfo(OBJECTREF* pObject) #endif { #ifdef FEATURE_COMWRAPPERS - pWeakReferenceSource = reinterpret_cast(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource, &wrapperId)); + bool isAggregated = false; + pWeakReferenceSource = reinterpret_cast(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource, &wrapperId, &isAggregated)); + if (isAggregated) + { + // If the RCW is an aggregated RCW, then the managed object cannot be recreated from the IUnknown as the outer IUnknown wraps the managed object. + // In this case, don't create a weak reference backed by a COM weak reference. + pWeakReferenceSource = nullptr; + } #endif } @@ -448,7 +456,7 @@ FCIMPL3(void, WeakReferenceNative::Create, WeakReferenceObject * pThisUNSAFE, Ob _ASSERTE(gc.pThis->GetMethodTable()->CanCastToClass(pWeakReferenceMT)); // Create the handle. -#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS) +#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS) NativeComWeakHandleInfo *comWeakHandleInfo = nullptr; if (gc.pTarget != NULL) { @@ -690,7 +698,7 @@ FCIMPL1(Object *, WeakReferenceNative::GetTarget, WeakReferenceObject * pThisUNS OBJECTREF pTarget = GetWeakReferenceTarget(pThis); -#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS) +#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS) // If we found an object, or we're not a native COM weak reference, then we're done. Othewrise // we can try to create a new RCW to the underlying native COM object if it's still alive. if (pTarget != NULL || !IsNativeComWeakReferenceHandle(pThis->m_Handle)) @@ -718,7 +726,7 @@ FCIMPL1(Object *, WeakReferenceOfTNative::GetTarget, WeakReferenceObject * pThis OBJECTREF pTarget = GetWeakReferenceTarget(pThis); -#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS) +#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS) // If we found an object, or we're not a native COM weak reference, then we're done. Othewrise // we can try to create a new RCW to the underlying native COM object if it's still alive. if (pTarget != NULL || !IsNativeComWeakReferenceHandle(pThis->m_Handle)) diff --git a/src/tests/Interop/COM/ComWrappers/WeakReference/WeakReferenceNative.cpp b/src/tests/Interop/COM/ComWrappers/WeakReference/WeakReferenceNative.cpp index 89a0b99..1b52136 100644 --- a/src/tests/Interop/COM/ComWrappers/WeakReference/WeakReferenceNative.cpp +++ b/src/tests/Interop/COM/ComWrappers/WeakReference/WeakReferenceNative.cpp @@ -167,8 +167,122 @@ namespace return UnknownImpl::DoRelease(); } }; + + struct WeakReferenceSource : public IWeakReferenceSource, public IInspectable + { + private: + IUnknown* _outerUnknown; + ComSmartPtr _weakReference; + public: + WeakReferenceSource(IUnknown* outerUnknown) + :_outerUnknown(outerUnknown), + _weakReference(new WeakReference(this, 1)) + { + } + + STDMETHOD(GetWeakReference)(IWeakReference** ppWeakReference) + { + _weakReference->AddRef(); + *ppWeakReference = _weakReference; + return S_OK; + } + + STDMETHOD(QueryInterface)( + /* [in] */ REFIID riid, + /* [iid_is][out] */ void ** ppvObject) + { + if (riid == __uuidof(IWeakReferenceSource)) + { + *ppvObject = static_cast(this); + _weakReference->AddStrongRef(); + return S_OK; + } + return _outerUnknown->QueryInterface(riid, ppvObject); + } + STDMETHOD_(ULONG, AddRef)(void) + { + return _weakReference->AddStrongRef(); + } + STDMETHOD_(ULONG, Release)(void) + { + return _weakReference->ReleaseStrongRef(); + } + + STDMETHOD(GetRuntimeClassName)(HSTRING* pRuntimeClassName) + { + return E_NOTIMPL; + } + + STDMETHOD(GetIids)( + ULONG *iidCount, + IID **iids) + { + return E_NOTIMPL; + } + + STDMETHOD(GetTrustLevel)(TrustLevel *trustLevel) + { + *trustLevel = FullTrust; + return S_OK; + } + }; + + struct AggregatedWeakReferenceSource : IInspectable + { + private: + IUnknown* _outerUnknown; + ComSmartPtr _weakReference; + public: + AggregatedWeakReferenceSource(IUnknown* outerUnknown) + :_outerUnknown(outerUnknown), + _weakReference(new WeakReferenceSource(outerUnknown)) + { + } + + STDMETHOD(GetRuntimeClassName)(HSTRING* pRuntimeClassName) + { + return E_NOTIMPL; + } + + STDMETHOD(GetIids)( + ULONG *iidCount, + IID **iids) + { + return E_NOTIMPL; + } + + STDMETHOD(GetTrustLevel)(TrustLevel *trustLevel) + { + *trustLevel = FullTrust; + return S_OK; + } + + STDMETHOD(QueryInterface)( + /* [in] */ REFIID riid, + /* [iid_is][out] */ void ** ppvObject) + { + if (riid == __uuidof(IWeakReferenceSource)) + { + return _weakReference->QueryInterface(riid, ppvObject); + } + return _outerUnknown->QueryInterface(riid, ppvObject); + } + STDMETHOD_(ULONG, AddRef)(void) + { + return _outerUnknown->AddRef(); + } + STDMETHOD_(ULONG, Release)(void) + { + return _outerUnknown->Release(); + } + }; } extern "C" DLL_EXPORT WeakReferencableObject* STDMETHODCALLTYPE CreateWeakReferencableObject() { return new WeakReferencableObject(); } + +extern "C" DLL_EXPORT AggregatedWeakReferenceSource* STDMETHODCALLTYPE CreateAggregatedWeakReferenceObject(IUnknown* pOuter) +{ + return new AggregatedWeakReferenceSource(pOuter); +} diff --git a/src/tests/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs b/src/tests/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs index 38e40d1..1d2ebe8 100644 --- a/src/tests/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs +++ b/src/tests/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs @@ -14,6 +14,9 @@ namespace ComWrappersTests { [DllImport(nameof(WeakReferenceNative))] public static extern IntPtr CreateWeakReferencableObject(); + + [DllImport(nameof(WeakReferenceNative))] + public static extern IntPtr CreateAggregatedWeakReferenceObject(IntPtr outer); } public struct VtblPtr @@ -28,71 +31,96 @@ namespace ComWrappersTests Marshalling, } - public class WeakReferenceableWrapper + public unsafe class WeakReferenceableWrapper { private struct Vtbl { - public IntPtr QueryInterface; - public _AddRef AddRef; - public _Release Release; + public delegate* unmanaged QueryInterface; + public delegate* unmanaged AddRef; + public delegate* unmanaged Release; } - private delegate int _AddRef(IntPtr This); - private delegate int _Release(IntPtr This); - private readonly IntPtr instance; private readonly Vtbl vtable; + private readonly bool releaseInFinalizer; public WrapperRegistration Registration { get; } - public WeakReferenceableWrapper(IntPtr instance, WrapperRegistration reg) + public WeakReferenceableWrapper(IntPtr instance, WrapperRegistration reg, bool releaseInFinalizer = true) { var inst = Marshal.PtrToStructure(instance); this.vtable = Marshal.PtrToStructure(inst.Vtbl); this.instance = instance; + this.releaseInFinalizer = releaseInFinalizer; Registration = reg; } + public int QueryInterface(Guid iid, out IntPtr ptr) + { + fixed(IntPtr* ppv = &ptr) + { + return this.vtable.QueryInterface(this.instance, &iid, ppv); + } + } + ~WeakReferenceableWrapper() { - if (this.instance != IntPtr.Zero) + if (this.instance != IntPtr.Zero && this.releaseInFinalizer) { this.vtable.Release(this.instance); } } } - class Program + class DerivedObject : ICustomQueryInterface { - class TestComWrappers : ComWrappers + private WeakReferenceableWrapper inner; + public DerivedObject(TestComWrappers comWrappersInstance) { - public WrapperRegistration Registration { get; } + IntPtr innerInstance = WeakReferenceNative.CreateAggregatedWeakReferenceObject( + comWrappersInstance.GetOrCreateComInterfaceForObject(this, CreateComInterfaceFlags.None)); + inner = new WeakReferenceableWrapper(innerInstance, comWrappersInstance.Registration, releaseInFinalizer: false); + comWrappersInstance.GetOrRegisterObjectForComInstance(innerInstance, CreateObjectFlags.Aggregation, this); + } - public TestComWrappers(WrapperRegistration reg = WrapperRegistration.Local) - { - Registration = reg; - } + public CustomQueryInterfaceResult GetInterface(ref Guid iid, out IntPtr ppv) + { + return inner.QueryInterface(iid, out ppv) == 0 ? CustomQueryInterfaceResult.Handled : CustomQueryInterfaceResult.Failed; + } + } - protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) - { - count = 0; - return null; - } + class TestComWrappers : ComWrappers + { + public WrapperRegistration Registration { get; } - protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flag) - { - Marshal.AddRef(externalComObject); - return new WeakReferenceableWrapper(externalComObject, Registration); - } + public TestComWrappers(WrapperRegistration reg = WrapperRegistration.Local) + { + Registration = reg; + } - protected override void ReleaseObjects(IEnumerable objects) - { - } + protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) + { + count = 0; + return null; + } + + protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flag) + { + Marshal.AddRef(externalComObject); + return new WeakReferenceableWrapper(externalComObject, Registration); + } - public static readonly TestComWrappers TrackerSupportInstance = new TestComWrappers(WrapperRegistration.TrackerSupport); - public static readonly TestComWrappers MarshallingInstance = new TestComWrappers(WrapperRegistration.Marshalling); + protected override void ReleaseObjects(IEnumerable objects) + { } + public static readonly TestComWrappers TrackerSupportInstance = new TestComWrappers(WrapperRegistration.TrackerSupport); + public static readonly TestComWrappers MarshallingInstance = new TestComWrappers(WrapperRegistration.Marshalling); + } + + class Program + { + private static void ValidateWeakReferenceState(WeakReference wr, bool expectedIsAlive, TestComWrappers sourceWrappers = null) { WeakReferenceableWrapper target; @@ -135,7 +163,7 @@ namespace ComWrappersTests // a global ComWrappers instance. If the RCW was created throug a local ComWrappers instance, the weak // reference should be dead and stay dead once the RCW is collected. bool supportsRehydration = cw.Registration != WrapperRegistration.Local; - + Console.WriteLine($" -- Validate RCW recreation"); ValidateWeakReferenceState(weakRef, expectedIsAlive: supportsRehydration, cw); @@ -221,6 +249,26 @@ namespace ComWrappersTests Assert.IsNull(weakRef.Target); } + static void ValidateAggregatedWeakReference() + { + Console.WriteLine("Validate weak reference with aggregation."); + var (handle, weakRef) = GetWeakReference(); + + GC.Collect(); + GC.WaitForPendingFinalizers(); + + Assert.IsNull(handle.Target); + Assert.IsFalse(weakRef.TryGetTarget(out _)); + + static (GCHandle handle, WeakReference) GetWeakReference() + { + DerivedObject obj = new DerivedObject(TestComWrappers.TrackerSupportInstance); + // We use an explicit weak GC handle here to enable us to validate that we are using "weak" GCHandle + // semantics with the weak reference. + return (GCHandle.Alloc(obj, GCHandleType.Weak), new WeakReference(obj)); + } + } + static int Main(string[] doNotUse) { try @@ -235,6 +283,7 @@ namespace ComWrappersTests ComWrappers.RegisterForTrackerSupport(TestComWrappers.TrackerSupportInstance); ValidateGlobalInstanceTrackerSupport(); + ValidateAggregatedWeakReference(); ValidateLocalInstance(); } -- 2.7.4