[DAG] try to convert multiply to shift via demanded bits
authorSanjay Patel <spatel@rotateright.com>
Wed, 23 Feb 2022 16:25:01 +0000 (11:25 -0500)
committerSanjay Patel <spatel@rotateright.com>
Wed, 23 Feb 2022 17:09:32 +0000 (12:09 -0500)
This is a fix for a regression discussed in:
https://github.com/llvm/llvm-project/issues/53829

We cleared more high multiplier bits with 995d400,
but that can lead to worse codegen because we would fail
to recognize the now disguised multiplication by neg-power-of-2
as a shift-left. The problem exists independently of the IR
change in the case that the multiply already had cleared high
bits. We also convert shl+sub into mul+add in instcombine's
negator.

This patch fills in the high-bits to see the shift transform
opportunity. Alive2 attempt to show correctness:
https://alive2.llvm.org/ce/z/GgSKVX

The AArch64, RISCV, and MIPS diffs look like clear wins. The
x86 code requires an extra move register in the minimal examples,
but it's still an improvement to get rid of the multiply on all
CPUs that I am aware of (because multiply is never as fast as a
shift).

There's a potential follow-up noted by the TODO comment. We
should already convert that pattern into shl+add in IR, so
it's probably not common:
https://alive2.llvm.org/ce/z/7QY_Ga

Fixes #53829

Differential Revision: https://reviews.llvm.org/D120216

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
llvm/test/CodeGen/AArch64/mul_pow2.ll
llvm/test/CodeGen/Mips/urem-seteq-illegal-types.ll
llvm/test/CodeGen/RISCV/mul.ll
llvm/test/CodeGen/X86/mul-demand.ll

index 6619f1c..5c86129 100644 (file)
@@ -2486,6 +2486,46 @@ bool TargetLowering::SimplifyDemandedBits(
       return TLO.CombineTo(Op, NewOp);
     }
 
+    // Match a multiply with a disguised negated-power-of-2 and convert to a
+    // an equivalent shift-left amount.
+    // Example: (X * MulC) + Op1 --> Op1 - (X << log2(-MulC))
+    auto getShiftLeftAmt = [&HighMask](SDValue Mul) -> unsigned {
+      if (Mul.getOpcode() != ISD::MUL || !Mul.hasOneUse())
+        return 0;
+
+      // Don't touch opaque constants. Also, ignore zero and power-of-2
+      // multiplies. Those will get folded later.
+      ConstantSDNode *MulC = isConstOrConstSplat(Mul.getOperand(1));
+      if (MulC && !MulC->isOpaque() && !MulC->isZero() &&
+          !MulC->getAPIntValue().isPowerOf2()) {
+        APInt UnmaskedC = MulC->getAPIntValue() | HighMask;
+        if (UnmaskedC.isNegatedPowerOf2())
+          return (-UnmaskedC).logBase2();
+      }
+      return 0;
+    };
+
+    auto foldMul = [&](SDValue X, SDValue Y, unsigned ShlAmt) {
+      EVT ShiftAmtTy = getShiftAmountTy(VT, TLO.DAG.getDataLayout());
+      SDValue ShlAmtC = TLO.DAG.getConstant(ShlAmt, dl, ShiftAmtTy);
+      SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, X, ShlAmtC);
+      SDValue Sub = TLO.DAG.getNode(ISD::SUB, dl, VT, Y, Shl);
+      return TLO.CombineTo(Op, Sub);
+    };
+
+    if (isOperationLegalOrCustom(ISD::SHL, VT)) {
+      if (Op.getOpcode() == ISD::ADD) {
+        // (X * MulC) + Op1 --> Op1 - (X << log2(-MulC))
+        if (unsigned ShAmt = getShiftLeftAmt(Op0))
+          return foldMul(Op0.getOperand(0), Op1, ShAmt);
+        // Op0 + (X * MulC) --> Op0 - (X << log2(-MulC))
+        if (unsigned ShAmt = getShiftLeftAmt(Op1))
+          return foldMul(Op1.getOperand(0), Op0, ShAmt);
+        // TODO:
+        // Op0 - (X * MulC) --> Op0 + (X << log2(-MulC))
+      }
+    }
+
     LLVM_FALLTHROUGH;
   }
   default:
