JIT: Avoid xchg for resolution (#81216)
authorJakob Botsch Nielsen <Jakob.botsch.nielsen@gmail.com>
Wed, 1 Feb 2023 10:12:25 +0000 (11:12 +0100)
committerGitHub <noreply@github.com>
Wed, 1 Feb 2023 10:12:25 +0000 (11:12 +0100)
LSRA today uses xchg for reg-reg resolution when there are cycles in the
resolution graph. Benchmarks show that xchg is handled much less
efficiently by Intel CPUs than using a few movs with a temporary
register. This PR enables using temporary registers for this kind of
resolution on xarch (it was already enabled for non-xarch
architectures). xchg is still used on xarch if no temporary register is
available.

Additionally this PR adds support for getting a temporary register even
for shared critical edges. Before this change we would spill on
non-xarch for this case.

Finally, we now try to prefer a non callee saved temporary register so
that we don't need to save/restore it in prolog/epilog.

This mostly fixes the string hashcode regressions from #80743.

src/coreclr/jit/lsra.cpp
src/coreclr/jit/lsra.h

index 840bdb3..d8754cb 100644 (file)
@@ -7374,9 +7374,12 @@ void LinearScan::insertSwap(
 // getTempRegForResolution: Get a free register to use for resolution code.
 //
 // Arguments:
-//    fromBlock - The "from" block on the edge being resolved.
-//    toBlock   - The "to" block on the edge
-//    type      - the type of register required
+//    fromBlock              - The "from" block on the edge being resolved.
+//    toBlock                - The "to" block on the edge. Can be null for shared critical edge resolution.
+//    type                   - The type of register required
+//    sharedCriticalLiveSet  - The set of live vars that require shared critical resolution. Only used when toBlock is
+//                             nullptr.
+//    terminatorConsumedRegs - Registers consumed by 'fromBlock's terminating node.
 //
 // Return Value:
 //    Returns a register that is free on the given edge, or REG_NA if none is available.
@@ -7386,12 +7389,16 @@ void LinearScan::insertSwap(
 //    available, and to handle that case appropriately.
 //    It is also up to the caller to cache the return value, as this is not cheap to compute.
 
-regNumber LinearScan::getTempRegForResolution(BasicBlock* fromBlock, BasicBlock* toBlock, var_types type)
+regNumber LinearScan::getTempRegForResolution(BasicBlock*      fromBlock,
+                                              BasicBlock*      toBlock,
+                                              var_types        type,
+                                              VARSET_VALARG_TP sharedCriticalLiveSet,
+                                              regMaskTP        terminatorConsumedRegs)
 {
     // TODO-Throughput: This would be much more efficient if we add RegToVarMaps instead of VarToRegMaps
     // and they would be more space-efficient as well.
     VarToRegMap fromVarToRegMap = getOutVarToRegMap(fromBlock->bbNum);
-    VarToRegMap toVarToRegMap   = getInVarToRegMap(toBlock->bbNum);
+    VarToRegMap toVarToRegMap   = toBlock == nullptr ? nullptr : getInVarToRegMap(toBlock->bbNum);
 
 #ifdef TARGET_ARM
     regMaskTP freeRegs;
@@ -7416,21 +7423,47 @@ regNumber LinearScan::getTempRegForResolution(BasicBlock* fromBlock, BasicBlock*
 #endif // DEBUG
     INDEBUG(freeRegs = stressLimitRegs(nullptr, freeRegs));
 
+    freeRegs &= ~terminatorConsumedRegs;
+
     // We are only interested in the variables that are live-in to the "to" block.
-    VarSetOps::Iter iter(compiler, toBlock->bbLiveIn);
+    VarSetOps::Iter iter(compiler, toBlock == nullptr ? fromBlock->bbLiveOut : toBlock->bbLiveIn);
     unsigned        varIndex = 0;
     while (iter.NextElem(&varIndex) && freeRegs != RBM_NONE)
     {
         regNumber fromReg = getVarReg(fromVarToRegMap, varIndex);
-        regNumber toReg   = getVarReg(toVarToRegMap, varIndex);
-        assert(fromReg != REG_NA && toReg != REG_NA);
+        assert(fromReg != REG_NA);
         if (fromReg != REG_STK)
         {
-            freeRegs &= ~genRegMask(fromReg, getIntervalForLocalVar(varIndex)->registerType);
+            freeRegs &= ~genRegMask(fromReg ARM_ARG(getIntervalForLocalVar(varIndex)->registerType));
         }
-        if (toReg != REG_STK)
+
+        if (toBlock != nullptr)
         {
-            freeRegs &= ~genRegMask(toReg, getIntervalForLocalVar(varIndex)->registerType);
+            regNumber toReg = getVarReg(toVarToRegMap, varIndex);
+            assert(toReg != REG_NA);
+            if (toReg != REG_STK)
+            {
+                freeRegs &= ~genRegMask(toReg ARM_ARG(getIntervalForLocalVar(varIndex)->registerType));
+            }
+        }
+    }
+
+    if (toBlock == nullptr)
+    {
+        // Resolution of critical edge that was determined to be shared (i.e.
+        // all vars requiring resolution are going into the same registers for
+        // all successor edges).
+
+        VarSetOps::Iter iter(compiler, sharedCriticalLiveSet);
+        varIndex = 0;
+        while (iter.NextElem(&varIndex) && freeRegs != RBM_NONE)
+        {
+            regNumber reg = getVarReg(sharedCriticalVarToRegMap, varIndex);
+            assert(reg != REG_NA);
+            if (reg != REG_STK)
+            {
+                freeRegs &= ~genRegMask(reg ARM_ARG(getIntervalForLocalVar(varIndex)->registerType));
+            }
         }
     }
 
@@ -7448,6 +7481,12 @@ regNumber LinearScan::getTempRegForResolution(BasicBlock* fromBlock, BasicBlock*
     }
     else
     {
+        // Prefer a callee-trashed register if possible to prevent new prolog/epilog saves/restores.
+        if ((freeRegs & RBM_CALLEE_TRASH) != 0)
+        {
+            freeRegs &= RBM_CALLEE_TRASH;
+        }
+
         regNumber tempReg = genRegNumFromMask(genFindLowestBit(freeRegs));
         return tempReg;
     }
@@ -7851,7 +7890,7 @@ void LinearScan::handleOutgoingCriticalEdges(BasicBlock* block)
         else
         {
             // For any vars in the sameResolutionSet, we can simply add the move at the end of "block".
-            resolveEdge(block, nullptr, ResolveSharedCritical, sameResolutionSet);
+            resolveEdge(block, nullptr, ResolveSharedCritical, sameResolutionSet, consumedRegs);
         }
     }
     if (!VarSetOps::IsEmpty(compiler, diffResolutionSet))
@@ -7907,7 +7946,7 @@ void LinearScan::handleOutgoingCriticalEdges(BasicBlock* block)
                 }
                 else
                 {
-                    resolveEdge(block, succBlock, ResolveCritical, edgeResolutionSet);
+                    resolveEdge(block, succBlock, ResolveCritical, edgeResolutionSet, consumedRegs);
                 }
             }
         }
@@ -7940,7 +7979,6 @@ void LinearScan::resolveEdges()
     // The resolutionCandidateVars set was initialized with all the lclVars that are live-in to
     // any block. We now intersect that set with any lclVars that ever spilled or split.
     // If there are no candidates for resolution, simply return.
-
     VarSetOps::IntersectionD(compiler, resolutionCandidateVars, splitOrSpilledVars);
     if (VarSetOps::IsEmpty(compiler, resolutionCandidateVars))
     {
@@ -7998,7 +8036,7 @@ void LinearScan::resolveEdges()
                     uniquePredBlock = uniquePredBlock->GetUniquePred(compiler);
                     noway_assert(uniquePredBlock != nullptr);
                 }
-                resolveEdge(uniquePredBlock, block, ResolveSplit, inResolutionSet);
+                resolveEdge(uniquePredBlock, block, ResolveSplit, inResolutionSet, RBM_NONE);
             }
         }
 
