From a7a2860d0eee37d9e0fd0b6a8e3d884f8ee4ec16 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Thu, 20 Jan 2022 14:51:45 -0500 Subject: [PATCH] [InstCombine] convert mul with sexted bool and constant to select We already have the related folds for zext-of-bool, so it should make things more consistent to have this transform to select for sext-of-bool too: https://alive2.llvm.org/ce/z/YikdfA Fixes #53319 --- llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 12 ++++++++++-- llvm/test/Transforms/InstCombine/mul.ll | 8 +++----- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index aca7ec8..076c313 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -348,13 +348,21 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { return CastInst::Create(Instruction::SExt, And, I.getType()); } - // (bool X) * Y --> X ? Y : 0 - // Y * (bool X) --> X ? Y : 0 + // (zext bool X) * Y --> X ? Y : 0 + // Y * (zext bool X) --> X ? Y : 0 if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) return SelectInst::Create(X, Op1, ConstantInt::get(I.getType(), 0)); if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) return SelectInst::Create(X, Op0, ConstantInt::get(I.getType(), 0)); + // (sext bool X) * C --> X ? -C : 0 + Constant *ImmC; + if (match(Op0, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1) && + match(Op1, m_ImmConstant(ImmC))) { + Constant *NegC = ConstantExpr::getNeg(ImmC); + return SelectInst::Create(X, NegC, ConstantInt::getNullValue(I.getType())); + } + // (lshr X, 31) * Y --> (ashr X, 31) & Y // Y * (lshr X, 31) --> (ashr X, 31) & Y // TODO: We are not checking one-use because the elimination of the multiply diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll index e99d1e9..5694704 100644 --- a/llvm/test/Transforms/InstCombine/mul.ll +++ b/llvm/test/Transforms/InstCombine/mul.ll @@ -133,8 +133,7 @@ define <2 x i32> @mul_bool_vec_commute(<2 x i32> %px, <2 x i1> %y) { define i32 @mul_sext_bool(i1 %x) { ; CHECK-LABEL: @mul_sext_bool( -; CHECK-NEXT: [[S:%.*]] = sext i1 [[X:%.*]] to i32 -; CHECK-NEXT: [[M:%.*]] = mul nsw i32 [[S]], 42 +; CHECK-NEXT: [[M:%.*]] = select i1 [[X:%.*]], i32 -42, i32 0 ; CHECK-NEXT: ret i32 [[M]] ; %s = sext i1 %x to i32 @@ -146,7 +145,7 @@ define i32 @mul_sext_bool_use(i1 %x) { ; CHECK-LABEL: @mul_sext_bool_use( ; CHECK-NEXT: [[S:%.*]] = sext i1 [[X:%.*]] to i32 ; CHECK-NEXT: call void @use32(i32 [[S]]) -; CHECK-NEXT: [[M:%.*]] = mul nsw i32 [[S]], 42 +; CHECK-NEXT: [[M:%.*]] = select i1 [[X]], i32 -42, i32 0 ; CHECK-NEXT: ret i32 [[M]] ; %s = sext i1 %x to i32 @@ -157,8 +156,7 @@ define i32 @mul_sext_bool_use(i1 %x) { define <2 x i8> @mul_sext_bool_vec(<2 x i1> %x) { ; CHECK-LABEL: @mul_sext_bool_vec( -; CHECK-NEXT: [[S:%.*]] = sext <2 x i1> [[X:%.*]] to <2 x i8> -; CHECK-NEXT: [[M:%.*]] = mul <2 x i8> [[S]], +; CHECK-NEXT: [[M:%.*]] = select <2 x i1> [[X:%.*]], <2 x i8> , <2 x i8> zeroinitializer ; CHECK-NEXT: ret <2 x i8> [[M]] ; %s = sext <2 x i1> %x to <2 x i8> -- 2.7.4