Populate properties of AssemblyLoadStart/Stop events (dotnet/coreclr#27558)
authorElinor Fung <47805090+elinor-fung@users.noreply.github.com>
Wed, 6 Nov 2019 20:17:49 +0000 (12:17 -0800)
committerGitHub <noreply@github.com>
Wed, 6 Nov 2019 20:17:49 +0000 (12:17 -0800)
* Populate properties of AssemblyLoadStart/Stop events
* Add tests for basic start/stop binder tracing events
* Add property for ALC of requesting assembly

Commit migrated from https://github.com/dotnet/coreclr/commit/05cbf03037737c631534642f699cf7533a9ee250

16 files changed:
src/coreclr/src/System.Private.CoreLib/src/System/Runtime/Loader/AssemblyLoadContext.CoreCLR.cs
src/coreclr/src/binder/activitytracker.cpp
src/coreclr/src/binder/bindertracing.cpp
src/coreclr/src/binder/inc/bindertracing.h
src/coreclr/src/vm/ClrEtwAll.man
src/coreclr/src/vm/appdomain.cpp
src/coreclr/src/vm/assemblynative.cpp
src/coreclr/src/vm/assemblyspec.cpp
src/coreclr/src/vm/baseassemblyspec.h
src/coreclr/src/vm/coreassemblyspec.cpp
src/coreclr/tests/src/Loader/binding/tracing/AssemblyToLoad.cs [new file with mode: 0644]
src/coreclr/tests/src/Loader/binding/tracing/AssemblyToLoad.csproj [new file with mode: 0644]
src/coreclr/tests/src/Loader/binding/tracing/BinderEventListener.cs
src/coreclr/tests/src/Loader/binding/tracing/BinderTracingTest.cs
src/coreclr/tests/src/Loader/binding/tracing/BinderTracingTest.csproj
src/libraries/System.Private.CoreLib/src/System/Diagnostics/Tracing/ActivityTracker.cs

index 6de67a8..dfcbf13 100644 (file)
@@ -206,7 +206,11 @@ namespace System.Runtime.Loader
         /// </summary>
         private static void StartAssemblyLoad(ref Guid activityId, ref Guid relatedActivityId)
         {
-            ActivityTracker.Instance.OnStart(NativeRuntimeEventSource.Log.Name, AssemblyLoadName, 0, ref activityId, ref relatedActivityId, EventActivityOptions.Recursive);
+            // Make sure ActivityTracker is enabled
+            ActivityTracker.Instance.Enable();
+
+            // Don't use trace to TPL event source in ActivityTracker - that event source is a singleton and its instantiation may have triggered the load.
+            ActivityTracker.Instance.OnStart(NativeRuntimeEventSource.Log.Name, AssemblyLoadName, 0, ref activityId, ref relatedActivityId, EventActivityOptions.Recursive, useTplSource: false);
         }
 
         /// <summary>
@@ -214,7 +218,8 @@ namespace System.Runtime.Loader
         /// </summary>
         private static void StopAssemblyLoad(ref Guid activityId)
         {
-            ActivityTracker.Instance.OnStop(NativeRuntimeEventSource.Log.Name, AssemblyLoadName, 0, ref activityId);
+            // Don't use trace to TPL event source in ActivityTracker - that event source is a singleton and its instantiation may have triggered the load.
+            ActivityTracker.Instance.OnStop(NativeRuntimeEventSource.Log.Name, AssemblyLoadName, 0, ref activityId, useTplSource: false);
         }
     }
 }
