[RISCV] Add RISCVISD opcodes for CLZW and CTZW.
authorCraig Topper <craig.topper@sifive.com>
Wed, 31 Mar 2021 16:27:22 +0000 (09:27 -0700)
committerCraig Topper <craig.topper@sifive.com>
Wed, 31 Mar 2021 16:40:07 +0000 (09:40 -0700)
Our CLZW isel pattern is quite easily broken by surrounding code
preventing it from matching sometimes. This usually results in
failing to remove the and X, 0xffffffff inserted by type
legalization. The add with -32 that type legalization also inserts
will often gets combined into other add/sub nodes. That doesn't
usually result in extra code when we don't use clzw.

CTTZ seems to be less fragile, but I wanted to keep it consistent
with CTLZ.

Reviewed By: asb, HsiangKai

Differential Revision: https://reviews.llvm.org/D99317

llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/lib/Target/RISCV/RISCVInstrInfoB.td
llvm/test/CodeGen/RISCV/rv64zbb.ll

index 3b59bda..16a7817 100644 (file)
@@ -276,6 +276,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::SMAX, XLenVT, Legal);
     setOperationAction(ISD::UMIN, XLenVT, Legal);
     setOperationAction(ISD::UMAX, XLenVT, Legal);
+
+    if (Subtarget.is64Bit()) {
+      setOperationAction(ISD::CTTZ, MVT::i32, Custom);
+      setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i32, Custom);
+      setOperationAction(ISD::CTLZ, MVT::i32, Custom);
+      setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i32, Custom);
+    }
   } else {
     setOperationAction(ISD::CTTZ, XLenVT, Expand);
     setOperationAction(ISD::CTLZ, XLenVT, Expand);
@@ -3885,6 +3892,22 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
            "Unexpected custom legalisation");
     Results.push_back(customLegalizeToWOp(N, DAG));
     break;
+  case ISD::CTTZ:
+  case ISD::CTTZ_ZERO_UNDEF:
+  case ISD::CTLZ:
+  case ISD::CTLZ_ZERO_UNDEF: {
+    assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
+           "Unexpected custom legalisation");
+
+    SDValue NewOp0 =
+        DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
+    bool IsCTZ =
+        N->getOpcode() == ISD::CTTZ || N->getOpcode() == ISD::CTTZ_ZERO_UNDEF;
+    unsigned Opc = IsCTZ ? RISCVISD::CTZW : RISCVISD::CLZW;
+    SDValue Res = DAG.getNode(Opc, DL, MVT::i64, NewOp0);
+    Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Res));
+    return;
+  }
   case ISD::SDIV:
   case ISD::UDIV:
   case ISD::UREM: {
@@ -4548,6 +4571,18 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     }
     break;
   }
+  case RISCVISD::CLZW:
+  case RISCVISD::CTZW: {
+    // Only the lower 32 bits of the first operand are read
+    SDValue Op0 = N->getOperand(0);
+    APInt Mask = APInt::getLowBitsSet(Op0.getValueSizeInBits(), 32);
+    if (SimplifyDemandedBits(Op0, Mask, DCI)) {
+      if (N->getOpcode() != ISD::DELETED_NODE)
+        DCI.AddToWorklist(N);
+      return SDValue(N, 0);
+    }
+    break;
+  }
   case RISCVISD::FSL:
   case RISCVISD::FSR: {
     // Only the lower log2(Bitwidth)+1 bits of the the shift amount are read.
@@ -4989,6 +5024,20 @@ void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
     Known = Known.sext(BitWidth);
     break;
   }
+  case RISCVISD::CTZW: {
+    KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(0), Depth + 1);
+    unsigned PossibleTZ = Known2.trunc(32).countMaxTrailingZeros();
+    unsigned LowBits = Log2_32(PossibleTZ) + 1;
+    Known.Zero.setBitsFrom(LowBits);
+    break;
+  }
+  case RISCVISD::CLZW: {
+    KnownBits Known2 = DAG.computeKnownBits(Op.getOperand(0), Depth + 1);
+    unsigned PossibleLZ = Known2.trunc(32).countMaxLeadingZeros();
+    unsigned LowBits = Log2_32(PossibleLZ) + 1;
+    Known.Zero.setBitsFrom(LowBits);
+    break;
+  }
   case RISCVISD::READ_VLENB:
     // We assume VLENB is at least 8 bytes.
     // FIXME: The 1.0 draft spec defines minimum VLEN as 128 bits.
@@ -6743,6 +6792,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(REMUW)
   NODE_NAME_CASE(ROLW)
   NODE_NAME_CASE(RORW)
+  NODE_NAME_CASE(CLZW)
+  NODE_NAME_CASE(CTZW)
   NODE_NAME_CASE(FSLW)
   NODE_NAME_CASE(FSRW)
   NODE_NAME_CASE(FSL)
