From bda2db238eec2f1ff81e25d709a03d966c8dd116 Mon Sep 17 00:00:00 2001
From: Elinor Fung <47805090+elinor-fung@users.noreply.github.com>
Date: Tue, 16 Jun 2020 20:10:20 -0700
Subject: [PATCH] ComWrappers isolation (#37861)
- Assign/track ID of ComWrappers used to create object/wrapper.
- Only reuse object/wrapper created by same instance for GetOrCreate*
- Rehydrate RCW on WeakReference for global ComWrapper instances only
---
.../System/Runtime/InteropServices/ComWrappers.cs | 33 +++--
src/coreclr/src/vm/appdomain.hpp | 4 +-
src/coreclr/src/vm/gchandleutilities.h | 25 ++--
src/coreclr/src/vm/interoplibinterface.cpp | 140 +++++++++++++-----
src/coreclr/src/vm/interoplibinterface.h | 15 +-
src/coreclr/src/vm/interoputil.cpp | 7 +-
src/coreclr/src/vm/syncblk.h | 66 +++++++--
src/coreclr/src/vm/weakreferencenative.cpp | 93 +++++++-----
.../src/Interop/COM/ComWrappers/API/Program.cs | 45 ++++++
.../ComWrappers/WeakReference/WeakReferenceTest.cs | 160 ++++++++++++++++++---
10 files changed, 456 insertions(+), 132 deletions(-)
diff --git a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs
index 6bdd48d..b404dc6 100644
--- a/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs
+++ b/src/coreclr/src/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.cs
@@ -126,12 +126,20 @@ namespace System.Runtime.InteropServices
///
private static ComWrappers? s_globalInstanceForMarshalling;
+ private static long s_instanceCounter;
+ private readonly long id = Interlocked.Increment(ref s_instanceCounter);
+
///
/// Create a COM representation of the supplied object that can be passed to a non-managed environment.
///
/// The managed object to expose outside the .NET runtime.
/// Flags used to configure the generated interface.
/// The generated COM interface that can be passed outside the .NET runtime.
+ ///
+ /// If a COM representation was previously created for the specified using
+ /// this instance, the previously created COM interface will be returned.
+ /// If not, a new one will be created.
+ ///
public IntPtr GetOrCreateComInterfaceForObject(object instance, CreateComInterfaceFlags flags)
{
IntPtr ptr;
@@ -152,16 +160,16 @@ namespace System.Runtime.InteropServices
///
/// If is null, the global instance (if registered) will be used.
///
- private static bool TryGetOrCreateComInterfaceForObjectInternal(ComWrappers? impl, object instance, CreateComInterfaceFlags flags, out IntPtr retValue)
+ private static bool TryGetOrCreateComInterfaceForObjectInternal(ComWrappers impl, object instance, CreateComInterfaceFlags flags, out IntPtr retValue)
{
if (instance == null)
throw new ArgumentNullException(nameof(instance));
- return TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), ObjectHandleOnStack.Create(ref instance), flags, out retValue);
+ return TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack.Create(ref impl), impl.id, ObjectHandleOnStack.Create(ref instance), flags, out retValue);
}
[DllImport(RuntimeHelpers.QCall)]
- private static extern bool TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack comWrappersImpl, ObjectHandleOnStack instance, CreateComInterfaceFlags flags, out IntPtr retValue);
+ private static extern bool TryGetOrCreateComInterfaceForObjectInternal(ObjectHandleOnStack comWrappersImpl, long wrapperId, ObjectHandleOnStack instance, CreateComInterfaceFlags flags, out IntPtr retValue);
///
/// Compute the desired Vtable for respecting the values of .
@@ -211,6 +219,11 @@ namespace System.Runtime.InteropServices
/// Object to import for usage into the .NET runtime.
/// Flags used to describe the external object.
/// Returns a managed object associated with the supplied external COM object.
+ ///
+ /// If a managed object was previously created for the specified
+ /// using this instance, the previously created object will be returned.
+ /// If not, a new one will be created.
+ ///
public object GetOrCreateObjectForComInstance(IntPtr externalComObject, CreateObjectFlags flags)
{
object? obj;
@@ -288,18 +301,18 @@ namespace System.Runtime.InteropServices
///
/// If is null, the global instance (if registered) will be used.
///
- private static bool TryGetOrCreateObjectForComInstanceInternal(ComWrappers? impl, IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe, out object? retValue)
+ private static bool TryGetOrCreateObjectForComInstanceInternal(ComWrappers impl, IntPtr externalComObject, CreateObjectFlags flags, object? wrapperMaybe, out object? retValue)
{
if (externalComObject == IntPtr.Zero)
throw new ArgumentNullException(nameof(externalComObject));
object? wrapperMaybeLocal = wrapperMaybe;
retValue = null;
- return TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue));
+ return TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack.Create(ref impl), impl.id, externalComObject, flags, ObjectHandleOnStack.Create(ref wrapperMaybeLocal), ObjectHandleOnStack.Create(ref retValue));
}
[DllImport(RuntimeHelpers.QCall)]
- private static extern bool TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, IntPtr externalComObject, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue);
+ private static extern bool TryGetOrCreateObjectForComInstanceInternal(ObjectHandleOnStack comWrappersImpl, long wrapperId, IntPtr externalComObject, CreateObjectFlags flags, ObjectHandleOnStack wrapper, ObjectHandleOnStack retValue);
///
/// Called when a request is made for a collection of objects to be released outside of normal object or COM interface lifetime.
@@ -332,13 +345,13 @@ namespace System.Runtime.InteropServices
throw new InvalidOperationException(SR.InvalidOperation_ResetGlobalComWrappersInstance);
}
- SetGlobalInstanceRegisteredForTrackerSupport();
+ SetGlobalInstanceRegisteredForTrackerSupport(instance.id);
}
[DllImport(RuntimeHelpers.QCall)]
[SuppressGCTransition]
- private static extern void SetGlobalInstanceRegisteredForTrackerSupport();
+ private static extern void SetGlobalInstanceRegisteredForTrackerSupport(long id);
///
/// Register a instance to be used as the global instance for marshalling in the runtime.
@@ -366,12 +379,12 @@ namespace System.Runtime.InteropServices
// Indicate to the runtime that a global instance has been registered for marshalling.
// This allows the native runtime know to call into the managed ComWrappers only if a
// global instance is registered for marshalling.
- SetGlobalInstanceRegisteredForMarshalling();
+ SetGlobalInstanceRegisteredForMarshalling(instance.id);
}
[DllImport(RuntimeHelpers.QCall)]
[SuppressGCTransition]
- private static extern void SetGlobalInstanceRegisteredForMarshalling();
+ private static extern void SetGlobalInstanceRegisteredForMarshalling(long id);
///
/// Get the runtime provided IUnknown implementation.
diff --git a/src/coreclr/src/vm/appdomain.hpp b/src/coreclr/src/vm/appdomain.hpp
index afdee41..6b1393f 100644
--- a/src/coreclr/src/vm/appdomain.hpp
+++ b/src/coreclr/src/vm/appdomain.hpp
@@ -1099,10 +1099,10 @@ public:
return ::CreateRefcountedHandle(m_handleStore, object);
}
- OBJECTHANDLE CreateNativeComWeakHandle(OBJECTREF object, IWeakReference* pComWeakReference)
+ OBJECTHANDLE CreateNativeComWeakHandle(OBJECTREF object, NativeComWeakHandleInfo* pComWeakHandleInfo)
{
WRAPPER_NO_CONTRACT;
- return ::CreateNativeComWeakHandle(m_handleStore, object, pComWeakReference);
+ return ::CreateNativeComWeakHandle(m_handleStore, object, pComWeakHandleInfo);
}
#endif // FEATURE_COMINTEROP
diff --git a/src/coreclr/src/vm/gchandleutilities.h b/src/coreclr/src/vm/gchandleutilities.h
index d90a2f0..039df61 100644
--- a/src/coreclr/src/vm/gchandleutilities.h
+++ b/src/coreclr/src/vm/gchandleutilities.h
@@ -201,9 +201,16 @@ inline OBJECTHANDLE CreateGlobalRefcountedHandle(OBJECTREF object)
// Special handle creation convenience functions
#ifdef FEATURE_COMINTEROP
-inline OBJECTHANDLE CreateNativeComWeakHandle(IGCHandleStore* store, OBJECTREF object, IWeakReference* pComWeakReference)
+
+struct NativeComWeakHandleInfo
+{
+ IWeakReference *WeakReference;
+ INT64 WrapperId;
+};
+
+inline OBJECTHANDLE CreateNativeComWeakHandle(IGCHandleStore* store, OBJECTREF object, NativeComWeakHandleInfo* pComWeakHandleInfo)
{
- OBJECTHANDLE hnd = store->CreateHandleWithExtraInfo(OBJECTREFToObject(object), HNDTYPE_WEAK_NATIVE_COM, (void*)pComWeakReference);
+ OBJECTHANDLE hnd = store->CreateHandleWithExtraInfo(OBJECTREFToObject(object), HNDTYPE_WEAK_NATIVE_COM, (void*)pComWeakHandleInfo);
if (!hnd)
{
COMPlusThrowOM();
@@ -374,14 +381,16 @@ inline void DestroyNativeComWeakHandle(OBJECTHANDLE handle)
}
CONTRACTL_END;
- // Release the WinRT weak reference if we have one. We're assuming that this will not reenter the
- // runtime, since if we are pointing at a managed object, we should not be using HNDTYPE_WEAK_NATIVE_COM
- // but rather HNDTYPE_WEAK_SHORT or HNDTYPE_WEAK_LONG.
+ // Delete the COM info and release the weak reference if we have one. We're assuming that
+ // this will not reenter the runtime, since if we are pointing at a managed object, we should
+ // not be using HNDTYPE_WEAK_NATIVE_COM but rather HNDTYPE_WEAK_SHORT or HNDTYPE_WEAK_LONG.
void* pExtraInfo = GCHandleUtilities::GetGCHandleManager()->GetExtraInfoFromHandle(handle);
- IWeakReference* pWinRTWeakReference = reinterpret_cast(pExtraInfo);
- if (pWinRTWeakReference != nullptr)
+ NativeComWeakHandleInfo* comWeakHandleInfo = reinterpret_cast(pExtraInfo);
+ if (comWeakHandleInfo != nullptr)
{
- pWinRTWeakReference->Release();
+ _ASSERTE(comWeakHandleInfo->WeakReference != nullptr);
+ comWeakHandleInfo->WeakReference->Release();
+ delete comWeakHandleInfo;
}
DiagHandleDestroyed(handle);
diff --git a/src/coreclr/src/vm/interoplibinterface.cpp b/src/coreclr/src/vm/interoplibinterface.cpp
index a17bf75..ffd8d4f 100644
--- a/src/coreclr/src/vm/interoplibinterface.cpp
+++ b/src/coreclr/src/vm/interoplibinterface.cpp
@@ -26,6 +26,7 @@ namespace
void* Identity;
void* ThreadContext;
DWORD SyncBlockIndex;
+ INT64 WrapperId;
enum
{
@@ -41,6 +42,7 @@ namespace
_In_ IUnknown* identity,
_In_opt_ void* threadContext,
_In_ DWORD syncBlockIndex,
+ _In_ INT64 wrapperId,
_In_ DWORD flags)
{
CONTRACTL
@@ -57,6 +59,7 @@ namespace
cxt->Identity = (void*)identity;
cxt->ThreadContext = threadContext;
cxt->SyncBlockIndex = syncBlockIndex;
+ cxt->WrapperId = wrapperId;
cxt->Flags = flags;
}
@@ -91,6 +94,40 @@ namespace
_ASSERTE(IsActive());
return ObjectToOBJECTREF(g_pSyncTable[SyncBlockIndex].m_Object);
}
+
+ struct Key
+ {
+ public:
+ Key(void* identity, INT64 wrapperId)
+ : _identity { identity }
+ , _wrapperId { wrapperId }
+ {
+ _ASSERTE(identity != NULL);
+ _ASSERTE(wrapperId != ComWrappersNative::InvalidWrapperId);
+ }
+
+ DWORD Hash() const
+ {
+ DWORD hash = (_wrapperId >> 32) ^ (_wrapperId & 0xFFFFFFFF);
+#if POINTER_BITS == 32
+ return hash ^ (DWORD)_identity;
+#else
+ INT64 identityInt64 = (INT64)_identity;
+ return hash ^ (identityInt64 >> 32) ^ (identityInt64 & 0xFFFFFFFF);
+#endif
+ }
+
+ bool operator==(const Key & rhs) const { return _identity == rhs._identity && _wrapperId == rhs._wrapperId; }
+
+ private:
+ void* _identity;
+ INT64 _wrapperId;
+ };
+
+ Key GetKey() const
+ {
+ return Key(Identity, WrapperId);
+ }
};
const DWORD ExternalObjectContext::InvalidSyncBlockIndex = 0; // See syncblk.h
@@ -171,10 +208,10 @@ namespace
class Traits : public DefaultSHashTraits
{
public:
- using key_t = void*;
- static const key_t GetKey(_In_ element_t e) { LIMITED_METHOD_CONTRACT; return (key_t)e->Identity; }
- static count_t Hash(_In_ key_t key) { LIMITED_METHOD_CONTRACT; return (count_t)(size_t)key; }
- static bool Equals(_In_ key_t lhs, _In_ key_t rhs) { LIMITED_METHOD_CONTRACT; return (lhs == rhs); }
+ using key_t = ExternalObjectContext::Key;
+ static const key_t GetKey(_In_ element_t e) { LIMITED_METHOD_CONTRACT; return e->GetKey(); }
+ static count_t Hash(_In_ key_t key) { LIMITED_METHOD_CONTRACT; return (count_t)key.Hash(); }
+ static bool Equals(_In_ key_t lhs, _In_ key_t rhs) { LIMITED_METHOD_CONTRACT; return lhs == rhs; }
};
// Alias some useful types
@@ -311,7 +348,7 @@ namespace
RETURN gc.arrRef;
}
- ExternalObjectContext* Find(_In_ IUnknown* instance)
+ ExternalObjectContext* Find(_In_ const ExternalObjectContext::Key& key)
{
CONTRACT(ExternalObjectContext*)
{
@@ -319,7 +356,6 @@ namespace
GC_NOTRIGGER;
MODE_COOPERATIVE;
PRECONDITION(IsLockHeld());
- PRECONDITION(instance != NULL);
POSTCONDITION(CheckPointer(RETVAL, NULL_OK));
}
CONTRACT_END;
@@ -327,7 +363,7 @@ namespace
// Forbid the GC from messing with the hash table.
GCX_FORBID();
- RETURN _hashMap.Lookup(instance);
+ RETURN _hashMap.Lookup(key);
}
ExternalObjectContext* Add(_In_ ExternalObjectContext* cxt)
@@ -340,7 +376,7 @@ namespace
PRECONDITION(IsLockHeld());
PRECONDITION(!Traits::IsNull(cxt) && !Traits::IsDeleted(cxt));
PRECONDITION(cxt->Identity != NULL);
- PRECONDITION(Find(static_cast(cxt->Identity)) == NULL);
+ PRECONDITION(Find(cxt->GetKey()) == NULL);
POSTCONDITION(RETVAL == cxt);
}
CONTRACT_END;
@@ -349,7 +385,7 @@ namespace
RETURN cxt;
}
- ExternalObjectContext* FindOrAdd(_In_ IUnknown* key, _In_ ExternalObjectContext* newCxt)
+ ExternalObjectContext* FindOrAdd(_In_ const ExternalObjectContext::Key& key, _In_ ExternalObjectContext* newCxt)
{
CONTRACT(ExternalObjectContext*)
{
@@ -357,9 +393,8 @@ namespace
GC_NOTRIGGER;
MODE_COOPERATIVE;
PRECONDITION(IsLockHeld());
- PRECONDITION(key != NULL);
PRECONDITION(!Traits::IsNull(newCxt) && !Traits::IsDeleted(newCxt));
- PRECONDITION(key == newCxt->Identity);
+ PRECONDITION(key == newCxt->GetKey());
POSTCONDITION(CheckPointer(RETVAL));
}
CONTRACT_END;
@@ -392,7 +427,7 @@ namespace
}
CONTRACTL_END;
- _hashMap.Remove(cxt->Identity);
+ _hashMap.Remove(cxt->GetKey());
}
};
@@ -400,8 +435,8 @@ namespace
Volatile ExtObjCxtCache::g_Instance;
// Indicator for if a ComWrappers implementation is globally registered
- bool g_IsGlobalComWrappersRegisteredForMarshalling;
- bool g_IsGlobalComWrappersRegisteredForTrackerSupport;
+ INT64 g_marshallingGlobalInstanceId = ComWrappersNative::InvalidWrapperId;
+ INT64 g_trackerSupportGlobalInstanceId = ComWrappersNative::InvalidWrapperId;
// Defined handle types for the specific object uses.
const HandleType InstanceHandleType{ HNDTYPE_STRONG };
@@ -522,6 +557,7 @@ namespace
bool TryGetOrCreateComInterfaceForObjectInternal(
_In_opt_ OBJECTREF impl,
+ _In_ INT64 wrapperId,
_In_ OBJECTREF instance,
_In_ CreateComInterfaceFlags flags,
_In_ ComWrappersScenario scenario,
@@ -533,7 +569,8 @@ namespace
MODE_COOPERATIVE;
PRECONDITION(instance != NULL);
PRECONDITION(wrapperRaw != NULL);
- PRECONDITION((impl != NULL && scenario == ComWrappersScenario::Instance)|| (impl == NULL && scenario != ComWrappersScenario::Instance));
+ PRECONDITION((impl != NULL && scenario == ComWrappersScenario::Instance) || (impl == NULL && scenario != ComWrappersScenario::Instance));
+ PRECONDITION(wrapperId != ComWrappersNative::InvalidWrapperId);
}
CONTRACT_END;
@@ -559,7 +596,7 @@ namespace
_ASSERTE(syncBlock->IsPrecious());
// Query the associated InteropSyncBlockInfo for an existing managed object wrapper.
- if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe))
+ if (!interopInfo->TryGetManagedObjectComWrapper(wrapperId, &wrapperRawMaybe))
{
// Compute VTables for the new existing COM object using the supplied COM Wrappers implementation.
//
@@ -570,7 +607,7 @@ namespace
void* vtables = CallComputeVTables(scenario, &gc.implRef, &gc.instRef, flags, &vtableCount);
// Re-query the associated InteropSyncBlockInfo for an existing managed object wrapper.
- if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe)
+ if (!interopInfo->TryGetManagedObjectComWrapper(wrapperId, &wrapperRawMaybe)
&& ((vtables != nullptr && vtableCount > 0) || (vtableCount == 0)))
{
OBJECTHANDLE instHandle = GetAppDomain()->CreateTypedHandle(gc.instRef, InstanceHandleType);
@@ -590,14 +627,14 @@ namespace
_ASSERTE(!newWrapper.IsNull());
// Try setting the newly created managed object wrapper on the InteropSyncBlockInfo.
- if (!interopInfo->TrySetManagedObjectComWrapper(newWrapper))
+ if (!interopInfo->TrySetManagedObjectComWrapper(wrapperId, newWrapper))
{
// The new wrapper couldn't be set which means a wrapper already exists.
newWrapper.Release();
// If the managed object wrapper couldn't be set, then
// it should be possible to get the current one.
- if (!interopInfo->TryGetManagedObjectComWrapper(&wrapperRawMaybe))
+ if (!interopInfo->TryGetManagedObjectComWrapper(wrapperId, &wrapperRawMaybe))
{
UNREACHABLE();
}
@@ -638,6 +675,7 @@ namespace
bool TryGetOrCreateObjectForComInstanceInternal(
_In_opt_ OBJECTREF impl,
+ _In_ INT64 wrapperId,
_In_ IUnknown* identity,
_In_ CreateObjectFlags flags,
_In_ ComWrappersScenario scenario,
@@ -651,6 +689,7 @@ namespace
PRECONDITION(identity != NULL);
PRECONDITION(objRef != NULL);
PRECONDITION((impl != NULL && scenario == ComWrappersScenario::Instance) || (impl == NULL && scenario != ComWrappersScenario::Instance));
+ PRECONDITION(wrapperId != ComWrappersNative::InvalidWrapperId);
}
CONTRACT_END;
@@ -672,13 +711,15 @@ namespace
ExtObjCxtCache* cache = ExtObjCxtCache::GetInstance();
InteropLib::OBJECTHANDLE handle = NULL;
+ ExternalObjectContext::Key cacheKey(identity, wrapperId);
+
// Check if the user requested a unique instance.
bool uniqueInstance = !!(flags & CreateObjectFlags::CreateObjectFlags_UniqueInstance);
if (!uniqueInstance)
{
// Query the external object cache
ExtObjCxtCache::LockHolder lock(cache);
- extObjCxt = cache->Find(identity);
+ extObjCxt = cache->Find(cacheKey);
// If is no object found in the cache, check if the object COM instance is actually the CCW
// representing a managed object. For the scenario of marshalling through a global instance,
@@ -750,6 +791,7 @@ namespace
identity,
GetCurrentCtxCookie(),
gc.objRefMaybe->GetSyncBlockIndex(),
+ wrapperId,
flags);
if (uniqueInstance)
@@ -760,7 +802,7 @@ namespace
{
// Attempt to insert the new context into the cache.
ExtObjCxtCache::LockHolder lock(cache);
- extObjCxt = cache->FindOrAdd(identity, resultHolder.GetContext());
+ extObjCxt = cache->FindOrAdd(cacheKey, resultHolder.GetContext());
}
// If the returned context matches the new context it means the
@@ -1039,6 +1081,7 @@ namespace InteropLibImports
// Get wrapper for external object
bool success = TryGetOrCreateObjectForComInstanceInternal(
gc.implRef,
+ g_trackerSupportGlobalInstanceId,
externalComObject,
externalObjectFlags,
ComWrappersScenario::TrackerSupportGlobalInstance,
@@ -1051,6 +1094,7 @@ namespace InteropLibImports
// Get wrapper for managed object
success = TryGetOrCreateComInterfaceForObjectInternal(
gc.implRef,
+ g_trackerSupportGlobalInstanceId,
gc.objRef,
trackerTargetFlags,
ComWrappersScenario::TrackerSupportGlobalInstance,
@@ -1220,6 +1264,7 @@ namespace InteropLibImports
BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateComInterfaceForObject(
_In_ QCall::ObjectHandleOnStack comWrappersImpl,
+ _In_ INT64 wrapperId,
_In_ QCall::ObjectHandleOnStack instance,
_In_ INT32 flags,
_Outptr_ void** wrapper)
@@ -1236,6 +1281,7 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateComInterfaceForObject(
GCX_COOP();
success = TryGetOrCreateComInterfaceForObjectInternal(
ObjectToOBJECTREF(*comWrappersImpl.m_ppObject),
+ wrapperId,
ObjectToOBJECTREF(*instance.m_ppObject),
(CreateComInterfaceFlags)flags,
ComWrappersScenario::Instance,
@@ -1249,6 +1295,7 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateComInterfaceForObject(
BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance(
_In_ QCall::ObjectHandleOnStack comWrappersImpl,
+ _In_ INT64 wrapperId,
_In_ void* ext,
_In_ INT32 flags,
_In_ QCall::ObjectHandleOnStack wrapperMaybe,
@@ -1278,6 +1325,7 @@ BOOL QCALLTYPE ComWrappersNative::TryGetOrCreateObjectForComInstance(
OBJECTREF newObj;
success = TryGetOrCreateObjectForComInstanceInternal(
ObjectToOBJECTREF(*comWrappersImpl.m_ppObject),
+ wrapperId,
identity,
(CreateObjectFlags)flags,
ComWrappersScenario::Instance,
@@ -1387,19 +1435,25 @@ void ComWrappersNative::MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe)
_ASSERTE(SUCCEEDED(hr) || hr == E_INVALIDARG);
}
-void QCALLTYPE GlobalComWrappersForMarshalling::SetGlobalInstanceRegisteredForMarshalling()
+void QCALLTYPE GlobalComWrappersForMarshalling::SetGlobalInstanceRegisteredForMarshalling(INT64 id)
{
QCALL_CONTRACT_NO_GC_TRANSITION;
- _ASSERTE(!g_IsGlobalComWrappersRegisteredForMarshalling);
- g_IsGlobalComWrappersRegisteredForMarshalling = true;
+ _ASSERTE(g_marshallingGlobalInstanceId == ComWrappersNative::InvalidWrapperId && id != ComWrappersNative::InvalidWrapperId);
+ g_marshallingGlobalInstanceId = id;
+}
+
+bool GlobalComWrappersForMarshalling::IsRegisteredInstance(INT64 id)
+{
+ return g_marshallingGlobalInstanceId != ComWrappersNative::InvalidWrapperId
+ && g_marshallingGlobalInstanceId == id;
}
bool GlobalComWrappersForMarshalling::TryGetOrCreateComInterfaceForObject(
_In_ OBJECTREF instance,
_Outptr_ void** wrapperRaw)
{
- if (!g_IsGlobalComWrappersRegisteredForMarshalling)
+ if (g_marshallingGlobalInstanceId == ComWrappersNative::InvalidWrapperId)
return false;
// Switch to Cooperative mode since object references
@@ -1412,6 +1466,7 @@ bool GlobalComWrappersForMarshalling::TryGetOrCreateComInterfaceForObject(
// Passing NULL as the ComWrappers implementation indicates using the globally registered instance
return TryGetOrCreateComInterfaceForObjectInternal(
NULL,
+ g_marshallingGlobalInstanceId,
instance,
flags,
ComWrappersScenario::MarshallingGlobalInstance,
@@ -1424,7 +1479,7 @@ bool GlobalComWrappersForMarshalling::TryGetOrCreateObjectForComInstance(
_In_ INT32 objFromComIPFlags,
_Out_ OBJECTREF* objRef)
{
- if (!g_IsGlobalComWrappersRegisteredForMarshalling)
+ if (g_marshallingGlobalInstanceId == ComWrappersNative::InvalidWrapperId)
return false;
// Determine the true identity of the object
@@ -1448,6 +1503,7 @@ bool GlobalComWrappersForMarshalling::TryGetOrCreateObjectForComInstance(
// Passing NULL as the ComWrappers implementation indicates using the globally registered instance
return TryGetOrCreateObjectForComInstanceInternal(
NULL /*comWrappersImpl*/,
+ g_marshallingGlobalInstanceId,
identity,
(CreateObjectFlags)flags,
ComWrappersScenario::MarshallingGlobalInstance,
@@ -1456,12 +1512,18 @@ bool GlobalComWrappersForMarshalling::TryGetOrCreateObjectForComInstance(
}
}
-void QCALLTYPE GlobalComWrappersForTrackerSupport::SetGlobalInstanceRegisteredForTrackerSupport()
+void QCALLTYPE GlobalComWrappersForTrackerSupport::SetGlobalInstanceRegisteredForTrackerSupport(INT64 id)
{
QCALL_CONTRACT_NO_GC_TRANSITION;
- _ASSERTE(!g_IsGlobalComWrappersRegisteredForTrackerSupport);
- g_IsGlobalComWrappersRegisteredForTrackerSupport = true;
+ _ASSERTE(g_trackerSupportGlobalInstanceId == ComWrappersNative::InvalidWrapperId && id != ComWrappersNative::InvalidWrapperId);
+ g_trackerSupportGlobalInstanceId = id;
+}
+
+bool GlobalComWrappersForTrackerSupport::IsRegisteredInstance(INT64 id)
+{
+ return g_trackerSupportGlobalInstanceId != ComWrappersNative::InvalidWrapperId
+ && g_trackerSupportGlobalInstanceId == id;
}
bool GlobalComWrappersForTrackerSupport::TryGetOrCreateComInterfaceForObject(
@@ -1475,12 +1537,13 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateComInterfaceForObject(
}
CONTRACTL_END;
- if (!g_IsGlobalComWrappersRegisteredForTrackerSupport)
+ if (g_trackerSupportGlobalInstanceId == ComWrappersNative::InvalidWrapperId)
return false;
// Passing NULL as the ComWrappers implementation indicates using the globally registered instance
return TryGetOrCreateComInterfaceForObjectInternal(
NULL,
+ g_trackerSupportGlobalInstanceId,
instance,
CreateComInterfaceFlags::CreateComInterfaceFlags_TrackerSupport,
ComWrappersScenario::TrackerSupportGlobalInstance,
@@ -1498,7 +1561,7 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(
}
CONTRACTL_END;
- if (!g_IsGlobalComWrappersRegisteredForTrackerSupport)
+ if (g_trackerSupportGlobalInstanceId == ComWrappersNative::InvalidWrapperId)
return false;
// Determine the true identity of the object
@@ -1513,6 +1576,7 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(
// Passing NULL as the ComWrappers implementation indicates using the globally registered instance
return TryGetOrCreateObjectForComInstanceInternal(
NULL /*comWrappersImpl*/,
+ g_trackerSupportGlobalInstanceId,
identity,
CreateObjectFlags::CreateObjectFlags_TrackerObject,
ComWrappersScenario::TrackerSupportGlobalInstance,
@@ -1520,7 +1584,7 @@ bool GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(
objRef);
}
-IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid)
+IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId)
{
CONTRACTL
{
@@ -1528,11 +1592,14 @@ IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTE
GC_TRIGGERS;
MODE_COOPERATIVE;
PRECONDITION(CheckPointer(objectPROTECTED));
+ PRECONDITION(CheckPointer(wrapperId));
}
CONTRACTL_END;
ASSERT_PROTECTED(objectPROTECTED);
+ *wrapperId = ComWrappersNative::InvalidWrapperId;
+
SyncBlock* syncBlock = (*objectPROTECTED)->PassiveGetSyncBlock();
if (syncBlock == nullptr)
{
@@ -1545,10 +1612,13 @@ IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTE
return nullptr;
}
- void* context;
- if (interopInfo->TryGetExternalComObjectContext(&context))
+ void* contextMaybe;
+ if (interopInfo->TryGetExternalComObjectContext(&contextMaybe))
{
- IUnknown* identity = reinterpret_cast(reinterpret_cast(context)->Identity);
+ ExternalObjectContext* context = reinterpret_cast(contextMaybe);
+ *wrapperId = context->WrapperId;
+
+ IUnknown* identity = reinterpret_cast(context->Identity);
GCX_PREEMP();
IUnknown* result;
if (SUCCEEDED(identity->QueryInterface(riid, (void**)&result)))
diff --git a/src/coreclr/src/vm/interoplibinterface.h b/src/coreclr/src/vm/interoplibinterface.h
index 3e5abda..bae00fb 100644
--- a/src/coreclr/src/vm/interoplibinterface.h
+++ b/src/coreclr/src/vm/interoplibinterface.h
@@ -11,6 +11,9 @@
// Native calls for the managed ComWrappers API
class ComWrappersNative
{
+public:
+ static const INT64 InvalidWrapperId = 0;
+
public: // Native QCalls for the abstract ComWrappers managed type.
static void QCALLTYPE GetIUnknownImpl(
_Out_ void** fpQueryInterface,
@@ -19,12 +22,14 @@ public: // Native QCalls for the abstract ComWrappers managed type.
static BOOL QCALLTYPE TryGetOrCreateComInterfaceForObject(
_In_ QCall::ObjectHandleOnStack comWrappersImpl,
+ _In_ INT64 wrapperId,
_In_ QCall::ObjectHandleOnStack instance,
_In_ INT32 flags,
_Outptr_ void** wrapperRaw);
static BOOL QCALLTYPE TryGetOrCreateObjectForComInstance(
_In_ QCall::ObjectHandleOnStack comWrappersImpl,
+ _In_ INT64 wrapperId,
_In_ void* externalComObject,
_In_ INT32 flags,
_In_ QCall::ObjectHandleOnStack wrapperMaybe,
@@ -39,7 +44,7 @@ public: // COM activation
static void MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe);
public: // Unwrapping support
- static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid);
+ static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId);
};
class GlobalComWrappersForMarshalling
@@ -48,9 +53,11 @@ public:
// Native QCall for the ComWrappers managed type to indicate a global instance
// is registered for marshalling. This should be set if the private static member
// representing the global instance for marshalling on ComWrappers is non-null.
- static void QCALLTYPE SetGlobalInstanceRegisteredForMarshalling();
+ static void QCALLTYPE SetGlobalInstanceRegisteredForMarshalling(_In_ INT64 id);
public: // Functions operating on a registered global instance for marshalling
+ static bool IsRegisteredInstance(_In_ INT64 id);
+
static bool TryGetOrCreateComInterfaceForObject(
_In_ OBJECTREF instance,
_Outptr_ void** wrapperRaw);
@@ -68,9 +75,11 @@ public:
// Native QCall for the ComWrappers managed type to indicate a global instance
// is registered for tracker support. This should be set if the private static member
// representing the global instance for tracker support on ComWrappers is non-null.
- static void QCALLTYPE SetGlobalInstanceRegisteredForTrackerSupport();
+ static void QCALLTYPE SetGlobalInstanceRegisteredForTrackerSupport(_In_ INT64 id);
public: // Functions operating on a registered global instance for tracker support
+ static bool IsRegisteredInstance(_In_ INT64 id);
+
static bool TryGetOrCreateComInterfaceForObject(
_In_ OBJECTREF instance,
_Outptr_ void** wrapperRaw);
diff --git a/src/coreclr/src/vm/interoputil.cpp b/src/coreclr/src/vm/interoputil.cpp
index 33cad06..a2f6846 100644
--- a/src/coreclr/src/vm/interoputil.cpp
+++ b/src/coreclr/src/vm/interoputil.cpp
@@ -1321,12 +1321,7 @@ void CleanupSyncBlockComData(InteropSyncBlockInfo* pInteropInfo)
}
#ifdef FEATURE_COMWRAPPERS
- void* mocw;
- if (pInteropInfo->TryGetManagedObjectComWrapper(&mocw))
- {
- (void)pInteropInfo->TrySetManagedObjectComWrapper(NULL, mocw);
- ComWrappersNative::DestroyManagedObjectComWrapper(mocw);
- }
+ pInteropInfo->ClearManagedObjectComWrappers(&ComWrappersNative::DestroyManagedObjectComWrapper);
void* eoc;
if (pInteropInfo->TryGetExternalComObjectContext(&eoc))
diff --git a/src/coreclr/src/vm/syncblk.h b/src/coreclr/src/vm/syncblk.h
index e0364df..08a4fee 100644
--- a/src/coreclr/src/vm/syncblk.h
+++ b/src/coreclr/src/vm/syncblk.h
@@ -602,6 +602,8 @@ class ComClassFactory;
struct RCW;
class RCWHolder;
typedef DPTR(class ComCallWrapper) PTR_ComCallWrapper;
+
+#include "shash.h"
#endif // FEATURE_COMINTEROP
class InteropSyncBlockInfo
@@ -791,22 +793,65 @@ public:
#endif
public:
- bool TryGetManagedObjectComWrapper(_Out_ void** mocw)
+ bool TryGetManagedObjectComWrapper(_In_ INT64 wrapperId, _Out_ void** mocw)
{
LIMITED_METHOD_DAC_CONTRACT;
- *mocw = m_managedObjectComWrapper;
- return (*mocw != NULL);
+
+ *mocw = NULL;
+ if (m_managedObjectComWrapperMap == NULL)
+ return false;
+
+ CrstHolder lock(&m_managedObjectComWrapperLock);
+ return m_managedObjectComWrapperMap->Lookup(wrapperId, mocw);
}
#ifndef DACCESS_COMPILE
- bool TrySetManagedObjectComWrapper(_In_ void* mocw, _In_ void* curr = NULL)
+ bool TrySetManagedObjectComWrapper(_In_ INT64 wrapperId, _In_ void* mocw, _In_ void* curr = NULL)
{
LIMITED_METHOD_CONTRACT;
- return (FastInterlockCompareExchangePointer(
- &m_managedObjectComWrapper,
- mocw,
- curr) == curr);
+ if (m_managedObjectComWrapperMap == NULL)
+ {
+ NewHolder map = new ManagedObjectComWrapperByIdMap();
+ if (FastInterlockCompareExchangePointer((ManagedObjectComWrapperByIdMap**)&m_managedObjectComWrapperMap, (ManagedObjectComWrapperByIdMap *)map, NULL) == NULL)
+ {
+ map.SuppressRelease();
+ m_managedObjectComWrapperLock.Init(CrstLeafLock);
+ }
+
+ _ASSERTE(m_managedObjectComWrapperMap != NULL);
+ }
+
+ CrstHolder lock(&m_managedObjectComWrapperLock);
+
+ if (m_managedObjectComWrapperMap->LookupPtr(wrapperId) != curr)
+ return false;
+
+ m_managedObjectComWrapperMap->Add(wrapperId, mocw);
+ return true;
+ }
+
+ using EnumWrappersCallback = void(void* mocw);
+ void ClearManagedObjectComWrappers(EnumWrappersCallback* callback)
+ {
+ LIMITED_METHOD_CONTRACT;
+
+ if (m_managedObjectComWrapperMap == NULL)
+ return;
+
+ CrstHolder lock(&m_managedObjectComWrapperLock);
+
+ if (callback != NULL)
+ {
+ ManagedObjectComWrapperByIdMap::Iterator iter = m_managedObjectComWrapperMap->Begin();
+ while (iter != m_managedObjectComWrapperMap->End())
+ {
+ callback(iter->Value());
+ ++iter;
+ }
+ }
+
+ m_managedObjectComWrapperMap->RemoveAll();
}
#endif // !DACCESS_COMPILE
@@ -831,8 +876,11 @@ public:
private:
// See InteropLib API for usage.
- void* m_managedObjectComWrapper;
void* m_externalComObjectContext;
+
+ using ManagedObjectComWrapperByIdMap = MapSHash;
+ CrstExplicitInit m_managedObjectComWrapperLock;
+ NewHolder m_managedObjectComWrapperMap;
#endif // FEATURE_COMINTEROP
};
diff --git a/src/coreclr/src/vm/weakreferencenative.cpp b/src/coreclr/src/vm/weakreferencenative.cpp
index 2cdb83c..bc52c61 100644
--- a/src/coreclr/src/vm/weakreferencenative.cpp
+++ b/src/coreclr/src/vm/weakreferencenative.cpp
@@ -103,9 +103,9 @@ private:
#ifdef FEATURE_COMINTEROP
-// Get a native COM weak reference for the object underlying an RCW if applicable. If the incoming object cannot
-// use a native COM weak reference, nullptr is returned. Otherwise, an AddRef-ed IWeakReference* for the COM
-// object underlying the RCW is returned.
+// Get the native COM information for the object underlying an RCW if applicable. If the incoming object cannot
+// use a native COM weak reference, nullptr is returned. Otherwise, a new NativeComWeakHandleInfo containing an
+// AddRef-ed IWeakReference* for the COM object underlying the RCW is returned.
//
// In order to qualify to be used with a HNDTYPE_WEAK_NATIVE_COM, the incoming object must:
// * be an RCW
@@ -113,7 +113,7 @@ private:
// * succeed when asked for an IWeakReference*
//
// Note that *pObject should be GC protected on the way into this method
-IWeakReference* GetComWeakReference(OBJECTREF* pObject)
+NativeComWeakHandleInfo* GetComWeakReferenceInfo(OBJECTREF* pObject)
{
CONTRACTL
{
@@ -134,6 +134,7 @@ IWeakReference* GetComWeakReference(OBJECTREF* pObject)
MethodTable* pMT = (*pObject)->GetMethodTable();
SafeComHolder pWeakReferenceSource(nullptr);
+ INT64 wrapperId = ComWrappersNative::InvalidWrapperId;
// If the object is not an RCW, then we do not want to use a native COM weak reference to it
// If the object is a managed type deriving from a COM type, then we also do not want to use a native COM
@@ -147,7 +148,7 @@ IWeakReference* GetComWeakReference(OBJECTREF* pObject)
#ifdef FEATURE_COMWRAPPERS
else
{
- pWeakReferenceSource = reinterpret_cast(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource));
+ pWeakReferenceSource = reinterpret_cast(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource, &wrapperId));
}
#endif
@@ -163,7 +164,9 @@ IWeakReference* GetComWeakReference(OBJECTREF* pObject)
return nullptr;
}
- return pWeakReference.Extract();
+ NewHolder info = new NativeComWeakHandleInfo { pWeakReference.GetValue(), wrapperId };
+ pWeakReference.SuppressRelease();
+ return info.Extract();
}
// Given an object handle that stores a native COM weak reference, attempt to create an RCW
@@ -205,6 +208,7 @@ NOINLINE Object* LoadComWeakReferenceTarget(WEAKREFERENCEREF weakReference, Type
// Since we're acquiring and releasing the lock multiple times, we need to check the handle state each time we
// reacquire the lock to make sure that another thread hasn't reassigned the target of the handle or finalized it
SafeComHolder pComWeakReference = nullptr;
+ INT64 wrapperId = ComWrappersNative::InvalidWrapperId;
{
WeakHandleSpinLockHolder handle(AcquireWeakHandleSpinLock(gc.weakReference), &gc.weakReference);
GCX_NOTRIGGER();
@@ -233,9 +237,11 @@ NOINLINE Object* LoadComWeakReferenceTarget(WEAKREFERENCEREF weakReference, Type
_ASSERTE(pComWeakReference.IsNull());
CONTRACT_VIOLATION(GCViolation);
IGCHandleManager *mgr = GCHandleUtilities::GetGCHandleManager();
- pComWeakReference = reinterpret_cast(mgr->GetExtraInfoFromHandle(handle.Handle));
- if (!pComWeakReference.IsNull())
+ NativeComWeakHandleInfo* comWeakHandleInfo = reinterpret_cast(mgr->GetExtraInfoFromHandle(handle.Handle));
+ if (comWeakHandleInfo != nullptr)
{
+ wrapperId = comWeakHandleInfo->WrapperId;
+ pComWeakReference = comWeakHandleInfo->WeakReference;
pComWeakReference->AddRef();
}
}
@@ -268,9 +274,21 @@ NOINLINE Object* LoadComWeakReferenceTarget(WEAKREFERENCEREF weakReference, Type
// If we were able to get an IUnkown identity for the object, then we can find or create an associated RCW for it.
if (!pTargetIdentity.IsNull())
{
- // Try the global COM wrappers first before falling back to the built-in system.
- if (!GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(pTargetIdentity, &gc.rcw))
+ if (wrapperId != ComWrappersNative::InvalidWrapperId)
{
+ // Try the global COM wrappers
+ if (GlobalComWrappersForTrackerSupport::IsRegisteredInstance(wrapperId))
+ {
+ (void)GlobalComWrappersForTrackerSupport::TryGetOrCreateObjectForComInstance(pTargetIdentity, &gc.rcw);
+ }
+ else if (GlobalComWrappersForMarshalling::IsRegisteredInstance(wrapperId))
+ {
+ (void)GlobalComWrappersForMarshalling::TryGetOrCreateObjectForComInstance(pTargetIdentity, ObjFromComIP::NONE, &gc.rcw);
+ }
+ }
+ else
+ {
+ // If the original RCW was not created through ComWrappers, fall back to the built-in system.
GetObjectRefFromComIP(&gc.rcw, pTargetIdentity);
}
}
@@ -428,21 +446,21 @@ FCIMPL3(void, WeakReferenceNative::Create, WeakReferenceObject * pThisUNSAFE, Ob
// Create the handle.
#ifdef FEATURE_COMINTEROP
- IWeakReference* pRawComWeakReference = nullptr;
+ NativeComWeakHandleInfo *comWeakHandleInfo = nullptr;
if (gc.pTarget != NULL)
{
- SyncBlock* pSyncBlock = gc.pTarget->PassiveGetSyncBlock();
+ SyncBlock *pSyncBlock = gc.pTarget->PassiveGetSyncBlock();
if (pSyncBlock != nullptr && pSyncBlock->GetInteropInfoNoCreate() != nullptr)
{
- pRawComWeakReference = GetComWeakReference(&gc.pTarget);
+ comWeakHandleInfo = GetComWeakReferenceInfo(&gc.pTarget);
}
}
- if (pRawComWeakReference != nullptr)
+ if (comWeakHandleInfo != nullptr)
{
- SafeComHolder pComWeakReferenceHolder(pRawComWeakReference);
- gc.pThis->m_Handle = SetNativeComWeakReferenceHandle(GetAppDomain()->CreateNativeComWeakHandle(gc.pTarget, pComWeakReferenceHolder));
- pComWeakReferenceHolder.SuppressRelease();
+ NewHolder infoHolder(comWeakHandleInfo);
+ gc.pThis->m_Handle = SetNativeComWeakReferenceHandle(GetAppDomain()->CreateNativeComWeakHandle(gc.pTarget, infoHolder));
+ infoHolder.SuppressRelease();
}
else
#endif // FEATURE_COMINTEROP
@@ -478,23 +496,22 @@ FCIMPL3(void, WeakReferenceOfTNative::Create, WeakReferenceObject * pThisUNSAFE,
_ASSERTE(gc.pThis->GetMethodTable()->GetCanonicalMethodTable() == pWeakReferenceOfTCanonMT);
- // Create the handle.
#ifdef FEATURE_COMINTEROP
- IWeakReference* pRawComWeakReference = nullptr;
+ NativeComWeakHandleInfo *comWeakHandleInfo = nullptr;
if (gc.pTarget != NULL)
{
- SyncBlock* pSyncBlock = gc.pTarget->PassiveGetSyncBlock();
+ SyncBlock *pSyncBlock = gc.pTarget->PassiveGetSyncBlock();
if (pSyncBlock != nullptr && pSyncBlock->GetInteropInfoNoCreate() != nullptr)
{
- pRawComWeakReference = GetComWeakReference(&gc.pTarget);
+ comWeakHandleInfo = GetComWeakReferenceInfo(&gc.pTarget);
}
}
- if (pRawComWeakReference != nullptr)
+ if (comWeakHandleInfo != nullptr)
{
- SafeComHolder pComWeakReferenceHolder(pRawComWeakReference);
- gc.pThis->m_Handle = SetNativeComWeakReferenceHandle(GetAppDomain()->CreateNativeComWeakHandle(gc.pTarget, pComWeakReferenceHolder));
- pComWeakReferenceHolder.SuppressRelease();
+ NewHolder infoHolder(comWeakHandleInfo);
+ gc.pThis->m_Handle = SetNativeComWeakReferenceHandle(GetAppDomain()->CreateNativeComWeakHandle(gc.pTarget, infoHolder));
+ infoHolder.SuppressRelease();
}
else
#endif // FEATURE_COMINTEROP
@@ -747,7 +764,7 @@ NOINLINE void SetWeakReferenceTarget(WEAKREFERENCEREF weakReference, OBJECTREF t
HELPER_METHOD_FRAME_BEGIN_ATTRIB_2(Frame::FRAME_ATTR_EXACT_DEPTH|Frame::FRAME_ATTR_CAPTURE_DEPTH_2, target, weakReference);
#ifdef FEATURE_COMINTEROP
- SafeComHolder pTargetWeakReference(GetComWeakReference(&target));
+ NewHolder comWeakHandleInfo(GetComWeakReferenceInfo(&target));
#endif // FEATURE_COMINTEROP
@@ -778,21 +795,23 @@ NOINLINE void SetWeakReferenceTarget(WEAKREFERENCEREF weakReference, OBJECTREF t
if (IsNativeComWeakReferenceHandle(handle.RawHandle))
{
- // If the existing reference is a native COM weak reference, we need to release its IWeakReference pointer
- // and update it with the new weak reference pointer. If the incoming object is not an RCW that can
- // use IWeakReference, then pTargetWeakReference will be null. Therefore, no matter what the incoming
- // object type is, we can unconditionally store pTargetWeakReference to the object handle's extra data.
+ // If the existing reference is a native COM weak reference, we need to delete its native COM info
+ // and update it with the new native COM info. If the incoming object is not an RCW that can use
+ // IWeakReference, then comWeakHandleInfo will be null. Therefore, no matter what the incoming
+ // object type is, we can unconditionally store comWeakHandleInfo to the object handle's extra data.
IGCHandleManager *mgr = GCHandleUtilities::GetGCHandleManager();
- IWeakReference* pExistingWeakReference = reinterpret_cast(mgr->GetExtraInfoFromHandle(handle.Handle));
- mgr->SetExtraInfoForHandle(handle.Handle, HNDTYPE_WEAK_NATIVE_COM, reinterpret_cast(pTargetWeakReference.GetValue()));
+ NativeComWeakHandleInfo* existingInfo = reinterpret_cast(mgr->GetExtraInfoFromHandle(handle.Handle));
+ mgr->SetExtraInfoForHandle(handle.Handle, HNDTYPE_WEAK_NATIVE_COM, reinterpret_cast(comWeakHandleInfo.GetValue()));
StoreObjectInHandle(handle.Handle, target);
- if (pExistingWeakReference != nullptr)
+ if (existingInfo != nullptr)
{
- pExistingWeakReference->Release();
+ _ASSERTE(existingInfo->WeakReference != nullptr);
+ existingInfo->WeakReference->Release();
+ delete existingInfo;
}
}
- else if (pTargetWeakReference != nullptr)
+ else if (comWeakHandleInfo != nullptr)
{
// The existing handle is not a native COM weak reference, but we need to store the new object in
// a native COM weak reference. Therefore we need to destroy the old handle and create a new native COM
@@ -801,7 +820,7 @@ NOINLINE void SetWeakReferenceTarget(WEAKREFERENCEREF weakReference, OBJECTREF t
_ASSERTE(!IsNativeComWeakReferenceHandle(handle.RawHandle));
OBJECTHANDLE previousHandle = handle.RawHandle;
- handle.Handle = GetAppDomain()->CreateNativeComWeakHandle(target, pTargetWeakReference);
+ handle.Handle = GetAppDomain()->CreateNativeComWeakHandle(target, comWeakHandleInfo);
handle.RawHandle = SetNativeComWeakReferenceHandle(handle.Handle);
DestroyTypedHandle(previousHandle);
@@ -813,7 +832,7 @@ NOINLINE void SetWeakReferenceTarget(WEAKREFERENCEREF weakReference, OBJECTREF t
}
#ifdef FEATURE_COMINTEROP
- pTargetWeakReference.SuppressRelease();
+ comWeakHandleInfo.SuppressRelease();
#endif // FEATURE_COMINTEROP
HELPER_METHOD_FRAME_END();
diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs
index 481eafd..2e76a65 100644
--- a/src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs
+++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/API/Program.cs
@@ -159,6 +159,50 @@ namespace ComWrappersTests
Assert.AreNotEqual(trackerObj1, trackerObj3);
}
+ static void ValidateWrappersInstanceIsolation()
+ {
+ Console.WriteLine($"Running {nameof(ValidateWrappersInstanceIsolation)}...");
+
+ var cw1 = new TestComWrappers();
+ var cw2 = new TestComWrappers();
+
+ var testObj = new Test();
+
+ // Allocate a wrapper for the object
+ IntPtr comWrapper1 = cw1.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.TrackerSupport);
+ IntPtr comWrapper2 = cw2.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.TrackerSupport);
+ Assert.AreNotEqual(comWrapper1, IntPtr.Zero);
+ Assert.AreNotEqual(comWrapper2, IntPtr.Zero);
+ Assert.AreNotEqual(comWrapper1, comWrapper2);
+
+ IntPtr comWrapper3 = cw1.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.TrackerSupport);
+ IntPtr comWrapper4 = cw2.GetOrCreateComInterfaceForObject(testObj, CreateComInterfaceFlags.TrackerSupport);
+ Assert.AreNotEqual(comWrapper3, comWrapper4);
+ Assert.AreEqual(comWrapper1, comWrapper3);
+ Assert.AreEqual(comWrapper2, comWrapper4);
+
+ Marshal.Release(comWrapper1);
+ Marshal.Release(comWrapper2);
+ Marshal.Release(comWrapper3);
+ Marshal.Release(comWrapper4);
+
+ // Get an object from a tracker runtime.
+ IntPtr trackerObjRaw = MockReferenceTrackerRuntime.CreateTrackerObject();
+
+ // Create objects for the COM instance
+ var trackerObj1 = (ITrackerObjectWrapper)cw1.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject);
+ var trackerObj2 = (ITrackerObjectWrapper)cw2.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject);
+ Assert.AreNotEqual(trackerObj1, trackerObj2);
+
+ var trackerObj3 = (ITrackerObjectWrapper)cw1.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject);
+ var trackerObj4 = (ITrackerObjectWrapper)cw2.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.TrackerObject);
+ Assert.AreNotEqual(trackerObj3, trackerObj4);
+ Assert.AreEqual(trackerObj1, trackerObj3);
+ Assert.AreEqual(trackerObj2, trackerObj4);
+
+ Marshal.Release(trackerObjRaw);
+ }
+
static void ValidatePrecreatedExternalWrapper()
{
Console.WriteLine($"Running {nameof(ValidatePrecreatedExternalWrapper)}...");
@@ -357,6 +401,7 @@ namespace ComWrappersTests
ValidateComInterfaceCreation();
ValidateFallbackQueryInterface();
ValidateCreateObjectCachingScenario();
+ ValidateWrappersInstanceIsolation();
ValidatePrecreatedExternalWrapper();
ValidateIUnknownImpls();
ValidateBadComWrapperImpl();
diff --git a/src/coreclr/tests/src/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs b/src/coreclr/tests/src/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs
index f85138f..64d32d6 100644
--- a/src/coreclr/tests/src/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs
+++ b/src/coreclr/tests/src/Interop/COM/ComWrappers/WeakReference/WeakReferenceTest.cs
@@ -22,7 +22,14 @@ namespace ComWrappersTests
public IntPtr Vtbl;
}
- public class WeakReferencableWrapper
+ public enum WrapperRegistration
+ {
+ Local,
+ TrackerSupport,
+ Marshalling,
+ }
+
+ public class WeakReferenceableWrapper
{
private struct Vtbl
{
@@ -37,14 +44,17 @@ namespace ComWrappersTests
private readonly IntPtr instance;
private readonly Vtbl vtable;
- public WeakReferencableWrapper(IntPtr instance)
+ public WrapperRegistration Registration { get; }
+
+ public WeakReferenceableWrapper(IntPtr instance, WrapperRegistration reg)
{
var inst = Marshal.PtrToStructure(instance);
this.vtable = Marshal.PtrToStructure(inst.Vtbl);
this.instance = instance;
+ Registration = reg;
}
- ~WeakReferencableWrapper()
+ ~WeakReferenceableWrapper()
{
if (this.instance != IntPtr.Zero)
{
@@ -57,6 +67,13 @@ namespace ComWrappersTests
{
class TestComWrappers : ComWrappers
{
+ public WrapperRegistration Registration { get; }
+
+ public TestComWrappers(WrapperRegistration reg = WrapperRegistration.Local)
+ {
+ Registration = reg;
+ }
+
protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
{
count = 0;
@@ -66,59 +83,158 @@ namespace ComWrappersTests
protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flag)
{
Marshal.AddRef(externalComObject);
- return new WeakReferencableWrapper(externalComObject);
+ return new WeakReferenceableWrapper(externalComObject, Registration);
}
protected override void ReleaseObjects(IEnumerable objects)
{
}
- public static readonly ComWrappers Instance = new TestComWrappers();
+ public static readonly TestComWrappers TrackerSupportInstance = new TestComWrappers(WrapperRegistration.TrackerSupport);
+ public static readonly TestComWrappers MarshallingInstance = new TestComWrappers(WrapperRegistration.Marshalling);
}
- static void ValidateNativeWeakReference()
+ private static void ValidateWeakReferenceState(WeakReference wr, bool expectedIsAlive, TestComWrappers sourceWrappers = null)
{
- Console.WriteLine($"Running {nameof(ValidateNativeWeakReference)}...");
+ WeakReferenceableWrapper target;
+ bool isAlive = wr.TryGetTarget(out target);
+ Assert.AreEqual(expectedIsAlive, isAlive);
- static (WeakReference, IntPtr) GetWeakReference()
- {
- var cw = new TestComWrappers();
+ if (isAlive && sourceWrappers != null)
+ Assert.AreEqual(sourceWrappers.Registration, target.Registration);
+ }
- IntPtr objRaw = WeakReferenceNative.CreateWeakReferencableObject();
+ private static (WeakReference, IntPtr) GetWeakReference(TestComWrappers cw)
+ {
+ IntPtr objRaw = WeakReferenceNative.CreateWeakReferencableObject();
+ var obj = (WeakReferenceableWrapper)cw.GetOrCreateObjectForComInstance(objRaw, CreateObjectFlags.None);
+ var wr = new WeakReference(obj);
+ ValidateWeakReferenceState(wr, expectedIsAlive: true, cw);
+ return (wr, objRaw);
+ }
+
+ private static IntPtr SetWeakReferenceTarget(WeakReference wr, TestComWrappers cw)
+ {
+ IntPtr objRaw = WeakReferenceNative.CreateWeakReferencableObject();
+ var obj = (WeakReferenceableWrapper)cw.GetOrCreateObjectForComInstance(objRaw, CreateObjectFlags.None);
+ wr.SetTarget(obj);
+ ValidateWeakReferenceState(wr, expectedIsAlive: true, cw);
+ return objRaw;
+ }
+
+ private static void ValidateNativeWeakReference(TestComWrappers cw)
+ {
+ Console.WriteLine($" -- Validate weak reference creation");
+ var (weakRef, nativeRef) = GetWeakReference(cw);
+
+ // Make sure RCW is collected
+ GC.Collect();
+ GC.WaitForPendingFinalizers();
+
+ // Non-globally registered ComWrappers instances do not support rehydration.
+ // A weak reference to an RCW wrapping an IWeakReference can stay alive if the RCW was created through
+ // a global ComWrappers instance. If the RCW was created throug a local ComWrappers instance, the weak
+ // reference should be dead and stay dead once the RCW is collected.
+ bool supportsRehydration = cw.Registration != WrapperRegistration.Local;
+
+ Console.WriteLine($" -- Validate RCW recreation");
+ ValidateWeakReferenceState(weakRef, expectedIsAlive: supportsRehydration, cw);
+
+ // Release the last native reference.
+ Marshal.Release(nativeRef);
+ GC.Collect();
+ GC.WaitForPendingFinalizers();
+
+ // After all native references die and the RCW is collected, the weak reference should be dead and stay dead.
+ Console.WriteLine($" -- Validate release");
+ ValidateWeakReferenceState(weakRef, expectedIsAlive: false);
+
+ // Reset the weak reference target
+ Console.WriteLine($" -- Validate target reset");
+ nativeRef = SetWeakReferenceTarget(weakRef, cw);
- var obj = (WeakReferencableWrapper)cw.GetOrCreateObjectForComInstance(objRaw, CreateObjectFlags.None);
+ // Make sure RCW is collected
+ GC.Collect();
+ GC.WaitForPendingFinalizers();
+
+ Console.WriteLine($" -- Validate RCW recreation");
+ ValidateWeakReferenceState(weakRef, expectedIsAlive: supportsRehydration, cw);
+
+ // Release the last native reference.
+ Marshal.Release(nativeRef);
+ GC.Collect();
+ GC.WaitForPendingFinalizers();
+
+ // After all native references die and the RCW is collected, the weak reference should be dead and stay dead.
+ Console.WriteLine($" -- Validate release");
+ ValidateWeakReferenceState(weakRef, expectedIsAlive: false);
+ }
+
+ static void ValidateGlobalInstanceTrackerSupport()
+ {
+ Console.WriteLine($"Running {nameof(ValidateGlobalInstanceTrackerSupport)}...");
+ ValidateNativeWeakReference(TestComWrappers.TrackerSupportInstance);
+ }
+
+ static void ValidateGlobalInstanceMarshalling()
+ {
+ Console.WriteLine($"Running {nameof(ValidateGlobalInstanceMarshalling)}...");
+ ValidateNativeWeakReference(TestComWrappers.MarshallingInstance);
+ }
+
+ static void ValidateLocalInstance()
+ {
+ Console.WriteLine($"Running {nameof(ValidateLocalInstance)}...");
+ ValidateNativeWeakReference(new TestComWrappers());
+ }
+
+ static void ValidateNonComWrappers()
+ {
+ Console.WriteLine($"Running {nameof(ValidateNonComWrappers)}...");
- return (new WeakReference(obj), objRaw);
+ (WeakReference, IntPtr) GetWeakReference()
+ {
+ IntPtr objRaw = WeakReferenceNative.CreateWeakReferencableObject();
+ var obj = Marshal.GetObjectForIUnknown(objRaw);
+ return (new WeakReference(obj), objRaw);
}
- static bool CheckIfWeakReferenceIsAlive(WeakReference wr)
+ bool HasTarget(WeakReference wr)
{
- return wr.TryGetTarget(out _);
+ return wr.Target != null;
}
var (weakRef, nativeRef) = GetWeakReference();
GC.Collect();
GC.WaitForPendingFinalizers();
- // A weak reference to an RCW wrapping an IWeakReference should stay alive even after the RCW dies
- Assert.IsTrue(CheckIfWeakReferenceIsAlive(weakRef));
+
+ // A weak reference to an RCW wrapping an IWeakReference created throguh the built-in system
+ // should stay alive even after the RCW dies
+ Assert.IsFalse(weakRef.IsAlive);
+ Assert.IsTrue(HasTarget(weakRef));
// Release the last native reference.
Marshal.Release(nativeRef);
-
GC.Collect();
GC.WaitForPendingFinalizers();
// After all native references die and the RCW is collected, the weak reference should be dead and stay dead.
- Assert.IsFalse(CheckIfWeakReferenceIsAlive(weakRef));
-
+ Assert.IsNull(weakRef.Target);
}
static int Main(string[] doNotUse)
{
try
{
- ComWrappers.RegisterForTrackerSupport(TestComWrappers.Instance);
- ValidateNativeWeakReference();
+ ValidateNonComWrappers();
+
+ ComWrappers.RegisterForTrackerSupport(TestComWrappers.TrackerSupportInstance);
+ ValidateGlobalInstanceTrackerSupport();
+
+ ComWrappers.RegisterForMarshalling(TestComWrappers.MarshallingInstance);
+ ValidateGlobalInstanceMarshalling();
+
+ ValidateLocalInstance();
}
catch (Exception e)
{
--
2.7.4