[NFC][SCEV] `computeSCEVAtScope()`: deduplicate handling
authorRoman Lebedev <lebedev.ri@gmail.com>
Sun, 22 Jan 2023 13:32:02 +0000 (16:32 +0300)
committerRoman Lebedev <lebedev.ri@gmail.com>
Sun, 22 Jan 2023 14:40:52 +0000 (17:40 +0300)
Casts and udiv get the exactly the same handling as n-ary,
there is no point in special-handling anything.

llvm/lib/Analysis/ScalarEvolution.cpp

index 5798471..c6dc9c8 100644 (file)
@@ -9773,24 +9773,6 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
   switch (V->getSCEVType()) {
   case scConstant:
     return V;
-  case scTruncate:
-  case scZeroExtend:
-  case scSignExtend:
-  case scPtrToInt: {
-    const SCEVCastExpr *Cast = cast<SCEVCastExpr>(V);
-    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
-    if (Op == Cast->getOperand())
-      return Cast; // must be loop invariant
-    return getCastExpr(Cast->getSCEVType(), Op, Cast->getType());
-  }
-  case scUDivExpr: {
-    const SCEVUDivExpr *Div = cast<SCEVUDivExpr>(V);
-    const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
-    const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
-    if (LHS == Div->getLHS() && RHS == Div->getRHS())
-      return Div; // must be loop invariant
-    return getUDivExpr(LHS, RHS);
-  }
   case scAddRecExpr: {
     // If this is a loop recurrence for a loop that does not contain L, then we
     // are dealing with the final value computed by the loop.
@@ -9838,43 +9820,52 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
 
     return AddRec;
   }
+  case scTruncate:
+  case scZeroExtend:
+  case scSignExtend:
+  case scPtrToInt:
   case scAddExpr:
   case scMulExpr:
+  case scUDivExpr:
   case scUMaxExpr:
   case scSMaxExpr:
   case scUMinExpr:
   case scSMinExpr:
   case scSequentialUMinExpr: {
-    const auto *Comm = cast<SCEVNAryExpr>(V);
+    ArrayRef<const SCEV *> Ops = V->operands();
     // Avoid performing the look-up in the common case where the specified
     // expression has no loop-variant portions.
-    for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
-      const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
-      if (OpAtScope != Comm->getOperand(i)) {
+    for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
+      const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
+      if (OpAtScope != Ops[i]) {
         // Okay, at least one of these operands is loop variant but might be
         // foldable.  Build a new instance of the folded commutative expression.
         SmallVector<const SCEV *, 8> NewOps;
-        NewOps.reserve(Comm->getNumOperands());
-        append_range(NewOps, Comm->operands().take_front(i));
+        NewOps.reserve(Ops.size());
+        append_range(NewOps, Ops.take_front(i));
         NewOps.push_back(OpAtScope);
 
         for (++i; i != e; ++i) {
-          OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
+          OpAtScope = getSCEVAtScope(Ops[i], L);
           NewOps.push_back(OpAtScope);
         }
-        if (isa<SCEVAddExpr>(Comm))
-          return getAddExpr(NewOps, Comm->getNoWrapFlags());
-        if (isa<SCEVMulExpr>(Comm))
-          return getMulExpr(NewOps, Comm->getNoWrapFlags());
-        if (isa<SCEVMinMaxExpr>(Comm))
-          return getMinMaxExpr(Comm->getSCEVType(), NewOps);
-        if (isa<SCEVSequentialMinMaxExpr>(Comm))
-          return getSequentialMinMaxExpr(Comm->getSCEVType(), NewOps);
-        llvm_unreachable("Unknown commutative / sequential min/max SCEV type!");
+        if (isa<SCEVCastExpr>(V))
+          return getCastExpr(V->getSCEVType(), NewOps[0], V->getType());
+        if (isa<SCEVAddExpr>(V))
+          return getAddExpr(NewOps, cast<SCEVAddExpr>(V)->getNoWrapFlags());
+        if (isa<SCEVMulExpr>(V))
+          return getMulExpr(NewOps, cast<SCEVMulExpr>(V)->getNoWrapFlags());
+        if (isa<SCEVUDivExpr>(V))
+          return getUDivExpr(NewOps[0], NewOps[1]);
+        if (isa<SCEVMinMaxExpr>(V))
+          return getMinMaxExpr(V->getSCEVType(), NewOps);
+        if (isa<SCEVSequentialMinMaxExpr>(V))
+          return getSequentialMinMaxExpr(V->getSCEVType(), NewOps);
+        llvm_unreachable("Unknown n-ary-like SCEV type!");
       }
     }
     // If we got here, all operands are loop invariant.
-    return Comm;
+    return V;
   }
   case scUnknown: {
     // If this instruction is evolved from a constant-evolving PHI, compute the