From: Simon Pilgrim Date: Tue, 20 Oct 2020 12:27:43 +0000 (+0100) Subject: [InstCombine] Add or((icmp ult/ule (A + C1), C3), (icmp ult/ule (A + C2), C3)) unifor... X-Git-Tag: llvmorg-13-init~8740 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e372a5f86f6488bb0c2593a665d51fdd3a97c6e4;p=platform%2Fupstream%2Fllvm.git [InstCombine] Add or((icmp ult/ule (A + C1), C3), (icmp ult/ule (A + C2), C3)) uniform vector support Reapplied rGa704d8238c86 with a check for integer/integervector types to prevent matching with pointer types --- diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index b34ba4e..2b8f5ff 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2283,8 +2283,6 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1); - auto *LHSC = dyn_cast(LHS1); - auto *RHSC = dyn_cast(RHS1); // Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3) // --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3) @@ -2296,42 +2294,43 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // 3) C1 ^ C2 is one-bit mask. // 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask. // This implies all values in the two ranges differ by exactly one bit. + const APInt *LHSVal, *RHSVal; if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) && - PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() && - LHSC->getType() == RHSC->getType() && - LHSC->getValue() == (RHSC->getValue())) { - + PredL == PredR && LHS->getType() == RHS->getType() && + LHS->getType()->isIntOrIntVectorTy() && match(LHS1, m_APInt(LHSVal)) && + match(RHS1, m_APInt(RHSVal)) && *LHSVal == *RHSVal && LHS->hasOneUse() && + RHS->hasOneUse()) { Value *LAddOpnd, *RAddOpnd; - ConstantInt *LAddC, *RAddC; - if (match(LHS0, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddC))) && - match(RHS0, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddC))) && - LAddC->getValue().ugt(LHSC->getValue()) && - RAddC->getValue().ugt(LHSC->getValue())) { + const APInt *LAddVal, *RAddVal; + if (match(LHS0, m_Add(m_Value(LAddOpnd), m_APInt(LAddVal))) && + match(RHS0, m_Add(m_Value(RAddOpnd), m_APInt(RAddVal))) && + LAddVal->ugt(*LHSVal) && RAddVal->ugt(*LHSVal)) { - APInt DiffC = LAddC->getValue() ^ RAddC->getValue(); + APInt DiffC = *LAddVal ^ *RAddVal; if (LAddOpnd == RAddOpnd && DiffC.isPowerOf2()) { - ConstantInt *MaxAddC = nullptr; - if (LAddC->getValue().ult(RAddC->getValue())) - MaxAddC = RAddC; + const APInt *MaxAddC = nullptr; + if (LAddVal->ult(*RAddVal)) + MaxAddC = RAddVal; else - MaxAddC = LAddC; + MaxAddC = LAddVal; - APInt RRangeLow = -RAddC->getValue(); - APInt RRangeHigh = RRangeLow + LHSC->getValue(); - APInt LRangeLow = -LAddC->getValue(); - APInt LRangeHigh = LRangeLow + LHSC->getValue(); + APInt RRangeLow = -*RAddVal; + APInt RRangeHigh = RRangeLow + *LHSVal; + APInt LRangeLow = -*LAddVal; + APInt LRangeHigh = LRangeLow + *LHSVal; APInt LowRangeDiff = RRangeLow ^ LRangeLow; APInt HighRangeDiff = RRangeHigh ^ LRangeHigh; APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow : RRangeLow - LRangeLow; if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff && - RangeDiff.ugt(LHSC->getValue())) { - Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC); - - Value *NewAnd = Builder.CreateAnd(LAddOpnd, MaskC); - Value *NewAdd = Builder.CreateAdd(NewAnd, MaxAddC); - return Builder.CreateICmp(LHS->getPredicate(), NewAdd, LHSC); + RangeDiff.ugt(*LHSVal)) { + Value *NewAnd = Builder.CreateAnd( + LAddOpnd, ConstantInt::get(LHS0->getType(), ~DiffC)); + Value *NewAdd = Builder.CreateAdd( + NewAnd, ConstantInt::get(LHS0->getType(), *MaxAddC)); + return Builder.CreateICmp(LHS->getPredicate(), NewAdd, + ConstantInt::get(LHS0->getType(), *LHSVal)); } } } @@ -2417,6 +2416,8 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, } // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). + auto *LHSC = dyn_cast(LHS1); + auto *RHSC = dyn_cast(RHS1); if (!LHSC || !RHSC) return nullptr; diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll index a563d2f..e13b593 100644 --- a/llvm/test/Transforms/InstCombine/or.ll +++ b/llvm/test/Transforms/InstCombine/or.ll @@ -577,12 +577,10 @@ define i1 @test46(i8 signext %c) { define <2 x i1> @test46_uniform(<2 x i8> %c) { ; CHECK-LABEL: @test46_uniform( -; CHECK-NEXT: [[C_OFF:%.*]] = add <2 x i8> [[C:%.*]], -; CHECK-NEXT: [[CMP1:%.*]] = icmp ult <2 x i8> [[C_OFF]], -; CHECK-NEXT: [[C_OFF17:%.*]] = add <2 x i8> [[C]], -; CHECK-NEXT: [[CMP2:%.*]] = icmp ult <2 x i8> [[C_OFF17]], -; CHECK-NEXT: [[OR:%.*]] = or <2 x i1> [[CMP1]], [[CMP2]] -; CHECK-NEXT: ret <2 x i1> [[OR]] +; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i8> [[C:%.*]], +; CHECK-NEXT: [[TMP2:%.*]] = add <2 x i8> [[TMP1]], +; CHECK-NEXT: [[TMP3:%.*]] = icmp ult <2 x i8> [[TMP2]], +; CHECK-NEXT: ret <2 x i1> [[TMP3]] ; %c.off = add <2 x i8> %c, %cmp1 = icmp ult <2 x i8> %c.off,