[NFC][ScalarEvolution] Fix SCEVNAryExpr::getType().
authorEli Friedman <efriedma@quicinc.com>
Wed, 23 Jun 2021 19:42:47 +0000 (12:42 -0700)
committerEli Friedman <efriedma@quicinc.com>
Wed, 23 Jun 2021 19:55:59 +0000 (12:55 -0700)
SCEVNAryExpr::getType() could return the wrong type for a SCEVAddExpr.
Remove it, and add getType() methods to the relevant subclasses.

NFC because nothing uses it directly, as far as I know; this is just
future-proofing.

llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
llvm/lib/Analysis/ScalarEvolution.cpp

index ad0c747..c0da311 100644 (file)
@@ -210,8 +210,6 @@ class Type;
       return make_range(op_begin(), op_end());
     }
 
-    Type *getType() const { return getOperand(0)->getType(); }
-
     NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
       return (NoWrapFlags)(SubclassData & Mask);
     }
@@ -293,6 +291,8 @@ class Type;
       : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
 
   public:
+    Type *getType() const { return getOperand(0)->getType(); }
+
     /// Methods for support type inquiry through isa, cast, and dyn_cast:
     static bool classof(const SCEV *S) {
       return S->getSCEVType() == scMulExpr;
@@ -359,6 +359,7 @@ class Type;
       : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
 
   public:
+    Type *getType() const { return getStart()->getType(); }
     const SCEV *getStart() const { return Operands[0]; }
     const Loop *getLoop() const { return L; }
 
@@ -445,6 +446,8 @@ class Type;
     }
 
   public:
+    Type *getType() const { return getOperand(0)->getType(); }
+
     static bool classof(const SCEV *S) {
       return isMinMaxType(S->getSCEVType());
     }
index 79925ed..e330ed2 100644 (file)
@@ -386,12 +386,14 @@ Type *SCEV::getType() const {
   case scSignExtend:
     return cast<SCEVCastExpr>(this)->getType();
   case scAddRecExpr:
+    return cast<SCEVAddRecExpr>(this)->getType();
   case scMulExpr:
+    return cast<SCEVMulExpr>(this)->getType();
   case scUMaxExpr:
   case scSMaxExpr:
   case scUMinExpr:
   case scSMinExpr:
-    return cast<SCEVNAryExpr>(this)->getType();
+    return cast<SCEVMinMaxExpr>(this)->getType();
   case scAddExpr:
     return cast<SCEVAddExpr>(this)->getType();
   case scUDivExpr: