From 70b3beb0e22dd0eb33c6fcef019a24f9f1f09ef9 Mon Sep 17 00:00:00 2001 From: Max Kazantsev Date: Mon, 31 Jan 2022 12:12:48 +0700 Subject: [PATCH] [InstCombine] Generalize and-reduce pattern to handle `ne` case as well as `eq` Following Sanjay's proposal from discussion in D118317, this patch generalizes and-reduce handling to fold the following pattern ``` icmp ne (bitcast(icmp ne (lhs, rhs)), 0) ``` into ``` icmp ne (bitcast(lhs), bitcast(rhs)) ``` https://alive2.llvm.org/ce/z/WDcuJ_ Differential Revision: https://reviews.llvm.org/D118431 Reviewed By: lebedev.ri --- llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 11 ++++++----- llvm/test/Transforms/InstCombine/icmp-vec.ll | 6 +++--- .../Transforms/InstCombine/reduction-or-sext-zext-i1.ll | 14 ++++++-------- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 677403a..e45be57 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -5897,13 +5897,15 @@ static Instruction *foldReductionIdiom(ICmpInst &I, // Match lowering of @llvm.vector.reduce.and. Turn /// %vec_ne = icmp ne <8 x i8> %lhs, %rhs /// %scalar_ne = bitcast <8 x i1> %vec_ne to i8 - /// %all_eq = icmp eq i8 %scalar_ne, 0 + /// %res = icmp i8 %scalar_ne, 0 /// /// into /// /// %lhs.scalar = bitcast <8 x i8> %lhs to i64 /// %rhs.scalar = bitcast <8 x i8> %rhs to i64 - /// %all_eq = icmp eq i64 %lhs.scalar, %rhs.scalar + /// %res = icmp i64 %lhs.scalar, %rhs.scalar + /// + /// for in {ne, eq}. if (!match(&I, m_ICmp(OuterPred, m_OneUse(m_BitCast(m_OneUse( m_ICmp(InnerPred, m_Value(LHS), m_Value(RHS))))), @@ -5918,12 +5920,11 @@ static Instruction *foldReductionIdiom(ICmpInst &I, if (!DL.isLegalInteger(NumBits)) return nullptr; - // TODO: Generalize to isEquality and support other patterns. - if (OuterPred == ICmpInst::ICMP_EQ && InnerPred == ICmpInst::ICMP_NE) { + if (ICmpInst::isEquality(OuterPred) && InnerPred == ICmpInst::ICMP_NE) { auto *ScalarTy = Builder.getIntNTy(NumBits); LHS = Builder.CreateBitCast(LHS, ScalarTy, LHS->getName() + ".scalar"); RHS = Builder.CreateBitCast(RHS, ScalarTy, RHS->getName() + ".scalar"); - return ICmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, LHS, RHS, + return ICmpInst::Create(Instruction::ICmp, OuterPred, LHS, RHS, I.getName()); } diff --git a/llvm/test/Transforms/InstCombine/icmp-vec.ll b/llvm/test/Transforms/InstCombine/icmp-vec.ll index 2888a82..50c0632 100644 --- a/llvm/test/Transforms/InstCombine/icmp-vec.ll +++ b/llvm/test/Transforms/InstCombine/icmp-vec.ll @@ -443,9 +443,9 @@ define i1 @eq_cast_ne-1(<2 x i7> %x, <2 x i7> %y) { define i1 @eq_cast_ne-1-legal-scalar(<2 x i8> %x, <2 x i8> %y) { ; CHECK-LABEL: @eq_cast_ne-1-legal-scalar( -; CHECK-NEXT: [[IC:%.*]] = icmp ne <2 x i8> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i1> [[IC]] to i2 -; CHECK-NEXT: [[R:%.*]] = icmp ne i2 [[TMP1]], 0 +; CHECK-NEXT: [[X_SCALAR:%.*]] = bitcast <2 x i8> [[X:%.*]] to i16 +; CHECK-NEXT: [[Y_SCALAR:%.*]] = bitcast <2 x i8> [[Y:%.*]] to i16 +; CHECK-NEXT: [[R:%.*]] = icmp ne i16 [[X_SCALAR]], [[Y_SCALAR]] ; CHECK-NEXT: ret i1 [[R]] ; %ic = icmp eq <2 x i8> %x, %y diff --git a/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll b/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll index 35174d4..cd1b10a 100644 --- a/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll +++ b/llvm/test/Transforms/InstCombine/reduction-or-sext-zext-i1.ll @@ -145,14 +145,12 @@ bb: define i1 @reduce_or_pointer_cast_ne(i8* %arg, i8* %arg1) { ; CHECK-LABEL: @reduce_or_pointer_cast_ne( ; CHECK-NEXT: bb: -; CHECK-NEXT: [[PTR1:%.*]] = bitcast i8* [[ARG1:%.*]] to <8 x i8>* -; CHECK-NEXT: [[PTR2:%.*]] = bitcast i8* [[ARG:%.*]] to <8 x i8>* -; CHECK-NEXT: [[LHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR1]], align 8 -; CHECK-NEXT: [[RHS:%.*]] = load <8 x i8>, <8 x i8>* [[PTR2]], align 8 -; CHECK-NEXT: [[CMP:%.*]] = icmp ne <8 x i8> [[LHS]], [[RHS]] -; CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i1> [[CMP]] to i8 -; CHECK-NEXT: [[TMP1:%.*]] = icmp ne i8 [[TMP0]], 0 -; CHECK-NEXT: ret i1 [[TMP1]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i8* [[ARG1:%.*]] to i64* +; CHECK-NEXT: [[LHS1:%.*]] = load i64, i64* [[TMP0]], align 8 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[ARG:%.*]] to i64* +; CHECK-NEXT: [[RHS2:%.*]] = load i64, i64* [[TMP1]], align 8 +; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i64 [[LHS1]], [[RHS2]] +; CHECK-NEXT: ret i1 [[TMP2]] ; bb: %ptr1 = bitcast i8* %arg1 to <8 x i8>* -- 2.7.4