[DAGCombiner] Add target hook function to decide folding (mul (add x, c1), c2)
authorBen Shi <powerman1st@163.com>
Thu, 19 Aug 2021 13:51:09 +0000 (21:51 +0800)
committerBen Shi <powerman1st@163.com>
Sun, 22 Aug 2021 08:53:32 +0000 (16:53 +0800)
Reviewed by: lebedev.ri, spatel, craig.topper, luismarques, jrtc27

Differential Revision: https://reviews.llvm.org/D107711

llvm/include/llvm/CodeGen/TargetLowering.h
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/test/CodeGen/RISCV/addimm-mulimm.ll

index 07de68f..100f4e2 100644 (file)
@@ -2081,6 +2081,20 @@ public:
     return false;
   }
 
+  /// Return true if it may be profitable to transform
+  /// (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2).
+  /// This may not be true if c1 and c2 can be represented as immediates but
+  /// c1*c2 cannot, for example.
+  /// The target should check if c1, c2 and c1*c2 can be represented as
+  /// immediates, or have to be materialized into registers. If it is not sure
+  /// about some cases, a default true can be returned to let the DAGCombiner
+  /// decide.
+  /// AddNode is (add x, c1), and ConstNode is c2.
+  virtual bool isMulAddWithConstProfitable(const SDValue &AddNode,
+                                           const SDValue &ConstNode) const {
+    return true;
+  }
+
   /// Return true if it is more correct/profitable to use strict FP_TO_INT
   /// conversion operations - canonicalizing the FP source value instead of
   /// converting all cases and then selecting based on value.
