From 3470e14ba47b338c604216a32e2c8345dcb94694 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Fri, 14 Apr 2017 17:42:10 +0000 Subject: [PATCH] Rewrite SCEV Normalization using SCEVRewriteVisitor; NFC Removes all of the boilerplate, cache management etc. from ScalarEvolutionNormalization, and keeps only the interesting bits. llvm-svn: 300349 --- llvm/lib/Analysis/ScalarEvolutionNormalization.cpp | 178 +++++++-------------- 1 file changed, 57 insertions(+), 121 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp b/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp index 949281e..2aaa4c1 100644 --- a/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionNormalization.cpp @@ -27,124 +27,63 @@ enum TransformKind { Denormalize }; -typedef DenseMap NormalizedCacheTy; - -static const SCEV *transformSubExpr(const TransformKind Kind, - NormalizePredTy Pred, ScalarEvolution &SE, - NormalizedCacheTy &Cache, const SCEV *S); - -/// Implement post-inc transformation for all valid expression types. -static const SCEV *transformImpl(const TransformKind Kind, NormalizePredTy Pred, - ScalarEvolution &SE, NormalizedCacheTy &Cache, - const SCEV *S) { - if (const SCEVCastExpr *X = dyn_cast(S)) { - const SCEV *O = X->getOperand(); - const SCEV *N = transformSubExpr(Kind, Pred, SE, Cache, O); - if (O != N) - switch (S->getSCEVType()) { - case scZeroExtend: return SE.getZeroExtendExpr(N, S->getType()); - case scSignExtend: return SE.getSignExtendExpr(N, S->getType()); - case scTruncate: return SE.getTruncateExpr(N, S->getType()); - default: llvm_unreachable("Unexpected SCEVCastExpr kind!"); - } - return S; - } - - if (const SCEVAddRecExpr *AR = dyn_cast(S)) { - // An addrec. This is the interesting part. - SmallVector Operands; - - transform(AR->operands(), std::back_inserter(Operands), - [&](const SCEV *Op) { - return transformSubExpr(Kind, Pred, SE, Cache, Op); - }); - - // Conservatively use AnyWrap until/unless we need FlagNW. - const SCEV *Result = - SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap); - switch (Kind) { - case Normalize: - // We want to normalize step expression, because otherwise we might not be - // able to denormalize to the original expression. - // - // Here is an example what will happen if we don't normalize step: - // ORIGINAL ISE: - // {(100 /u {1,+,1}<%bb16>),+,(100 /u {1,+,1}<%bb16>)}<%bb25> - // NORMALIZED ISE: - // {((-1 * (100 /u {1,+,1}<%bb16>)) + (100 /u {0,+,1}<%bb16>)),+, - // (100 /u {0,+,1}<%bb16>)}<%bb25> - // DENORMALIZED BACK ISE: - // {((2 * (100 /u {1,+,1}<%bb16>)) + (-1 * (100 /u {2,+,1}<%bb16>))),+, - // (100 /u {1,+,1}<%bb16>)}<%bb25> - // Note that the initial value changes after normalization + - // denormalization, which isn't correct. - if (Pred(AR)) { - const SCEV *TransformedStep = - transformSubExpr(Kind, Pred, SE, Cache, AR->getStepRecurrence(SE)); - Result = SE.getMinusSCEV(Result, TransformedStep); - } - break; - case Denormalize: - // Here we want to normalize step expressions for the same reasons, as - // stated above. - if (Pred(AR)) { - const SCEV *TransformedStep = - transformSubExpr(Kind, Pred, SE, Cache, AR->getStepRecurrence(SE)); - Result = SE.getAddExpr(Result, TransformedStep); - } - break; +namespace { +struct NormalizeDenormalizeRewriter + : public SCEVRewriteVisitor { + const TransformKind Kind; + + // NB! Pred is a function_ref. Storing it here is okay only because + // we're careful about the lifetime of NormalizeDenormalizeRewriter. + const NormalizePredTy Pred; + + NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred, + ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), Kind(Kind), + Pred(Pred) {} + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr); +}; +} // namespace + +const SCEV * +NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) { + SmallVector Operands; + + transform(AR->operands(), std::back_inserter(Operands), + [&](const SCEV *Op) { return visit(Op); }); + + // Conservatively use AnyWrap until/unless we need FlagNW. + const SCEV *Result = + SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap); + switch (Kind) { + case Normalize: + // We want to normalize step expression, because otherwise we might not be + // able to denormalize to the original expression. + // + // Here is an example what will happen if we don't normalize step: + // ORIGINAL ISE: + // {(100 /u {1,+,1}<%bb16>),+,(100 /u {1,+,1}<%bb16>)}<%bb25> + // NORMALIZED ISE: + // {((-1 * (100 /u {1,+,1}<%bb16>)) + (100 /u {0,+,1}<%bb16>)),+, + // (100 /u {0,+,1}<%bb16>)}<%bb25> + // DENORMALIZED BACK ISE: + // {((2 * (100 /u {1,+,1}<%bb16>)) + (-1 * (100 /u {2,+,1}<%bb16>))),+, + // (100 /u {1,+,1}<%bb16>)}<%bb25> + // Note that the initial value changes after normalization + + // denormalization, which isn't correct. + if (Pred(AR)) { + const SCEV *TransformedStep = visit(AR->getStepRecurrence(SE)); + Result = SE.getMinusSCEV(Result, TransformedStep); } - return Result; - } - - if (const SCEVNAryExpr *X = dyn_cast(S)) { - SmallVector Operands; - bool Changed = false; - // Transform each operand. - for (auto *O : X->operands()) { - const SCEV *N = transformSubExpr(Kind, Pred, SE, Cache, O); - Changed |= N != O; - Operands.push_back(N); + break; + case Denormalize: + // Here we want to normalize step expressions for the same reasons, as + // stated above. + if (Pred(AR)) { + const SCEV *TransformedStep = visit(AR->getStepRecurrence(SE)); + Result = SE.getAddExpr(Result, TransformedStep); } - // If any operand actually changed, return a transformed result. - if (Changed) - switch (S->getSCEVType()) { - case scAddExpr: return SE.getAddExpr(Operands); - case scMulExpr: return SE.getMulExpr(Operands); - case scSMaxExpr: return SE.getSMaxExpr(Operands); - case scUMaxExpr: return SE.getUMaxExpr(Operands); - default: llvm_unreachable("Unexpected SCEVNAryExpr kind!"); - } - return S; - } - - if (const SCEVUDivExpr *X = dyn_cast(S)) { - const SCEV *LO = X->getLHS(); - const SCEV *RO = X->getRHS(); - const SCEV *LN = transformSubExpr(Kind, Pred, SE, Cache, LO); - const SCEV *RN = transformSubExpr(Kind, Pred, SE, Cache, RO); - if (LO != LN || RO != RN) - return SE.getUDivExpr(LN, RN); - return S; + break; } - - llvm_unreachable("Unexpected SCEV kind!"); -} - -/// Manage recursive transformation across an expression DAG. Revisiting -/// expressions would lead to exponential recursion. -static const SCEV *transformSubExpr(const TransformKind Kind, - NormalizePredTy Pred, ScalarEvolution &SE, - NormalizedCacheTy &Cache, const SCEV *S) { - if (isa(S) || isa(S)) - return S; - - const SCEV *Result = Cache.lookup(S); - if (Result) - return Result; - - Result = transformImpl(Kind, Pred, SE, Cache, S); - Cache[S] = Result; return Result; } @@ -154,14 +93,12 @@ const SCEV *llvm::normalizeForPostIncUse(const SCEV *S, auto Pred = [&](const SCEVAddRecExpr *AR) { return Loops.count(AR->getLoop()); }; - NormalizedCacheTy Cache; - return transformSubExpr(Normalize, Pred, SE, Cache, S); + return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); } const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred, ScalarEvolution &SE) { - NormalizedCacheTy Cache; - return transformSubExpr(Normalize, Pred, SE, Cache, S); + return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S); } const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S, @@ -170,6 +107,5 @@ const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S, auto Pred = [&](const SCEVAddRecExpr *AR) { return Loops.count(AR->getLoop()); }; - NormalizedCacheTy Cache; - return transformSubExpr(Denormalize, Pred, SE, Cache, S); + return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S); } -- 2.7.4