[InstrSimplify,NewGVN] Add option to ignore additional instr info when simplifying.
authorFlorian Hahn <florian.hahn@arm.com>
Fri, 17 Aug 2018 14:39:04 +0000 (14:39 +0000)
committerFlorian Hahn <florian.hahn@arm.com>
Fri, 17 Aug 2018 14:39:04 +0000 (14:39 +0000)
NewGVN uses InstructionSimplify for simplifications of leaders of
congruence classes. It is not guaranteed that the metadata or other
flags/keywords (like nsw or exact) of the leader is available for all members
in a congruence class, so we cannot use it for simplification.

This patch adds a InstrInfoQuery struct with a boolean field
UseInstrInfo (which defaults to true to keep the current behavior as
default) and a set of helper methods to get metadata/keywords for a
given instruction, if UseInstrInfo is true. The whole thing might need a
better name, to avoid confusion with TargetInstrInfo but I am not sure
what a better name would be.

The current patch threads through InstrInfoQuery to the required
places, which is messier then it would need to be, if
InstructionSimplify and ValueTracking would share the same Query struct.

The reason I added it as a separate struct is that it can be shared
between InstructionSimplify and ValueTracking's query objects. Also,
some places do not need a full query object, just the InstrInfoQuery.

It also updates some interfaces that do not take a Query object, but a
set of optional parameters to take an additional boolean UseInstrInfo.

See https://bugs.llvm.org/show_bug.cgi?id=37540.

Reviewers: dberlin, davide, efriedma, sebpop, hiraditya

Reviewed By: hiraditya

Differential Revision: https://reviews.llvm.org/D47143

llvm-svn: 340031

llvm/include/llvm/Analysis/InstructionSimplify.h
llvm/include/llvm/Analysis/ValueTracking.h
llvm/lib/Analysis/InstructionSimplify.cpp
llvm/lib/Analysis/ValueTracking.cpp
llvm/lib/Transforms/Scalar/NewGVN.cpp
llvm/test/Transforms/NewGVN/pr33185.ll

index 4f896bd..6662e91 100644 (file)
@@ -32,6 +32,8 @@
 #ifndef LLVM_ANALYSIS_INSTRUCTIONSIMPLIFY_H
 #define LLVM_ANALYSIS_INSTRUCTIONSIMPLIFY_H
 
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Operator.h"
 #include "llvm/IR/User.h"
 
 namespace llvm {
@@ -40,7 +42,6 @@ template <typename T, typename... TArgs> class AnalysisManager;
 template <class T> class ArrayRef;
 class AssumptionCache;
 class DominatorTree;
-class Instruction;
 class ImmutableCallSite;
 class DataLayout;
 class FastMathFlags;
@@ -50,6 +51,41 @@ class Pass;
 class TargetLibraryInfo;
 class Type;
 class Value;
+class MDNode;
+class BinaryOperator;
+
+/// InstrInfoQuery provides an interface to query additional information for
+/// instructions like metadata or keywords like nsw, which provides conservative
+/// results if the users specified it is safe to use.
+struct InstrInfoQuery {
+  InstrInfoQuery(bool UMD) : UseInstrInfo(UMD) {}
+  InstrInfoQuery() : UseInstrInfo(true) {}
+  bool UseInstrInfo = true;
+
+  MDNode *getMetadata(const Instruction *I, unsigned KindID) const {
+    if (UseInstrInfo)
+      return I->getMetadata(KindID);
+    return nullptr;
+  }
+
+  template <class InstT> bool hasNoUnsignedWrap(const InstT *Op) const {
+    if (UseInstrInfo)
+      return Op->hasNoUnsignedWrap();
+    return false;
+  }
+
+  template <class InstT> bool hasNoSignedWrap(const InstT *Op) const {
+    if (UseInstrInfo)
+      return Op->hasNoSignedWrap();
+    return false;
+  }
+
+  bool isExact(const BinaryOperator *Op) const {
+    if (UseInstrInfo && isa<PossiblyExactOperator>(Op))
+      return cast<PossiblyExactOperator>(Op)->isExact();
+    return false;
+  }
+};
 
 struct SimplifyQuery {
   const DataLayout &DL;
@@ -58,14 +94,19 @@ struct SimplifyQuery {
   AssumptionCache *AC = nullptr;
   const Instruction *CxtI = nullptr;
 
+  // Wrapper to query additional information for instructions like metadata or
+  // keywords like nsw, which provides conservative results if those cannot
+  // be safely used.
+  const InstrInfoQuery IIQ;
+
   SimplifyQuery(const DataLayout &DL, const Instruction *CXTI = nullptr)
       : DL(DL), CxtI(CXTI) {}
 
   SimplifyQuery(const DataLayout &DL, const TargetLibraryInfo *TLI,
                 const DominatorTree *DT = nullptr,
                 AssumptionCache *AC = nullptr,
-                const Instruction *CXTI = nullptr)
-      : DL(DL), TLI(TLI), DT(DT), AC(AC), CxtI(CXTI) {}
+                const Instruction *CXTI = nullptr, bool UseInstrInfo = true)
+      : DL(DL), TLI(TLI), DT(DT), AC(AC), CxtI(CXTI), IIQ(UseInstrInfo) {}
   SimplifyQuery getWithInstruction(Instruction *I) const {
     SimplifyQuery Copy(*this);
     Copy.CxtI = I;
index 99b47fb..57cf37f 100644 (file)
@@ -55,14 +55,16 @@ class Value;
                         AssumptionCache *AC = nullptr,
                         const Instruction *CxtI = nullptr,
                         const DominatorTree *DT = nullptr,
-                        OptimizationRemarkEmitter *ORE = nullptr);
+                        OptimizationRemarkEmitter *ORE = nullptr,
+                        bool UseInstrInfo = true);
 
   /// Returns the known bits rather than passing by reference.
   KnownBits computeKnownBits(const Value *V, const DataLayout &DL,
                              unsigned Depth = 0, AssumptionCache *AC = nullptr,
                              const Instruction *CxtI = nullptr,
                              const DominatorTree *DT = nullptr,
-                             OptimizationRemarkEmitter *ORE = nullptr);
+                             OptimizationRemarkEmitter *ORE = nullptr,
+                             bool UseInstrInfo = true);
 
   /// Compute known bits from the range metadata.
   /// \p KnownZero the set of bits that are known to be zero
@@ -75,7 +77,8 @@ class Value;
                            const DataLayout &DL,
                            AssumptionCache *AC = nullptr,
                            const Instruction *CxtI = nullptr,
-                           const DominatorTree *DT = nullptr);
+                           const DominatorTree *DT = nullptr,
+                           bool UseInstrInfo = true);
 
   /// Return true if the given value is known to have exactly one bit set when
   /// defined. For vectors return true if every element is known to be a power
