}
void RISCVDAGToDAGISel::PostprocessISelDAG() {
- doPeepholeLoadStoreADDI();
+ SelectionDAG::allnodes_iterator Position = CurDAG->allnodes_end();
+
+ bool MadeChange = false;
+ while (Position != CurDAG->allnodes_begin()) {
+ SDNode *N = &*--Position;
+ // Skip dead nodes and any non-machine opcodes.
+ if (N->use_empty() || !N->isMachineOpcode())
+ continue;
+
+ MadeChange |= doPeepholeSExtW(N);
+ MadeChange |= doPeepholeLoadStoreADDI(N);
+ }
+
+ if (MadeChange)
+ CurDAG->RemoveDeadNodes();
}
static SDNode *selectImm(SelectionDAG *CurDAG, const SDLoc &DL, int64_t Imm,
// (load (addi base, off1), off2) -> (load base, off1+off2)
// (store val, (addi base, off1), off2) -> (store val, base, off1+off2)
// This is possible when off1+off2 fits a 12-bit immediate.
-void RISCVDAGToDAGISel::doPeepholeLoadStoreADDI() {
- SelectionDAG::allnodes_iterator Position(CurDAG->getRoot().getNode());
- ++Position;
+bool RISCVDAGToDAGISel::doPeepholeLoadStoreADDI(SDNode *N) {
+ int OffsetOpIdx;
+ int BaseOpIdx;
- while (Position != CurDAG->allnodes_begin()) {
- SDNode *N = &*--Position;
- // Skip dead nodes and any non-machine opcodes.
- if (N->use_empty() || !N->isMachineOpcode())
- continue;
+ // Only attempt this optimisation for I-type loads and S-type stores.
+ switch (N->getMachineOpcode()) {
+ default:
+ return false;
+ case RISCV::LB:
+ case RISCV::LH:
+ case RISCV::LW:
+ case RISCV::LBU:
+ case RISCV::LHU:
+ case RISCV::LWU:
+ case RISCV::LD:
+ case RISCV::FLH:
+ case RISCV::FLW:
+ case RISCV::FLD:
+ BaseOpIdx = 0;
+ OffsetOpIdx = 1;
+ break;
+ case RISCV::SB:
+ case RISCV::SH:
+ case RISCV::SW:
+ case RISCV::SD:
+ case RISCV::FSH:
+ case RISCV::FSW:
+ case RISCV::FSD:
+ BaseOpIdx = 1;
+ OffsetOpIdx = 2;
+ break;
+ }
- int OffsetOpIdx;
- int BaseOpIdx;
+ if (!isa<ConstantSDNode>(N->getOperand(OffsetOpIdx)))
+ return false;
- // Only attempt this optimisation for I-type loads and S-type stores.
- switch (N->getMachineOpcode()) {
- default:
- continue;
- case RISCV::LB:
- case RISCV::LH:
- case RISCV::LW:
- case RISCV::LBU:
- case RISCV::LHU:
- case RISCV::LWU:
- case RISCV::LD:
- case RISCV::FLH:
- case RISCV::FLW:
- case RISCV::FLD:
- BaseOpIdx = 0;
- OffsetOpIdx = 1;
- break;
- case RISCV::SB:
- case RISCV::SH:
- case RISCV::SW:
- case RISCV::SD:
- case RISCV::FSH:
- case RISCV::FSW:
- case RISCV::FSD:
- BaseOpIdx = 1;
- OffsetOpIdx = 2;
- break;
- }
+ SDValue Base = N->getOperand(BaseOpIdx);
- if (!isa<ConstantSDNode>(N->getOperand(OffsetOpIdx)))
- continue;
+ // If the base is an ADDI, we can merge it in to the load/store.
+ if (!Base.isMachineOpcode() || Base.getMachineOpcode() != RISCV::ADDI)
+ return false;
- SDValue Base = N->getOperand(BaseOpIdx);
+ SDValue ImmOperand = Base.getOperand(1);
+ uint64_t Offset2 = N->getConstantOperandVal(OffsetOpIdx);
- // If the base is an ADDI, we can merge it in to the load/store.
- if (!Base.isMachineOpcode() || Base.getMachineOpcode() != RISCV::ADDI)
- continue;
+ if (auto *Const = dyn_cast<ConstantSDNode>(ImmOperand)) {
+ int64_t Offset1 = Const->getSExtValue();
+ int64_t CombinedOffset = Offset1 + Offset2;
+ if (!isInt<12>(CombinedOffset))
+ return false;
+ ImmOperand = CurDAG->getTargetConstant(CombinedOffset, SDLoc(ImmOperand),
+ ImmOperand.getValueType());
+ } else if (auto *GA = dyn_cast<GlobalAddressSDNode>(ImmOperand)) {
+ // If the off1 in (addi base, off1) is a global variable's address (its
+ // low part, really), then we can rely on the alignment of that variable
+ // to provide a margin of safety before off1 can overflow the 12 bits.
+ // Check if off2 falls within that margin; if so off1+off2 can't overflow.
+ const DataLayout &DL = CurDAG->getDataLayout();
+ Align Alignment = GA->getGlobal()->getPointerAlignment(DL);
+ if (Offset2 != 0 && Alignment <= Offset2)
+ return false;
+ int64_t Offset1 = GA->getOffset();
+ int64_t CombinedOffset = Offset1 + Offset2;
+ ImmOperand = CurDAG->getTargetGlobalAddress(
+ GA->getGlobal(), SDLoc(ImmOperand), ImmOperand.getValueType(),
+ CombinedOffset, GA->getTargetFlags());
+ } else if (auto *CP = dyn_cast<ConstantPoolSDNode>(ImmOperand)) {
+ // Ditto.
+ Align Alignment = CP->getAlign();
+ if (Offset2 != 0 && Alignment <= Offset2)
+ return false;
+ int64_t Offset1 = CP->getOffset();
+ int64_t CombinedOffset = Offset1 + Offset2;
+ ImmOperand = CurDAG->getTargetConstantPool(
+ CP->getConstVal(), ImmOperand.getValueType(), CP->getAlign(),
+ CombinedOffset, CP->getTargetFlags());
+ } else {
+ return false;
+ }
- SDValue ImmOperand = Base.getOperand(1);
- uint64_t Offset2 = N->getConstantOperandVal(OffsetOpIdx);
-
- if (auto *Const = dyn_cast<ConstantSDNode>(ImmOperand)) {
- int64_t Offset1 = Const->getSExtValue();
- int64_t CombinedOffset = Offset1 + Offset2;
- if (!isInt<12>(CombinedOffset))
- continue;
- ImmOperand = CurDAG->getTargetConstant(CombinedOffset, SDLoc(ImmOperand),
- ImmOperand.getValueType());
- } else if (auto *GA = dyn_cast<GlobalAddressSDNode>(ImmOperand)) {
- // If the off1 in (addi base, off1) is a global variable's address (its
- // low part, really), then we can rely on the alignment of that variable
- // to provide a margin of safety before off1 can overflow the 12 bits.
- // Check if off2 falls within that margin; if so off1+off2 can't overflow.
- const DataLayout &DL = CurDAG->getDataLayout();
- Align Alignment = GA->getGlobal()->getPointerAlignment(DL);
- if (Offset2 != 0 && Alignment <= Offset2)
- continue;
- int64_t Offset1 = GA->getOffset();
- int64_t CombinedOffset = Offset1 + Offset2;
- ImmOperand = CurDAG->getTargetGlobalAddress(
- GA->getGlobal(), SDLoc(ImmOperand), ImmOperand.getValueType(),
- CombinedOffset, GA->getTargetFlags());
- } else if (auto *CP = dyn_cast<ConstantPoolSDNode>(ImmOperand)) {
- // Ditto.
- Align Alignment = CP->getAlign();
- if (Offset2 != 0 && Alignment <= Offset2)
- continue;
- int64_t Offset1 = CP->getOffset();
- int64_t CombinedOffset = Offset1 + Offset2;
- ImmOperand = CurDAG->getTargetConstantPool(
- CP->getConstVal(), ImmOperand.getValueType(), CP->getAlign(),
- CombinedOffset, CP->getTargetFlags());
- } else {
- continue;
+ LLVM_DEBUG(dbgs() << "Folding add-immediate into mem-op:\nBase: ");
+ LLVM_DEBUG(Base->dump(CurDAG));
+ LLVM_DEBUG(dbgs() << "\nN: ");
+ LLVM_DEBUG(N->dump(CurDAG));
+ LLVM_DEBUG(dbgs() << "\n");
+
+ // Modify the offset operand of the load/store.
+ if (BaseOpIdx == 0) // Load
+ CurDAG->UpdateNodeOperands(N, Base.getOperand(0), ImmOperand,
+ N->getOperand(2));
+ else // Store
+ CurDAG->UpdateNodeOperands(N, N->getOperand(0), Base.getOperand(0),
+ ImmOperand, N->getOperand(3));
+
+ return true;
+}
+
+// Try to remove sext.w if the input is a W instruction or can be made into
+// a W instruction cheaply.
+bool RISCVDAGToDAGISel::doPeepholeSExtW(SDNode *N) {
+ // Look for the sext.w pattern, addiw rd, rs1, 0.
+ if (N->getMachineOpcode() != RISCV::ADDIW ||
+ !isNullConstant(N->getOperand(1)))
+ return false;
+
+ SDValue N0 = N->getOperand(0);
+ if (!N0.isMachineOpcode())
+ return false;
+
+ switch (N0.getMachineOpcode()) {
+ default:
+ break;
+ case RISCV::ADD:
+ case RISCV::ADDI:
+ case RISCV::SUB:
+ case RISCV::MUL:
+ case RISCV::SLLI: {
+ // Convert sext.w+add/sub/mul to their W instructions. This will create
+ // a new independent instruction. This improves latency.
+ unsigned Opc;
+ switch (N0.getMachineOpcode()) {
+ default:
+ llvm_unreachable("Unexpected opcode!");
+ case RISCV::ADD: Opc = RISCV::ADDW; break;
+ case RISCV::ADDI: Opc = RISCV::ADDIW; break;
+ case RISCV::SUB: Opc = RISCV::SUBW; break;
+ case RISCV::MUL: Opc = RISCV::MULW; break;
+ case RISCV::SLLI: Opc = RISCV::SLLIW; break;
}
- LLVM_DEBUG(dbgs() << "Folding add-immediate into mem-op:\nBase: ");
- LLVM_DEBUG(Base->dump(CurDAG));
- LLVM_DEBUG(dbgs() << "\nN: ");
- LLVM_DEBUG(N->dump(CurDAG));
- LLVM_DEBUG(dbgs() << "\n");
-
- // Modify the offset operand of the load/store.
- if (BaseOpIdx == 0) // Load
- CurDAG->UpdateNodeOperands(N, Base.getOperand(0), ImmOperand,
- N->getOperand(2));
- else // Store
- CurDAG->UpdateNodeOperands(N, N->getOperand(0), Base.getOperand(0),
- ImmOperand, N->getOperand(3));
-
- // The add-immediate may now be dead, in which case remove it.
- if (Base.getNode()->use_empty())
- CurDAG->RemoveDeadNode(Base.getNode());
+ SDValue N00 = N0.getOperand(0);
+ SDValue N01 = N0.getOperand(1);
+
+ // Shift amount needs to be uimm5.
+ if (N0.getMachineOpcode() == RISCV::SLLI &&
+ !isUInt<5>(cast<ConstantSDNode>(N01)->getSExtValue()))
+ break;
+
+ SDNode *Result =
+ CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0),
+ N00, N01);
+ ReplaceUses(N, Result);
+ return true;
}
+ case RISCV::ADDW:
+ case RISCV::ADDIW:
+ case RISCV::SUBW:
+ case RISCV::MULW:
+ case RISCV::SLLIW:
+ // Result is already sign extended just remove the sext.w.
+ // NOTE: We only handle the nodes that are selected with hasAllWUsers.
+ ReplaceUses(N, N0.getNode());
+ return true;
+ }
+
+ return false;
}
// This pass converts a legalized DAG into a RISCV-specific DAG, ready
;
; RV64B-LABEL: mulw192:
; RV64B: # %bb.0:
-; RV64B-NEXT: addi a1, zero, 192
-; RV64B-NEXT: mulw a0, a0, a1
+; RV64B-NEXT: sh1add a0, a0, a0
+; RV64B-NEXT: slliw a0, a0, 6
; RV64B-NEXT: ret
;
; RV64ZBA-LABEL: mulw192:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: addi a1, zero, 192
-; RV64ZBA-NEXT: mulw a0, a0, a1
+; RV64ZBA-NEXT: sh1add a0, a0, a0
+; RV64ZBA-NEXT: slliw a0, a0, 6
; RV64ZBA-NEXT: ret
%c = mul i32 %a, 192
ret i32 %c
;
; RV64B-LABEL: mulw320:
; RV64B: # %bb.0:
-; RV64B-NEXT: addi a1, zero, 320
-; RV64B-NEXT: mulw a0, a0, a1
+; RV64B-NEXT: sh2add a0, a0, a0
+; RV64B-NEXT: slliw a0, a0, 6
; RV64B-NEXT: ret
;
; RV64ZBA-LABEL: mulw320:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: addi a1, zero, 320
-; RV64ZBA-NEXT: mulw a0, a0, a1
+; RV64ZBA-NEXT: sh2add a0, a0, a0
+; RV64ZBA-NEXT: slliw a0, a0, 6
; RV64ZBA-NEXT: ret
%c = mul i32 %a, 320
ret i32 %c
;
; RV64B-LABEL: mulw576:
; RV64B: # %bb.0:
-; RV64B-NEXT: addi a1, zero, 576
-; RV64B-NEXT: mulw a0, a0, a1
+; RV64B-NEXT: sh3add a0, a0, a0
+; RV64B-NEXT: slliw a0, a0, 6
; RV64B-NEXT: ret
;
; RV64ZBA-LABEL: mulw576:
; RV64ZBA: # %bb.0:
-; RV64ZBA-NEXT: addi a1, zero, 576
-; RV64ZBA-NEXT: mulw a0, a0, a1
+; RV64ZBA-NEXT: sh3add a0, a0, a0
+; RV64ZBA-NEXT: slliw a0, a0, 6
; RV64ZBA-NEXT: ret
%c = mul i32 %a, 576
ret i32 %c