index 8dc4fee..20e96c6 100644 (file)
@@ -55,6 +55,10 @@ enum NodeType : unsigned {
   // instructions.
   ROLW,
   RORW,
+  // RV64IZbb bit counting instructions directly matching the semantics of the
+  // named RISC-V instructions.
+  CLZW,
+  CTZW,
   // RV64IB/RV32IB funnel shifts, with the semantics of the named RISC-V
   // instructions, but the same operand order as fshl/fshr intrinsics.
   FSR,
index abdffec..b9cc928 100644 (file)
@@ -17,6 +17,8 @@
 // Operand and SDNode transformation definitions.
 //===----------------------------------------------------------------------===//
 
+def riscv_clzw : SDNode<"RISCVISD::CLZW", SDTIntUnaryOp>;
+def riscv_ctzw : SDNode<"RISCVISD::CTZW", SDTIntUnaryOp>;
 def riscv_rolw : SDNode<"RISCVISD::ROLW", SDTIntShiftOp>;
 def riscv_rorw : SDNode<"RISCVISD::RORW", SDTIntShiftOp>;
 def riscv_fslw : SDNode<"RISCVISD::FSLW", SDTIntShiftDOp>;
@@ -877,14 +879,8 @@ def : Pat<(i64 (riscv_fslw GPR:$rs3, GPR:$rs1, uimm5:$shamt)),
 } // Predicates = [HasStdExtZbt, IsRV64]
 
 let Predicates = [HasStdExtZbb, IsRV64] in {
-def : Pat<(i64 (add (ctlz (and GPR:$rs1, 0xFFFFFFFF)), -32)),
-          (CLZW GPR:$rs1)>;
-// computeKnownBits can't figure out that the and mask on the add result is
-// unnecessary so we need to pattern match it away.
-def : Pat<(i64 (and (add (ctlz (and GPR:$rs1, 0xFFFFFFFF)), -32), 0xFFFFFFFF)),
-          (CLZW GPR:$rs1)>;
-def : Pat<(i64 (cttz (or GPR:$rs1, 0x100000000))),
-          (CTZW GPR:$rs1)>;
+def : Pat<(i64 (riscv_clzw GPR:$rs1)), (CLZW GPR:$rs1)>;
+def : Pat<(i64 (riscv_ctzw GPR:$rs1)), (CTZW GPR:$rs1)>;
 def : Pat<(i64 (ctpop (and GPR:$rs1, 0xFFFFFFFF))), (CPOPW GPR:$rs1)>;
 } // Predicates = [HasStdExtZbb, IsRV64]
 
