[AArch64] Reassociate sub(x, add(m1, m2)) to sub(sub(x, m1), m2)
authorDavid Green <david.green@arm.com>
Mon, 13 Feb 2023 14:35:10 +0000 (14:35 +0000)
committerDavid Green <david.green@arm.com>
Mon, 13 Feb 2023 14:35:10 +0000 (14:35 +0000)
The mid end will reassociate sub(sub(x, m1), m2) to sub(x, add(m1, m2)). This
reassociates it back to allow the creation of more mls instructions.

Differential Revision: https://reviews.llvm.org/D143143

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/test/CodeGen/AArch64/arm64-vmul.ll
llvm/test/CodeGen/AArch64/reassocmls.ll

index 7ad92aa..4db2b10 100644 (file)
@@ -17703,6 +17703,32 @@ static SDValue performAddCombineForShiftedOperands(SDNode *N,
   return SDValue();
 }
 
+// The mid end will reassociate sub(sub(x, m1), m2) to sub(x, add(m1, m2))
+// This reassociates it back to allow the creation of more mls instructions.
+static SDValue performSubAddMULCombine(SDNode *N, SelectionDAG &DAG) {
+  if (N->getOpcode() != ISD::SUB)
+    return SDValue();
+  SDValue Add = N->getOperand(1);
+  if (Add.getOpcode() != ISD::ADD)
+    return SDValue();
+
+  SDValue X = N->getOperand(0);
+  if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(X)))
+    return SDValue();
+  SDValue M1 = Add.getOperand(0);
+  SDValue M2 = Add.getOperand(1);
+  if (M1.getOpcode() != ISD::MUL && M1.getOpcode() != AArch64ISD::SMULL &&
+      M1.getOpcode() != AArch64ISD::UMULL)
+    return SDValue();
+  if (M2.getOpcode() != ISD::MUL && M2.getOpcode() != AArch64ISD::SMULL &&
+      M2.getOpcode() != AArch64ISD::UMULL)
+    return SDValue();
+
+  EVT VT = N->getValueType(0);
+  SDValue Sub = DAG.getNode(ISD::SUB, SDLoc(N), VT, X, M1);
+  return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2);
+}
+
 static SDValue performAddSubCombine(SDNode *N,
                                     TargetLowering::DAGCombinerInfo &DCI,
                                     SelectionDAG &DAG) {
@@ -17719,6 +17745,8 @@ static SDValue performAddSubCombine(SDNode *N,
     return Val;
   if (SDValue Val = performAddCombineForShiftedOperands(N, DAG))
     return Val;
+  if (SDValue Val = performSubAddMULCombine(N, DAG))
+    return Val;
 
   return performAddSubLongCombine(N, DCI, DAG);
 }
index 7f743f6..3a9f031 100644 (file)
@@ -457,12 +457,11 @@ define <2 x i64> @smlsl2d(ptr %A, ptr %B, ptr %C) nounwind {
 define void @smlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) {
 ; CHECK-LABEL: smlsl8h_chain_with_constant:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    smull.8h v0, v0, v2
-; CHECK-NEXT:    mvn.8b v2, v2
 ; CHECK-NEXT:    movi.16b v3, #1
-; CHECK-NEXT:    smlal.8h v0, v1, v2
-; CHECK-NEXT:    sub.8h v0, v3, v0
-; CHECK-NEXT:    str q0, [x0]
+; CHECK-NEXT:    smlsl.8h v3, v0, v2
+; CHECK-NEXT:    mvn.8b v0, v2
+; CHECK-NEXT:    smlsl.8h v3, v1, v0
+; CHECK-NEXT:    str q3, [x0]
 ; CHECK-NEXT:    ret
   %xor = xor <8 x i8> %v3, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
   %smull.1 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %v1, <8 x i8> %v3)
@@ -476,13 +475,12 @@ define void @smlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <
 define void @smlsl2d_chain_with_constant(ptr %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) {
 ; CHECK-LABEL: smlsl2d_chain_with_constant:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    smull.2d v0, v0, v2
 ; CHECK-NEXT:    mov w8, #257
-; CHECK-NEXT:    mvn.8b v2, v2
-; CHECK-NEXT:    smlal.2d v0, v1, v2
-; CHECK-NEXT:    dup.2d v1, x8
-; CHECK-NEXT:    sub.2d v0, v1, v0
-; CHECK-NEXT:    str q0, [x0]
+; CHECK-NEXT:    dup.2d v3, x8
+; CHECK-NEXT:    smlsl.2d v3, v0, v2
+; CHECK-NEXT:    mvn.8b v0, v2
+; CHECK-NEXT:    smlsl.2d v3, v1, v0
+; CHECK-NEXT:    str q3, [x0]
 ; CHECK-NEXT:    ret
   %xor = xor <2 x i32> %v3, <i32 -1, i32 -1>
   %smull.1 = tail call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %v1, <2 x i32> %v3)
@@ -738,12 +736,11 @@ define <2 x i64> @umlsl2d(ptr %A, ptr %B, ptr %C) nounwind {
 define void @umlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) {
 ; CHECK-LABEL: umlsl8h_chain_with_constant:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    umull.8h v0, v0, v2
-; CHECK-NEXT:    mvn.8b v2, v2
 ; CHECK-NEXT:    movi.16b v3, #1
-; CHECK-NEXT:    umlal.8h v0, v1, v2
-; CHECK-NEXT:    sub.8h v0, v3, v0
-; CHECK-NEXT:    str q0, [x0]
+; CHECK-NEXT:    umlsl.8h v3, v0, v2
+; CHECK-NEXT:    mvn.8b v0, v2
+; CHECK-NEXT:    umlsl.8h v3, v1, v0
+; CHECK-NEXT:    str q3, [x0]
 ; CHECK-NEXT:    ret
   %xor = xor <8 x i8> %v3, <i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1, i8 -1>
   %umull.1 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %v1, <8 x i8> %v3)
@@ -757,13 +754,12 @@ define void @umlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <
 define void @umlsl2d_chain_with_constant(ptr %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) {
 ; CHECK-LABEL: umlsl2d_chain_with_constant:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    umull.2d v0, v0, v2
 ; CHECK-NEXT:    mov w8, #257
-; CHECK-NEXT:    mvn.8b v2, v2
-; CHECK-NEXT:    umlal.2d v0, v1, v2
-; CHECK-NEXT:    dup.2d v1, x8
-; CHECK-NEXT:    sub.2d v0, v1, v0
-; CHECK-NEXT:    str q0, [x0]
+; CHECK-NEXT:    dup.2d v3, x8
+; CHECK-NEXT:    umlsl.2d v3, v0, v2
+; CHECK-NEXT:    mvn.8b v0, v2
+; CHECK-NEXT:    umlsl.2d v3, v1, v0
+; CHECK-NEXT:    str q3, [x0]
 ; CHECK-NEXT:    ret
   %xor = xor <2 x i32> %v3, <i32 -1, i32 -1>
   %umull.1 = tail call <2 x i64> @llvm.aarch64.neon.umull.v2i64(<2 x i32> %v1, <2 x i32> %v3)
index cf201ca..731d973 100644 (file)
@@ -4,9 +4,8 @@
 define i64 @smlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) {
 ; CHECK-LABEL: smlsl_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    smull x8, w4, w3
-; CHECK-NEXT:    smaddl x8, w2, w1, x8
-; CHECK-NEXT:    sub x0, x0, x8
+; CHECK-NEXT:    smsubl x8, w4, w3, x0
+; CHECK-NEXT:    smsubl x0, w2, w1, x8
 ; CHECK-NEXT:    ret
   %be = sext i32 %b to i64
   %ce = sext i32 %c to i64
@@ -22,9 +21,8 @@ define i64 @smlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) {
 define i64 @umlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) {
 ; CHECK-LABEL: umlsl_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    umull x8, w4, w3
-; CHECK-NEXT:    umaddl x8, w2, w1, x8
-; CHECK-NEXT:    sub x0, x0, x8
+; CHECK-NEXT:    umsubl x8, w4, w3, x0
+; CHECK-NEXT:    umsubl x0, w2, w1, x8
 ; CHECK-NEXT:    ret
   %be = zext i32 %b to i64
   %ce = zext i32 %c to i64
@@ -40,9 +38,8 @@ define i64 @umlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) {
 define i64 @mls_i64(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) {
 ; CHECK-LABEL: mls_i64:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mul x8, x2, x1
-; CHECK-NEXT:    madd x8, x4, x3, x8
-; CHECK-NEXT:    sub x0, x0, x8
+; CHECK-NEXT:    msub x8, x4, x3, x0
+; CHECK-NEXT:    msub x0, x2, x1, x8
 ; CHECK-NEXT:    ret
   %m1.neg = mul i64 %c, %b
   %m2.neg = mul i64 %e, %d
@@ -54,9 +51,8 @@ define i64 @mls_i64(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) {
 define i16 @mls_i16(i16 %a, i16 %b, i16 %c, i16 %d, i16 %e) {
 ; CHECK-LABEL: mls_i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mul w8, w2, w1
-; CHECK-NEXT:    madd w8, w4, w3, w8
-; CHECK-NEXT:    sub w0, w0, w8
+; CHECK-NEXT:    msub w8, w4, w3, w0
+; CHECK-NEXT:    msub w0, w2, w1, w8
 ; CHECK-NEXT:    ret
   %m1.neg = mul i16 %c, %b
   %m2.neg = mul i16 %e, %d
@@ -97,9 +93,8 @@ define i64 @mls_i64_C(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) {
 define <8 x i16> @smlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d, <8 x i8> %e) {
 ; CHECK-LABEL: smlsl_v8i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    smull v3.8h, v4.8b, v3.8b
-; CHECK-NEXT:    smlal v3.8h, v2.8b, v1.8b
-; CHECK-NEXT:    sub v0.8h, v0.8h, v3.8h
+; CHECK-NEXT:    smlsl v0.8h, v4.8b, v3.8b
+; CHECK-NEXT:    smlsl v0.8h, v2.8b, v1.8b
 ; CHECK-NEXT:    ret
   %be = sext <8 x i8> %b to <8 x i16>
   %ce = sext <8 x i8> %c to <8 x i16>
@@ -115,9 +110,8 @@ define <8 x i16> @smlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %
 define <8 x i16> @umlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d, <8 x i8> %e) {
 ; CHECK-LABEL: umlsl_v8i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    umull v3.8h, v4.8b, v3.8b
-; CHECK-NEXT:    umlal v3.8h, v2.8b, v1.8b
-; CHECK-NEXT:    sub v0.8h, v0.8h, v3.8h
+; CHECK-NEXT:    umlsl v0.8h, v4.8b, v3.8b
+; CHECK-NEXT:    umlsl v0.8h, v2.8b, v1.8b
 ; CHECK-NEXT:    ret
   %be = zext <8 x i8> %b to <8 x i16>
   %ce = zext <8 x i8> %c to <8 x i16>
@@ -133,9 +127,8 @@ define <8 x i16> @umlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %
 define <8 x i16> @mls_v8i16(<8 x i16> %a, <8 x i16> %b, <8 x i16> %c, <8 x i16> %d, <8 x i16> %e) {
 ; CHECK-LABEL: mls_v8i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    mul v1.8h, v2.8h, v1.8h
-; CHECK-NEXT:    mla v1.8h, v4.8h, v3.8h
-; CHECK-NEXT:    sub v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    mls v0.8h, v4.8h, v3.8h
+; CHECK-NEXT:    mls v0.8h, v2.8h, v1.8h
 ; CHECK-NEXT:    ret
   %m1.neg = mul <8 x i16> %c, %b
   %m2.neg = mul <8 x i16> %e, %d
@@ -157,6 +150,20 @@ define <8 x i16> @mla_v8i16(<8 x i16> %a, <8 x i16> %b, <8 x i16> %c, <8 x i16>
   ret <8 x i16> %s2
 }
 
+define <8 x i16> @mls_v8i16_C(<8 x i16> %a, <8 x i16> %b, <8 x i16> %c, <8 x i16> %d, <8 x i16> %e) {
+; CHECK-LABEL: mls_v8i16_C:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v0.8h, #10
+; CHECK-NEXT:    mls v0.8h, v4.8h, v3.8h
+; CHECK-NEXT:    mls v0.8h, v2.8h, v1.8h
+; CHECK-NEXT:    ret
+  %m1.neg = mul <8 x i16> %c, %b
+  %m2.neg = mul <8 x i16> %e, %d
+  %reass.add = add <8 x i16> %m2.neg, %m1.neg
+  %s2 = sub <8 x i16> <i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10, i16 10>, %reass.add
+  ret <8 x i16> %s2
+}
+
 
 define <vscale x 8 x i16> @smlsl_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i8> %b, <vscale x 8 x i8> %c, <vscale x 8 x i8> %d, <vscale x 8 x i8> %e) {
 ; CHECK-LABEL: smlsl_nxv8i16:
@@ -166,9 +173,8 @@ define <vscale x 8 x i16> @smlsl_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i8
 ; CHECK-NEXT:    sxtb z4.h, p0/m, z4.h
 ; CHECK-NEXT:    sxtb z1.h, p0/m, z1.h
 ; CHECK-NEXT:    sxtb z2.h, p0/m, z2.h
-; CHECK-NEXT:    mul z3.h, z4.h, z3.h
-; CHECK-NEXT:    mla z3.h, p0/m, z2.h, z1.h
-; CHECK-NEXT:    sub z0.h, z0.h, z3.h
+; CHECK-NEXT:    mls z0.h, p0/m, z4.h, z3.h
+; CHECK-NEXT:    mls z0.h, p0/m, z2.h, z1.h
 ; CHECK-NEXT:    ret
   %be = sext <vscale x 8 x i8> %b to <vscale x 8 x i16>
   %ce = sext <vscale x 8 x i8> %c to <vscale x 8 x i16>
@@ -184,14 +190,13 @@ define <vscale x 8 x i16> @smlsl_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i8
 define <vscale x 8 x i16> @umlsl_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i8> %b, <vscale x 8 x i8> %c, <vscale x 8 x i8> %d, <vscale x 8 x i8> %e) {
 ; CHECK-LABEL: umlsl_nxv8i16:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.h
 ; CHECK-NEXT:    and z3.h, z3.h, #0xff
 ; CHECK-NEXT:    and z4.h, z4.h, #0xff
-; CHECK-NEXT:    ptrue p0.h
 ; CHECK-NEXT:    and z1.h, z1.h, #0xff
 ; CHECK-NEXT:    and z2.h, z2.h, #0xff
-; CHECK-NEXT:    mul z3.h, z4.h, z3.h
-; CHECK-NEXT:    mla z3.h, p0/m, z2.h, z1.h
-; CHECK-NEXT:    sub z0.h, z0.h, z3.h
+; CHECK-NEXT:    mls z0.h, p0/m, z4.h, z3.h
+; CHECK-NEXT:    mls z0.h, p0/m, z2.h, z1.h
 ; CHECK-NEXT:    ret
   %be = zext <vscale x 8 x i8> %b to <vscale x 8 x i16>
   %ce = zext <vscale x 8 x i8> %c to <vscale x 8 x i16>
@@ -208,9 +213,8 @@ define <vscale x 8 x i16> @mls_nxv8i16(<vscale x 8 x i16> %a, <vscale x 8 x i16>
 ; CHECK-LABEL: mls_nxv8i16:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p0.h
-; CHECK-NEXT:    mul z3.h, z4.h, z3.h
-; CHECK-NEXT:    mla z3.h, p0/m, z2.h, z1.h
-; CHECK-NEXT:    sub z0.h, z0.h, z3.h
+; CHECK-NEXT:    mls z0.h, p0/m, z4.h, z3.h
+; CHECK-NEXT:    mls z0.h, p0/m, z2.h, z1.h
 ; CHECK-NEXT:    ret
   %m1.neg = mul <vscale x 8 x i16> %c, %b
   %m2.neg = mul <vscale x 8 x i16> %e, %d