static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
SDLoc DL(N);
+ auto *GorS = cast<MaskedGatherScatterSDNode>(N);
+ SDValue Chain = GorS->getChain();
+ SDValue Index = GorS->getIndex();
+ SDValue Mask = GorS->getMask();
+ SDValue Base = GorS->getBasePtr();
+ SDValue Scale = GorS->getScale();
if (DCI.isBeforeLegalizeOps()) {
- SDValue Index = N->getOperand(4);
// Remove any sign extends from 32 or smaller to larger than 32.
// Only do this before LegalizeOps in case we need the sign extend for
// legalization.
- if (Index.getOpcode() == ISD::SIGN_EXTEND) {
- if (Index.getScalarValueSizeInBits() > 32 &&
- Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
- SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
- NewOps[4] = Index.getOperand(0);
- SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
- if (Res == N) {
- // The original sign extend has less users, add back to worklist in
- // case it needs to be removed
- DCI.AddToWorklist(Index.getNode());
- DCI.AddToWorklist(N);
- }
- return SDValue(Res, 0);
- }
+ if (Index.getOpcode() == ISD::SIGN_EXTEND &&
+ Index.getScalarValueSizeInBits() > 32 &&
+ Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
+ Index = Index.getOperand(0);
+ if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
+ SDValue Ops[] = { Chain, Gather->getPassThru(),
+ Mask, Base, Index, Scale } ;
+ return DAG.getMaskedGather(Gather->getVTList(),
+ Gather->getMemoryVT(), DL, Ops,
+ Gather->getMemOperand(),
+ Gather->getIndexType());
+ }
+ auto *Scatter = cast<MaskedScatterSDNode>(GorS);
+ SDValue Ops[] = { Chain, Scatter->getValue(),
+ Mask, Base, Index, Scale };
+ return DAG.getMaskedScatter(Scatter->getVTList(),
+ Scatter->getMemoryVT(), DL,
+ Ops, Scatter->getMemOperand(),
+ Scatter->getIndexType());
}
// Make sure the index is either i32 or i64
EVT IndexVT = EVT::getVectorVT(*DAG.getContext(), EltVT,
Index.getValueType().getVectorNumElements());
Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
- SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
- NewOps[4] = Index;
- SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
- if (Res == N)
- DCI.AddToWorklist(N);
- return SDValue(Res, 0);
+ if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
+ SDValue Ops[] = { Chain, Gather->getPassThru(),
+ Mask, Base, Index, Scale } ;
+ return DAG.getMaskedGather(Gather->getVTList(),
+ Gather->getMemoryVT(), DL, Ops,
+ Gather->getMemOperand(),
+ Gather->getIndexType());
+ }
+ auto *Scatter = cast<MaskedScatterSDNode>(GorS);
+ SDValue Ops[] = { Chain, Scatter->getValue(),
+ Mask, Base, Index, Scale };
+ return DAG.getMaskedScatter(Scatter->getVTList(),
+ Scatter->getMemoryVT(), DL,
+ Ops, Scatter->getMemOperand(),
+ Scatter->getIndexType());
}
// Try to remove zero extends from 32->64 if we know the sign bit of
// the input is zero.
if (Index.getOpcode() == ISD::ZERO_EXTEND &&
Index.getScalarValueSizeInBits() == 64 &&
- Index.getOperand(0).getScalarValueSizeInBits() == 32) {
- if (DAG.SignBitIsZero(Index.getOperand(0))) {
- SmallVector<SDValue, 5> NewOps(N->op_begin(), N->op_end());
- NewOps[4] = Index.getOperand(0);
- SDNode *Res = DAG.UpdateNodeOperands(N, NewOps);
- if (Res == N) {
- // The original sign extend has less users, add back to worklist in
- // case it needs to be removed
- DCI.AddToWorklist(Index.getNode());
- DCI.AddToWorklist(N);
- }
- return SDValue(Res, 0);
- }
+ Index.getOperand(0).getScalarValueSizeInBits() == 32 &&
+ DAG.SignBitIsZero(Index.getOperand(0))) {
+ Index = Index.getOperand(0);
+ if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
+ SDValue Ops[] = { Chain, Gather->getPassThru(),
+ Mask, Base, Index, Scale } ;
+ return DAG.getMaskedGather(Gather->getVTList(),
+ Gather->getMemoryVT(), DL, Ops,
+ Gather->getMemOperand(),
+ Gather->getIndexType());
+ }
+ auto *Scatter = cast<MaskedScatterSDNode>(GorS);
+ SDValue Ops[] = { Chain, Scatter->getValue(),
+ Mask, Base, Index, Scale };
+ return DAG.getMaskedScatter(Scatter->getVTList(),
+ Scatter->getMemoryVT(), DL,
+ Ops, Scatter->getMemOperand(),
+ Scatter->getIndexType());
}
}
// With vector masks we only demand the upper bit of the mask.
- SDValue Mask = cast<MaskedGatherScatterSDNode>(N)->getMask();
if (Mask.getScalarValueSizeInBits() != 1) {
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));