[DAGCombiner] Fold add (mul x, C), x to mul x, C+1
authorNikita Popov <npopov@redhat.com>
Fri, 14 Apr 2023 13:46:45 +0000 (15:46 +0200)
committerNikita Popov <npopov@redhat.com>
Mon, 17 Apr 2023 10:33:46 +0000 (12:33 +0200)
While this is normally non-canonical IR, this pattern can appear
during SDAG lowering if the add is actually a getelementptr, as
illustrated in `@test_ptr`. This pattern comes up when doing
provenance-aware high-bit pointer tagging.

Proof: https://alive2.llvm.org/ce/z/DLoEcs

Fixes https://github.com/llvm/llvm-project/issues/62093.

Differential Revision: https://reviews.llvm.org/D148341

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/AArch64/arm64-promote-const.ll
llvm/test/CodeGen/X86/add-of-mul.ll

index 3d11a74..ddb91ad 100644 (file)
@@ -3058,6 +3058,15 @@ SDValue DAGCombiner::visitADDLikeCommutative(SDValue N0, SDValue N1,
     }
   }
 
+  // add (mul x, C), x -> mul x, C+1
+  if (N0.getOpcode() == ISD::MUL && N0.getOperand(0) == N1 &&
+      isConstantOrConstantVector(N0.getOperand(1), /*NoOpaques=*/true) &&
+      N0.hasOneUse()) {
+    SDValue NewC = DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(1),
+                               DAG.getConstant(1, DL, VT));
+    return DAG.getNode(ISD::MUL, DL, VT, N0.getOperand(0), NewC);
+  }
+
   // If the target's bool is represented as 0/1, prefer to make this 'sub 0/1'
   // rather than 'add 0/-1' (the zext should get folded).
   // add (sext i1 Y), X --> sub X, (zext i1 Y)
index 93ff057..1531615 100644 (file)
@@ -44,7 +44,7 @@ entry:
 ; PROMOTED: ldr q[[REGNUM:[0-9]+]], [[[PAGEADDR]], [[CSTV1]]@PAGEOFF]
 ; Destination register is defined by ABI
 ; PROMOTED-NEXT: add.16b v0, v0, v[[REGNUM]]
-; PROMOTED-NEXT: mla.16b v0, v0, v[[REGNUM]]
+; PROMOTED-NEXT: mls.16b v0, v0, v[[REGNUM]]
 ; PROMOTED-NEXT: ret
 
 ; REGULAR-LABEL: test2:
@@ -55,12 +55,12 @@ entry:
 ; REGULAR: ldr q[[REGNUM:[0-9]+]], [[[PAGEADDR]], [[CSTLABEL]]@PAGEOFF]
 ; Destination register is defined by ABI
 ; REGULAR-NEXT: add.16b v0, v0, v[[REGNUM]]
-; REGULAR-NEXT: mla.16b v0, v0, v[[REGNUM]]
+; REGULAR-NEXT: mls.16b v0, v0, v[[REGNUM]]
 ; REGULAR-NEXT: ret
   %add.i = add <16 x i8> %arg, <i8 -40, i8 -93, i8 -118, i8 -99, i8 -75, i8 -105, i8 74, i8 -110, i8 62, i8 -115, i8 -119, i8 -120, i8 34, i8 -124, i8 0, i8 -128>
   %mul.i = mul <16 x i8> %add.i, <i8 -40, i8 -93, i8 -118, i8 -99, i8 -75, i8 -105, i8 74, i8 -110, i8 62, i8 -115, i8 -119, i8 -120, i8 34, i8 -124, i8 0, i8 -128>
-  %add.i9 = add <16 x i8> %add.i, %mul.i
-  ret <16 x i8> %add.i9
+  %sub.i9 = sub <16 x i8> %add.i, %mul.i
+  ret <16 x i8> %sub.i9
 }
 
 ; Two different uses of the same constant in two different basic blocks,
index ac02204..63f9fa8 100644 (file)
@@ -5,8 +5,7 @@ define i32 @test_scalar(i32 %x) {
 ; CHECK-LABEL: test_scalar:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    # kill: def $edi killed $edi def $rdi
-; CHECK-NEXT:    leal (%rdi,%rdi,2), %eax
-; CHECK-NEXT:    addl %edi, %eax
+; CHECK-NEXT:    leal (,%rdi,4), %eax
 ; CHECK-NEXT:    retq
   %mul = mul i32 %x, 3
   %add = add i32 %mul, %x
@@ -17,8 +16,7 @@ define i32 @test_scalar_commuted(i32 %x) {
 ; CHECK-LABEL: test_scalar_commuted:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    # kill: def $edi killed $edi def $rdi
-; CHECK-NEXT:    leal (%rdi,%rdi,2), %eax
-; CHECK-NEXT:    addl %edi, %eax
+; CHECK-NEXT:    leal (,%rdi,4), %eax
 ; CHECK-NEXT:    retq
   %mul = mul i32 %x, 3
   %add = add i32 %x, %mul
@@ -28,8 +26,7 @@ define i32 @test_scalar_commuted(i32 %x) {
 define <4 x i32> @test_vector(<4 x i32> %x) {
 ; CHECK-LABEL: test_vector:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    paddd %xmm0, %xmm0
-; CHECK-NEXT:    paddd %xmm0, %xmm0
+; CHECK-NEXT:    pslld $2, %xmm0
 ; CHECK-NEXT:    retq
   %mul = mul <4 x i32> %x, <i32 3, i32 3, i32 3, i32 3>
   %add = add <4 x i32> %mul, %x
@@ -39,8 +36,7 @@ define <4 x i32> @test_vector(<4 x i32> %x) {
 define ptr @test_ptr(ptr %p) {
 ; CHECK-LABEL: test_ptr:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    leaq (%rdi,%rdi,2), %rax
-; CHECK-NEXT:    addq %rdi, %rax
+; CHECK-NEXT:    leaq (,%rdi,4), %rax
 ; CHECK-NEXT:    retq
   %addr = ptrtoint ptr %p to i64
   %mul = mul i64 %addr, 3