@@ -86,7 +89,8 @@ class Value;
                               bool OrZero = false, unsigned Depth = 0,
                               AssumptionCache *AC = nullptr,
                               const Instruction *CxtI = nullptr,
-                              const DominatorTree *DT = nullptr);
+                              const DominatorTree *DT = nullptr,
+                              bool UseInstrInfo = true);
 
   bool isOnlyUsedInZeroEqualityComparison(const Instruction *CxtI);
 
@@ -99,7 +103,8 @@ class Value;
   bool isKnownNonZero(const Value *V, const DataLayout &DL, unsigned Depth = 0,
                       AssumptionCache *AC = nullptr,
                       const Instruction *CxtI = nullptr,
-                      const DominatorTree *DT = nullptr);
+                      const DominatorTree *DT = nullptr,
+                      bool UseInstrInfo = true);
 
   /// Return true if the two given values are negation.
   /// Currently can recoginze Value pair:
@@ -112,28 +117,32 @@ class Value;
                           unsigned Depth = 0,
                           AssumptionCache *AC = nullptr,
                           const Instruction *CxtI = nullptr,
-                          const DominatorTree *DT = nullptr);
+                          const DominatorTree *DT = nullptr,
+                          bool UseInstrInfo = true);
 
   /// Returns true if the given value is known be positive (i.e. non-negative
   /// and non-zero).
   bool isKnownPositive(const Value *V, const DataLayout &DL, unsigned Depth = 0,
                        AssumptionCache *AC = nullptr,
                        const Instruction *CxtI = nullptr,
-                       const DominatorTree *DT = nullptr);
+                       const DominatorTree *DT = nullptr,
+                       bool UseInstrInfo = true);
 
   /// Returns true if the given value is known be negative (i.e. non-positive
   /// and non-zero).
   bool isKnownNegative(const Value *V, const DataLayout &DL, unsigned Depth = 0,
                        AssumptionCache *AC = nullptr,
                        const Instruction *CxtI = nullptr,
-                       const DominatorTree *DT = nullptr);
+                       const DominatorTree *DT = nullptr,
+                       bool UseInstrInfo = true);
 
   /// Return true if the given values are known to be non-equal when defined.
   /// Supports scalar integer types only.
   bool isKnownNonEqual(const Value *V1, const Value *V2, const DataLayout &DL,
-                      AssumptionCache *AC = nullptr,
-                      const Instruction *CxtI = nullptr,
-                      const DominatorTree *DT = nullptr);
+                       AssumptionCache *AC = nullptr,
+                       const Instruction *CxtI = nullptr,
+                       const DominatorTree *DT = nullptr,
+                       bool UseInstrInfo = true);
 
   /// Return true if 'V & Mask' is known to be zero. We use this predicate to
   /// simplify operations downstream. Mask is known to be zero for bits that V
@@ -148,7 +157,8 @@ class Value;
                          const DataLayout &DL,
                          unsigned Depth = 0, AssumptionCache *AC = nullptr,
                          const Instruction *CxtI = nullptr,
-                         const DominatorTree *DT = nullptr);
+                         const DominatorTree *DT = nullptr,
+                         bool UseInstrInfo = true);
 
   /// Return the number of times the sign bit of the register is replicated into
   /// the other bits. We know that at least 1 bit is always equal to the sign
@@ -160,7 +170,8 @@ class Value;
   unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL,
                               unsigned Depth = 0, AssumptionCache *AC = nullptr,
                               const Instruction *CxtI = nullptr,
-                              const DominatorTree *DT = nullptr);
+                              const DominatorTree *DT = nullptr,
+                              bool UseInstrInfo = true);
 
   /// This function computes the integer multiple of Base that equals V. If
   /// successful, it returns true and returns the multiple in Multiple. If
@@ -406,18 +417,21 @@ class Value;
                                                const DataLayout &DL,
                                                AssumptionCache *AC,
                                                const Instruction *CxtI,
-                                               const DominatorTree *DT);
+                                               const DominatorTree *DT,
+                                               bool UseInstrInfo = true);
   OverflowResult computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
                                              const DataLayout &DL,
                                              AssumptionCache *AC,
                                              const Instruction *CxtI,
-                                             const DominatorTree *DT);
+                                             const DominatorTree *DT,
+                                             bool UseInstrInfo = true);
   OverflowResult computeOverflowForUnsignedAdd(const Value *LHS,
                                                const Value *RHS,
                                                const DataLayout &DL,
                                                AssumptionCache *AC,
                                                const Instruction *CxtI,
-                                               const DominatorTree *DT);
+                                               const DominatorTree *DT,
+                                               bool UseInstrInfo = true);
   OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
                                              const DataLayout &DL,
                                              AssumptionCache *AC = nullptr,
index dee681a..35aaf96 100644 (file)
@@ -861,8 +861,10 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
 
   // (X / Y) * Y -> X if the division is exact.
   Value *X = nullptr;
-  if (match(Op0, m_Exact(m_IDiv(m_Value(X), m_Specific(Op1)))) || // (X / Y) * Y
-      match(Op1, m_Exact(m_IDiv(m_Value(X), m_Specific(Op0)))))   // Y * (X / Y)
+  if (Q.IIQ.UseInstrInfo &&
+      (match(Op0,
+             m_Exact(m_IDiv(m_Value(X), m_Specific(Op1)))) ||     // (X / Y) * Y
+       match(Op1, m_Exact(m_IDiv(m_Value(X), m_Specific(Op0)))))) // Y * (X / Y)
     return X;
 
   // i1 mul -> and.
