From 1206a18d417aa6d9d3d6f9e25cbf0b07ed1409a4 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Mon, 11 Apr 2022 15:44:57 -0400 Subject: [PATCH] [InstCombine] guard against splat-mul corner case The test is already simplified, and I'm not sure how to write a test to exercise the new clause. But it protects the 2-bit pattern from miscompiling as noted in D123453. https://alive2.llvm.org/ce/z/QPyVfv (If we managed to fall into the mul transform, it would wrongly create a zero on this pattern.) --- llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp | 4 ++-- llvm/test/Transforms/InstCombine/lshr.ll | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index b38ee71..26fe8e0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -1165,11 +1165,11 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { // Look for a "splat" mul pattern - it replicates bits across each half of // a value, so a right shift is just a mask of the low bits: - // lshr i32 (mul nuw X, Pow2+1), 16 --> and X, Pow2-1 + // lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1 // TODO: Generalize to allow more than just half-width shifts? const APInt *MulC; if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) && - ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() && + BitWidth > 2 && ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() && MulC->logBase2() == ShAmtC) return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2)); diff --git a/llvm/test/Transforms/InstCombine/lshr.ll b/llvm/test/Transforms/InstCombine/lshr.ll index 2db2c09..e4e869b 100644 --- a/llvm/test/Transforms/InstCombine/lshr.ll +++ b/llvm/test/Transforms/InstCombine/lshr.ll @@ -399,6 +399,17 @@ define i32 @mul_splat_fold_no_nuw(i32 %x) { ret i32 %t } +; Negative test (but simplifies before we reach the mul_splat transform)- need more than 2 bits + +define i2 @mul_splat_fold_too_narrow(i2 %x) { +; CHECK-LABEL: @mul_splat_fold_too_narrow( +; CHECK-NEXT: ret i2 [[X:%.*]] +; + %m = mul nuw i2 %x, 2 + %t = lshr i2 %m, 1 + ret i2 %t +} + define i32 @negative_and_odd(i32 %x) { ; CHECK-LABEL: @negative_and_odd( ; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 31 -- 2.7.4