[RISCV] Refactor an optimization of addition with immediate
authorBen Shi <powerman1st@163.com>
Tue, 20 Apr 2021 10:04:25 +0000 (18:04 +0800)
committerBen Shi <powerman1st@163.com>
Tue, 20 Apr 2021 10:04:25 +0000 (18:04 +0800)
Reviewed By: craig.topper

Differential Revision: https://reviews.llvm.org/D100769

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
llvm/lib/Target/RISCV/RISCVInstrInfo.td

index da9ab75..1e2cd4b 100644 (file)
@@ -379,28 +379,6 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
   MVT VT = Node->getSimpleValueType(0);
 
   switch (Opcode) {
-  case ISD::ADD: {
-    // Optimize (add r, imm) to (addi (addi r, imm0) imm1) if applicable. The
-    // immediate must be in specific ranges and have a single use.
-    if (auto *ConstOp = dyn_cast<ConstantSDNode>(Node->getOperand(1))) {
-      if (!(ConstOp->hasOneUse()))
-        break;
-      // The imm must be in range [-4096,-2049] or [2048,4094].
-      int64_t Imm = ConstOp->getSExtValue();
-      if (!(-4096 <= Imm && Imm <= -2049) && !(2048 <= Imm && Imm <= 4094))
-        break;
-      // Break the imm to imm0+imm1.
-      const SDValue ImmOp0 = CurDAG->getTargetConstant(Imm - Imm / 2, DL, VT);
-      const SDValue ImmOp1 = CurDAG->getTargetConstant(Imm / 2, DL, VT);
-      auto *NodeAddi0 = CurDAG->getMachineNode(RISCV::ADDI, DL, VT,
-                                               Node->getOperand(0), ImmOp0);
-      auto *NodeAddi1 = CurDAG->getMachineNode(RISCV::ADDI, DL, VT,
-                                               SDValue(NodeAddi0, 0), ImmOp1);
-      ReplaceNode(Node, NodeAddi1);
-      return;
-    }
-    break;
-  }
   case ISD::Constant: {
     auto *ConstNode = cast<ConstantSDNode>(Node);
     if (VT == XLenVT && ConstNode->isNullValue()) {
@@ -1092,6 +1070,23 @@ bool RISCVDAGToDAGISel::selectZExti32(SDValue N, SDValue &Val) {
   return false;
 }
 
+// Check if (add r, imm) can be optimized to (ADDI (ADDI r, imm0), imm1),
+// in which imm = imm0 + imm1 and both imm0 and imm1 are simm12.
+bool RISCVDAGToDAGISel::selectAddiPair(SDValue N, SDValue &Val) {
+  if (auto *ConstOp = dyn_cast<ConstantSDNode>(N)) {
+    // The immediate operand must have only use.
+    if (!(ConstOp->hasOneUse()))
+      return false;
+    // The immediate operand must be in range [-4096,-2049] or [2048,4094].
+    int64_t Imm = ConstOp->getSExtValue();
+    if ((-4096 <= Imm && Imm <= -2049) || (2048 <= Imm && Imm <= 4094)) {
+      Val = N;
+      return true;
+    }
+  }
+  return false;
+}
+
 // Check that it is a SLLIUW (Shift Logical Left Immediate Unsigned i32
 // on RV64).
 // SLLIUW is the same as SLLI except for the fact that it clears the bits
index f67e88a..7336bd8 100644 (file)
@@ -57,6 +57,8 @@ public:
   bool selectSExti32(SDValue N, SDValue &Val);
   bool selectZExti32(SDValue N, SDValue &Val);
 
+  bool selectAddiPair(SDValue N, SDValue &Val);
+
   bool MatchSLLIUW(SDNode *N) const;
 
   bool selectVLOp(SDValue N, SDValue &VL);
index a6b95ac..df126f5 100644 (file)
@@ -145,7 +145,7 @@ def uimmlog2xlen : Operand<XLenVT>, ImmLeaf<XLenVT, [{
     if (!MCOp.evaluateAsConstantImm(Imm))
       return false;
     if (STI.getTargetTriple().isArch64Bit())
-      return  isUInt<6>(Imm);
+      return isUInt<6>(Imm);
     return isUInt<5>(Imm);
   }];
   let OperandType = "OPERAND_UIMMLOG2XLEN";
@@ -343,6 +343,22 @@ def ImmSubFrom32 : SDNodeXForm<imm, [{
                                    N->getValueType(0));
 }]>;
 
+// Check if an addition can be broken to a pair of ADDI.
+def AddiPair : ComplexPattern<XLenVT, 1, "selectAddiPair">;
+
+// Return imm/2.
+def AddiPairImmA : SDNodeXForm<imm, [{
+  return CurDAG->getTargetConstant(N->getSExtValue() / 2, SDLoc(N),
+                                   N->getValueType(0));
+}]>;
+
+// Return imm - imm/2.
+def AddiPairImmB : SDNodeXForm<imm, [{
+  int64_t Imm = N->getSExtValue();
+  return CurDAG->getTargetConstant(Imm - Imm / 2, SDLoc(N),
+                                   N->getValueType(0));
+}]>;
+
 //===----------------------------------------------------------------------===//
 // Instruction Formats
 //===----------------------------------------------------------------------===//
@@ -1274,6 +1290,11 @@ def : Pat<(trap), (UNIMP)>;
 // debugger if possible.
 def : Pat<(debugtrap), (EBREAK)>;
 
+/// Simple optimization
+def : Pat<(add GPR:$rs1, (AddiPair GPR:$rs2)),
+          (ADDI (ADDI GPR:$rs1, (AddiPairImmB GPR:$rs2)),
+                (AddiPairImmA GPR:$rs2))>;
+
 //===----------------------------------------------------------------------===//
 // Standard extensions
 //===----------------------------------------------------------------------===//