@@ -1035,8 +1037,8 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
   if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) {
     auto *Mul = cast<OverflowingBinaryOperator>(Op0);
     // If the Mul does not overflow, then we are good to go.
-    if ((IsSigned && Mul->hasNoSignedWrap()) ||
-        (!IsSigned && Mul->hasNoUnsignedWrap()))
+    if ((IsSigned && Q.IIQ.hasNoSignedWrap(Mul)) ||
+        (!IsSigned && Q.IIQ.hasNoUnsignedWrap(Mul)))
       return X;
     // If X has the form X = A / Y, then X * Y cannot overflow.
     if ((IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) ||
@@ -1094,10 +1096,11 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
     return Op0;
 
   // (X << Y) % X -> 0
-  if ((Opcode == Instruction::SRem &&
-       match(Op0, m_NSWShl(m_Specific(Op1), m_Value()))) ||
-      (Opcode == Instruction::URem &&
-       match(Op0, m_NUWShl(m_Specific(Op1), m_Value()))))
+  if (Q.IIQ.UseInstrInfo &&
+      ((Opcode == Instruction::SRem &&
+        match(Op0, m_NSWShl(m_Specific(Op1), m_Value()))) ||
+       (Opcode == Instruction::URem &&
+        match(Op0, m_NUWShl(m_Specific(Op1), m_Value())))))
     return Constant::getNullValue(Op0->getType());
 
   // If the operation is with the result of a select instruction, check whether
@@ -1295,7 +1298,8 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
 
   // (X >> A) << A -> X
   Value *X;
-  if (match(Op0, m_Exact(m_Shr(m_Value(X), m_Specific(Op1)))))
+  if (Q.IIQ.UseInstrInfo &&
+      match(Op0, m_Exact(m_Shr(m_Value(X), m_Specific(Op1)))))
     return X;
 
   // shl nuw i8 C, %x  ->  C  iff C has sign bit set.
@@ -1365,7 +1369,7 @@ static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact,
 
   // (X << A) >> A -> X
   Value *X;
-  if (match(Op0, m_NSWShl(m_Value(X), m_Specific(Op1))))
+  if (Q.IIQ.UseInstrInfo && match(Op0, m_NSWShl(m_Value(X), m_Specific(Op1))))
     return X;
 
   // Arithmetic shifting an all-sign-bit value is a no-op.
@@ -1552,7 +1556,8 @@ static Value *simplifyAndOrOfICmpsWithZero(ICmpInst *Cmp0, ICmpInst *Cmp1,
   return nullptr;
 }
 
-static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) {
+static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
+                                        const InstrInfoQuery &IIQ) {
   // (icmp (add V, C0), C1) & (icmp V, C0)
   ICmpInst::Predicate Pred0, Pred1;
   const APInt *C0, *C1;
@@ -1563,13 +1568,13 @@ static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) {
   if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value())))
     return nullptr;
 
-  auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0));
+  auto *AddInst = cast<OverflowingBinaryOperator>(Op0->getOperand(0));
   if (AddInst->getOperand(1) != Op1->getOperand(1))
     return nullptr;
 
   Type *ITy = Op0->getType();
-  bool isNSW = AddInst->hasNoSignedWrap();
-  bool isNUW = AddInst->hasNoUnsignedWrap();
+  bool isNSW = IIQ.hasNoSignedWrap(AddInst);
+  bool isNUW = IIQ.hasNoUnsignedWrap(AddInst);
 
   const APInt Delta = *C1 - *C0;
   if (C0->isStrictlyPositive()) {
@@ -1598,7 +1603,8 @@ static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) {
   return nullptr;
 }
 
-static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) {
+static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1,
+                                 const InstrInfoQuery &IIQ) {
   if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true))
     return X;
   if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/true))
@@ -1615,15 +1621,16 @@ static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) {
   if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, true))
     return X;
 
-  if (Value *X = simplifyAndOfICmpsWithAdd(Op0, Op1))
+  if (Value *X = simplifyAndOfICmpsWithAdd(Op0, Op1, IIQ))
     return X;
-  if (Value *X = simplifyAndOfICmpsWithAdd(Op1, Op0))
+  if (Value *X = simplifyAndOfICmpsWithAdd(Op1, Op0, IIQ))
     return X;
 
   return nullptr;
 }
 
-static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) {
+static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
+                                       const InstrInfoQuery &IIQ) {
   // (icmp (add V, C0), C1) | (icmp V, C0)
   ICmpInst::Predicate Pred0, Pred1;
   const APInt *C0, *C1;
@@ -1639,8 +1646,8 @@ static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) {
     return nullptr;
 
   Type *ITy = Op0->getType();
-  bool isNSW = AddInst->hasNoSignedWrap();
-  bool isNUW = AddInst->hasNoUnsignedWrap();
+  bool isNSW = IIQ.hasNoSignedWrap(AddInst);
+  bool isNUW = IIQ.hasNoUnsignedWrap(AddInst);
 
   const APInt Delta = *C1 - *C0;
   if (C0->isStrictlyPositive()) {
@@ -1669,7 +1676,8 @@ static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) {
   return nullptr;
 }
 
-static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) {
+static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1,
+                                const InstrInfoQuery &IIQ) {
   if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/false))
     return X;
   if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/false))
@@ -1686,9 +1694,9 @@ static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) {
   if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, false))
     return X;
 
-  if (Value *X = simplifyOrOfICmpsWithAdd(Op0, Op1))
+  if (Value *X = simplifyOrOfICmpsWithAdd(Op0, Op1, IIQ))
     return X;
-  if (Value *X = simplifyOrOfICmpsWithAdd(Op1, Op0))
+  if (Value *X = simplifyOrOfICmpsWithAdd(Op1, Op0, IIQ))
     return X;
 
   return nullptr;
@@ -1732,7 +1740,7 @@ static Value *simplifyAndOrOfFCmps(const TargetLibraryInfo *TLI,
   return nullptr;
 }
 
-static Value *simplifyAndOrOfCmps(const TargetLibraryInfo *TLI,
+static Value *simplifyAndOrOfCmps(const SimplifyQuery &Q,
                                   Value *Op0, Value *Op1, bool IsAnd) {
   // Look through casts of the 'and' operands to find compares.
   auto *Cast0 = dyn_cast<CastInst>(Op0);
@@ -1747,13 +1755,13 @@ static Value *simplifyAndOrOfCmps(const TargetLibraryInfo *TLI,
   auto *ICmp0 = dyn_cast<ICmpInst>(Op0);
   auto *ICmp1 = dyn_cast<ICmpInst>(Op1);
   if (ICmp0 && ICmp1)
-    V = IsAnd ? simplifyAndOfICmps(ICmp0, ICmp1) :
-                simplifyOrOfICmps(ICmp0, ICmp1);
+    V = IsAnd ? simplifyAndOfICmps(ICmp0, ICmp1, Q.IIQ)
+              : simplifyOrOfICmps(ICmp0, ICmp1, Q.IIQ);
 
   auto *FCmp0 = dyn_cast<FCmpInst>(Op0);
   auto *FCmp1 = dyn_cast<FCmpInst>(Op1);
   if (FCmp0 && FCmp1)
-    V = simplifyAndOrOfFCmps(TLI, FCmp0, FCmp1, IsAnd);
+    V = simplifyAndOrOfFCmps(Q.TLI, FCmp0, FCmp1, IsAnd);
 
   if (!V)
     return nullptr;
@@ -1833,7 +1841,7 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
       return Op1;
   }
 
-  if (Value *V = simplifyAndOrOfCmps(Q.TLI, Op0, Op1, true))
+  if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, true))
     return V;
 
   // Try some generic simplifications for associative operations.
@@ -1983,7 +1991,7 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
        match(Op0, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B)))))
     return Op0;
 
