[InstCombine] fold lshr(trunc(lshr X, C1)) C2
authorSanjay Patel <spatel@rotateright.com>
Fri, 24 Sep 2021 17:27:02 +0000 (13:27 -0400)
committerSanjay Patel <spatel@rotateright.com>
Fri, 24 Sep 2021 19:44:07 +0000 (15:44 -0400)
Only the multi-use cases are changing here because there's
another fold that catches the simpler patterns.

But that other fold is the source of infinite loops when we
try to add D110170, so removing that is planned as a follow-up.

Attempt to show the general proof in Alive2:
https://alive2.llvm.org/ce/z/Ns1uS2

Note that the overshift fold-to-zero tests are not
currently handled by instsimplify. If they were, we
could assert that the shift amount sum is less than
the source bitwidth.

llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
llvm/test/Transforms/InstCombine/lshr.ll

index 92bfae2..b0d328f 100644 (file)
@@ -1149,14 +1149,26 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
       }
     }
 
+    // (X >>u C1) >>u C --> X >>u (C1 + C)
     if (match(Op0, m_LShr(m_Value(X), m_APInt(C1)))) {
-      unsigned AmtSum = ShAmtC + C1->getZExtValue();
       // Oversized shifts are simplified to zero in InstSimplify.
+      unsigned AmtSum = ShAmtC + C1->getZExtValue();
       if (AmtSum < BitWidth)
-        // (X >>u C1) >>u C --> X >>u (C1 + C)
         return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum));
     }
 
+    // If the first shift covers the number of bits truncated and the combined
+    // shift fits in the source width:
+    // (trunc (X >>u C1)) >>u C --> trunc (X >>u (C1 + C))
+    if (match(Op0, m_OneUse(m_Trunc(m_LShr(m_Value(X), m_APInt(C1)))))) {
+      unsigned SrcWidth = X->getType()->getScalarSizeInBits();
+      unsigned AmtSum = ShAmtC + C1->getZExtValue();
+      if (C1->uge(SrcWidth - BitWidth) && AmtSum < SrcWidth) {
+        Value *SumShift = Builder.CreateLShr(X, AmtSum, "sum.shift");
+        return new TruncInst(SumShift, Ty);
+      }
+    }
+
     // Look for a "splat" mul pattern - it replicates bits across each half of
     // a value, so a right shift is just a mask of the low bits:
     // lshr i32 (mul nuw X, Pow2+1), 16 --> and X, Pow2-1
index 2172749..b8b1438 100644 (file)
@@ -487,8 +487,8 @@ define i12 @trunc_sandwich_use1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 28
 ; CHECK-NEXT:    call void @use(i32 [[SH]])
-; CHECK-NEXT:    [[TR:%.*]] = trunc i32 [[SH]] to i12
-; CHECK-NEXT:    [[R:%.*]] = lshr i12 [[TR]], 2
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 30
+; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
 ; CHECK-NEXT:    ret i12 [[R]]
 ;
   %sh = lshr i32 %x, 28
@@ -502,8 +502,8 @@ define <3 x i9> @trunc_sandwich_splat_vec_use1(<3 x i14> %x) {
 ; CHECK-LABEL: @trunc_sandwich_splat_vec_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr <3 x i14> [[X:%.*]], <i14 6, i14 6, i14 6>
 ; CHECK-NEXT:    call void @usevec(<3 x i14> [[SH]])
-; CHECK-NEXT:    [[TR:%.*]] = trunc <3 x i14> [[SH]] to <3 x i9>
-; CHECK-NEXT:    [[R:%.*]] = lshr <3 x i9> [[TR]], <i9 5, i9 5, i9 5>
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr <3 x i14> [[X]], <i14 11, i14 11, i14 11>
+; CHECK-NEXT:    [[R:%.*]] = trunc <3 x i14> [[SUM_SHIFT]] to <3 x i9>
 ; CHECK-NEXT:    ret <3 x i9> [[R]]
 ;
   %sh = lshr <3 x i14> %x, <i14 6, i14 6, i14 6>
@@ -517,8 +517,8 @@ define i12 @trunc_sandwich_min_shift1_use1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_min_shift1_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 20
 ; CHECK-NEXT:    call void @use(i32 [[SH]])
-; CHECK-NEXT:    [[TR:%.*]] = trunc i32 [[SH]] to i12
-; CHECK-NEXT:    [[R:%.*]] = lshr i12 [[TR]], 1
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 21
+; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
 ; CHECK-NEXT:    ret i12 [[R]]
 ;
   %sh = lshr i32 %x, 20
@@ -528,6 +528,8 @@ define i12 @trunc_sandwich_min_shift1_use1(i32 %x) {
   ret i12 %r
 }
 
+; negative test - trunc is bigger than first shift
+
 define i12 @trunc_sandwich_small_shift1_use1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_small_shift1_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 19
@@ -547,8 +549,8 @@ define i12 @trunc_sandwich_max_sum_shift_use1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_max_sum_shift_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 20
 ; CHECK-NEXT:    call void @use(i32 [[SH]])
-; CHECK-NEXT:    [[TR:%.*]] = trunc i32 [[SH]] to i12
-; CHECK-NEXT:    [[R:%.*]] = lshr i12 [[TR]], 11
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 31
+; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
 ; CHECK-NEXT:    ret i12 [[R]]
 ;
   %sh = lshr i32 %x, 20
@@ -562,8 +564,8 @@ define i12 @trunc_sandwich_max_sum_shift2_use1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_max_sum_shift2_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 30
 ; CHECK-NEXT:    call void @use(i32 [[SH]])
-; CHECK-NEXT:    [[TR:%.*]] = trunc i32 [[SH]] to i12
-; CHECK-NEXT:    [[R:%.*]] = lshr i12 [[TR]], 1
+; CHECK-NEXT:    [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 31
+; CHECK-NEXT:    [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
 ; CHECK-NEXT:    ret i12 [[R]]
 ;
   %sh = lshr i32 %x, 30
@@ -573,6 +575,8 @@ define i12 @trunc_sandwich_max_sum_shift2_use1(i32 %x) {
   ret i12 %r
 }
 
+; negative test - but overshift is simplified to zero by another fold
+
 define i12 @trunc_sandwich_big_sum_shift1_use1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_big_sum_shift1_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 21
@@ -586,6 +590,8 @@ define i12 @trunc_sandwich_big_sum_shift1_use1(i32 %x) {
   ret i12 %r
 }
 
+; negative test - but overshift is simplified to zero by another fold
+
 define i12 @trunc_sandwich_big_sum_shift2_use1(i32 %x) {
 ; CHECK-LABEL: @trunc_sandwich_big_sum_shift2_use1(
 ; CHECK-NEXT:    [[SH:%.*]] = lshr i32 [[X:%.*]], 31