index ce50872..ea5dbd1 100644 (file)
@@ -19,6 +19,8 @@ void ActivityTracker::Start(/*out*/ GUID *activityId, /*out*/ GUID *relatedActiv
 {
     GCX_COOP();
 
+    OVERRIDE_TYPE_LOAD_LEVEL_LIMIT(CLASS_LOADED);
+
     PREPARE_NONVIRTUAL_CALLSITE(METHOD__ASSEMBLYLOADCONTEXT__START_ASSEMBLY_LOAD);
     DECLARE_ARGHOLDER_ARRAY(args, 2);
     args[ARGNUM_0] = PTR_TO_ARGHOLDER(activityId);
@@ -31,6 +33,8 @@ void ActivityTracker::Stop(/*out*/ GUID *activityId)
 {
     GCX_COOP();
 
+    OVERRIDE_TYPE_LOAD_LEVEL_LIMIT(CLASS_LOADED);
+
     PREPARE_NONVIRTUAL_CALLSITE(METHOD__ASSEMBLYLOADCONTEXT__STOP_ASSEMBLY_LOAD);
     DECLARE_ARGHOLDER_ARRAY(args, 1);
     args[ARGNUM_0] = PTR_TO_ARGHOLDER(activityId);
index f89c921..f5523f0 100644 (file)
@@ -41,12 +41,13 @@ namespace
             request.AssemblyPath,
             request.RequestingAssembly,
             request.AssemblyLoadContext,
+            request.RequestingAssemblyLoadContext,
             &activityId,
             &relatedActivityId);
 #endif // FEATURE_EVENT_TRACE
     }
 
-    void FireAssemblyLoadStop(const BinderTracing::AssemblyBindOperation::BindRequest &request, bool success, const WCHAR *resultName, const WCHAR *resultPath, bool cached)
+    void FireAssemblyLoadStop(const BinderTracing::AssemblyBindOperation::BindRequest &request, PEAssembly *resultAssembly, bool cached)
     {
 #ifdef FEATURE_EVENT_TRACE
         if (!EventEnabledAssemblyLoadStop())
@@ -55,12 +56,22 @@ namespace
         GUID activityId = GUID_NULL;
         ActivityTracker::Stop(&activityId);
 
+        SString resultName;
+        SString resultPath;
+        bool success = resultAssembly != nullptr;
+        if (success)
+        {
+            resultPath = resultAssembly->GetPath();
+            resultAssembly->GetDisplayName(resultName);
+        }
+
         FireEtwAssemblyLoadStop(
             GetClrInstanceId(),
             request.AssemblyName,
             request.AssemblyPath,
             request.RequestingAssembly,
             request.AssemblyLoadContext,
+            request.RequestingAssemblyLoadContext,
             success,
             resultName,
             resultPath,
@@ -68,6 +79,93 @@ namespace
             &activityId);
 #endif // FEATURE_EVENT_TRACE
     }
+
+    void GetAssemblyLoadContextNameFromBindContext(ICLRPrivBinder *bindContext, AppDomain *domain, /*out*/ SString &alcName)
+    {
+        _ASSERTE(bindContext != nullptr);
+
+        UINT_PTR binderID = 0;
+        HRESULT hr = bindContext->GetBinderID(&binderID);
+        _ASSERTE(SUCCEEDED(hr));
+        if (FAILED(hr))
+            return;
+
+        ICLRPrivBinder *binder = reinterpret_cast<ICLRPrivBinder *>(binderID);
+#ifdef FEATURE_COMINTEROP
+        if (AreSameBinderInstance(binder, domain->GetTPABinderContext()) || AreSameBinderInstance(binder, domain->GetWinRtBinder()))
+#else
+        if (AreSameBinderInstance(binder, domain->GetTPABinderContext()))
+#endif // FEATURE_COMINTEROP
+        {
+            alcName.Set(W("Default"));
+        }
+        else
+        {
+#ifdef CROSSGEN_COMPILE
+            alcName.Set(W("Custom"));
+#else // CROSSGEN_COMPILE
+            CLRPrivBinderAssemblyLoadContext * alcBinder = static_cast<CLRPrivBinderAssemblyLoadContext *>(binder);
+            OBJECTREF *alc = reinterpret_cast<OBJECTREF *>(alcBinder->GetManagedAssemblyLoadContext());
+
+            GCX_COOP();
+            struct _gc {
+                STRINGREF alcName;
+            } gc;
+            ZeroMemory(&gc, sizeof(gc));
+
+            GCPROTECT_BEGIN(gc);
+
+            PREPARE_VIRTUAL_CALLSITE(METHOD__OBJECT__TO_STRING, *alc);
+            DECLARE_ARGHOLDER_ARRAY(args, 1);
+            args[ARGNUM_0] = OBJECTREF_TO_ARGHOLDER(*alc);
+            CALL_MANAGED_METHOD_RETREF(gc.alcName, STRINGREF, args);
+            gc.alcName->GetSString(alcName);
+
+            GCPROTECT_END();
+#endif // CROSSGEN_COMPILE
+        }
+    }
+
+    void GetAssemblyLoadContextNameFromSpec(AssemblySpec *spec, /*out*/ SString &alcName)
+    {
+        _ASSERTE(spec != nullptr);
+
+        AppDomain *domain = spec->GetAppDomain();
+        ICLRPrivBinder* bindContext = spec->GetBindingContext();
+        if (bindContext == nullptr)
+            bindContext = spec->GetBindingContextFromParentAssembly(domain);
+
+        GetAssemblyLoadContextNameFromBindContext(bindContext, domain, alcName);
+    }
+
+    void PopulateBindRequest(/*inout*/ BinderTracing::AssemblyBindOperation::BindRequest &request)
+    {
+        AssemblySpec *spec = request.AssemblySpec;
+        _ASSERTE(spec != nullptr);
+
+        if (request.AssemblyPath.IsEmpty())
+            request.AssemblyPath = spec->GetCodeBase();
+
+        if (spec->GetName() != nullptr)
+            spec->GetDisplayName(ASM_DISPLAYF_VERSION | ASM_DISPLAYF_CULTURE | ASM_DISPLAYF_PUBLIC_KEY_TOKEN, request.AssemblyName);
+
+        DomainAssembly *parentAssembly = spec->GetParentAssembly();
+        if (parentAssembly != nullptr)
+        {
+            PEAssembly *peAssembly = parentAssembly->GetFile();
+            _ASSERTE(peAssembly != nullptr);
+            peAssembly->GetDisplayName(request.RequestingAssembly);
+
+            AppDomain *domain = parentAssembly->GetAppDomain();
+            ICLRPrivBinder *bindContext = peAssembly->GetBindingContext();
+            if (bindContext == nullptr)
+                bindContext = domain->GetTPABinderContext(); // System.Private.CoreLib returns null
+
+            GetAssemblyLoadContextNameFromBindContext(bindContext, domain, request.RequestingAssemblyLoadContext);
+        }
+
+        GetAssemblyLoadContextNameFromSpec(spec, request.AssemblyLoadContext);
+    }
 }
 
 bool BinderTracing::IsEnabled()
@@ -81,31 +179,55 @@ bool BinderTracing::IsEnabled()
 
 namespace BinderTracing
 {
-    AssemblyBindOperation::AssemblyBindOperation(AssemblySpec *assemblySpec)
-        : m_bindRequest { assemblySpec }
-        , m_success { false }
+    AssemblyBindOperation::AssemblyBindOperation(AssemblySpec *assemblySpec, const WCHAR *assemblyPath)
+        : m_bindRequest { assemblySpec, nullptr, assemblyPath }
+        , m_populatedBindRequest { false }
+        , m_checkedIgnoreBind { false }
+        , m_ignoreBind { false }
+        , m_resultAssembly { nullptr }
         , m_cached { false }
     {
         _ASSERTE(assemblySpec != nullptr);
 
-        // ActivityTracker or EventSource may have triggered the system satellite load.
-        // Don't track system satellite binding to avoid potential infinite recursion.
-        m_trackingBind = BinderTracing::IsEnabled() && !m_bindRequest.AssemblySpec->IsMscorlibSatellite();
-        if (m_trackingBind)
-        {
-            m_bindRequest.AssemblySpec->GetFileOrDisplayName(ASM_DISPLAYF_VERSION | ASM_DISPLAYF_CULTURE | ASM_DISPLAYF_PUBLIC_KEY_TOKEN, m_bindRequest.AssemblyName);
-            FireAssemblyLoadStart(m_bindRequest);
-        }
+        if (!BinderTracing::IsEnabled() || ShouldIgnoreBind())
+            return;
+
+        PopulateBindRequest(m_bindRequest);
+        m_populatedBindRequest = true;
+        FireAssemblyLoadStart(m_bindRequest);
     }
 
     AssemblyBindOperation::~AssemblyBindOperation()
     {
-        if (m_trackingBind)
-            FireAssemblyLoadStop(m_bindRequest, m_success, m_resultName.GetUnicode(), m_resultPath.GetUnicode(), m_cached);
+        if (!BinderTracing::IsEnabled() || ShouldIgnoreBind())
+            return;
+
+        // Make sure the bind request is populated. Tracing may have been enabled mid-bind.
+        if (!m_populatedBindRequest)
+            PopulateBindRequest(m_bindRequest);
+
+        FireAssemblyLoadStop(m_bindRequest, m_resultAssembly, m_cached);
     }
 
-    void AssemblyBindOperation::SetResult(PEAssembly *assembly)
+    void AssemblyBindOperation::SetResult(PEAssembly *assembly, bool cached)
     {
-        m_success = assembly != nullptr;
+        _ASSERTE(m_resultAssembly == nullptr);
+        m_resultAssembly = assembly;
+        if (m_resultAssembly != nullptr)
+            m_resultAssembly->AddRef();
+
+        m_cached = cached;
     }
-}
\ No newline at end of file
+
+    bool AssemblyBindOperation::ShouldIgnoreBind()
+    {
+        if (m_checkedIgnoreBind)
+            return m_ignoreBind;
+
+        // ActivityTracker or EventSource may have triggered the system satellite load.
+        // Don't track system satellite binding to avoid potential infinite recursion.
+        m_ignoreBind = m_bindRequest.AssemblySpec->IsMscorlibSatellite();
+        m_checkedIgnoreBind = true;
+        return m_ignoreBind;
+    }
+}
index 068d881..a259f37 100644 (file)
@@ -21,10 +21,10 @@ namespace BinderTracing
     {
     public:
         // This class assumes the assembly spec will have a longer lifetime than itself
-        AssemblyBindOperation(AssemblySpec *assemblySpec);
+        AssemblyBindOperation(AssemblySpec *assemblySpec, const WCHAR *assemblyPath = nullptr);
         ~AssemblyBindOperation();
 
-        void SetResult(PEAssembly *assembly);
+        void SetResult(PEAssembly *assembly, bool cached = false);
 
         struct BindRequest
         {
@@ -33,16 +33,20 @@ namespace BinderTracing
             SString AssemblyPath;
             SString RequestingAssembly;
             SString AssemblyLoadContext;
+            SString RequestingAssemblyLoadContext;
         };
 
     private:
+        bool ShouldIgnoreBind();
+
+    private:
         BindRequest m_bindRequest;
+        bool m_populatedBindRequest;
 
-        bool m_trackingBind;
+        bool m_checkedIgnoreBind;
+        bool m_ignoreBind;
 
-        bool m_success;
-        SString m_resultName;
-        SString m_resultPath;
+        ReleaseHolder<PEAssembly> m_resultAssembly;
         bool m_cached;
     };
 };