@@ -8017,7 +8055,7 @@ void LinearScan::resolveEdges()
                     VarSetOps::Intersection(compiler, succBlock->bbLiveIn, resolutionCandidateVars));
                 if (!VarSetOps::IsEmpty(compiler, outResolutionSet))
                 {
-                    resolveEdge(block, succBlock, ResolveJoin, outResolutionSet);
+                    resolveEdge(block, succBlock, ResolveJoin, outResolutionSet, RBM_NONE);
                 }
             }
         }
@@ -8130,10 +8168,12 @@ void LinearScan::resolveEdges()
 // resolveEdge: Perform the specified type of resolution between two blocks.
 //
 // Arguments:
-//    fromBlock     - the block from which the edge originates
-//    toBlock       - the block at which the edge terminates
-//    resolveType   - the type of resolution to be performed
-//    liveSet       - the set of tracked lclVar indices which may require resolution
+//    fromBlock              - the block from which the edge originates
+//    toBlock                - the block at which the edge terminates
+//    resolveType            - the type of resolution to be performed
+//    liveSet                - the set of tracked lclVar indices which may require resolution
+//    terminatorConsumedRegs - the registers consumed by the terminator node.
+//                             These registers will be used after any resolution added at the end of the 'fromBlock'.
 //
 // Return Value:
 //    None.
