[RISCV] Remove sext_inreg+add/sub/mul/shl isel patterns.
authorCraig Topper <craig.topper@sifive.com>
Wed, 18 Aug 2021 17:46:09 +0000 (10:46 -0700)
committerCraig Topper <craig.topper@sifive.com>
Wed, 18 Aug 2021 18:07:11 +0000 (11:07 -0700)
Let the sext_inreg be selected to sext.w. Remove unneeded sext.w
during PostProcessISelDAG.

This gives opportunities for some other isel patterns to match
like the ADDIPair or matching mul with immediate to shXadd.

This becomes possible after D107658 started selecting W instructions
based on users. The sext.w will be considered a W user so isel
will often select a W instruction for the sext.w input and we can
just remove the sext.w. Otherwise we can combine the sext.w with
a ADD/SUB/MUL/SLLI to create a new W instruction in parallel
to the the original instruction.

Reviewed By: luismarques

Differential Revision: https://reviews.llvm.org/D107708

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
llvm/lib/Target/RISCV/RISCVInstrInfo.td
llvm/lib/Target/RISCV/RISCVInstrInfoM.td
llvm/test/CodeGen/RISCV/add-imm.ll
llvm/test/CodeGen/RISCV/rv64zba.ll
llvm/test/CodeGen/RISCV/xaluo.ll

