[RISCV] Teach DAG combine what bits Zbp instructions demanded from their inputs.
authorCraig Topper <craig.topper@sifive.com>
Mon, 26 Apr 2021 03:17:13 +0000 (20:17 -0700)
committerCraig Topper <craig.topper@sifive.com>
Mon, 26 Apr 2021 04:54:06 +0000 (21:54 -0700)
This teaches DAG combine that shift amount operands for grev, gorc
shfl, unshfl only read a few bits.

This also teaches DAG combine that grevw, gorcw, shflw, unshflw,
bcompressw, bdecompressw only consume the lower 32 bits of their
inputs.

In the future we can teach SimplifyDemandedBits to also propagate
demanded bits of the output to the inputs in some cases.

llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/rv32zbp-intrinsic.ll
llvm/test/CodeGen/RISCV/rv64zbe-intrinsic.ll
llvm/test/CodeGen/RISCV/rv64zbp-intrinsic.ll

index 4bd5376..c40b2b2 100644 (file)
@@ -5273,12 +5273,30 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     }
     break;
   }
+  case RISCVISD::GREV:
+  case RISCVISD::GORC: {
+    // Only the lower log2(Bitwidth) bits of the the shift amount are read.
+    SDValue ShAmt = N->getOperand(1);
+    unsigned BitWidth = ShAmt.getValueSizeInBits();
+    assert(isPowerOf2_32(BitWidth) && "Unexpected bit width");
+    APInt ShAmtMask(BitWidth, BitWidth - 1);
+    if (SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) {
+      if (N->getOpcode() != ISD::DELETED_NODE)
+        DCI.AddToWorklist(N);
+      return SDValue(N, 0);
+    }
+
+    return combineGREVI_GORCI(N, DCI.DAG);
+  }
   case RISCVISD::GREVW:
   case RISCVISD::GORCW: {
-    // Only the lower 32 bits of the first operand are read
-    SDValue Op0 = N->getOperand(0);
-    APInt Mask = APInt::getLowBitsSet(Op0.getValueSizeInBits(), 32);
-    if (SimplifyDemandedBits(Op0, Mask, DCI)) {
+    // Only the lower 32 bits of LHS and lower 5 bits of RHS are read.
+    SDValue LHS = N->getOperand(0);
+    SDValue RHS = N->getOperand(1);
+    APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32);
+    APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 5);
+    if (SimplifyDemandedBits(LHS, LHSMask, DCI) ||
+        SimplifyDemandedBits(RHS, RHSMask, DCI)) {
       if (N->getOpcode() != ISD::DELETED_NODE)
         DCI.AddToWorklist(N);
       return SDValue(N, 0);
@@ -5286,6 +5304,52 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
 
     return combineGREVI_GORCI(N, DCI.DAG);
   }
+  case RISCVISD::SHFL:
+  case RISCVISD::UNSHFL: {
+    // Only the lower log2(Bitwidth) bits of the the shift amount are read.
+    SDValue ShAmt = N->getOperand(1);
+    unsigned BitWidth = ShAmt.getValueSizeInBits();
+    assert(isPowerOf2_32(BitWidth) && "Unexpected bit width");
+    APInt ShAmtMask(BitWidth, (BitWidth / 2) - 1);
+    if (SimplifyDemandedBits(ShAmt, ShAmtMask, DCI)) {
+      if (N->getOpcode() != ISD::DELETED_NODE)
+        DCI.AddToWorklist(N);
+      return SDValue(N, 0);
+    }
+
+    break;
+  }
+  case RISCVISD::SHFLW:
+  case RISCVISD::UNSHFLW: {
+    // Only the lower 32 bits of LHS and lower 5 bits of RHS are read.
+    SDValue LHS = N->getOperand(0);
+    SDValue RHS = N->getOperand(1);
+    APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32);
+    APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 4);
+    if (SimplifyDemandedBits(LHS, LHSMask, DCI) ||
+        SimplifyDemandedBits(RHS, RHSMask, DCI)) {
+      if (N->getOpcode() != ISD::DELETED_NODE)
+        DCI.AddToWorklist(N);
+      return SDValue(N, 0);
+    }
+
+    break;
+  }
+  case RISCVISD::BCOMPRESSW:
+  case RISCVISD::BDECOMPRESSW: {
+    // Only the lower 32 bits of LHS and RHS are read.
+    SDValue LHS = N->getOperand(0);
+    SDValue RHS = N->getOperand(1);
+    APInt Mask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32);
+    if (SimplifyDemandedBits(LHS, Mask, DCI) ||
+        SimplifyDemandedBits(RHS, Mask, DCI)) {
+      if (N->getOpcode() != ISD::DELETED_NODE)
+        DCI.AddToWorklist(N);
+      return SDValue(N, 0);
+    }
+
+    break;
+  }
   case RISCVISD::FMV_X_ANYEXTW_RV64: {
     SDLoc DL(N);
     SDValue Op0 = N->getOperand(0);
@@ -5316,9 +5380,6 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     return DAG.getNode(ISD::AND, DL, MVT::i64, NewFMV,
                        DAG.getConstant(~SignBit, DL, MVT::i64));
   }
