if (!PredVT.isScalableVector() || PredVT.getVectorElementType() != MVT::i1)
return EVT();
- const unsigned NumElts = PredVT.getVectorNumElements();
-
- if (NumElts != 2 && NumElts != 4 && NumElts != 8 && NumElts != 16)
+ if (PredVT != MVT::nxv16i1 && PredVT != MVT::nxv8i1 &&
+ PredVT != MVT::nxv4i1 && PredVT != MVT::nxv2i1)
return EVT();
- EVT ScalarVT = EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / NumElts);
- EVT MemVT = EVT::getVectorVT(Ctx, ScalarVT, NumElts, /*IsScalable=*/true);
+ ElementCount EC = PredVT.getVectorElementCount();
+ EVT ScalarVT = EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / EC.Min);
+ EVT MemVT = EVT::getVectorVT(Ctx, ScalarVT, EC);
return MemVT;
}