-  if (Value *V = simplifyAndOrOfCmps(Q.TLI, Op0, Op1, false))
+  if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, false))
     return V;
 
   // Try some generic simplifications for associative operations.
@@ -2144,13 +2152,15 @@ static Constant *
 computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI,
                    const DominatorTree *DT, CmpInst::Predicate Pred,
                    AssumptionCache *AC, const Instruction *CxtI,
-                   Value *LHS, Value *RHS) {
+                   const InstrInfoQuery &IIQ, Value *LHS, Value *RHS) {
   // First, skip past any trivial no-ops.
   LHS = LHS->stripPointerCasts();
   RHS = RHS->stripPointerCasts();
 
   // A non-null pointer is not equal to a null pointer.
-  if (llvm::isKnownNonZero(LHS, DL) && isa<ConstantPointerNull>(RHS) &&
+  if (llvm::isKnownNonZero(LHS, DL, 0, nullptr, nullptr, nullptr,
+                           IIQ.UseInstrInfo) &&
+      isa<ConstantPointerNull>(RHS) &&
       (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE))
     return ConstantInt::get(GetCompareTy(LHS),
                             !CmpInst::isTrueWhenEqual(Pred));
@@ -2415,12 +2425,12 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
     return getTrue(ITy);
   case ICmpInst::ICMP_EQ:
   case ICmpInst::ICMP_ULE:
-    if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT))
+    if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo))
       return getFalse(ITy);
     break;
   case ICmpInst::ICMP_NE:
   case ICmpInst::ICMP_UGT:
-    if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT))
+    if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo))
       return getTrue(ITy);
     break;
   case ICmpInst::ICMP_SLT: {
@@ -2465,17 +2475,18 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
 /// Many binary operators with a constant operand have an easy-to-compute
 /// range of outputs. This can be used to fold a comparison to always true or
 /// always false.
-static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) {
+static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper,
+                              const InstrInfoQuery &IIQ) {
   unsigned Width = Lower.getBitWidth();
   const APInt *C;
   switch (BO.getOpcode()) {
   case Instruction::Add:
     if (match(BO.getOperand(1), m_APInt(C)) && !C->isNullValue()) {
       // FIXME: If we have both nuw and nsw, we should reduce the range further.
-      if (BO.hasNoUnsignedWrap()) {
+      if (IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(&BO))) {
         // 'add nuw x, C' produces [C, UINT_MAX].
         Lower = *C;
-      } else if (BO.hasNoSignedWrap()) {
+      } else if (IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(&BO))) {
         if (C->isNegative()) {
           // 'add nsw x, -C' produces [SINT_MIN, SINT_MAX - C].
           Lower = APInt::getSignedMinValue(Width);
@@ -2508,7 +2519,7 @@ static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) {
       Upper = APInt::getSignedMaxValue(Width).ashr(*C) + 1;
     } else if (match(BO.getOperand(0), m_APInt(C))) {
       unsigned ShiftAmount = Width - 1;
-      if (!C->isNullValue() && BO.isExact())
+      if (!C->isNullValue() && IIQ.isExact(&BO))
         ShiftAmount = C->countTrailingZeros();
       if (C->isNegative()) {
         // 'ashr C, x' produces [C, C >> (Width-1)]
@@ -2529,7 +2540,7 @@ static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) {
     } else if (match(BO.getOperand(0), m_APInt(C))) {
       // 'lshr C, x' produces [C >> (Width-1), C].
       unsigned ShiftAmount = Width - 1;
-      if (!C->isNullValue() && BO.isExact())
+      if (!C->isNullValue() && IIQ.isExact(&BO))
         ShiftAmount = C->countTrailingZeros();
       Lower = C->lshr(ShiftAmount);
       Upper = *C + 1;
@@ -2538,7 +2549,7 @@ static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) {
 
   case Instruction::Shl:
     if (match(BO.getOperand(0), m_APInt(C))) {
-      if (BO.hasNoUnsignedWrap()) {
+      if (IIQ.hasNoUnsignedWrap(&BO)) {
         // 'shl nuw C, x' produces [C, C << CLZ(C)]
         Lower = *C;
         Upper = Lower.shl(Lower.countLeadingZeros()) + 1;
@@ -2620,7 +2631,7 @@ static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) {
 }
 
 static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
-                                       Value *RHS) {
+                                       Value *RHS, const InstrInfoQuery &IIQ) {
   Type *ITy = GetCompareTy(RHS); // The return type.
 
   Value *X;
@@ -2651,13 +2662,13 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
   APInt Lower = APInt(Width, 0);
   APInt Upper = APInt(Width, 0);
   if (auto *BO = dyn_cast<BinaryOperator>(LHS))
-    setLimitsForBinOp(*BO, Lower, Upper);
+    setLimitsForBinOp(*BO, Lower, Upper, IIQ);
 
   ConstantRange LHS_CR =
       Lower != Upper ? ConstantRange(Lower, Upper) : ConstantRange(Width, true);
 
   if (auto *I = dyn_cast<Instruction>(LHS))
-    if (auto *Ranges = I->getMetadata(LLVMContext::MD_range))
+    if (auto *Ranges = IIQ.getMetadata(I, LLVMContext::MD_range))
       LHS_CR = LHS_CR.intersectWith(getConstantRangeFromMetadata(*Ranges));
 
   if (!LHS_CR.isFullSet()) {
@@ -2690,16 +2701,20 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
       B = LBO->getOperand(1);
       NoLHSWrapProblem =
           ICmpInst::isEquality(Pred) ||
-          (CmpInst::isUnsigned(Pred) && LBO->hasNoUnsignedWrap()) ||
-          (CmpInst::isSigned(Pred) && LBO->hasNoSignedWrap());
+          (CmpInst::isUnsigned(Pred) &&
+           Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(LBO))) ||
+          (CmpInst::isSigned(Pred) &&
+           Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(LBO)));
     }
     if (RBO && RBO->getOpcode() == Instruction::Add) {
       C = RBO->getOperand(0);
       D = RBO->getOperand(1);
       NoRHSWrapProblem =
           ICmpInst::isEquality(Pred) ||
-          (CmpInst::isUnsigned(Pred) && RBO->hasNoUnsignedWrap()) ||
-          (CmpInst::isSigned(Pred) && RBO->hasNoSignedWrap());
+          (CmpInst::isUnsigned(Pred) &&
+           Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(RBO))) ||
+          (CmpInst::isSigned(Pred) &&
+           Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(RBO)));
     }
 
     // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow.
@@ -2917,7 +2932,8 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
         // - The shift is nuw, we can't shift out the one bit.
         // - CI2 is one
         // - CI isn't zero
-        if (LBO->hasNoSignedWrap() || LBO->hasNoUnsignedWrap() ||
+        if (Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(LBO)) ||
+            Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(LBO)) ||
             CI2Val->isOneValue() || !CI->isZero()) {
           if (Pred == ICmpInst::ICMP_EQ)
             return ConstantInt::getFalse(RHS->getContext());
@@ -2941,29 +2957,31 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
       break;
     case Instruction::UDiv:
     case Instruction::LShr:
-      if (ICmpInst::isSigned(Pred) || !LBO->isExact() || !RBO->isExact())
+      if (ICmpInst::isSigned(Pred) || !Q.IIQ.isExact(LBO) ||
+          !Q.IIQ.isExact(RBO))
         break;
       if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0),
                                       RBO->getOperand(0), Q, MaxRecurse - 1))
           return V;
       break;
     case Instruction::SDiv:
