From 23b41986527a3fc5615480a8f7a0b0debd5fcef4 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Mon, 4 Jan 2021 12:18:59 +0000 Subject: [PATCH] [Support] Add KnownBits::icmp helpers. Check if all possible values for a pair of knownbits give the same icmp result - these are based off the checks performed in InstCombineCompares.cpp and D86578. Add exhaustive unit test coverage - a followup will update InstCombineCompares.cpp to use this. --- llvm/include/llvm/Support/KnownBits.h | 31 ++++++++++++ llvm/lib/Support/KnownBits.cpp | 69 +++++++++++++++++++++++++ llvm/unittests/Support/KnownBitsTest.cpp | 87 ++++++++++++++++++++++++++++++++ 3 files changed, 187 insertions(+) diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h index ec88b98..edb771d 100644 --- a/llvm/include/llvm/Support/KnownBits.h +++ b/llvm/include/llvm/Support/KnownBits.h @@ -15,6 +15,7 @@ #define LLVM_SUPPORT_KNOWNBITS_H #include "llvm/ADT/APInt.h" +#include "llvm/ADT/Optional.h" namespace llvm { @@ -328,6 +329,36 @@ public: /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS. static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS); + /// Determine if these known bits always give the same ICMP_EQ result. + static Optional eq(const KnownBits &LHS, const KnownBits &RHS); + + /// Determine if these known bits always give the same ICMP_NE result. + static Optional ne(const KnownBits &LHS, const KnownBits &RHS); + + /// Determine if these known bits always give the same ICMP_UGT result. + static Optional ugt(const KnownBits &LHS, const KnownBits &RHS); + + /// Determine if these known bits always give the same ICMP_UGE result. + static Optional uge(const KnownBits &LHS, const KnownBits &RHS); + + /// Determine if these known bits always give the same ICMP_ULT result. + static Optional ult(const KnownBits &LHS, const KnownBits &RHS); + + /// Determine if these known bits always give the same ICMP_ULE result. + static Optional ule(const KnownBits &LHS, const KnownBits &RHS); + + /// Determine if these known bits always give the same ICMP_SGT result. + static Optional sgt(const KnownBits &LHS, const KnownBits &RHS); + + /// Determine if these known bits always give the same ICMP_SGE result. + static Optional sge(const KnownBits &LHS, const KnownBits &RHS); + + /// Determine if these known bits always give the same ICMP_SLT result. + static Optional slt(const KnownBits &LHS, const KnownBits &RHS); + + /// Determine if these known bits always give the same ICMP_SLE result. + static Optional sle(const KnownBits &LHS, const KnownBits &RHS); + /// Insert the bits from a smaller known bits starting at bitPosition. void insertBits(const KnownBits &SubBits, unsigned BitPosition) { Zero.insertBits(SubBits.Zero, BitPosition); diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp index 2c25b7d..0147d21 100644 --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -268,6 +268,75 @@ KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS) { return Known; } +Optional KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) { + if (LHS.isConstant() && RHS.isConstant()) + return Optional(LHS.getConstant() == RHS.getConstant()); + if (LHS.getMaxValue().ult(RHS.getMinValue()) || + LHS.getMinValue().ugt(RHS.getMaxValue())) + return Optional(false); + if (LHS.One.intersects(RHS.Zero) || RHS.One.intersects(LHS.Zero)) + return Optional(false); + return None; +} + +Optional KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) { + if (Optional KnownEQ = eq(LHS, RHS)) + return Optional(!KnownEQ.getValue()); + return None; +} + +Optional KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) { + if (LHS.isConstant() && RHS.isConstant()) + return Optional(LHS.getConstant().ugt(RHS.getConstant())); + // LHS >u RHS -> false if umax(LHS) <= umax(RHS) + if (LHS.getMaxValue().ule(RHS.getMinValue())) + return Optional(false); + // LHS >u RHS -> true if umin(LHS) > umax(RHS) + if (LHS.getMinValue().ugt(RHS.getMaxValue())) + return Optional(true); + return None; +} + +Optional KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) { + if (Optional IsUGT = ugt(RHS, LHS)) + return Optional(!IsUGT.getValue()); + return None; +} + +Optional KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) { + return ugt(RHS, LHS); +} + +Optional KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) { + return uge(RHS, LHS); +} + +Optional KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) { + if (LHS.isConstant() && RHS.isConstant()) + return Optional(LHS.getConstant().sgt(RHS.getConstant())); + // LHS >s RHS -> false if smax(LHS) <= smax(RHS) + if (LHS.getSignedMaxValue().sle(RHS.getSignedMinValue())) + return Optional(false); + // LHS >s RHS -> true if smin(LHS) > smax(RHS) + if (LHS.getSignedMinValue().sgt(RHS.getSignedMaxValue())) + return Optional(true); + return None; +} + +Optional KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) { + if (Optional KnownSGT = sgt(RHS, LHS)) + return Optional(!KnownSGT.getValue()); + return None; +} + +Optional KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) { + return sgt(RHS, LHS); +} + +Optional KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) { + return sge(RHS, LHS); +} + KnownBits KnownBits::abs(bool IntMinIsPoison) const { // If the source's MSB is zero then we know the rest of the bits already. if (isNonNegative()) diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp index 528a645..ba587a1 100644 --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -281,6 +281,93 @@ TEST(KnownBitsTest, UnaryExhaustive) { }); } +TEST(KnownBitsTest, ICmpExhaustive) { + unsigned Bits = 4; + ForeachKnownBits(Bits, [&](const KnownBits &Known1) { + ForeachKnownBits(Bits, [&](const KnownBits &Known2) { + bool AllEQ = true, NoneEQ = true; + bool AllNE = true, NoneNE = true; + bool AllUGT = true, NoneUGT = true; + bool AllUGE = true, NoneUGE = true; + bool AllULT = true, NoneULT = true; + bool AllULE = true, NoneULE = true; + bool AllSGT = true, NoneSGT = true; + bool AllSGE = true, NoneSGE = true; + bool AllSLT = true, NoneSLT = true; + bool AllSLE = true, NoneSLE = true; + + ForeachNumInKnownBits(Known1, [&](const APInt &N1) { + ForeachNumInKnownBits(Known2, [&](const APInt &N2) { + AllEQ &= N1.eq(N2); + AllNE &= N1.ne(N2); + AllUGT &= N1.ugt(N2); + AllUGE &= N1.uge(N2); + AllULT &= N1.ult(N2); + AllULE &= N1.ule(N2); + AllSGT &= N1.sgt(N2); + AllSGE &= N1.sge(N2); + AllSLT &= N1.slt(N2); + AllSLE &= N1.sle(N2); + NoneEQ &= !N1.eq(N2); + NoneNE &= !N1.ne(N2); + NoneUGT &= !N1.ugt(N2); + NoneUGE &= !N1.uge(N2); + NoneULT &= !N1.ult(N2); + NoneULE &= !N1.ule(N2); + NoneSGT &= !N1.sgt(N2); + NoneSGE &= !N1.sge(N2); + NoneSLT &= !N1.slt(N2); + NoneSLE &= !N1.sle(N2); + }); + }); + + Optional KnownEQ = KnownBits::eq(Known1, Known2); + Optional KnownNE = KnownBits::ne(Known1, Known2); + Optional KnownUGT = KnownBits::ugt(Known1, Known2); + Optional KnownUGE = KnownBits::uge(Known1, Known2); + Optional KnownULT = KnownBits::ult(Known1, Known2); + Optional KnownULE = KnownBits::ule(Known1, Known2); + Optional KnownSGT = KnownBits::sgt(Known1, Known2); + Optional KnownSGE = KnownBits::sge(Known1, Known2); + Optional KnownSLT = KnownBits::slt(Known1, Known2); + Optional KnownSLE = KnownBits::sle(Known1, Known2); + + EXPECT_EQ(AllEQ || NoneEQ, KnownEQ.hasValue()); + EXPECT_EQ(AllNE || NoneNE, KnownNE.hasValue()); + EXPECT_EQ(AllUGT || NoneUGT, KnownUGT.hasValue()); + EXPECT_EQ(AllUGE || NoneUGE, KnownUGE.hasValue()); + EXPECT_EQ(AllULT || NoneULT, KnownULT.hasValue()); + EXPECT_EQ(AllULE || NoneULE, KnownULE.hasValue()); + EXPECT_EQ(AllSGT || NoneSGT, KnownSGT.hasValue()); + EXPECT_EQ(AllSGE || NoneSGE, KnownSGE.hasValue()); + EXPECT_EQ(AllSLT || NoneSLT, KnownSLT.hasValue()); + EXPECT_EQ(AllSLE || NoneSLE, KnownSLE.hasValue()); + + EXPECT_EQ(AllEQ, KnownEQ.hasValue() && KnownEQ.getValue()); + EXPECT_EQ(AllNE, KnownNE.hasValue() && KnownNE.getValue()); + EXPECT_EQ(AllUGT, KnownUGT.hasValue() && KnownUGT.getValue()); + EXPECT_EQ(AllUGE, KnownUGE.hasValue() && KnownUGE.getValue()); + EXPECT_EQ(AllULT, KnownULT.hasValue() && KnownULT.getValue()); + EXPECT_EQ(AllULE, KnownULE.hasValue() && KnownULE.getValue()); + EXPECT_EQ(AllSGT, KnownSGT.hasValue() && KnownSGT.getValue()); + EXPECT_EQ(AllSGE, KnownSGE.hasValue() && KnownSGE.getValue()); + EXPECT_EQ(AllSLT, KnownSLT.hasValue() && KnownSLT.getValue()); + EXPECT_EQ(AllSLE, KnownSLE.hasValue() && KnownSLE.getValue()); + + EXPECT_EQ(NoneEQ, KnownEQ.hasValue() && !KnownEQ.getValue()); + EXPECT_EQ(NoneNE, KnownNE.hasValue() && !KnownNE.getValue()); + EXPECT_EQ(NoneUGT, KnownUGT.hasValue() && !KnownUGT.getValue()); + EXPECT_EQ(NoneUGE, KnownUGE.hasValue() && !KnownUGE.getValue()); + EXPECT_EQ(NoneULT, KnownULT.hasValue() && !KnownULT.getValue()); + EXPECT_EQ(NoneULE, KnownULE.hasValue() && !KnownULE.getValue()); + EXPECT_EQ(NoneSGT, KnownSGT.hasValue() && !KnownSGT.getValue()); + EXPECT_EQ(NoneSGE, KnownSGE.hasValue() && !KnownSGE.getValue()); + EXPECT_EQ(NoneSLT, KnownSLT.hasValue() && !KnownSLT.getValue()); + EXPECT_EQ(NoneSLE, KnownSLE.hasValue() && !KnownSLE.getValue()); + }); + }); +} + TEST(KnownBitsTest, GetMinMaxVal) { unsigned Bits = 4; ForeachKnownBits(Bits, [&](const KnownBits &Known) { -- 2.7.4