[NFCI] Introduce `ICmpInst::compare()` and use it where appropriate
authorRoman Lebedev <lebedev.ri@gmail.com>
Sat, 30 Oct 2021 14:36:23 +0000 (17:36 +0300)
committerRoman Lebedev <lebedev.ri@gmail.com>
Sat, 30 Oct 2021 14:50:06 +0000 (17:50 +0300)
As noted in https://reviews.llvm.org/D90924#inline-1076197
apparently this is a pretty common pattern,
let's not repeat it yet again, but have it in a common place.

There may be some more places where it could be used,
but these are the most obvious ones.

llvm/include/llvm/CodeGen/Analysis.h
llvm/include/llvm/IR/Instructions.h
llvm/include/llvm/IR/PatternMatch.h
llvm/lib/CodeGen/Analysis.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/lib/IR/ConstantFold.cpp
llvm/lib/IR/Instructions.cpp
llvm/lib/Transforms/IPO/AttributorAttributes.cpp
llvm/unittests/IR/ConstantRangeTest.cpp

index bdfb416..6044232 100644 (file)
@@ -104,9 +104,12 @@ ISD::CondCode getFCmpCodeWithoutNaN(ISD::CondCode CC);
 
 /// getICmpCondCode - Return the ISD condition code corresponding to
 /// the given LLVM IR integer condition code.
-///
 ISD::CondCode getICmpCondCode(ICmpInst::Predicate Pred);
 
+/// getICmpCondCode - Return the LLVM IR integer condition code
+/// corresponding to the given ISD integer condition code.
+ICmpInst::Predicate getICmpCondCode(ISD::CondCode Pred);
+
 /// Test if the given instruction is in a position to be optimized
 /// with a tail-call. This roughly means that it's in a block with
 /// a return and there's nothing that needs to be scheduled
index aa8613f..3c7911d 100644 (file)
@@ -1349,6 +1349,10 @@ public:
     Op<0>().swap(Op<1>());
   }
 
+  /// Return result of `LHS Pred RHS` comparison.
+  static bool compare(const APInt &LHS, const APInt &RHS,
+                      ICmpInst::Predicate Pred);
+
   // Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const Instruction *I) {
     return I->getOpcode() == Instruction::ICmp;
index 6096cbd..483e927 100644 (file)
@@ -593,32 +593,7 @@ inline cst_pred_ty<is_lowbit_mask> m_LowBitMask() {
 struct icmp_pred_with_threshold {
   ICmpInst::Predicate Pred;
   const APInt *Thr;
-  bool isValue(const APInt &C) {
-    switch (Pred) {
-    case ICmpInst::Predicate::ICMP_EQ:
-      return C.eq(*Thr);
-    case ICmpInst::Predicate::ICMP_NE:
-      return C.ne(*Thr);
-    case ICmpInst::Predicate::ICMP_UGT:
-      return C.ugt(*Thr);
-    case ICmpInst::Predicate::ICMP_UGE:
-      return C.uge(*Thr);
-    case ICmpInst::Predicate::ICMP_ULT:
-      return C.ult(*Thr);
-    case ICmpInst::Predicate::ICMP_ULE:
-      return C.ule(*Thr);
-    case ICmpInst::Predicate::ICMP_SGT:
-      return C.sgt(*Thr);
-    case ICmpInst::Predicate::ICMP_SGE:
-      return C.sge(*Thr);
-    case ICmpInst::Predicate::ICMP_SLT:
-      return C.slt(*Thr);
-    case ICmpInst::Predicate::ICMP_SLE:
-      return C.sle(*Thr);
-    default:
-      llvm_unreachable("Unhandled ICmp predicate");
-    }
-  }
+  bool isValue(const APInt &C) { return ICmpInst::compare(C, *Thr, Pred); }
 };
 /// Match an integer or vector with every element comparing 'pred' (eg/ne/...)
 /// to Threshold. For vectors, this includes constants with undefined elements.
index c21db82..7d8a73e 100644 (file)
@@ -221,9 +221,6 @@ ISD::CondCode llvm::getFCmpCodeWithoutNaN(ISD::CondCode CC) {
   }
 }
 
