From 29f8628d1fc8d96670e13562c4d92fc916bd0ce1 Mon Sep 17 00:00:00 2001 From: Juneyoung Lee Date: Tue, 5 Jan 2021 10:09:49 +0900 Subject: [PATCH] [Constant] Add containsPoisonElement This patch - Adds containsPoisonElement that checks existence of poison in constant vector elements, - Renames containsUndefElement to containsUndefOrPoisonElement to clarify its behavior & updates its uses properly With this patch, isGuaranteedNotToBeUndefOrPoison's tests w.r.t constant vectors are added because its analysis is improved. Thanks! Reviewed By: nikic Differential Revision: https://reviews.llvm.org/D94053 --- llvm/include/llvm/IR/Constant.h | 14 +++++--- llvm/lib/Analysis/ValueTracking.cpp | 7 ++-- llvm/lib/IR/ConstantFold.cpp | 2 +- llvm/lib/IR/Constants.cpp | 25 +++++++++++---- .../Transforms/InstCombine/InstCombineCompares.cpp | 5 +-- .../Transforms/InstCombine/InstCombineNegator.cpp | 4 +-- llvm/unittests/Analysis/ValueTrackingTest.cpp | 24 ++++++++++++++ llvm/unittests/IR/ConstantsTest.cpp | 37 ++++++++++++++++++++++ 8 files changed, 98 insertions(+), 20 deletions(-) diff --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h index 97650c2..0190aca 100644 --- a/llvm/include/llvm/IR/Constant.h +++ b/llvm/include/llvm/IR/Constant.h @@ -101,11 +101,15 @@ public: /// lane, the constants still match. bool isElementWiseEqual(Value *Y) const; - /// Return true if this is a vector constant that includes any undefined - /// elements. Since it is impossible to inspect a scalable vector element- - /// wise at compile time, this function returns true only if the entire - /// vector is undef - bool containsUndefElement() const; + /// Return true if this is a vector constant that includes any undef or + /// poison elements. Since it is impossible to inspect a scalable vector + /// element- wise at compile time, this function returns true only if the + /// entire vector is undef or poison. + bool containsUndefOrPoisonElement() const; + + /// Return true if this is a vector constant that includes any poison + /// elements. + bool containsPoisonElement() const; /// Return true if this is a fixed width vector constant that includes /// any constant expressions. diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index e15d4f0..1c75c5f 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -4895,7 +4895,8 @@ static bool isGuaranteedNotToBeUndefOrPoison(const Value *V, return true; if (C->getType()->isVectorTy() && !isa(C)) - return (PoisonOnly || !C->containsUndefElement()) && + return (PoisonOnly ? !C->containsPoisonElement() + : !C->containsUndefOrPoisonElement()) && !C->containsConstantExpression(); } @@ -5636,10 +5637,10 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, // elements because those can not be back-propagated for analysis. Value *OutputZeroVal = nullptr; if (match(TrueVal, m_AnyZeroFP()) && !match(FalseVal, m_AnyZeroFP()) && - !cast(TrueVal)->containsUndefElement()) + !cast(TrueVal)->containsUndefOrPoisonElement()) OutputZeroVal = TrueVal; else if (match(FalseVal, m_AnyZeroFP()) && !match(TrueVal, m_AnyZeroFP()) && - !cast(FalseVal)->containsUndefElement()) + !cast(FalseVal)->containsUndefOrPoisonElement()) OutputZeroVal = FalseVal; if (OutputZeroVal) { diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp index 4774568..03cb108 100644 --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -811,7 +811,7 @@ Constant *llvm::ConstantFoldSelectInstruction(Constant *Cond, return true; if (C->getType()->isVectorTy()) - return !C->containsUndefElement() && !C->containsConstantExpression(); + return !C->containsPoisonElement() && !C->containsConstantExpression(); // TODO: Recursively analyze aggregates or other constants. return false; diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index a38302d..5aa819d 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -304,31 +304,42 @@ bool Constant::isElementWiseEqual(Value *Y) const { return isa(CmpEq) || match(CmpEq, m_One()); } -bool Constant::containsUndefElement() const { - if (auto *VTy = dyn_cast(getType())) { - if (isa(this)) +static bool +containsUndefinedElement(const Constant *C, + function_ref HasFn) { + if (auto *VTy = dyn_cast(C->getType())) { + if (HasFn(C)) return true; - if (isa(this)) + if (isa(C)) return false; - if (isa(getType())) + if (isa(C->getType())) return false; for (unsigned i = 0, e = cast(VTy)->getNumElements(); i != e; ++i) - if (isa(getAggregateElement(i))) + if (HasFn(C->getAggregateElement(i))) return true; } return false; } +bool Constant::containsUndefOrPoisonElement() const { + return containsUndefinedElement( + this, [&](const auto *C) { return isa(C); }); +} + +bool Constant::containsPoisonElement() const { + return containsUndefinedElement( + this, [&](const auto *C) { return isa(C); }); +} + bool Constant::containsConstantExpression() const { if (auto *VTy = dyn_cast(getType())) { for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) if (isa(getAggregateElement(i))) return true; } - return false; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 87d4b40..0887779 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -3370,7 +3370,7 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, Type *OpTy = M->getType(); auto *VecC = dyn_cast(M); auto *OpVTy = dyn_cast(OpTy); - if (OpVTy && VecC && VecC->containsUndefElement()) { + if (OpVTy && VecC && VecC->containsUndefOrPoisonElement()) { Constant *SafeReplacementConstant = nullptr; for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) { if (!isa(VecC->getAggregateElement(i))) { @@ -5259,7 +5259,8 @@ InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, // It may not be safe to change a compare predicate in the presence of // undefined elements, so replace those elements with the first safe constant // that we found. - if (C->containsUndefElement()) { + // TODO: in case of poison, it is safe; let's replace undefs only. + if (C->containsUndefOrPoisonElement()) { assert(SafeReplacementConstant && "Replacement constant not set"); C = Constant::replaceUndefsWith(C, SafeReplacementConstant); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp index 494c58e..7718c8b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -239,8 +239,8 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { // While this is normally not behind a use-check, // let's consider division to be special since it's costly. if (auto *Op1C = dyn_cast(I->getOperand(1))) { - if (!Op1C->containsUndefElement() && Op1C->isNotMinSignedValue() && - Op1C->isNotOneValue()) { + if (!Op1C->containsUndefOrPoisonElement() && + Op1C->isNotMinSignedValue() && Op1C->isNotOneValue()) { Value *BO = Builder.CreateSDiv(I->getOperand(0), ConstantExpr::getNeg(Op1C), I->getName() + ".neg"); diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp index 0d65774..d70fd6e 100644 --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -888,6 +888,30 @@ TEST_F(ValueTrackingTest, isGuaranteedNotToBeUndefOrPoison) { EXPECT_EQ(isGuaranteedNotToBeUndefOrPoison(PoisonValue::get(IntegerType::get(Context, 8))), false); EXPECT_EQ(isGuaranteedNotToBePoison(UndefValue::get(IntegerType::get(Context, 8))), true); EXPECT_EQ(isGuaranteedNotToBePoison(PoisonValue::get(IntegerType::get(Context, 8))), false); + + Type *Int32Ty = Type::getInt32Ty(Context); + Constant *CU = UndefValue::get(Int32Ty); + Constant *CP = PoisonValue::get(Int32Ty); + Constant *C1 = ConstantInt::get(Int32Ty, 1); + Constant *C2 = ConstantInt::get(Int32Ty, 2); + + { + Constant *V1 = ConstantVector::get({C1, C2}); + EXPECT_TRUE(isGuaranteedNotToBeUndefOrPoison(V1)); + EXPECT_TRUE(isGuaranteedNotToBePoison(V1)); + } + + { + Constant *V2 = ConstantVector::get({C1, CU}); + EXPECT_FALSE(isGuaranteedNotToBeUndefOrPoison(V2)); + EXPECT_TRUE(isGuaranteedNotToBePoison(V2)); + } + + { + Constant *V3 = ConstantVector::get({C1, CP}); + EXPECT_FALSE(isGuaranteedNotToBeUndefOrPoison(V3)); + EXPECT_FALSE(isGuaranteedNotToBePoison(V3)); + } } TEST_F(ValueTrackingTest, isGuaranteedNotToBeUndefOrPoison_assume) { diff --git a/llvm/unittests/IR/ConstantsTest.cpp b/llvm/unittests/IR/ConstantsTest.cpp index 9dd1ba8..afae154 100644 --- a/llvm/unittests/IR/ConstantsTest.cpp +++ b/llvm/unittests/IR/ConstantsTest.cpp @@ -585,6 +585,43 @@ TEST(ConstantsTest, FoldGlobalVariablePtr) { Instruction::And, TheConstantExpr, TheConstant)->isNullValue()); } +// Check that containsUndefOrPoisonElement and containsPoisonElement is working +// great + +TEST(ConstantsTest, containsUndefElemTest) { + LLVMContext Context; + + Type *Int32Ty = Type::getInt32Ty(Context); + Constant *CU = UndefValue::get(Int32Ty); + Constant *CP = PoisonValue::get(Int32Ty); + Constant *C1 = ConstantInt::get(Int32Ty, 1); + Constant *C2 = ConstantInt::get(Int32Ty, 2); + + { + Constant *V1 = ConstantVector::get({C1, C2}); + EXPECT_FALSE(V1->containsUndefOrPoisonElement()); + EXPECT_FALSE(V1->containsPoisonElement()); + } + + { + Constant *V2 = ConstantVector::get({C1, CU}); + EXPECT_TRUE(V2->containsUndefOrPoisonElement()); + EXPECT_FALSE(V2->containsPoisonElement()); + } + + { + Constant *V3 = ConstantVector::get({C1, CP}); + EXPECT_TRUE(V3->containsUndefOrPoisonElement()); + EXPECT_TRUE(V3->containsPoisonElement()); + } + + { + Constant *V4 = ConstantVector::get({CU, CP}); + EXPECT_TRUE(V4->containsUndefOrPoisonElement()); + EXPECT_TRUE(V4->containsPoisonElement()); + } +} + // Check that undefined elements in vector constants are matched // correctly for both integer and floating-point types. Just don't // crash on vectors of pointers (could be handled?). -- 2.7.4