index 31ff289..2c0bec9 100644 (file)
@@ -704,8 +704,7 @@ define i32 @ntest16(i32 %x) {
 define i32 @muladd_demand(i32 %x, i32 %y) {
 ; CHECK-LABEL: muladd_demand:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov w8, #131008
-; CHECK-NEXT:    madd w8, w0, w8, w1
+; CHECK-NEXT:    sub w8, w1, w0, lsl #6
 ; CHECK-NEXT:    and w0, w8, #0x1ffc0
 ; CHECK-NEXT:    ret
 ;
@@ -724,11 +723,10 @@ define i32 @muladd_demand(i32 %x, i32 %y) {
 define <4 x i32> @muladd_demand_commute(<4 x i32> %x, <4 x i32> %y) {
 ; CHECK-LABEL: muladd_demand_commute:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mov w8, #131008
-; CHECK-NEXT:    dup v2.4s, w8
-; CHECK-NEXT:    mla v1.4s, v0.4s, v2.4s
-; CHECK-NEXT:    movi v0.4s, #1, msl #16
-; CHECK-NEXT:    and v0.16b, v1.16b, v0.16b
+; CHECK-NEXT:    movi v2.4s, #1, msl #16
+; CHECK-NEXT:    shl v0.4s, v0.4s, #6
+; CHECK-NEXT:    sub v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
 ; CHECK-NEXT:    ret
 ;
 ; GISEL-LABEL: muladd_demand_commute:
index 236addd..6c477e6 100644 (file)
@@ -151,45 +151,45 @@ define i1 @test_urem_oversized(i66 %X) nounwind {
 ; MIPSEL-NEXT:    lui $1, 12057
 ; MIPSEL-NEXT:    ori $1, $1, 37186
 ; MIPSEL-NEXT:    multu $6, $1
-; MIPSEL-NEXT:    mflo $2
-; MIPSEL-NEXT:    mfhi $3
-; MIPSEL-NEXT:    lui $7, 52741
-; MIPSEL-NEXT:    ori $7, $7, 40665
-; MIPSEL-NEXT:    multu $6, $7
-; MIPSEL-NEXT:    mflo $8
+; MIPSEL-NEXT:    mflo $1
+; MIPSEL-NEXT:    mfhi $2
+; MIPSEL-NEXT:    lui $3, 52741
+; MIPSEL-NEXT:    ori $3, $3, 40665
+; MIPSEL-NEXT:    multu $6, $3
+; MIPSEL-NEXT:    mflo $7
+; MIPSEL-NEXT:    mfhi $8
+; MIPSEL-NEXT:    multu $5, $3
 ; MIPSEL-NEXT:    mfhi $9
-; MIPSEL-NEXT:    multu $5, $7
-; MIPSEL-NEXT:    mfhi $10
-; MIPSEL-NEXT:    mflo $11
-; MIPSEL-NEXT:    addu $9, $11, $9
-; MIPSEL-NEXT:    addu $12, $2, $9
-; MIPSEL-NEXT:    sltu $9, $9, $11
-; MIPSEL-NEXT:    sll $11, $12, 31
-; MIPSEL-NEXT:    sltu $2, $12, $2
-; MIPSEL-NEXT:    srl $13, $8, 1
-; MIPSEL-NEXT:    sll $8, $8, 1
-; MIPSEL-NEXT:    addu $2, $3, $2
-; MIPSEL-NEXT:    or $3, $13, $11
-; MIPSEL-NEXT:    srl $11, $12, 1
-; MIPSEL-NEXT:    addu $9, $10, $9
-; MIPSEL-NEXT:    mul $4, $4, $7
-; MIPSEL-NEXT:    mul $1, $5, $1
-; MIPSEL-NEXT:    sll $5, $6, 1
+; MIPSEL-NEXT:    mflo $10
+; MIPSEL-NEXT:    addu $8, $10, $8
+; MIPSEL-NEXT:    addu $11, $1, $8
+; MIPSEL-NEXT:    sltu $8, $8, $10
+; MIPSEL-NEXT:    sll $10, $11, 31
+; MIPSEL-NEXT:    sltu $1, $11, $1
+; MIPSEL-NEXT:    srl $12, $7, 1
+; MIPSEL-NEXT:    sll $7, $7, 1
+; MIPSEL-NEXT:    addu $1, $2, $1
+; MIPSEL-NEXT:    or $10, $12, $10
+; MIPSEL-NEXT:    srl $2, $11, 1
+; MIPSEL-NEXT:    addu $8, $9, $8
+; MIPSEL-NEXT:    mul $3, $4, $3
+; MIPSEL-NEXT:    sll $4, $6, 1
+; MIPSEL-NEXT:    sll $5, $5, 1
 ; MIPSEL-NEXT:    lui $6, 60010
 ; MIPSEL-NEXT:    ori $6, $6, 61135
-; MIPSEL-NEXT:    addu $2, $9, $2
-; MIPSEL-NEXT:    addu $1, $1, $2
-; MIPSEL-NEXT:    addu $2, $5, $4
-; MIPSEL-NEXT:    addu $1, $1, $2
+; MIPSEL-NEXT:    addu $1, $8, $1
+; MIPSEL-NEXT:    subu $1, $1, $5
+; MIPSEL-NEXT:    addu $3, $4, $3
+; MIPSEL-NEXT:    addu $1, $1, $3
 ; MIPSEL-NEXT:    andi $1, $1, 3
-; MIPSEL-NEXT:    sll $2, $1, 31
-; MIPSEL-NEXT:    or $4, $11, $2
-; MIPSEL-NEXT:    sltiu $2, $4, 13
-; MIPSEL-NEXT:    xori $4, $4, 13
-; MIPSEL-NEXT:    sltu $3, $3, $6
-; MIPSEL-NEXT:    movz $2, $3, $4
+; MIPSEL-NEXT:    sll $3, $1, 31
+; MIPSEL-NEXT:    or $3, $2, $3
+; MIPSEL-NEXT:    sltiu $2, $3, 13
+; MIPSEL-NEXT:    xori $3, $3, 13
+; MIPSEL-NEXT:    sltu $4, $10, $6
+; MIPSEL-NEXT:    movz $2, $4, $3
 ; MIPSEL-NEXT:    srl $1, $1, 1
-; MIPSEL-NEXT:    or $1, $1, $8
+; MIPSEL-NEXT:    or $1, $1, $7
 ; MIPSEL-NEXT:    andi $1, $1, 3
 ; MIPSEL-NEXT:    jr $ra
 ; MIPSEL-NEXT:    movn $2, $zero, $1
index ad72080..5782a60 100644 (file)
@@ -1550,47 +1550,29 @@ define i64 @mulhsu_i64(i64 %a, i64 %b) nounwind {
 define i8 @muladd_demand(i8 %x, i8 %y) nounwind {
 ; RV32I-LABEL: muladd_demand:
 ; RV32I:       # %bb.0:
-; RV32I-NEXT:    addi sp, sp, -16
-; RV32I-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
-; RV32I-NEXT:    sw s0, 8(sp) # 4-byte Folded Spill
-; RV32I-NEXT:    mv s0, a1
-; RV32I-NEXT:    li a1, 14
-; RV32I-NEXT:    call __mulsi3@plt
-; RV32I-NEXT:    add a0, s0, a0
+; RV32I-NEXT:    slli a0, a0, 1
+; RV32I-NEXT:    sub a0, a1, a0
 ; RV32I-NEXT:    andi a0, a0, 15
-; RV32I-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
-; RV32I-NEXT:    lw s0, 8(sp) # 4-byte Folded Reload
-; RV32I-NEXT:    addi sp, sp, 16
 ; RV32I-NEXT:    ret
 ;
 ; RV32IM-LABEL: muladd_demand:
 ; RV32IM:       # %bb.0:
-; RV32IM-NEXT:    li a2, 14
-; RV32IM-NEXT:    mul a0, a0, a2
-; RV32IM-NEXT:    add a0, a1, a0
+; RV32IM-NEXT:    slli a0, a0, 1
+; RV32IM-NEXT:    sub a0, a1, a0
 ; RV32IM-NEXT:    andi a0, a0, 15
 ; RV32IM-NEXT:    ret
 ;
 ; RV64I-LABEL: muladd_demand:
 ; RV64I:       # %bb.0:
-; RV64I-NEXT:    addi sp, sp, -16
-; RV64I-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
-; RV64I-NEXT:    sd s0, 0(sp) # 8-byte Folded Spill
-; RV64I-NEXT:    mv s0, a1
-; RV64I-NEXT:    li a1, 14
-; RV64I-NEXT:    call __muldi3@plt
-; RV64I-NEXT:    addw a0, s0, a0
+; RV64I-NEXT:    slliw a0, a0, 1
+; RV64I-NEXT:    subw a0, a1, a0
 ; RV64I-NEXT:    andi a0, a0, 15
-; RV64I-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
-; RV64I-NEXT:    ld s0, 0(sp) # 8-byte Folded Reload
-; RV64I-NEXT:    addi sp, sp, 16
 ; RV64I-NEXT:    ret
 ;
 ; RV64IM-LABEL: muladd_demand:
 ; RV64IM:       # %bb.0:
-; RV64IM-NEXT:    li a2, 14
-; RV64IM-NEXT:    mulw a0, a0, a2
-; RV64IM-NEXT:    addw a0, a1, a0
+; RV64IM-NEXT:    slliw a0, a0, 1
+; RV64IM-NEXT:    subw a0, a1, a0
 ; RV64IM-NEXT:    andi a0, a0, 15
 ; RV64IM-NEXT:    ret
   %m = mul i8 %x, 14
index 0af5cb3..3454a84 100644 (file)
@@ -4,8 +4,9 @@
 define i64 @muladd_demand(i64 %x, i64 %y) {
 ; CHECK-LABEL: muladd_demand:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    imull $131008, %edi, %eax # imm = 0x1FFC0
-; CHECK-NEXT:    addl %esi, %eax
+; CHECK-NEXT:    movq %rsi, %rax
+; CHECK-NEXT:    shll $6, %edi
+; CHECK-NEXT:    subl %edi, %eax
 ; CHECK-NEXT:    shlq $47, %rax
 ; CHECK-NEXT:    retq
   %m = mul i64 %x, 131008 ; 0x0001ffc0
@@ -17,9 +18,10 @@ define i64 @muladd_demand(i64 %x, i64 %y) {
 define <2 x i64> @muladd_demand_commute(<2 x i64> %x, <2 x i64> %y) {
 ; CHECK-LABEL: muladd_demand_commute:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    pmuludq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
-; CHECK-NEXT:    paddq %xmm1, %xmm0
-; CHECK-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
+; CHECK-NEXT:    psllq $6, %xmm0
+; CHECK-NEXT:    psubq %xmm0, %xmm1
+; CHECK-NEXT:    pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
+; CHECK-NEXT:    movdqa %xmm1, %xmm0
 ; CHECK-NEXT:    retq
   %m = mul <2 x i64> %x, <i64 131008, i64 131008>
   %a = add <2 x i64> %y, %m