[DAGCombine] Mask doesn't have to be (EltSize - 1) exactly when combining rotation
authorwangpc <pc.wang@linux.alibaba.com>
Tue, 26 Jul 2022 13:11:39 +0000 (21:11 +0800)
committerwangpc <pc.wang@linux.alibaba.com>
Tue, 26 Jul 2022 13:14:45 +0000 (21:14 +0800)
I think what we need is the least Log2(EltSize) significant bits are known to be ones.

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D130251

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/RISCV/rotl-rotr.ll

index 0edfefa..21bad73 100644 (file)
@@ -7261,6 +7261,7 @@ static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
 // Otherwise if matching a general funnel shift, it should be clear.
 static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
                            SelectionDAG &DAG, bool IsRotate) {
+  const auto &TLI = DAG.getTargetLoweringInfo();
   // If EltSize is a power of 2 then:
   //
   //  (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1)
@@ -7292,19 +7293,19 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
   // always invokes undefined behavior for 32-bit X.
   //
   // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise.
+  // This allows us to peek through any operations that only affect Mask's
+  // un-demanded bits.
   //
-  // NOTE: We can only do this when matching an AND and not a general
-  // funnel shift.
+  // NOTE: We can only do this when matching operations which won't modify the
+  // least Log2(EltSize) significant bits and not a general funnel shift.
   unsigned MaskLoBits = 0;
-  if (IsRotate && Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) {
-    if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) {
-      KnownBits Known = DAG.computeKnownBits(Neg.getOperand(0));
-      unsigned Bits = Log2_64(EltSize);
-      if (NegC->getAPIntValue().getActiveBits() <= Bits &&
-          ((NegC->getAPIntValue() | Known.Zero).countTrailingOnes() >= Bits)) {
-        Neg = Neg.getOperand(0);
-        MaskLoBits = Bits;
-      }
+  if (IsRotate && isPowerOf2_64(EltSize)) {
+    unsigned Bits = Log2_64(EltSize);
+    APInt DemandedBits = APInt::getLowBitsSet(EltSize, Bits);
+    if (SDValue Inner =
+            TLI.SimplifyMultipleUseDemandedBits(Neg, DemandedBits, DAG)) {
+      Neg = Inner;
+      MaskLoBits = Bits;
     }
   }
 
@@ -7316,15 +7317,14 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
     return false;
   SDValue NegOp1 = Neg.getOperand(1);
 
-  // On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with
-  // Pos'.  The truncation is redundant for the purpose of the equality.
-  if (MaskLoBits && Pos.getOpcode() == ISD::AND) {
-    if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) {
-      KnownBits Known = DAG.computeKnownBits(Pos.getOperand(0));
-      if (PosC->getAPIntValue().getActiveBits() <= MaskLoBits &&
-          ((PosC->getAPIntValue() | Known.Zero).countTrailingOnes() >=
-           MaskLoBits))
-        Pos = Pos.getOperand(0);
+  // On the RHS of [A], if Pos is the result of operation on Pos' that won't
+  // affect Mask's demanded bits, just replace Pos with Pos'. These operations
+  // are redundant for the purpose of the equality.
+  if (MaskLoBits) {
+    APInt DemandedBits = APInt::getLowBitsSet(EltSize, MaskLoBits);
+    if (SDValue Inner =
+            TLI.SimplifyMultipleUseDemandedBits(Pos, DemandedBits, DAG)) {
+      Pos = Inner;
     }
   }
 
index aa6381a..0109ca5 100644 (file)
@@ -341,18 +341,12 @@ define i32 @rotl_32_mask_and_63_and_31(i32 %x, i32 %y) nounwind {
 ;
 ; RV32ZBB-LABEL: rotl_32_mask_and_63_and_31:
 ; RV32ZBB:       # %bb.0:
-; RV32ZBB-NEXT:    sll a2, a0, a1
-; RV32ZBB-NEXT:    neg a1, a1
-; RV32ZBB-NEXT:    srl a0, a0, a1
-; RV32ZBB-NEXT:    or a0, a2, a0
+; RV32ZBB-NEXT:    rol a0, a0, a1
 ; RV32ZBB-NEXT:    ret
 ;
 ; RV64ZBB-LABEL: rotl_32_mask_and_63_and_31:
 ; RV64ZBB:       # %bb.0:
-; RV64ZBB-NEXT:    sllw a2, a0, a1
-; RV64ZBB-NEXT:    negw a1, a1
-; RV64ZBB-NEXT:    srlw a0, a0, a1
-; RV64ZBB-NEXT:    or a0, a2, a0
+; RV64ZBB-NEXT:    rolw a0, a0, a1
 ; RV64ZBB-NEXT:    ret
   %a = and i32 %y, 63
   %b = shl i32 %x, %a
@@ -384,20 +378,12 @@ define i32 @rotl_32_mask_or_64_or_32(i32 %x, i32 %y) nounwind {
 ;
 ; RV32ZBB-LABEL: rotl_32_mask_or_64_or_32:
 ; RV32ZBB:       # %bb.0:
-; RV32ZBB-NEXT:    ori a2, a1, 64
-; RV32ZBB-NEXT:    sll a2, a0, a2
-; RV32ZBB-NEXT:    neg a1, a1
-; RV32ZBB-NEXT:    ori a1, a1, 32
-; RV32ZBB-NEXT:    srl a0, a0, a1
-; RV32ZBB-NEXT:    or a0, a2, a0
+; RV32ZBB-NEXT:    rol a0, a0, a1
 ; RV32ZBB-NEXT:    ret
 ;
 ; RV64ZBB-LABEL: rotl_32_mask_or_64_or_32:
 ; RV64ZBB:       # %bb.0:
-; RV64ZBB-NEXT:    sllw a2, a0, a1
-; RV64ZBB-NEXT:    negw a1, a1
-; RV64ZBB-NEXT:    srlw a0, a0, a1
-; RV64ZBB-NEXT:    or a0, a2, a0
+; RV64ZBB-NEXT:    rolw a0, a0, a1
 ; RV64ZBB-NEXT:    ret
   %a = or i32 %y, 64
   %b = shl i32 %x, %a
@@ -461,18 +447,12 @@ define i32 @rotr_32_mask_and_63_and_31(i32 %x, i32 %y) nounwind {
 ;
 ; RV32ZBB-LABEL: rotr_32_mask_and_63_and_31:
 ; RV32ZBB:       # %bb.0:
-; RV32ZBB-NEXT:    srl a2, a0, a1
-; RV32ZBB-NEXT:    neg a1, a1
-; RV32ZBB-NEXT:    sll a0, a0, a1
-; RV32ZBB-NEXT:    or a0, a2, a0
+; RV32ZBB-NEXT:    ror a0, a0, a1
 ; RV32ZBB-NEXT:    ret
 ;
 ; RV64ZBB-LABEL: rotr_32_mask_and_63_and_31:
 ; RV64ZBB:       # %bb.0:
-; RV64ZBB-NEXT:    srlw a2, a0, a1
-; RV64ZBB-NEXT:    negw a1, a1
-; RV64ZBB-NEXT:    sllw a0, a0, a1
-; RV64ZBB-NEXT:    or a0, a2, a0
+; RV64ZBB-NEXT:    rorw a0, a0, a1
 ; RV64ZBB-NEXT:    ret
   %a = and i32 %y, 63
   %b = lshr i32 %x, %a
@@ -504,20 +484,12 @@ define i32 @rotr_32_mask_or_64_or_32(i32 %x, i32 %y) nounwind {
 ;
 ; RV32ZBB-LABEL: rotr_32_mask_or_64_or_32:
 ; RV32ZBB:       # %bb.0:
-; RV32ZBB-NEXT:    ori a2, a1, 64
-; RV32ZBB-NEXT:    srl a2, a0, a2
-; RV32ZBB-NEXT:    neg a1, a1
-; RV32ZBB-NEXT:    ori a1, a1, 32
-; RV32ZBB-NEXT:    sll a0, a0, a1
-; RV32ZBB-NEXT:    or a0, a2, a0
+; RV32ZBB-NEXT:    ror a0, a0, a1
 ; RV32ZBB-NEXT:    ret
 ;
 ; RV64ZBB-LABEL: rotr_32_mask_or_64_or_32:
 ; RV64ZBB:       # %bb.0:
-; RV64ZBB-NEXT:    srlw a2, a0, a1
-; RV64ZBB-NEXT:    negw a1, a1
-; RV64ZBB-NEXT:    sllw a0, a0, a1
-; RV64ZBB-NEXT:    or a0, a2, a0
+; RV64ZBB-NEXT:    rorw a0, a0, a1
 ; RV64ZBB-NEXT:    ret
   %a = or i32 %y, 64
   %b = lshr i32 %x, %a
@@ -718,10 +690,7 @@ define i64 @rotl_64_mask_and_127_and_63(i64 %x, i64 %y) nounwind {
 ;
 ; RV64ZBB-LABEL: rotl_64_mask_and_127_and_63:
 ; RV64ZBB:       # %bb.0:
-; RV64ZBB-NEXT:    sll a2, a0, a1
-; RV64ZBB-NEXT:    neg a1, a1
-; RV64ZBB-NEXT:    srl a0, a0, a1
-; RV64ZBB-NEXT:    or a0, a2, a0
+; RV64ZBB-NEXT:    rol a0, a0, a1
 ; RV64ZBB-NEXT:    ret
   %a = and i64 %y, 127
   %b = shl i64 %x, %a
@@ -761,12 +730,7 @@ define i64 @rotl_64_mask_or_128_or_64(i64 %x, i64 %y) nounwind {
 ;
 ; RV64ZBB-LABEL: rotl_64_mask_or_128_or_64:
 ; RV64ZBB:       # %bb.0:
-; RV64ZBB-NEXT:    ori a2, a1, 128
-; RV64ZBB-NEXT:    sll a2, a0, a2
-; RV64ZBB-NEXT:    neg a1, a1
-; RV64ZBB-NEXT:    ori a1, a1, 64
-; RV64ZBB-NEXT:    srl a0, a0, a1
-; RV64ZBB-NEXT:    or a0, a2, a0
+; RV64ZBB-NEXT:    rol a0, a0, a1
 ; RV64ZBB-NEXT:    ret
   %a = or i64 %y, 128
   %b = shl i64 %x, %a
@@ -967,10 +931,7 @@ define i64 @rotr_64_mask_and_127_and_63(i64 %x, i64 %y) nounwind {
 ;
 ; RV64ZBB-LABEL: rotr_64_mask_and_127_and_63:
 ; RV64ZBB:       # %bb.0:
-; RV64ZBB-NEXT:    srl a2, a0, a1
-; RV64ZBB-NEXT:    neg a1, a1
-; RV64ZBB-NEXT:    sll a0, a0, a1
-; RV64ZBB-NEXT:    or a0, a2, a0
+; RV64ZBB-NEXT:    ror a0, a0, a1
 ; RV64ZBB-NEXT:    ret
   %a = and i64 %y, 127
   %b = lshr i64 %x, %a
@@ -1010,12 +971,7 @@ define i64 @rotr_64_mask_or_128_or_64(i64 %x, i64 %y) nounwind {
 ;
 ; RV64ZBB-LABEL: rotr_64_mask_or_128_or_64:
 ; RV64ZBB:       # %bb.0:
-; RV64ZBB-NEXT:    ori a2, a1, 128
-; RV64ZBB-NEXT:    srl a2, a0, a2
-; RV64ZBB-NEXT:    neg a1, a1
-; RV64ZBB-NEXT:    ori a1, a1, 64
-; RV64ZBB-NEXT:    sll a0, a0, a1
-; RV64ZBB-NEXT:    or a0, a2, a0
+; RV64ZBB-NEXT:    ror a0, a0, a1
 ; RV64ZBB-NEXT:    ret
   %a = or i64 %y, 128
   %b = lshr i64 %x, %a