From 3f3356bdd9c7188530f6582b4a407469131ae679 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Sat, 10 Oct 2020 11:08:50 -0400 Subject: [PATCH] [InstCombine] allow vector splats for add+xor --> shifts --- .../Transforms/InstCombine/InstCombineAddSub.cpp | 50 ++++++++-------------- llvm/test/Transforms/InstCombine/signext.ll | 19 +++----- 2 files changed, 26 insertions(+), 43 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 2dd5180..4987ba7 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -936,6 +936,25 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { if ((*C2 | LHSKnown.Zero).isAllOnesValue()) return BinaryOperator::CreateSub(ConstantInt::get(Ty, *C2 + *C), X); } + + // Look for a math+logic pattern that corresponds to sext-in-register of a + // value with cleared high bits. Convert that into a pair of shifts: + // add (xor X, 0x80), 0xF..F80 --> (X << ShAmtC) >>s ShAmtC + // add (xor X, 0xF..F80), 0x80 --> (X << ShAmtC) >>s ShAmtC + if (Op0->hasOneUse() && *C2 == -(*C)) { + unsigned BitWidth = Ty->getScalarSizeInBits(); + unsigned ShAmt = 0; + if (C->isPowerOf2()) + ShAmt = BitWidth - C->logBase2() - 1; + else if (C2->isPowerOf2()) + ShAmt = BitWidth - C2->logBase2() - 1; + if (ShAmt && MaskedValueIsZero(X, APInt::getHighBitsSet(BitWidth, ShAmt), + 0, &Add)) { + Constant *ShAmtC = ConstantInt::get(Ty, ShAmt); + Value *NewShl = Builder.CreateShl(X, ShAmtC, "sext"); + return BinaryOperator::CreateAShr(NewShl, ShAmtC); + } + } } if (C->isOneValue() && Op0->hasOneUse()) { @@ -1284,39 +1303,8 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (Instruction *X = foldNoWrapAdd(I, Builder)) return X; - // FIXME: This should be moved into the above helper function to allow these - // transforms for general constant or constant splat vectors. Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); Type *Ty = I.getType(); - if (ConstantInt *CI = dyn_cast(RHS)) { - Value *XorLHS = nullptr; ConstantInt *XorRHS = nullptr; - if (match(LHS, m_OneUse(m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS))))) { - unsigned TySizeBits = Ty->getScalarSizeInBits(); - const APInt &RHSVal = CI->getValue(); - unsigned ExtendAmt = 0; - // If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext. - // If we have ADD(XOR(AND(X, 0xFF), 0xF..F80), 0x80), it's a sext. - if (XorRHS->getValue() == -RHSVal) { - if (RHSVal.isPowerOf2()) - ExtendAmt = TySizeBits - RHSVal.logBase2() - 1; - else if (XorRHS->getValue().isPowerOf2()) - ExtendAmt = TySizeBits - XorRHS->getValue().logBase2() - 1; - } - - if (ExtendAmt) { - APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt); - if (!MaskedValueIsZero(XorLHS, Mask, 0, &I)) - ExtendAmt = 0; - } - - if (ExtendAmt) { - Constant *ShAmt = ConstantInt::get(Ty, ExtendAmt); - Value *NewShl = Builder.CreateShl(XorLHS, ShAmt, "sext"); - return BinaryOperator::CreateAShr(NewShl, ShAmt); - } - } - } - if (Ty->isIntOrIntVectorTy(1)) return BinaryOperator::CreateXor(LHS, RHS); diff --git a/llvm/test/Transforms/InstCombine/signext.ll b/llvm/test/Transforms/InstCombine/signext.ll index 4faf4e3..447e548 100644 --- a/llvm/test/Transforms/InstCombine/signext.ll +++ b/llvm/test/Transforms/InstCombine/signext.ll @@ -34,9 +34,8 @@ define i32 @sextinreg_extra_use(i32 %x) { define <2 x i32> @sextinreg_splat(<2 x i32> %x) { ; CHECK-LABEL: @sextinreg_splat( -; CHECK-NEXT: [[T1:%.*]] = and <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[T2:%.*]] = xor <2 x i32> [[T1]], -; CHECK-NEXT: [[T3:%.*]] = add nsw <2 x i32> [[T2]], +; CHECK-NEXT: [[SEXT:%.*]] = shl <2 x i32> [[X:%.*]], +; CHECK-NEXT: [[T3:%.*]] = ashr exact <2 x i32> [[SEXT]], ; CHECK-NEXT: ret <2 x i32> [[T3]] ; %t1 = and <2 x i32> %x, @@ -59,9 +58,8 @@ define i32 @sextinreg_alt(i32 %x) { define <2 x i32> @sextinreg_alt_splat(<2 x i32> %x) { ; CHECK-LABEL: @sextinreg_alt_splat( -; CHECK-NEXT: [[T1:%.*]] = and <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[T2:%.*]] = xor <2 x i32> [[T1]], -; CHECK-NEXT: [[T3:%.*]] = add nsw <2 x i32> [[T2]], +; CHECK-NEXT: [[SEXT:%.*]] = shl <2 x i32> [[X:%.*]], +; CHECK-NEXT: [[T3:%.*]] = ashr exact <2 x i32> [[SEXT]], ; CHECK-NEXT: ret <2 x i32> [[T3]] ; %t1 = and <2 x i32> %x, @@ -121,9 +119,8 @@ define i32 @sextinreg2(i32 %x) { define <2 x i32> @sextinreg2_splat(<2 x i32> %x) { ; CHECK-LABEL: @sextinreg2_splat( -; CHECK-NEXT: [[T1:%.*]] = and <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[T2:%.*]] = xor <2 x i32> [[T1]], -; CHECK-NEXT: [[T3:%.*]] = add nsw <2 x i32> [[T2]], +; CHECK-NEXT: [[SEXT:%.*]] = shl <2 x i32> [[X:%.*]], +; CHECK-NEXT: [[T3:%.*]] = ashr exact <2 x i32> [[SEXT]], ; CHECK-NEXT: ret <2 x i32> [[T3]] ; %t1 = and <2 x i32> %x, @@ -184,9 +181,7 @@ define i32 @ashr(i32 %x) { define <2 x i32> @ashr_splat(<2 x i32> %x) { ; CHECK-LABEL: @ashr_splat( -; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[XOR:%.*]] = xor <2 x i32> [[SHR]], -; CHECK-NEXT: [[SUB:%.*]] = add nsw <2 x i32> [[XOR]], +; CHECK-NEXT: [[SUB:%.*]] = ashr <2 x i32> [[X:%.*]], ; CHECK-NEXT: ret <2 x i32> [[SUB]] ; %shr = lshr <2 x i32> %x, -- 2.7.4