-      if (!ICmpInst::isEquality(Pred) || !LBO->isExact() || !RBO->isExact())
+      if (!ICmpInst::isEquality(Pred) || !Q.IIQ.isExact(LBO) ||
+          !Q.IIQ.isExact(RBO))
         break;
       if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0),
                                       RBO->getOperand(0), Q, MaxRecurse - 1))
         return V;
       break;
     case Instruction::AShr:
-      if (!LBO->isExact() || !RBO->isExact())
+      if (!Q.IIQ.isExact(LBO) || !Q.IIQ.isExact(RBO))
         break;
       if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0),
                                       RBO->getOperand(0), Q, MaxRecurse - 1))
         return V;
       break;
     case Instruction::Shl: {
-      bool NUW = LBO->hasNoUnsignedWrap() && RBO->hasNoUnsignedWrap();
-      bool NSW = LBO->hasNoSignedWrap() && RBO->hasNoSignedWrap();
+      bool NUW = Q.IIQ.hasNoUnsignedWrap(LBO) && Q.IIQ.hasNoUnsignedWrap(RBO);
+      bool NSW = Q.IIQ.hasNoSignedWrap(LBO) && Q.IIQ.hasNoSignedWrap(RBO);
       if (!NUW && !NSW)
         break;
       if (!NSW && ICmpInst::isSigned(Pred))
@@ -3211,7 +3229,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
   if (Value *V = simplifyICmpWithZero(Pred, LHS, RHS, Q))
     return V;
 
-  if (Value *V = simplifyICmpWithConstant(Pred, LHS, RHS))
+  if (Value *V = simplifyICmpWithConstant(Pred, LHS, RHS, Q.IIQ))
     return V;
 
   // If both operands have range metadata, use the metadata
@@ -3220,8 +3238,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
     auto RHS_Instr = cast<Instruction>(RHS);
     auto LHS_Instr = cast<Instruction>(LHS);
 
-    if (RHS_Instr->getMetadata(LLVMContext::MD_range) &&
-        LHS_Instr->getMetadata(LLVMContext::MD_range)) {
+    if (Q.IIQ.getMetadata(RHS_Instr, LLVMContext::MD_range) &&
+        Q.IIQ.getMetadata(LHS_Instr, LLVMContext::MD_range)) {
       auto RHS_CR = getConstantRangeFromMetadata(
           *RHS_Instr->getMetadata(LLVMContext::MD_range));
       auto LHS_CR = getConstantRangeFromMetadata(
@@ -3399,7 +3417,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
 
   // icmp eq|ne X, Y -> false|true if X != Y
   if (ICmpInst::isEquality(Pred) &&
-      isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT)) {
+      isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo)) {
     return Pred == ICmpInst::ICMP_NE ? getTrue(ITy) : getFalse(ITy);
   }
 
@@ -3412,8 +3430,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
   // Simplify comparisons of related pointers using a powerful, recursive
   // GEP-walk when we have target data available..
   if (LHS->getType()->isPointerTy())
-    if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI, LHS,
-                                     RHS))
+    if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI,
+                                     Q.IIQ, LHS, RHS))
       return C;
   if (auto *CLHS = dyn_cast<PtrToIntOperator>(LHS))
     if (auto *CRHS = dyn_cast<PtrToIntOperator>(RHS))
@@ -3422,7 +3440,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
           Q.DL.getTypeSizeInBits(CRHS->getPointerOperandType()) ==
               Q.DL.getTypeSizeInBits(CRHS->getType()))
         if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI,
-                                         CLHS->getPointerOperand(),
+                                         Q.IIQ, CLHS->getPointerOperand(),
                                          CRHS->getPointerOperand()))
           return C;
 
@@ -3636,11 +3654,10 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
     //
     // We can't replace %sel with %add unless we strip away the flags.
     if (isa<OverflowingBinaryOperator>(B))
-      if (B->hasNoSignedWrap() || B->hasNoUnsignedWrap())
-        return nullptr;
-    if (isa<PossiblyExactOperator>(B))
-      if (B->isExact())
+      if (Q.IIQ.hasNoSignedWrap(B) || Q.IIQ.hasNoUnsignedWrap(B))
         return nullptr;
+    if (isa<PossiblyExactOperator>(B) && Q.IIQ.isExact(B))
+      return nullptr;
 
     if (MaxRecurse) {
       if (B->getOperand(0) == Op)
@@ -4970,18 +4987,20 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ,
                               I->getFastMathFlags(), Q);
     break;
   case Instruction::Add:
-    Result = SimplifyAddInst(I->getOperand(0), I->getOperand(1),
-                             cast<BinaryOperator>(I)->hasNoSignedWrap(),
-                             cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q);
+    Result =
+        SimplifyAddInst(I->getOperand(0), I->getOperand(1),
+                        Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
+                        Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
     break;
   case Instruction::FSub:
     Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1),
                               I->getFastMathFlags(), Q);
     break;
   case Instruction::Sub:
-    Result = SimplifySubInst(I->getOperand(0), I->getOperand(1),
-                             cast<BinaryOperator>(I)->hasNoSignedWrap(),
-                             cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q);
+    Result =
+        SimplifySubInst(I->getOperand(0), I->getOperand(1),
+                        Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
+                        Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
     break;
   case Instruction::FMul:
     Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1),
@@ -5011,17 +5030,18 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ,
                               I->getFastMathFlags(), Q);
     break;
   case Instruction::Shl:
-    Result = SimplifyShlInst(I->getOperand(0), I->getOperand(1),
-                             cast<BinaryOperator>(I)->hasNoSignedWrap(),
-                             cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q);
+    Result =
+        SimplifyShlInst(I->getOperand(0), I->getOperand(1),
+                        Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
+                        Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
     break;
   case Instruction::LShr:
     Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1),
-                              cast<BinaryOperator>(I)->isExact(), Q);
+                              Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);
     break;
   case Instruction::AShr:
     Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1),
-                              cast<BinaryOperator>(I)->isExact(), Q);
+                              Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);
     break;
   case Instruction::And:
     Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), Q);
