From 492a4a428f77556d413c2179fad9ba4ae6d130b9 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Sat, 16 Oct 2021 20:25:18 +0200 Subject: [PATCH] [APInt] Fix 1-bit edge case in smul_ov() The sdiv used to check for overflow can itself overflow if the LHS is signed min and the RHS is -1. The code tried to account for this by also checking the commuted version. However, for 1-bit values, signed min and -1 are the same value, so both divisions overflow. As such, the overflow for -1 * -1 was not detected (which results in -1 rather than 1 for 1-bit values). Fix this by explicitly checking for this case instead. Noticed while adding exhaustive test coverage for smul_ov(), which is also part of this commit. --- llvm/lib/Support/APInt.cpp | 5 +++-- llvm/unittests/ADT/APIntTest.cpp | 21 ++++++++++++++++++--- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp index 15b099a..52a2099 100644 --- a/llvm/lib/Support/APInt.cpp +++ b/llvm/lib/Support/APInt.cpp @@ -1952,8 +1952,9 @@ APInt APInt::sdiv_ov(const APInt &RHS, bool &Overflow) const { APInt APInt::smul_ov(const APInt &RHS, bool &Overflow) const { APInt Res = *this * RHS; - if (*this != 0 && RHS != 0) - Overflow = Res.sdiv(RHS) != *this || Res.sdiv(*this) != RHS; + if (RHS != 0) + Overflow = Res.sdiv(RHS) != *this || + (isMinSignedValue() && RHS.isAllOnes()); else Overflow = false; return Res; diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp index 83bbc5e..58c17b4 100644 --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -2729,9 +2729,24 @@ TEST(APIntTest, umul_ov) { for (unsigned Bits = 1; Bits <= 5; ++Bits) for (unsigned A = 0; A != 1u << Bits; ++A) for (unsigned B = 0; B != 1u << Bits; ++B) { - APInt C = APInt(Bits, A).umul_ov(APInt(Bits, B), Overflow); - APInt D = APInt(2 * Bits, A) * APInt(2 * Bits, B); - EXPECT_TRUE(D.getHiBits(Bits).isNullValue() != Overflow); + APInt N1 = APInt(Bits, A), N2 = APInt(Bits, B); + APInt Narrow = N1.umul_ov(N2, Overflow); + APInt Wide = N1.zext(2 * Bits) * N2.zext(2 * Bits); + EXPECT_EQ(Wide.trunc(Bits), Narrow); + EXPECT_EQ(Narrow.zext(2 * Bits) != Wide, Overflow); + } +} + +TEST(APIntTest, smul_ov) { + for (unsigned Bits = 1; Bits <= 5; ++Bits) + for (unsigned A = 0; A != 1u << Bits; ++A) + for (unsigned B = 0; B != 1u << Bits; ++B) { + bool Overflow; + APInt N1 = APInt(Bits, A), N2 = APInt(Bits, B); + APInt Narrow = N1.smul_ov(N2, Overflow); + APInt Wide = N1.sext(2 * Bits) * N2.sext(2 * Bits); + EXPECT_EQ(Wide.trunc(Bits), Narrow); + EXPECT_EQ(Narrow.sext(2 * Bits) != Wide, Overflow); } } -- 2.7.4