From 310f62b4ff3ecb67cf696a977b194cceb326fa43 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Sat, 24 Oct 2020 12:42:43 +0100 Subject: [PATCH] [InstCombine] narrowFunnelShift - fold trunc/zext or(shl(a,x),lshr(b,sub(bw,x))) -> fshl(a,b,x) (PR35155) As discussed on PR35155, this extends narrowFunnelShift (recently renamed from narrowRotate) to support basic funnel shift patterns. Unlike matchFunnelShift we don't include the computeKnownBits limitation as extracting the pattern from the zext/trunc layers should be a indicator of reasonable funnel shift codegen, in D89139 we demonstrated how to efficiently promote funnel shifts to wider types. Differential Revision: https://reviews.llvm.org/D89542 --- .../Transforms/InstCombine/InstCombineCasts.cpp | 24 +++++++----- llvm/test/Transforms/InstCombine/funnel.ll | 43 ++++++---------------- 2 files changed, 25 insertions(+), 42 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index c42b240..9bfaa31 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -553,12 +553,17 @@ Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) { // Match the shift amount operands for a funnel/rotate pattern. This always // matches a subtraction on the R operand. - auto matchShiftAmount = [](Value *L, Value *R, unsigned Width) -> Value * { + auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * { // The shift amounts may add up to the narrow bit width: // (shl ShVal0, L) | (lshr ShVal1, Width - L) if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) return L; + // The following patterns currently only work for rotation patterns. + // TODO: Add more general funnel-shift compatible patterns. + if (ShVal0 != ShVal1) + return nullptr; + // The shift amount may be masked with negation: // (shl ShVal0, (X & (Width - 1))) | (lshr ShVal1, ((-X) & (Width - 1))) Value *X; @@ -575,11 +580,6 @@ Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) { return nullptr; }; - // TODO: Add support for funnel shifts (ShVal0 != ShVal1). - if (ShVal0 != ShVal1) - return nullptr; - Value *ShVal = ShVal0; - Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, NarrowWidth); bool IsFshl = true; // Sub on LSHR. if (!ShAmt) { @@ -593,18 +593,22 @@ Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) { // will be a zext, but it could also be the result of an 'and' or 'shift'. unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits(); APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth); - if (!MaskedValueIsZero(ShVal, HiBitMask, 0, &Trunc)) + if (!MaskedValueIsZero(ShVal0, HiBitMask, 0, &Trunc) || + !MaskedValueIsZero(ShVal1, HiBitMask, 0, &Trunc)) return nullptr; // We have an unnecessarily wide rotate! - // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt)) + // trunc (or (lshr ShVal0, ShAmt), (shl ShVal1, BitWidth - ShAmt)) // Narrow the inputs and convert to funnel shift intrinsic: // llvm.fshl.i8(trunc(ShVal), trunc(ShVal), trunc(ShAmt)) Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy); - Value *X = Builder.CreateTrunc(ShVal, DestTy); + Value *X, *Y; + X = Y = Builder.CreateTrunc(ShVal0, DestTy); + if (ShVal0 != ShVal1) + Y = Builder.CreateTrunc(ShVal1, DestTy); Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; Function *F = Intrinsic::getDeclaration(Trunc.getModule(), IID, DestTy); - return IntrinsicInst::Create(F, { X, X, NarrowShAmt }); + return IntrinsicInst::Create(F, {X, Y, NarrowShAmt}); } /// Try to narrow the width of math or bitwise logic instructions by pulling a diff --git a/llvm/test/Transforms/InstCombine/funnel.ll b/llvm/test/Transforms/InstCombine/funnel.ll index 2f01c40..1bc3c26 100644 --- a/llvm/test/Transforms/InstCombine/funnel.ll +++ b/llvm/test/Transforms/InstCombine/funnel.ll @@ -205,14 +205,8 @@ define <2 x i64> @fshr_sub_mask_vector(<2 x i64> %x, <2 x i64> %y, <2 x i64> %a) define i16 @fshl_16bit(i16 %x, i16 %y, i32 %shift) { ; CHECK-LABEL: @fshl_16bit( -; CHECK-NEXT: [[AND:%.*]] = and i32 [[SHIFT:%.*]], 15 -; CHECK-NEXT: [[CONVX:%.*]] = zext i16 [[X:%.*]] to i32 -; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[CONVX]], [[AND]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 16, [[AND]] -; CHECK-NEXT: [[CONVY:%.*]] = zext i16 [[Y:%.*]] to i32 -; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[CONVY]], [[SUB]] -; CHECK-NEXT: [[OR:%.*]] = or i32 [[SHR]], [[SHL]] -; CHECK-NEXT: [[CONV2:%.*]] = trunc i32 [[OR]] to i16 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[SHIFT:%.*]] to i16 +; CHECK-NEXT: [[CONV2:%.*]] = call i16 @llvm.fshl.i16(i16 [[X:%.*]], i16 [[Y:%.*]], i16 [[TMP1]]) ; CHECK-NEXT: ret i16 [[CONV2]] ; %and = and i32 %shift, 15 @@ -230,14 +224,8 @@ define i16 @fshl_16bit(i16 %x, i16 %y, i32 %shift) { define <2 x i16> @fshl_commute_16bit_vec(<2 x i16> %x, <2 x i16> %y, <2 x i32> %shift) { ; CHECK-LABEL: @fshl_commute_16bit_vec( -; CHECK-NEXT: [[AND:%.*]] = and <2 x i32> [[SHIFT:%.*]], -; CHECK-NEXT: [[CONVX:%.*]] = zext <2 x i16> [[X:%.*]] to <2 x i32> -; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i32> [[CONVX]], [[AND]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw <2 x i32> , [[AND]] -; CHECK-NEXT: [[CONVY:%.*]] = zext <2 x i16> [[Y:%.*]] to <2 x i32> -; CHECK-NEXT: [[SHR:%.*]] = lshr <2 x i32> [[CONVY]], [[SUB]] -; CHECK-NEXT: [[OR:%.*]] = or <2 x i32> [[SHL]], [[SHR]] -; CHECK-NEXT: [[CONV2:%.*]] = trunc <2 x i32> [[OR]] to <2 x i16> +; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i32> [[SHIFT:%.*]] to <2 x i16> +; CHECK-NEXT: [[CONV2:%.*]] = call <2 x i16> @llvm.fshl.v2i16(<2 x i16> [[X:%.*]], <2 x i16> [[Y:%.*]], <2 x i16> [[TMP1]]) ; CHECK-NEXT: ret <2 x i16> [[CONV2]] ; %and = and <2 x i32> %shift, @@ -255,14 +243,8 @@ define <2 x i16> @fshl_commute_16bit_vec(<2 x i16> %x, <2 x i16> %y, <2 x i32> % define i8 @fshr_8bit(i8 %x, i8 %y, i3 %shift) { ; CHECK-LABEL: @fshr_8bit( -; CHECK-NEXT: [[AND:%.*]] = zext i3 [[SHIFT:%.*]] to i32 -; CHECK-NEXT: [[CONVX:%.*]] = zext i8 [[X:%.*]] to i32 -; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[CONVX]], [[AND]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 8, [[AND]] -; CHECK-NEXT: [[CONVY:%.*]] = zext i8 [[Y:%.*]] to i32 -; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[CONVY]], [[SUB]] -; CHECK-NEXT: [[OR:%.*]] = or i32 [[SHL]], [[SHR]] -; CHECK-NEXT: [[CONV2:%.*]] = trunc i32 [[OR]] to i8 +; CHECK-NEXT: [[TMP1:%.*]] = zext i3 [[SHIFT:%.*]] to i8 +; CHECK-NEXT: [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[Y:%.*]], i8 [[X:%.*]], i8 [[TMP1]]) ; CHECK-NEXT: ret i8 [[CONV2]] ; %and = zext i3 %shift to i32 @@ -281,14 +263,11 @@ define i8 @fshr_8bit(i8 %x, i8 %y, i3 %shift) { define i8 @fshr_commute_8bit(i32 %x, i32 %y, i32 %shift) { ; CHECK-LABEL: @fshr_commute_8bit( -; CHECK-NEXT: [[AND:%.*]] = and i32 [[SHIFT:%.*]], 3 -; CHECK-NEXT: [[CONVX:%.*]] = and i32 [[X:%.*]], 255 -; CHECK-NEXT: [[SHR:%.*]] = lshr i32 [[CONVX]], [[AND]] -; CHECK-NEXT: [[SUB:%.*]] = sub nuw nsw i32 8, [[AND]] -; CHECK-NEXT: [[CONVY:%.*]] = and i32 [[Y:%.*]], 255 -; CHECK-NEXT: [[SHL:%.*]] = shl i32 [[CONVY]], [[SUB]] -; CHECK-NEXT: [[OR:%.*]] = or i32 [[SHR]], [[SHL]] -; CHECK-NEXT: [[CONV2:%.*]] = trunc i32 [[OR]] to i8 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[SHIFT:%.*]] to i8 +; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 3 +; CHECK-NEXT: [[TMP3:%.*]] = trunc i32 [[Y:%.*]] to i8 +; CHECK-NEXT: [[TMP4:%.*]] = trunc i32 [[X:%.*]] to i8 +; CHECK-NEXT: [[CONV2:%.*]] = call i8 @llvm.fshr.i8(i8 [[TMP3]], i8 [[TMP4]], i8 [[TMP2]]) ; CHECK-NEXT: ret i8 [[CONV2]] ; %and = and i32 %shift, 3 -- 2.7.4