index 2484cec..a89c877 100644 (file)
@@ -118,14 +118,18 @@ struct Query {
   /// (all of which can call computeKnownBits), and so on.
   std::array<const Value *, MaxDepth> Excluded;
 
+  /// If true, it is safe to use metadata during simplification.
+  InstrInfoQuery IIQ;
+
   unsigned NumExcluded = 0;
 
   Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI,
-        const DominatorTree *DT, OptimizationRemarkEmitter *ORE = nullptr)
-      : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE) {}
+        const DominatorTree *DT, bool UseInstrInfo,
+        OptimizationRemarkEmitter *ORE = nullptr)
+      : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE), IIQ(UseInstrInfo) {}
 
   Query(const Query &Q, const Value *NewExcl)
-      : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE),
+      : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), IIQ(Q.IIQ),
         NumExcluded(Q.NumExcluded) {
     Excluded = Q.Excluded;
     Excluded[NumExcluded++] = NewExcl;
@@ -165,9 +169,9 @@ void llvm::computeKnownBits(const Value *V, KnownBits &Known,
                             const DataLayout &DL, unsigned Depth,
                             AssumptionCache *AC, const Instruction *CxtI,
                             const DominatorTree *DT,
-                            OptimizationRemarkEmitter *ORE) {
+                            OptimizationRemarkEmitter *ORE, bool UseInstrInfo) {
   ::computeKnownBits(V, Known, Depth,
-                     Query(DL, AC, safeCxtI(V, CxtI), DT, ORE));
+                     Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
 }
 
 static KnownBits computeKnownBits(const Value *V, unsigned Depth,
@@ -177,15 +181,16 @@ KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
                                  unsigned Depth, AssumptionCache *AC,
                                  const Instruction *CxtI,
                                  const DominatorTree *DT,
-                                 OptimizationRemarkEmitter *ORE) {
-  return ::computeKnownBits(V, Depth,
-                            Query(DL, AC, safeCxtI(V, CxtI), DT, ORE));
+                                 OptimizationRemarkEmitter *ORE,
+                                 bool UseInstrInfo) {
+  return ::computeKnownBits(
+      V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
 }
 
 bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
-                               const DataLayout &DL,
-                               AssumptionCache *AC, const Instruction *CxtI,
-                               const DominatorTree *DT) {
+                               const DataLayout &DL, AssumptionCache *AC,
+                               const Instruction *CxtI, const DominatorTree *DT,
+                               bool UseInstrInfo) {
   assert(LHS->getType() == RHS->getType() &&
          "LHS and RHS should have the same type");
   assert(LHS->getType()->isIntOrIntVectorTy() &&
@@ -201,8 +206,8 @@ bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
   IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType());
   KnownBits LHSKnown(IT->getBitWidth());
   KnownBits RHSKnown(IT->getBitWidth());
-  computeKnownBits(LHS, LHSKnown, DL, 0, AC, CxtI, DT);
-  computeKnownBits(RHS, RHSKnown, DL, 0, AC, CxtI, DT);
+  computeKnownBits(LHS, LHSKnown, DL, 0, AC, CxtI, DT, nullptr, UseInstrInfo);
+  computeKnownBits(RHS, RHSKnown, DL, 0, AC, CxtI, DT, nullptr, UseInstrInfo);
   return (LHSKnown.Zero | RHSKnown.Zero).isAllOnesValue();
 }
 
@@ -222,69 +227,71 @@ static bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth,
                                    const Query &Q);
 
 bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL,
-                                  bool OrZero,
-                                  unsigned Depth, AssumptionCache *AC,
-                                  const Instruction *CxtI,
-                                  const DominatorTree *DT) {
-  return ::isKnownToBeAPowerOfTwo(V, OrZero, Depth,
-                                  Query(DL, AC, safeCxtI(V, CxtI), DT));
+                                  bool OrZero, unsigned Depth,
+                                  AssumptionCache *AC, const Instruction *CxtI,
+                                  const DominatorTree *DT, bool UseInstrInfo) {
+  return ::isKnownToBeAPowerOfTwo(
+      V, OrZero, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo));
 }
 
 static bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q);
 
 bool llvm::isKnownNonZero(const Value *V, const DataLayout &DL, unsigned Depth,
                           AssumptionCache *AC, const Instruction *CxtI,
-                          const DominatorTree *DT) {
-  return ::isKnownNonZero(V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT));
+                          const DominatorTree *DT, bool UseInstrInfo) {
+  return ::isKnownNonZero(V, Depth,
+                          Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo));
 }
 
 bool llvm::isKnownNonNegative(const Value *V, const DataLayout &DL,
-                              unsigned Depth,
-                              AssumptionCache *AC, const Instruction *CxtI,
-                              const DominatorTree *DT) {
-  KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT);
+                              unsigned Depth, AssumptionCache *AC,
+                              const Instruction *CxtI, const DominatorTree *DT,
+                              bool UseInstrInfo) {
+  KnownBits Known =
+      computeKnownBits(V, DL, Depth, AC, CxtI, DT, nullptr, UseInstrInfo);
   return Known.isNonNegative();
 }
 
 bool llvm::isKnownPositive(const Value *V, const DataLayout &DL, unsigned Depth,
                            AssumptionCache *AC, const Instruction *CxtI,
-                           const DominatorTree *DT) {
+                           const DominatorTree *DT, bool UseInstrInfo) {
   if (auto *CI = dyn_cast<ConstantInt>(V))
     return CI->getValue().isStrictlyPositive();
 
   // TODO: We'd doing two recursive queries here.  We should factor this such
   // that only a single query is needed.
-  return isKnownNonNegative(V, DL, Depth, AC, CxtI, DT) &&
-    isKnownNonZero(V, DL, Depth, AC, CxtI, DT);
+  return isKnownNonNegative(V, DL, Depth, AC, CxtI, DT, UseInstrInfo) &&
+         isKnownNonZero(V, DL, Depth, AC, CxtI, DT, UseInstrInfo);
 }
 
 bool llvm::isKnownNegative(const Value *V, const DataLayout &DL, unsigned Depth,
                            AssumptionCache *AC, const Instruction *CxtI,
-                           const DominatorTree *DT) {
-  KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT);
+                           const DominatorTree *DT, bool UseInstrInfo) {
+  KnownBits Known =
+      computeKnownBits(V, DL, Depth, AC, CxtI, DT, nullptr, UseInstrInfo);
   return Known.isNegative();
 }
 
 static bool isKnownNonEqual(const Value *V1, const Value *V2, const Query &Q);
 
 bool llvm::isKnownNonEqual(const Value *V1, const Value *V2,
-                           const DataLayout &DL,
-                           AssumptionCache *AC, const Instruction *CxtI,
-                           const DominatorTree *DT) {
-  return ::isKnownNonEqual(V1, V2, Query(DL, AC,
-                                         safeCxtI(V1, safeCxtI(V2, CxtI)),
-                                         DT));
+                           const DataLayout &DL, AssumptionCache *AC,
+                           const Instruction *CxtI, const DominatorTree *DT,
+                           bool UseInstrInfo) {
+  return ::isKnownNonEqual(V1, V2,
+                           Query(DL, AC, safeCxtI(V1, safeCxtI(V2, CxtI)), DT,
+                                 UseInstrInfo, /*ORE=*/nullptr));
 }
 
 static bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth,
                               const Query &Q);
 
 bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask,
