return SDValue();
}
-/// \brief Returns a vector of 0s if the node in input is a vector logical
-/// shift by a constant amount which is known to be bigger than or equal
-/// to the vector element size in bits.
-static SDValue performShiftToAllZeros(SDNode *N, SelectionDAG &DAG,
- const X86Subtarget &Subtarget) {
- EVT VT = N->getValueType(0);
-
- if (VT != MVT::v2i64 && VT != MVT::v4i32 && VT != MVT::v8i16 &&
- (!Subtarget.hasInt256() ||
- (VT != MVT::v4i64 && VT != MVT::v8i32 && VT != MVT::v16i16)))
- return SDValue();
-
- SDValue Amt = N->getOperand(1);
- SDLoc DL(N);
- if (auto *AmtBV = dyn_cast<BuildVectorSDNode>(Amt))
- if (auto *AmtSplat = AmtBV->getConstantSplatNode()) {
- const APInt &ShiftAmt = AmtSplat->getAPIntValue();
- unsigned MaxAmount =
- VT.getSimpleVT().getScalarSizeInBits();
-
- // SSE2/AVX2 logical shifts always return a vector of 0s
- // if the shift amount is bigger than or equal to
- // the element size. The constant shift amount will be
- // encoded as a 8-bit immediate.
- if (ShiftAmt.trunc(8).uge(MaxAmount))
- return getZeroVector(VT.getSimpleVT(), Subtarget, DAG, DL);
- }
-
- return SDValue();
-}
-
static SDValue combineShift(SDNode* N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
if (SDValue V = combineShiftRightLogical(N, DAG))
return V;
- // Try to fold this logical shift into a zero vector.
- if (N->getOpcode() != ISD::SRA)
- if (SDValue V = performShiftToAllZeros(N, DAG, Subtarget))
- return V;
-
return SDValue();
}