From 3b9833597e810d4c485487d2f094a8e223af5548 Mon Sep 17 00:00:00 2001 From: David Green Date: Mon, 4 Apr 2022 23:07:47 +0100 Subject: [PATCH] [AArch64] Alter mull buildvectors(ext(..)) combine to work on shuffles D120018 altered this combine to work on buildvectors as opposed to shuffle dup's. This works well for dups and other things that are expanded into buildvectors. Some shuffles are legal though, and stay as vector_shuffle through lowering. This expands the transform to also handle shuffles, so that we can turn mul(shuffle(sext into mul(sext(shuffle and more readily make smull/umull instructions. This can come up from the SLP vectorizer adding shuffles that are costed from extends. Differential Revision: https://reviews.llvm.org/D123012 --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 54 +++++++++++++++++-------- llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll | 52 ++++++------------------ 2 files changed, 49 insertions(+), 57 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 38079a1..f6e1c61 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -13629,15 +13629,17 @@ static EVT calculatePreExtendType(SDValue Extend) { } } -/// Combines a buildvector(sext/zext) node pattern into sext/zext(buildvector) -/// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt -static SDValue performBuildVectorExtendCombine(SDValue BV, SelectionDAG &DAG) { +/// Combines a buildvector(sext/zext) or shuffle(sext/zext, undef) node pattern +/// into sext/zext(buildvector) or sext/zext(shuffle) making use of the vector +/// SExt/ZExt rather than the scalar SExt/ZExt +static SDValue performBuildShuffleExtendCombine(SDValue BV, SelectionDAG &DAG) { EVT VT = BV.getValueType(); - if (BV.getOpcode() != ISD::BUILD_VECTOR) + if (BV.getOpcode() != ISD::BUILD_VECTOR && + BV.getOpcode() != ISD::VECTOR_SHUFFLE) return SDValue(); - // Use the first item in the buildvector to get the size of the extend, and - // make sure it looks valid. + // Use the first item in the buildvector/shuffle to get the size of the + // extend, and make sure it looks valid. SDValue Extend = BV->getOperand(0); unsigned ExtendOpcode = Extend.getOpcode(); bool IsSExt = ExtendOpcode == ISD::SIGN_EXTEND || @@ -13646,15 +13648,22 @@ static SDValue performBuildVectorExtendCombine(SDValue BV, SelectionDAG &DAG) { if (!IsSExt && ExtendOpcode != ISD::ZERO_EXTEND && ExtendOpcode != ISD::AssertZext && ExtendOpcode != ISD::AND) return SDValue(); + // Shuffle inputs are vector, limit to SIGN_EXTEND and ZERO_EXTEND to ensure + // calculatePreExtendType will work without issue. + if (BV.getOpcode() == ISD::VECTOR_SHUFFLE && + ExtendOpcode != ISD::SIGN_EXTEND && ExtendOpcode != ISD::ZERO_EXTEND) + return SDValue(); // Restrict valid pre-extend data type EVT PreExtendType = calculatePreExtendType(Extend); if (PreExtendType == MVT::Other || - PreExtendType.getSizeInBits() != VT.getScalarSizeInBits() / 2) + PreExtendType.getScalarSizeInBits() != VT.getScalarSizeInBits() / 2) return SDValue(); // Make sure all other operands are equally extended for (SDValue Op : drop_begin(BV->ops())) { + if (Op.isUndef()) + continue; unsigned Opc = Op.getOpcode(); bool OpcIsSExt = Opc == ISD::SIGN_EXTEND || Opc == ISD::SIGN_EXTEND_INREG || Opc == ISD::AssertSext; @@ -13662,15 +13671,26 @@ static SDValue performBuildVectorExtendCombine(SDValue BV, SelectionDAG &DAG) { return SDValue(); } - EVT PreExtendVT = VT.changeVectorElementType(PreExtendType); - EVT PreExtendLegalType = - PreExtendType.getScalarSizeInBits() < 32 ? MVT::i32 : PreExtendType; + SDValue NBV; SDLoc DL(BV); - SmallVector NewOps; - for (SDValue Op : BV->ops()) - NewOps.push_back( - DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, PreExtendLegalType)); - SDValue NBV = DAG.getNode(ISD::BUILD_VECTOR, DL, PreExtendVT, NewOps); + if (BV.getOpcode() == ISD::BUILD_VECTOR) { + EVT PreExtendVT = VT.changeVectorElementType(PreExtendType); + EVT PreExtendLegalType = + PreExtendType.getScalarSizeInBits() < 32 ? MVT::i32 : PreExtendType; + SmallVector NewOps; + for (SDValue Op : BV->ops()) + NewOps.push_back(Op.isUndef() ? DAG.getUNDEF(PreExtendLegalType) + : DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, + PreExtendLegalType)); + NBV = DAG.getNode(ISD::BUILD_VECTOR, DL, PreExtendVT, NewOps); + } else { // BV.getOpcode() == ISD::VECTOR_SHUFFLE + EVT PreExtendVT = VT.changeVectorElementType(PreExtendType.getScalarType()); + NBV = DAG.getVectorShuffle(PreExtendVT, DL, BV.getOperand(0).getOperand(0), + BV.getOperand(1).isUndef() + ? DAG.getUNDEF(PreExtendVT) + : BV.getOperand(1).getOperand(0), + cast(BV)->getMask()); + } return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, VT, NBV); } @@ -13682,8 +13702,8 @@ static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) { if (VT != MVT::v8i16 && VT != MVT::v4i32 && VT != MVT::v2i64) return SDValue(); - SDValue Op0 = performBuildVectorExtendCombine(Mul->getOperand(0), DAG); - SDValue Op1 = performBuildVectorExtendCombine(Mul->getOperand(1), DAG); + SDValue Op0 = performBuildShuffleExtendCombine(Mul->getOperand(0), DAG); + SDValue Op1 = performBuildShuffleExtendCombine(Mul->getOperand(1), DAG); // Neither operands have been changed, don't make any further changes if (!Op0 && !Op1) diff --git a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll index 2146221..e0d7759 100644 --- a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll +++ b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll @@ -245,9 +245,8 @@ entry: define <8 x i16> @missing_insert(<8 x i8> %b) { ; CHECK-LABEL: missing_insert: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sshll v0.8h, v0.8b, #0 -; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #4 -; CHECK-NEXT: mul v0.8h, v1.8h, v0.8h +; CHECK-NEXT: ext v1.8b, v0.8b, v0.8b, #2 +; CHECK-NEXT: smull v0.8h, v1.8b, v0.8b ; CHECK-NEXT: ret entry: %ext.b = sext <8 x i8> %b to <8 x i16> @@ -259,11 +258,8 @@ entry: define <8 x i16> @shufsext_v8i8_v8i16(<8 x i8> %src, <8 x i8> %b) { ; CHECK-LABEL: shufsext_v8i8_v8i16: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sshll v0.8h, v0.8b, #0 -; CHECK-NEXT: sshll v1.8h, v1.8b, #0 -; CHECK-NEXT: rev64 v0.8h, v0.8h -; CHECK-NEXT: ext v0.16b, v0.16b, v0.16b, #8 -; CHECK-NEXT: mul v0.8h, v0.8h, v1.8h +; CHECK-NEXT: rev64 v0.8b, v0.8b +; CHECK-NEXT: smull v0.8h, v0.8b, v1.8b ; CHECK-NEXT: ret entry: %in = sext <8 x i8> %src to <8 x i16> @@ -276,17 +272,8 @@ entry: define <2 x i64> @shufsext_v2i32_v2i64(<2 x i32> %src, <2 x i32> %b) { ; CHECK-LABEL: shufsext_v2i32_v2i64: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sshll v0.2d, v0.2s, #0 -; CHECK-NEXT: sshll v1.2d, v1.2s, #0 -; CHECK-NEXT: ext v0.16b, v0.16b, v0.16b, #8 -; CHECK-NEXT: fmov x9, d1 -; CHECK-NEXT: mov x8, v1.d[1] -; CHECK-NEXT: fmov x10, d0 -; CHECK-NEXT: mov x11, v0.d[1] -; CHECK-NEXT: mul x9, x10, x9 -; CHECK-NEXT: mul x8, x11, x8 -; CHECK-NEXT: fmov d0, x9 -; CHECK-NEXT: mov v0.d[1], x8 +; CHECK-NEXT: rev64 v0.2s, v0.2s +; CHECK-NEXT: smull v0.2d, v0.2s, v1.2s ; CHECK-NEXT: ret entry: %in = sext <2 x i32> %src to <2 x i64> @@ -299,11 +286,8 @@ entry: define <8 x i16> @shufzext_v8i8_v8i16(<8 x i8> %src, <8 x i8> %b) { ; CHECK-LABEL: shufzext_v8i8_v8i16: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: ushll v1.8h, v1.8b, #0 -; CHECK-NEXT: rev64 v0.8h, v0.8h -; CHECK-NEXT: ext v0.16b, v0.16b, v0.16b, #8 -; CHECK-NEXT: mul v0.8h, v0.8h, v1.8h +; CHECK-NEXT: rev64 v0.8b, v0.8b +; CHECK-NEXT: umull v0.8h, v0.8b, v1.8b ; CHECK-NEXT: ret entry: %in = zext <8 x i8> %src to <8 x i16> @@ -316,17 +300,8 @@ entry: define <2 x i64> @shufzext_v2i32_v2i64(<2 x i32> %src, <2 x i32> %b) { ; CHECK-LABEL: shufzext_v2i32_v2i64: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: sshll v0.2d, v0.2s, #0 -; CHECK-NEXT: sshll v1.2d, v1.2s, #0 -; CHECK-NEXT: ext v0.16b, v0.16b, v0.16b, #8 -; CHECK-NEXT: fmov x9, d1 -; CHECK-NEXT: mov x8, v1.d[1] -; CHECK-NEXT: fmov x10, d0 -; CHECK-NEXT: mov x11, v0.d[1] -; CHECK-NEXT: mul x9, x10, x9 -; CHECK-NEXT: mul x8, x11, x8 -; CHECK-NEXT: fmov d0, x9 -; CHECK-NEXT: mov v0.d[1], x8 +; CHECK-NEXT: rev64 v0.2s, v0.2s +; CHECK-NEXT: smull v0.2d, v0.2s, v1.2s ; CHECK-NEXT: ret entry: %in = sext <2 x i32> %src to <2 x i64> @@ -339,11 +314,8 @@ entry: define <8 x i16> @shufzext_v8i8_v8i16_twoin(<8 x i8> %src1, <8 x i8> %src2, <8 x i8> %b) { ; CHECK-LABEL: shufzext_v8i8_v8i16_twoin: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: ushll v0.8h, v0.8b, #0 -; CHECK-NEXT: ushll v1.8h, v1.8b, #0 -; CHECK-NEXT: trn1 v0.8h, v0.8h, v1.8h -; CHECK-NEXT: ushll v1.8h, v2.8b, #0 -; CHECK-NEXT: mul v0.8h, v0.8h, v1.8h +; CHECK-NEXT: trn1 v0.8b, v0.8b, v1.8b +; CHECK-NEXT: umull v0.8h, v0.8b, v2.8b ; CHECK-NEXT: ret entry: %in1 = zext <8 x i8> %src1 to <8 x i16> -- 2.7.4