[SCEV] Introduce bulk umin creation utilities
authorSerguei Katkov <serguei.katkov@azul.com>
Fri, 27 Apr 2018 03:56:53 +0000 (03:56 +0000)
committerSerguei Katkov <serguei.katkov@azul.com>
Fri, 27 Apr 2018 03:56:53 +0000 (03:56 +0000)
Add new umin creation method which accepts a list of operands.

SCEV does not represents umin which is required in getExact, so
it transforms umin to umax with not. As a result the transformation of
tree of max to max with several operands does not work.
We just use the new introduced method for creation umin from several operands.

Reviewers: sanjoy, mkazantsev
Reviewed By: sanjoy
Subscribers: javed.absar, llvm-commits
Differential Revision: https://reviews.llvm.org/D46047

llvm-svn: 331015

llvm/include/llvm/Analysis/ScalarEvolution.h
llvm/lib/Analysis/ScalarEvolution.cpp

index 0fef451..20d61d0 100644 (file)
@@ -587,7 +587,9 @@ public:
   const SCEV *getUMaxExpr(const SCEV *LHS, const SCEV *RHS);
   const SCEV *getUMaxExpr(SmallVectorImpl<const SCEV *> &Operands);
   const SCEV *getSMinExpr(const SCEV *LHS, const SCEV *RHS);
+  const SCEV *getSMinExpr(SmallVectorImpl<const SCEV *> &Operands);
   const SCEV *getUMinExpr(const SCEV *LHS, const SCEV *RHS);
+  const SCEV *getUMinExpr(SmallVectorImpl<const SCEV *> &Operands);
   const SCEV *getUnknown(Value *V);
   const SCEV *getCouldNotCompute();
 
@@ -650,6 +652,10 @@ public:
   /// then perform a umin operation with them.
   const SCEV *getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS);
 
+  /// Promote the operands to the wider of the types using zero-extension, and
+  /// then perform a umin operation with them. N-ary function.
+  const SCEV *getUMinFromMismatchedTypes(SmallVectorImpl<const SCEV *> &Ops);
+
   /// Transitively follow the chain of pointer-type operands until reaching a
   /// SCEV that does not have a single pointer operand. This returns a
   /// SCEVUnknown pointer for well-formed pointer-type expressions, but corner
index d6fee9b..f1ea3a8 100644 (file)
@@ -3548,14 +3548,30 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
 
 const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
                                          const SCEV *RHS) {
-  // ~smax(~x, ~y) == smin(x, y).
-  return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
+  SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
+  return getSMinExpr(Ops);
+}
+
+const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl<const SCEV *> &Ops) {
+  // ~smax(~x, ~y, ~z) == smin(x, y, z).
+  SmallVector<const SCEV *, 2> NotOps;
+  for (auto *S : Ops)
+    NotOps.push_back(getNotSCEV(S));
+  return getNotSCEV(getSMaxExpr(NotOps));
 }
 
 const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
                                          const SCEV *RHS) {
-  // ~umax(~x, ~y) == umin(x, y)
-  return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
+  SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
+  return getUMinExpr(Ops);
+}
+
+const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops) {
+  // ~umax(~x, ~y, ~z) == umin(x, y, z).
+  SmallVector<const SCEV *, 2> NotOps;
+  for (auto *S : Ops)
+    NotOps.push_back(getNotSCEV(S));
+  return getNotSCEV(getUMaxExpr(NotOps));
 }
 
 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
@@ -3943,15 +3959,27 @@ const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
 
 const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
                                                         const SCEV *RHS) {
-  const SCEV *PromotedLHS = LHS;
-  const SCEV *PromotedRHS = RHS;
+  SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
+  return getUMinFromMismatchedTypes(Ops);
+}
 
-  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
-    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
-  else
-    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
+const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(
+    SmallVectorImpl<const SCEV *> &Ops) {
+  // Find the max type first.
+  Type *MaxType = nullptr;
+  for (auto *S : Ops)
+    if (MaxType)
+      MaxType = getWiderType(MaxType, S->getType());
+    else
+      MaxType = S->getType();
+
+  // Extend all ops to max type.
+  SmallVector<const SCEV *, 2> PromotedOps;
+  for (auto *S : Ops)
+    PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
 
-  return getUMinExpr(PromotedLHS, PromotedRHS);
+  // Generate umin.
+  return getUMinExpr(PromotedOps);
 }
 
 const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
@@ -6666,7 +6694,6 @@ ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
   if (!isComplete() || ExitNotTaken.empty())
     return SE->getCouldNotCompute();
 
-  const SCEV *BECount = nullptr;
   const BasicBlock *Latch = L->getLoopLatch();
   // All exiting blocks we have collected must dominate the only backedge.
   if (!Latch)
@@ -6674,16 +6701,15 @@ ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
 
   // All exiting blocks we have gathered dominate loop's latch, so exact trip
   // count is simply a minimum out of all these calculated exit counts.
+  SmallVector<const SCEV *, 2> Ops;
   for (auto &ENT : ExitNotTaken) {
-    assert(ENT.ExactNotTaken != SE->getCouldNotCompute() && "Bad exit SCEV!");
+    const SCEV *BECount = ENT.ExactNotTaken;
+    assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
     assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
            "We should only have known counts for exiting blocks that dominate "
            "latch!");
 
-    if (!BECount)
-      BECount = ENT.ExactNotTaken;
-    else if (BECount != ENT.ExactNotTaken)
-      BECount = SE->getUMinFromMismatchedTypes(BECount, ENT.ExactNotTaken);
+    Ops.push_back(BECount);
 
     if (Preds && !ENT.hasAlwaysTruePredicate())
       Preds->add(ENT.Predicate.get());
@@ -6692,8 +6718,8 @@ ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
            "Predicate should be always true!");
   }
 
-  assert(BECount && "Invalid not taken count for loop exit");
-  return BECount;
+  assert(!Ops.empty() && "Loop without exits");
+  return Ops.size() == 1 ? Ops[0] : SE->getUMinFromMismatchedTypes(Ops);
 }
 
 /// Get the exact not taken count for this loop exit.