Little refactoring based on PR feedback
authorJan Vorlicek <janvorli@microsoft.com>
Sat, 4 Apr 2020 00:22:30 +0000 (17:22 -0700)
committerJan Vorlicek <janvorli@microsoft.com>
Sat, 4 Apr 2020 00:22:30 +0000 (17:22 -0700)
src/coreclr/src/vm/threads.h
src/coreclr/src/vm/threadsuspend.cpp

index 5d8b068..5b082e0 100644 (file)
@@ -3412,7 +3412,7 @@ public:
 
 private:
 #ifdef FEATURE_HIJACK
-    void    HijackThread(VOID *pvHijackAddr, ExecutionState *esb);
+    void    HijackThread(ReturnKind returnKind, ExecutionState *esb);
 
     VOID        *m_pvHJRetAddr;           // original return address (before hijack)
     VOID       **m_ppvHJRetAddrPtr;       // place we bashed a new return address
index 7a4234a..1b35a11 100644 (file)
@@ -5280,7 +5280,7 @@ struct ExecutionState
 };
 
 // Client is responsible for suspending the thread before calling
-void Thread::HijackThread(VOID *pvHijackAddr, ExecutionState *esb)
+void Thread::HijackThread(ReturnKind returnKind, ExecutionState *esb)
 {
     CONTRACTL {
         NOTHROW;
@@ -5288,6 +5288,15 @@ void Thread::HijackThread(VOID *pvHijackAddr, ExecutionState *esb)
     }
     CONTRACTL_END;
 
+    _ASSERTE(IsValidReturnKind(returnKind));
+    VOID *pvHijackAddr = reinterpret_cast<VOID *>(OnHijackTripThread);
+#ifdef TARGET_X86
+    if (returnKind == RT_Float)
+    {
+        hijackAddress = reinterpret_cast<VOID *>(OnHijackFPTripThread);
+    }
+#endif // TARGET_X86
+
     // Don't hijack if are in the first level of running a filter/finally/catch.
     // This is because they share ebp with their containing function further down the
     // stack and we will hijack their containing function incorrectly
@@ -5305,7 +5314,7 @@ void Thread::HijackThread(VOID *pvHijackAddr, ExecutionState *esb)
         return;
     }
 
-    IS_VALID_CODE_PTR((FARPROC) pvHijackAddr);
+    SetHijackReturnKind(returnKind);
 
     if (m_State & TS_Hijacked)
         UnhijackThread();
@@ -5616,28 +5625,10 @@ void STDCALL OnHijackWorker(HijackArgs * pArgs)
 #endif // HIJACK_NONINTERRUPTIBLE_THREADS
 }
 
-bool GetReturnAddressHijackInfo(Thread *pThread, EECodeInfo *codeInfo, void** hijackAddress)
+bool GetReturnAddressHijackInfo(Thread *pThread, EECodeInfo *pCodeInfo, ReturnKind *pReturnKind)
 {
-    ReturnKind returnKind;
-    GCInfoToken gcInfoToken = codeInfo->GetGCInfoToken();
-    if (!codeInfo->GetCodeManager()->GetReturnAddressHijackInfo(gcInfoToken, &returnKind))
-    {
-        return false;
-    }
-
-    _ASSERTE(IsValidReturnKind(returnKind));
-    pThread->SetHijackReturnKind(returnKind);
-
-#ifdef TARGET_X86
-    if (returnKind == RT_Float)
-    {
-        *hijackAddress = reinterpret_cast<VOID *>(OnHijackFPTripThread);
-    }
-#endif // TARGET_X86
-
-    *hijackAddress = reinterpret_cast<VOID *>(OnHijackTripThread);
-
-    return true;
+    GCInfoToken gcInfoToken = pCodeInfo->GetGCInfoToken();
+    return pCodeInfo->GetCodeManager()->GetReturnAddressHijackInfo(gcInfoToken, pReturnKind);
 }
 
 #ifndef TARGET_UNIX
@@ -6093,9 +6084,9 @@ BOOL Thread::HandledJITCase(BOOL ForTaskSwitchIn)
             // it or not.
             EECodeInfo codeInfo(ip);
 
-            VOID *pvHijackAddr;
+            ReturnKind returnKind;
 
-            if (GetReturnAddressHijackInfo(this, &codeInfo, &pvHijackAddr))
+            if (GetReturnAddressHijackInfo(this, &codeInfo, &returnKind))
             {
 
 #ifdef FEATURE_ENABLE_GCPOLL
@@ -6107,7 +6098,7 @@ BOOL Thread::HandledJITCase(BOOL ForTaskSwitchIn)
                 if (EEConfig::GCPOLL_TYPE_HIJACK == pollType || EEConfig::GCPOLL_TYPE_DEFAULT == pollType)
 #endif // FEATURE_ENABLE_GCPOLL
                 {
-                    HijackThread(pvHijackAddr, &esb);
+                    HijackThread(returnKind, &esb);
                 }
             }
         }
@@ -6646,8 +6637,9 @@ void HandleGCSuspensionForInterruptedThread(CONTEXT *interruptedContext)
         if (executionState.m_ppvRetAddrPtr == NULL)
             return;
 
-        void *pvHijackAddr;
-        if (!GetReturnAddressHijackInfo(pThread, &codeInfo, &pvHijackAddr))
+        ReturnKind returnKind;
+
+        if (!GetReturnAddressHijackInfo(pThread, &codeInfo, &returnKind))
         {
             return;
         }
@@ -6661,7 +6653,7 @@ void HandleGCSuspensionForInterruptedThread(CONTEXT *interruptedContext)
         StackWalkerWalkingThreadHolder threadStackWalking(pThread);
 
         // Hijack the return address to point to the appropriate routine based on the method's return type.
-        pThread->HijackThread(pvHijackAddr, &executionState);
+        pThread->HijackThread(returnKind, &executionState);
     }
 }