[InstCombine] fold (trunc (X>>C1)) << C to shift+mask directly
authorSanjay Patel <spatel@rotateright.com>
Fri, 1 Oct 2021 17:30:44 +0000 (13:30 -0400)
committerSanjay Patel <spatel@rotateright.com>
Fri, 1 Oct 2021 18:22:44 +0000 (14:22 -0400)
This is no-externally-visible-functional-difference-intended.
That is, the test diffs show identical instructions other than
name changes (those are included specifically to verify the logic).

The existing transforms created extra instructions and relied
on subsequent folds to get to the final result, but that could
conflict with other transforms like the proposed D110170 (and
caused that patch to be reverted twice so far because of infinite
combine loops).

llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
llvm/test/Transforms/InstCombine/bswap.ll
llvm/test/Transforms/InstCombine/shift-shift.ll

index df7fb01..0f83be5 100644 (file)
@@ -845,6 +845,26 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
       }
     }
 
+    // Similar to above, but look through an intermediate trunc instruction.
+    BinaryOperator *Shr;
+    if (match(Op0, m_OneUse(m_Trunc(m_OneUse(m_BinOp(Shr))))) &&
+        match(Shr, m_Shr(m_Value(X), m_APInt(C1)))) {
+      // The larger shift direction survives through the transform.
+      unsigned ShrAmtC = C1->getZExtValue();
+      unsigned ShDiff = ShrAmtC > ShAmtC ? ShrAmtC - ShAmtC : ShAmtC - ShrAmtC;
+      Constant *ShiftDiffC = ConstantInt::get(X->getType(), ShDiff);
+      auto ShiftOpc = ShrAmtC > ShAmtC ? Shr->getOpcode() : Instruction::Shl;
+
+      // If C1 > C:
+      // (trunc (X >> C1)) << C --> (trunc (X >> (C1 - C))) && (-1 << C)
+      // If C > C1:
+      // (trunc (X >> C1)) << C --> (trunc (X << (C - C1))) && (-1 << C)
+      Value *NewShift = Builder.CreateBinOp(ShiftOpc, X, ShiftDiffC, "sh.diff");
+      Value *Trunc = Builder.CreateTrunc(NewShift, Ty, "tr.sh.diff");
+      APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC));
+      return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, Mask));
+    }
+
     if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) {
       unsigned AmtSum = ShAmtC + C1->getZExtValue();
       // Oversized shifts are simplified to zero in InstSimplify.
@@ -853,41 +873,6 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
         return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum));
     }
 