-                             const DataLayout &DL,
-                             unsigned Depth, AssumptionCache *AC,
-                             const Instruction *CxtI, const DominatorTree *DT) {
-  return ::MaskedValueIsZero(V, Mask, Depth,
-                             Query(DL, AC, safeCxtI(V, CxtI), DT));
+                             const DataLayout &DL, unsigned Depth,
+                             AssumptionCache *AC, const Instruction *CxtI,
+                             const DominatorTree *DT, bool UseInstrInfo) {
+  return ::MaskedValueIsZero(
+      V, Mask, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo));
 }
 
 static unsigned ComputeNumSignBits(const Value *V, unsigned Depth,
@@ -293,8 +300,9 @@ static unsigned ComputeNumSignBits(const Value *V, unsigned Depth,
 unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL,
                                   unsigned Depth, AssumptionCache *AC,
                                   const Instruction *CxtI,
-                                  const DominatorTree *DT) {
-  return ::ComputeNumSignBits(V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT));
+                                  const DominatorTree *DT, bool UseInstrInfo) {
+  return ::ComputeNumSignBits(
+      V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo));
 }
 
 static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
@@ -965,7 +973,8 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known,
   switch (I->getOpcode()) {
   default: break;
   case Instruction::Load:
-    if (MDNode *MD = cast<LoadInst>(I)->getMetadata(LLVMContext::MD_range))
+    if (MDNode *MD =
+            Q.IIQ.getMetadata(cast<LoadInst>(I), LLVMContext::MD_range))
       computeKnownBitsFromRangeMetadata(*MD, Known);
     break;
   case Instruction::And: {
@@ -1014,7 +1023,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known,
     break;
   }
   case Instruction::Mul: {
-    bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
+    bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
     computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, Known,
                         Known2, Depth, Q);
     break;
@@ -1082,7 +1091,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known,
       // RHS from matchSelectPattern returns the negation part of abs pattern.
       // If the negate has an NSW flag we can assume the sign bit of the result
       // will be 0 because that makes abs(INT_MIN) undefined.
-      if (cast<Instruction>(RHS)->hasNoSignedWrap())
+      if (Q.IIQ.hasNoSignedWrap(cast<Instruction>(RHS)))
         MaxHighZeros = 1;
     }
 
@@ -1151,7 +1160,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known,
   }
   case Instruction::Shl: {
     // (shl X, C1) & C2 == 0   iff   (X & C2 >>u C1) == 0
-    bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
+    bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
     auto KZF = [NSW](const APInt &KnownZero, unsigned ShiftAmt) {
       APInt KZResult = KnownZero << ShiftAmt;
       KZResult.setLowBits(ShiftAmt); // Low bits known 0.
@@ -1202,13 +1211,13 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known,
     break;
   }
   case Instruction::Sub: {
-    bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
+    bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
     computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW,
                            Known, Known2, Depth, Q);
     break;
   }
   case Instruction::Add: {
-    bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
+    bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
     computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW,
                            Known, Known2, Depth, Q);
     break;
@@ -1369,7 +1378,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known,
                                          Known3.countMinTrailingZeros()));
 
           auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(LU);
