From 02a27b38909edc46c41732f79a837c95c9992d5a Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Thu, 15 Sep 2022 12:01:11 -0400 Subject: [PATCH] [InstCombine] fold X*X == 0 --> X == 0 This is safe when the mul does not overflow: https://alive2.llvm.org/ce/z/LedVVP This could be extended to handle non-zero compare constants and non-squared multiplies. --- .../Transforms/InstCombine/InstCombineCompares.cpp | 10 ++++++++++ llvm/test/Transforms/InstCombine/icmp-mul.ll | 20 +++++++++++++------- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 0969e9f..239cd16 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2053,6 +2053,16 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, BinaryOperator *Mul, const APInt &C) { + // If there's no overflow: + // X * X == 0 --> X == 0 + // X * X != 0 --> X != 0 + Type *MulTy = Mul->getType(); + if (Cmp.isEquality() && C.isZero() && + Mul->getOperand(0) == Mul->getOperand(1) && + (Mul->hasNoUnsignedWrap() || Mul->hasNoSignedWrap())) + return new ICmpInst(Cmp.getPredicate(), Mul->getOperand(0), + ConstantInt::getNullValue(MulTy)); + const APInt *MulC; if (!match(Mul->getOperand(1), m_APInt(MulC))) return nullptr; diff --git a/llvm/test/Transforms/InstCombine/icmp-mul.ll b/llvm/test/Transforms/InstCombine/icmp-mul.ll index 0f8ee5d..f2aa0db 100644 --- a/llvm/test/Transforms/InstCombine/icmp-mul.ll +++ b/llvm/test/Transforms/InstCombine/icmp-mul.ll @@ -5,8 +5,7 @@ declare void @use(i8) define i1 @squared_nsw_eq0(i5 %x) { ; CHECK-LABEL: @squared_nsw_eq0( -; CHECK-NEXT: [[M:%.*]] = mul nsw i5 [[X:%.*]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i5 [[M]], 0 +; CHECK-NEXT: [[R:%.*]] = icmp eq i5 [[X:%.*]], 0 ; CHECK-NEXT: ret i1 [[R]] ; %m = mul nsw i5 %x, %x @@ -16,8 +15,7 @@ define i1 @squared_nsw_eq0(i5 %x) { define <2 x i1> @squared_nuw_eq0(<2 x i8> %x) { ; CHECK-LABEL: @squared_nuw_eq0( -; CHECK-NEXT: [[M:%.*]] = mul nuw <2 x i8> [[X:%.*]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp eq <2 x i8> [[M]], zeroinitializer +; CHECK-NEXT: [[R:%.*]] = icmp eq <2 x i8> [[X:%.*]], zeroinitializer ; CHECK-NEXT: ret <2 x i1> [[R]] ; %m = mul nuw <2 x i8> %x, %x @@ -25,11 +23,13 @@ define <2 x i1> @squared_nuw_eq0(<2 x i8> %x) { ret <2 x i1> %r } +; extra use is ok + define i1 @squared_nsw_nuw_ne0(i8 %x) { ; CHECK-LABEL: @squared_nsw_nuw_ne0( ; CHECK-NEXT: [[M:%.*]] = mul nuw nsw i8 [[X:%.*]], [[X]] ; CHECK-NEXT: call void @use(i8 [[M]]) -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[M]], 0 +; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[X]], 0 ; CHECK-NEXT: ret i1 [[R]] ; %m = mul nsw nuw i8 %x, %x @@ -38,6 +38,8 @@ define i1 @squared_nsw_nuw_ne0(i8 %x) { ret i1 %r } +; negative test - must have no-overflow + define i1 @squared_eq0(i8 %x) { ; CHECK-LABEL: @squared_eq0( ; CHECK-NEXT: [[M:%.*]] = mul i8 [[X:%.*]], [[X]] @@ -49,6 +51,9 @@ define i1 @squared_eq0(i8 %x) { ret i1 %r } +; negative test - not squared +; TODO: This could be or-of-icmps. + define i1 @mul_nsw_eq0(i5 %x, i5 %y) { ; CHECK-LABEL: @mul_nsw_eq0( ; CHECK-NEXT: [[M:%.*]] = mul nsw i5 [[X:%.*]], [[Y:%.*]] @@ -60,6 +65,8 @@ define i1 @mul_nsw_eq0(i5 %x, i5 %y) { ret i1 %r } +; negative test - non-zero cmp + define i1 @squared_nsw_eq1(i5 %x) { ; CHECK-LABEL: @squared_nsw_eq1( ; CHECK-NEXT: [[M:%.*]] = mul nsw i5 [[X:%.*]], [[X]] @@ -73,8 +80,7 @@ define i1 @squared_nsw_eq1(i5 %x) { define i1 @squared_nsw_sgt0(i5 %x) { ; CHECK-LABEL: @squared_nsw_sgt0( -; CHECK-NEXT: [[M:%.*]] = mul nsw i5 [[X:%.*]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i5 [[M]], 0 +; CHECK-NEXT: [[R:%.*]] = icmp ne i5 [[X:%.*]], 0 ; CHECK-NEXT: ret i1 [[R]] ; %m = mul nsw i5 %x, %x -- 2.7.4