index 1e82a1b..286575d 100644 (file)
                         <data name="AssemblyPath" inType="win:UnicodeString" />
                         <data name="RequestingAssembly" inType="win:UnicodeString" />
                         <data name="AssemblyLoadContext" inType="win:UnicodeString" />
+                        <data name="RequestingAssemblyLoadContext" inType="win:UnicodeString" />
                         <UserData>
                             <AssemblyLoadStart xmlns="myNs">
                                 <ClrInstanceID> %1 </ClrInstanceID>
                                 <AssemblyPath> %3 </AssemblyPath>
                                 <RequestingAssembly> %4 </RequestingAssembly>
                                 <AssemblyLoadContext> %5 </AssemblyLoadContext>
+                                <RequestingAssemblyLoadContext> %6 </RequestingAssemblyLoadContext>
                             </AssemblyLoadStart>
                         </UserData>
                     </template>
                         <data name="AssemblyPath" inType="win:UnicodeString" />
                         <data name="RequestingAssembly" inType="win:UnicodeString" />
                         <data name="AssemblyLoadContext" inType="win:UnicodeString" />
+                        <data name="RequestingAssemblyLoadContext" inType="win:UnicodeString" />
                         <data name="Success" inType="win:Boolean" />
                         <data name="ResultAssemblyName" inType="win:UnicodeString" />
                         <data name="ResultAssemblyPath" inType="win:UnicodeString" />
                                 <AssemblyPath> %3 </AssemblyPath>
                                 <RequestingAssembly> %4 </RequestingAssembly>
                                 <AssemblyLoadContext> %5 </AssemblyLoadContext>
-                                <Success> %6 </Success>
-                                <ResultAssemblyName> %7 </ResultAssemblyName>
-                                <ResultAssemblyPath> %8 </ResultAssemblyPath>
-                                <Cached> %9 </Cached>
+                                <RequestingAssemblyLoadContext> %6 </RequestingAssemblyLoadContext>
+                                <Success> %7 </Success>
+                                <ResultAssemblyName> %8 </ResultAssemblyName>
+                                <ResultAssemblyPath> %9 </ResultAssemblyPath>
+                                <Cached> %10 </Cached>
                             </AssemblyLoadStop>
                         </UserData>
                     </template>
                 <string id="RuntimePublisher.AppDomainLoad_V1EventMessage" value="AppDomainID=%1;%nAppDomainFlags=%2;%nAppDomainName=%3;%nAppDomainIndex=%4;%nClrInstanceID=%5" />
                 <string id="RuntimePublisher.AppDomainUnloadEventMessage" value="AppDomainID=%1;%nAppDomainFlags=%2;%nAppDomainName=%3" />
                 <string id="RuntimePublisher.AppDomainUnload_V1EventMessage" value="AppDomainID=%1;%nAppDomainFlags=%2;%nAppDomainName=%3;%nAppDomainIndex=%4;%nClrInstanceID=%5" />
-                <string id="RuntimePublisher.AssemblyLoadStartEventMessage" value="ClrInstanceID=%1;%nAssemblyName=%2;%nAssemblyPath=%3;%nRequestingAssembly=%4;%nAssemblyLoadContext=%5" />
-                <string id="RuntimePublisher.AssemblyLoadStopEventMessage" value="ClrInstanceID=%1;%nAssemblyName=%2;%nAssemblyPath=%3;%nRequestingAssembly=%4;%nAssemblyLoadContext=%5;%nSuccess=%6;%nResultAssemblyName=%7;%nResultAssemblyPath=%8;%nCached=%9" />
+                <string id="RuntimePublisher.AssemblyLoadStartEventMessage" value="ClrInstanceID=%1;%nAssemblyName=%2;%nAssemblyPath=%3;%nRequestingAssembly=%4;%nAssemblyLoadContext=%5;%nRequestingAssemblyLoadContext=%6" />
+                <string id="RuntimePublisher.AssemblyLoadStopEventMessage" value="ClrInstanceID=%1;%nAssemblyName=%2;%nAssemblyPath=%3;%nRequestingAssembly=%4;%nAssemblyLoadContext=%5;%nRequestingAssemblyLoadContext=%6;%nSuccess=%7;%nResultAssemblyName=%8;%nResultAssemblyPath=%9;%nCached=%10" />
                 <string id="RuntimePublisher.StackEventMessage" value="ClrInstanceID=%1;%nReserved1=%2;%nReserved2=%3;%nFrameCount=%4;%nStack=%5" />
                 <string id="RuntimePublisher.AppDomainMemAllocatedEventMessage" value="AppDomainID=%1;%nAllocated=%2;%nClrInstanceID=%3" />
                 <string id="RuntimePublisher.AppDomainMemSurvivedEventMessage" value="AppDomainID=%1;%nSurvived=%2;%nProcessSurvived=%3;%nClrInstanceID=%4" />