-          if (OverflowOp && OverflowOp->hasNoSignedWrap()) {
+          if (OverflowOp && Q.IIQ.hasNoSignedWrap(OverflowOp)) {
             // If initial value of recurrence is nonnegative, and we are adding
             // a nonnegative number with nsw, the result can only be nonnegative
             // or poison value regardless of the number of times we execute the
@@ -1442,7 +1451,8 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known,
     // If range metadata is attached to this call, set known bits from that,
     // and then intersect with known bits based on other properties of the
     // function.
-    if (MDNode *MD = cast<Instruction>(I)->getMetadata(LLVMContext::MD_range))
+    if (MDNode *MD =
+            Q.IIQ.getMetadata(cast<Instruction>(I), LLVMContext::MD_range))
       computeKnownBitsFromRangeMetadata(*MD, Known);
     if (const Value *RV = ImmutableCallSite(I).getReturnedArgOperand()) {
       computeKnownBits(RV, Known2, Depth + 1, Q);
@@ -1722,7 +1732,8 @@ bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth,
   // either the original power-of-two, a larger power-of-two or zero.
   if (match(V, m_Add(m_Value(X), m_Value(Y)))) {
     const OverflowingBinaryOperator *VOBO = cast<OverflowingBinaryOperator>(V);
-    if (OrZero || VOBO->hasNoUnsignedWrap() || VOBO->hasNoSignedWrap()) {
+    if (OrZero || Q.IIQ.hasNoUnsignedWrap(VOBO) ||
+        Q.IIQ.hasNoSignedWrap(VOBO)) {
       if (match(X, m_And(m_Specific(Y), m_Value())) ||
           match(X, m_And(m_Value(), m_Specific(Y))))
         if (isKnownToBeAPowerOfTwo(Y, OrZero, Depth, Q))
@@ -1960,7 +1971,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) {
   }
 
   if (auto *I = dyn_cast<Instruction>(V)) {
-    if (MDNode *Ranges = I->getMetadata(LLVMContext::MD_range)) {
+    if (MDNode *Ranges = Q.IIQ.getMetadata(I, LLVMContext::MD_range)) {
       // If the possible ranges don't contain zero, then the value is
       // definitely non-zero.
       if (auto *Ty = dyn_cast<IntegerType>(V->getType())) {
@@ -1988,7 +1999,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) {
 
     // A Load tagged with nonnull metadata is never null.
     if (const LoadInst *LI = dyn_cast<LoadInst>(V))
-      if (LI->getMetadata(LLVMContext::MD_nonnull))
+      if (Q.IIQ.getMetadata(LI, LLVMContext::MD_nonnull))
         return true;
 
     if (auto CS = ImmutableCallSite(V)) {
@@ -2026,7 +2037,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) {
   if (match(V, m_Shl(m_Value(X), m_Value(Y)))) {
     // shl nuw can't remove any non-zero bits.
     const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V);
-    if (BO->hasNoUnsignedWrap())
+    if (Q.IIQ.hasNoUnsignedWrap(BO))
       return isKnownNonZero(X, Depth, Q);
 
     KnownBits Known(BitWidth);
@@ -2101,7 +2112,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) {
     const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V);
     // If X and Y are non-zero then so is X * Y as long as the multiplication
     // does not overflow.
-    if ((BO->hasNoSignedWrap() || BO->hasNoUnsignedWrap()) &&
+    if ((Q.IIQ.hasNoSignedWrap(BO) || Q.IIQ.hasNoUnsignedWrap(BO)) &&
         isKnownNonZero(X, Depth, Q) && isKnownNonZero(Y, Depth, Q))
       return true;
   }
@@ -2123,7 +2134,8 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) {
       if (ConstantInt *C = dyn_cast<ConstantInt>(Start)) {
         if (!C->isZero() && !C->isNegative()) {
           ConstantInt *X;
-          if ((match(Induction, m_NSWAdd(m_Specific(PN), m_ConstantInt(X))) ||
+          if (Q.IIQ.UseInstrInfo &&
+              (match(Induction, m_NSWAdd(m_Specific(PN), m_ConstantInt(X))) ||
                match(Induction, m_NUWAdd(m_Specific(PN), m_ConstantInt(X)))) &&
               !X->isNegative())
             return true;
@@ -3745,12 +3757,10 @@ bool llvm::mayBeMemoryDependent(const Instruction &I) {
   return I.mayReadOrWriteMemory() || !isSafeToSpeculativelyExecute(&I);
 }
 
-OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
-                                                   const Value *RHS,
-                                                   const DataLayout &DL,
-                                                   AssumptionCache *AC,
-                                                   const Instruction *CxtI,
-                                                   const DominatorTree *DT) {
+OverflowResult llvm::computeOverflowForUnsignedMul(
+    const Value *LHS, const Value *RHS, const DataLayout &DL,
+    AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT,
+    bool UseInstrInfo) {
   // Multiplying n * m significant bits yields a result of n + m significant
   // bits. If the total number of significant bits does not exceed the
   // result bit width (minus 1), there is no overflow.
@@ -3760,8 +3770,10 @@ OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
   unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
   KnownBits LHSKnown(BitWidth);
   KnownBits RHSKnown(BitWidth);
-  computeKnownBits(LHS, LHSKnown, DL, /*Depth=*/0, AC, CxtI, DT);
-  computeKnownBits(RHS, RHSKnown, DL, /*Depth=*/0, AC, CxtI, DT);
+  computeKnownBits(LHS, LHSKnown, DL, /*Depth=*/0, AC, CxtI, DT, nullptr,
+                   UseInstrInfo);
+  computeKnownBits(RHS, RHSKnown, DL, /*Depth=*/0, AC, CxtI, DT, nullptr,
+                   UseInstrInfo);
   // Note that underestimating the number of zero bits gives a more
   // conservative answer.
   unsigned ZeroBits = LHSKnown.countMinLeadingZeros() +
@@ -3792,12 +3804,11 @@ OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
   return OverflowResult::MayOverflow;
 }
 
-OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
-                                                 const Value *RHS,
-                                                 const DataLayout &DL,
-                                                 AssumptionCache *AC,
-                                                 const Instruction *CxtI,
-                                                 const DominatorTree *DT) {
+OverflowResult
+llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
+                                  const DataLayout &DL, AssumptionCache *AC,
+                                  const Instruction *CxtI,
+                                  const DominatorTree *DT, bool UseInstrInfo) {
   // Multiplying n * m significant bits yields a result of n + m significant
   // bits. If the total number of significant bits does not exceed the
   // result bit width (minus 1), there is no overflow.
@@ -3826,23 +3837,25 @@ OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
     // product is exactly the minimum negative number.
     // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
     // For simplicity we just check if at least one side is not negative.
-    KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT);
-    KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT);
+    KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT,
+                                          nullptr, UseInstrInfo);
+    KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT,
+                                          nullptr, UseInstrInfo);
     if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative())
       return OverflowResult::NeverOverflows;
   }
   return OverflowResult::MayOverflow;
 }
 
-OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS,
-                                                   const Value *RHS,
-                                                   const DataLayout &DL,
-                                                   AssumptionCache *AC,
-                                                   const Instruction *CxtI,
-                                                   const DominatorTree *DT) {
-  KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT);
+OverflowResult llvm::computeOverflowForUnsignedAdd(
+    const Value *LHS, const Value *RHS, const DataLayout &DL,
+    AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT,
+    bool UseInstrInfo) {
+  KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT,
+                                        nullptr, UseInstrInfo);
   if (LHSKnown.isNonNegative() || LHSKnown.isNegative()) {
-    KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT);
+    KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT,
+                                          nullptr, UseInstrInfo);
 
     if (LHSKnown.isNegative() && RHSKnown.isNegative()) {
       // The sign bit is set in both cases: this MUST overflow.
index bfb7d3d..5c2da6c 100644 (file)
@@ -657,8 +657,8 @@ public:
          TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA,
          const DataLayout &DL)
       : F(F), DT(DT), TLI(TLI), AA(AA), MSSA(MSSA), DL(DL),
-        PredInfo(make_unique<PredicateInfo>(F, *DT, *AC)), SQ(DL, TLI, DT, AC) {
-  }
+        PredInfo(make_unique<PredicateInfo>(F, *DT, *AC)),
+        SQ(DL, TLI, DT, AC, /*CtxI=*/nullptr, /*UseInstrInfo=*/false) {}
 
   bool runGVN();
 
index efec237..f20871f 100644 (file)
@@ -79,7 +79,8 @@ define i32 @test2() local_unnamed_addr {
 ; CHECK:       cond.end.i:
 ; CHECK-NEXT:    [[COND_I:%.*]] = phi i32 [ [[DIV_I]], [[COND_TRUE_I]] ], [ 0, [[FOR_BODY_I]] ]
 ; CHECK-NEXT:    [[INC_I]] = add nuw nsw i32 [[F_08_I]], 1
-; CHECK-NEXT:    [[CALL5:%.*]] = tail call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([6 x i8], [6 x i8]* @.str4, i64 0, i64 0), i32 0)
+; CHECK-NEXT:    [[TEST:%.*]] = udiv i32 [[MUL_I]], [[INC_I]]
+; CHECK-NEXT:    [[CALL5:%.*]] = tail call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([6 x i8], [6 x i8]* @.str4, i64 0, i64 0), i32 [[TEST]])
 ; CHECK-NEXT:    [[EXITCOND_I:%.*]] = icmp eq i32 [[INC_I]], 4
 ; CHECK-NEXT:    br i1 [[EXITCOND_I]], label [[FN1_EXIT:%.*]], label [[FOR_BODY_I]]
 ; CHECK:       fn1.exit: