From bb9333c3504a4a02b982526ad8264d14c6ec1ad4 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Thu, 23 Sep 2021 09:40:01 -0400 Subject: [PATCH] [InstCombine] fold cast of right-shift if high bits are not demanded (2nd try) The 1st try at this was reverted because it caused an infinite loop in instcombine. That should be fixed after: 1cd6b44f267b (masked) trunc (lshr X, C) --> (masked) lshr (trunc X), C Narrowing the shift should be better for analysis and can lead to follow-on transforms as shown. Attempt at a general proof in Alive2: https://alive2.llvm.org/ce/z/tRnnSF Here are a couple of the specific tests: https://alive2.llvm.org/ce/z/bCnTp- https://alive2.llvm.org/ce/z/TfaHnb Differential Revision: https://reviews.llvm.org/D110170 --- .../InstCombine/InstCombineSimplifyDemanded.cpp | 20 +++++++- llvm/test/Transforms/InstCombine/trunc-demand.ll | 56 ++++++++++++++-------- 2 files changed, 55 insertions(+), 21 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index a8174d0..2158ae5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -385,8 +385,26 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Known = KnownBits::commonBits(LHSKnown, RHSKnown); break; } - case Instruction::ZExt: case Instruction::Trunc: { + // If we do not demand the high bits of a right-shifted and truncated value, + // then we may be able to truncate it before the shift. + Value *X; + const APInt *C; + if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { + // The shift amount must be valid (not poison) in the narrow type, and + // it must not be greater than the high bits demanded of the result. + if (C->ult(I->getType()->getScalarSizeInBits()) && + C->ule(DemandedMask.countLeadingZeros())) { + // trunc (lshr X, C) --> lshr (trunc X), C + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(I); + Value *Trunc = Builder.CreateTrunc(X, I->getType()); + return Builder.CreateLShr(Trunc, C->getZExtValue()); + } + } + } + LLVM_FALLTHROUGH; + case Instruction::ZExt: { unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth); diff --git a/llvm/test/Transforms/InstCombine/trunc-demand.ll b/llvm/test/Transforms/InstCombine/trunc-demand.ll index ce638fe..e8df45b 100644 --- a/llvm/test/Transforms/InstCombine/trunc-demand.ll +++ b/llvm/test/Transforms/InstCombine/trunc-demand.ll @@ -6,9 +6,9 @@ declare void @use8(i8) define i6 @trunc_lshr(i8 %x) { ; CHECK-LABEL: @trunc_lshr( -; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 2 -; CHECK-NEXT: [[T:%.*]] = trunc i8 [[S]] to i6 -; CHECK-NEXT: [[R:%.*]] = and i6 [[T]], 14 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i8 [[X:%.*]] to i6 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i6 [[TMP1]], 2 +; CHECK-NEXT: [[R:%.*]] = and i6 [[TMP2]], 14 ; CHECK-NEXT: ret i6 [[R]] ; %s = lshr i8 %x, 2 @@ -17,12 +17,13 @@ define i6 @trunc_lshr(i8 %x) { ret i6 %r } +; The 'and' is eliminated. + define i6 @trunc_lshr_exact_mask(i8 %x) { ; CHECK-LABEL: @trunc_lshr_exact_mask( -; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 2 -; CHECK-NEXT: [[T:%.*]] = trunc i8 [[S]] to i6 -; CHECK-NEXT: [[R:%.*]] = and i6 [[T]], 15 -; CHECK-NEXT: ret i6 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = trunc i8 [[X:%.*]] to i6 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i6 [[TMP1]], 2 +; CHECK-NEXT: ret i6 [[TMP2]] ; %s = lshr i8 %x, 2 %t = trunc i8 %s to i6 @@ -30,6 +31,8 @@ define i6 @trunc_lshr_exact_mask(i8 %x) { ret i6 %r } +; negative test - a high bit of x is in the result + define i6 @trunc_lshr_big_mask(i8 %x) { ; CHECK-LABEL: @trunc_lshr_big_mask( ; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 2 @@ -43,6 +46,8 @@ define i6 @trunc_lshr_big_mask(i8 %x) { ret i6 %r } +; negative test - too many uses + define i6 @trunc_lshr_use1(i8 %x) { ; CHECK-LABEL: @trunc_lshr_use1( ; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 2 @@ -58,6 +63,8 @@ define i6 @trunc_lshr_use1(i8 %x) { ret i6 %r } +; negative test - too many uses + define i6 @trunc_lshr_use2(i8 %x) { ; CHECK-LABEL: @trunc_lshr_use2( ; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 2 @@ -73,11 +80,13 @@ define i6 @trunc_lshr_use2(i8 %x) { ret i6 %r } +; Splat vectors are ok. + define <2 x i7> @trunc_lshr_vec_splat(<2 x i16> %x) { ; CHECK-LABEL: @trunc_lshr_vec_splat( -; CHECK-NEXT: [[S:%.*]] = lshr <2 x i16> [[X:%.*]], -; CHECK-NEXT: [[T:%.*]] = trunc <2 x i16> [[S]] to <2 x i7> -; CHECK-NEXT: [[R:%.*]] = and <2 x i7> [[T]], +; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i16> [[X:%.*]] to <2 x i7> +; CHECK-NEXT: [[TMP2:%.*]] = lshr <2 x i7> [[TMP1]], +; CHECK-NEXT: [[R:%.*]] = and <2 x i7> [[TMP2]], ; CHECK-NEXT: ret <2 x i7> [[R]] ; %s = lshr <2 x i16> %x, @@ -86,12 +95,13 @@ define <2 x i7> @trunc_lshr_vec_splat(<2 x i16> %x) { ret <2 x i7> %r } +; The 'and' is eliminated. + define <2 x i7> @trunc_lshr_vec_splat_exact_mask(<2 x i16> %x) { ; CHECK-LABEL: @trunc_lshr_vec_splat_exact_mask( -; CHECK-NEXT: [[S:%.*]] = lshr <2 x i16> [[X:%.*]], -; CHECK-NEXT: [[T:%.*]] = trunc <2 x i16> [[S]] to <2 x i7> -; CHECK-NEXT: [[R:%.*]] = and <2 x i7> [[T]], -; CHECK-NEXT: ret <2 x i7> [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i16> [[X:%.*]] to <2 x i7> +; CHECK-NEXT: [[TMP2:%.*]] = lshr <2 x i7> [[TMP1]], +; CHECK-NEXT: ret <2 x i7> [[TMP2]] ; %s = lshr <2 x i16> %x, %t = trunc <2 x i16> %s to <2 x i7> @@ -99,6 +109,8 @@ define <2 x i7> @trunc_lshr_vec_splat_exact_mask(<2 x i16> %x) { ret <2 x i7> %r } +; negative test - the shift is too big for the narrow type + define <2 x i7> @trunc_lshr_big_shift(<2 x i16> %x) { ; CHECK-LABEL: @trunc_lshr_big_shift( ; CHECK-NEXT: [[S:%.*]] = lshr <2 x i16> [[X:%.*]], @@ -112,11 +124,13 @@ define <2 x i7> @trunc_lshr_big_shift(<2 x i16> %x) { ret <2 x i7> %r } +; High bits could also be set rather than cleared. + define i6 @or_trunc_lshr(i8 %x) { ; CHECK-LABEL: @or_trunc_lshr( -; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 1 -; CHECK-NEXT: [[T:%.*]] = trunc i8 [[S]] to i6 -; CHECK-NEXT: [[R:%.*]] = or i6 [[T]], -32 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i8 [[X:%.*]] to i6 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i6 [[TMP1]], 1 +; CHECK-NEXT: [[R:%.*]] = or i6 [[TMP2]], -32 ; CHECK-NEXT: ret i6 [[R]] ; %s = lshr i8 %x, 1 @@ -127,9 +141,9 @@ define i6 @or_trunc_lshr(i8 %x) { define i6 @or_trunc_lshr_more(i8 %x) { ; CHECK-LABEL: @or_trunc_lshr_more( -; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 4 -; CHECK-NEXT: [[T:%.*]] = trunc i8 [[S]] to i6 -; CHECK-NEXT: [[R:%.*]] = or i6 [[T]], -4 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i8 [[X:%.*]] to i6 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i6 [[TMP1]], 4 +; CHECK-NEXT: [[R:%.*]] = or i6 [[TMP2]], -4 ; CHECK-NEXT: ret i6 [[R]] ; %s = lshr i8 %x, 4 @@ -138,6 +152,8 @@ define i6 @or_trunc_lshr_more(i8 %x) { ret i6 %r } +; negative test - need all high bits to be undemanded + define i6 @or_trunc_lshr_small_mask(i8 %x) { ; CHECK-LABEL: @or_trunc_lshr_small_mask( ; CHECK-NEXT: [[S:%.*]] = lshr i8 [[X:%.*]], 4 -- 2.7.4