@@ -8152,7 +8192,8 @@ void LinearScan::resolveEdges()
 void LinearScan::resolveEdge(BasicBlock*      fromBlock,
                              BasicBlock*      toBlock,
                              ResolveType      resolveType,
-                             VARSET_VALARG_TP liveSet)
+                             VARSET_VALARG_TP liveSet,
+                             regMaskTP        terminatorConsumedRegs)
 {
     VarToRegMap fromVarToRegMap = getOutVarToRegMap(fromBlock->bbNum);
     VarToRegMap toVarToRegMap;
@@ -8194,24 +8235,21 @@ void LinearScan::resolveEdge(BasicBlock*      fromBlock,
             break;
     }
 
-#ifndef TARGET_XARCH
     // We record tempregs for beginning and end of each block.
     // For amd64/x86 we only need a tempReg for float - we'll use xchg for int.
     // TODO-Throughput: It would be better to determine the tempRegs on demand, but the code below
     // modifies the varToRegMaps so we don't have all the correct registers at the time
     // we need to get the tempReg.
-    regNumber tempRegInt =
-        (resolveType == ResolveSharedCritical) ? REG_NA : getTempRegForResolution(fromBlock, toBlock, TYP_INT);
-#endif // !TARGET_XARCH
+    regNumber tempRegInt = getTempRegForResolution(fromBlock, toBlock, TYP_INT, liveSet, terminatorConsumedRegs);
     regNumber tempRegFlt = REG_NA;
 #ifdef TARGET_ARM
     regNumber tempRegDbl = REG_NA;
 #endif
-    if ((compiler->compFloatingPointUsed) && (resolveType != ResolveSharedCritical))
+    if (compiler->compFloatingPointUsed)
     {
 #ifdef TARGET_ARM
         // Try to reserve a double register for TYP_DOUBLE and use it for TYP_FLOAT too if available.
-        tempRegDbl = getTempRegForResolution(fromBlock, toBlock, TYP_DOUBLE);
+        tempRegDbl = getTempRegForResolution(fromBlock, toBlock, TYP_DOUBLE, liveSet, terminatorConsumedRegs);
         if (tempRegDbl != REG_NA)
         {
             tempRegFlt = tempRegDbl;
@@ -8219,7 +8257,7 @@ void LinearScan::resolveEdge(BasicBlock*      fromBlock,
         else
 #endif // TARGET_ARM
         {
-            tempRegFlt = getTempRegForResolution(fromBlock, toBlock, TYP_FLOAT);
+            tempRegFlt = getTempRegForResolution(fromBlock, toBlock, TYP_FLOAT, liveSet, terminatorConsumedRegs);
         }
     }
 
@@ -8486,18 +8524,16 @@ void LinearScan::resolveEdge(BasicBlock*      fromBlock,
                         tempReg = tempRegFlt;
                 }
 #ifdef TARGET_XARCH
-                else
+                else if (tempRegInt == REG_NA)
                 {
                     useSwap = true;
                 }
-#else // !TARGET_XARCH
-
+#endif
                 else
                 {
                     tempReg = tempRegInt;
                 }
 
-#endif // !TARGET_XARCH
                 if (useSwap || tempReg == REG_NA)
                 {
                     // First, we have to figure out the destination register for what's currently in fromReg,
index cfbd744..c63d075 100644 (file)
@@ -706,7 +706,11 @@ public:
 
     void handleOutgoingCriticalEdges(BasicBlock* block);
 
-    void resolveEdge(BasicBlock* fromBlock, BasicBlock* toBlock, ResolveType resolveType, VARSET_VALARG_TP liveSet);
+    void resolveEdge(BasicBlock*      fromBlock,
+                     BasicBlock*      toBlock,
+                     ResolveType      resolveType,
+                     VARSET_VALARG_TP liveSet,
+                     regMaskTP        terminatorConsumedRegs);
 
     void resolveEdges();
 
@@ -1352,7 +1356,11 @@ private:
     // the block)
     VarToRegMap setInVarToRegMap(unsigned int bbNum, VarToRegMap srcVarToRegMap);
 
-    regNumber getTempRegForResolution(BasicBlock* fromBlock, BasicBlock* toBlock, var_types type);
+    regNumber getTempRegForResolution(BasicBlock*      fromBlock,
+                                      BasicBlock*      toBlock,
+                                      var_types        type,
+                                      VARSET_VALARG_TP sharedCriticalLiveSet,
+                                      regMaskTP        terminatorConsumedRegs);
 
 #ifdef DEBUG
     void dumpVarToRegMap(VarToRegMap map);