[SCEVExpander] Migrate costAndCollectOperands to use InstructionCost.
authorSander de Smalen <sander.desmalen@arm.com>
Tue, 16 Feb 2021 08:42:19 +0000 (08:42 +0000)
committerSander de Smalen <sander.desmalen@arm.com>
Tue, 16 Feb 2021 09:27:34 +0000 (09:27 +0000)
This patch changes costAndCollectOperands to use InstructionCost for
accumulated cost values.

isHighCostExpansion will return true if the cost has exceeded the budget.

Reviewed By: CarolineConcatto, ctetreau

Differential Revision: https://reviews.llvm.org/D92238

llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

index e0653c9e403c7aa8f2e8931b43adeb218d0244bc..288e2f014d5cabe0cbfac85ce1095ca2dc282048 100644 (file)
@@ -24,6 +24,7 @@
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/ValueHandle.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InstructionCost.h"
 
 namespace llvm {
 extern cl::opt<unsigned> SCEVCheapExpansionBudget;
@@ -237,15 +238,16 @@ public:
       return true; // by always claiming to be high-cost.
     SmallVector<SCEVOperand, 8> Worklist;
     SmallPtrSet<const SCEV *, 8> Processed;
-    int BudgetRemaining = Budget * TargetTransformInfo::TCC_Basic;
+    InstructionCost Cost = 0;
+    unsigned ScaledBudget = Budget * TargetTransformInfo::TCC_Basic;
     Worklist.emplace_back(-1, -1, Expr);
     while (!Worklist.empty()) {
       const SCEVOperand WorkItem = Worklist.pop_back_val();
-      if (isHighCostExpansionHelper(WorkItem, L, *At, BudgetRemaining,
-                                    *TTI, Processed, Worklist))
+      if (isHighCostExpansionHelper(WorkItem, L, *At, Cost, ScaledBudget, *TTI,
+                                    Processed, Worklist))
         return true;
     }
-    assert(BudgetRemaining >= 0 && "Should have returned from inner loop.");
+    assert(Cost <= ScaledBudget && "Should have returned from inner loop.");
     return false;
   }
 
@@ -399,11 +401,12 @@ private:
   Value *expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *I, bool Root);
 
   /// Recursive helper function for isHighCostExpansion.
-  bool isHighCostExpansionHelper(
-    const SCEVOperand &WorkItem, Loop *L, const Instruction &At,
-    int &BudgetRemaining, const TargetTransformInfo &TTI,
-    SmallPtrSetImpl<const SCEV *> &Processed,
-    SmallVectorImpl<SCEVOperand> &Worklist);
+  bool isHighCostExpansionHelper(const SCEVOperand &WorkItem, Loop *L,
+                                 const Instruction &At, InstructionCost &Cost,
+                                 unsigned Budget,
+                                 const TargetTransformInfo &TTI,
+                                 SmallPtrSetImpl<const SCEV *> &Processed,
+                                 SmallVectorImpl<SCEVOperand> &Worklist);
 
   /// Insert the specified binary operator, doing a small amount of work to
   /// avoid inserting an obviously redundant operation, and hoisting to an
index 518238d59e27ed0d5ed1b05a6bcbe032694b6611..2dbf732ad3647f3d0b65901b6621403660f23318 100644 (file)
@@ -2179,13 +2179,13 @@ SCEVExpander::getRelatedExistingExpansion(const SCEV *S, const Instruction *At,
   return None;
 }
 
