* Don't create a COM weak reference if the object is an aggregated COMWrappers RCW.
* Add test for weak reference + aggregation with native weak reference impl.
* Apply suggestions from code review
Co-authored-by: Aaron Robinson <arobins@microsoft.com>
Co-authored-by: Aaron Robinson <arobins@microsoft.com>
static void MarkWrapperAsComActivated(_In_ IUnknown* wrapperMaybe);
public: // Unwrapping support
- static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId);
+ static IUnknown* GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId, _Out_ bool* isAggregated);
static bool HasManagedObjectComWrapper(_In_ OBJECTREF object, _Out_ bool* isActive);
public: // GC interaction
// The EOC is "detached" and no longer used to map between identity and a managed object.
// This will only be set if the EOC was inserted into the cache.
Flags_Detached = 8,
+
+ // This EOC is an aggregated instance
+ Flags_Aggregated = 16
};
DWORD Flags;
: ExternalObjectContext::Flags_None) |
(uniqueInstance
? ExternalObjectContext::Flags_None
- : ExternalObjectContext::Flags_InCache);
+ : ExternalObjectContext::Flags_InCache) |
+ ((flags & CreateObjectFlags::CreateObjectFlags_Aggregated) != 0
+ ? ExternalObjectContext::Flags_Aggregated
+ : ExternalObjectContext::Flags_None);
+
ExternalObjectContext::Construct(
resultHolder.GetContext(),
identity,
objRef);
}
-IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId)
+IUnknown* ComWrappersNative::GetIdentityForObject(_In_ OBJECTREF* objectPROTECTED, _In_ REFIID riid, _Out_ INT64* wrapperId, _Out_ bool* isAggregated)
{
CONTRACTL
{
{
ExternalObjectContext* context = reinterpret_cast<ExternalObjectContext*>(contextMaybe);
*wrapperId = context->WrapperId;
+ *isAggregated = context->IsSet(ExternalObjectContext::Flags_Aggregated);
IUnknown* identity = reinterpret_cast<IUnknown*>(context->Identity);
GCX_PREEMP();
//
// In order to qualify to be used with a HNDTYPE_WEAK_NATIVE_COM, the incoming object must:
// * be an RCW
+// * not be an aggregated RCW
// * respond to a QI for IWeakReferenceSource
// * succeed when asked for an IWeakReference*
//
#endif
{
#ifdef FEATURE_COMWRAPPERS
- pWeakReferenceSource = reinterpret_cast<IWeakReferenceSource*>(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource, &wrapperId));
+ bool isAggregated = false;
+ pWeakReferenceSource = reinterpret_cast<IWeakReferenceSource*>(ComWrappersNative::GetIdentityForObject(pObject, IID_IWeakReferenceSource, &wrapperId, &isAggregated));
+ if (isAggregated)
+ {
+ // If the RCW is an aggregated RCW, then the managed object cannot be recreated from the IUnknown as the outer IUnknown wraps the managed object.
+ // In this case, don't create a weak reference backed by a COM weak reference.
+ pWeakReferenceSource = nullptr;
+ }
#endif
}
_ASSERTE(gc.pThis->GetMethodTable()->CanCastToClass(pWeakReferenceMT));
// Create the handle.
-#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
+#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
NativeComWeakHandleInfo *comWeakHandleInfo = nullptr;
if (gc.pTarget != NULL)
{
OBJECTREF pTarget = GetWeakReferenceTarget(pThis);
-#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
+#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
// If we found an object, or we're not a native COM weak reference, then we're done. Othewrise
// we can try to create a new RCW to the underlying native COM object if it's still alive.
if (pTarget != NULL || !IsNativeComWeakReferenceHandle(pThis->m_Handle))
OBJECTREF pTarget = GetWeakReferenceTarget(pThis);
-#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
+#if defined(FEATURE_COMINTEROP) || defined(FEATURE_COMWRAPPERS)
// If we found an object, or we're not a native COM weak reference, then we're done. Othewrise
// we can try to create a new RCW to the underlying native COM object if it's still alive.
if (pTarget != NULL || !IsNativeComWeakReferenceHandle(pThis->m_Handle))
return UnknownImpl::DoRelease();
}
};
+
+ struct WeakReferenceSource : public IWeakReferenceSource, public IInspectable
+ {
+ private:
+ IUnknown* _outerUnknown;
+ ComSmartPtr<WeakReference> _weakReference;
+ public:
+ WeakReferenceSource(IUnknown* outerUnknown)
+ :_outerUnknown(outerUnknown),
+ _weakReference(new WeakReference(this, 1))
+ {
+ }
+
+ STDMETHOD(GetWeakReference)(IWeakReference** ppWeakReference)
+ {
+ _weakReference->AddRef();
+ *ppWeakReference = _weakReference;
+ return S_OK;
+ }
+
+ STDMETHOD(QueryInterface)(
+ /* [in] */ REFIID riid,
+ /* [iid_is][out] */ void ** ppvObject)
+ {
+ if (riid == __uuidof(IWeakReferenceSource))
+ {
+ *ppvObject = static_cast<IWeakReferenceSource*>(this);
+ _weakReference->AddStrongRef();
+ return S_OK;
+ }
+ return _outerUnknown->QueryInterface(riid, ppvObject);
+ }
+ STDMETHOD_(ULONG, AddRef)(void)
+ {
+ return _weakReference->AddStrongRef();
+ }
+ STDMETHOD_(ULONG, Release)(void)
+ {
+ return _weakReference->ReleaseStrongRef();
+ }
+
+ STDMETHOD(GetRuntimeClassName)(HSTRING* pRuntimeClassName)
+ {
+ return E_NOTIMPL;
+ }
+
+ STDMETHOD(GetIids)(
+ ULONG *iidCount,
+ IID **iids)
+ {
+ return E_NOTIMPL;
+ }
+
+ STDMETHOD(GetTrustLevel)(TrustLevel *trustLevel)
+ {
+ *trustLevel = FullTrust;
+ return S_OK;
+ }
+ };
+
+ struct AggregatedWeakReferenceSource : IInspectable
+ {
+ private:
+ IUnknown* _outerUnknown;
+ ComSmartPtr<WeakReferenceSource> _weakReference;
+ public:
+ AggregatedWeakReferenceSource(IUnknown* outerUnknown)
+ :_outerUnknown(outerUnknown),
+ _weakReference(new WeakReferenceSource(outerUnknown))
+ {
+ }
+
+ STDMETHOD(GetRuntimeClassName)(HSTRING* pRuntimeClassName)
+ {
+ return E_NOTIMPL;
+ }
+
+ STDMETHOD(GetIids)(
+ ULONG *iidCount,
+ IID **iids)
+ {
+ return E_NOTIMPL;
+ }
+
+ STDMETHOD(GetTrustLevel)(TrustLevel *trustLevel)
+ {
+ *trustLevel = FullTrust;
+ return S_OK;
+ }
+
+ STDMETHOD(QueryInterface)(
+ /* [in] */ REFIID riid,
+ /* [iid_is][out] */ void ** ppvObject)
+ {
+ if (riid == __uuidof(IWeakReferenceSource))
+ {
+ return _weakReference->QueryInterface(riid, ppvObject);
+ }
+ return _outerUnknown->QueryInterface(riid, ppvObject);
+ }
+ STDMETHOD_(ULONG, AddRef)(void)
+ {
+ return _outerUnknown->AddRef();
+ }
+ STDMETHOD_(ULONG, Release)(void)
+ {
+ return _outerUnknown->Release();
+ }
+ };
}
extern "C" DLL_EXPORT WeakReferencableObject* STDMETHODCALLTYPE CreateWeakReferencableObject()
{
return new WeakReferencableObject();
}
+
+extern "C" DLL_EXPORT AggregatedWeakReferenceSource* STDMETHODCALLTYPE CreateAggregatedWeakReferenceObject(IUnknown* pOuter)
+{
+ return new AggregatedWeakReferenceSource(pOuter);
+}
{
[DllImport(nameof(WeakReferenceNative))]
public static extern IntPtr CreateWeakReferencableObject();
+
+ [DllImport(nameof(WeakReferenceNative))]
+ public static extern IntPtr CreateAggregatedWeakReferenceObject(IntPtr outer);
}
public struct VtblPtr
Marshalling,
}
- public class WeakReferenceableWrapper
+ public unsafe class WeakReferenceableWrapper
{
private struct Vtbl
{
- public IntPtr QueryInterface;
- public _AddRef AddRef;
- public _Release Release;
+ public delegate* unmanaged<IntPtr, Guid*, IntPtr*, int> QueryInterface;
+ public delegate* unmanaged<IntPtr, int> AddRef;
+ public delegate* unmanaged<IntPtr, int> Release;
}
- private delegate int _AddRef(IntPtr This);
- private delegate int _Release(IntPtr This);
-
private readonly IntPtr instance;
private readonly Vtbl vtable;
+ private readonly bool releaseInFinalizer;
public WrapperRegistration Registration { get; }
- public WeakReferenceableWrapper(IntPtr instance, WrapperRegistration reg)
+ public WeakReferenceableWrapper(IntPtr instance, WrapperRegistration reg, bool releaseInFinalizer = true)
{
var inst = Marshal.PtrToStructure<VtblPtr>(instance);
this.vtable = Marshal.PtrToStructure<Vtbl>(inst.Vtbl);
this.instance = instance;
+ this.releaseInFinalizer = releaseInFinalizer;
Registration = reg;
}
+ public int QueryInterface(Guid iid, out IntPtr ptr)
+ {
+ fixed(IntPtr* ppv = &ptr)
+ {
+ return this.vtable.QueryInterface(this.instance, &iid, ppv);
+ }
+ }
+
~WeakReferenceableWrapper()
{
- if (this.instance != IntPtr.Zero)
+ if (this.instance != IntPtr.Zero && this.releaseInFinalizer)
{
this.vtable.Release(this.instance);
}
}
}
- class Program
+ class DerivedObject : ICustomQueryInterface
{
- class TestComWrappers : ComWrappers
+ private WeakReferenceableWrapper inner;
+ public DerivedObject(TestComWrappers comWrappersInstance)
{
- public WrapperRegistration Registration { get; }
+ IntPtr innerInstance = WeakReferenceNative.CreateAggregatedWeakReferenceObject(
+ comWrappersInstance.GetOrCreateComInterfaceForObject(this, CreateComInterfaceFlags.None));
+ inner = new WeakReferenceableWrapper(innerInstance, comWrappersInstance.Registration, releaseInFinalizer: false);
+ comWrappersInstance.GetOrRegisterObjectForComInstance(innerInstance, CreateObjectFlags.Aggregation, this);
+ }
- public TestComWrappers(WrapperRegistration reg = WrapperRegistration.Local)
- {
- Registration = reg;
- }
+ public CustomQueryInterfaceResult GetInterface(ref Guid iid, out IntPtr ppv)
+ {
+ return inner.QueryInterface(iid, out ppv) == 0 ? CustomQueryInterfaceResult.Handled : CustomQueryInterfaceResult.Failed;
+ }
+ }
- protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
- {
- count = 0;
- return null;
- }
+ class TestComWrappers : ComWrappers
+ {
+ public WrapperRegistration Registration { get; }
- protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flag)
- {
- Marshal.AddRef(externalComObject);
- return new WeakReferenceableWrapper(externalComObject, Registration);
- }
+ public TestComWrappers(WrapperRegistration reg = WrapperRegistration.Local)
+ {
+ Registration = reg;
+ }
- protected override void ReleaseObjects(IEnumerable objects)
- {
- }
+ protected unsafe override ComInterfaceEntry* ComputeVtables(object obj, CreateComInterfaceFlags flags, out int count)
+ {
+ count = 0;
+ return null;
+ }
+
+ protected override object CreateObject(IntPtr externalComObject, CreateObjectFlags flag)
+ {
+ Marshal.AddRef(externalComObject);
+ return new WeakReferenceableWrapper(externalComObject, Registration);
+ }
- public static readonly TestComWrappers TrackerSupportInstance = new TestComWrappers(WrapperRegistration.TrackerSupport);
- public static readonly TestComWrappers MarshallingInstance = new TestComWrappers(WrapperRegistration.Marshalling);
+ protected override void ReleaseObjects(IEnumerable objects)
+ {
}
+ public static readonly TestComWrappers TrackerSupportInstance = new TestComWrappers(WrapperRegistration.TrackerSupport);
+ public static readonly TestComWrappers MarshallingInstance = new TestComWrappers(WrapperRegistration.Marshalling);
+ }
+
+ class Program
+ {
+
private static void ValidateWeakReferenceState(WeakReference<WeakReferenceableWrapper> wr, bool expectedIsAlive, TestComWrappers sourceWrappers = null)
{
WeakReferenceableWrapper target;
// a global ComWrappers instance. If the RCW was created throug a local ComWrappers instance, the weak
// reference should be dead and stay dead once the RCW is collected.
bool supportsRehydration = cw.Registration != WrapperRegistration.Local;
-
+
Console.WriteLine($" -- Validate RCW recreation");
ValidateWeakReferenceState(weakRef, expectedIsAlive: supportsRehydration, cw);
Assert.IsNull(weakRef.Target);
}
+ static void ValidateAggregatedWeakReference()
+ {
+ Console.WriteLine("Validate weak reference with aggregation.");
+ var (handle, weakRef) = GetWeakReference();
+
+ GC.Collect();
+ GC.WaitForPendingFinalizers();
+
+ Assert.IsNull(handle.Target);
+ Assert.IsFalse(weakRef.TryGetTarget(out _));
+
+ static (GCHandle handle, WeakReference<DerivedObject>) GetWeakReference()
+ {
+ DerivedObject obj = new DerivedObject(TestComWrappers.TrackerSupportInstance);
+ // We use an explicit weak GC handle here to enable us to validate that we are using "weak" GCHandle
+ // semantics with the weak reference.
+ return (GCHandle.Alloc(obj, GCHandleType.Weak), new WeakReference<DerivedObject>(obj));
+ }
+ }
+
static int Main(string[] doNotUse)
{
try
ComWrappers.RegisterForTrackerSupport(TestComWrappers.TrackerSupportInstance);
ValidateGlobalInstanceTrackerSupport();
+ ValidateAggregatedWeakReference();
ValidateLocalInstance();
}