index 81d7d89..9c85565 100644 (file)
@@ -4983,10 +4983,12 @@ EndTry2:;
     {
         HRESULT hrBindResult = S_OK;
         PEAssemblyHolder result;
-
+        
+        bool isCached = false;
         EX_TRY
         {
-            if (!IsCached(pSpec))
+            isCached = IsCached(pSpec);
+            if (!isCached)
             {
 
                 {
@@ -5147,7 +5149,7 @@ EndTry2:;
                 result->AddRef();
         }
 
-        bindOperation.SetResult(result.GetValue());
+        bindOperation.SetResult(result.GetValue(), isCached);
         return result.Extract();
     }
     else
index 7672fd0..341de82 100644 (file)
@@ -189,7 +189,7 @@ Assembly* AssemblyNative::LoadFromPEImage(ICLRPrivBinder* pBinderContext, PEImag
     spec.InitializeSpec(TokenFromRid(1, mdtAssembly), pImage->GetMDImport(), pCallersAssembly);
     spec.SetBindingContext(pBinderContext);
 
-    BinderTracing::AssemblyBindOperation bindOperation(&spec);
+    BinderTracing::AssemblyBindOperation bindOperation(&spec, pImage->GetPath());
 
     HRESULT hr = S_OK;
     PTR_AppDomain pCurDomain = GetAppDomain();
index 216f4fa..92ddcaf 100644 (file)
@@ -29,6 +29,8 @@
 #include "winrthelpers.h"
 #endif
 
+#include "../binder/inc/bindertracing.h"
+
 #ifdef _DEBUG
 // This debug-only wrapper for LookupAssembly is solely for the use of postconditions and
 // assertions. The problem is that the real LookupAssembly can throw an OOM
@@ -917,6 +919,9 @@ DomainAssembly *AssemblySpec::LoadDomainAssembly(FileLoadLevel targetLevel,
 
     if (pAssembly)
     {
+        BinderTracing::AssemblyBindOperation bindOperation(this);
+        bindOperation.SetResult(pAssembly->GetFile(), true /*cached*/);
+
         pDomain->LoadDomainFile(pAssembly, targetLevel);
         RETURN pAssembly;
     }
index 7cda267..69e615e 100644 (file)
@@ -175,6 +175,7 @@ public:
     static BOOL VerifyBindingString(LPCWSTR pwStr);
 
     void GetFileOrDisplayName(DWORD flags, SString &result) const;
+    void GetDisplayName(DWORD flags, SString &result) const;
 
     inline void GetPublicKey(
         PBYTE * ppbPublicKey,
@@ -214,9 +215,11 @@ public:
         return IsAfRetargetable(m_dwFlags);
     }
 
-
 protected:
     static BOOL CompareRefToDef(const BaseAssemblySpec *pRef, const BaseAssemblySpec *pDef);
+
+private:
+    void GetDisplayNameInternal(DWORD flags, SString &result) const;
 };
 
 #endif // __BASE_ASSEMBLY_SPEC_H__
index a88b0b6..d69c5ad 100644 (file)
@@ -102,7 +102,7 @@ VOID  AssemblySpec::Bind(AppDomain      *pAppDomain,
 
     if (m_wszCodeBase == NULL)
     {
-        GetFileOrDisplayName(0, assemblyDisplayName);
+        GetDisplayName(0, assemblyDisplayName);
     }
 
     // Have a default binding context setup
@@ -461,6 +461,26 @@ VOID BaseAssemblySpec::GetFileOrDisplayName(DWORD flags, SString &result) const
         return;
     }
 
+    GetDisplayNameInternal(flags, result);
+}
+
+VOID BaseAssemblySpec::GetDisplayName(DWORD flags, SString &result) const
+{
+    CONTRACTL
+    {
+        INSTANCE_CHECK;
+        THROWS;
+        INJECT_FAULT(ThrowOutOfMemory());
+        PRECONDITION(CheckValue(result));
+        PRECONDITION(result.IsEmpty());
+    }
+    CONTRACTL_END;
+
+    GetDisplayNameInternal(flags, result);
+}
+
+VOID BaseAssemblySpec::GetDisplayNameInternal(DWORD flags, SString &result) const
+{
     if (flags==0)
         flags=ASM_DISPLAYF_FULL;
 
diff --git a/src/coreclr/tests/src/Loader/binding/tracing/AssemblyToLoad.cs b/src/coreclr/tests/src/Loader/binding/tracing/AssemblyToLoad.cs
new file mode 100644 (file)
index 0000000..6d26ac7
--- /dev/null
@@ -0,0 +1,9 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace AssemblyToLoad
+{
+    public class Program
+    { }
+}
diff --git a/src/coreclr/tests/src/Loader/binding/tracing/AssemblyToLoad.csproj b/src/coreclr/tests/src/Loader/binding/tracing/AssemblyToLoad.csproj
new file mode 100644 (file)
index 0000000..bf4c48e
--- /dev/null
@@ -0,0 +1,9 @@
+<Project Sdk="Microsoft.NET.Sdk">
+  <PropertyGroup>
+    <OutputType>Library</OutputType>
+    <CLRTestKind>BuildOnly</CLRTestKind>
+  </PropertyGroup>
+  <ItemGroup>
+    <Compile Include="AssemblyToLoad.cs" />
+  </ItemGroup>
+</Project>
index e221a1e..5a8e173 100644 (file)
@@ -8,16 +8,23 @@ using System.Diagnostics.Tracing;
 using System.Linq;
 using System.Threading;
 using System.Reflection;
-using Xunit;
 
-using Assert = Xunit.Assert;
+using TestLibrary;
 
 namespace BinderTracingTests
 {
     internal class BindOperation
     {
         internal AssemblyName AssemblyName;
+        internal string AssemblyPath;
+        internal AssemblyName RequestingAssembly;
+        internal string AssemblyLoadContext;
+        internal string RequestingAssemblyLoadContext;
+
         internal bool Success;
+        internal AssemblyName ResultAssemblyName;
+        internal string ResultAssemblyPath;
+        internal bool Cached;
 
         internal Guid ActivityId;
         internal Guid ParentActivityId;
@@ -81,23 +88,39 @@ namespace BinderTracingTests
                 case "AssemblyLoadStart":
                     lock (eventsLock)
                     {
-                        Assert.True(!bindOperations.ContainsKey(data.ActivityId), "AssemblyLoadStart should not exist for same activity ID ");
+                        Assert.IsTrue(!bindOperations.ContainsKey(data.ActivityId), "AssemblyLoadStart should not exist for same activity ID ");
                         var bindOperation = new BindOperation()
                         {
                             AssemblyName = new AssemblyName(GetDataString("AssemblyName")),
+                            AssemblyPath = GetDataString("AssemblyPath"),
+                            AssemblyLoadContext = GetDataString("AssemblyLoadContext"),
+                            RequestingAssemblyLoadContext = GetDataString("RequestingAssemblyLoadContext"),
                             ActivityId = data.ActivityId,
                             ParentActivityId = data.RelatedActivityId,
                             Nested = bindOperations.ContainsKey(data.RelatedActivityId)
                         };
+                        string requestingAssembly = GetDataString("RequestingAssembly");
+                        if (!string.IsNullOrEmpty(requestingAssembly))
+                        {
+                            bindOperation.RequestingAssembly = new AssemblyName(requestingAssembly);
+                        }
                         bindOperations.Add(data.ActivityId, bindOperation);
                     }
                     break;
                 case "AssemblyLoadStop":
                     lock (eventsLock)
                     {
-                        Assert.True(bindOperations.ContainsKey(data.ActivityId), "AssemblyLoadStop should have a matching AssemblyLoadStart");
-                        bindOperations[data.ActivityId].Success = (bool)GetData("Success");
-                        bindOperations[data.ActivityId].Completed = true;
+                        Assert.IsTrue(bindOperations.ContainsKey(data.ActivityId), "AssemblyLoadStop should have a matching AssemblyBindStart");
+                        BindOperation bind = bindOperations[data.ActivityId];
+                        bind.Success = (bool)GetData("Success");
+                        string resultName = GetDataString("ResultAssemblyName");
+                        if (!string.IsNullOrEmpty(resultName))
+                        {
+                            bind.ResultAssemblyName = new AssemblyName(resultName);
+                        }
+                        bind.ResultAssemblyPath = GetDataString("ResultAssemblyPath");
+                        bind.Cached = (bool)GetData("Cached");
+                        bind.Completed = true;
                     }
                     break;
             }
index a450cd0..9aee2c4 100644 (file)
 // See the LICENSE file in the project root for more information.
 
 using System;
+using System.Diagnostics;
+using System.IO;
 using System.Linq;
 using System.Reflection;
 using System.Runtime.Loader;
 
-using Assert = Xunit.Assert;
+using TestLibrary;
 
 namespace BinderTracingTests
 {
+    [AttributeUsage(System.AttributeTargets.Method)]
+    class BinderTestAttribute : Attribute
+    {
+        public bool Isolate { get; private set; }
+        public string TestSetup { get; private set; }
+        public BinderTestAttribute(bool isolate = false, string testSetup = null)
+        {
+            Isolate = isolate;
+            TestSetup = testSetup;
+        }
+    }
+
     class BinderTracingTest
     {
-        public static void PlatformAssembly_DefaultALC()
+        public class CustomALC : AssemblyLoadContext
+        {
+            public CustomALC(string name) : base(name)
+            { }
+        }
+
+        private const string DefaultALC = "Default";
+        private const string DependentAssemblyName = "AssemblyToLoad";
+
+        [BinderTest]
+        public static BindOperation LoadFile()
+        {
+            var executingAssembly = Assembly.GetExecutingAssembly();
+            Assembly asm = Assembly.LoadFile(executingAssembly.Location);
+
+            return new BindOperation()
+            {
+                AssemblyName = executingAssembly.GetName(),
+                AssemblyPath = executingAssembly.Location,
+                AssemblyLoadContext = AssemblyLoadContext.GetLoadContext(asm).ToString(),
+                RequestingAssemblyLoadContext = DefaultALC,
+                Success = true,
+                ResultAssemblyName = asm.GetName(),
+                ResultAssemblyPath = asm.Location,
+                Cached = false
+            };
+        }
+
+        [BinderTest]
+        public static BindOperation LoadBytes()
+        {
+            var executingAssembly = Assembly.GetExecutingAssembly();
+            Byte[] bytes = File.ReadAllBytes(executingAssembly.Location);
+            Assembly asm = Assembly.Load(bytes);
+
+            return new BindOperation()
+            {
+                AssemblyName = executingAssembly.GetName(),
+                AssemblyLoadContext = AssemblyLoadContext.GetLoadContext(asm).ToString(),
+                RequestingAssemblyLoadContext = DefaultALC,
+                Success = true,
+                ResultAssemblyName = asm.GetName(),
+                ResultAssemblyPath = asm.Location,
+                Cached = false
+            };
+        }
+
+        [BinderTest]
+        public static BindOperation LoadFromStream()
+        {
+            var executingAssembly = Assembly.GetExecutingAssembly();
+            Stream stream = File.OpenRead(executingAssembly.Location);
+            CustomALC alc = new CustomALC(nameof(LoadFromStream));
+            Assembly asm = alc.LoadFromStream(stream);
+
+            return new BindOperation()
+            {
+                AssemblyName = executingAssembly.GetName(),
+                AssemblyLoadContext = alc.ToString(),
+                RequestingAssemblyLoadContext = DefaultALC,
+                Success = true,
+                ResultAssemblyName = asm.GetName(),
+                ResultAssemblyPath = asm.Location,
+                Cached = false
+            };
+        }
+
+        [BinderTest]
+        public static BindOperation LoadFromAssemblyPath()
+        {
+            CustomALC alc = new CustomALC(nameof(LoadFromAssemblyPath));
+            var executingAssembly = Assembly.GetExecutingAssembly();
+            Assembly asm = alc.LoadFromAssemblyPath(executingAssembly.Location);
+
+            return new BindOperation()
+            {
+                AssemblyName = executingAssembly.GetName(),
+                AssemblyPath = executingAssembly.Location,
+                AssemblyLoadContext = alc.ToString(),
+                RequestingAssemblyLoadContext = DefaultALC,
+                Success = true,
+                ResultAssemblyName = asm.GetName(),
+                ResultAssemblyPath = asm.Location,
+                Cached = false
+            };
+        }
+
+        [BinderTest(isolate: true)]
+        public static BindOperation LoadFromAssemblyName()
         {
-            Console.WriteLine($"Running {nameof(PlatformAssembly_DefaultALC)}...");
-            using (var listener = new BinderEventListener())
+            AssemblyName assemblyName = new AssemblyName("System.Xml");
+            CustomALC alc = new CustomALC(nameof(LoadFromAssemblyName));
+            Assembly asm = alc.LoadFromAssemblyName(assemblyName);
+
+            return new BindOperation()
+            {
+                AssemblyName = assemblyName,
+                AssemblyLoadContext = alc.ToString(),
+                Success = true,
+                ResultAssemblyName = asm.GetName(),
+                ResultAssemblyPath = asm.Location,
+                Cached = false
+            };
+        }
+
+        [BinderTest(isolate: true)]
+        public static BindOperation LoadFrom()
+        {
+            var executingAssembly = Assembly.GetExecutingAssembly();
+            Assembly asm = Assembly.LoadFrom(executingAssembly.Location);
+
+            return new BindOperation()
+            {
+                AssemblyName = executingAssembly.GetName(),
+                AssemblyPath = executingAssembly.Location,
+                AssemblyLoadContext = DefaultALC,
+                RequestingAssemblyLoadContext = DefaultALC,
+                Success = true,
+                ResultAssemblyName = asm.GetName(),
+                ResultAssemblyPath = asm.Location,
+                Cached = false
+            };
+        }
+
+        [BinderTest(isolate: true)]
+        public static BindOperation PlatformAssembly()
+        {
+            string assemblyName = "System.Xml";
+            Assembly asm = Assembly.Load(assemblyName);
+
+            return new BindOperation()
+            {
+                AssemblyName = new AssemblyName(assemblyName),
+                AssemblyLoadContext = DefaultALC,
+                RequestingAssemblyLoadContext = DefaultALC,
+                Success = true,
+                ResultAssemblyName = asm.GetName(),
+                ResultAssemblyPath = asm.Location,
+                Cached = false
+            };
+        }
+
+        [BinderTest(isolate: true, testSetup: nameof(PlatformAssembly))]
+        public static BindOperation PlatformAssembly_Cached()
+        {
+            BindOperation bind = PlatformAssembly();
+            bind.Cached = true;
+            return bind;
+        }
+
+        [BinderTest]
+        public static BindOperation NonExistentAssembly()
+        {
+            string assemblyName = "DoesNotExist";
+            try
+            {
+                Assembly.Load(assemblyName);
+            }
+            catch { }
+
+            return new BindOperation()
+            {
+                AssemblyName = new AssemblyName(assemblyName),
+                AssemblyLoadContext = DefaultALC,
+                RequestingAssemblyLoadContext = DefaultALC,
+                Success = false,
+                Cached = false
+            };
+        }
+
+        [BinderTest(isolate: true)]
+        public static BindOperation Reflection()
+        {
+            var t = GetDependentAssemblyType();
+
+            return new BindOperation()
+            {
+                AssemblyName = new AssemblyName(DependentAssemblyName),
+                AssemblyLoadContext = DefaultALC,
+                RequestingAssemblyLoadContext = DefaultALC,
+                RequestingAssembly = Assembly.GetExecutingAssembly().GetName(),
+                Success = true,
+                ResultAssemblyName = t.Assembly.GetName(),
+                ResultAssemblyPath = t.Assembly.Location,
+                Cached = false,
+            };
+        }
+
+        [BinderTest(isolate: true, testSetup: nameof(Reflection))]
+        public static BindOperation Reflection_Cached()
+        {
+            BindOperation bind = Reflection();
+            bind.Cached = true;
+            return bind;
+        }
+
+        [BinderTest(isolate: true)]
+        public static BindOperation Reflection_CustomALC()
+        {
+            CustomALC alc = new CustomALC(nameof(Reflection_CustomALC));
+            Type testClass = LoadTestClassInALC(alc);
+            MethodInfo method = testClass.GetMethod(nameof(GetDependentAssemblyType), BindingFlags.NonPublic | BindingFlags.Static);
+            Type t = (Type)method.Invoke(null, new object[0]);
+
+            return new BindOperation()
+            {
+                AssemblyName = new AssemblyName(DependentAssemblyName),
+                AssemblyLoadContext = alc.ToString(),
+                RequestingAssembly = testClass.Assembly.GetName(),
+                RequestingAssemblyLoadContext = alc.ToString(),
+                Success = true,
+                ResultAssemblyName = t.Assembly.GetName(),
+                ResultAssemblyPath = t.Assembly.Location,
+                Cached = false,
+            };
+        }
+
+        [BinderTest(isolate: true)]
+        public static BindOperation ContextualReflection_DefaultToCustomALC()
+        {
+            Type t;
+            CustomALC alc = new CustomALC(nameof(ContextualReflection_DefaultToCustomALC));
+            using (alc.EnterContextualReflection())
+            {
+                t = GetDependentAssemblyType();
+            }
+
+            return new BindOperation()
             {
-                string assemblyName = "System.Xml";
-                Assembly asm = Assembly.Load(assemblyName);
+                AssemblyName = new AssemblyName(DependentAssemblyName),
+                AssemblyLoadContext = alc.ToString(),
+                RequestingAssembly = Assembly.GetExecutingAssembly().GetName(),
+                RequestingAssemblyLoadContext = DefaultALC,
+                Success = true,
+                ResultAssemblyName = t.Assembly.GetName(),
+                ResultAssemblyPath = t.Assembly.Location,
+                Cached = false,
+            };
+        }
+
+        [BinderTest(isolate: true)]
+        public static BindOperation ContextualReflection_CustomToDefaultALC()
+        {
+            CustomALC alc = new CustomALC(nameof(ContextualReflection_CustomToDefaultALC));
+            Type testClass = LoadTestClassInALC(alc);
+            MethodInfo method = testClass.GetMethod(nameof(GetDependentAssemblyType), BindingFlags.NonPublic | BindingFlags.Static);
 
-                BindOperation[] binds = listener.WaitAndGetEventsForAssembly(assemblyName);
-                Assert.True(binds.Length == 1, $"Bind count for {assemblyName} - expected: 1, actual: {binds.Length}");
-                BindOperation bind = binds[0];
-                Assert.True(bind.Success, $"Expected bind for {assemblyName} to succeed");
+            Type t;
+            using (AssemblyLoadContext.Default.EnterContextualReflection())
+            {
+                t = (Type)method.Invoke(null, new object[0]);
             }
+
+            return new BindOperation()
+            {
+                AssemblyName = new AssemblyName(DependentAssemblyName),
+                AssemblyLoadContext = DefaultALC,
+                RequestingAssembly = testClass.Assembly.GetName(),
+                RequestingAssemblyLoadContext = alc.ToString(),
+                Success = true,
+                ResultAssemblyName = t.Assembly.GetName(),
+                ResultAssemblyPath = t.Assembly.Location,
+                Cached = false,
+            };
+        }
+
+        [BinderTest(isolate: true)]
+        public static BindOperation JITLoad()
+        {
+            Assembly asm = UseDependentAssembly();
+
+            return new BindOperation()
+            {
+                AssemblyName = asm.GetName(),
+                AssemblyLoadContext = DefaultALC,
+                RequestingAssembly = Assembly.GetExecutingAssembly().GetName(),
+                RequestingAssemblyLoadContext = DefaultALC,
+                Success = true,
+                ResultAssemblyName = asm.GetName(),
+                ResultAssemblyPath = asm.Location,
+                Cached = false,
+            };
         }
 
-        public static void NonExistentAssembly_DefaultALC()
+        [BinderTest(isolate: true)]
+        public static BindOperation JITLoad_CustomALC()
         {
-            Console.WriteLine($"Running {nameof(NonExistentAssembly_DefaultALC)}...");
-            using (var listener = new BinderEventListener())
+            CustomALC alc = new CustomALC(nameof(JITLoad_CustomALC));
+            Type testClass= LoadTestClassInALC(alc);
+            MethodInfo method = testClass.GetMethod(nameof(UseDependentAssembly), BindingFlags.NonPublic | BindingFlags.Static);
+            Assembly asm = (Assembly)method.Invoke(null, new object[0]);
+
+            return new BindOperation()
             {
-                string assemblyName = "DoesNotExist";
-                try
+                AssemblyName = asm.GetName(),
+                AssemblyLoadContext = alc.ToString(),
+                RequestingAssembly = testClass.Assembly.GetName(),
+                RequestingAssemblyLoadContext = alc.ToString(),
+                Success = true,
+                ResultAssemblyName = asm.GetName(),
+                ResultAssemblyPath = asm.Location,
+                Cached = false
+            };
+        }
+
+        public static bool RunAllTests()
+        {
+            MethodInfo[] methods = typeof(BinderTracingTest)
+                .GetMethods(BindingFlags.Public | BindingFlags.Static)
+                .Where(m => m.GetCustomAttribute<BinderTestAttribute>() != null && m.ReturnType == typeof(BindOperation))
+                .ToArray();
+
+            foreach (var method in methods)
+            {
+                BinderTestAttribute attribute = method.GetCustomAttribute<BinderTestAttribute>();
+                bool success = attribute.Isolate
+                    ? RunTestInSeparateProcess(method)
+                    : RunSingleTest(method);
+                if (!success)
                 {
-                    Assembly.Load(assemblyName);
+                    return false;
                 }
-                catch { }
-
-                BindOperation[] binds = listener.WaitAndGetEventsForAssembly(assemblyName);
-                Assert.True(binds.Length == 1, $"Bind event count for {assemblyName} - expected: 1, actual: {binds.Length}");
-                BindOperation bind = binds[0];
-                Assert.False(bind.Success, $"Expected bind for {assemblyName} to fail");
             }
+
+            return true;
         }
 
-        public static int Main(string[] unused)
+        public static int Main(string[] args)
         {
+            bool success;
             try
             {
-                PlatformAssembly_DefaultALC();
-                NonExistentAssembly_DefaultALC();
+                if (args.Length == 0)
+                {
+                    success = RunAllTests();
+                }
+                else
+                {
+                    // Run specific test - first argument should be the test method name
+                    MethodInfo method = typeof(BinderTracingTest)
+                        .GetMethod(args[0], BindingFlags.Public | BindingFlags.Static);
+                    Assert.IsTrue(method != null && method.GetCustomAttribute<BinderTestAttribute>() != null && method.ReturnType == typeof(BindOperation));
+                    success = RunSingleTest(method);
+                }
             }
             catch (Exception e)
             {
@@ -60,7 +378,114 @@ namespace BinderTracingTests
                 return 101;
             }
 
-            return 100;
+            return success ? 100 : 101;
+        }
+
+        private static Assembly UseDependentAssembly()
+        {
+            var p = new AssemblyToLoad.Program();
+            return Assembly.GetAssembly(p.GetType());
+        }
+
+        private static Type GetDependentAssemblyType()
+        {
+            return Type.GetType($"AssemblyToLoad.Program, {DependentAssemblyName}");
+        }
+
+        private static Type LoadTestClassInALC(AssemblyLoadContext alc)
+        {
+            Assembly asm = alc.LoadFromAssemblyPath(Assembly.GetExecutingAssembly().Location);
+            return asm.GetType(typeof(BinderTracingTest).FullName);
+        }
+
+        private static bool RunSingleTest(MethodInfo method)
+        {
+            Console.WriteLine($"Running {method.Name}...");
+            try
+            {
+                BinderTestAttribute attribute = method.GetCustomAttribute<BinderTestAttribute>();
+                if (!string.IsNullOrEmpty(attribute.TestSetup))
+                {
+                    MethodInfo setupMethod = method.DeclaringType
+                        .GetMethod(attribute.TestSetup, BindingFlags.Public | BindingFlags.Static);
+                    Assert.IsTrue(setupMethod != null);
+                    setupMethod.Invoke(null, new object[0]);
+                }
+
+                Func<BindOperation> func = (Func<BindOperation>)method.CreateDelegate(typeof(Func<BindOperation>));
+                using (var listener = new BinderEventListener())
+                {
+                    BindOperation expected = func();
+                    ValidateSingleBind(listener, expected.AssemblyName.Name, expected);
+                }
+            }
+            catch (Exception e)
+            {
+                Console.WriteLine($"Test {method.Name} failed: {e}");
+                return false;
+            }
+
+            return true;
+        }
+
+        private static bool RunTestInSeparateProcess(MethodInfo method)
+        {
+            var startInfo = new ProcessStartInfo()
+            {
+                FileName = Process.GetCurrentProcess().MainModule.FileName,
+                Arguments = $"{Assembly.GetExecutingAssembly().Location} {method.Name}",
+                UseShellExecute = false,
+                RedirectStandardOutput = true,
+                RedirectStandardError = true
+            };
+
+            Console.WriteLine($"Launching process for {method.Name}...");
+            using (Process p = Process.Start(startInfo))
+            {
+                p.OutputDataReceived += (_, args) => Console.WriteLine(args.Data);
+                p.BeginOutputReadLine();
+
+                p.ErrorDataReceived += (_, args) => Console.Error.WriteLine(args.Data);
+                p.BeginErrorReadLine();
+
+                p.WaitForExit();
+                return p.ExitCode == 100;
+            }
+        }
+
+        private static void ValidateSingleBind(BinderEventListener listener, string assemblyName, BindOperation expected)
+        {
+            BindOperation[] binds = listener.WaitAndGetEventsForAssembly(assemblyName);
+            Assert.IsTrue(binds.Length == 1, $"Bind event count for {assemblyName} - expected: 1, actual: {binds.Length}");
+            BindOperation actual = binds[0];
+
+            ValidateAssemblyName(expected.AssemblyName, actual.AssemblyName, nameof(BindOperation.AssemblyName));
+            Assert.AreEqual(expected.AssemblyPath ?? string.Empty, actual.AssemblyPath, $"Unexpected value for {nameof(BindOperation.AssemblyPath)} on event");
+            Assert.AreEqual(expected.AssemblyLoadContext, actual.AssemblyLoadContext, $"Unexpected value for {nameof(BindOperation.AssemblyLoadContext)} on event");
+            ValidateAssemblyName(expected.RequestingAssembly, actual.RequestingAssembly, nameof(BindOperation.RequestingAssembly));
+            Assert.AreEqual(expected.RequestingAssemblyLoadContext ?? string.Empty, actual.RequestingAssemblyLoadContext, $"Unexpected value for {nameof(BindOperation.RequestingAssemblyLoadContext)} on event");
+
+            Assert.AreEqual(expected.Success, actual.Success, $"Unexpected value for {nameof(BindOperation.Success)} on event");
+            Assert.AreEqual(expected.ResultAssemblyPath ?? string.Empty, actual.ResultAssemblyPath, $"Unexpected value for {nameof(BindOperation.ResultAssemblyPath)} on event");
+            Assert.AreEqual(expected.Cached, actual.Cached, $"Unexpected value for {nameof(BindOperation.Cached)} on event");
+            ValidateAssemblyName(expected.ResultAssemblyName, actual.ResultAssemblyName, nameof(BindOperation.ResultAssemblyName));
+        }
+
+        private static void ValidateAssemblyName(AssemblyName expected, AssemblyName actual, string propertyName)
+        {
+            if (expected == null)
+            {
+                return;
+            }
+
+            if (expected.Version != null)
+            {
+                Assert.AreEqual(expected.FullName, actual.FullName, $"Unexpected value for {propertyName} on event");
+            }
+            else
+            {
+                Assert.AreEqual(expected.Name, actual.Name, $"Unexpected value for {propertyName} on event");
+            }
         }
     }
 }
index 286fec5..f581bf1 100644 (file)
@@ -8,5 +8,6 @@
   </ItemGroup>
   <ItemGroup>
     <ProjectReference Include="$(SourceDir)Common/CoreCLRTestLibrary/CoreCLRTestLibrary.csproj" />
+    <ProjectReference Include="AssemblyToLoad.csproj" />
   </ItemGroup>
 </Project>
index 8ee2cc1..07c2112 100644 (file)
@@ -52,7 +52,7 @@ namespace System.Diagnostics.Tracing
         ///
         /// If activity tracing is not on, then activityId and relatedActivityId are not set
         /// </summary>
-        public void OnStart(string providerName, string activityName, int task, ref Guid activityId, ref Guid relatedActivityId, EventActivityOptions options)
+        public void OnStart(string providerName, string activityName, int task, ref Guid activityId, ref Guid relatedActivityId, EventActivityOptions options, bool useTplSource = true)
         {
             if (m_current == null)        // We are not enabled
             {
@@ -62,7 +62,7 @@ namespace System.Diagnostics.Tracing
                 if (m_checkedForEnable)
                     return;
                 m_checkedForEnable = true;
-                if (TplEventSource.Log.IsEnabled(EventLevel.Informational, TplEventSource.Keywords.TasksFlowActivityIds))
+                if (useTplSource && TplEventSource.Log.IsEnabled(EventLevel.Informational, TplEventSource.Keywords.TasksFlowActivityIds))
                     Enable();
                 if (m_current == null)
                     return;
@@ -73,11 +73,12 @@ namespace System.Diagnostics.Tracing
             ActivityInfo? currentActivity = m_current.Value;
             string fullActivityName = NormalizeActivityName(providerName, activityName, task);
 
-            TplEventSource log = TplEventSource.Log;
-            if (log.Debug)
+            TplEventSource? log = useTplSource ? TplEventSource.Log : null;
+            bool tplDebug = log != null && log.Debug;
+            if (tplDebug)
             {
-                log.DebugFacilityMessage("OnStartEnter", fullActivityName);
-                log.DebugFacilityMessage("OnStartEnterActivityState", ActivityInfo.LiveActivities(currentActivity));
+                log!.DebugFacilityMessage("OnStartEnter", fullActivityName);
+                log!.DebugFacilityMessage("OnStartEnterActivityState", ActivityInfo.LiveActivities(currentActivity));
             }
 
             if (currentActivity != null)
@@ -87,8 +88,8 @@ namespace System.Diagnostics.Tracing
                 {
                     activityId = Guid.Empty;
                     relatedActivityId = Guid.Empty;
-                    if (log.Debug)
-                        log.DebugFacilityMessage("OnStartRET", "Fail");
+                    if (tplDebug)
+                        log!.DebugFacilityMessage("OnStartRET", "Fail");
                     return;
                 }
                 // Check for recursion, and force-stop any activities if the activity already started.
@@ -120,10 +121,10 @@ namespace System.Diagnostics.Tracing
             // Remember the current ID so we can log it
             activityId = newActivity.ActivityId;
 
-            if (log.Debug)
+            if (tplDebug)
             {
-                log.DebugFacilityMessage("OnStartRetActivityState", ActivityInfo.LiveActivities(newActivity));
-                log.DebugFacilityMessage1("OnStartRet", activityId.ToString(), relatedActivityId.ToString());
+                log!.DebugFacilityMessage("OnStartRetActivityState", ActivityInfo.LiveActivities(newActivity));
+                log!.DebugFacilityMessage1("OnStartRet", activityId.ToString(), relatedActivityId.ToString());
             }
         }
 
@@ -133,18 +134,19 @@ namespace System.Diagnostics.Tracing
         ///
         /// If activity tracing is not on, then activityId and relatedActivityId are not set
         /// </summary>
-        public void OnStop(string providerName, string activityName, int task, ref Guid activityId)
+        public void OnStop(string providerName, string activityName, int task, ref Guid activityId, bool useTplSource = true)
         {
             if (m_current == null)        // We are not enabled
                 return;
 
             string fullActivityName = NormalizeActivityName(providerName, activityName, task);
 
-            TplEventSource log = TplEventSource.Log;
-            if (log.Debug)
+            TplEventSource? log = useTplSource ? TplEventSource.Log : null;
+            bool tplDebug = log != null && log.Debug;
+            if (tplDebug)
             {
-                log.DebugFacilityMessage("OnStopEnter", fullActivityName);
-                log.DebugFacilityMessage("OnStopEnterActivityState", ActivityInfo.LiveActivities(m_current.Value));
+                log!.DebugFacilityMessage("OnStopEnter", fullActivityName);
+                log!.DebugFacilityMessage("OnStopEnterActivityState", ActivityInfo.LiveActivities(m_current.Value));
             }
 
             while (true) // This is a retry loop.
@@ -162,8 +164,8 @@ namespace System.Diagnostics.Tracing
                 {
                     activityId = Guid.Empty;
                     // TODO add some logging about this. Basically could not find matching start.
-                    if (log.Debug)
-                        log.DebugFacilityMessage("OnStopRET", "Fail");
+                    if (tplDebug)
+                        log!.DebugFacilityMessage("OnStopRET", "Fail");
                     return;
                 }
 
@@ -201,10 +203,10 @@ namespace System.Diagnostics.Tracing
 
                     m_current.Value = newCurrentActivity;
 
-                    if (log.Debug)
+                    if (tplDebug)
                     {
-                        log.DebugFacilityMessage("OnStopRetActivityState", ActivityInfo.LiveActivities(newCurrentActivity));
-                        log.DebugFacilityMessage("OnStopRet", activityId.ToString());
+                        log!.DebugFacilityMessage("OnStopRetActivityState", ActivityInfo.LiveActivities(newCurrentActivity));
+                        log!.DebugFacilityMessage("OnStopRet", activityId.ToString());
                     }
                     return;
                 }