-template<typename T> static int costAndCollectOperands(
+template<typename T> static InstructionCost costAndCollectOperands(
   const SCEVOperand &WorkItem, const TargetTransformInfo &TTI,
   TargetTransformInfo::TargetCostKind CostKind,
   SmallVectorImpl<SCEVOperand> &Worklist) {
 
   const T *S = cast<T>(WorkItem.S);
-  int Cost = 0;
+  InstructionCost Cost = 0;
   // Object to help map SCEV operands to expanded IR instructions.
   struct OperationIndices {
     OperationIndices(unsigned Opc, size_t min, size_t max) :
@@ -2200,7 +2200,7 @@ template<typename T> static int costAndCollectOperands(
   // we know what the generated user(s) will be.
   SmallVector<OperationIndices, 2> Operations;
 
-  auto CastCost = [&](unsigned Opcode) {
+  auto CastCost = [&](unsigned Opcode) -> InstructionCost {
     Operations.emplace_back(Opcode, 0, 0);
     return TTI.getCastInstrCost(Opcode, S->getType(),
                                 S->getOperand(0)->getType(),
@@ -2208,14 +2208,15 @@ template<typename T> static int costAndCollectOperands(
   };
 
   auto ArithCost = [&](unsigned Opcode, unsigned NumRequired,
-                       unsigned MinIdx = 0, unsigned MaxIdx = 1) {
+                       unsigned MinIdx = 0,
+                       unsigned MaxIdx = 1) -> InstructionCost {
     Operations.emplace_back(Opcode, MinIdx, MaxIdx);
     return NumRequired *
       TTI.getArithmeticInstrCost(Opcode, S->getType(), CostKind);
   };
 
-  auto CmpSelCost = [&](unsigned Opcode, unsigned NumRequired,
-                        unsigned MinIdx, unsigned MaxIdx) {
+  auto CmpSelCost = [&](unsigned Opcode, unsigned NumRequired, unsigned MinIdx,
+                        unsigned MaxIdx) -> InstructionCost {
     Operations.emplace_back(Opcode, MinIdx, MaxIdx);
     Type *OpType = S->getOperand(0)->getType();
     return NumRequired * TTI.getCmpSelInstrCost(
@@ -2286,10 +2287,11 @@ template<typename T> static int costAndCollectOperands(
 
     // Much like with normal add expr, the polynominal will require
     // one less addition than the number of it's terms.
-    int AddCost = ArithCost(Instruction::Add, NumTerms - 1,
-                            /*MinIdx*/1, /*MaxIdx*/1);
+    InstructionCost AddCost = ArithCost(Instruction::Add, NumTerms - 1,
+                                        /*MinIdx*/ 1, /*MaxIdx*/ 1);
     // Here, *each* one of those will require a multiplication.
-    int MulCost = ArithCost(Instruction::Mul, NumNonZeroDegreeNonOneTerms);
+    InstructionCost MulCost =
+        ArithCost(Instruction::Mul, NumNonZeroDegreeNonOneTerms);
     Cost = AddCost + MulCost;
 
     // What is the degree of this polynominal?
@@ -2320,10 +2322,10 @@ template<typename T> static int costAndCollectOperands(
 
 bool SCEVExpander::isHighCostExpansionHelper(
     const SCEVOperand &WorkItem, Loop *L, const Instruction &At,
-    int &BudgetRemaining, const TargetTransformInfo &TTI,
+    InstructionCost &Cost, unsigned Budget, const TargetTransformInfo &TTI,
     SmallPtrSetImpl<const SCEV *> &Processed,
     SmallVectorImpl<SCEVOperand> &Worklist) {
-  if (BudgetRemaining < 0)
+  if (Cost > Budget)
     return true; // Already run out of budget, give up.
 
   const SCEV *S = WorkItem.S;
@@ -2353,17 +2355,16 @@ bool SCEVExpander::isHighCostExpansionHelper(
       return 0;
     const APInt &Imm = cast<SCEVConstant>(S)->getAPInt();
     Type *Ty = S->getType();
-    BudgetRemaining -= TTI.getIntImmCostInst(
+    Cost += TTI.getIntImmCostInst(
         WorkItem.ParentOpcode, WorkItem.OperandIdx, Imm, Ty, CostKind);
-    return BudgetRemaining < 0;
+    return Cost > Budget;
   }
   case scTruncate:
   case scPtrToInt:
   case scZeroExtend:
   case scSignExtend: {
-    int Cost =
+    Cost +=
         costAndCollectOperands<SCEVCastExpr>(WorkItem, TTI, CostKind, Worklist);
-    BudgetRemaining -= Cost;
     return false; // Will answer upon next entry into this function.
   }
   case scUDivExpr: {
@@ -2379,10 +2380,8 @@ bool SCEVExpander::isHighCostExpansionHelper(
             SE.getAddExpr(S, SE.getConstant(S->getType(), 1)), &At, L))
       return false; // Consider it to be free.
 
-    int Cost =
+    Cost +=
         costAndCollectOperands<SCEVUDivExpr>(WorkItem, TTI, CostKind, Worklist);
-    // Need to count the cost of this UDiv.
-    BudgetRemaining -= Cost;
     return false; // Will answer upon next entry into this function.
   }
   case scAddExpr:
@@ -2395,17 +2394,16 @@ bool SCEVExpander::isHighCostExpansionHelper(
            "Nary expr should have more than 1 operand.");
     // The simple nary expr will require one less op (or pair of ops)
     // than the number of it's terms.
-    int Cost =
+    Cost +=
         costAndCollectOperands<SCEVNAryExpr>(WorkItem, TTI, CostKind, Worklist);
-    BudgetRemaining -= Cost;
-    return BudgetRemaining < 0;
+    return Cost > Budget;
   }
   case scAddRecExpr: {
     assert(cast<SCEVAddRecExpr>(S)->getNumOperands() >= 2 &&
            "Polynomial should be at least linear");
-    BudgetRemaining -= costAndCollectOperands<SCEVAddRecExpr>(
+    Cost += costAndCollectOperands<SCEVAddRecExpr>(
         WorkItem, TTI, CostKind, Worklist);
-    return BudgetRemaining < 0;
+    return Cost > Budget;
   }
   }
   llvm_unreachable("Unknown SCEV kind!");