index b3ee2bc..669245f 100644 (file)
@@ -108,7 +108,21 @@ void RISCVDAGToDAGISel::PreprocessISelDAG() {
 }
 
 void RISCVDAGToDAGISel::PostprocessISelDAG() {
-  doPeepholeLoadStoreADDI();
+  SelectionDAG::allnodes_iterator Position = CurDAG->allnodes_end();
+
+  bool MadeChange = false;
+  while (Position != CurDAG->allnodes_begin()) {
+    SDNode *N = &*--Position;
+    // Skip dead nodes and any non-machine opcodes.
+    if (N->use_empty() || !N->isMachineOpcode())
+      continue;
+
+    MadeChange |= doPeepholeSExtW(N);
+    MadeChange |= doPeepholeLoadStoreADDI(N);
+  }
+
+  if (MadeChange)
+    CurDAG->RemoveDeadNodes();
 }
 
 static SDNode *selectImm(SelectionDAG *CurDAG, const SDLoc &DL, int64_t Imm,
@@ -1702,113 +1716,162 @@ bool RISCVDAGToDAGISel::selectRVVSimm5(SDValue N, unsigned Width,
 // (load (addi base, off1), off2) -> (load base, off1+off2)
 // (store val, (addi base, off1), off2) -> (store val, base, off1+off2)
 // This is possible when off1+off2 fits a 12-bit immediate.
-void RISCVDAGToDAGISel::doPeepholeLoadStoreADDI() {
-  SelectionDAG::allnodes_iterator Position(CurDAG->getRoot().getNode());
-  ++Position;
+bool RISCVDAGToDAGISel::doPeepholeLoadStoreADDI(SDNode *N) {
+  int OffsetOpIdx;
+  int BaseOpIdx;
 
-  while (Position != CurDAG->allnodes_begin()) {
-    SDNode *N = &*--Position;
-    // Skip dead nodes and any non-machine opcodes.
-    if (N->use_empty() || !N->isMachineOpcode())
-      continue;
+  // Only attempt this optimisation for I-type loads and S-type stores.
+  switch (N->getMachineOpcode()) {
+  default:
+    return false;
+  case RISCV::LB:
+  case RISCV::LH:
+  case RISCV::LW:
+  case RISCV::LBU:
+  case RISCV::LHU:
+  case RISCV::LWU:
+  case RISCV::LD:
+  case RISCV::FLH:
+  case RISCV::FLW:
+  case RISCV::FLD:
+    BaseOpIdx = 0;
+    OffsetOpIdx = 1;
+    break;
+  case RISCV::SB:
+  case RISCV::SH:
+  case RISCV::SW:
+  case RISCV::SD:
+  case RISCV::FSH:
+  case RISCV::FSW:
+  case RISCV::FSD:
+    BaseOpIdx = 1;
+    OffsetOpIdx = 2;
+    break;
+  }
 
-    int OffsetOpIdx;
-    int BaseOpIdx;
+  if (!isa<ConstantSDNode>(N->getOperand(OffsetOpIdx)))
+    return false;
 
-    // Only attempt this optimisation for I-type loads and S-type stores.
-    switch (N->getMachineOpcode()) {
-    default:
-      continue;
-    case RISCV::LB:
-    case RISCV::LH:
-    case RISCV::LW:
-    case RISCV::LBU:
-    case RISCV::LHU:
-    case RISCV::LWU:
-    case RISCV::LD:
-    case RISCV::FLH:
-    case RISCV::FLW:
-    case RISCV::FLD:
-      BaseOpIdx = 0;
-      OffsetOpIdx = 1;
-      break;
-    case RISCV::SB:
-    case RISCV::SH:
-    case RISCV::SW:
-    case RISCV::SD:
-    case RISCV::FSH:
-    case RISCV::FSW:
-    case RISCV::FSD:
-      BaseOpIdx = 1;
-      OffsetOpIdx = 2;
-      break;
-    }
+  SDValue Base = N->getOperand(BaseOpIdx);
 
-    if (!isa<ConstantSDNode>(N->getOperand(OffsetOpIdx)))
-      continue;
+  // If the base is an ADDI, we can merge it in to the load/store.
+  if (!Base.isMachineOpcode() || Base.getMachineOpcode() != RISCV::ADDI)
+    return false;
 
-    SDValue Base = N->getOperand(BaseOpIdx);
+  SDValue ImmOperand = Base.getOperand(1);
+  uint64_t Offset2 = N->getConstantOperandVal(OffsetOpIdx);
 
-    // If the base is an ADDI, we can merge it in to the load/store.
-    if (!Base.isMachineOpcode() || Base.getMachineOpcode() != RISCV::ADDI)
-      continue;
+  if (auto *Const = dyn_cast<ConstantSDNode>(ImmOperand)) {
+    int64_t Offset1 = Const->getSExtValue();
+    int64_t CombinedOffset = Offset1 + Offset2;
+    if (!isInt<12>(CombinedOffset))
+      return false;
+    ImmOperand = CurDAG->getTargetConstant(CombinedOffset, SDLoc(ImmOperand),
+                                           ImmOperand.getValueType());
+  } else if (auto *GA = dyn_cast<GlobalAddressSDNode>(ImmOperand)) {
+    // If the off1 in (addi base, off1) is a global variable's address (its
+    // low part, really), then we can rely on the alignment of that variable
+    // to provide a margin of safety before off1 can overflow the 12 bits.
+    // Check if off2 falls within that margin; if so off1+off2 can't overflow.
+    const DataLayout &DL = CurDAG->getDataLayout();
+    Align Alignment = GA->getGlobal()->getPointerAlignment(DL);
+    if (Offset2 != 0 && Alignment <= Offset2)
+      return false;
+    int64_t Offset1 = GA->getOffset();
+    int64_t CombinedOffset = Offset1 + Offset2;
+    ImmOperand = CurDAG->getTargetGlobalAddress(
+        GA->getGlobal(), SDLoc(ImmOperand), ImmOperand.getValueType(),
+        CombinedOffset, GA->getTargetFlags());
+  } else if (auto *CP = dyn_cast<ConstantPoolSDNode>(ImmOperand)) {
+    // Ditto.
+    Align Alignment = CP->getAlign();
+    if (Offset2 != 0 && Alignment <= Offset2)
+      return false;
+    int64_t Offset1 = CP->getOffset();
+    int64_t CombinedOffset = Offset1 + Offset2;
+    ImmOperand = CurDAG->getTargetConstantPool(
+        CP->getConstVal(), ImmOperand.getValueType(), CP->getAlign(),
+        CombinedOffset, CP->getTargetFlags());
+  } else {
+    return false;
+  }
 
-    SDValue ImmOperand = Base.getOperand(1);
-    uint64_t Offset2 = N->getConstantOperandVal(OffsetOpIdx);
-
-    if (auto *Const = dyn_cast<ConstantSDNode>(ImmOperand)) {
-      int64_t Offset1 = Const->getSExtValue();
-      int64_t CombinedOffset = Offset1 + Offset2;
-      if (!isInt<12>(CombinedOffset))
-        continue;
-      ImmOperand = CurDAG->getTargetConstant(CombinedOffset, SDLoc(ImmOperand),
-                                             ImmOperand.getValueType());
-    } else if (auto *GA = dyn_cast<GlobalAddressSDNode>(ImmOperand)) {
-      // If the off1 in (addi base, off1) is a global variable's address (its
-      // low part, really), then we can rely on the alignment of that variable
-      // to provide a margin of safety before off1 can overflow the 12 bits.
-      // Check if off2 falls within that margin; if so off1+off2 can't overflow.
-      const DataLayout &DL = CurDAG->getDataLayout();
-      Align Alignment = GA->getGlobal()->getPointerAlignment(DL);
-      if (Offset2 != 0 && Alignment <= Offset2)
-        continue;
-      int64_t Offset1 = GA->getOffset();
-      int64_t CombinedOffset = Offset1 + Offset2;
-      ImmOperand = CurDAG->getTargetGlobalAddress(
-          GA->getGlobal(), SDLoc(ImmOperand), ImmOperand.getValueType(),
-          CombinedOffset, GA->getTargetFlags());
-    } else if (auto *CP = dyn_cast<ConstantPoolSDNode>(ImmOperand)) {
-      // Ditto.
-      Align Alignment = CP->getAlign();
-      if (Offset2 != 0 && Alignment <= Offset2)
-        continue;
-      int64_t Offset1 = CP->getOffset();
-      int64_t CombinedOffset = Offset1 + Offset2;
-      ImmOperand = CurDAG->getTargetConstantPool(
-          CP->getConstVal(), ImmOperand.getValueType(), CP->getAlign(),
-          CombinedOffset, CP->getTargetFlags());
-    } else {
-      continue;
+  LLVM_DEBUG(dbgs() << "Folding add-immediate into mem-op:\nBase:    ");
+  LLVM_DEBUG(Base->dump(CurDAG));
+  LLVM_DEBUG(dbgs() << "\nN: ");
+  LLVM_DEBUG(N->dump(CurDAG));
+  LLVM_DEBUG(dbgs() << "\n");
+
+  // Modify the offset operand of the load/store.
+  if (BaseOpIdx == 0) // Load
+    CurDAG->UpdateNodeOperands(N, Base.getOperand(0), ImmOperand,
+                               N->getOperand(2));
+  else // Store
+    CurDAG->UpdateNodeOperands(N, N->getOperand(0), Base.getOperand(0),
+                               ImmOperand, N->getOperand(3));
+
+  return true;
+}
+
+// Try to remove sext.w if the input is a W instruction or can be made into
+// a W instruction cheaply.
+bool RISCVDAGToDAGISel::doPeepholeSExtW(SDNode *N) {
+  // Look for the sext.w pattern, addiw rd, rs1, 0.
+  if (N->getMachineOpcode() != RISCV::ADDIW ||
+      !isNullConstant(N->getOperand(1)))
+    return false;
+
+  SDValue N0 = N->getOperand(0);
+  if (!N0.isMachineOpcode())
+    return false;
+
+  switch (N0.getMachineOpcode()) {
+  default:
+    break;
+  case RISCV::ADD:
+  case RISCV::ADDI:
+  case RISCV::SUB:
+  case RISCV::MUL:
+  case RISCV::SLLI: {
+    // Convert sext.w+add/sub/mul to their W instructions. This will create
+    // a new independent instruction. This improves latency.
+    unsigned Opc;
+    switch (N0.getMachineOpcode()) {
+    default:
+      llvm_unreachable("Unexpected opcode!");
+    case RISCV::ADD:  Opc = RISCV::ADDW;  break;
+    case RISCV::ADDI: Opc = RISCV::ADDIW; break;
+    case RISCV::SUB:  Opc = RISCV::SUBW;  break;
+    case RISCV::MUL:  Opc = RISCV::MULW;  break;
+    case RISCV::SLLI: Opc = RISCV::SLLIW; break;
     }
 
-    LLVM_DEBUG(dbgs() << "Folding add-immediate into mem-op:\nBase:    ");
-    LLVM_DEBUG(Base->dump(CurDAG));
-    LLVM_DEBUG(dbgs() << "\nN: ");
-    LLVM_DEBUG(N->dump(CurDAG));
-    LLVM_DEBUG(dbgs() << "\n");
-
-    // Modify the offset operand of the load/store.
-    if (BaseOpIdx == 0) // Load
-      CurDAG->UpdateNodeOperands(N, Base.getOperand(0), ImmOperand,
-                                 N->getOperand(2));
-    else // Store
-      CurDAG->UpdateNodeOperands(N, N->getOperand(0), Base.getOperand(0),
-                                 ImmOperand, N->getOperand(3));
-
-    // The add-immediate may now be dead, in which case remove it.
-    if (Base.getNode()->use_empty())
-      CurDAG->RemoveDeadNode(Base.getNode());
+    SDValue N00 = N0.getOperand(0);
+    SDValue N01 = N0.getOperand(1);
+
+    // Shift amount needs to be uimm5.
+    if (N0.getMachineOpcode() == RISCV::SLLI &&
+        !isUInt<5>(cast<ConstantSDNode>(N01)->getSExtValue()))
+      break;
+
+    SDNode *Result =
+        CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0),
+                               N00, N01);
+    ReplaceUses(N, Result);
+    return true;
   }
+  case RISCV::ADDW:
+  case RISCV::ADDIW:
+  case RISCV::SUBW:
+  case RISCV::MULW:
+  case RISCV::SLLIW:
+    // Result is already sign extended just remove the sext.w.
+    // NOTE: We only handle the nodes that are selected with hasAllWUsers.
+    ReplaceUses(N, N0.getNode());
+    return true;
+  }
+
+  return false;
 }
 
 // This pass converts a legalized DAG into a RISCV-specific DAG, ready
index ca61caf..5feaa8f 100644 (file)
@@ -113,7 +113,8 @@ public:
 #include "RISCVGenDAGISel.inc"
 
 private:
-  void doPeepholeLoadStoreADDI();
+  bool doPeepholeLoadStoreADDI(SDNode *Node);
+  bool doPeepholeSExtW(SDNode *Node);
 };
 
 namespace RISCV {
index b7c05cb..ce23dc6 100644 (file)
@@ -1270,14 +1270,6 @@ def : Pat<(sext_inreg GPR:$rs1, i32), (ADDIW GPR:$rs1, 0)>;
 
 /// ALU operations
 
-def : Pat<(sext_inreg (add GPR:$rs1, GPR:$rs2), i32),
-          (ADDW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(sext_inreg (add GPR:$rs1, simm12:$imm12), i32),
-          (ADDIW GPR:$rs1, simm12:$imm12)>;
-def : Pat<(sext_inreg (sub GPR:$rs1, GPR:$rs2), i32),
-          (SUBW GPR:$rs1, GPR:$rs2)>;
-def : Pat<(sext_inreg (shl GPR:$rs1, uimm5:$shamt), i32),
-          (SLLIW GPR:$rs1, uimm5:$shamt)>;
 def : Pat<(i64 (srl (and GPR:$rs1, 0xffffffff), uimm5:$shamt)),
           (SRLIW GPR:$rs1, uimm5:$shamt)>;
 def : Pat<(i64 (srl (shl GPR:$rs1, (i64 32)), uimm6gt32:$shamt)),
@@ -1338,7 +1330,8 @@ def : Pat<(add GPR:$rs1, (AddiPair:$rs2)),
                 (AddiPairImmA GPR:$rs2))>;
 
 let Predicates = [IsRV64] in {
-def : Pat<(sext_inreg (add_oneuse GPR:$rs1, (AddiPair:$rs2)), i32),
+// Select W instructions if only the lower 32-bits of the result are used.
+def : Pat<(overflowingbinopw<add> GPR:$rs1, (AddiPair:$rs2)),
           (ADDIW (ADDIW GPR:$rs1, (AddiPairImmB AddiPair:$rs2)),
                  (AddiPairImmA AddiPair:$rs2))>;
 }
index b38998f..e035360 100644 (file)
@@ -72,9 +72,6 @@ def : PatGprGpr<urem, REMU>;
 } // Predicates = [HasStdExtM]
 
 let Predicates = [HasStdExtM, IsRV64] in {
-def : Pat<(sext_inreg (mul GPR:$rs1, GPR:$rs2), i32),
-          (MULW GPR:$rs1, GPR:$rs2)>;
-
 // Select W instructions without sext_inreg if only the lower 32-bits of the
 // result are used.
 def : PatGprGpr<overflowingbinopw<mul>, MULW>;
@@ -109,11 +106,4 @@ let Predicates = [HasStdExtM, IsRV64, NotHasStdExtZba] in {
 // still be better off shifting both left by 32.
 def : Pat<(i64 (mul (and GPR:$rs1, 0xffffffff), (and GPR:$rs2, 0xffffffff))),
           (MULHU (SLLI GPR:$rs1, 32), (SLLI GPR:$rs2, 32))>;
-// Prevent matching the first part of this pattern to mulw. The mul here has
-// additionals users or the ANDs would have been removed. The above pattern
-// will be used for the other users. If we form a mulw we'll keep the ANDs alive
-// and they'll still become SLLI+SRLI.
-def : Pat<(sext_inreg (mul (and GPR:$rs1, 0xffffffff),
-                           (and GPR:$rs2, 0xffffffff)), i32),
-          (ADDIW (MULHU (SLLI GPR:$rs1, 32), (SLLI GPR:$rs2, 32)), 0)>;
 } // Predicates = [HasStdExtM, IsRV64, NotHasStdExtZba]
index 970f809..82178da 100644 (file)
@@ -178,9 +178,8 @@ define signext i32 @add32_sext_reject_on_rv64(i32 signext %a) nounwind {
 ;
 ; RV64I-LABEL: add32_sext_reject_on_rv64:
 ; RV64I:       # %bb.0:
-; RV64I-NEXT:    lui a1, 1
-; RV64I-NEXT:    addiw a1, a1, -1096
-; RV64I-NEXT:    addw a0, a0, a1
+; RV64I-NEXT:    addiw a0, a0, 1500
+; RV64I-NEXT:    addiw a0, a0, 1500
 ; RV64I-NEXT:    lui a1, %hi(gv0)
 ; RV64I-NEXT:    sw a0, %lo(gv0)(a1)
 ; RV64I-NEXT:    ret
index e16efa5..1cf572d 100644 (file)
@@ -1254,14 +1254,14 @@ define signext i32 @mulw192(i32 signext %a) {
 ;
 ; RV64B-LABEL: mulw192:
 ; RV64B:       # %bb.0:
-; RV64B-NEXT:    addi a1, zero, 192
-; RV64B-NEXT:    mulw a0, a0, a1
+; RV64B-NEXT:    sh1add a0, a0, a0
+; RV64B-NEXT:    slliw a0, a0, 6
 ; RV64B-NEXT:    ret
 ;
 ; RV64ZBA-LABEL: mulw192:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    addi a1, zero, 192
-; RV64ZBA-NEXT:    mulw a0, a0, a1
+; RV64ZBA-NEXT:    sh1add a0, a0, a0
+; RV64ZBA-NEXT:    slliw a0, a0, 6
 ; RV64ZBA-NEXT:    ret
   %c = mul i32 %a, 192
   ret i32 %c
@@ -1276,14 +1276,14 @@ define signext i32 @mulw320(i32 signext %a) {
 ;
 ; RV64B-LABEL: mulw320:
 ; RV64B:       # %bb.0:
-; RV64B-NEXT:    addi a1, zero, 320
-; RV64B-NEXT:    mulw a0, a0, a1
+; RV64B-NEXT:    sh2add a0, a0, a0
+; RV64B-NEXT:    slliw a0, a0, 6
 ; RV64B-NEXT:    ret
 ;
 ; RV64ZBA-LABEL: mulw320:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    addi a1, zero, 320
-; RV64ZBA-NEXT:    mulw a0, a0, a1
+; RV64ZBA-NEXT:    sh2add a0, a0, a0
+; RV64ZBA-NEXT:    slliw a0, a0, 6
 ; RV64ZBA-NEXT:    ret
   %c = mul i32 %a, 320
   ret i32 %c
@@ -1298,14 +1298,14 @@ define signext i32 @mulw576(i32 signext %a) {
 ;
 ; RV64B-LABEL: mulw576:
 ; RV64B:       # %bb.0:
-; RV64B-NEXT:    addi a1, zero, 576
-; RV64B-NEXT:    mulw a0, a0, a1
+; RV64B-NEXT:    sh3add a0, a0, a0
+; RV64B-NEXT:    slliw a0, a0, 6
 ; RV64B-NEXT:    ret
 ;
 ; RV64ZBA-LABEL: mulw576:
 ; RV64ZBA:       # %bb.0:
-; RV64ZBA-NEXT:    addi a1, zero, 576
-; RV64ZBA-NEXT:    mulw a0, a0, a1
+; RV64ZBA-NEXT:    sh3add a0, a0, a0
+; RV64ZBA-NEXT:    slliw a0, a0, 6
 ; RV64ZBA-NEXT:    ret
   %c = mul i32 %a, 576
   ret i32 %c
index 5fe5d68..2695508 100644 (file)
@@ -883,12 +883,12 @@ define zeroext i1 @smulo2.i32(i32 %v1, i32* %res) {
 ; RV64ZBA-LABEL: smulo2.i32:
 ; RV64ZBA:       # %bb.0: # %entry
 ; RV64ZBA-NEXT:    sext.w a0, a0
-; RV64ZBA-NEXT:    addi a2, zero, 13
-; RV64ZBA-NEXT:    mul a3, a0, a2
-; RV64ZBA-NEXT:    mulw a0, a0, a2
-; RV64ZBA-NEXT:    xor a0, a0, a3
+; RV64ZBA-NEXT:    sh1add a2, a0, a0
+; RV64ZBA-NEXT:    sh2add a2, a2, a0
+; RV64ZBA-NEXT:    sext.w a0, a2
+; RV64ZBA-NEXT:    xor a0, a0, a2
 ; RV64ZBA-NEXT:    snez a0, a0
-; RV64ZBA-NEXT:    sw a3, 0(a1)
+; RV64ZBA-NEXT:    sw a2, 0(a1)
 ; RV64ZBA-NEXT:    ret
 entry:
   %t = call {i32, i1} @llvm.smul.with.overflow.i32(i32 %v1, i32 13)