From ebbc37391f9d1de5e8c4bee14493fce20f9c6906 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Wed, 18 May 2022 14:34:48 -0400 Subject: [PATCH] [InstCombine] allow variable shift amount in bswap + shift fold When shifting by a byte-multiple: bswap (shl X, Y) --> lshr (bswap X), Y bswap (lshr X, Y) --> shl (bswap X), Y This was limited to constants as a first step in D122010 / 60820e53ec9d3be02 , but issue #55327 shows a source example (and there's a test based on that here) where a variable shift amount is used in this pattern. --- .../Transforms/InstCombine/InstCombineCalls.cpp | 24 ++++++++++++---------- llvm/test/Transforms/InstCombine/bswap-fold.ll | 15 ++++++++------ 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index c966ac0..f572227 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1433,23 +1433,25 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } case Intrinsic::bswap: { Value *IIOperand = II->getArgOperand(0); - Value *X = nullptr; // Try to canonicalize bswap-of-logical-shift-by-8-bit-multiple as // inverse-shift-of-bswap: - // bswap (shl X, C) --> lshr (bswap X), C - // bswap (lshr X, C) --> shl (bswap X), C - // TODO: Use knownbits to allow variable shift and non-splat vector match. - BinaryOperator *BO; - if (match(IIOperand, m_OneUse(m_BinOp(BO)))) { + // bswap (shl X, Y) --> lshr (bswap X), Y + // bswap (lshr X, Y) --> shl (bswap X), Y + Value *X, *Y; + if (match(IIOperand, m_OneUse(m_LogicalShift(m_Value(X), m_Value(Y))))) { + // The transform allows undef vector elements, so try a constant match + // first. If knownbits can handle that case, that clause could be removed. + unsigned BitWidth = IIOperand->getType()->getScalarSizeInBits(); const APInt *C; - if (match(BO, m_LogicalShift(m_Value(X), m_APIntAllowUndef(C))) && - (*C & 7) == 0) { + if ((match(Y, m_APIntAllowUndef(C)) && (*C & 7) == 0) || + MaskedValueIsZero(Y, APInt::getLowBitsSet(BitWidth, 3))) { Value *NewSwap = Builder.CreateUnaryIntrinsic(Intrinsic::bswap, X); BinaryOperator::BinaryOps InverseShift = - BO->getOpcode() == Instruction::Shl ? Instruction::LShr - : Instruction::Shl; - return BinaryOperator::Create(InverseShift, NewSwap, BO->getOperand(1)); + cast(IIOperand)->getOpcode() == Instruction::Shl + ? Instruction::LShr + : Instruction::Shl; + return BinaryOperator::Create(InverseShift, NewSwap, Y); } } diff --git a/llvm/test/Transforms/InstCombine/bswap-fold.ll b/llvm/test/Transforms/InstCombine/bswap-fold.ll index 49809f8..1e1903f1 100644 --- a/llvm/test/Transforms/InstCombine/bswap-fold.ll +++ b/llvm/test/Transforms/InstCombine/bswap-fold.ll @@ -159,11 +159,14 @@ define i64 @swap_shl16_i64(i64 %x) { ret i64 %r } +; canonicalize shift after bswap if shift amount is multiple of 8-bits +; (including non-uniform vector elements) + define <2 x i32> @variable_lshr_v2i32(<2 x i32> %x, <2 x i32> %n) { ; CHECK-LABEL: @variable_lshr_v2i32( ; CHECK-NEXT: [[SHAMT:%.*]] = and <2 x i32> [[N:%.*]], -; CHECK-NEXT: [[S:%.*]] = shl <2 x i32> [[X:%.*]], [[SHAMT]] -; CHECK-NEXT: [[R:%.*]] = call <2 x i32> @llvm.bswap.v2i32(<2 x i32> [[S]]) +; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i32> @llvm.bswap.v2i32(<2 x i32> [[X:%.*]]) +; CHECK-NEXT: [[R:%.*]] = lshr <2 x i32> [[TMP1]], [[SHAMT]] ; CHECK-NEXT: ret <2 x i32> [[R]] ; %shamt = and <2 x i32> %n, @@ -172,15 +175,13 @@ define <2 x i32> @variable_lshr_v2i32(<2 x i32> %x, <2 x i32> %n) { ret <2 x i32> %r } -; PR55327 +; PR55327 - swaps cancel define i64 @variable_shl_i64(i64 %x, i64 %n) { ; CHECK-LABEL: @variable_shl_i64( -; CHECK-NEXT: [[B:%.*]] = tail call i64 @llvm.bswap.i64(i64 [[X:%.*]]) ; CHECK-NEXT: [[N8:%.*]] = shl i64 [[N:%.*]], 3 ; CHECK-NEXT: [[SHAMT:%.*]] = and i64 [[N8]], 56 -; CHECK-NEXT: [[S:%.*]] = shl i64 [[B]], [[SHAMT]] -; CHECK-NEXT: [[R:%.*]] = tail call i64 @llvm.bswap.i64(i64 [[S]]) +; CHECK-NEXT: [[R:%.*]] = lshr i64 [[X:%.*]], [[SHAMT]] ; CHECK-NEXT: ret i64 [[R]] ; %b = tail call i64 @llvm.bswap.i64(i64 %x) @@ -191,6 +192,8 @@ define i64 @variable_shl_i64(i64 %x, i64 %n) { ret i64 %r } +; negative test - must have multiple of 8-bit shift amount + define i64 @variable_shl_not_masked_enough_i64(i64 %x, i64 %n) { ; CHECK-LABEL: @variable_shl_not_masked_enough_i64( ; CHECK-NEXT: [[SHAMT:%.*]] = and i64 [[N:%.*]], -4 -- 2.7.4