From 4c3322cc843c6233e78a123259808eeb37198aa9 Mon Sep 17 00:00:00 2001 From: Daniil Fukalov Date: Thu, 17 Nov 2016 16:07:52 +0000 Subject: [PATCH] [SCEV] limit recursion depth of CompareSCEVComplexity Summary: CompareSCEVComplexity goes too deep (50+ on a quite a big unrolled loop) and runs almost infinite time. Added cache of "equal" SCEV pairs to earlier cutoff of further estimation. Recursion depth limit was also introduced as a parameter. Reviewers: sanjoy Subscribers: mzolotukhin, tstellarAMD, llvm-commits Differential Revision: https://reviews.llvm.org/D26389 llvm-svn: 287232 --- llvm/lib/Analysis/ScalarEvolution.cpp | 61 +++++++++++++++------- llvm/unittests/Analysis/ScalarEvolutionTest.cpp | 67 +++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 17 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 0bb624c..8de2c15 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -127,6 +127,11 @@ static cl::opt MulOpsInlineThreshold( cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(1000)); +static cl::opt + MaxCompareDepth("scalar-evolution-max-compare-depth", cl::Hidden, + cl::desc("Maximum depth of recursive compare complexity"), + cl::init(32)); + //===----------------------------------------------------------------------===// // SCEV class definitions //===----------------------------------------------------------------------===// @@ -475,8 +480,8 @@ bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const { static int CompareValueComplexity(SmallSet, 8> &EqCache, const LoopInfo *const LI, Value *LV, Value *RV, - unsigned DepthLeft = 2) { - if (DepthLeft == 0 || EqCache.count({LV, RV})) + unsigned Depth) { + if (Depth > MaxCompareDepth || EqCache.count({LV, RV})) return 0; // Order pointer values after integer values. This helps SCEVExpander form @@ -537,21 +542,23 @@ CompareValueComplexity(SmallSet, 8> &EqCache, for (unsigned Idx : seq(0u, LNumOps)) { int Result = CompareValueComplexity(EqCache, LI, LInst->getOperand(Idx), - RInst->getOperand(Idx), DepthLeft - 1); + RInst->getOperand(Idx), Depth + 1); if (Result != 0) return Result; - EqCache.insert({LV, RV}); } } + EqCache.insert({LV, RV}); return 0; } // Return negative, zero, or positive, if LHS is less than, equal to, or greater // than RHS, respectively. A three-way result allows recursive comparisons to be // more efficient. -static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, - const SCEV *RHS) { +static int CompareSCEVComplexity( + SmallSet, 8> &EqCacheSCEV, + const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, + unsigned Depth = 0) { // Fast-path: SCEVs are uniqued so we can do a quick equality check. if (LHS == RHS) return 0; @@ -561,6 +568,8 @@ static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, if (LType != RType) return (int)LType - (int)RType; + if (Depth > MaxCompareDepth || EqCacheSCEV.count({LHS, RHS})) + return 0; // Aside from the getSCEVType() ordering, the particular ordering // isn't very important except that it's beneficial to be consistent, // so that (a + b) and (b + a) don't end up as different expressions. @@ -570,7 +579,11 @@ static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEVUnknown *RU = cast(RHS); SmallSet, 8> EqCache; - return CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue()); + int X = CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue(), + Depth + 1); + if (X == 0) + EqCacheSCEV.insert({LHS, RHS}); + return X; } case scConstant: { @@ -605,11 +618,12 @@ static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, // Lexicographically compare. for (unsigned i = 0; i != LNumOps; ++i) { - long X = CompareSCEVComplexity(LI, LA->getOperand(i), RA->getOperand(i)); + int X = CompareSCEVComplexity(EqCacheSCEV, LI, LA->getOperand(i), + RA->getOperand(i), Depth + 1); if (X != 0) return X; } - + EqCacheSCEV.insert({LHS, RHS}); return 0; } @@ -628,11 +642,13 @@ static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, for (unsigned i = 0; i != LNumOps; ++i) { if (i >= RNumOps) return 1; - long X = CompareSCEVComplexity(LI, LC->getOperand(i), RC->getOperand(i)); + int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(i), + RC->getOperand(i), Depth + 1); if (X != 0) return X; } - return (int)LNumOps - (int)RNumOps; + EqCacheSCEV.insert({LHS, RHS}); + return 0; } case scUDivExpr: { @@ -640,10 +656,15 @@ static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEVUDivExpr *RC = cast(RHS); // Lexicographically compare udiv expressions. - long X = CompareSCEVComplexity(LI, LC->getLHS(), RC->getLHS()); + int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getLHS(), RC->getLHS(), + Depth + 1); if (X != 0) return X; - return CompareSCEVComplexity(LI, LC->getRHS(), RC->getRHS()); + X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(), + Depth + 1); + if (X == 0) + EqCacheSCEV.insert({LHS, RHS}); + return X; } case scTruncate: @@ -653,7 +674,11 @@ static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEVCastExpr *RC = cast(RHS); // Compare cast expressions by operand. - return CompareSCEVComplexity(LI, LC->getOperand(), RC->getOperand()); + int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(), + RC->getOperand(), Depth + 1); + if (X == 0) + EqCacheSCEV.insert({LHS, RHS}); + return X; } case scCouldNotCompute: @@ -675,19 +700,21 @@ static int CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, static void GroupByComplexity(SmallVectorImpl &Ops, LoopInfo *LI) { if (Ops.size() < 2) return; // Noop + + SmallSet, 8> EqCache; if (Ops.size() == 2) { // This is the common case, which also happens to be trivially simple. // Special case it. const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; - if (CompareSCEVComplexity(LI, RHS, LHS) < 0) + if (CompareSCEVComplexity(EqCache, LI, RHS, LHS) < 0) std::swap(LHS, RHS); return; } // Do the rough sort by complexity. std::stable_sort(Ops.begin(), Ops.end(), - [LI](const SCEV *LHS, const SCEV *RHS) { - return CompareSCEVComplexity(LI, LHS, RHS) < 0; + [&EqCache, LI](const SCEV *LHS, const SCEV *RHS) { + return CompareSCEVComplexity(EqCache, LI, LHS, RHS) < 0; }); // Now that we are sorted by complexity, group elements of the same diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp index 6dcb18f..752cc81 100644 --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -465,5 +465,72 @@ TEST_F(ScalarEvolutionsTest, CommutativeExprOperandOrder) { }); } +TEST_F(ScalarEvolutionsTest, SCEVCompareComplexity) { + FunctionType *FTy = + FunctionType::get(Type::getVoidTy(Context), std::vector(), false); + Function *F = cast(M.getOrInsertFunction("f", FTy)); + BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F); + BasicBlock *LoopBB = BasicBlock::Create(Context, "bb1", F); + BranchInst::Create(LoopBB, EntryBB); + + auto *Ty = Type::getInt32Ty(Context); + SmallVector Muls(8), Acc(8), NextAcc(8); + + Acc[0] = PHINode::Create(Ty, 2, "", LoopBB); + Acc[1] = PHINode::Create(Ty, 2, "", LoopBB); + Acc[2] = PHINode::Create(Ty, 2, "", LoopBB); + Acc[3] = PHINode::Create(Ty, 2, "", LoopBB); + Acc[4] = PHINode::Create(Ty, 2, "", LoopBB); + Acc[5] = PHINode::Create(Ty, 2, "", LoopBB); + Acc[6] = PHINode::Create(Ty, 2, "", LoopBB); + Acc[7] = PHINode::Create(Ty, 2, "", LoopBB); + + for (int i = 0; i < 20; i++) { + Muls[0] = BinaryOperator::CreateMul(Acc[0], Acc[0], "", LoopBB); + NextAcc[0] = BinaryOperator::CreateAdd(Muls[0], Acc[4], "", LoopBB); + Muls[1] = BinaryOperator::CreateMul(Acc[1], Acc[1], "", LoopBB); + NextAcc[1] = BinaryOperator::CreateAdd(Muls[1], Acc[5], "", LoopBB); + Muls[2] = BinaryOperator::CreateMul(Acc[2], Acc[2], "", LoopBB); + NextAcc[2] = BinaryOperator::CreateAdd(Muls[2], Acc[6], "", LoopBB); + Muls[3] = BinaryOperator::CreateMul(Acc[3], Acc[3], "", LoopBB); + NextAcc[3] = BinaryOperator::CreateAdd(Muls[3], Acc[7], "", LoopBB); + + Muls[4] = BinaryOperator::CreateMul(Acc[4], Acc[4], "", LoopBB); + NextAcc[4] = BinaryOperator::CreateAdd(Muls[4], Acc[0], "", LoopBB); + Muls[5] = BinaryOperator::CreateMul(Acc[5], Acc[5], "", LoopBB); + NextAcc[5] = BinaryOperator::CreateAdd(Muls[5], Acc[1], "", LoopBB); + Muls[6] = BinaryOperator::CreateMul(Acc[6], Acc[6], "", LoopBB); + NextAcc[6] = BinaryOperator::CreateAdd(Muls[6], Acc[2], "", LoopBB); + Muls[7] = BinaryOperator::CreateMul(Acc[7], Acc[7], "", LoopBB); + NextAcc[7] = BinaryOperator::CreateAdd(Muls[7], Acc[3], "", LoopBB); + Acc = NextAcc; + } + + auto II = LoopBB->begin(); + for (int i = 0; i < 8; i++) { + PHINode *Phi = cast(&*II++); + Phi->addIncoming(Acc[i], LoopBB); + Phi->addIncoming(UndefValue::get(Ty), EntryBB); + } + + BasicBlock *ExitBB = BasicBlock::Create(Context, "bb2", F); + BranchInst::Create(LoopBB, ExitBB, UndefValue::get(Type::getInt1Ty(Context)), + LoopBB); + + Acc[0] = BinaryOperator::CreateAdd(Acc[0], Acc[1], "", ExitBB); + Acc[1] = BinaryOperator::CreateAdd(Acc[2], Acc[3], "", ExitBB); + Acc[2] = BinaryOperator::CreateAdd(Acc[4], Acc[5], "", ExitBB); + Acc[3] = BinaryOperator::CreateAdd(Acc[6], Acc[7], "", ExitBB); + Acc[0] = BinaryOperator::CreateAdd(Acc[0], Acc[1], "", ExitBB); + Acc[1] = BinaryOperator::CreateAdd(Acc[2], Acc[3], "", ExitBB); + Acc[0] = BinaryOperator::CreateAdd(Acc[0], Acc[1], "", ExitBB); + + ReturnInst::Create(Context, nullptr, ExitBB); + + ScalarEvolution SE = buildSE(*F); + + EXPECT_NE(nullptr, SE.getSCEV(Acc[0])); +} + } // end anonymous namespace } // end namespace llvm -- 2.7.4