From 9304168103bd82f3fd838d7df36fbf23d32ec419 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Tue, 2 Jul 2019 13:30:04 +0000 Subject: [PATCH] [X86][AVX] combineX86ShuffleChain - pull out CombineShuffleWithExtract lambda. NFCI. Pull out CombineShuffleWithExtract lambda to new combineX86ShuffleChainWithExtract wrapper and refactored it to handle more than 2 shuffle inputs - this will allow combineX86ShufflesRecursively to call this in a future patch. llvm-svn: 364924 --- llvm/lib/Target/X86/X86ISelLowering.cpp | 221 +++++++++++++++++--------------- 1 file changed, 116 insertions(+), 105 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index b6e3106..75fdd0c 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -31810,6 +31810,11 @@ static bool matchBinaryPermuteShuffle( return false; } +static SDValue combineX86ShuffleChainWithExtract( + ArrayRef Inputs, SDValue Root, ArrayRef BaseMask, int Depth, + bool HasVariableMask, bool AllowVariableMask, SelectionDAG &DAG, + const X86Subtarget &Subtarget); + /// Combine an arbitrary chain of shuffles into a single instruction if /// possible. /// @@ -32083,87 +32088,6 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, bool MaskContainsZeros = any_of(Mask, [](int M) { return M == SM_SentinelZero; }); - // Unwrap shuffle(extract_subvector(x,c1),extract_subvector(y,c2),m1) -> - // shuffle(x,y,m2) - auto CombineShuffleWithExtract = - [&](SDValue &NewRoot, SmallVectorImpl &NewMask, - SmallVectorImpl &NewInputs) -> bool { - assert(NewMask.empty() && NewInputs.empty() && "Non-empty shuffle mask"); - if (UnaryShuffle) - return false; - - SDValue Src1 = V1, Src2 = V2; - unsigned Offset1 = 0, Offset2 = 0; - while (Src1.getOpcode() == ISD::EXTRACT_SUBVECTOR && - isa(Src1.getOperand(1))) { - Offset1 += Src1.getConstantOperandVal(1); - Src1 = Src1.getOperand(0); - } - while (Src2.getOpcode() == ISD::EXTRACT_SUBVECTOR && - isa(Src2.getOperand(1))) { - Offset2 += Src2.getConstantOperandVal(1); - Src2 = Src2.getOperand(0); - } - if (Offset1 == 0 && Offset2 == 0) - return false; - - // If the src vector types aren't the same, see if we can extend - // one to match the other. - if ((Src1.getValueType().getScalarType() != - Src2.getValueType().getScalarType()) || - !DAG.getTargetLoweringInfo().isTypeLegal(Src1.getValueType()) || - !DAG.getTargetLoweringInfo().isTypeLegal(Src2.getValueType())) - return false; - - unsigned Src1SizeInBits = Src1.getValueSizeInBits(); - unsigned Src2SizeInBits = Src2.getValueSizeInBits(); - assert(((Src1SizeInBits % Src2SizeInBits) == 0 || - (Src2SizeInBits % Src1SizeInBits) == 0) && - "Shuffle vector size mismatch"); - if (Src1SizeInBits != Src2SizeInBits) { - if (Src1SizeInBits > Src2SizeInBits) { - Src2 = widenSubVector(Src2, false, Subtarget, DAG, DL, Src1SizeInBits); - Src2SizeInBits = Src1SizeInBits; - } else { - Src1 = widenSubVector(Src1, false, Subtarget, DAG, DL, Src2SizeInBits); - Src1SizeInBits = Src2SizeInBits; - } - } - - assert(((Offset1 % VT1.getVectorNumElements()) == 0 && - (Offset2 % VT2.getVectorNumElements()) == 0 && - (Src1SizeInBits % RootSizeInBits) == 0 && - Src1SizeInBits == Src2SizeInBits) && - "Unexpected subvector extraction"); - unsigned Scale = Src1SizeInBits / RootSizeInBits; - - // Convert extraction indices to mask size. - Offset1 /= VT1.getVectorNumElements(); - Offset2 /= VT2.getVectorNumElements(); - Offset1 *= NumMaskElts; - Offset2 *= NumMaskElts; - - NewInputs.push_back(Src1); - if (Src1 != Src2) { - NewInputs.push_back(Src2); - Offset2 += Scale * NumMaskElts; - } - - // Create new mask for larger type. - NewMask.append(Mask.begin(), Mask.end()); - for (int &M : NewMask) { - if (M < 0) - continue; - if (M < (int)NumMaskElts) - M += Offset1; - else - M = (M - NumMaskElts) + Offset2; - } - NewMask.append((Scale - 1) * NumMaskElts, SM_SentinelUndef); - NewRoot = Src1; - return true; - }; - if (is128BitLaneCrossingShuffleMask(MaskVT, Mask)) { // If we have a single input lane-crossing shuffle then lower to VPERMV. if (UnaryShuffle && AllowVariableMask && !MaskContainsZeros && @@ -32209,18 +32133,10 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, // If that failed and either input is extracted then try to combine as a // shuffle with the larger type. - SDValue NewRoot; - SmallVector NewMask; - SmallVector NewInputs; - if (CombineShuffleWithExtract(NewRoot, NewMask, NewInputs)) { - if (SDValue Res = combineX86ShuffleChain( - NewInputs, NewRoot, NewMask, Depth, HasVariableMask, - AllowVariableMask, DAG, Subtarget)) { - Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT1, Res, - DAG.getIntPtrConstant(0, DL)); - return DAG.getBitcast(RootVT, Res); - } - } + if (SDValue WideShuffle = combineX86ShuffleChainWithExtract( + Inputs, Root, BaseMask, Depth, HasVariableMask, AllowVariableMask, + DAG, Subtarget)) + return WideShuffle; // If we have a dual input lane-crossing shuffle then lower to VPERMV3. if (AllowVariableMask && !MaskContainsZeros && @@ -32389,18 +32305,10 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, // If that failed and either input is extracted then try to combine as a // shuffle with the larger type. - SDValue NewRoot; - SmallVector NewMask; - SmallVector NewInputs; - if (CombineShuffleWithExtract(NewRoot, NewMask, NewInputs)) { - if (SDValue Res = combineX86ShuffleChain(NewInputs, NewRoot, NewMask, Depth, - HasVariableMask, AllowVariableMask, - DAG, Subtarget)) { - Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT1, Res, - DAG.getIntPtrConstant(0, DL)); - return DAG.getBitcast(RootVT, Res); - } - } + if (SDValue WideShuffle = combineX86ShuffleChainWithExtract( + Inputs, Root, BaseMask, Depth, HasVariableMask, AllowVariableMask, + DAG, Subtarget)) + return WideShuffle; // If we have a dual input shuffle then lower to VPERMV3. if (!UnaryShuffle && AllowVariableMask && !MaskContainsZeros && @@ -32428,6 +32336,109 @@ static SDValue combineX86ShuffleChain(ArrayRef Inputs, SDValue Root, return SDValue(); } +// Combine an arbitrary chain of shuffles + extract_subvectors into a single +// instruction if possible. +// +// Wrapper for combineX86ShuffleChain that extends the shuffle mask to a larger +// type size to attempt to combine: +// shuffle(extract_subvector(x,c1),extract_subvector(y,c2),m1) +// --> +// extract_subvector(shuffle(x,y,m2),0) +static SDValue combineX86ShuffleChainWithExtract( + ArrayRef Inputs, SDValue Root, ArrayRef BaseMask, int Depth, + bool HasVariableMask, bool AllowVariableMask, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + unsigned NumMaskElts = BaseMask.size(); + unsigned NumInputs = Inputs.size(); + if (NumInputs == 0) + return SDValue(); + + SmallVector WideInputs(Inputs.begin(), Inputs.end()); + SmallVector Offsets(NumInputs, 0); + + // Peek through subvectors. + // TODO: Support inter-mixed EXTRACT_SUBVECTORs + BITCASTs? + unsigned WideSizeInBits = WideInputs[0].getValueSizeInBits(); + for (unsigned i = 0; i != NumInputs; ++i) { + SDValue &Src = WideInputs[i]; + unsigned &Offset = Offsets[i]; + Src = peekThroughBitcasts(Src); + EVT BaseVT = Src.getValueType(); + while (Src.getOpcode() == ISD::EXTRACT_SUBVECTOR && + isa(Src.getOperand(1))) { + Offset += Src.getConstantOperandVal(1); + Src = Src.getOperand(0); + } + WideSizeInBits = std::max(WideSizeInBits, Src.getValueSizeInBits()); + assert((Offset % BaseVT.getVectorNumElements()) == 0 && + "Unexpected subvector extraction"); + Offset /= BaseVT.getVectorNumElements(); + Offset *= NumMaskElts; + } + + // Bail if we're always extracting from the lowest subvectors, + // combineX86ShuffleChain should match this for the current width. + if (llvm::all_of(Offsets, [](unsigned Offset) { return Offset == 0; })) + return SDValue(); + + EVT RootVT = Root.getValueType(); + unsigned RootSizeInBits = RootVT.getSizeInBits(); + unsigned Scale = WideSizeInBits / RootSizeInBits; + assert((WideSizeInBits % RootSizeInBits) == 0 && + "Unexpected subvector extraction"); + + // If the src vector types aren't the same, see if we can extend + // them to match each other. + // TODO: Support different scalar types? + EVT WideSVT = WideInputs[0].getValueType().getScalarType(); + if (llvm::any_of(WideInputs, [&WideSVT, &DAG](SDValue Op) { + return !DAG.getTargetLoweringInfo().isTypeLegal(Op.getValueType()) || + Op.getValueType().getScalarType() != WideSVT; + })) + return SDValue(); + + for (SDValue &NewInput : WideInputs) { + assert((WideSizeInBits % NewInput.getValueSizeInBits()) == 0 && + "Shuffle vector size mismatch"); + if (WideSizeInBits > NewInput.getValueSizeInBits()) + NewInput = widenSubVector(NewInput, false, Subtarget, DAG, + SDLoc(NewInput), WideSizeInBits); + assert(WideSizeInBits == NewInput.getValueSizeInBits() && + "Unexpected subvector extraction"); + } + + // Create new mask for larger type. + for (unsigned i = 1; i != NumInputs; ++i) + Offsets[i] += i * Scale * NumMaskElts; + + SmallVector WideMask(BaseMask.begin(), BaseMask.end()); + for (int &M : WideMask) { + if (M < 0) + continue; + M = (M % NumMaskElts) + Offsets[M / NumMaskElts]; + } + WideMask.append((Scale - 1) * NumMaskElts, SM_SentinelUndef); + + // Remove unused/repeated shuffle source ops. + resolveTargetShuffleInputsAndMask(WideInputs, WideMask); + assert(!WideInputs.empty() && "Shuffle with no inputs detected"); + + if (WideInputs.size() > 2) + return SDValue(); + + // Attempt to combine wider chain. + // TODO: Can we use a better Root? + SDValue WideRoot = WideInputs[0]; + if (SDValue WideShuffle = combineX86ShuffleChain( + WideInputs, WideRoot, WideMask, Depth, HasVariableMask, + AllowVariableMask, DAG, Subtarget)) { + WideShuffle = + extractSubVector(WideShuffle, 0, DAG, SDLoc(Root), RootSizeInBits); + return DAG.getBitcast(RootVT, WideShuffle); + } + return SDValue(); +} + // Attempt to constant fold all of the constant source ops. // Returns true if the entire shuffle is folded to a constant. // TODO: Extend this to merge multiple constant Ops and update the mask. -- 2.7.4