Add check for unloaded types in GetAppDomainStaticAddress (#34677)
authorDavid Mason <davmason@microsoft.com>
Fri, 10 Apr 2020 18:26:00 +0000 (11:26 -0700)
committerGitHub <noreply@github.com>
Fri, 10 Apr 2020 18:26:00 +0000 (11:26 -0700)
Fixes #33367

When a module is unloaded, the managed objectref is collected during a GC before the profiler is notified that the module is unloading. That means if you call in to GetAppDomainStaticAddress between when the object is collected and when you are notified about the module unload (GarbageCollectionFinished is a place that will hit this every time) it will cause an AV from trying to use the null objectref.

This fix prevents this AV by checking to see if the object's loaderheap or managed loaderheap object ref are invalid first.

13 files changed:
src/coreclr/src/vm/proftoeeinterfaceimpl.cpp
src/coreclr/tests/src/profiler/native/CMakeLists.txt
src/coreclr/tests/src/profiler/native/classfactory.cpp
src/coreclr/tests/src/profiler/native/getappdomainstaticaddress/getappdomainstaticaddress.cpp [new file with mode: 0644]
src/coreclr/tests/src/profiler/native/getappdomainstaticaddress/getappdomainstaticaddress.h [new file with mode: 0644]
src/coreclr/tests/src/profiler/native/metadatagetdispenser/metadatagetdispenser.cpp [moved from src/coreclr/tests/src/profiler/native/unittestprofiler/unittestprofiler.cpp with 90% similarity]
src/coreclr/tests/src/profiler/native/metadatagetdispenser/metadatagetdispenser.h [moved from src/coreclr/tests/src/profiler/native/unittestprofiler/unittestprofiler.h with 87% similarity]
src/coreclr/tests/src/profiler/native/profilerstring.h
src/coreclr/tests/src/profiler/unittest/getappdomainstaticaddress.cs [new file with mode: 0644]
src/coreclr/tests/src/profiler/unittest/getappdomainstaticaddress.csproj [new file with mode: 0644]
src/coreclr/tests/src/profiler/unittest/metadatagetdispenser.cs
src/coreclr/tests/src/profiler/unittest/unloadlibrary.cs [new file with mode: 0644]
src/coreclr/tests/src/profiler/unittest/unloadlibrary.csproj [new file with mode: 0644]

index 385ff4a..4801655 100644 (file)
@@ -3145,6 +3145,12 @@ HRESULT ProfToEEInterfaceImpl::GetAppDomainStaticAddress(ClassID classId,
         return CORPROF_E_DATAINCOMPLETE;
     }
 
+    if (typeHandle.GetModule()->GetLoaderAllocator() == NULL ||
+        typeHandle.GetModule()->GetLoaderAllocator()->GetExposedObject() == NULL)
+    {
+        return CORPROF_E_DATAINCOMPLETE;
+    }
+
     //
     // Get the field descriptor object
     //
index 9405df3..3de1e8b 100644 (file)
@@ -5,9 +5,10 @@ project(Profiler)
 set(GCBASIC_SOURCES gcbasicprofiler/gcbasicprofiler.cpp)
 set(REJIT_SOURCES rejitprofiler/rejitprofiler.cpp rejitprofiler/ilrewriter.cpp rejitprofiler/sigparse.cpp)
 set(EVENTPIPE_SOURCES eventpipeprofiler/eventpipeprofiler.cpp)
-set(UNITTEST_SOURCES unittestprofiler/unittestprofiler.cpp)
+set(METADATAGETDISPENSER_SOURCES metadatagetdispenser/metadatagetdispenser.cpp)
+set(GETAPPDOMAINSTATICADDRESS_SOURCES getappdomainstaticaddress/getappdomainstaticaddress.cpp)
 
-set(SOURCES ${GCBASIC_SOURCES} ${REJIT_SOURCES} ${EVENTPIPE_SOURCES} ${UNITTEST_SOURCES} profiler.def profiler.cpp classfactory.cpp dllmain.cpp guids.cpp)
+set(SOURCES ${GCBASIC_SOURCES} ${REJIT_SOURCES} ${EVENTPIPE_SOURCES} ${METADATAGETDISPENSER_SOURCES} ${GETAPPDOMAINSTATICADDRESS_SOURCES} profiler.def profiler.cpp classfactory.cpp dllmain.cpp guids.cpp)
 
 include_directories(../../../../src/pal/prebuilt/inc)
 
index 535fa0a..a3f0f6a 100644 (file)
@@ -6,7 +6,8 @@
 #include "gcbasicprofiler/gcbasicprofiler.h"
 #include "rejitprofiler/rejitprofiler.h"
 #include "eventpipeprofiler/eventpipeprofiler.h"
-#include "unittestprofiler/unittestprofiler.h"
+#include "metadatagetdispenser/metadatagetdispenser.h"
+#include "getappdomainstaticaddress/getappdomainstaticaddress.h"
 
 ClassFactory::ClassFactory(REFCLSID clsid) : refCount(0), clsid(clsid)
 {
@@ -58,7 +59,8 @@ HRESULT STDMETHODCALLTYPE ClassFactory::CreateInstance(IUnknown *pUnkOuter, REFI
                new GCBasicProfiler(),
         new ReJITProfiler(),
         new EventPipeProfiler(),
-        new UnitTestProfiler()
+        new MetaDataGetDispenser(),
+        new GetAppDomainStaticAddress()
                // add new profilers here
        };
 
diff --git a/src/coreclr/tests/src/profiler/native/getappdomainstaticaddress/getappdomainstaticaddress.cpp b/src/coreclr/tests/src/profiler/native/getappdomainstaticaddress/getappdomainstaticaddress.cpp
new file mode 100644 (file)
index 0000000..598bc6a
--- /dev/null
@@ -0,0 +1,560 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+#include "getappdomainstaticaddress.h"
+#include <string>
+#include <assert.h>
+#include <inttypes.h>
+#include <sstream>
+
+
+using std::thread;
+using std::shared_ptr;
+using std::map;
+using std::make_pair;
+using std::mutex;
+using std::lock_guard;
+using std::wstring;
+using std::vector;
+
+// Prints a lot to the console for easier tracking
+#define DEBUG_OUT false
+
+GetAppDomainStaticAddress::GetAppDomainStaticAddress() :
+    refCount(0),
+    failures(0),
+    successes(0),
+    jitEventCount(0),
+    gcTriggerThread(),
+    gcWaitEvent(),
+    classADMap(),
+    classADMapLock()
+{
+
+}
+
+GetAppDomainStaticAddress::~GetAppDomainStaticAddress()
+{
+
+}
+
+GUID GetAppDomainStaticAddress::GetClsid()
+{
+    // {604D76F0-2AF2-48E0-B196-80C972F6AFB7}
+    GUID clsid = { 0x604D76F0, 0x2AF2, 0x48E0, {0xB1, 0x96, 0x80, 0xC9, 0x72, 0xF6, 0xAF, 0xB7 } };
+    return clsid;
+}
+
+HRESULT GetAppDomainStaticAddress::Initialize(IUnknown *pICorProfilerInfoUnk)
+{
+    printf("Initialize profiler!\n");
+
+    HRESULT hr = pICorProfilerInfoUnk->QueryInterface(IID_ICorProfilerInfo10, (void**)&pCorProfilerInfo);
+    if (hr != S_OK)
+    {
+        printf("Got HR %X from QI for ICorProfilerInfo4", hr);
+        ++failures;
+        return E_FAIL;
+    }
+
+    pCorProfilerInfo->SetEventMask2(COR_PRF_MONITOR_GC |
+                                   COR_PRF_MONITOR_CLASS_LOADS |
+                                   COR_PRF_MONITOR_MODULE_LOADS |
+                                   COR_PRF_MONITOR_JIT_COMPILATION |
+                                   COR_PRF_DISABLE_ALL_NGEN_IMAGES, 0);
+
+    auto gcTriggerLambda = [&]()
+    {
+        pCorProfilerInfo->InitializeCurrentThread();
+
+        while (true)
+        {
+            std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+            gcWaitEvent.Wait();
+
+            if (!IsRuntimeExecutingManagedCode())
+            {
+                if (DEBUG_OUT)
+                {
+                    printf("Runtime has not started executing managed code yet.\n");
+                }
+                continue;
+            }
+
+            printf("Forcing GC\n");
+            HRESULT hr = pCorProfilerInfo->ForceGC();
+            if (FAILED(hr))
+            {
+                printf("Error forcing GC... hr=0x%x \n", hr);
+                ++failures;
+                continue;
+            }
+        }
+    };
+
+    gcTriggerThread = thread(gcTriggerLambda);
+    gcWaitEvent.Signal();
+
+    return S_OK;
+}
+
+HRESULT GetAppDomainStaticAddress::Shutdown()
+{
+    Profiler::Shutdown();
+
+    gcWaitEvent.Reset();
+
+    if (this->pCorProfilerInfo != nullptr)
+    {
+        this->pCorProfilerInfo->Release();
+        this->pCorProfilerInfo = nullptr;
+    }
+
+    if(failures == 0 && successes > 0)
+    {
+        printf("PROFILER TEST PASSES\n");
+    }
+    else
+    {
+        printf("Test failed number of failures=%d successes=%d\n", failures.load(), successes.load());
+    }
+    fflush(stdout);
+
+    return S_OK;
+}
+
+HRESULT GetAppDomainStaticAddress::ModuleLoadFinished(ModuleID moduleId, HRESULT hrStatus)
+{
+    constexpr size_t nameLen = 1024;
+    WCHAR name[nameLen];
+    HRESULT hr = pCorProfilerInfo->GetModuleInfo2(moduleId,
+                                                 NULL,
+                                                nameLen,
+                                                NULL,
+                                                name,
+                                                NULL,
+                                                NULL);
+    if (FAILED(hr))
+    {
+        printf("GetModuleInfo2 failed with hr=0x%x\n", hr);
+        ++failures;
+    }
+
+    if (DEBUG_OUT)
+    {
+        wprintf(L"Module 0x%" PRIxPTR " (%s) loaded\n", moduleId, name);
+    }
+
+    return S_OK;
+}
+
+HRESULT GetAppDomainStaticAddress::ModuleUnloadStarted(ModuleID moduleId)
+{
+    lock_guard<mutex> guard(classADMapLock);
+    constexpr size_t nameLen = 1024;
+    WCHAR name[nameLen];
+    HRESULT hr = pCorProfilerInfo->GetModuleInfo2(moduleId,
+                                                 NULL,
+                                                nameLen,
+                                                NULL,
+                                                name,
+                                                NULL,
+                                                NULL);
+    if (FAILED(hr))
+    {
+        printf("GetModuleInfo2 failed with hr=0x%x\n", hr);
+        ++failures;
+        return E_FAIL;
+    }
+
+    if (DEBUG_OUT)
+    {
+        wprintf(L"Module 0x%" PRIxPTR " (%s) unload started\n", moduleId, name);
+    }
+
+    for (auto it = classADMap.begin(); it != classADMap.end(); )
+    {
+        ClassID classId = it->first;
+
+        ModuleID modId;
+        hr = pCorProfilerInfo->GetClassIDInfo(classId, &modId, NULL);
+        if (FAILED(hr))
+        {
+            printf("Failed to get ClassIDInfo hr=0x%x\n", hr);
+            ++failures;
+            return E_FAIL;
+        }
+
+        if (modId == moduleId)
+        {
+            if (DEBUG_OUT)
+            {
+                printf("ClassID 0x%" PRIxPTR " being removed due to parent module unloading\n", classId);
+            }
+
+            it = classADMap.erase(it);
+            continue;
+        }
+
+        // Now check the generic arguments
+        bool shouldEraseClassId = false;
+        vector<ClassID> genericTypes = GetGenericTypeArgs(classId);
+        for (auto genericIt = genericTypes.begin(); genericIt != genericTypes.end(); ++genericIt)
+        {
+            ClassID typeArg = *genericIt;
+            ModuleID typeArgModId;
+
+            if (DEBUG_OUT)
+            {
+                printf("Checking generic argument 0x%" PRIxPTR " of class 0x%" PRIxPTR "\n", typeArg, classId);
+            }
+
+            hr = pCorProfilerInfo->GetClassIDInfo(typeArg, &typeArgModId, NULL);
+            if (FAILED(hr))
+            {
+                printf("Failed to get ClassIDInfo hr=0x%x\n", hr);
+                ++failures;
+                return E_FAIL;
+            }
+
+            if (typeArgModId == moduleId)
+            {
+                if (DEBUG_OUT)
+                {
+                    wprintf(L"ClassID 0x%" PRIxPTR " (%s) being removed due to generic argument 0x%" PRIxPTR " (%s) belonging to the parent module 0x%" PRIxPTR " unloading\n",
+                            classId, GetClassIDName(classId).ToWString().c_str(), typeArg, GetClassIDName(typeArg).ToWString().c_str(), typeArgModId);
+                }
+
+                shouldEraseClassId = true;
+                break;
+            }
+        }
+
+        if (shouldEraseClassId)
+        {
+            it = classADMap.erase(it);
+        }
+        else
+        {
+            ++it;
+        }
+    }
+
+    return S_OK;
+}
+
+HRESULT GetAppDomainStaticAddress::ClassLoadFinished(ClassID classId, HRESULT hrStatus)
+{
+    HRESULT hr = S_OK;
+
+    ThreadID threadId = NULL;
+    AppDomainID appDomainId = NULL;
+    CorElementType baseElemType;
+    ClassID        baseClassId;
+    ULONG          cRank;
+
+    // We don't care about array classes, so skip them.
+
+    hr = pCorProfilerInfo->IsArrayClass(
+        classId,
+        &baseElemType,
+        &baseClassId,
+        &cRank);
+    if (hr == S_OK)
+    {
+        return S_OK;
+    }
+
+
+    hr = pCorProfilerInfo->GetCurrentThreadID(&threadId);
+    if (FAILED(hr))
+    {
+        printf("GetCurrentThreadID returned 0x%x\n", hr);
+        ++failures;
+        return hr;
+    }
+
+    hr = pCorProfilerInfo->GetThreadAppDomain(threadId, &appDomainId);
+    if (FAILED(hr))
+    {
+        printf("GetThreadAppDomain returned 0x%x for ThreadID 0x%" PRIxPTR "\n", hr, threadId);
+        ++failures;
+        return hr;
+    }
+
+    lock_guard<mutex> guard(classADMapLock);
+    classADMap.insert(make_pair(classId, appDomainId));
+
+    ModuleID modId;
+    hr = pCorProfilerInfo->GetClassIDInfo2(classId,
+                                          &modId,
+                                          NULL,
+                                          NULL,
+                                          NULL,
+                                          NULL,
+                                          NULL);
+    if (FAILED(hr))
+    {
+        printf("GetClassIDInfo2 returned 0x%x for ClassID 0x%" PRIxPTR "\n", hr, classId);
+        ++failures;
+    }
+
+    wstring name = GetClassIDName(classId).ToWString();
+
+    if (DEBUG_OUT)
+    {
+        wprintf(L"Class 0x%" PRIxPTR " (%s) loaded from module 0x%" PRIxPTR "\n", classId, name.c_str(), modId);
+    }
+
+    return hr;
+}
+
+HRESULT GetAppDomainStaticAddress::ClassUnloadStarted(ClassID classId)
+{
+    lock_guard<mutex> guard(classADMapLock);
+
+    mdTypeDef unloadClassToken;
+    HRESULT hr = pCorProfilerInfo->GetClassIDInfo2(classId,
+                                                  NULL,
+                                                  &unloadClassToken,
+                                                  NULL,
+                                                  0,
+                                                  NULL,
+                                                  NULL);
+    if (FAILED(hr))
+    {
+        printf("GetClassIDInfo2 failed with hr=0x%x\n", hr);
+        ++failures;
+    }
+
+    if (DEBUG_OUT)
+    {
+        wprintf(L"Class 0x%" PRIxPTR " (%s) unload started\n", classId, GetClassIDName(classId).ToWString().c_str());
+    }
+
+    for (auto it = classADMap.begin(); it != classADMap.end(); ++it)
+    {
+        ClassID mapClass = it->first;
+        mdTypeDef mapClassToken;
+        hr = pCorProfilerInfo->GetClassIDInfo2(mapClass,
+                                              NULL,
+                                              &mapClassToken,
+                                              NULL,
+                                              0,
+                                              NULL,
+                                              NULL);
+        if (mapClass == classId || mapClassToken == unloadClassToken)
+        {
+            it = classADMap.erase(it);
+        }
+    }
+
+    return S_OK;
+}
+
+HRESULT GetAppDomainStaticAddress::JITCompilationFinished(FunctionID functionId, HRESULT hrStatus, BOOL fIsSafeToBlock)
+{
+    ++jitEventCount;
+    return S_OK;
+}
+
+HRESULT GetAppDomainStaticAddress::GarbageCollectionFinished()
+{
+    HRESULT hr = S_OK;
+    lock_guard<mutex> guard(classADMapLock);
+
+    for (ClassAppDomainMap::iterator iCADM = classADMap.begin();
+            iCADM != classADMap.end();
+            iCADM++)
+    {
+        ClassID classId = iCADM->first;
+        AppDomainID appDomainId = iCADM->second;
+
+        if (DEBUG_OUT)
+        {
+            printf("Calling GetClassIDInfo2 on classId 0x%" PRIxPTR "\n", classId);
+            fflush(stdout);
+        }
+
+        ModuleID classModuleId = NULL;
+        hr = pCorProfilerInfo->GetClassIDInfo2(classId,
+                                    &classModuleId,
+                                    NULL,
+                                    NULL,
+                                    NULL,
+                                    NULL,
+                                    NULL);
+        if (FAILED(hr))
+        {
+            printf("GetClassIDInfo2 returned 0x%x for ClassID 0x%" PRIxPTR "\n", hr, classId);
+            ++failures;
+            continue;
+        }
+
+        COMPtrHolder<IMetaDataImport> pIMDImport;
+
+        hr = pCorProfilerInfo->GetModuleMetaData(classModuleId,
+                                        ofRead,
+                                        IID_IMetaDataImport,
+                                        (IUnknown **)&pIMDImport);
+        if (hr == CORPROF_E_DATAINCOMPLETE)
+        {
+            // Module is being unloaded...
+            continue;
+        }
+        if (FAILED(hr))
+        {
+            printf("GetModuleMetaData returned 0x%x  for ModuleID 0x%" PRIxPTR "\n", hr, classModuleId);
+            ++failures;
+            continue;
+        }
+
+        HCORENUM hEnum = NULL;
+        mdTypeDef token = NULL;
+        mdFieldDef fieldTokens[SHORT_LENGTH];
+        ULONG cTokens = NULL;
+
+        if (DEBUG_OUT)
+        {
+            printf("Calling GetClassIDInfo2 (again?) on classId 0x%" PRIxPTR "\n", classId);
+            fflush(stdout);
+        }
+
+        // Get class token to enum all field    s from MetaData.  (Needed for statics)
+        hr = pCorProfilerInfo->GetClassIDInfo2(classId,
+                                            NULL,
+                                            &token,
+                                            NULL,
+                                            NULL,
+                                            NULL,
+                                            NULL);
+        if (hr == CORPROF_E_DATAINCOMPLETE)
+        {
+            // Class load not complete.  We can not inspect yet.
+            continue;
+        }
+        if (FAILED(hr))
+        {
+            printf("GetClassIDInfo2returned 0x%x\n", hr);
+            ++failures;
+            continue;
+        }
+
+        // Enum all fields of the class from the MetaData
+        hr = pIMDImport->EnumFields(&hEnum,
+                                            token,
+                                            fieldTokens,
+                                            SHORT_LENGTH,
+                                            &cTokens);
+        if (FAILED(hr))
+        {
+            printf("IMetaDataImport::EnumFields returned 0x%x\n", hr);
+            ++failures;
+            continue;
+        }
+
+        for (ULONG i = 0; i < cTokens; i++)
+        {
+            mdTypeDef fieldClassToken = NULL;
+            WCHAR tokenName[256];
+            ULONG nameLength = NULL;
+            DWORD fieldAttributes = NULL;
+            PCCOR_SIGNATURE pvSig = NULL;
+            ULONG cbSig = NULL;
+            DWORD corElementType = NULL;
+
+            hr = pIMDImport->GetFieldProps(fieldTokens[i],
+                                            &fieldClassToken,
+                                            tokenName,
+                                            256,
+                                            &nameLength,
+                                            &fieldAttributes,
+                                            &pvSig,
+                                            &cbSig,
+                                            &corElementType,
+                                            NULL,
+                                            NULL);
+
+            if (FAILED(hr))
+            {
+                printf("GetFieldProps returned 0x%x for Field %d\n", hr, i);
+                ++failures;
+                continue;
+            }
+
+            if ((IsFdStatic(fieldAttributes)) && (!IsFdLiteral(fieldAttributes)))
+            {
+                COR_PRF_STATIC_TYPE fieldInfo = COR_PRF_FIELD_NOT_A_STATIC;
+                hr = pCorProfilerInfo->GetStaticFieldInfo(classId, fieldTokens[i], &fieldInfo);
+                if (FAILED(hr))
+                {
+                    wprintf(L"GetStaticFieldInfo returned HR=0x%x for field %x (%s)\n", hr, fieldTokens[i], tokenName);
+                    ++failures;
+                    continue;
+                }
+
+                if (fieldInfo & COR_PRF_FIELD_APP_DOMAIN_STATIC)
+                {
+                    PVOID staticOffSet = NULL;
+
+                    if (DEBUG_OUT)
+                    {
+                        printf("Calling GetAppDomainStaticAddress on classId=0x%" PRIxPTR "\n", classId);
+                        fflush(stdout);
+                    }
+
+                    hr = pCorProfilerInfo->GetAppDomainStaticAddress(classId,
+                                                fieldTokens[i],
+                                                appDomainId,
+                                                &staticOffSet);
+
+                    if (FAILED(hr) && (hr != CORPROF_E_DATAINCOMPLETE))
+                    {
+                        printf("GetAppDomainStaticAddress Failed HR 0x%x\n", hr);
+                        ++failures;
+                        continue;
+                    }
+                }
+            }
+        }
+    }
+
+    printf("Garbage collection finished\n");
+    ++successes;
+    return hr;
+}
+
+bool GetAppDomainStaticAddress::IsRuntimeExecutingManagedCode()
+{
+    return jitEventCount.load() > 0;
+}
+
+std::vector<ClassID> GetAppDomainStaticAddress::GetGenericTypeArgs(ClassID classId)
+{
+    HRESULT hr = S_OK;
+    constexpr size_t typeIdArgsLen = 10;
+    ClassID typeArgs[typeIdArgsLen];
+    ULONG32 typeArgsCount;
+    hr = pCorProfilerInfo->GetClassIDInfo2(classId,
+                                          NULL,
+                                          NULL,
+                                          NULL,
+                                          typeIdArgsLen,
+                                          &typeArgsCount,
+                                          typeArgs);
+    if (FAILED(hr))
+    {
+        printf("Error calling GetClassIDInfo2 hr=0x%x\n", hr);
+        ++failures;
+    }
+
+    vector<ClassID> types;
+    for (ULONG32 i = 0; i < typeArgsCount; ++i)
+    {
+        types.push_back(typeArgs[i]);
+    }
+
+    return types;
+}
diff --git a/src/coreclr/tests/src/profiler/native/getappdomainstaticaddress/getappdomainstaticaddress.h b/src/coreclr/tests/src/profiler/native/getappdomainstaticaddress/getappdomainstaticaddress.h
new file mode 100644 (file)
index 0000000..d0246da
--- /dev/null
@@ -0,0 +1,103 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+#pragma once
+
+#include "../profiler.h"
+
+#include <atomic>
+#include <memory>
+#include <set>
+#include <mutex>
+#include <vector>
+#include <map>
+#include <string>
+#include <thread>
+#include <chrono>
+#include <condition_variable>
+#include <functional>
+#include "cor.h"
+#include "corprof.h"
+
+typedef HRESULT (*GetDispenserFunc) (const CLSID &pClsid, const IID &pIid, void **ppv);
+
+class ManualEvent
+{
+private:
+    std::mutex m_mtx;
+    std::condition_variable m_cv;
+    bool m_set = false;
+
+    static void DoNothing()
+    {
+
+    }
+
+public:
+    ManualEvent() = default;
+    ~ManualEvent() = default;
+    ManualEvent(ManualEvent& other) = delete;
+    ManualEvent(ManualEvent&& other) = delete;
+    ManualEvent& operator= (ManualEvent& other) = delete;
+    ManualEvent& operator= (ManualEvent&& other) = delete;
+
+    void Wait(std::function<void()> spuriousCallback = DoNothing)
+    {
+        std::unique_lock<std::mutex> lock(m_mtx);
+        while (!m_set)
+        {
+            m_cv.wait(lock, [&]() { return m_set; });
+            if (!m_set)
+            {
+                spuriousCallback();
+            }
+        }
+    }
+
+    void Signal()
+    {
+        std::unique_lock<std::mutex> lock(m_mtx);
+        m_set = true;
+    }
+
+    void Reset()
+    {
+        std::unique_lock<std::mutex> lock(m_mtx);
+        m_set = false;
+    }
+};
+
+class GetAppDomainStaticAddress : public Profiler
+{
+public:
+    GetAppDomainStaticAddress();
+    virtual ~GetAppDomainStaticAddress();
+
+    virtual GUID GetClsid() override;
+    virtual HRESULT STDMETHODCALLTYPE Initialize(IUnknown* pICorProfilerInfoUnk) override;
+    virtual HRESULT STDMETHODCALLTYPE Shutdown() override;
+
+    virtual HRESULT STDMETHODCALLTYPE ModuleLoadFinished(ModuleID moduleId, HRESULT hrStatus) override;
+    virtual HRESULT STDMETHODCALLTYPE ModuleUnloadStarted(ModuleID moduleId) override;
+    virtual HRESULT STDMETHODCALLTYPE ClassLoadFinished(ClassID classId, HRESULT hrStatus) override;
+    virtual HRESULT STDMETHODCALLTYPE ClassUnloadStarted(ClassID classId) override;
+    virtual HRESULT STDMETHODCALLTYPE JITCompilationFinished(FunctionID functionId, HRESULT hrStatus, BOOL fIsSafeToBlock) override;
+    virtual HRESULT STDMETHODCALLTYPE GarbageCollectionFinished() override;
+
+private:
+    std::atomic<int> refCount;
+    std::atomic<ULONG32> failures;
+    std::atomic<ULONG32> successes;
+
+    std::atomic<int> jitEventCount;
+    std::thread gcTriggerThread;
+    ManualEvent gcWaitEvent;
+
+    typedef std::map<ClassID, AppDomainID>ClassAppDomainMap;
+    ClassAppDomainMap classADMap;
+    std::mutex classADMapLock;
+
+    bool IsRuntimeExecutingManagedCode();
+    std::vector<ClassID> GetGenericTypeArgs(ClassID classId);
+};
@@ -2,25 +2,25 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 // See the LICENSE file in the project root for more information.
 
-#include "unittestprofiler.h"
+#include "metadatagetdispenser.h"
 
 #ifdef WIN32
 #include <Windows.h>
 #else // WIN32
 #include <dlfcn.h>
-#ifdef __APPLE__ 
+#ifdef __APPLE__
 #include <mach-o/dyld.h>
 #endif // __APPLE__
 #endif // WIN32
 
-UnitTestProfiler::UnitTestProfiler() :
+MetaDataGetDispenser::MetaDataGetDispenser() :
     _dispenser(NULL),
     _failures(0)
 {
 
 }
 
-UnitTestProfiler::~UnitTestProfiler()
+MetaDataGetDispenser::~MetaDataGetDispenser()
 {
     if (_dispenser != NULL)
     {
@@ -29,14 +29,14 @@ UnitTestProfiler::~UnitTestProfiler()
     }
 }
 
-GUID UnitTestProfiler::GetClsid()
+GUID MetaDataGetDispenser::GetClsid()
 {
     // {7198FF3E-50E8-4AD1-9B89-CB15A1D6E740}
     GUID clsid = { 0x7198FF3E, 0x50E8, 0x4AD1, {0x9B, 0x89, 0xCB, 0x15, 0xA1, 0xD6, 0xE7, 0x40 } };
     return clsid;
 }
 
-HRESULT UnitTestProfiler::Initialize(IUnknown* pICorProfilerInfoUnk)
+HRESULT MetaDataGetDispenser::Initialize(IUnknown* pICorProfilerInfoUnk)
 {
     HRESULT hr = Profiler::Initialize(pICorProfilerInfoUnk);
     if (FAILED(hr))
@@ -62,8 +62,10 @@ HRESULT UnitTestProfiler::Initialize(IUnknown* pICorProfilerInfoUnk)
     return S_OK;
 }
 
-HRESULT UnitTestProfiler::Shutdown()
+HRESULT MetaDataGetDispenser::Shutdown()
 {
+    Profiler::Shutdown();
+
     if(_failures == 0)
     {
         printf("PROFILER TEST PASSES\n");
@@ -77,7 +79,7 @@ HRESULT UnitTestProfiler::Shutdown()
     return S_OK;
 }
 
-HRESULT UnitTestProfiler::ModuleLoadStarted(ModuleID moduleId)
+HRESULT MetaDataGetDispenser::ModuleLoadStarted(ModuleID moduleId)
 {
     COMPtrHolder<IMetaDataDispenserEx> pDispenser;
     HRESULT hr = GetDispenser(&pDispenser);
@@ -124,7 +126,7 @@ HRESULT UnitTestProfiler::ModuleLoadStarted(ModuleID moduleId)
 // typedef HRESULT (*GetDispenserFunc) (CLSID *pClsid, IID *pIid, void **ppv);
 #if WIN32
 
-HRESULT UnitTestProfiler::GetDispenser(IMetaDataDispenserEx **disp)
+HRESULT MetaDataGetDispenser::GetDispenser(IMetaDataDispenserEx **disp)
 {
     HMODULE coreclr = LoadLibrary("coreclr.dll");
     if (coreclr == NULL)
@@ -201,7 +203,7 @@ const char *GetCoreCLRPath()
 }
 #endif // __APPLE__
 
-HRESULT UnitTestProfiler::GetDispenser(IMetaDataDispenserEx **disp)
+HRESULT MetaDataGetDispenser::GetDispenser(IMetaDataDispenserEx **disp)
 {
 #ifdef __APPLE__
     const char *profilerName = GetCoreCLRPath();
@@ -238,5 +240,4 @@ HRESULT UnitTestProfiler::GetDispenser(IMetaDataDispenserEx **disp)
     printf("Got IMetaDataDispenserEx\n");
     return S_OK;
 }
-
-#endif // WIN32
\ No newline at end of file
+#endif // WIN32
 
 typedef HRESULT (*GetDispenserFunc) (const CLSID &pClsid, const IID &pIid, void **ppv);
 
-class UnitTestProfiler : public Profiler
+class MetaDataGetDispenser : public Profiler
 {
 public:
-    UnitTestProfiler();
-    virtual ~UnitTestProfiler();
+    MetaDataGetDispenser();
+    virtual ~MetaDataGetDispenser();
 
     virtual GUID GetClsid();
     virtual HRESULT STDMETHODCALLTYPE Initialize(IUnknown* pICorProfilerInfoUnk);
index 386c4c6..0f3a876 100644 (file)
@@ -7,8 +7,14 @@
 #include <iostream>
 #include <assert.h>
 #include <cstring>
+#include <string>
 
 #if defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__)
+
+// Definitely won't work for non-ascii characters so hopefully we never start using
+// them in the tests
+#define CAST_CHAR(ch) static_cast<wchar_t>(ch)
+
 // On linux the runtime uses 16 bit strings but the native platform wchar_t is 32 bit.
 // This means there aren't c runtime functions like wcslen for 16 bit strings. The idea
 // here is to provide the easy ones to avoid all the copying and transforming. If more complex
@@ -52,6 +58,7 @@ inline int wcscmp(const char16_t *lhs, const char16_t *rhs)
 
 #else // defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__)
 #define WCHAR(str) L##str
+#define CAST_CHAR(ch) ch
 #endif // defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__)
 
 // 16 bit string type that works cross plat and doesn't require changing widths
@@ -187,6 +194,22 @@ public:
         }
     }
 
+    std::wstring ToWString()
+    {
+        std::wstring temp;
+        for (size_t i = 0; i < bufferLen; ++i)
+        {
+            if (buffer[i] == 0)
+            {
+                break;
+            }
+
+            temp.push_back(CAST_CHAR(buffer[i]));
+        }
+
+        return temp;
+    }
+
     size_t Size() const
     {
         return wcslen(buffer);
@@ -202,11 +225,7 @@ inline std::wostream& operator<<(std::wostream& os, const String& obj)
             break;
         }
 
-#if defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__)
-        os << static_cast<wchar_t>(obj.buffer[i]);
-#else // defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__)
-        os << obj.buffer[i];
-#endif // defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__)
+        os << CAST_CHAR(obj.buffer[i]);
     }
 
     return os;
diff --git a/src/coreclr/tests/src/profiler/unittest/getappdomainstaticaddress.cs b/src/coreclr/tests/src/profiler/unittest/getappdomainstaticaddress.cs
new file mode 100644 (file)
index 0000000..18fdd6d
--- /dev/null
@@ -0,0 +1,55 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Runtime.CompilerServices;
+using System.Threading;
+using System.IO;
+using System.Runtime.Loader;
+
+namespace Profiler.Tests
+{
+    class Program
+    {
+        static readonly Guid GetAppDomainStaticAddressProfilerGuid = new Guid("604D76F0-2AF2-48E0-B196-80C972F6AFB7");
+
+        static int Main(string[] args)
+        {
+            if (args.Length == 1 && args[0].Equals("RunTest", StringComparison.OrdinalIgnoreCase))
+            {
+                return RunTest();
+            }
+
+            return ProfilerTestRunner.Run(profileePath: System.Reflection.Assembly.GetExecutingAssembly().Location,
+                                          testName: "UnitTestGetAppDomainStaticAddress",
+                                          profilerClsid: GetAppDomainStaticAddressProfilerGuid);
+        }
+
+        static int RunTest()
+        {
+            LoadCollectibleAssembly();
+
+            Thread.Sleep(TimeSpan.FromSeconds(3));
+
+            return 100;
+        }
+
+        private static void LoadCollectibleAssembly()
+        {
+            var collectibleContext = new AssemblyLoadContext("Collectible", true);
+
+            var asmDir = Path.GetDirectoryName(typeof(Program).Assembly.Location);
+            var dynamicLibrary = collectibleContext.LoadFromAssemblyPath(Path.Combine(asmDir, "unloadlibrary.dll"));
+            var testType = dynamicLibrary.GetType("UnloadLibrary.TestClass");
+
+            object instance = Activator.CreateInstance(testType);
+
+            Console.WriteLine(instance.GetHashCode());
+
+            collectibleContext.Unload();
+        }
+    }
+}
diff --git a/src/coreclr/tests/src/profiler/unittest/getappdomainstaticaddress.csproj b/src/coreclr/tests/src/profiler/unittest/getappdomainstaticaddress.csproj
new file mode 100644 (file)
index 0000000..3a10288
--- /dev/null
@@ -0,0 +1,17 @@
+<Project Sdk="Microsoft.NET.Sdk">
+  <PropertyGroup>
+    <TargetFrameworkIdentifier>.NETCoreApp</TargetFrameworkIdentifier>
+    <OutputType>exe</OutputType>
+    <CLRTestKind>BuildAndRun</CLRTestKind>
+    <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
+    <CLRTestPriority>0</CLRTestPriority>
+    <Optimize>true</Optimize>
+  </PropertyGroup>
+  <ItemGroup>
+    <Compile Include="$(MSBuildProjectName).cs" />
+    <ProjectReference Include="$(TestSourceDir)Common/CoreCLRTestLibrary/CoreCLRTestLibrary.csproj" />
+    <ProjectReference Include="../common/profiler_common.csproj" />
+    <ProjectReference Include="unloadlibrary.csproj" />
+    <ProjectReference Include="$(MSBuildThisFileDirectory)/../native/CMakeLists.txt" />
+  </ItemGroup>
+</Project>
index bf77732..ff03ee5 100644 (file)
@@ -9,7 +9,7 @@ namespace Profiler.Tests
 {
     class MetadataGetDispenser //: ProfilerTest
     {
-        static readonly Guid UnitTestProfilerGuid = new Guid("7198FF3E-50E8-4AD1-9B89-CB15A1D6E740");
+        static readonly Guid MetaDataGetDispenserProfilerGuid = new Guid("7198FF3E-50E8-4AD1-9B89-CB15A1D6E740");
 
         public static int RunTest(string[] args)
         {
@@ -28,7 +28,7 @@ namespace Profiler.Tests
 
             return ProfilerTestRunner.Run(profileePath: System.Reflection.Assembly.GetExecutingAssembly().Location,
                                           testName: "UnitTestMetadataGetDispenser",
-                                          profilerClsid: UnitTestProfilerGuid);
+                                          profilerClsid: MetaDataGetDispenserProfilerGuid);
         }
     }
 }
diff --git a/src/coreclr/tests/src/profiler/unittest/unloadlibrary.cs b/src/coreclr/tests/src/profiler/unittest/unloadlibrary.cs
new file mode 100644 (file)
index 0000000..291d07a
--- /dev/null
@@ -0,0 +1,30 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+
+namespace UnloadLibrary
+{
+    public class TestClass : MarshalByRefObject
+    {
+        static string staticString = "A static string";
+        string[] instanceStrings;
+        static int s_count;
+
+        static List<TestClass> instances = new List<TestClass>();
+
+        public TestClass()
+        {
+            instanceStrings = new string[100];
+            for (int i = 0; i < instanceStrings.Length; i++)
+            {
+                instanceStrings[i] = staticString + (++s_count);
+            }
+            Console.WriteLine("Class1 constructed");
+
+            instances.Add(this);
+        }
+    }
+}
diff --git a/src/coreclr/tests/src/profiler/unittest/unloadlibrary.csproj b/src/coreclr/tests/src/profiler/unittest/unloadlibrary.csproj
new file mode 100644 (file)
index 0000000..51b3e4b
--- /dev/null
@@ -0,0 +1,10 @@
+<Project Sdk="Microsoft.NET.Sdk">
+  <PropertyGroup>
+    <TargetFrameworkIdentifier>.NETCoreApp</TargetFrameworkIdentifier>
+    <OutputType>Library</OutputType>
+    <Optimize>true</Optimize>
+  </PropertyGroup>
+  <ItemGroup>
+    <Compile Include="$(MSBuildProjectName).cs" />
+  </ItemGroup>
+</Project>