From c56846a8928f8708f56c0eb36dcd6345e312faa0 Mon Sep 17 00:00:00 2001 From: David Green Date: Sun, 5 Feb 2023 20:59:49 +0000 Subject: [PATCH] [ARM] Remove FlattenVectorShuffle and add PerformVQDMULHCombine. This removes the FlattenVectorShuffle that folds shuffles through certain binops. This is now handled by generic DAG combines for all but ARMISD::VQDMULH where a PerformVQDMULHCombine is added to compensate. It pushes identical shuffles down through the operation, in a similar way to the other combines in DAG. --- llvm/lib/Target/ARM/ARMISelLowering.cpp | 70 +++++++++++---------------------- 1 file changed, 23 insertions(+), 47 deletions(-) diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index 07fa829..24ac30c 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -15468,51 +15468,6 @@ static SDValue PerformSignExtendInregCombine(SDNode *N, SelectionDAG &DAG) { return SDValue(); } -// When lowering complex nodes that we recognize, like VQDMULH and MULH, we -// can end up with shuffle(binop(shuffle, shuffle)), that can be simplified to -// binop as the shuffles cancel out. -static SDValue FlattenVectorShuffle(ShuffleVectorSDNode *N, SelectionDAG &DAG) { - EVT VT = N->getValueType(0); - if (!N->getOperand(1).isUndef() || N->getOperand(0).getValueType() != VT) - return SDValue(); - SDValue Op = N->getOperand(0); - - // Looking for binary operators that will have been folded from - // truncates/extends. - switch (Op.getOpcode()) { - case ARMISD::VQDMULH: - case ISD::MULHS: - case ISD::MULHU: - case ISD::ABDS: - case ISD::ABDU: - case ISD::AVGFLOORS: - case ISD::AVGFLOORU: - case ISD::AVGCEILS: - case ISD::AVGCEILU: - break; - default: - return SDValue(); - } - - ShuffleVectorSDNode *Op0 = dyn_cast(Op.getOperand(0)); - ShuffleVectorSDNode *Op1 = dyn_cast(Op.getOperand(1)); - if (!Op0 || !Op1 || !Op0->getOperand(1).isUndef() || - !Op1->getOperand(1).isUndef() || Op0->getMask() != Op1->getMask() || - Op0->getOperand(0).getValueType() != VT) - return SDValue(); - - // Check the mask turns into an identity shuffle. - ArrayRef NMask = N->getMask(); - ArrayRef OpMask = Op0->getMask(); - for (int i = 0, e = NMask.size(); i != e; i++) { - if (NMask[i] > 0 && OpMask[NMask[i]] > 0 && OpMask[NMask[i]] != i) - return SDValue(); - } - - return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), - Op0->getOperand(0), Op1->getOperand(0)); -} - static SDValue PerformInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { SDValue Vec = N->getOperand(0); @@ -15581,8 +15536,6 @@ static SDValue PerformShuffleVMOVNCombine(ShuffleVectorSDNode *N, /// PerformVECTOR_SHUFFLECombine - Target-specific dag combine xforms for /// ISD::VECTOR_SHUFFLE. static SDValue PerformVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG) { - if (SDValue R = FlattenVectorShuffle(cast(N), DAG)) - return R; if (SDValue R = PerformShuffleVMOVNCombine(cast(N), DAG)) return R; @@ -17227,6 +17180,27 @@ static SDValue PerformVQMOVNCombine(SDNode *N, return SDValue(); } +static SDValue PerformVQDMULHCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + EVT VT = N->getValueType(0); + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + + auto *Shuf0 = dyn_cast(LHS); + auto *Shuf1 = dyn_cast(RHS); + // Turn VQDMULH(shuffle, shuffle) -> shuffle(VQDMULH) + if (Shuf0 && Shuf1 && Shuf0->getMask().equals(Shuf1->getMask()) && + LHS.getOperand(1).isUndef() && RHS.getOperand(1).isUndef() && + (LHS.hasOneUse() || RHS.hasOneUse() || LHS == RHS)) { + SDLoc DL(N); + SDValue NewBinOp = DCI.DAG.getNode(N->getOpcode(), DL, VT, + LHS.getOperand(0), RHS.getOperand(0)); + SDValue UndefV = LHS.getOperand(1); + return DCI.DAG.getVectorShuffle(VT, DL, NewBinOp, UndefV, Shuf0->getMask()); + } + return SDValue(); +} + static SDValue PerformLongShiftCombine(SDNode *N, SelectionDAG &DAG) { SDLoc DL(N); SDValue Op0 = N->getOperand(0); @@ -18755,6 +18729,8 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N, case ARMISD::VQMOVNs: case ARMISD::VQMOVNu: return PerformVQMOVNCombine(N, DCI); + case ARMISD::VQDMULH: + return PerformVQDMULHCombine(N, DCI); case ARMISD::ASRL: case ARMISD::LSRL: case ARMISD::LSLL: -- 2.7.4