-/// getICmpCondCode - Return the ISD condition code corresponding to
-/// the given LLVM IR integer condition code.
-///
 ISD::CondCode llvm::getICmpCondCode(ICmpInst::Predicate Pred) {
   switch (Pred) {
   case ICmpInst::ICMP_EQ:  return ISD::SETEQ;
@@ -241,6 +238,33 @@ ISD::CondCode llvm::getICmpCondCode(ICmpInst::Predicate Pred) {
   }
 }
 
+ICmpInst::Predicate llvm::getICmpCondCode(ISD::CondCode Pred) {
+  switch (Pred) {
+  case ISD::SETEQ:
+    return ICmpInst::ICMP_EQ;
+  case ISD::SETNE:
+    return ICmpInst::ICMP_NE;
+  case ISD::SETLE:
+    return ICmpInst::ICMP_SLE;
+  case ISD::SETULE:
+    return ICmpInst::ICMP_ULE;
+  case ISD::SETGE:
+    return ICmpInst::ICMP_SGE;
+  case ISD::SETUGE:
+    return ICmpInst::ICMP_UGE;
+  case ISD::SETLT:
+    return ICmpInst::ICMP_SLT;
+  case ISD::SETULT:
+    return ICmpInst::ICMP_ULT;
+  case ISD::SETGT:
+    return ICmpInst::ICMP_SGT;
+  case ISD::SETUGT:
+    return ICmpInst::ICMP_UGT;
+  default:
+    llvm_unreachable("Invalid ISD integer condition code!");
+  }
+}
+
 static bool isNoopBitcast(Type *T1, Type *T2,
                           const TargetLoweringBase& TLI) {
   return T1 == T2 || (T1->isPointerTy() && T2->isPointerTy()) ||
index b928fd3..1c25f5f 100644 (file)
@@ -28,6 +28,7 @@
 #include "llvm/Analysis/MemoryLocation.h"
 #include "llvm/Analysis/ProfileSummaryInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/FunctionLoweringInfo.h"
 #include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
@@ -2312,19 +2313,8 @@ SDValue SelectionDAG::FoldSetCC(EVT VT, SDValue N1, SDValue N2,
     if (ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1)) {
       const APInt &C1 = N1C->getAPIntValue();
 
-      switch (Cond) {
-      default: llvm_unreachable("Unknown integer setcc!");
-      case ISD::SETEQ:  return getBoolConstant(C1 == C2, dl, VT, OpVT);
-      case ISD::SETNE:  return getBoolConstant(C1 != C2, dl, VT, OpVT);
-      case ISD::SETULT: return getBoolConstant(C1.ult(C2), dl, VT, OpVT);
-      case ISD::SETUGT: return getBoolConstant(C1.ugt(C2), dl, VT, OpVT);
-      case ISD::SETULE: return getBoolConstant(C1.ule(C2), dl, VT, OpVT);
-      case ISD::SETUGE: return getBoolConstant(C1.uge(C2), dl, VT, OpVT);
-      case ISD::SETLT:  return getBoolConstant(C1.slt(C2), dl, VT, OpVT);
-      case ISD::SETGT:  return getBoolConstant(C1.sgt(C2), dl, VT, OpVT);
-      case ISD::SETLE:  return getBoolConstant(C1.sle(C2), dl, VT, OpVT);
-      case ISD::SETGE:  return getBoolConstant(C1.sge(C2), dl, VT, OpVT);
-      }
+      return getBoolConstant(ICmpInst::compare(C1, C2, getICmpCondCode(Cond)),
+                             dl, VT, OpVT);
     }
   }
 
index e7357a8..7c49adb 100644 (file)
@@ -1792,19 +1792,8 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
   if (isa<ConstantInt>(C1) && isa<ConstantInt>(C2)) {
     const APInt &V1 = cast<ConstantInt>(C1)->getValue();
     const APInt &V2 = cast<ConstantInt>(C2)->getValue();
-    switch (pred) {
-    default: llvm_unreachable("Invalid ICmp Predicate");
-    case ICmpInst::ICMP_EQ:  return ConstantInt::get(ResultTy, V1 == V2);
-    case ICmpInst::ICMP_NE:  return ConstantInt::get(ResultTy, V1 != V2);
-    case ICmpInst::ICMP_SLT: return ConstantInt::get(ResultTy, V1.slt(V2));
-    case ICmpInst::ICMP_SGT: return ConstantInt::get(ResultTy, V1.sgt(V2));
-    case ICmpInst::ICMP_SLE: return ConstantInt::get(ResultTy, V1.sle(V2));
-    case ICmpInst::ICMP_SGE: return ConstantInt::get(ResultTy, V1.sge(V2));
-    case ICmpInst::ICMP_ULT: return ConstantInt::get(ResultTy, V1.ult(V2));
-    case ICmpInst::ICMP_UGT: return ConstantInt::get(ResultTy, V1.ugt(V2));
-    case ICmpInst::ICMP_ULE: return ConstantInt::get(ResultTy, V1.ule(V2));
-    case ICmpInst::ICMP_UGE: return ConstantInt::get(ResultTy, V1.uge(V2));
-    }
+    return ConstantInt::get(
+        ResultTy, ICmpInst::compare(V1, V2, (ICmpInst::Predicate)pred));
   } else if (isa<ConstantFP>(C1) && isa<ConstantFP>(C2)) {
     const APFloat &C1V = cast<ConstantFP>(C1)->getValueAPF();
     const APFloat &C2V = cast<ConstantFP>(C2)->getValueAPF();
index 83be432..1bc4ebc 100644 (file)
@@ -4055,6 +4055,35 @@ bool CmpInst::isSigned(Predicate predicate) {
   }
 }
 
+bool ICmpInst::compare(const APInt &LHS, const APInt &RHS,
+                       ICmpInst::Predicate Pred) {
+  assert(ICmpInst::isIntPredicate(Pred) && "Only for integer predicates!");
+  switch (Pred) {
+  case ICmpInst::Predicate::ICMP_EQ:
+    return LHS.eq(RHS);
+  case ICmpInst::Predicate::ICMP_NE:
+    return LHS.ne(RHS);
+  case ICmpInst::Predicate::ICMP_UGT:
+    return LHS.ugt(RHS);
+  case ICmpInst::Predicate::ICMP_UGE:
+    return LHS.uge(RHS);
+  case ICmpInst::Predicate::ICMP_ULT:
+    return LHS.ult(RHS);
+  case ICmpInst::Predicate::ICMP_ULE:
+    return LHS.ule(RHS);
+  case ICmpInst::Predicate::ICMP_SGT:
+    return LHS.sgt(RHS);
+  case ICmpInst::Predicate::ICMP_SGE:
+    return LHS.sge(RHS);
+  case ICmpInst::Predicate::ICMP_SLT:
+    return LHS.slt(RHS);
+  case ICmpInst::Predicate::ICMP_SLE:
+    return LHS.sle(RHS);
+  default:
+    llvm_unreachable("Unexpected non-integer predicate.");
+  };
+}
+
 CmpInst::Predicate CmpInst::getFlippedSignednessPredicate(Predicate pred) {
   assert(CmpInst::isRelational(pred) &&
          "Call only with non-equality predicates!");
index 06f3991..badb118 100644 (file)
@@ -8651,31 +8651,7 @@ struct AAPotentialValuesFloating : AAPotentialValuesImpl {
 
   static bool calculateICmpInst(const ICmpInst *ICI, const APInt &LHS,
                                 const APInt &RHS) {
-    ICmpInst::Predicate Pred = ICI->getPredicate();
-    switch (Pred) {
-    case ICmpInst::ICMP_UGT:
-      return LHS.ugt(RHS);
-    case ICmpInst::ICMP_SGT:
-      return LHS.sgt(RHS);
-    case ICmpInst::ICMP_EQ:
-      return LHS.eq(RHS);
-    case ICmpInst::ICMP_UGE:
-      return LHS.uge(RHS);
-    case ICmpInst::ICMP_SGE:
-      return LHS.sge(RHS);
-    case ICmpInst::ICMP_ULT:
-      return LHS.ult(RHS);
-    case ICmpInst::ICMP_SLT:
-      return LHS.slt(RHS);
-    case ICmpInst::ICMP_NE:
-      return LHS.ne(RHS);
-    case ICmpInst::ICMP_ULE:
-      return LHS.ule(RHS);
-    case ICmpInst::ICMP_SLE:
-      return LHS.sle(RHS);
-    default:
-      llvm_unreachable("Invalid ICmp predicate!");
-    }
+    return ICmpInst::compare(LHS, RHS, ICI->getPredicate());
   }
 
   static APInt calculateCastInst(const CastInst *CI, const APInt &Src,
index 2153365..5e6bc88 100644 (file)
@@ -1557,41 +1557,15 @@ TEST(ConstantRange, MakeSatisfyingICmpRegion) {
       ConstantRange(APInt(8, 4), APInt(8, -128)));
 }
 
-static bool icmp(CmpInst::Predicate Pred, const APInt &LHS, const APInt &RHS) {
-  switch (Pred) {
-  case CmpInst::Predicate::ICMP_EQ:
-    return LHS.eq(RHS);
-  case CmpInst::Predicate::ICMP_NE:
-    return LHS.ne(RHS);
-  case CmpInst::Predicate::ICMP_UGT:
-    return LHS.ugt(RHS);
-  case CmpInst::Predicate::ICMP_UGE:
-    return LHS.uge(RHS);
-  case CmpInst::Predicate::ICMP_ULT:
-    return LHS.ult(RHS);
-  case CmpInst::Predicate::ICMP_ULE:
-    return LHS.ule(RHS);
-  case CmpInst::Predicate::ICMP_SGT:
-    return LHS.sgt(RHS);
-  case CmpInst::Predicate::ICMP_SGE:
-    return LHS.sge(RHS);
-  case CmpInst::Predicate::ICMP_SLT:
-    return LHS.slt(RHS);
-  case CmpInst::Predicate::ICMP_SLE:
-    return LHS.sle(RHS);
-  default:
-    llvm_unreachable("Not an ICmp predicate!");
-  }
-}
-
 void ICmpTestImpl(CmpInst::Predicate Pred) {
   unsigned Bits = 4;
   EnumerateTwoConstantRanges(
       Bits, [&](const ConstantRange &CR1, const ConstantRange &CR2) {
         bool Exhaustive = true;
         ForeachNumInConstantRange(CR1, [&](const APInt &N1) {
-          ForeachNumInConstantRange(
-              CR2, [&](const APInt &N2) { Exhaustive &= icmp(Pred, N1, N2); });
+          ForeachNumInConstantRange(CR2, [&](const APInt &N2) {
+            Exhaustive &= ICmpInst::compare(N1, N2, Pred);
+          });
         });
         EXPECT_EQ(CR1.icmp(Pred, CR2), Exhaustive);
       });