From 6c7d713cf5d9bb188f1e73452a256386f0288bf7 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Tue, 6 Oct 2020 17:32:35 +0100 Subject: [PATCH] [X86][SSE] combineX86ShuffleChain add 'CanonicalizeShuffleInput' helper. NFCI. As part of PR45974, we're getting closer to not creating 'padded' vectors on-the-fly in combineX86ShufflesRecursively, and only pad the source inputs if we have a definite match inside combineX86ShuffleChain. At the moment combineX86ShuffleChain just has to bitcast an input to the correct shuffle type, but eventually we'll need to pad them as well. So, move the bitcast into a 'CanonicalizeShuffleInput helper for now, making the diff for future padding support a lot smaller. --- llvm/lib/Target/X86/X86ISelLowering.cpp | 100 +++++++++++++++++--------------- 1 file changed, 54 insertions(+), 46 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index bd80812..66986a1 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -35013,6 +35013,12 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, unsigned RootSizeInBits = RootVT.getSizeInBits(); unsigned NumRootElts = RootVT.getVectorNumElements(); + // Canonicalize shuffle input op to the requested type. + // TODO: Support cases where Op is smaller than VT. + auto CanonicalizeShuffleInput = [&](MVT VT, SDValue Op) { + return DAG.getBitcast(VT, Op); + }; + // Find the inputs that enter the chain. Note that multiple uses are OK // here, we're not going to remove the operands we find. bool UnaryShuffle = (Inputs.size() == 1); @@ -35031,7 +35037,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, unsigned NumBaseMaskElts = BaseMask.size(); if (NumBaseMaskElts == 1) { assert(BaseMask[0] == 0 && "Invalid shuffle index found!"); - return DAG.getBitcast(RootVT, V1); + return CanonicalizeShuffleInput(RootVT, V1); } bool OptForSize = DAG.shouldOptForSize(); @@ -35055,8 +35061,9 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, // we can just use the broadcast directly. This works for smaller broadcast // elements as well as they already repeat across each mask element if (UnaryShuffle && isTargetShuffleSplat(V1) && !isAnyZero(BaseMask) && - (BaseMaskEltSizeInBits % V1.getScalarValueSizeInBits()) == 0) { - return DAG.getBitcast(RootVT, V1); + (BaseMaskEltSizeInBits % V1.getScalarValueSizeInBits()) == 0 && + V1.getValueSizeInBits() >= RootSizeInBits) { + return CanonicalizeShuffleInput(RootVT, V1); } // Attempt to match a subvector broadcast. @@ -35089,7 +35096,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, return SDValue(); // Nothing to do! assert(isInRange(BaseMask[0], 0, NumBaseMaskElts) && "Unexpected lane shuffle"); - Res = DAG.getBitcast(ShuffleVT, V1); + Res = CanonicalizeShuffleInput(ShuffleVT, V1); unsigned SubIdx = BaseMask[0] * (8 / NumBaseMaskElts); bool UseZero = isAnyZero(BaseMask); Res = extractSubVector(Res, SubIdx, DAG, DL, BaseMaskEltSizeInBits); @@ -35103,8 +35110,8 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, narrowShuffleMaskElts(BaseMaskEltSizeInBits / 128, BaseMask, Mask); // Try to lower to vshuf64x2/vshuf32x4. - auto MatchSHUF128 = [](MVT ShuffleVT, const SDLoc &DL, ArrayRef Mask, - SDValue V1, SDValue V2, SelectionDAG &DAG) { + auto MatchSHUF128 = [&](MVT ShuffleVT, const SDLoc &DL, ArrayRef Mask, + SDValue V1, SDValue V2, SelectionDAG &DAG) { unsigned PermMask = 0; // Insure elements came from the same Op. SDValue Ops[2] = {DAG.getUNDEF(ShuffleVT), DAG.getUNDEF(ShuffleVT)}; @@ -35127,8 +35134,8 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, } return DAG.getNode(X86ISD::SHUF128, DL, ShuffleVT, - DAG.getBitcast(ShuffleVT, Ops[0]), - DAG.getBitcast(ShuffleVT, Ops[1]), + CanonicalizeShuffleInput(ShuffleVT, Ops[0]), + CanonicalizeShuffleInput(ShuffleVT, Ops[1]), DAG.getTargetConstant(PermMask, DL, MVT::i8)); }; @@ -35161,7 +35168,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, if (Depth == 0 && Root.getOpcode() == ISD::INSERT_SUBVECTOR) return SDValue(); // Nothing to do! assert(isInRange(BaseMask[0], 0, 2) && "Unexpected lane shuffle"); - Res = DAG.getBitcast(ShuffleVT, V1); + Res = CanonicalizeShuffleInput(ShuffleVT, V1); Res = extract128BitVector(Res, BaseMask[0] * 2, DAG, DL); Res = widenSubVector(Res, BaseMask[1] == SM_SentinelZero, Subtarget, DAG, DL, 256); @@ -35181,7 +35188,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, PermMask |= ((BaseMask[0] < 0 ? 0x8 : (BaseMask[0] & 1)) << 0); PermMask |= ((BaseMask[1] < 0 ? 0x8 : (BaseMask[1] & 1)) << 4); - Res = DAG.getBitcast(ShuffleVT, V1); + Res = CanonicalizeShuffleInput(ShuffleVT, V1); Res = DAG.getNode(X86ISD::VPERM2X128, DL, ShuffleVT, Res, DAG.getUNDEF(ShuffleVT), DAG.getTargetConstant(PermMask, DL, MVT::i8)); @@ -35202,11 +35209,12 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, PermMask |= ((BaseMask[0] & 3) << 0); PermMask |= ((BaseMask[1] & 3) << 4); - Res = DAG.getNode( - X86ISD::VPERM2X128, DL, ShuffleVT, - DAG.getBitcast(ShuffleVT, isInRange(BaseMask[0], 0, 2) ? V1 : V2), - DAG.getBitcast(ShuffleVT, isInRange(BaseMask[1], 0, 2) ? V1 : V2), - DAG.getTargetConstant(PermMask, DL, MVT::i8)); + SDValue LHS = isInRange(BaseMask[0], 0, 2) ? V1 : V2; + SDValue RHS = isInRange(BaseMask[1], 0, 2) ? V1 : V2; + Res = DAG.getNode(X86ISD::VPERM2X128, DL, ShuffleVT, + CanonicalizeShuffleInput(ShuffleVT, LHS), + CanonicalizeShuffleInput(ShuffleVT, RHS), + DAG.getTargetConstant(PermMask, DL, MVT::i8)); return DAG.getBitcast(RootVT, Res); } } @@ -35282,7 +35290,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, if (Subtarget.hasAVX2()) { if (Depth == 0 && Root.getOpcode() == X86ISD::VBROADCAST) return SDValue(); // Nothing to do! - Res = DAG.getBitcast(MaskVT, V1); + Res = CanonicalizeShuffleInput(MaskVT, V1); Res = DAG.getNode(X86ISD::VBROADCAST, DL, MaskVT, Res); return DAG.getBitcast(RootVT, Res); } @@ -35297,7 +35305,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 0 && Root.getOpcode() == Shuffle) return SDValue(); // Nothing to do! - Res = DAG.getBitcast(ShuffleSrcVT, NewV1); + Res = CanonicalizeShuffleInput(ShuffleSrcVT, NewV1); Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res); return DAG.getBitcast(RootVT, Res); } @@ -35309,7 +35317,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 0 && Root.getOpcode() == Shuffle) return SDValue(); // Nothing to do! - Res = DAG.getBitcast(ShuffleVT, V1); + Res = CanonicalizeShuffleInput(ShuffleVT, V1); Res = DAG.getNode(Shuffle, DL, ShuffleVT, Res, DAG.getTargetConstant(PermuteImm, DL, MVT::i8)); return DAG.getBitcast(RootVT, Res); @@ -35330,8 +35338,8 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, if (Depth == 0 && Root.getOpcode() == X86ISD::INSERTPS) return SDValue(); // Nothing to do! Res = DAG.getNode(X86ISD::INSERTPS, DL, MVT::v4f32, - DAG.getBitcast(MVT::v4f32, SrcV1), - DAG.getBitcast(MVT::v4f32, SrcV2), + CanonicalizeShuffleInput(MVT::v4f32, SrcV1), + CanonicalizeShuffleInput(MVT::v4f32, SrcV2), DAG.getTargetConstant(PermuteImm, DL, MVT::i8)); return DAG.getBitcast(RootVT, Res); } @@ -35344,8 +35352,8 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, return SDValue(); // Nothing to do! PermuteImm = (/*DstIdx*/2 << 4) | (/*SrcIdx*/0 << 0); Res = DAG.getNode(X86ISD::INSERTPS, DL, MVT::v4f32, - DAG.getBitcast(MVT::v4f32, V1), - DAG.getBitcast(MVT::v4f32, V2), + CanonicalizeShuffleInput(MVT::v4f32, V1), + CanonicalizeShuffleInput(MVT::v4f32, V2), DAG.getTargetConstant(PermuteImm, DL, MVT::i8)); return DAG.getBitcast(RootVT, Res); } @@ -35359,8 +35367,8 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, (!IsMaskedShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 0 && Root.getOpcode() == Shuffle) return SDValue(); // Nothing to do! - NewV1 = DAG.getBitcast(ShuffleSrcVT, NewV1); - NewV2 = DAG.getBitcast(ShuffleSrcVT, NewV2); + NewV1 = CanonicalizeShuffleInput(ShuffleSrcVT, NewV1); + NewV2 = CanonicalizeShuffleInput(ShuffleSrcVT, NewV2); Res = DAG.getNode(Shuffle, DL, ShuffleVT, NewV1, NewV2); return DAG.getBitcast(RootVT, Res); } @@ -35373,8 +35381,8 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, (!IsMaskedShuffle || (NumRootElts == ShuffleVT.getVectorNumElements()))) { if (Depth == 0 && Root.getOpcode() == Shuffle) return SDValue(); // Nothing to do! - NewV1 = DAG.getBitcast(ShuffleVT, NewV1); - NewV2 = DAG.getBitcast(ShuffleVT, NewV2); + NewV1 = CanonicalizeShuffleInput(ShuffleVT, NewV1); + NewV2 = CanonicalizeShuffleInput(ShuffleVT, NewV2); Res = DAG.getNode(Shuffle, DL, ShuffleVT, NewV1, NewV2, DAG.getTargetConstant(PermuteImm, DL, MVT::i8)); return DAG.getBitcast(RootVT, Res); @@ -35391,7 +35399,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, Zeroable)) { if (Depth == 0 && Root.getOpcode() == X86ISD::EXTRQI) return SDValue(); // Nothing to do! - V1 = DAG.getBitcast(IntMaskVT, V1); + V1 = CanonicalizeShuffleInput(IntMaskVT, V1); Res = DAG.getNode(X86ISD::EXTRQI, DL, IntMaskVT, V1, DAG.getTargetConstant(BitLen, DL, MVT::i8), DAG.getTargetConstant(BitIdx, DL, MVT::i8)); @@ -35401,8 +35409,8 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, if (matchShuffleAsINSERTQ(IntMaskVT, V1, V2, Mask, BitLen, BitIdx)) { if (Depth == 0 && Root.getOpcode() == X86ISD::INSERTQI) return SDValue(); // Nothing to do! - V1 = DAG.getBitcast(IntMaskVT, V1); - V2 = DAG.getBitcast(IntMaskVT, V2); + V1 = CanonicalizeShuffleInput(IntMaskVT, V1); + V2 = CanonicalizeShuffleInput(IntMaskVT, V2); Res = DAG.getNode(X86ISD::INSERTQI, DL, IntMaskVT, V1, V2, DAG.getTargetConstant(BitLen, DL, MVT::i8), DAG.getTargetConstant(BitIdx, DL, MVT::i8)); @@ -35421,7 +35429,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, IsTRUNCATE ? (unsigned)ISD::TRUNCATE : (unsigned)X86ISD::VTRUNC; if (Depth == 0 && Root.getOpcode() == Opc) return SDValue(); // Nothing to do! - V1 = DAG.getBitcast(ShuffleSrcVT, V1); + V1 = CanonicalizeShuffleInput(ShuffleSrcVT, V1); Res = DAG.getNode(Opc, DL, ShuffleVT, V1); if (ShuffleVT.getSizeInBits() < RootSizeInBits) Res = widenSubVector(Res, true, Subtarget, DAG, DL, RootSizeInBits); @@ -35438,8 +35446,8 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, return SDValue(); // Nothing to do! ShuffleSrcVT = MVT::getIntegerVT(MaskEltSizeInBits * 2); ShuffleSrcVT = MVT::getVectorVT(ShuffleSrcVT, NumMaskElts / 2); - V1 = DAG.getBitcast(ShuffleSrcVT, V1); - V2 = DAG.getBitcast(ShuffleSrcVT, V2); + V1 = CanonicalizeShuffleInput(ShuffleSrcVT, V1); + V2 = CanonicalizeShuffleInput(ShuffleSrcVT, V2); ShuffleSrcVT = MVT::getIntegerVT(MaskEltSizeInBits * 2); ShuffleSrcVT = MVT::getVectorVT(ShuffleSrcVT, NumMaskElts); Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, ShuffleSrcVT, V1, V2); @@ -35468,7 +35476,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, if (Subtarget.hasAVX2() && (MaskVT == MVT::v8f32 || MaskVT == MVT::v8i32)) { SDValue VPermMask = getConstVector(Mask, IntMaskVT, DAG, DL, true); - Res = DAG.getBitcast(MaskVT, V1); + Res = CanonicalizeShuffleInput(MaskVT, V1); Res = DAG.getNode(X86ISD::VPERMV, DL, MaskVT, VPermMask, Res); return DAG.getBitcast(RootVT, Res); } @@ -35480,7 +35488,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, (MaskVT == MVT::v16i16 || MaskVT == MVT::v32i16)) || (Subtarget.hasVBMI() && (MaskVT == MVT::v32i8 || MaskVT == MVT::v64i8))) { - V1 = DAG.getBitcast(MaskVT, V1); + V1 = CanonicalizeShuffleInput(MaskVT, V1); V2 = DAG.getUNDEF(MaskVT); Res = lowerShuffleWithPERMV(DL, MaskVT, Mask, V1, V2, Subtarget, DAG); return DAG.getBitcast(RootVT, Res); @@ -35503,7 +35511,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, for (unsigned i = 0; i != NumMaskElts; ++i) if (Mask[i] == SM_SentinelZero) Mask[i] = NumMaskElts + i; - V1 = DAG.getBitcast(MaskVT, V1); + V1 = CanonicalizeShuffleInput(MaskVT, V1); V2 = getZeroVector(MaskVT, Subtarget, DAG, DL); Res = lowerShuffleWithPERMV(DL, MaskVT, Mask, V1, V2, Subtarget, DAG); return DAG.getBitcast(RootVT, Res); @@ -35528,8 +35536,8 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, (MaskVT == MVT::v16i16 || MaskVT == MVT::v32i16)) || (Subtarget.hasVBMI() && AllowBWIVPERMV3 && (MaskVT == MVT::v32i8 || MaskVT == MVT::v64i8)))) { - V1 = DAG.getBitcast(MaskVT, V1); - V2 = DAG.getBitcast(MaskVT, V2); + V1 = CanonicalizeShuffleInput(MaskVT, V1); + V2 = CanonicalizeShuffleInput(MaskVT, V2); Res = lowerShuffleWithPERMV(DL, MaskVT, Mask, V1, V2, Subtarget, DAG); return DAG.getBitcast(RootVT, Res); } @@ -35556,7 +35564,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, EltBits[i] = AllOnes; } SDValue BitMask = getConstVector(EltBits, UndefElts, MaskVT, DAG, DL); - Res = DAG.getBitcast(MaskVT, V1); + Res = CanonicalizeShuffleInput(MaskVT, V1); unsigned AndOpcode = MaskVT.isFloatingPoint() ? unsigned(X86ISD::FAND) : unsigned(ISD::AND); Res = DAG.getNode(AndOpcode, DL, MaskVT, Res, BitMask); @@ -35576,7 +35584,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, VPermIdx.push_back(Idx); } SDValue VPermMask = DAG.getBuildVector(IntMaskVT, DL, VPermIdx); - Res = DAG.getBitcast(MaskVT, V1); + Res = CanonicalizeShuffleInput(MaskVT, V1); Res = DAG.getNode(X86ISD::VPERMILPV, DL, MaskVT, Res, VPermMask); return DAG.getBitcast(RootVT, Res); } @@ -35608,8 +35616,8 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, Index = (MaskVT.getScalarSizeInBits() == 64 ? Index << 1 : Index); VPerm2Idx.push_back(Index); } - V1 = DAG.getBitcast(MaskVT, V1); - V2 = DAG.getBitcast(MaskVT, V2); + V1 = CanonicalizeShuffleInput(MaskVT, V1); + V2 = CanonicalizeShuffleInput(MaskVT, V2); SDValue VPerm2MaskOp = getConstVector(VPerm2Idx, IntMaskVT, DAG, DL, true); Res = DAG.getNode(X86ISD::VPERMIL2, DL, MaskVT, V1, V2, VPerm2MaskOp, DAG.getTargetConstant(M2ZImm, DL, MVT::i8)); @@ -35643,7 +35651,7 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, PSHUFBMask.push_back(DAG.getConstant(M, DL, MVT::i8)); } MVT ByteVT = MVT::getVectorVT(MVT::i8, NumBytes); - Res = DAG.getBitcast(ByteVT, V1); + Res = CanonicalizeShuffleInput(ByteVT, V1); SDValue PSHUFBMaskOp = DAG.getBuildVector(ByteVT, DL, PSHUFBMask); Res = DAG.getNode(X86ISD::PSHUFB, DL, ByteVT, Res, PSHUFBMaskOp); return DAG.getBitcast(RootVT, Res); @@ -35673,8 +35681,8 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, VPPERMMask.push_back(DAG.getConstant(M, DL, MVT::i8)); } MVT ByteVT = MVT::v16i8; - V1 = DAG.getBitcast(ByteVT, V1); - V2 = DAG.getBitcast(ByteVT, V2); + V1 = CanonicalizeShuffleInput(ByteVT, V1); + V2 = CanonicalizeShuffleInput(ByteVT, V2); SDValue VPPERMMaskOp = DAG.getBuildVector(ByteVT, DL, VPPERMMask); Res = DAG.getNode(X86ISD::VPPERM, DL, ByteVT, V1, V2, VPPERMMaskOp); return DAG.getBitcast(RootVT, Res); @@ -35700,8 +35708,8 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, (MaskVT == MVT::v8i16 || MaskVT == MVT::v16i16 || MaskVT == MVT::v32i16)) || (Subtarget.hasVBMI() && AllowBWIVPERMV3 && (MaskVT == MVT::v16i8 || MaskVT == MVT::v32i8 || MaskVT == MVT::v64i8)))) { - V1 = DAG.getBitcast(MaskVT, V1); - V2 = DAG.getBitcast(MaskVT, V2); + V1 = CanonicalizeShuffleInput(MaskVT, V1); + V2 = CanonicalizeShuffleInput(MaskVT, V2); Res = lowerShuffleWithPERMV(DL, MaskVT, Mask, V1, V2, Subtarget, DAG); return DAG.getBitcast(RootVT, Res); } -- 2.7.4