From 13ec913bdf500e2354cc55bf29e2f5d99e0c709e Mon Sep 17 00:00:00 2001 From: Roman Lebedev Date: Tue, 20 Apr 2021 21:18:26 +0300 Subject: [PATCH] [InstCombine] Recognize `((x * y) s/ x) !=/== y` as an signed multiplication overflow check (PR48769) We already had support for it's unsigned variant, so simply extend it to also handle the signed variant. Fixes https://bugs.llvm.org/show_bug.cgi?id=48769 --- .../Transforms/InstCombine/InstCombineCompares.cpp | 39 ++++++++++------- .../Transforms/InstCombine/InstCombineInternal.h | 2 +- ...gned-mul-lack-of-overflow-check-via-mul-sdiv.ll | 51 +++++++++++----------- .../signed-mul-overflow-check-via-mul-sdiv.ll | 45 +++++++++---------- 4 files changed, 71 insertions(+), 66 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 4e3ddae..41b4857 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -3672,19 +3672,22 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, /// Fold /// (-1 u/ x) u< y -/// ((x * y) u/ x) != y +/// ((x * y) ?/ x) != y /// to -/// @llvm.umul.with.overflow(x, y) plus extraction of overflow bit +/// @llvm.?mul.with.overflow(x, y) plus extraction of overflow bit /// Note that the comparison is commutative, while inverted (u>=, ==) predicate /// will mean that we are looking for the opposite answer. -Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { +Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) { ICmpInst::Predicate Pred; Value *X, *Y; Instruction *Mul; + Instruction *Div; bool NeedNegation; // Look for: (-1 u/ x) u= y if (!I.isEquality() && - match(&I, m_c_ICmp(Pred, m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), + match(&I, m_c_ICmp(Pred, + m_CombineAnd(m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), + m_Instruction(Div)), m_Value(Y)))) { Mul = nullptr; @@ -3699,13 +3702,16 @@ Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { default: return nullptr; // Wrong predicate. } - } else // Look for: ((x * y) u/ x) !=/== y + } else // Look for: ((x * y) / x) !=/== y if (I.isEquality() && - match(&I, m_c_ICmp(Pred, m_Value(Y), - m_OneUse(m_UDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y), + match(&I, + m_c_ICmp(Pred, m_Value(Y), + m_CombineAnd( + m_OneUse(m_IDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y), m_Value(X)), m_Instruction(Mul)), - m_Deferred(X)))))) { + m_Deferred(X))), + m_Instruction(Div))))) { NeedNegation = Pred == ICmpInst::Predicate::ICMP_EQ; } else return nullptr; @@ -3717,19 +3723,22 @@ Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { if (MulHadOtherUses) Builder.SetInsertPoint(Mul); - Function *F = Intrinsic::getDeclaration( - I.getModule(), Intrinsic::umul_with_overflow, X->getType()); - CallInst *Call = Builder.CreateCall(F, {X, Y}, "umul"); + Function *F = Intrinsic::getDeclaration(I.getModule(), + Div->getOpcode() == Instruction::UDiv + ? Intrinsic::umul_with_overflow + : Intrinsic::smul_with_overflow, + X->getType()); + CallInst *Call = Builder.CreateCall(F, {X, Y}, "mul"); // If the multiplication was used elsewhere, to ensure that we don't leave // "duplicate" instructions, replace uses of that original multiplication // with the multiplication result from the with.overflow intrinsic. if (MulHadOtherUses) - replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "umul.val")); + replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "mul.val")); - Value *Res = Builder.CreateExtractValue(Call, 1, "umul.ov"); + Value *Res = Builder.CreateExtractValue(Call, 1, "mul.ov"); if (NeedNegation) // This technically increases instruction count. - Res = Builder.CreateNot(Res, "umul.not.ov"); + Res = Builder.CreateNot(Res, "mul.not.ov"); // If we replaced the mul, erase it. Do this after all uses of Builder, // as the mul is used as insertion point. @@ -4126,7 +4135,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, } } - if (Value *V = foldUnsignedMultiplicationOverflowCheck(I)) + if (Value *V = foldMultiplicationOverflowCheck(I)) return replaceInstUsesWith(I, V); if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index edf8f0f..15152bb 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -656,7 +656,7 @@ public: Instruction *foldSignBitTest(ICmpInst &I); Instruction *foldICmpWithZero(ICmpInst &Cmp); - Value *foldUnsignedMultiplicationOverflowCheck(ICmpInst &Cmp); + Value *foldMultiplicationOverflowCheck(ICmpInst &Cmp); Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select, ConstantInt *C); diff --git a/llvm/test/Transforms/InstCombine/signed-mul-lack-of-overflow-check-via-mul-sdiv.ll b/llvm/test/Transforms/InstCombine/signed-mul-lack-of-overflow-check-via-mul-sdiv.ll index 39a1bc7..d2a5d5a 100644 --- a/llvm/test/Transforms/InstCombine/signed-mul-lack-of-overflow-check-via-mul-sdiv.ll +++ b/llvm/test/Transforms/InstCombine/signed-mul-lack-of-overflow-check-via-mul-sdiv.ll @@ -8,10 +8,10 @@ define i1 @t0_basic(i8 %x, i8 %y) { ; CHECK-LABEL: @t0_basic( -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]]) +; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1 +; CHECK-NEXT: [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true +; CHECK-NEXT: ret i1 [[MUL_NOT_OV]] ; %t0 = mul i8 %x, %y %t1 = sdiv i8 %t0, %x @@ -21,10 +21,10 @@ define i1 @t0_basic(i8 %x, i8 %y) { define <2 x i1> @t1_vec(<2 x i8> %x, <2 x i8> %y) { ; CHECK-LABEL: @t1_vec( -; CHECK-NEXT: [[T0:%.*]] = mul <2 x i8> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[T1:%.*]] = sdiv <2 x i8> [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp eq <2 x i8> [[T1]], [[Y]] -; CHECK-NEXT: ret <2 x i1> [[R]] +; CHECK-NEXT: [[MUL:%.*]] = call { <2 x i8>, <2 x i1> } @llvm.smul.with.overflow.v2i8(<2 x i8> [[X:%.*]], <2 x i8> [[Y:%.*]]) +; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { <2 x i8>, <2 x i1> } [[MUL]], 1 +; CHECK-NEXT: [[MUL_NOT_OV:%.*]] = xor <2 x i1> [[MUL_OV]], +; CHECK-NEXT: ret <2 x i1> [[MUL_NOT_OV]] ; %t0 = mul <2 x i8> %x, %y %t1 = sdiv <2 x i8> %t0, %x @@ -37,10 +37,10 @@ declare i8 @gen8() define i1 @t2_commutative(i8 %x) { ; CHECK-LABEL: @t2_commutative( ; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8() -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]] -; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]]) +; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1 +; CHECK-NEXT: [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true +; CHECK-NEXT: ret i1 [[MUL_NOT_OV]] ; %y = call i8 @gen8() %t0 = mul i8 %y, %x ; swapped @@ -52,10 +52,10 @@ define i1 @t2_commutative(i8 %x) { define i1 @t3_commutative(i8 %x) { ; CHECK-LABEL: @t3_commutative( ; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8() -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]] -; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]]) +; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1 +; CHECK-NEXT: [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true +; CHECK-NEXT: ret i1 [[MUL_NOT_OV]] ; %y = call i8 @gen8() %t0 = mul i8 %y, %x ; swapped @@ -67,10 +67,10 @@ define i1 @t3_commutative(i8 %x) { define i1 @t4_commutative(i8 %x) { ; CHECK-LABEL: @t4_commutative( ; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8() -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]] -; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[Y]], [[T1]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]]) +; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1 +; CHECK-NEXT: [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true +; CHECK-NEXT: ret i1 [[MUL_NOT_OV]] ; %y = call i8 @gen8() %t0 = mul i8 %y, %x ; swapped @@ -85,11 +85,12 @@ declare void @use8(i8) define i1 @t5_extrause0(i8 %x, i8 %y) { ; CHECK-LABEL: @t5_extrause0( -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: call void @use8(i8 [[T0]]) -; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]]) +; CHECK-NEXT: [[MUL_VAL:%.*]] = extractvalue { i8, i1 } [[MUL]], 0 +; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1 +; CHECK-NEXT: [[MUL_NOT_OV:%.*]] = xor i1 [[MUL_OV]], true +; CHECK-NEXT: call void @use8(i8 [[MUL_VAL]]) +; CHECK-NEXT: ret i1 [[MUL_NOT_OV]] ; %t0 = mul i8 %x, %y call void @use8(i8 %t0) diff --git a/llvm/test/Transforms/InstCombine/signed-mul-overflow-check-via-mul-sdiv.ll b/llvm/test/Transforms/InstCombine/signed-mul-overflow-check-via-mul-sdiv.ll index 81c04a0..f84ae67 100644 --- a/llvm/test/Transforms/InstCombine/signed-mul-overflow-check-via-mul-sdiv.ll +++ b/llvm/test/Transforms/InstCombine/signed-mul-overflow-check-via-mul-sdiv.ll @@ -8,10 +8,9 @@ define i1 @t0_basic(i8 %x, i8 %y) { ; CHECK-LABEL: @t0_basic( -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]]) +; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1 +; CHECK-NEXT: ret i1 [[MUL_OV]] ; %t0 = mul i8 %x, %y %t1 = sdiv i8 %t0, %x @@ -21,10 +20,9 @@ define i1 @t0_basic(i8 %x, i8 %y) { define <2 x i1> @t1_vec(<2 x i8> %x, <2 x i8> %y) { ; CHECK-LABEL: @t1_vec( -; CHECK-NEXT: [[T0:%.*]] = mul <2 x i8> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[T1:%.*]] = sdiv <2 x i8> [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp ne <2 x i8> [[T1]], [[Y]] -; CHECK-NEXT: ret <2 x i1> [[R]] +; CHECK-NEXT: [[MUL:%.*]] = call { <2 x i8>, <2 x i1> } @llvm.smul.with.overflow.v2i8(<2 x i8> [[X:%.*]], <2 x i8> [[Y:%.*]]) +; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { <2 x i8>, <2 x i1> } [[MUL]], 1 +; CHECK-NEXT: ret <2 x i1> [[MUL_OV]] ; %t0 = mul <2 x i8> %x, %y %t1 = sdiv <2 x i8> %t0, %x @@ -37,10 +35,9 @@ declare i8 @gen8() define i1 @t2_commutative(i8 %x) { ; CHECK-LABEL: @t2_commutative( ; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8() -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]] -; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]]) +; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1 +; CHECK-NEXT: ret i1 [[MUL_OV]] ; %y = call i8 @gen8() %t0 = mul i8 %y, %x ; swapped @@ -52,10 +49,9 @@ define i1 @t2_commutative(i8 %x) { define i1 @t3_commutative(i8 %x) { ; CHECK-LABEL: @t3_commutative( ; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8() -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]] -; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]]) +; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1 +; CHECK-NEXT: ret i1 [[MUL_OV]] ; %y = call i8 @gen8() %t0 = mul i8 %y, %x ; swapped @@ -67,10 +63,9 @@ define i1 @t3_commutative(i8 %x) { define i1 @t4_commutative(i8 %x) { ; CHECK-LABEL: @t4_commutative( ; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8() -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]] -; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[Y]], [[T1]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]]) +; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1 +; CHECK-NEXT: ret i1 [[MUL_OV]] ; %y = call i8 @gen8() %t0 = mul i8 %y, %x ; swapped @@ -85,11 +80,11 @@ declare void @use8(i8) define i1 @t5_extrause0(i8 %x, i8 %y) { ; CHECK-LABEL: @t5_extrause0( -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: call void @use8(i8 [[T0]]) -; CHECK-NEXT: [[T1:%.*]] = sdiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[MUL:%.*]] = call { i8, i1 } @llvm.smul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]]) +; CHECK-NEXT: [[MUL_VAL:%.*]] = extractvalue { i8, i1 } [[MUL]], 0 +; CHECK-NEXT: [[MUL_OV:%.*]] = extractvalue { i8, i1 } [[MUL]], 1 +; CHECK-NEXT: call void @use8(i8 [[MUL_VAL]]) +; CHECK-NEXT: ret i1 [[MUL_OV]] ; %t0 = mul i8 %x, %y call void @use8(i8 %t0) -- 2.7.4