-    // Fold shl(trunc(shr(x,c1)),c2) -> trunc(and(shl(shr(x,c1),c2),c2'))
-    // Require that the input operand is a non-poison shift-by-constant so that
-    // we have confidence that the shifts will get folded together.
-    Instruction *TrOp;
-    const APInt *TrShiftAmt;
-    if (match(Op0, m_OneUse(m_Trunc(m_Instruction(TrOp)))) &&
-        match(TrOp, m_OneUse(m_Shr(m_Value(), m_APInt(TrShiftAmt)))) &&
-        TrShiftAmt->ult(TrOp->getType()->getScalarSizeInBits())) {
-      Type *SrcTy = TrOp->getType();
-
-      // Okay, we'll do this xform.  Make the shift of shift.
-      unsigned SrcSize = SrcTy->getScalarSizeInBits();
-      Constant *ShAmt = ConstantInt::get(SrcTy, C->zext(SrcSize));
-
-      // (shift2 (shift1 & 0x00FF), c2)
-      Value *NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName());
-
-      // For logical shifts, the truncation has the effect of making the high
-      // part of the register be zeros.  Emulate this by inserting an AND to
-      // clear the top bits as needed.  This 'and' will usually be zapped by
-      // other xforms later if dead.
-      Constant *MaskV =
-          ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcSize, BitWidth));
-
-      // The mask we constructed says what the trunc would do if occurring
-      // between the shifts.  We want to know the effect *after* the second
-      // shift.  We know that it is a logical shift by a constant, so adjust the
-      // mask as appropriate.
-      MaskV = ConstantExpr::get(I.getOpcode(), MaskV, ShAmt);
-      // shift1 & 0x00FF
-      Value *And = Builder.CreateAnd(NSh, MaskV, Op0->getName());
-      // Return the value truncated to the interesting size.
-      return new TruncInst(And, Ty);
-    }
-
     // If we have an opposite shift by the same amount, we may be able to
     // reorder binops and shifts to eliminate math/logic.
     auto isSuitableBinOpcode = [](Instruction::BinaryOps BinOpcode) {
index 32caf4a..18b8a25 100644 (file)
@@ -743,9 +743,9 @@ define i16 @trunc_bswap_i160(i160* %a0) {
 ; CHECK-NEXT:    [[LSHR1:%.*]] = lshr i160 [[LOAD]], 136
 ; CHECK-NEXT:    [[CAST1:%.*]] = trunc i160 [[LSHR1]] to i16
 ; CHECK-NEXT:    [[AND1:%.*]] = and i16 [[CAST1]], 255
-; CHECK-NEXT:    [[TMP1:%.*]] = lshr i160 [[LOAD]], 120
-; CHECK-NEXT:    [[TMP2:%.*]] = trunc i160 [[TMP1]] to i16
-; CHECK-NEXT:    [[SHL:%.*]] = and i16 [[TMP2]], -256
+; CHECK-NEXT:    [[SH_DIFF:%.*]] = lshr i160 [[LOAD]], 120
+; CHECK-NEXT:    [[TR_SH_DIFF:%.*]] = trunc i160 [[SH_DIFF]] to i16
+; CHECK-NEXT:    [[SHL:%.*]] = and i16 [[TR_SH_DIFF]], -256
 ; CHECK-NEXT:    [[OR:%.*]] = or i16 [[AND1]], [[SHL]]
 ; CHECK-NEXT:    ret i16 [[OR]]
 ;
index 7b93364..7c6618f 100644 (file)
@@ -139,9 +139,9 @@ define <2 x i32> @lshr_lshr_vec(<2 x i32> %A) {
 
 define i8 @shl_trunc_bigger_lshr(i32 %x) {
 ; CHECK-LABEL: @shl_trunc_bigger_lshr(
-; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[X:%.*]], 2
-; CHECK-NEXT:    [[TMP2:%.*]] = trunc i32 [[TMP1]] to i8
-; CHECK-NEXT:    [[LT:%.*]] = and i8 [[TMP2]], -8
+; CHECK-NEXT:    [[SH_DIFF:%.*]] = lshr i32 [[X:%.*]], 2
+; CHECK-NEXT:    [[TR_SH_DIFF:%.*]] = trunc i32 [[SH_DIFF]] to i8
+; CHECK-NEXT:    [[LT:%.*]] = and i8 [[TR_SH_DIFF]], -8
 ; CHECK-NEXT:    ret i8 [[LT]]
 ;
   %rt = lshr i32 %x, 5
@@ -153,8 +153,8 @@ define i8 @shl_trunc_bigger_lshr(i32 %x) {
 define i8 @shl_trunc_smaller_lshr(i32 %x) {
 ; CHECK-LABEL: @shl_trunc_smaller_lshr(
 ; CHECK-NEXT:    [[X_TR:%.*]] = trunc i32 [[X:%.*]] to i8
-; CHECK-NEXT:    [[TMP1:%.*]] = shl i8 [[X_TR]], 2
-; CHECK-NEXT:    [[LT:%.*]] = and i8 [[TMP1]], -32
+; CHECK-NEXT:    [[TR_SH_DIFF:%.*]] = shl i8 [[X_TR]], 2
+; CHECK-NEXT:    [[LT:%.*]] = and i8 [[TR_SH_DIFF]], -32
 ; CHECK-NEXT:    ret i8 [[LT]]
 ;
   %rt = lshr i32 %x, 3
@@ -165,9 +165,9 @@ define i8 @shl_trunc_smaller_lshr(i32 %x) {
 
 define i24 @shl_trunc_bigger_ashr(i32 %x) {
 ; CHECK-LABEL: @shl_trunc_bigger_ashr(
-; CHECK-NEXT:    [[TMP1:%.*]] = ashr i32 [[X:%.*]], 9
-; CHECK-NEXT:    [[TMP2:%.*]] = trunc i32 [[TMP1]] to i24
-; CHECK-NEXT:    [[LT:%.*]] = and i24 [[TMP2]], -8
+; CHECK-NEXT:    [[SH_DIFF:%.*]] = ashr i32 [[X:%.*]], 9
+; CHECK-NEXT:    [[TR_SH_DIFF:%.*]] = trunc i32 [[SH_DIFF]] to i24
+; CHECK-NEXT:    [[LT:%.*]] = and i24 [[TR_SH_DIFF]], -8
 ; CHECK-NEXT:    ret i24 [[LT]]
 ;
   %rt = ashr i32 %x, 12
@@ -179,8 +179,8 @@ define i24 @shl_trunc_bigger_ashr(i32 %x) {
 define i24 @shl_trunc_smaller_ashr(i32 %x) {
 ; CHECK-LABEL: @shl_trunc_smaller_ashr(
 ; CHECK-NEXT:    [[X_TR:%.*]] = trunc i32 [[X:%.*]] to i24
-; CHECK-NEXT:    [[TMP1:%.*]] = shl i24 [[X_TR]], 3
-; CHECK-NEXT:    [[LT:%.*]] = and i24 [[TMP1]], -8192
+; CHECK-NEXT:    [[TR_SH_DIFF:%.*]] = shl i24 [[X_TR]], 3
+; CHECK-NEXT:    [[LT:%.*]] = and i24 [[TR_SH_DIFF]], -8192
 ; CHECK-NEXT:    ret i24 [[LT]]
 ;
   %rt = ashr i32 %x, 10