index 1c56d9e..839787e 100644 (file)
@@ -16854,8 +16854,10 @@ bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode,
                                               SDValue &ConstNode) {
   APInt Val;
 
-  // If the add only has one use, this would be OK to do.
-  if (AddNode.getNode()->hasOneUse())
+  // If the add only has one use, and the target thinks the folding is
+  // profitable or does not lead to worse code, this would be OK to do.
+  if (AddNode.getNode()->hasOneUse() &&
+      TLI.isMulAddWithConstProfitable(AddNode, ConstNode))
     return true;
 
   // Walk all the users of the constant with which we're multiplying.
index 6659d26..a80d9f9 100644 (file)
@@ -9044,6 +9044,29 @@ bool RISCVTargetLowering::decomposeMulByConstant(LLVMContext &Context, EVT VT,
   return false;
 }
 
+bool RISCVTargetLowering::isMulAddWithConstProfitable(
+    const SDValue &AddNode, const SDValue &ConstNode) const {
+  // Let the DAGCombiner decide for vectors.
+  EVT VT = AddNode.getValueType();
+  if (VT.isVector())
+    return true;
+
+  // Let the DAGCombiner decide for larger types.
+  if (VT.getScalarSizeInBits() > Subtarget.getXLen())
+    return true;
+
+  // It is worse if c1 is simm12 while c1*c2 is not.
+  ConstantSDNode *C1Node = cast<ConstantSDNode>(AddNode.getOperand(1));
+  ConstantSDNode *C2Node = cast<ConstantSDNode>(ConstNode);
+  const APInt &C1 = C1Node->getAPIntValue();
+  const APInt &C2 = C2Node->getAPIntValue();
+  if (C1.isSignedIntN(12) && !(C1 * C2).isSignedIntN(12))
+    return false;
+
+  // Default to true and let the DAGCombiner decide.
+  return true;
+}
+
 bool RISCVTargetLowering::allowsMisalignedMemoryAccesses(
     EVT VT, unsigned AddrSpace, Align Alignment, MachineMemOperand::Flags Flags,
     bool *Fast) const {
index 092bbc9..06f4235 100644 (file)
@@ -461,6 +461,9 @@ public:
   bool decomposeMulByConstant(LLVMContext &Context, EVT VT,
                               SDValue C) const override;
 
+  bool isMulAddWithConstProfitable(const SDValue &AddNode,
+                                   const SDValue &ConstNode) const override;
+
   TargetLowering::AtomicExpansionKind
   shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const override;
   Value *emitMaskedAtomicRMWIntrinsic(IRBuilderBase &Builder, AtomicRMWInst *AI,
index fa1ca24..e56b037 100644 (file)
@@ -146,20 +146,16 @@ define i64 @add_mul_combine_accept_b3(i64 %x) {
 define i32 @add_mul_combine_reject_a1(i32 %x) {
 ; RV32IMB-LABEL: add_mul_combine_reject_a1:
 ; RV32IMB:       # %bb.0:
+; RV32IMB-NEXT:    addi a0, a0, 1971
 ; RV32IMB-NEXT:    addi a1, zero, 29
 ; RV32IMB-NEXT:    mul a0, a0, a1
-; RV32IMB-NEXT:    lui a1, 14
-; RV32IMB-NEXT:    addi a1, a1, -185
-; RV32IMB-NEXT:    add a0, a0, a1
 ; RV32IMB-NEXT:    ret
 ;
 ; RV64IMB-LABEL: add_mul_combine_reject_a1:
 ; RV64IMB:       # %bb.0:
+; RV64IMB-NEXT:    addiw a0, a0, 1971
 ; RV64IMB-NEXT:    addi a1, zero, 29
 ; RV64IMB-NEXT:    mulw a0, a0, a1
-; RV64IMB-NEXT:    lui a1, 14
-; RV64IMB-NEXT:    addiw a1, a1, -185
-; RV64IMB-NEXT:    addw a0, a0, a1
 ; RV64IMB-NEXT:    ret
   %tmp0 = add i32 %x, 1971
   %tmp1 = mul i32 %tmp0, 29
@@ -169,20 +165,16 @@ define i32 @add_mul_combine_reject_a1(i32 %x) {
 define signext i32 @add_mul_combine_reject_a2(i32 signext %x) {
 ; RV32IMB-LABEL: add_mul_combine_reject_a2:
 ; RV32IMB:       # %bb.0:
+; RV32IMB-NEXT:    addi a0, a0, 1971
 ; RV32IMB-NEXT:    addi a1, zero, 29
 ; RV32IMB-NEXT:    mul a0, a0, a1
-; RV32IMB-NEXT:    lui a1, 14
-; RV32IMB-NEXT:    addi a1, a1, -185
-; RV32IMB-NEXT:    add a0, a0, a1
 ; RV32IMB-NEXT:    ret
 ;
 ; RV64IMB-LABEL: add_mul_combine_reject_a2:
 ; RV64IMB:       # %bb.0:
+; RV64IMB-NEXT:    addiw a0, a0, 1971
 ; RV64IMB-NEXT:    addi a1, zero, 29
 ; RV64IMB-NEXT:    mulw a0, a0, a1
-; RV64IMB-NEXT:    lui a1, 14
-; RV64IMB-NEXT:    addiw a1, a1, -185
-; RV64IMB-NEXT:    addw a0, a0, a1
 ; RV64IMB-NEXT:    ret
   %tmp0 = add i32 %x, 1971
   %tmp1 = mul i32 %tmp0, 29
@@ -206,11 +198,9 @@ define i64 @add_mul_combine_reject_a3(i64 %x) {
 ;
 ; RV64IMB-LABEL: add_mul_combine_reject_a3:
 ; RV64IMB:       # %bb.0:
+; RV64IMB-NEXT:    addi a0, a0, 1971
 ; RV64IMB-NEXT:    addi a1, zero, 29
 ; RV64IMB-NEXT:    mul a0, a0, a1
-; RV64IMB-NEXT:    lui a1, 14
-; RV64IMB-NEXT:    addiw a1, a1, -185
-; RV64IMB-NEXT:    add a0, a0, a1
 ; RV64IMB-NEXT:    ret
   %tmp0 = add i64 %x, 1971
   %tmp1 = mul i64 %tmp0, 29
@@ -220,20 +210,17 @@ define i64 @add_mul_combine_reject_a3(i64 %x) {
 define i32 @add_mul_combine_reject_c1(i32 %x) {
 ; RV32IMB-LABEL: add_mul_combine_reject_c1:
 ; RV32IMB:       # %bb.0:
+; RV32IMB-NEXT:    addi a0, a0, 1000
 ; RV32IMB-NEXT:    sh3add a1, a0, a0
 ; RV32IMB-NEXT:    sh3add a0, a1, a0
-; RV32IMB-NEXT:    lui a1, 18
-; RV32IMB-NEXT:    addi a1, a1, -728
-; RV32IMB-NEXT:    add a0, a0, a1
 ; RV32IMB-NEXT:    ret
 ;
 ; RV64IMB-LABEL: add_mul_combine_reject_c1:
 ; RV64IMB:       # %bb.0:
+; RV64IMB-NEXT:    addi a0, a0, 1000
 ; RV64IMB-NEXT:    sh3add a1, a0, a0
 ; RV64IMB-NEXT:    sh3add a0, a1, a0
-; RV64IMB-NEXT:    lui a1, 18
-; RV64IMB-NEXT:    addiw a1, a1, -728
-; RV64IMB-NEXT:    addw a0, a0, a1
+; RV64IMB-NEXT:    sext.w a0, a0
 ; RV64IMB-NEXT:    ret
   %tmp0 = add i32 %x, 1000
   %tmp1 = mul i32 %tmp0, 73
@@ -243,20 +230,17 @@ define i32 @add_mul_combine_reject_c1(i32 %x) {
 define signext i32 @add_mul_combine_reject_c2(i32 signext %x) {
 ; RV32IMB-LABEL: add_mul_combine_reject_c2:
 ; RV32IMB:       # %bb.0:
+; RV32IMB-NEXT:    addi a0, a0, 1000
 ; RV32IMB-NEXT:    sh3add a1, a0, a0
 ; RV32IMB-NEXT:    sh3add a0, a1, a0
-; RV32IMB-NEXT:    lui a1, 18
-; RV32IMB-NEXT:    addi a1, a1, -728
-; RV32IMB-NEXT:    add a0, a0, a1
 ; RV32IMB-NEXT:    ret
 ;
 ; RV64IMB-LABEL: add_mul_combine_reject_c2:
 ; RV64IMB:       # %bb.0:
+; RV64IMB-NEXT:    addi a0, a0, 1000
 ; RV64IMB-NEXT:    sh3add a1, a0, a0
 ; RV64IMB-NEXT:    sh3add a0, a1, a0
-; RV64IMB-NEXT:    lui a1, 18
-; RV64IMB-NEXT:    addiw a1, a1, -728
-; RV64IMB-NEXT:    addw a0, a0, a1
+; RV64IMB-NEXT:    sext.w a0, a0
 ; RV64IMB-NEXT:    ret
   %tmp0 = add i32 %x, 1000
   %tmp1 = mul i32 %tmp0, 73
@@ -280,11 +264,9 @@ define i64 @add_mul_combine_reject_c3(i64 %x) {
 ;
 ; RV64IMB-LABEL: add_mul_combine_reject_c3:
 ; RV64IMB:       # %bb.0:
+; RV64IMB-NEXT:    addi a0, a0, 1000
 ; RV64IMB-NEXT:    sh3add a1, a0, a0
 ; RV64IMB-NEXT:    sh3add a0, a1, a0
-; RV64IMB-NEXT:    lui a1, 18
-; RV64IMB-NEXT:    addiw a1, a1, -728
-; RV64IMB-NEXT:    add a0, a0, a1
 ; RV64IMB-NEXT:    ret
   %tmp0 = add i64 %x, 1000
   %tmp1 = mul i64 %tmp0, 73
@@ -294,20 +276,16 @@ define i64 @add_mul_combine_reject_c3(i64 %x) {
 define i32 @add_mul_combine_reject_d1(i32 %x) {
 ; RV32IMB-LABEL: add_mul_combine_reject_d1:
 ; RV32IMB:       # %bb.0:
+; RV32IMB-NEXT:    addi a0, a0, 1000
 ; RV32IMB-NEXT:    sh1add a0, a0, a0
 ; RV32IMB-NEXT:    slli a0, a0, 6
-; RV32IMB-NEXT:    lui a1, 47
-; RV32IMB-NEXT:    addi a1, a1, -512
-; RV32IMB-NEXT:    add a0, a0, a1
 ; RV32IMB-NEXT:    ret
 ;
 ; RV64IMB-LABEL: add_mul_combine_reject_d1:
 ; RV64IMB:       # %bb.0:
+; RV64IMB-NEXT:    addi a0, a0, 1000
 ; RV64IMB-NEXT:    sh1add a0, a0, a0
-; RV64IMB-NEXT:    slli a0, a0, 6
-; RV64IMB-NEXT:    lui a1, 47
-; RV64IMB-NEXT:    addiw a1, a1, -512
-; RV64IMB-NEXT:    addw a0, a0, a1
+; RV64IMB-NEXT:    slliw a0, a0, 6
 ; RV64IMB-NEXT:    ret
   %tmp0 = add i32 %x, 1000
   %tmp1 = mul i32 %tmp0, 192
@@ -317,20 +295,16 @@ define i32 @add_mul_combine_reject_d1(i32 %x) {
 define signext i32 @add_mul_combine_reject_d2(i32 signext %x) {
 ; RV32IMB-LABEL: add_mul_combine_reject_d2:
 ; RV32IMB:       # %bb.0:
+; RV32IMB-NEXT:    addi a0, a0, 1000
 ; RV32IMB-NEXT:    sh1add a0, a0, a0
 ; RV32IMB-NEXT:    slli a0, a0, 6
-; RV32IMB-NEXT:    lui a1, 47
-; RV32IMB-NEXT:    addi a1, a1, -512
-; RV32IMB-NEXT:    add a0, a0, a1
 ; RV32IMB-NEXT:    ret
 ;
 ; RV64IMB-LABEL: add_mul_combine_reject_d2:
 ; RV64IMB:       # %bb.0:
+; RV64IMB-NEXT:    addi a0, a0, 1000
 ; RV64IMB-NEXT:    sh1add a0, a0, a0
-; RV64IMB-NEXT:    slli a0, a0, 6
-; RV64IMB-NEXT:    lui a1, 47
-; RV64IMB-NEXT:    addiw a1, a1, -512
-; RV64IMB-NEXT:    addw a0, a0, a1
+; RV64IMB-NEXT:    slliw a0, a0, 6
 ; RV64IMB-NEXT:    ret
   %tmp0 = add i32 %x, 1000
   %tmp1 = mul i32 %tmp0, 192
@@ -356,11 +330,9 @@ define i64 @add_mul_combine_reject_d3(i64 %x) {
 ;
 ; RV64IMB-LABEL: add_mul_combine_reject_d3:
 ; RV64IMB:       # %bb.0:
+; RV64IMB-NEXT:    addi a0, a0, 1000
 ; RV64IMB-NEXT:    sh1add a0, a0, a0
 ; RV64IMB-NEXT:    slli a0, a0, 6
-; RV64IMB-NEXT:    lui a1, 47
-; RV64IMB-NEXT:    addiw a1, a1, -512
-; RV64IMB-NEXT:    add a0, a0, a1
 ; RV64IMB-NEXT:    ret
   %tmp0 = add i64 %x, 1000
   %tmp1 = mul i64 %tmp0, 192