index a173235..3aea386 100644 (file)
@@ -171,18 +171,15 @@ define signext i32 @log2_i32(i32 signext %a) nounwind {
 ;
 ; RV64IB-LABEL: log2_i32:
 ; RV64IB:       # %bb.0:
-; RV64IB-NEXT:    zext.w a0, a0
-; RV64IB-NEXT:    clz a0, a0
-; RV64IB-NEXT:    addi a1, zero, 63
+; RV64IB-NEXT:    clzw a0, a0
+; RV64IB-NEXT:    addi a1, zero, 31
 ; RV64IB-NEXT:    sub a0, a1, a0
 ; RV64IB-NEXT:    ret
 ;
 ; RV64IBB-LABEL: log2_i32:
 ; RV64IBB:       # %bb.0:
-; RV64IBB-NEXT:    slli a0, a0, 32
-; RV64IBB-NEXT:    srli a0, a0, 32
-; RV64IBB-NEXT:    clz a0, a0
-; RV64IBB-NEXT:    addi a1, zero, 63
+; RV64IBB-NEXT:    clzw a0, a0
+; RV64IBB-NEXT:    addi a1, zero, 31
 ; RV64IBB-NEXT:    sub a0, a1, a0
 ; RV64IBB-NEXT:    ret
   %1 = call i32 @llvm.ctlz.i32(i32 %a, i1 false)
@@ -270,19 +267,16 @@ define signext i32 @log2_ceil_i32(i32 signext %a) nounwind {
 ; RV64IB-LABEL: log2_ceil_i32:
 ; RV64IB:       # %bb.0:
 ; RV64IB-NEXT:    addi a0, a0, -1
-; RV64IB-NEXT:    zext.w a0, a0
-; RV64IB-NEXT:    clz a0, a0
-; RV64IB-NEXT:    addi a1, zero, 64
+; RV64IB-NEXT:    clzw a0, a0
+; RV64IB-NEXT:    addi a1, zero, 32
 ; RV64IB-NEXT:    sub a0, a1, a0
 ; RV64IB-NEXT:    ret
 ;
 ; RV64IBB-LABEL: log2_ceil_i32:
 ; RV64IBB:       # %bb.0:
 ; RV64IBB-NEXT:    addi a0, a0, -1
-; RV64IBB-NEXT:    slli a0, a0, 32
-; RV64IBB-NEXT:    srli a0, a0, 32
-; RV64IBB-NEXT:    clz a0, a0
-; RV64IBB-NEXT:    addi a1, zero, 64
+; RV64IBB-NEXT:    clzw a0, a0
+; RV64IBB-NEXT:    addi a1, zero, 32
 ; RV64IBB-NEXT:    sub a0, a1, a0
 ; RV64IBB-NEXT:    ret
   %1 = sub i32 %a, 1
@@ -469,15 +463,13 @@ define i32 @ctlz_lshr_i32(i32 signext %a) {
 ; RV64IB-LABEL: ctlz_lshr_i32:
 ; RV64IB:       # %bb.0:
 ; RV64IB-NEXT:    srliw a0, a0, 1
-; RV64IB-NEXT:    clz a0, a0
-; RV64IB-NEXT:    addi a0, a0, -32
+; RV64IB-NEXT:    clzw a0, a0
 ; RV64IB-NEXT:    ret
 ;
 ; RV64IBB-LABEL: ctlz_lshr_i32:
 ; RV64IBB:       # %bb.0:
 ; RV64IBB-NEXT:    srliw a0, a0, 1
-; RV64IBB-NEXT:    clz a0, a0
-; RV64IBB-NEXT:    addi a0, a0, -32
+; RV64IBB-NEXT:    clzw a0, a0
 ; RV64IBB-NEXT:    ret
   %1 = lshr i32 %a, 1
   %2 = call i32 @llvm.ctlz.i32(i32 %1, i1 false)
@@ -700,12 +692,12 @@ define signext i32 @cttz_zero_undef_i32(i32 signext %a) nounwind {
 ;
 ; RV64IB-LABEL: cttz_zero_undef_i32:
 ; RV64IB:       # %bb.0:
-; RV64IB-NEXT:    ctz a0, a0
+; RV64IB-NEXT:    ctzw a0, a0
 ; RV64IB-NEXT:    ret
 ;
 ; RV64IBB-LABEL: cttz_zero_undef_i32:
 ; RV64IBB:       # %bb.0:
-; RV64IBB-NEXT:    ctz a0, a0
+; RV64IBB-NEXT:    ctzw a0, a0
 ; RV64IBB-NEXT:    ret
   %1 = call i32 @llvm.cttz.i32(i32 %a, i1 true)
   ret i32 %1
@@ -775,7 +767,7 @@ define signext i32 @findFirstSet_i32(i32 signext %a) nounwind {
 ;
 ; RV64IB-LABEL: findFirstSet_i32:
 ; RV64IB:       # %bb.0:
-; RV64IB-NEXT:    ctz a1, a0
+; RV64IB-NEXT:    ctzw a1, a0
 ; RV64IB-NEXT:    addi a2, zero, -1
 ; RV64IB-NEXT:    cmov a0, a0, a1, a2
 ; RV64IB-NEXT:    ret
@@ -786,7 +778,7 @@ define signext i32 @findFirstSet_i32(i32 signext %a) nounwind {
 ; RV64IBB-NEXT:    addi a0, zero, -1
 ; RV64IBB-NEXT:    beqz a1, .LBB8_2
 ; RV64IBB-NEXT:  # %bb.1:
-; RV64IBB-NEXT:    ctz a0, a1
+; RV64IBB-NEXT:    ctzw a0, a1
 ; RV64IBB-NEXT:  .LBB8_2:
 ; RV64IBB-NEXT:    ret
   %1 = call i32 @llvm.cttz.i32(i32 %a, i1 true)
@@ -860,7 +852,7 @@ define signext i32 @ffs_i32(i32 signext %a) nounwind {
 ;
 ; RV64IB-LABEL: ffs_i32:
 ; RV64IB:       # %bb.0:
-; RV64IB-NEXT:    ctz a1, a0
+; RV64IB-NEXT:    ctzw a1, a0
 ; RV64IB-NEXT:    addi a1, a1, 1
 ; RV64IB-NEXT:    cmov a0, a0, a1, zero
 ; RV64IB-NEXT:    ret
@@ -871,7 +863,7 @@ define signext i32 @ffs_i32(i32 signext %a) nounwind {
 ; RV64IBB-NEXT:    mv a0, zero
 ; RV64IBB-NEXT:    beqz a1, .LBB9_2
 ; RV64IBB-NEXT:  # %bb.1:
-; RV64IBB-NEXT:    ctz a0, a1
+; RV64IBB-NEXT:    ctzw a0, a1
 ; RV64IBB-NEXT:    addi a0, a0, 1
 ; RV64IBB-NEXT:  .LBB9_2:
 ; RV64IBB-NEXT:    ret