From a47c8e40c734429903d4000285ca45a1c3299321 Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Fri, 24 Sep 2021 13:27:02 -0400 Subject: [PATCH] [InstCombine] fold lshr(trunc(lshr X, C1)) C2 Only the multi-use cases are changing here because there's another fold that catches the simpler patterns. But that other fold is the source of infinite loops when we try to add D110170, so removing that is planned as a follow-up. Attempt to show the general proof in Alive2: https://alive2.llvm.org/ce/z/Ns1uS2 Note that the overshift fold-to-zero tests are not currently handled by instsimplify. If they were, we could assert that the shift amount sum is less than the source bitwidth. --- .../Transforms/InstCombine/InstCombineShifts.cpp | 16 +++++++++++-- llvm/test/Transforms/InstCombine/lshr.ll | 26 +++++++++++++--------- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 92bfae2..b0d328f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -1149,14 +1149,26 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { } } + // (X >>u C1) >>u C --> X >>u (C1 + C) if (match(Op0, m_LShr(m_Value(X), m_APInt(C1)))) { - unsigned AmtSum = ShAmtC + C1->getZExtValue(); // Oversized shifts are simplified to zero in InstSimplify. + unsigned AmtSum = ShAmtC + C1->getZExtValue(); if (AmtSum < BitWidth) - // (X >>u C1) >>u C --> X >>u (C1 + C) return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum)); } + // If the first shift covers the number of bits truncated and the combined + // shift fits in the source width: + // (trunc (X >>u C1)) >>u C --> trunc (X >>u (C1 + C)) + if (match(Op0, m_OneUse(m_Trunc(m_LShr(m_Value(X), m_APInt(C1)))))) { + unsigned SrcWidth = X->getType()->getScalarSizeInBits(); + unsigned AmtSum = ShAmtC + C1->getZExtValue(); + if (C1->uge(SrcWidth - BitWidth) && AmtSum < SrcWidth) { + Value *SumShift = Builder.CreateLShr(X, AmtSum, "sum.shift"); + return new TruncInst(SumShift, Ty); + } + } + // Look for a "splat" mul pattern - it replicates bits across each half of // a value, so a right shift is just a mask of the low bits: // lshr i32 (mul nuw X, Pow2+1), 16 --> and X, Pow2-1 diff --git a/llvm/test/Transforms/InstCombine/lshr.ll b/llvm/test/Transforms/InstCombine/lshr.ll index 2172749..b8b1438 100644 --- a/llvm/test/Transforms/InstCombine/lshr.ll +++ b/llvm/test/Transforms/InstCombine/lshr.ll @@ -487,8 +487,8 @@ define i12 @trunc_sandwich_use1(i32 %x) { ; CHECK-LABEL: @trunc_sandwich_use1( ; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 28 ; CHECK-NEXT: call void @use(i32 [[SH]]) -; CHECK-NEXT: [[TR:%.*]] = trunc i32 [[SH]] to i12 -; CHECK-NEXT: [[R:%.*]] = lshr i12 [[TR]], 2 +; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 30 +; CHECK-NEXT: [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12 ; CHECK-NEXT: ret i12 [[R]] ; %sh = lshr i32 %x, 28 @@ -502,8 +502,8 @@ define <3 x i9> @trunc_sandwich_splat_vec_use1(<3 x i14> %x) { ; CHECK-LABEL: @trunc_sandwich_splat_vec_use1( ; CHECK-NEXT: [[SH:%.*]] = lshr <3 x i14> [[X:%.*]], ; CHECK-NEXT: call void @usevec(<3 x i14> [[SH]]) -; CHECK-NEXT: [[TR:%.*]] = trunc <3 x i14> [[SH]] to <3 x i9> -; CHECK-NEXT: [[R:%.*]] = lshr <3 x i9> [[TR]], +; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr <3 x i14> [[X]], +; CHECK-NEXT: [[R:%.*]] = trunc <3 x i14> [[SUM_SHIFT]] to <3 x i9> ; CHECK-NEXT: ret <3 x i9> [[R]] ; %sh = lshr <3 x i14> %x, @@ -517,8 +517,8 @@ define i12 @trunc_sandwich_min_shift1_use1(i32 %x) { ; CHECK-LABEL: @trunc_sandwich_min_shift1_use1( ; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 20 ; CHECK-NEXT: call void @use(i32 [[SH]]) -; CHECK-NEXT: [[TR:%.*]] = trunc i32 [[SH]] to i12 -; CHECK-NEXT: [[R:%.*]] = lshr i12 [[TR]], 1 +; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 21 +; CHECK-NEXT: [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12 ; CHECK-NEXT: ret i12 [[R]] ; %sh = lshr i32 %x, 20 @@ -528,6 +528,8 @@ define i12 @trunc_sandwich_min_shift1_use1(i32 %x) { ret i12 %r } +; negative test - trunc is bigger than first shift + define i12 @trunc_sandwich_small_shift1_use1(i32 %x) { ; CHECK-LABEL: @trunc_sandwich_small_shift1_use1( ; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 19 @@ -547,8 +549,8 @@ define i12 @trunc_sandwich_max_sum_shift_use1(i32 %x) { ; CHECK-LABEL: @trunc_sandwich_max_sum_shift_use1( ; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 20 ; CHECK-NEXT: call void @use(i32 [[SH]]) -; CHECK-NEXT: [[TR:%.*]] = trunc i32 [[SH]] to i12 -; CHECK-NEXT: [[R:%.*]] = lshr i12 [[TR]], 11 +; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 31 +; CHECK-NEXT: [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12 ; CHECK-NEXT: ret i12 [[R]] ; %sh = lshr i32 %x, 20 @@ -562,8 +564,8 @@ define i12 @trunc_sandwich_max_sum_shift2_use1(i32 %x) { ; CHECK-LABEL: @trunc_sandwich_max_sum_shift2_use1( ; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 30 ; CHECK-NEXT: call void @use(i32 [[SH]]) -; CHECK-NEXT: [[TR:%.*]] = trunc i32 [[SH]] to i12 -; CHECK-NEXT: [[R:%.*]] = lshr i12 [[TR]], 1 +; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 31 +; CHECK-NEXT: [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12 ; CHECK-NEXT: ret i12 [[R]] ; %sh = lshr i32 %x, 30 @@ -573,6 +575,8 @@ define i12 @trunc_sandwich_max_sum_shift2_use1(i32 %x) { ret i12 %r } +; negative test - but overshift is simplified to zero by another fold + define i12 @trunc_sandwich_big_sum_shift1_use1(i32 %x) { ; CHECK-LABEL: @trunc_sandwich_big_sum_shift1_use1( ; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 21 @@ -586,6 +590,8 @@ define i12 @trunc_sandwich_big_sum_shift1_use1(i32 %x) { ret i12 %r } +; negative test - but overshift is simplified to zero by another fold + define i12 @trunc_sandwich_big_sum_shift2_use1(i32 %x) { ; CHECK-LABEL: @trunc_sandwich_big_sum_shift2_use1( ; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 31 -- 2.7.4