From bda2db238eec2f1ff81e25d709a03d966c8dd116 Mon Sep 17 00:00:00 2001 From: Elinor Fung <47805090+elinor-fung@users.noreply.github.com> Date: Tue, 16 Jun 2020 20:10:20 -0700 Subject: [PATCH] ComWrappers isolation (#37861) - Assign/track ID of ComWrappers used to create object/wrapper. - Only reuse object/wrapper created by same instance for GetOrCreate* - Rehydrate RCW on WeakReference for global ComWrapper instances only --- .../System/Runtime/InteropServices/ComWrappers.cs | 33 +++-- src/coreclr/src/vm/appdomain.hpp | 4 +- src/coreclr/src/vm/gchandleutilities.h | 25 ++-- src/coreclr/src/vm/interoplibinterface.cpp | 140 +++++++++++++----- src/coreclr/src/vm/interoplibinterface.h | 15 +- src/coreclr/src/vm/interoputil.cpp | 7 +- src/coreclr/src/vm/syncblk.h | 66 +++++++-- src/coreclr/src/vm/weakreferencenative.cpp | 93 +++++++----- .../src/Interop/COM/ComWrappers/API/Program.cs | 45 ++++++ .../ComWrappers/WeakReference/WeakReferenceTest.cs | 160 ++++++++++++++++++--- 10 files changed, 456 insertions(+), 132 deletions(-) diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs index 6bdd48d..b404dc6 100644 --- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs +++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs @@ -126,12 +126,20 @@ namespace System.Runtime.InteropServices /// private static ComWrappers? s_globalInstanceForMarshalling; + private static long s_instanceCounter; + private readonly long id = Interlocked.Increment(ref s_instanceCounter); + /// /// Create a COM representation of the supplied object that can be passed to a non-managed environment. /// /// The managed object to expose outside the .NET runtime. /// Flags used to configure the generated interface. /// The generated COM interface that can be passed outside the .NET runtime. + /// + /// If a COM representation was previously created for the specified using + /// this instance, the previously created COM interface will be returned. + /// If not, a new one will be created. + /// public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfaceFlags flags) { IntPtr ptr; @@ -152,16 +160,16 @@ namespace System.Runtime.InteropServices /// /// If is null, the global instance (if registered) will be used. /// - private static bool TryGetOrCreateComInterfaceForObjectInternal(ComWrappers? impl, object instance, CreateComInterfaceFlags flags, out IntPtr retValue) + private static bool TryGetOrCreateComInterfaceForObjectInternal(ComWrappers impl, object instance, CreateComInterfaceFlags flags, out IntPtr retValue) { if (instance == null) throw new ArgumentNullException(nameof(instance)); - return TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), ObjectHandleOnStack.Create(ref instance), flags, out retValue); + return TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), impl.id, ObjectHandleOnStack.Create(ref instance), flags, out retValue); } [DllImport(RuntimeHelpers.QCall)] - private static extern bool TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack comWrappersImpl, ObjectHandleOnStack instance, CreateComInterfaceFlags flags, out IntPtr retValue); + private static extern bool TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack comWrappersImpl, long wrapperId, ObjectHandleOnStack instance, CreateComInterfaceFlags flags, out IntPtr retValue); /// /// Compute the desired Vtable for respecting the values of . @@ -211,6 +219,11 @@ namespace System.Runtime.InteropServices /// Object to import for usage into the .NET runtime. /// Flags used to describe the external object. /// Returns a managed object associated with the supplied external COM object. + /// + /// If a managed object was previously created for the specified + /// using this instance, the previously created object will be returned. + /// If not, a new one will be created. + /// public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags) { object? obj; @@ -288,18 +301,18 @@ namespace System.Runtime.InteropServices /// /// If is null, the global instance (if registered) will be used. /// - private static bool TryGetOrCreateObjectForComInstanceInternal(ComWrappers? impl, IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe, out object? retValue) + private static bool TryGetOrCreateObjectForComInstanceInternal(ComWrappers impl, IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe, out object? retValue) { if (externalComObject == IntPtr.Zero) throw new ArgumentNullException(nameof(externalComObject)); object? wrapperMaybeLocal = wrapperMaybe; retValue = null; - return TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue)); + return TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), impl.id, externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue)); } [DllImport(RuntimeHelpers.QCall)] - private static extern bool TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue); + private static extern bool TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, long wrapperId, IntPtr externalComObject, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue); /// /// Called when a request is made for a collection of objects to be released outside of normal object or COM interface lifetime. @@ -332,13 +345,13 @@ namespace System.Runtime.InteropServices throw new InvalidOperationException(SR.InvalidOperation_ResetGlobalComWrappersInstance); } - SetGlobalInstanceRegisteredForTrackerSupport(); + SetGlobalInstanceRegisteredForTrackerSupport(instance.id); } [DllImport(RuntimeHelpers.QCall)] [SuppressGCTransition] - private static extern void SetGlobalInstanceRegisteredForTrackerSupport(); + private static extern void SetGlobalInstanceRegisteredForTrackerSupport(long id); /// /// Register a instance to be used as the global instance for marshalling in the runtime. @@ -366,12 +379,12 @@ namespace System.Runtime.InteropServices // Indicate to the runtime that a global instance has been registered for marshalling. // This allows the native runtime know to call into the managed ComWrappers only if a // global instance is registered for marshalling. - SetGlobalInstanceRegisteredForMarshalling(); + SetGlobalInstanceRegisteredForMarshalling(instance.id); } [DllImport(RuntimeHelpers.QCall)] [SuppressGCTransition] - private static extern void SetGlobalInstanceRegisteredForMarshalling(); + private static extern void SetGlobalInstanceRegisteredForMarshalling(long id); /// /// Get the runtime provided IUnknown implementation. diff --git a/src/coreclr/src/vm/appdomain.hpp b/src/coreclr/src/vm/appdomain.hpp index afdee41..6b1393f 100644 --- a/src/coreclr/src/vm/appdomain.hpp +++ b/src/coreclr/src/vm/appdomain.hpp @@ -1099,10 +1099,10 @@ public: return ::CreateRefcountedHandle(m_handleStore, object); } - OBJECTHANDLE CreateNativeComWeakHandle(OBJECTREF object, IWeakReference* pComWeakReference) + OBJECTHANDLE CreateNativeComWeakHandle(OBJECTREF object, NativeComWeakHandleInfo* pComWeakHandleInfo) { WRAPPER_NO_CONTRACT; - return ::CreateNativeComWeakHandle(m_handleStore, object, pComWeakReference); + return ::CreateNativeComWeakHandle(m_handleStore, object, pComWeakHandleInfo); } #endif // FEATURE_COMINTEROP diff --git a/src/coreclr/src/vm/gchandleutilities.h b/src/coreclr/src/vm/gchandleutilities.h index d90a2f0..039df61 100644 --- a/src/coreclr/src/vm/gchandleutilities.h +++ b/src/coreclr/src/vm/gchandleutilities.h @@ -201,9 +201,16 @@ inline OBJECTHANDLE CreateGlobalRefcountedHandle(OBJECTREF object) // Special handle creation convenience functions #ifdef FEATURE_COMINTEROP -inline OBJECTHANDLE CreateNativeComWeakHandle(IGCHandleStore* store, OBJECTREF object, IWeakReference* pComWeakReference) + +struct NativeComWeakHandleInfo +{ + IWeakReference *WeakReference; + INT64 WrapperId; +}; + +inline OBJECTHANDLE CreateNativeComWeakHandle(IGCHandleStore* store, OBJECTREF object, NativeComWeakHandleInfo* pComWeakHandleInfo) { - OBJECTHANDLE hnd = store->CreateHandleWithExtraInfo(OBJECTREFToObject(object), HNDTYPE_WEAK_NATIVE_COM, (void*)pComWeakReference); + OBJECTHANDLE hnd = store->CreateHandleWithExtraInfo(OBJECTREFToObject(object), HNDTYPE_WEAK_NATIVE_COM, (void*)pComWeakHandleInfo); if (!hnd) { COMPlusThrowOM(); @@ -374,14 +381,16 @@ inline void DestroyNativeComWeakHandle(OBJECTHANDLE handle) } CONTRACTL_END; - // Release the WinRT weak reference if we have one. We're assuming that this will not reenter the - // runtime, since if we are pointing at a managed object, we should not be using HNDTYPE_WEAK_NATIVE_COM - // but rather HNDTYPE_WEAK_SHORT or HNDTYPE_WEAK_LONG. + // Delete the COM info and release the weak reference if we have one. We're assuming that + // this will not reenter the runtime, since if we are pointing at a managed object, we should + // not be using HNDTYPE_WEAK_NATIVE_COM but rather HNDTYPE_WEAK_SHORT or HNDTYPE_WEAK_LONG. void* pExtraInfo = GCHandleUtilities::GetGCHandleManager()->GetExtraInfoFromHandle(handle); - IWeakReference* pWinRTWeakReference = reinterpret_cast(pExtraInfo); - if (pWinRTWeakReference != nullptr) + NativeComWeakHandleInfo* comWeakHandleInfo = reinterpret_cast(pExtraInfo); + if (comWeakHandleInfo != nullptr) { - pWinRTWeakReference->Release(); + _ASSERTE(comWeakHandleInfo->WeakReference != nullptr); + comWeakHandleInfo->WeakReference->Release(); + delete comWeakHandleInfo; } DiagHandleDestroyed(handle); diff --git a/src/coreclr/src/vm/interoplibinterface.cpp b/src/coreclr/src/vm/interoplibinterface.cpp index a17bf75..ffd8d4f 100644 --- a/src/coreclr/src/vm/interoplibinterface.cpp +++ b/src/coreclr/src/vm/interoplibinterface.cpp @@ -26,6 +26,7 @@ namespace void* Identity; void* ThreadContext; DWORD SyncBlockIndex; + INT64 WrapperId; enum { @@ -41,6 +42,7 @@ namespace _In_ IUnknown* identity, _In_opt_ void* threadContext, _In_ DWORD syncBlockIndex, + _In_ INT64 wrapperId, _In_ DWORD flags) { CONTRACTL @@ -57,6 +59,7 @@ namespace cxt->Identity = (void*)identity; cxt->ThreadContext = threadContext; cxt->SyncBlockIndex = syncBlockIndex; + cxt->WrapperId = wrapperId; cxt->Flags = flags; } @@ -91,6 +94,40 @@ namespace _ASSERTE(IsActive()); return ObjectToOBJECTREF(g_pSyncTable[SyncBlockIndex].m_Object); } + + struct Key + { + public: + Key(void* identity, INT64 wrapperId) + : _identity { identity } + , _wrapperId { wrapperId } + { + _ASSERTE(identity != NULL); + _ASSERTE(wrapperId != ComWrappersNative::InvalidWrapperId); + } + + DWORD Hash() const + { + DWORD hash = (_wrapperId >> 32) ^ (_wrapperId & 0xFFFFFFFF); +#if POINTER_BITS == 32 + return hash ^ (DWORD)_identity; +#else + INT64 identityInt64 = (INT64)_identity; + return hash ^ (identityInt64 >> 32) ^ (identityInt64 & 0xFFFFFFFF); +#endif + } + + bool operator==(const Key & rhs) const { return _identity == rhs._identity && _wrapperId == rhs._wrapperId; } + + private: + void* _identity; + INT64 _wrapperId; + }; + + Key GetKey() const + { + return Key(Identity, WrapperId); + } }; const DWORD ExternalObjectContext::InvalidSyncBlockIndex = 0; // See syncblk.h @@ -171,10 +208,10 @@ namespace class Traits : public DefaultSHashTraits { public: - using key_t = void*; - static const key_t GetKey(_In_ element_t e) { LIMITED_METHOD_CONTRACT; return (key_t)e->Identity; } - static count_t Hash(_In_ key_t key) { LIMITED_METHOD_CONTRACT; return (count_t)(size_t)key; } - static bool Equals(_In_ key_t lhs, _In_ key_t rhs) { LIMITED_METHOD_CONTRACT; return (lhs == rhs); } + using key_t = ExternalObjectContext::Key; + static const key_t GetKey(_In_ element_t e) { LIMITED_METHOD_CONTRACT; return e->GetKey(); } + static count_t Hash(_In_ key_t key) { LIMITED_METHOD_CONTRACT; return (count_t)key.Hash(); } + static bool Equals(_In_ key_t lhs, _In_ key_t rhs) { LIMITED_METHOD_CONTRACT; return lhs == rhs; } }; // Alias some useful types @@ -311,7 +348,7 @@ namespace RETURN gc.arrRef; } - ExternalObjectContext* Find(_In_ IUnknown* instance) + ExternalObjectContext* Find(_In_ const ExternalObjectContext::Key& key) { CONTRACT(ExternalObjectContext*) { @@ -319,7 +356,6 @@ namespace GC_NOTRIGGER; MODE_COOPERATIVE; PRECONDITION(IsLockHeld()); - PRECONDITION(instance != NULL); POSTCONDITION(CheckPointer(RETVAL, NULL_OK)); } CONTRACT_END; @@ -327,7 +363,7 @@ namespace // Forbid the GC from messing with the hash table. GCX_FORBID(); - RETURN _hashMap.Lookup(instance); + RETURN _hashMap.Lookup(key); } ExternalObjectContext* Add(_In_ ExternalObjectContext* cxt) @@ -340,7 +376,7 @@ namespace PRECONDITION(IsLockHeld()); PRECONDITION(!Traits::IsNull(cxt) && !Traits::IsDeleted(cxt)); PRECONDITION(cxt->Identity != NULL); - PRECONDITION(Find(static_cast(cxt->Identity)) == NULL); + PRECONDITION(Find(cxt->GetKey()) == NULL); POSTCONDITION(RETVAL == cxt); } CONTRACT_END; @@ -349,7 +385,7 @@ namespace RETURN cxt; } - ExternalObjectContext* FindOrAdd(_In_ IUnknown* key, _In_ ExternalObjectContext* newCxt) + ExternalObjectContext* FindOrAdd(_In_ const ExternalObjectContext::Key& key, _In_ ExternalObjectContext* newCxt) { CONTRACT(ExternalObjectContext*) { @@ -357,9 +393,8 @@ namespace GC_NOTRIGGER; MODE_COOPERATIVE; PRECONDITION(IsLockHeld()); - PRECONDITION(key != NULL); PRECONDITION(!Traits::IsNull(newCxt) && !Traits::IsDeleted(newCxt)); - PRECONDITION(key == newCxt->Identity); + PRECONDITION(key == newCxt->GetKey()); POSTCONDITION(CheckPointer(RETVAL)); } CONTRACT_END; @@ -392,7 +427,7 @@ namespace } CONTRACTL_END; - _hashMap.Remove(cxt->Identity); + _hashMap.Remove(cxt->GetKey()); } }; @@ -400,8 +435,8 @@ namespace Volatile ExtObjCxtCache::g_Instance; // Indicator for if a ComWrappers implementation is globally registered - bool g_IsGlobalComWrappersRegisteredForMarshalling; - bool g_IsGlobalComWrappersRegisteredForTrackerSupport; + INT64 g_marshallingGlobalInstanceId = ComWrappersNative::InvalidWrapperId; + INT64 g_trackerSupportGlobalInstanceId = ComWrappersNative::InvalidWrapperId; // Defined handle types for the specific object uses. const HandleType InstanceHandleType{ HNDTYPE_STRONG }; @@ -522,6 +557,7 @@ namespace bool TryGetOrCreateComInterfaceForObjectInternal( _In_opt_ OBJECTREF impl, + _In_ INT64 wrapperId, _In_ OBJECTREF instance, _In_ CreateComInterfaceFlags flags, _In_ ComWrappersScenario scenario, @@ -533,7 +569,8 @@ namespace MODE_COOPERATIVE; PRECONDITION(instance != NULL); PRECONDITION(wrapperRaw != NULL); - PRECONDITION((impl != NULL && scenario == ComWrappersScenario::Instance)|| (impl == NULL && scenario != ComWrappersScenario::Instance)); + PRECONDITION((impl != NULL && scenario == ComWrappersScenario::Instance) || (impl == NULL && scenario != ComWrappersScenario::Instance)); + PRECONDITION(wrapperId != ComWrappersNative::InvalidWrapperId); } CONTRACT_END; @@ -559,7 +596,7 @@ namespace _ASSERTE(syncBlock->IsPrecious()); // Query the associated InteropSyncBlockInfo for an existing managed object wrapper. - if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe)) + if (!interopInfo->TryGetManagedObjectComWrapper(wrapperId, &wrapperRawMaybe)) { // Compute VTables for the new existing COM object using the supplied COM Wrappers implementation. // @@ -570,7 +607,7 @@ namespace void* vtables = CallComputeVTables(scenario, &gc.implRef, &gc.instRef, flags, &vtableCount); // Re-query the associated InteropSyncBlockInfo for an existing managed object wrapper. - if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe) + if (!interopInfo->TryGetManagedObjectComWrapper(wrapperId, &wrapperRawMaybe) && ((vtables != nullptr && vtableCount > 0) || (vtableCount == 0))) { OBJECTHANDLE instHandle = GetAppDomain()->CreateTypedHandle(gc.instRef, InstanceHandleType); @@ -590,14 +627,14 @@ namespace _ASSERTE(!newWrapper.IsNull()); // Try setting the newly created managed object wrapper on the InteropSyncBlockInfo. - if (!interopInfo->TrySetManagedObjectComWrapper(newWrapper)) + if (!interopInfo->TrySetManagedObjectComWrapper(wrapperId, newWrapper)) { // The new wrapper couldn't be set which means a wrapper already exists. newWrapper.Release(); // If the managed object wrapper couldn't be set, then // it should be possible to get the current one. - if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe)) + if (!interopInfo->TryGetManagedObjectComWrapper(wrapperId, &wrapperRawMaybe)) { UNREACHABLE(); } @@ -638,6 +675,7 @@ namespace bool TryGetOrCreateObjectForComInstanceInternal( _In_opt_ OBJECTREF impl, + _In_ INT64 wrapperId, _In_ IUnknown* identity, _In_ CreateObjectFlags flags, _In_ ComWrappersScenario scenario, @@ -651,6 +689,7 @@ namespace PRECONDITION(identity != NULL); PRECONDITION(objRef != NULL); PRECONDITION((impl != NULL && scenario == ComWrappersScenario::Instance) || (impl == NULL && scenario != ComWrappersScenario::Instance)); + PRECONDITION(wrapperId != ComWrappersNative::InvalidWrapperId); } CONTRACT_END; @@ -672,13 +711,15 @@ namespace ExtObjCxtCache* cache = ExtObjCxtCache::GetInstance(); InteropLib::OBJECTHANDLE handle = NULL; + ExternalObjectContext::Key cacheKey(identity, wrapperId); + // Check if the user requested a unique instance. bool uniqueInstance = !!(flags & CreateObjectFlags::CreateObjectFlags_UniqueInstance); if (!uniqueInstance) { // Query the external object cache ExtObjCxtCache::LockHolder lock(cache); - extObjCxt = cache->Find(identity); + extObjCxt = cache->Find(cacheKey); // If is no object found in the cache, check if the object COM instance is actually the CCW // representing a managed object. For the scenario of marshalling through a global instance, @@ -750,6 +791,7 @@ namespace identity, GetCurrentCtxCookie(), gc.objRefMaybe->GetSyncBlockIndex(), + wrapperId, flags); if (uniqueInstance) @@ -760,7 +802,7 @@ namespace { // Attempt to insert the new context into the cache. ExtObjCxtCache::LockHolder lock(cache); - extObjCxt = cache->FindOrAdd(identity, resultHolder.GetContext()); + extObjCxt = cache->FindOrAdd(cacheKey, resultHolder.GetContext()); } // If the returned context matches the new context it means the @@ -1039,6 +1081,7 @@ namespace InteropLibImports // Get wrapper for external object bool success = TryGetOrCreateObjectForComInstanceInternal( gc.implRef, + g_trackerSupportGlobalInstanceId, externalComObject, externalObjectFlags, ComWrappersScenario::TrackerSupportGlobalInstance, @@ -1051,6 +1094,7 @@ namespace InteropLibImports // Get wrapper for managed object success = TryGetOrCreateComInterfaceForObjectInternal( gc.implRef, + g_trackerSupportGlobalInstanceId, gc.objRef, trackerTargetFlags, ComWrappersScenario::TrackerSupportGlobalInstance, @@ -1220,6 +1264,7 @@ namespace InteropLibImports BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateComInterfaceForObject( _In_ QCall::ObjectHandleOnStack comWrappersImpl, + _In_ INT64 wrapperId, _In_ QCall::ObjectHandleOnStack instance, _In_ INT32 flags, _Outptr_ void** wrapper) @@ -1236,6 +1281,7 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateComInterfaceForObject( GCX_COOP(); success = TryGetOrCreateComInterfaceForObjectInternal( ObjectToOBJECTREF(*comWrappersImpl.m_ppObject), + wrapperId, ObjectToOBJECTREF(*instance.m_ppObject), (CreateComInterfaceFlags)flags, ComWrappersScenario::Instance, @@ -1249,6 +1295,7 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateComInterfaceForObject( BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance( _In_ QCall::ObjectHandleOnStack comWrappersImpl, + _In_ INT64 wrapperId, _In_ void* ext, _In_ INT32 flags, _In_ QCall::ObjectHandleOnStack wrapperMaybe, @@ -1278,6 +1325,7 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance( OBJECTREF newObj; success = TryGetOrCreateObjectForComInstanceInternal( ObjectToOBJECTREF(*comWrappersImpl.m_ppObject), + wrapperId, identity, (CreateObjectFlags)flags, ComWrappersScenario::Instance, @@ -1387,19 +1435,25 @@ void ComWrappersNative::MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe) _ASSERTE(SUCCEEDED(hr) || hr == E_INVALIDARG); } -void QCALLTYPE GlobalComWrappersForMarshalling::SetGlobalInstanceRegisteredForMarshalling() +void QCALLTYPE GlobalComWrappersForMarshalling::SetGlobalInstanceRegisteredForMarshalling(INT64 id) { QCALL_CONTRACT_NO_GC_TRANSITION; - _ASSERTE(!g_IsGlobalComWrappersRegisteredForMarshalling); - g_IsGlobalComWrappersRegisteredForMarshalling = true; + _ASSERTE(g_marshallingGlobalInstanceId == ComWrappersNative::InvalidWrapperId && id != ComWrappersNative::InvalidWrapperId); + g_marshallingGlobalInstanceId = id; +} + +bool GlobalComWrappersForMarshalling::IsRegisteredInstance(INT64 id) +{ + return g_marshallingGlobalInstanceId != ComWrappersNative::InvalidWrapperId + && g_marshallingGlobalInstanceId == id; } bool GlobalComWrappersForMarshalling::TryGetOrCreateComInterfaceForObject( _In_ OBJECTREF instance, _Outptr_ void** wrapperRaw) { - if (!g_IsGlobalComWrappersRegisteredForMarshalling) + if (g_marshallingGlobalInstanceId == ComWrappersNative::InvalidWrapperId) return false; // Switch to Cooperative mode since object references @@ -1412,6 +1466,7 @@ bool GlobalComWrappersForMarshalling::TryGetOrCreateComInterfaceForObject( // Passing NULL as the ComWrappers implementation indicates using the globally registered instance return TryGetOrCreateComInterfaceForObjectInternal( NULL, + g_marshallingGlobalInstanceId, instance, flags, ComWrappersScenario::MarshallingGlobalInstance, @@ -1424,7 +1479,7 @@ bool GlobalComWrappersForMarshalling::TryGetOrCreateObjectForComInstance( _In_ INT32 objFromComIPFlags, _Out_ OBJECTREF* objRef) { - if (!g_IsGlobalComWrappersRegisteredForMarshalling) + if (g_marshallingGlobalInstanceId == ComWrappersNative::InvalidWrapperId) return false; // Determine the true identity of the object @@ -1448,6 +1503,7 @@ bool GlobalComWrappersForMarshalling::TryGetOrCreateObjectForComInstance( // Passing NULL as the ComWrappers implementation indicates using the globally registered instance return TryGetOrCreateObjectForComInstanceInternal( NULL /*comWrappersImpl*/, + g_marshallingGlobalInstanceId, identity, (CreateObjectFlags)flags, ComWrappersScenario::MarshallingGlobalInstance, @@ -1456,12 +1512,18 @@ bool GlobalComWrappersForMarshalling::TryGetOrCreateObjectForComInstance( } } -void QCALLTYPE GlobalComWrappersForTrackerSupport::SetGlobalInstanceRegisteredForTrackerSupport() +void QCALLTYPE GlobalComWrappersForTrackerSupport::SetGlobalInstanceRegisteredForTrackerSupport(INT64 id) { QCALL_CONTRACT_NO_GC_TRANSITION; - _ASSERTE(!g_IsGlobalComWrappersRegisteredForTrackerSupport); - g_IsGlobalComWrappersRegisteredForTrackerSupport = true; + _ASSERTE(g_trackerSupportGlobalInstanceId == ComWrappersNative::InvalidWrapperId && id != ComWrappersNative::InvalidWrapperId); + g_trackerSupportGlobalInstanceId = id; +} + +bool GlobalComWrappersForTrackerSupport::IsRegisteredInstance(INT64 id) +{ + return g_trackerSupportGlobalInstanceId != ComWrappersNative::InvalidWrapperId + && g_trackerSupportGlobalInstanceId == id; } bool GlobalComWrappersForTrackerSupport::TryGetOrCreateComInterfaceForObject( @@ -1475,12 +1537,13 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateComInterfaceForObject( } CONTRACTL_END; - if (!g_IsGlobalComWrappersRegisteredForTrackerSupport) + if (g_trackerSupportGlobalInstanceId == ComWrappersNative::InvalidWrapperId) return false; // Passing NULL as the ComWrappers implementation indicates using the globally registered instance return TryGetOrCreateComInterfaceForObjectInternal( NULL, + g_trackerSupportGlobalInstanceId, instance, CreateComInterfaceFlags::CreateComInterfaceFlags_TrackerSupport, ComWrappersScenario::TrackerSupportGlobalInstance, @@ -1498,7 +1561,7 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance( } CONTRACTL_END; - if (!g_IsGlobalComWrappersRegisteredForTrackerSupport) + if (g_trackerSupportGlobalInstanceId == ComWrappersNative::InvalidWrapperId) return false; // Determine the true identity of the object @@ -1513,6 +1576,7 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance( // Passing NULL as the ComWrappers implementation indicates using the globally registered instance return TryGetOrCreateObjectForComInstanceInternal( NULL /*comWrappersImpl*/, + g_trackerSupportGlobalInstanceId, identity, CreateObjectFlags::CreateObjectFlags_TrackerObject, ComWrappersScenario::TrackerSupportGlobalInstance, @@ -1520,7 +1584,7 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance( objRef); } -IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid) +IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId) { CONTRACTL { @@ -1528,11 +1592,14 @@ IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTE GC_TRIGGERS; MODE_COOPERATIVE; PRECONDITION(CheckPointer(objectPROTECTED)); + PRECONDITION(CheckPointer(wrapperId)); } CONTRACTL_END; ASSERT_PROTECTED(objectPROTECTED); + *wrapperId = ComWrappersNative::InvalidWrapperId; + SyncBlock* syncBlock = (*objectPROTECTED)->PassiveGetSyncBlock(); if (syncBlock == nullptr) { @@ -1545,10 +1612,13 @@ IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTE return nullptr; } - void* context; - if (interopInfo->TryGetExternalComObjectContext(&context)) + void* contextMaybe; + if (interopInfo->TryGetExternalComObjectContext(&contextMaybe)) { - IUnknown* identity = reinterpret_cast(reinterpret_cast(context)->Identity); + ExternalObjectContext* context = reinterpret_cast(contextMaybe); + *wrapperId = context->WrapperId; + + IUnknown* identity = reinterpret_cast(context->Identity); GCX_PREEMP(); IUnknown* result; if (SUCCEEDED(identity->QueryInterface(riid, (void**)&result))) diff --git a/src/coreclr/src/vm/interoplibinterface.h b/src/coreclr/src/vm/interoplibinterface.h index 3e5abda..bae00fb 100644 --- a/src/coreclr/src/vm/interoplibinterface.h +++ b/src/coreclr/src/vm/interoplibinterface.h @@ -11,6 +11,9 @@ // Native calls for the managed ComWrappers API class ComWrappersNative { +public: + static const INT64 InvalidWrapperId = 0; + public: // Native QCalls for the abstract ComWrappers managed type. static void QCALLTYPE GetIUnknownImpl( _Out_ void** fpQueryInterface, @@ -19,12 +22,14 @@ public: // Native QCalls for the abstract ComWrappers managed type. static BOOL QCALLTYPE TryGetOrCreateComInterfaceForObject( _In_ QCall::ObjectHandleOnStack comWrappersImpl, + _In_ INT64 wrapperId, _In_ QCall::ObjectHandleOnStack instance, _In_ INT32 flags, _Outptr_ void** wrapperRaw); static BOOL QCALLTYPE TryGetOrCreateObjectForComInstance( _In_ QCall::ObjectHandleOnStack comWrappersImpl, + _In_ INT64 wrapperId, _In_ void* externalComObject, _In_ INT32 flags, _In_ QCall::ObjectHandleOnStack wrapperMaybe, @@ -39,7 +44,7 @@ public: // COM activation static void MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe); public: // Unwrapping support - static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid); + static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId); }; class GlobalComWrappersForMarshalling @@ -48,9 +53,11 @@ public: // Native QCall for the ComWrappers managed type to indicate a global instance // is registered for marshalling. This should be set if the private static member // representing the global instance for marshalling on ComWrappers is non-null. - static void QCALLTYPE SetGlobalInstanceRegisteredForMarshalling(); + static void QCALLTYPE SetGlobalInstanceRegisteredForMarshalling(_In_ INT64 id); public: // Functions operating on a registered global instance for marshalling + static bool IsRegisteredInstance(_In_ INT64 id); + static bool TryGetOrCreateComInterfaceForObject( _In_ OBJECTREF instance, _Outptr_ void** wrapperRaw); @@ -68,9 +75,11 @@ public: // Native QCall for the ComWrappers managed type to indicate a global instance // is registered for tracker support. This should be set if the private static member // representing the global instance for tracker support on ComWrappers is non-null. - static void QCALLTYPE SetGlobalInstanceRegisteredForTrackerSupport(); + static void QCALLTYPE SetGlobalInstanceRegisteredForTrackerSupport(_In_ INT64 id); public: // Functions operating on a registered global instance for tracker support + static bool IsRegisteredInstance(_In_ INT64 id); + static bool TryGetOrCreateComInterfaceForObject( _In_ OBJECTREF instance, _Outptr_ void** wrapperRaw); diff --git a/src/coreclr/src/vm/interoputil.cpp b/src/coreclr/src/vm/interoputil.cpp index 33cad06..a2f6846 100644 --- a/src/coreclr/src/vm/interoputil.cpp +++ b/src/coreclr/src/vm/interoputil.cpp @@ -1321,12 +1321,7 @@ void CleanupSyncBlockComData(InteropSyncBlockInfo* pInteropInfo) } #ifdef FEATURE_COMWRAPPERS - void* mocw; - if (pInteropInfo->TryGetManagedObjectComWrapper(&mocw)) - { - (void)pInteropInfo->TrySetManagedObjectComWrapper(NULL, mocw); - ComWrappersNative::DestroyManagedObjectComWrapper(mocw); - } + pInteropInfo->ClearManagedObjectComWrappers(&ComWrappersNative::DestroyManagedObjectComWrapper); void* eoc; if (pInteropInfo->TryGetExternalComObjectContext(&eoc)) diff --git a/src/coreclr/src/vm/syncblk.h b/src/coreclr/src/vm/syncblk.h index e0364df..08a4fee 100644 --- a/src/coreclr/src/vm/syncblk.h +++ b/src/coreclr/src/vm/syncblk.h @@ -602,6 +602,8 @@ class ComClassFactory; struct RCW; class RCWHolder; typedef DPTR(class ComCallWrapper) PTR_ComCallWrapper; + +#include "shash.h" #endif // FEATURE_COMINTEROP class InteropSyncBlockInfo @@ -791,22 +793,65 @@ public: #endif public: - bool TryGetManagedObjectComWrapper(_Out_ void** mocw) + bool TryGetManagedObjectComWrapper(_In_ INT64 wrapperId, _Out_ void** mocw) { LIMITED_METHOD_DAC_CONTRACT; - *mocw = m_managedObjectComWrapper; - return (*mocw != NULL); + + *mocw = NULL; + if (m_managedObjectComWrapperMap == NULL) + return false; + + CrstHolder lock(&m_managedObjectComWrapperLock); + return m_managedObjectComWrapperMap->Lookup(wrapperId, mocw); } #ifndef DACCESS_COMPILE - bool TrySetManagedObjectComWrapper(_In_ void* mocw, _In_ void* curr = NULL) + bool TrySetManagedObjectComWrapper(_In_ INT64 wrapperId, _In_ void* mocw, _In_ void* curr = NULL) { LIMITED_METHOD_CONTRACT; - return (FastInterlockCompareExchangePointer( - &m_managedObjectComWrapper, - mocw, - curr) == curr); + if (m_managedObjectComWrapperMap == NULL) + { + NewHolder map = new ManagedObjectComWrapperByIdMap(); + if (FastInterlockCompareExchangePointer((ManagedObjectComWrapperByIdMap**)&m_managedObjectComWrapperMap, (ManagedObjectComWrapperByIdMap *)map, NULL) == NULL) + { + map.SuppressRelease(); + m_managedObjectComWrapperLock.Init(CrstLeafLock); + } + + _ASSERTE(m_managedObjectComWrapperMap != NULL); + } + + CrstHolder lock(&m_managedObjectComWrapperLock); + + if (m_managedObjectComWrapperMap->LookupPtr(wrapperId) != curr) + return false; + + m_managedObjectComWrapperMap->Add(wrapperId, mocw); + return true; + } + + using EnumWrappersCallback = void(void* mocw); + void ClearManagedObjectComWrappers(EnumWrappersCallback* callback) + { + LIMITED_METHOD_CONTRACT; + + if (m_managedObjectComWrapperMap == NULL) + return; + + CrstHolder lock(&m_managedObjectComWrapperLock); + + if (callback != NULL) + { + ManagedObjectComWrapperByIdMap::Iterator iter = m_managedObjectComWrapperMap->Begin(); + while (iter != m_managedObjectComWrapperMap->End()) + { + callback(iter->Value()); + ++iter; + } + } + + m_managedObjectComWrapperMap->RemoveAll(); } #endif // !DACCESS_COMPILE @@ -831,8 +876,11 @@ public: private: // See InteropLib API for usage. - void* m_managedObjectComWrapper; void* m_externalComObjectContext; + + using ManagedObjectComWrapperByIdMap = MapSHash; + CrstExplicitInit m_managedObjectComWrapperLock; + NewHolder m_managedObjectComWrapperMap; #endif // FEATURE_COMINTEROP }; diff --git a/src/coreclr/src/vm/weakreferencenative.cpp b/src/coreclr/src/vm/weakreferencenative.cpp index 2cdb83c..bc52c61 100644 --- a/src/coreclr/src/vm/weakreferencenative.cpp +++ b/src/coreclr/src/vm/weakreferencenative.cpp @@ -103,9 +103,9 @@ private: #ifdef FEATURE_COMINTEROP -// Get a native COM weak reference for the object underlying an RCW if applicable. If the incoming object cannot -// use a native COM weak reference, nullptr is returned. Otherwise, an AddRef-ed IWeakReference* for the COM -// object underlying the RCW is returned. +// Get the native COM information for the object underlying an RCW if applicable. If the incoming object cannot +// use a native COM weak reference, nullptr is returned. Otherwise, a new NativeComWeakHandleInfo containing an +// AddRef-ed IWeakReference* for the COM object underlying the RCW is returned. // // In order to qualify to be used with a HNDTYPE_WEAK_NATIVE_COM, the incoming object must: // * be an RCW @@ -113,7 +113,7 @@ private: // * succeed when asked for an IWeakReference* // // Note that *pObject should be GC protected on the way into this method -IWeakReference* GetComWeakReference(OBJECTREF* pObject) +NativeComWeakHandleInfo* GetComWeakReferenceInfo(OBJECTREF* pObject) { CONTRACTL { @@ -134,6 +134,7 @@ IWeakReference* GetComWeakReference(OBJECTREF* pObject) MethodTable* pMT = (*pObject)->GetMethodTable(); SafeComHolder pWeakReferenceSource(nullptr); + INT64 wrapperId = ComWrappersNative::InvalidWrapperId; // If the object is not an RCW, then we do not want to use a native COM weak reference to it // If the object is a managed type deriving from a COM type, then we also do not want to use a native COM @@ -147,7 +148,7 @@ IWeakReference* GetComWeakReference(OBJECTREF* pObject) #ifdef FEATURE_COMWRAPPERS else { - pWeakReferenceSource = reinterpret_cast(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource)); + pWeakReferenceSource = reinterpret_cast(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource, &wrapperId)); } #endif @@ -163,7 +164,9 @@ IWeakReference* GetComWeakReference(OBJECTREF* pObject) return nullptr; } - return pWeakReference.Extract(); + NewHolder info = new NativeComWeakHandleInfo { pWeakReference.GetValue(), wrapperId }; + pWeakReference.SuppressRelease(); + return info.Extract(); } // Given an object handle that stores a native COM weak reference, attempt to create an RCW @@ -205,6 +208,7 @@ NOINLINE Object* LoadComWeakReferenceTarget(WEAKREFERENCEREF weakReference, Type // Since we're acquiring and releasing the lock multiple times, we need to check the handle state each time we // reacquire the lock to make sure that another thread hasn't reassigned the target of the handle or finalized it SafeComHolder pComWeakReference = nullptr; + INT64 wrapperId = ComWrappersNative::InvalidWrapperId; { WeakHandleSpinLockHolder handle(AcquireWeakHandleSpinLock(gc.weakReference), &gc.weakReference); GCX_NOTRIGGER(); @@ -233,9 +237,11 @@ NOINLINE Object* LoadComWeakReferenceTarget(WEAKREFERENCEREF weakReference, Type _ASSERTE(pComWeakReference.IsNull()); CONTRACT_VIOLATION(GCViolation); IGCHandleManager *mgr = GCHandleUtilities::GetGCHandleManager(); - pComWeakReference = reinterpret_cast(mgr->GetExtraInfoFromHandle(handle.Handle)); - if (!pComWeakReference.IsNull()) + NativeComWeakHandleInfo* comWeakHandleInfo = reinterpret_cast(mgr->GetExtraInfoFromHandle(handle.Handle)); + if (comWeakHandleInfo != nullptr) { + wrapperId = comWeakHandleInfo->WrapperId; + pComWeakReference = comWeakHandleInfo->WeakReference; pComWeakReference->AddRef(); } } @@ -268,9 +274,21 @@ NOINLINE Object* LoadComWeakReferenceTarget(WEAKREFERENCEREF weakReference, Type // If we were able to get an IUnkown identity for the object, then we can find or create an associated RCW for it. if (!pTargetIdentity.IsNull()) { - // Try the global COM wrappers first before falling back to the built-in system. - if (!GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(pTargetIdentity, &gc.rcw)) + if (wrapperId != ComWrappersNative::InvalidWrapperId) { + // Try the global COM wrappers + if (GlobalComWrappersForTrackerSupport::IsRegisteredInstance(wrapperId)) + { + (void)GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(pTargetIdentity, &gc.rcw); + } + else if (GlobalComWrappersForMarshalling::IsRegisteredInstance(wrapperId)) + { + (void)GlobalComWrappersForMarshalling::TryGetOrCreateObjectForComInstance(pTargetIdentity, ObjFromComIP::NONE, &gc.rcw); + } + } + else + { + // If the original RCW was not created through ComWrappers, fall back to the built-in system. GetObjectRefFromComIP(&gc.rcw, pTargetIdentity); } } @@ -428,21 +446,21 @@ FCIMPL3(void, WeakReferenceNative::Create, WeakReferenceObject * pThisUNSAFE, Ob // Create the handle. #ifdef FEATURE_COMINTEROP - IWeakReference* pRawComWeakReference = nullptr; + NativeComWeakHandleInfo *comWeakHandleInfo = nullptr; if (gc.pTarget != NULL) { - SyncBlock* pSyncBlock = gc.pTarget->PassiveGetSyncBlock(); + SyncBlock *pSyncBlock = gc.pTarget->PassiveGetSyncBlock(); if (pSyncBlock != nullptr && pSyncBlock->GetInteropInfoNoCreate() != nullptr) { - pRawComWeakReference = GetComWeakReference(&gc.pTarget); + comWeakHandleInfo = GetComWeakReferenceInfo(&gc.pTarget); } } - if (pRawComWeakReference != nullptr) + if (comWeakHandleInfo != nullptr) { - SafeComHolder pComWeakReferenceHolder(pRawComWeakReference); - gc.pThis->m_Handle = SetNativeComWeakReferenceHandle(GetAppDomain()->CreateNativeComWeakHandle(gc.pTarget, pComWeakReferenceHolder)); - pComWeakReferenceHolder.SuppressRelease(); + NewHolder infoHolder(comWeakHandleInfo); + gc.pThis->m_Handle = SetNativeComWeakReferenceHandle(GetAppDomain()->CreateNativeComWeakHandle(gc.pTarget, infoHolder)); + infoHolder.SuppressRelease(); } else #endif // FEATURE_COMINTEROP @@ -478,23 +496,22 @@ FCIMPL3(void, WeakReferenceOfTNative::Create, WeakReferenceObject * pThisUNSAFE, _ASSERTE(gc.pThis->GetMethodTable()->GetCanonicalMethodTable() == pWeakReferenceOfTCanonMT); - // Create the handle. #ifdef FEATURE_COMINTEROP - IWeakReference* pRawComWeakReference = nullptr; + NativeComWeakHandleInfo *comWeakHandleInfo = nullptr; if (gc.pTarget != NULL) { - SyncBlock* pSyncBlock = gc.pTarget->PassiveGetSyncBlock(); + SyncBlock *pSyncBlock = gc.pTarget->PassiveGetSyncBlock(); if (pSyncBlock != nullptr && pSyncBlock->GetInteropInfoNoCreate() != nullptr) { - pRawComWeakReference = GetComWeakReference(&gc.pTarget); + comWeakHandleInfo = GetComWeakReferenceInfo(&gc.pTarget); } } - if (pRawComWeakReference != nullptr) + if (comWeakHandleInfo != nullptr) { - SafeComHolder pComWeakReferenceHolder(pRawComWeakReference); - gc.pThis->m_Handle = SetNativeComWeakReferenceHandle(GetAppDomain()->CreateNativeComWeakHandle(gc.pTarget, pComWeakReferenceHolder)); - pComWeakReferenceHolder.SuppressRelease(); + NewHolder infoHolder(comWeakHandleInfo); + gc.pThis->m_Handle = SetNativeComWeakReferenceHandle(GetAppDomain()->CreateNativeComWeakHandle(gc.pTarget, infoHolder)); + infoHolder.SuppressRelease(); } else #endif // FEATURE_COMINTEROP @@ -747,7 +764,7 @@ NOINLINE void SetWeakReferenceTarget(WEAKREFERENCEREF weakReference, OBJECTREF t HELPER_METHOD_FRAME_BEGIN_ATTRIB_2(Frame::FRAME_ATTR_EXACT_DEPTH|Frame::FRAME_ATTR_CAPTURE_DEPTH_2, target, weakReference); #ifdef FEATURE_COMINTEROP - SafeComHolder pTargetWeakReference(GetComWeakReference(&target)); + NewHolder comWeakHandleInfo(GetComWeakReferenceInfo(&target)); #endif // FEATURE_COMINTEROP @@ -778,21 +795,23 @@ NOINLINE void SetWeakReferenceTarget(WEAKREFERENCEREF weakReference, OBJECTREF t if (IsNativeComWeakReferenceHandle(handle.RawHandle)) { - // If the existing reference is a native COM weak reference, we need to release its IWeakReference pointer - // and update it with the new weak reference pointer. If the incoming object is not an RCW that can - // use IWeakReference, then pTargetWeakReference will be null. Therefore, no matter what the incoming - // object type is, we can unconditionally store pTargetWeakReference to the object handle's extra data. + // If the existing reference is a native COM weak reference, we need to delete its native COM info + // and update it with the new native COM info. If the incoming object is not an RCW that can use + // IWeakReference, then comWeakHandleInfo will be null. Therefore, no matter what the incoming + // object type is, we can unconditionally store comWeakHandleInfo to the object handle's extra data. IGCHandleManager *mgr = GCHandleUtilities::GetGCHandleManager(); - IWeakReference* pExistingWeakReference = reinterpret_cast(mgr->GetExtraInfoFromHandle(handle.Handle)); - mgr->SetExtraInfoForHandle(handle.Handle, HNDTYPE_WEAK_NATIVE_COM, reinterpret_cast(pTargetWeakReference.GetValue())); + NativeComWeakHandleInfo* existingInfo = reinterpret_cast(mgr->GetExtraInfoFromHandle(handle.Handle)); + mgr->SetExtraInfoForHandle(handle.Handle, HNDTYPE_WEAK_NATIVE_COM, reinterpret_cast(comWeakHandleInfo.GetValue())); StoreObjectInHandle(handle.Handle, target); - if (pExistingWeakReference != nullptr) + if (existingInfo != nullptr) { - pExistingWeakReference->Release(); + _ASSERTE(existingInfo->WeakReference != nullptr); + existingInfo->WeakReference->Release(); + delete existingInfo; } } - else if (pTargetWeakReference != nullptr) + else if (comWeakHandleInfo != nullptr) { // The existing handle is not a native COM weak reference, but we need to store the new object in // a native COM weak reference. Therefore we need to destroy the old handle and create a new native COM @@ -801,7 +820,7 @@ NOINLINE void SetWeakReferenceTarget(WEAKREFERENCEREF weakReference, OBJECTREF t _ASSERTE(!IsNativeComWeakReferenceHandle(handle.RawHandle)); OBJECTHANDLE previousHandle = handle.RawHandle; - handle.Handle = GetAppDomain()->CreateNativeComWeakHandle(target, pTargetWeakReference); + handle.Handle = GetAppDomain()->CreateNativeComWeakHandle(target, comWeakHandleInfo); handle.RawHandle = SetNativeComWeakReferenceHandle(handle.Handle); DestroyTypedHandle(previousHandle); @@ -813,7 +832,7 @@ NOINLINE void SetWeakReferenceTarget(WEAKREFERENCEREF weakReference, OBJECTREF t } #ifdef FEATURE_COMINTEROP - pTargetWeakReference.SuppressRelease(); + comWeakHandleInfo.SuppressRelease(); #endif // FEATURE_COMINTEROP HELPER_METHOD_FRAME_END(); diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs index 481eafd..2e76a65 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs @@ -159,6 +159,50 @@ namespace ComWrappersTests Assert.AreNotEqual(trackerObj1, trackerObj3); } + static void ValidateWrappersInstanceIsolation() + { + Console.WriteLine($"Running {nameof(ValidateWrappersInstanceIsolation)}..."); + + var cw1 = new TestComWrappers(); + var cw2 = new TestComWrappers(); + + var testObj = new Test(); + + // Allocate a wrapper for the object + IntPtr comWrapper1 = cw1.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.TrackerSupport); + IntPtr comWrapper2 = cw2.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.TrackerSupport); + Assert.AreNotEqual(comWrapper1, IntPtr.Zero); + Assert.AreNotEqual(comWrapper2, IntPtr.Zero); + Assert.AreNotEqual(comWrapper1, comWrapper2); + + IntPtr comWrapper3 = cw1.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.TrackerSupport); + IntPtr comWrapper4 = cw2.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.TrackerSupport); + Assert.AreNotEqual(comWrapper3, comWrapper4); + Assert.AreEqual(comWrapper1, comWrapper3); + Assert.AreEqual(comWrapper2, comWrapper4); + + Marshal.Release(comWrapper1); + Marshal.Release(comWrapper2); + Marshal.Release(comWrapper3); + Marshal.Release(comWrapper4); + + // Get an object from a tracker runtime. + IntPtr trackerObjRaw = MockReferenceTrackerRuntime.CreateTrackerObject(); + + // Create objects for the COM instance + var trackerObj1 = (ITrackerObjectWrapper)cw1.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject); + var trackerObj2 = (ITrackerObjectWrapper)cw2.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject); + Assert.AreNotEqual(trackerObj1, trackerObj2); + + var trackerObj3 = (ITrackerObjectWrapper)cw1.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject); + var trackerObj4 = (ITrackerObjectWrapper)cw2.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject); + Assert.AreNotEqual(trackerObj3, trackerObj4); + Assert.AreEqual(trackerObj1, trackerObj3); + Assert.AreEqual(trackerObj2, trackerObj4); + + Marshal.Release(trackerObjRaw); + } + static void ValidatePrecreatedExternalWrapper() { Console.WriteLine($"Running {nameof(ValidatePrecreatedExternalWrapper)}..."); @@ -357,6 +401,7 @@ namespace ComWrappersTests ValidateComInterfaceCreation(); ValidateFallbackQueryInterface(); ValidateCreateObjectCachingScenario(); + ValidateWrappersInstanceIsolation(); ValidatePrecreatedExternalWrapper(); ValidateIUnknownImpls(); ValidateBadComWrapperImpl(); diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs index f85138f..64d32d6 100644 --- a/src/coreclr/tests/src/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs +++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs @@ -22,7 +22,14 @@ namespace ComWrappersTests public IntPtr Vtbl; } - public class WeakReferencableWrapper + public enum WrapperRegistration + { + Local, + TrackerSupport, + Marshalling, + } + + public class WeakReferenceableWrapper { private struct Vtbl { @@ -37,14 +44,17 @@ namespace ComWrappersTests private readonly IntPtr instance; private readonly Vtbl vtable; - public WeakReferencableWrapper(IntPtr instance) + public WrapperRegistration Registration { get; } + + public WeakReferenceableWrapper(IntPtr instance, WrapperRegistration reg) { var inst = Marshal.PtrToStructure(instance); this.vtable = Marshal.PtrToStructure(inst.Vtbl); this.instance = instance; + Registration = reg; } - ~WeakReferencableWrapper() + ~WeakReferenceableWrapper() { if (this.instance != IntPtr.Zero) { @@ -57,6 +67,13 @@ namespace ComWrappersTests { class TestComWrappers : ComWrappers { + public WrapperRegistration Registration { get; } + + public TestComWrappers(WrapperRegistration reg = WrapperRegistration.Local) + { + Registration = reg; + } + protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count) { count = 0; @@ -66,59 +83,158 @@ namespace ComWrappersTests protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flag) { Marshal.AddRef(externalComObject); - return new WeakReferencableWrapper(externalComObject); + return new WeakReferenceableWrapper(externalComObject, Registration); } protected override void ReleaseObjects(IEnumerable objects) { } - public static readonly ComWrappers Instance = new TestComWrappers(); + public static readonly TestComWrappers TrackerSupportInstance = new TestComWrappers(WrapperRegistration.TrackerSupport); + public static readonly TestComWrappers MarshallingInstance = new TestComWrappers(WrapperRegistration.Marshalling); } - static void ValidateNativeWeakReference() + private static void ValidateWeakReferenceState(WeakReference wr, bool expectedIsAlive, TestComWrappers sourceWrappers = null) { - Console.WriteLine($"Running {nameof(ValidateNativeWeakReference)}..."); + WeakReferenceableWrapper target; + bool isAlive = wr.TryGetTarget(out target); + Assert.AreEqual(expectedIsAlive, isAlive); - static (WeakReference, IntPtr) GetWeakReference() - { - var cw = new TestComWrappers(); + if (isAlive && sourceWrappers != null) + Assert.AreEqual(sourceWrappers.Registration, target.Registration); + } - IntPtr objRaw = WeakReferenceNative.CreateWeakReferencableObject(); + private static (WeakReference, IntPtr) GetWeakReference(TestComWrappers cw) + { + IntPtr objRaw = WeakReferenceNative.CreateWeakReferencableObject(); + var obj = (WeakReferenceableWrapper)cw.GetOrCreateObjectForComInstance(objRaw, CreateObjectFlags.None); + var wr = new WeakReference(obj); + ValidateWeakReferenceState(wr, expectedIsAlive: true, cw); + return (wr, objRaw); + } + + private static IntPtr SetWeakReferenceTarget(WeakReference wr, TestComWrappers cw) + { + IntPtr objRaw = WeakReferenceNative.CreateWeakReferencableObject(); + var obj = (WeakReferenceableWrapper)cw.GetOrCreateObjectForComInstance(objRaw, CreateObjectFlags.None); + wr.SetTarget(obj); + ValidateWeakReferenceState(wr, expectedIsAlive: true, cw); + return objRaw; + } + + private static void ValidateNativeWeakReference(TestComWrappers cw) + { + Console.WriteLine($" -- Validate weak reference creation"); + var (weakRef, nativeRef) = GetWeakReference(cw); + + // Make sure RCW is collected + GC.Collect(); + GC.WaitForPendingFinalizers(); + + // Non-globally registered ComWrappers instances do not support rehydration. + // A weak reference to an RCW wrapping an IWeakReference can stay alive if the RCW was created through + // a global ComWrappers instance. If the RCW was created throug a local ComWrappers instance, the weak + // reference should be dead and stay dead once the RCW is collected. + bool supportsRehydration = cw.Registration != WrapperRegistration.Local; + + Console.WriteLine($" -- Validate RCW recreation"); + ValidateWeakReferenceState(weakRef, expectedIsAlive: supportsRehydration, cw); + + // Release the last native reference. + Marshal.Release(nativeRef); + GC.Collect(); + GC.WaitForPendingFinalizers(); + + // After all native references die and the RCW is collected, the weak reference should be dead and stay dead. + Console.WriteLine($" -- Validate release"); + ValidateWeakReferenceState(weakRef, expectedIsAlive: false); + + // Reset the weak reference target + Console.WriteLine($" -- Validate target reset"); + nativeRef = SetWeakReferenceTarget(weakRef, cw); - var obj = (WeakReferencableWrapper)cw.GetOrCreateObjectForComInstance(objRaw, CreateObjectFlags.None); + // Make sure RCW is collected + GC.Collect(); + GC.WaitForPendingFinalizers(); + + Console.WriteLine($" -- Validate RCW recreation"); + ValidateWeakReferenceState(weakRef, expectedIsAlive: supportsRehydration, cw); + + // Release the last native reference. + Marshal.Release(nativeRef); + GC.Collect(); + GC.WaitForPendingFinalizers(); + + // After all native references die and the RCW is collected, the weak reference should be dead and stay dead. + Console.WriteLine($" -- Validate release"); + ValidateWeakReferenceState(weakRef, expectedIsAlive: false); + } + + static void ValidateGlobalInstanceTrackerSupport() + { + Console.WriteLine($"Running {nameof(ValidateGlobalInstanceTrackerSupport)}..."); + ValidateNativeWeakReference(TestComWrappers.TrackerSupportInstance); + } + + static void ValidateGlobalInstanceMarshalling() + { + Console.WriteLine($"Running {nameof(ValidateGlobalInstanceMarshalling)}..."); + ValidateNativeWeakReference(TestComWrappers.MarshallingInstance); + } + + static void ValidateLocalInstance() + { + Console.WriteLine($"Running {nameof(ValidateLocalInstance)}..."); + ValidateNativeWeakReference(new TestComWrappers()); + } + + static void ValidateNonComWrappers() + { + Console.WriteLine($"Running {nameof(ValidateNonComWrappers)}..."); - return (new WeakReference(obj), objRaw); + (WeakReference, IntPtr) GetWeakReference() + { + IntPtr objRaw = WeakReferenceNative.CreateWeakReferencableObject(); + var obj = Marshal.GetObjectForIUnknown(objRaw); + return (new WeakReference(obj), objRaw); } - static bool CheckIfWeakReferenceIsAlive(WeakReference wr) + bool HasTarget(WeakReference wr) { - return wr.TryGetTarget(out _); + return wr.Target != null; } var (weakRef, nativeRef) = GetWeakReference(); GC.Collect(); GC.WaitForPendingFinalizers(); - // A weak reference to an RCW wrapping an IWeakReference should stay alive even after the RCW dies - Assert.IsTrue(CheckIfWeakReferenceIsAlive(weakRef)); + + // A weak reference to an RCW wrapping an IWeakReference created throguh the built-in system + // should stay alive even after the RCW dies + Assert.IsFalse(weakRef.IsAlive); + Assert.IsTrue(HasTarget(weakRef)); // Release the last native reference. Marshal.Release(nativeRef); - GC.Collect(); GC.WaitForPendingFinalizers(); // After all native references die and the RCW is collected, the weak reference should be dead and stay dead. - Assert.IsFalse(CheckIfWeakReferenceIsAlive(weakRef)); - + Assert.IsNull(weakRef.Target); } static int Main(string[] doNotUse) { try { - ComWrappers.RegisterForTrackerSupport(TestComWrappers.Instance); - ValidateNativeWeakReference(); + ValidateNonComWrappers(); + + ComWrappers.RegisterForTrackerSupport(TestComWrappers.TrackerSupportInstance); + ValidateGlobalInstanceTrackerSupport(); + + ComWrappers.RegisterForMarshalling(TestComWrappers.MarshallingInstance); + ValidateGlobalInstanceMarshalling(); + + ValidateLocalInstance(); } catch (Exception e) { -- 2.7.4