From e9eaee9da196265d20dbeaf7920c24ccb33e2d04 Mon Sep 17 00:00:00 2001 From: David Green Date: Mon, 13 Feb 2023 14:35:10 +0000 Subject: [PATCH] [AArch64] Reassociate sub(x, add(m1, m2)) to sub(sub(x, m1), m2) The mid end will reassociate sub(sub(x, m1), m2) to sub(x, add(m1, m2)). This reassociates it back to allow the creation of more mls instructions. Differential Revision: https://reviews.llvm.org/D143143 --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 28 +++++++++++ llvm/test/CodeGen/AArch64/arm64-vmul.ll | 40 +++++++-------- llvm/test/CodeGen/AArch64/reassocmls.ll | 66 +++++++++++++------------ 3 files changed, 81 insertions(+), 53 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 7ad92aa..4db2b10 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17703,6 +17703,32 @@ static SDValue performAddCombineForShiftedOperands(SDNode *N, return SDValue(); } +// The mid end will reassociate sub(sub(x, m1), m2) to sub(x, add(m1, m2)) +// This reassociates it back to allow the creation of more mls instructions. +static SDValue performSubAddMULCombine(SDNode *N, SelectionDAG &DAG) { + if (N->getOpcode() != ISD::SUB) + return SDValue(); + SDValue Add = N->getOperand(1); + if (Add.getOpcode() != ISD::ADD) + return SDValue(); + + SDValue X = N->getOperand(0); + if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(X))) + return SDValue(); + SDValue M1 = Add.getOperand(0); + SDValue M2 = Add.getOperand(1); + if (M1.getOpcode() != ISD::MUL && M1.getOpcode() != AArch64ISD::SMULL && + M1.getOpcode() != AArch64ISD::UMULL) + return SDValue(); + if (M2.getOpcode() != ISD::MUL && M2.getOpcode() != AArch64ISD::SMULL && + M2.getOpcode() != AArch64ISD::UMULL) + return SDValue(); + + EVT VT = N->getValueType(0); + SDValue Sub = DAG.getNode(ISD::SUB, SDLoc(N), VT, X, M1); + return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2); +} + static SDValue performAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { @@ -17719,6 +17745,8 @@ static SDValue performAddSubCombine(SDNode *N, return Val; if (SDValue Val = performAddCombineForShiftedOperands(N, DAG)) return Val; + if (SDValue Val = performSubAddMULCombine(N, DAG)) + return Val; return performAddSubLongCombine(N, DCI, DAG); } diff --git a/llvm/test/CodeGen/AArch64/arm64-vmul.ll b/llvm/test/CodeGen/AArch64/arm64-vmul.ll index 7f743f6..3a9f031 100644 --- a/llvm/test/CodeGen/AArch64/arm64-vmul.ll +++ b/llvm/test/CodeGen/AArch64/arm64-vmul.ll @@ -457,12 +457,11 @@ define <2 x i64> @smlsl2d(ptr %A, ptr %B, ptr %C) nounwind { define void @smlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) { ; CHECK-LABEL: smlsl8h_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: smull.8h v0, v0, v2 -; CHECK-NEXT: mvn.8b v2, v2 ; CHECK-NEXT: movi.16b v3, #1 -; CHECK-NEXT: smlal.8h v0, v1, v2 -; CHECK-NEXT: sub.8h v0, v3, v0 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: smlsl.8h v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: smlsl.8h v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <8 x i8> %v3, %smull.1 = tail call <8 x i16> @llvm.aarch64.neon.smull.v8i16(<8 x i8> %v1, <8 x i8> %v3) @@ -476,13 +475,12 @@ define void @smlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, < define void @smlsl2d_chain_with_constant(ptr %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) { ; CHECK-LABEL: smlsl2d_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: smull.2d v0, v0, v2 ; CHECK-NEXT: mov w8, #257 -; CHECK-NEXT: mvn.8b v2, v2 -; CHECK-NEXT: smlal.2d v0, v1, v2 -; CHECK-NEXT: dup.2d v1, x8 -; CHECK-NEXT: sub.2d v0, v1, v0 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: dup.2d v3, x8 +; CHECK-NEXT: smlsl.2d v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: smlsl.2d v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <2 x i32> %v3, %smull.1 = tail call <2 x i64> @llvm.aarch64.neon.smull.v2i64(<2 x i32> %v1, <2 x i32> %v3) @@ -738,12 +736,11 @@ define <2 x i64> @umlsl2d(ptr %A, ptr %B, ptr %C) nounwind { define void @umlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3) { ; CHECK-LABEL: umlsl8h_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: umull.8h v0, v0, v2 -; CHECK-NEXT: mvn.8b v2, v2 ; CHECK-NEXT: movi.16b v3, #1 -; CHECK-NEXT: umlal.8h v0, v1, v2 -; CHECK-NEXT: sub.8h v0, v3, v0 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: umlsl.8h v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: umlsl.8h v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <8 x i8> %v3, %umull.1 = tail call <8 x i16> @llvm.aarch64.neon.umull.v8i16(<8 x i8> %v1, <8 x i8> %v3) @@ -757,13 +754,12 @@ define void @umlsl8h_chain_with_constant(ptr %dst, <8 x i8> %v1, <8 x i8> %v2, < define void @umlsl2d_chain_with_constant(ptr %dst, <2 x i32> %v1, <2 x i32> %v2, <2 x i32> %v3) { ; CHECK-LABEL: umlsl2d_chain_with_constant: ; CHECK: // %bb.0: -; CHECK-NEXT: umull.2d v0, v0, v2 ; CHECK-NEXT: mov w8, #257 -; CHECK-NEXT: mvn.8b v2, v2 -; CHECK-NEXT: umlal.2d v0, v1, v2 -; CHECK-NEXT: dup.2d v1, x8 -; CHECK-NEXT: sub.2d v0, v1, v0 -; CHECK-NEXT: str q0, [x0] +; CHECK-NEXT: dup.2d v3, x8 +; CHECK-NEXT: umlsl.2d v3, v0, v2 +; CHECK-NEXT: mvn.8b v0, v2 +; CHECK-NEXT: umlsl.2d v3, v1, v0 +; CHECK-NEXT: str q3, [x0] ; CHECK-NEXT: ret %xor = xor <2 x i32> %v3, %umull.1 = tail call <2 x i64> @llvm.aarch64.neon.umull.v2i64(<2 x i32> %v1, <2 x i32> %v3) diff --git a/llvm/test/CodeGen/AArch64/reassocmls.ll b/llvm/test/CodeGen/AArch64/reassocmls.ll index cf201ca..731d973 100644 --- a/llvm/test/CodeGen/AArch64/reassocmls.ll +++ b/llvm/test/CodeGen/AArch64/reassocmls.ll @@ -4,9 +4,8 @@ define i64 @smlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) { ; CHECK-LABEL: smlsl_i64: ; CHECK: // %bb.0: -; CHECK-NEXT: smull x8, w4, w3 -; CHECK-NEXT: smaddl x8, w2, w1, x8 -; CHECK-NEXT: sub x0, x0, x8 +; CHECK-NEXT: smsubl x8, w4, w3, x0 +; CHECK-NEXT: smsubl x0, w2, w1, x8 ; CHECK-NEXT: ret %be = sext i32 %b to i64 %ce = sext i32 %c to i64 @@ -22,9 +21,8 @@ define i64 @smlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) { define i64 @umlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) { ; CHECK-LABEL: umlsl_i64: ; CHECK: // %bb.0: -; CHECK-NEXT: umull x8, w4, w3 -; CHECK-NEXT: umaddl x8, w2, w1, x8 -; CHECK-NEXT: sub x0, x0, x8 +; CHECK-NEXT: umsubl x8, w4, w3, x0 +; CHECK-NEXT: umsubl x0, w2, w1, x8 ; CHECK-NEXT: ret %be = zext i32 %b to i64 %ce = zext i32 %c to i64 @@ -40,9 +38,8 @@ define i64 @umlsl_i64(i64 %a, i32 %b, i32 %c, i32 %d, i32 %e) { define i64 @mls_i64(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) { ; CHECK-LABEL: mls_i64: ; CHECK: // %bb.0: -; CHECK-NEXT: mul x8, x2, x1 -; CHECK-NEXT: madd x8, x4, x3, x8 -; CHECK-NEXT: sub x0, x0, x8 +; CHECK-NEXT: msub x8, x4, x3, x0 +; CHECK-NEXT: msub x0, x2, x1, x8 ; CHECK-NEXT: ret %m1.neg = mul i64 %c, %b %m2.neg = mul i64 %e, %d @@ -54,9 +51,8 @@ define i64 @mls_i64(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) { define i16 @mls_i16(i16 %a, i16 %b, i16 %c, i16 %d, i16 %e) { ; CHECK-LABEL: mls_i16: ; CHECK: // %bb.0: -; CHECK-NEXT: mul w8, w2, w1 -; CHECK-NEXT: madd w8, w4, w3, w8 -; CHECK-NEXT: sub w0, w0, w8 +; CHECK-NEXT: msub w8, w4, w3, w0 +; CHECK-NEXT: msub w0, w2, w1, w8 ; CHECK-NEXT: ret %m1.neg = mul i16 %c, %b %m2.neg = mul i16 %e, %d @@ -97,9 +93,8 @@ define i64 @mls_i64_C(i64 %a, i64 %b, i64 %c, i64 %d, i64 %e) { define <8 x i16> @smlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d, <8 x i8> %e) { ; CHECK-LABEL: smlsl_v8i16: ; CHECK: // %bb.0: -; CHECK-NEXT: smull v3.8h, v4.8b, v3.8b -; CHECK-NEXT: smlal v3.8h, v2.8b, v1.8b -; CHECK-NEXT: sub v0.8h, v0.8h, v3.8h +; CHECK-NEXT: smlsl v0.8h, v4.8b, v3.8b +; CHECK-NEXT: smlsl v0.8h, v2.8b, v1.8b ; CHECK-NEXT: ret %be = sext <8 x i8> %b to <8 x i16> %ce = sext <8 x i8> %c to <8 x i16> @@ -115,9 +110,8 @@ define <8 x i16> @smlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> % define <8 x i16> @umlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> %d, <8 x i8> %e) { ; CHECK-LABEL: umlsl_v8i16: ; CHECK: // %bb.0: -; CHECK-NEXT: umull v3.8h, v4.8b, v3.8b -; CHECK-NEXT: umlal v3.8h, v2.8b, v1.8b -; CHECK-NEXT: sub v0.8h, v0.8h, v3.8h +; CHECK-NEXT: umlsl v0.8h, v4.8b, v3.8b +; CHECK-NEXT: umlsl v0.8h, v2.8b, v1.8b ; CHECK-NEXT: ret %be = zext <8 x i8> %b to <8 x i16> %ce = zext <8 x i8> %c to <8 x i16> @@ -133,9 +127,8 @@ define <8 x i16> @umlsl_v8i16(<8 x i16> %a, <8 x i8> %b, <8 x i8> %c, <8 x i8> % define <8 x i16> @mls_v8i16(<8 x i16> %a, <8 x i16> %b, <8 x i16> %c, <8 x i16> %d, <8 x i16> %e) { ; CHECK-LABEL: mls_v8i16: ; CHECK: // %bb.0: -; CHECK-NEXT: mul v1.8h, v2.8h, v1.8h -; CHECK-NEXT: mla v1.8h, v4.8h, v3.8h -; CHECK-NEXT: sub v0.8h, v0.8h, v1.8h +; CHECK-NEXT: mls v0.8h, v4.8h, v3.8h +; CHECK-NEXT: mls v0.8h, v2.8h, v1.8h ; CHECK-NEXT: ret %m1.neg = mul <8 x i16> %c, %b %m2.neg = mul <8 x i16> %e, %d @@ -157,6 +150,20 @@ define <8 x i16> @mla_v8i16(<8 x i16> %a, <8 x i16> %b, <8 x i16> %c, <8 x i16> ret <8 x i16> %s2 } +define <8 x i16> @mls_v8i16_C(<8 x i16> %a, <8 x i16> %b, <8 x i16> %c, <8 x i16> %d, <8 x i16> %e) { +; CHECK-LABEL: mls_v8i16_C: +; CHECK: // %bb.0: +; CHECK-NEXT: movi v0.8h, #10 +; CHECK-NEXT: mls v0.8h, v4.8h, v3.8h +; CHECK-NEXT: mls v0.8h, v2.8h, v1.8h +; CHECK-NEXT: ret + %m1.neg = mul <8 x i16> %c, %b + %m2.neg = mul <8 x i16> %e, %d + %reass.add = add <8 x i16> %m2.neg, %m1.neg + %s2 = sub <8 x i16> , %reass.add + ret <8 x i16> %s2 +} + define @smlsl_nxv8i16( %a, %b, %c, %d, %e) { ; CHECK-LABEL: smlsl_nxv8i16: @@ -166,9 +173,8 @@ define @smlsl_nxv8i16( %a, %b to %ce = sext %c to @@ -184,14 +190,13 @@ define @smlsl_nxv8i16( %a, @umlsl_nxv8i16( %a, %b, %c, %d, %e) { ; CHECK-LABEL: umlsl_nxv8i16: ; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.h ; CHECK-NEXT: and z3.h, z3.h, #0xff ; CHECK-NEXT: and z4.h, z4.h, #0xff -; CHECK-NEXT: ptrue p0.h ; CHECK-NEXT: and z1.h, z1.h, #0xff ; CHECK-NEXT: and z2.h, z2.h, #0xff -; CHECK-NEXT: mul z3.h, z4.h, z3.h -; CHECK-NEXT: mla z3.h, p0/m, z2.h, z1.h -; CHECK-NEXT: sub z0.h, z0.h, z3.h +; CHECK-NEXT: mls z0.h, p0/m, z4.h, z3.h +; CHECK-NEXT: mls z0.h, p0/m, z2.h, z1.h ; CHECK-NEXT: ret %be = zext %b to %ce = zext %c to @@ -208,9 +213,8 @@ define @mls_nxv8i16( %a, ; CHECK-LABEL: mls_nxv8i16: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p0.h -; CHECK-NEXT: mul z3.h, z4.h, z3.h -; CHECK-NEXT: mla z3.h, p0/m, z2.h, z1.h -; CHECK-NEXT: sub z0.h, z0.h, z3.h +; CHECK-NEXT: mls z0.h, p0/m, z4.h, z3.h +; CHECK-NEXT: mls z0.h, p0/m, z2.h, z1.h ; CHECK-NEXT: ret %m1.neg = mul %c, %b %m2.neg = mul %e, %d -- 2.7.4