-  case RISCVISD::GREV:
-  case RISCVISD::GORC:
-    return combineGREVI_GORCI(N, DCI.DAG);
   case ISD::OR:
     if (auto GREV = combineORToGREV(SDValue(N, 0), DCI.DAG, Subtarget))
       return GREV;
index 9786579..1c617be 100644 (file)
@@ -20,7 +20,20 @@ define i32 @grev32(i32 %a, i32 %b) nounwind {
  ret i32 %tmp
 }
 
-declare i32 @llvm.riscv.grevi.i32(i32 %a)
+define i32 @grev32_demandedbits(i32 %a, i32 %b) nounwind {
+; RV32IB-LABEL: grev32_demandedbits:
+; RV32IB:       # %bb.0:
+; RV32IB-NEXT:    grev a0, a0, a1
+; RV32IB-NEXT:    ret
+;
+; RV32IBP-LABEL: grev32_demandedbits:
+; RV32IBP:       # %bb.0:
+; RV32IBP-NEXT:    grev a0, a0, a1
+; RV32IBP-NEXT:    ret
+  %c = and i32 %b, 31
+  %tmp = call i32 @llvm.riscv.grev.i32(i32 %a, i32 %b)
+  ret i32 %tmp
+}
 
 define i32 @grevi32(i32 %a) nounwind {
 ; RV32IB-LABEL: grevi32:
@@ -52,7 +65,20 @@ define i32 @gorc32(i32 %a, i32 %b) nounwind {
  ret i32 %tmp
 }
 
-declare i32 @llvm.riscv.gorci.i32(i32 %a)
+define i32 @gorc32_demandedbits(i32 %a, i32 %b) nounwind {
+; RV32IB-LABEL: gorc32_demandedbits:
+; RV32IB:       # %bb.0:
+; RV32IB-NEXT:    gorc a0, a0, a1
+; RV32IB-NEXT:    ret
+;
+; RV32IBP-LABEL: gorc32_demandedbits:
+; RV32IBP:       # %bb.0:
+; RV32IBP-NEXT:    gorc a0, a0, a1
+; RV32IBP-NEXT:    ret
+  %c = and i32 %b, 31
+  %tmp = call i32 @llvm.riscv.gorc.i32(i32 %a, i32 %b)
+  ret i32 %tmp
+}
 
 define i32 @gorci32(i32 %a) nounwind {
 ; RV32IB-LABEL: gorci32:
@@ -84,7 +110,20 @@ define i32 @shfl32(i32 %a, i32 %b) nounwind {
  ret i32 %tmp
 }
 
-declare i32 @llvm.riscv.shfli.i32(i32 %a)
+define i32 @shfl32_demandedbits(i32 %a, i32 %b) nounwind {
+; RV32IB-LABEL: shfl32_demandedbits:
+; RV32IB:       # %bb.0:
+; RV32IB-NEXT:    shfl a0, a0, a1
+; RV32IB-NEXT:    ret
+;
+; RV32IBP-LABEL: shfl32_demandedbits:
+; RV32IBP:       # %bb.0:
+; RV32IBP-NEXT:    shfl a0, a0, a1
+; RV32IBP-NEXT:    ret
+  %c = and i32 %b, 15
+  %tmp = call i32 @llvm.riscv.shfl.i32(i32 %a, i32 %c)
+  ret i32 %tmp
+}
 
 define i32 @shfli32(i32 %a) nounwind {
 ; RV32IB-LABEL: shfli32:
@@ -116,7 +155,20 @@ define i32 @unshfl32(i32 %a, i32 %b) nounwind {
  ret i32 %tmp
 }
 
-declare i32 @llvm.riscv.unshfli.i32(i32 %a)
+define i32 @unshfl32_demandedbits(i32 %a, i32 %b) nounwind {
+; RV32IB-LABEL: unshfl32_demandedbits:
+; RV32IB:       # %bb.0:
+; RV32IB-NEXT:    unshfl a0, a0, a1
+; RV32IB-NEXT:    ret
+;
+; RV32IBP-LABEL: unshfl32_demandedbits:
+; RV32IBP:       # %bb.0:
+; RV32IBP-NEXT:    unshfl a0, a0, a1
+; RV32IBP-NEXT:    ret
+  %c = and i32 %b, 15
+  %tmp = call i32 @llvm.riscv.unshfl.i32(i32 %a, i32 %c)
+  ret i32 %tmp
+}
 
 define i32 @unshfli32(i32 %a) nounwind {
 ; RV32IB-LABEL: unshfli32:
index 07a0689..3e7af31 100644 (file)
@@ -17,7 +17,27 @@ define signext i32 @bcompress32(i32 signext %a, i32 signext %b) nounwind {
 ; RV64IBE-NEXT:    bcompressw a0, a0, a1
 ; RV64IBE-NEXT:    ret
   %tmp = call i32 @llvm.riscv.bcompress.i32(i32 %a, i32 %b)
- ret i32 %tmp
+  ret i32 %tmp
+}
+
+define signext i32 @bcompress32_demandedbits(i32 signext %a, i32 signext %b, i32 signext %c, i32 signext %d) nounwind {
+; RV64IB-LABEL: bcompress32_demandedbits:
+; RV64IB:       # %bb.0:
+; RV64IB-NEXT:    add a0, a0, a1
+; RV64IB-NEXT:    add a1, a2, a3
+; RV64IB-NEXT:    bcompressw a0, a0, a1
+; RV64IB-NEXT:    ret
+;
+; RV64IBE-LABEL: bcompress32_demandedbits:
+; RV64IBE:       # %bb.0:
+; RV64IBE-NEXT:    add a0, a0, a1
+; RV64IBE-NEXT:    add a1, a2, a3
+; RV64IBE-NEXT:    bcompressw a0, a0, a1
+; RV64IBE-NEXT:    ret
+  %e = add i32 %a, %b
+  %f = add i32 %c, %d
+  %tmp = call i32 @llvm.riscv.bcompress.i32(i32 %e, i32 %f)
+  ret i32 %tmp
 }
 
 declare i32 @llvm.riscv.bdecompress.i32(i32 %a, i32 %b)
@@ -33,7 +53,27 @@ define signext i32 @bdecompress32(i32 signext %a, i32 signext %b) nounwind {
 ; RV64IBE-NEXT:    bdecompressw a0, a0, a1
 ; RV64IBE-NEXT:    ret
   %tmp = call i32 @llvm.riscv.bdecompress.i32(i32 %a, i32 %b)
- ret i32 %tmp
+  ret i32 %tmp
+}
+
+define signext i32 @bdecompress32_demandedbits(i32 signext %a, i32 signext %b, i32 signext %c, i32 signext %d) nounwind {
+; RV64IB-LABEL: bdecompress32_demandedbits:
+; RV64IB:       # %bb.0:
+; RV64IB-NEXT:    add a0, a0, a1
+; RV64IB-NEXT:    add a1, a2, a3
+; RV64IB-NEXT:    bdecompressw a0, a0, a1
+; RV64IB-NEXT:    ret
+;
+; RV64IBE-LABEL: bdecompress32_demandedbits:
+; RV64IBE:       # %bb.0:
+; RV64IBE-NEXT:    add a0, a0, a1
+; RV64IBE-NEXT:    add a1, a2, a3
+; RV64IBE-NEXT:    bdecompressw a0, a0, a1
+; RV64IBE-NEXT:    ret
+  %e = add i32 %a, %b
+  %f = add i32 %c, %d
+  %tmp = call i32 @llvm.riscv.bdecompress.i32(i32 %e, i32 %f)
+  ret i32 %tmp
 }
 
 declare i64 @llvm.riscv.bcompress.i64(i64 %a, i64 %b)
@@ -49,7 +89,7 @@ define i64 @bcompress64(i64 %a, i64 %b) nounwind {
 ; RV64IBE-NEXT:    bcompress a0, a0, a1
 ; RV64IBE-NEXT:    ret
   %tmp = call i64 @llvm.riscv.bcompress.i64(i64 %a, i64 %b)
- ret i64 %tmp
 ret i64 %tmp
 }
 
 declare i64 @llvm.riscv.bdecompress.i64(i64 %a, i64 %b)
@@ -65,5 +105,5 @@ define i64 @bdecompress64(i64 %a, i64 %b) nounwind {
 ; RV64IBE-NEXT:    bdecompress a0, a0, a1
 ; RV64IBE-NEXT:    ret
   %tmp = call i64 @llvm.riscv.bdecompress.i64(i64 %a, i64 %b)
- ret i64 %tmp
 ret i64 %tmp
 }
index e2c325a..81f14d6 100644 (file)
@@ -20,6 +20,24 @@ define signext i32 @grev32(i32 signext %a, i32 signext %b) nounwind {
  ret i32 %tmp
 }
 
+define signext i32 @grev32_demandedbits(i32 signext %a, i32 signext %b, i32 signext %c) nounwind {
+; RV64IB-LABEL: grev32_demandedbits:
+; RV64IB:       # %bb.0:
+; RV64IB-NEXT:    add a0, a0, a1
+; RV64IB-NEXT:    grevw a0, a0, a2
+; RV64IB-NEXT:    ret
+;
+; RV64IBP-LABEL: grev32_demandedbits:
+; RV64IBP:       # %bb.0:
+; RV64IBP-NEXT:    add a0, a0, a1
+; RV64IBP-NEXT:    grevw a0, a0, a2
+; RV64IBP-NEXT:    ret
+  %d = add i32 %a, %b
+  %e = and i32 %c, 31
+  %tmp = call i32 @llvm.riscv.grev.i32(i32 %d, i32 %e)
+  ret i32 %tmp
+}
+
 declare i32 @llvm.riscv.grevi.i32(i32 %a)
 
 define signext i32 @grevi32(i32 signext %a) nounwind {
@@ -52,7 +70,23 @@ define signext i32 @gorc32(i32 signext %a, i32 signext %b) nounwind {
  ret i32 %tmp
 }
 
-declare i32 @llvm.riscv.gorci.i32(i32 %a)
+define signext i32 @gorc32_demandedbits(i32 signext %a, i32 signext %b, i32 signext %c) nounwind {
+; RV64IB-LABEL: gorc32_demandedbits:
+; RV64IB:       # %bb.0:
+; RV64IB-NEXT:    add a0, a0, a1
+; RV64IB-NEXT:    gorcw a0, a0, a2
+; RV64IB-NEXT:    ret
+;
+; RV64IBP-LABEL: gorc32_demandedbits:
+; RV64IBP:       # %bb.0:
+; RV64IBP-NEXT:    add a0, a0, a1
+; RV64IBP-NEXT:    gorcw a0, a0, a2
+; RV64IBP-NEXT:    ret
+  %d = add i32 %a, %b
+  %e = and i32 %c, 31
+  %tmp = call i32 @llvm.riscv.gorc.i32(i32 %d, i32 %e)
+  ret i32 %tmp
+}
 
 define signext i32 @gorci32(i32 signext %a) nounwind {
 ; RV64IB-LABEL: gorci32:
@@ -84,7 +118,23 @@ define signext i32 @shfl32(i32 signext %a, i32 signext %b) nounwind {
  ret i32 %tmp
 }
 
-declare i32 @llvm.riscv.shfli.i32(i32 %a)
+define signext i32 @shfl32_demandedbits(i32 signext %a, i32 signext %b, i32 signext %c) nounwind {
+; RV64IB-LABEL: shfl32_demandedbits:
+; RV64IB:       # %bb.0:
+; RV64IB-NEXT:    add a0, a0, a1
+; RV64IB-NEXT:    shflw a0, a0, a2
+; RV64IB-NEXT:    ret
+;
+; RV64IBP-LABEL: shfl32_demandedbits:
+; RV64IBP:       # %bb.0:
+; RV64IBP-NEXT:    add a0, a0, a1
+; RV64IBP-NEXT:    shflw a0, a0, a2
+; RV64IBP-NEXT:    ret
+  %d = add i32 %a, %b
+  %e = and i32 %c, 15
+  %tmp = call i32 @llvm.riscv.shfl.i32(i32 %d, i32 %e)
+  ret i32 %tmp
+}
 
 define signext i32 @shfli32(i32 signext %a) nounwind {
 ; RV64IB-LABEL: shfli32:
@@ -116,7 +166,23 @@ define signext i32 @unshfl32(i32 signext %a, i32 signext %b) nounwind {
  ret i32 %tmp
 }
 
-declare i32 @llvm.riscv.unshfli.i32(i32 %a)
+define signext i32 @unshfl32_demandedbits(i32 signext %a, i32 signext %b, i32 signext %c) nounwind {
+; RV64IB-LABEL: unshfl32_demandedbits:
+; RV64IB:       # %bb.0:
+; RV64IB-NEXT:    add a0, a0, a1
+; RV64IB-NEXT:    unshflw a0, a0, a2
+; RV64IB-NEXT:    ret
+;
+; RV64IBP-LABEL: unshfl32_demandedbits:
+; RV64IBP:       # %bb.0:
+; RV64IBP-NEXT:    add a0, a0, a1
+; RV64IBP-NEXT:    unshflw a0, a0, a2
+; RV64IBP-NEXT:    ret
+  %d = add i32 %a, %b
+  %e = and i32 %c, 15
+  %tmp = call i32 @llvm.riscv.unshfl.i32(i32 %d, i32 %e)
+  ret i32 %tmp
+}
 
 define signext i32 @unshfli32(i32 signext %a) nounwind {
 ; RV64IB-LABEL: unshfli32:
@@ -148,7 +214,20 @@ define i64 @grev64(i64 %a, i64 %b) nounwind {
  ret i64 %tmp
 }
 
-declare i64 @llvm.riscv.grevi.i64(i64 %a)
+define i64 @grev64_demandedbits(i64 %a, i64 %b) nounwind {
+; RV64IB-LABEL: grev64_demandedbits:
+; RV64IB:       # %bb.0:
+; RV64IB-NEXT:    grev a0, a0, a1
+; RV64IB-NEXT:    ret
+;
+; RV64IBP-LABEL: grev64_demandedbits:
+; RV64IBP:       # %bb.0:
+; RV64IBP-NEXT:    grev a0, a0, a1
+; RV64IBP-NEXT:    ret
+  %c = and i64 %b, 63
+  %tmp = call i64 @llvm.riscv.grev.i64(i64 %a, i64 %c)
+  ret i64 %tmp
+}
 
 define i64 @grevi64(i64 %a) nounwind {
 ; RV64IB-LABEL: grevi64:
@@ -180,6 +259,21 @@ define i64 @gorc64(i64 %a, i64 %b) nounwind {
  ret i64 %tmp
 }
 
+define i64 @gorc64_demandedbits(i64 %a, i64 %b) nounwind {
+; RV64IB-LABEL: gorc64_demandedbits:
+; RV64IB:       # %bb.0:
+; RV64IB-NEXT:    gorc a0, a0, a1
+; RV64IB-NEXT:    ret
+;
+; RV64IBP-LABEL: gorc64_demandedbits:
+; RV64IBP:       # %bb.0:
+; RV64IBP-NEXT:    gorc a0, a0, a1
+; RV64IBP-NEXT:    ret
+  %c = and i64 %b, 63
+  %tmp = call i64 @llvm.riscv.gorc.i64(i64 %a, i64 %c)
+  ret i64 %tmp
+}
+
 declare i64 @llvm.riscv.gorci.i64(i64 %a)
 
 define i64 @gorci64(i64 %a) nounwind {
@@ -212,7 +306,20 @@ define i64 @shfl64(i64 %a, i64 %b) nounwind {
  ret i64 %tmp
 }
 
-declare i64 @llvm.riscv.shfli.i64(i64 %a)
+define i64 @shfl64_demandedbits(i64 %a, i64 %b) nounwind {
+; RV64IB-LABEL: shfl64_demandedbits:
+; RV64IB:       # %bb.0:
+; RV64IB-NEXT:    shfl a0, a0, a1
+; RV64IB-NEXT:    ret
+;
+; RV64IBP-LABEL: shfl64_demandedbits:
+; RV64IBP:       # %bb.0:
+; RV64IBP-NEXT:    shfl a0, a0, a1
+; RV64IBP-NEXT:    ret
+  %c = and i64 %b, 31
+  %tmp = call i64 @llvm.riscv.shfl.i64(i64 %a, i64 %c)
+  ret i64 %tmp
+}
 
 define i64 @shfli64(i64 %a) nounwind {
 ; RV64IB-LABEL: shfli64:
@@ -244,7 +351,20 @@ define i64 @unshfl64(i64 %a, i64 %b) nounwind {
  ret i64 %tmp
 }
 
-declare i64 @llvm.riscv.unshfli.i64(i64 %a)
+define i64 @unshfl64_demandedbits(i64 %a, i64 %b) nounwind {
+; RV64IB-LABEL: unshfl64_demandedbits:
+; RV64IB:       # %bb.0:
+; RV64IB-NEXT:    unshfl a0, a0, a1
+; RV64IB-NEXT:    ret
+;
+; RV64IBP-LABEL: unshfl64_demandedbits:
+; RV64IBP:       # %bb.0:
+; RV64IBP-NEXT:    unshfl a0, a0, a1
+; RV64IBP-NEXT:    ret
+  %c = and i64 %b, 31
+  %tmp = call i64 @llvm.riscv.unshfl.i64(i64 %a, i64 %c)
+  ret i64 %tmp
+}
 
 define i64 @unshfli64(i64 %a) nounwind {
 ; RV64IB-LABEL: unshfli64: