Rewrite SCEV Normalization using SCEVRewriteVisitor; NFC
authorSanjoy Das <sanjoy@playingwithpointers.com>
Fri, 14 Apr 2017 17:42:10 +0000 (17:42 +0000)
committerSanjoy Das <sanjoy@playingwithpointers.com>
Fri, 14 Apr 2017 17:42:10 +0000 (17:42 +0000)
Removes all of the boilerplate, cache management etc. from
ScalarEvolutionNormalization, and keeps only the interesting bits.

llvm-svn: 300349

llvm/lib/Analysis/ScalarEvolutionNormalization.cpp

index 949281e..2aaa4c1 100644 (file)
@@ -27,124 +27,63 @@ enum TransformKind {
   Denormalize
 };
 
-typedef DenseMap<const SCEV *, const SCEV *> NormalizedCacheTy;
-
-static const SCEV *transformSubExpr(const TransformKind Kind,
-                                    NormalizePredTy Pred, ScalarEvolution &SE,
-                                    NormalizedCacheTy &Cache, const SCEV *S);
-
-/// Implement post-inc transformation for all valid expression types.
-static const SCEV *transformImpl(const TransformKind Kind, NormalizePredTy Pred,
-                                 ScalarEvolution &SE, NormalizedCacheTy &Cache,
-                                 const SCEV *S) {
-  if (const SCEVCastExpr *X = dyn_cast<SCEVCastExpr>(S)) {
-    const SCEV *O = X->getOperand();
-    const SCEV *N = transformSubExpr(Kind, Pred, SE, Cache, O);
-    if (O != N)
-      switch (S->getSCEVType()) {
-      case scZeroExtend: return SE.getZeroExtendExpr(N, S->getType());
-      case scSignExtend: return SE.getSignExtendExpr(N, S->getType());
-      case scTruncate: return SE.getTruncateExpr(N, S->getType());
-      default: llvm_unreachable("Unexpected SCEVCastExpr kind!");
-      }
-    return S;
-  }
-
-  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
-    // An addrec. This is the interesting part.
-    SmallVector<const SCEV *, 8> Operands;
-
-    transform(AR->operands(), std::back_inserter(Operands),
-              [&](const SCEV *Op) {
-                return transformSubExpr(Kind, Pred, SE, Cache, Op);
-              });
-
-    // Conservatively use AnyWrap until/unless we need FlagNW.
-    const SCEV *Result =
-        SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
-    switch (Kind) {
-    case Normalize:
-      // We want to normalize step expression, because otherwise we might not be
-      // able to denormalize to the original expression.
-      //
-      // Here is an example what will happen if we don't normalize step:
-      //  ORIGINAL ISE:
-      //    {(100 /u {1,+,1}<%bb16>),+,(100 /u {1,+,1}<%bb16>)}<%bb25>
-      //  NORMALIZED ISE:
-      //    {((-1 * (100 /u {1,+,1}<%bb16>)) + (100 /u {0,+,1}<%bb16>)),+,
-      //     (100 /u {0,+,1}<%bb16>)}<%bb25>
-      //  DENORMALIZED BACK ISE:
-      //    {((2 * (100 /u {1,+,1}<%bb16>)) + (-1 * (100 /u {2,+,1}<%bb16>))),+,
-      //     (100 /u {1,+,1}<%bb16>)}<%bb25>
-      //  Note that the initial value changes after normalization +
-      //  denormalization, which isn't correct.
-      if (Pred(AR)) {
-        const SCEV *TransformedStep =
-            transformSubExpr(Kind, Pred, SE, Cache, AR->getStepRecurrence(SE));
-        Result = SE.getMinusSCEV(Result, TransformedStep);
-      }
-      break;
-    case Denormalize:
-      // Here we want to normalize step expressions for the same reasons, as
-      // stated above.
-      if (Pred(AR)) {
-        const SCEV *TransformedStep =
-            transformSubExpr(Kind, Pred, SE, Cache, AR->getStepRecurrence(SE));
-        Result = SE.getAddExpr(Result, TransformedStep);
-      }
-      break;
+namespace {
+struct NormalizeDenormalizeRewriter
+    : public SCEVRewriteVisitor<NormalizeDenormalizeRewriter> {
+  const TransformKind Kind;
+
+  // NB! Pred is a function_ref.  Storing it here is okay only because
+  // we're careful about the lifetime of NormalizeDenormalizeRewriter.
+  const NormalizePredTy Pred;
+
+  NormalizeDenormalizeRewriter(TransformKind Kind, NormalizePredTy Pred,
+                               ScalarEvolution &SE)
+      : SCEVRewriteVisitor<NormalizeDenormalizeRewriter>(SE), Kind(Kind),
+        Pred(Pred) {}
+  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr);
+};
+} // namespace
+
+const SCEV *
+NormalizeDenormalizeRewriter::visitAddRecExpr(const SCEVAddRecExpr *AR) {
+  SmallVector<const SCEV *, 8> Operands;
+
+  transform(AR->operands(), std::back_inserter(Operands),
+            [&](const SCEV *Op) { return visit(Op); });
+
+  // Conservatively use AnyWrap until/unless we need FlagNW.
+  const SCEV *Result =
+      SE.getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagAnyWrap);
+  switch (Kind) {
+  case Normalize:
+    // We want to normalize step expression, because otherwise we might not be
+    // able to denormalize to the original expression.
+    //
+    // Here is an example what will happen if we don't normalize step:
+    //  ORIGINAL ISE:
+    //    {(100 /u {1,+,1}<%bb16>),+,(100 /u {1,+,1}<%bb16>)}<%bb25>
+    //  NORMALIZED ISE:
+    //    {((-1 * (100 /u {1,+,1}<%bb16>)) + (100 /u {0,+,1}<%bb16>)),+,
+    //     (100 /u {0,+,1}<%bb16>)}<%bb25>
+    //  DENORMALIZED BACK ISE:
+    //    {((2 * (100 /u {1,+,1}<%bb16>)) + (-1 * (100 /u {2,+,1}<%bb16>))),+,
+    //     (100 /u {1,+,1}<%bb16>)}<%bb25>
+    //  Note that the initial value changes after normalization +
+    //  denormalization, which isn't correct.
+    if (Pred(AR)) {
+      const SCEV *TransformedStep = visit(AR->getStepRecurrence(SE));
+      Result = SE.getMinusSCEV(Result, TransformedStep);
     }
-    return Result;
-  }
-
-  if (const SCEVNAryExpr *X = dyn_cast<SCEVNAryExpr>(S)) {
-    SmallVector<const SCEV *, 8> Operands;
-    bool Changed = false;
-    // Transform each operand.
-    for (auto *O : X->operands()) {
-      const SCEV *N = transformSubExpr(Kind, Pred, SE, Cache, O);
-      Changed |= N != O;
-      Operands.push_back(N);
+    break;
+  case Denormalize:
+    // Here we want to normalize step expressions for the same reasons, as
+    // stated above.
+    if (Pred(AR)) {
+      const SCEV *TransformedStep = visit(AR->getStepRecurrence(SE));
+      Result = SE.getAddExpr(Result, TransformedStep);
     }
-    // If any operand actually changed, return a transformed result.
-    if (Changed)
-      switch (S->getSCEVType()) {
-      case scAddExpr: return SE.getAddExpr(Operands);
-      case scMulExpr: return SE.getMulExpr(Operands);
-      case scSMaxExpr: return SE.getSMaxExpr(Operands);
-      case scUMaxExpr: return SE.getUMaxExpr(Operands);
-      default: llvm_unreachable("Unexpected SCEVNAryExpr kind!");
-      }
-    return S;
-  }
-
-  if (const SCEVUDivExpr *X = dyn_cast<SCEVUDivExpr>(S)) {
-    const SCEV *LO = X->getLHS();
-    const SCEV *RO = X->getRHS();
-    const SCEV *LN = transformSubExpr(Kind, Pred, SE, Cache, LO);
-    const SCEV *RN = transformSubExpr(Kind, Pred, SE, Cache, RO);
-    if (LO != LN || RO != RN)
-      return SE.getUDivExpr(LN, RN);
-    return S;
+    break;
   }
-
-  llvm_unreachable("Unexpected SCEV kind!");
-}
-
-/// Manage recursive transformation across an expression DAG. Revisiting
-/// expressions would lead to exponential recursion.
-static const SCEV *transformSubExpr(const TransformKind Kind,
-                                    NormalizePredTy Pred, ScalarEvolution &SE,
-                                    NormalizedCacheTy &Cache, const SCEV *S) {
-  if (isa<SCEVConstant>(S) || isa<SCEVUnknown>(S))
-    return S;
-
-  const SCEV *Result = Cache.lookup(S);
-  if (Result)
-    return Result;
-
-  Result = transformImpl(Kind, Pred, SE, Cache, S);
-  Cache[S] = Result;
   return Result;
 }
 
@@ -154,14 +93,12 @@ const SCEV *llvm::normalizeForPostIncUse(const SCEV *S,
   auto Pred = [&](const SCEVAddRecExpr *AR) {
     return Loops.count(AR->getLoop());
   };
-  NormalizedCacheTy Cache;
-  return transformSubExpr(Normalize, Pred, SE, Cache, S);
+  return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
 }
 
 const SCEV *llvm::normalizeForPostIncUseIf(const SCEV *S, NormalizePredTy Pred,
                                            ScalarEvolution &SE) {
-  NormalizedCacheTy Cache;
-  return transformSubExpr(Normalize, Pred, SE, Cache, S);
+  return NormalizeDenormalizeRewriter(Normalize, Pred, SE).visit(S);
 }
 
 const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
@@ -170,6 +107,5 @@ const SCEV *llvm::denormalizeForPostIncUse(const SCEV *S,
   auto Pred = [&](const SCEVAddRecExpr *AR) {
     return Loops.count(AR->getLoop());
   };
-  NormalizedCacheTy Cache;
-  return transformSubExpr(Denormalize, Pred, SE, Cache, S);
+  return NormalizeDenormalizeRewriter(Denormalize, Pred, SE).visit(S);
 }