Array<SignedVarIntList>:$shape,
Type:$elementType
)> {
- let printerPredicate = "!$_val.getNumScalableDims()";
+ let printerPredicate = "!$_val.isScalable()";
}
def VectorTypeWithScalableDims : DialectType<(type
Array<BoolList>:$scalableDims,
- VarInt:$numScalableDims,
Array<SignedVarIntList>:$shape,
Type:$elementType
)> {
- let printerPredicate = "$_val.getNumScalableDims()";
+ let printerPredicate = "$_val.isScalable()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims, scalableDims)";
+ let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
}
}
/// Build from another VectorType.
explicit Builder(VectorType other)
: shape(other.getShape()), elementType(other.getElementType()),
- numScalableDims(other.getNumScalableDims()),
scalableDims(other.getScalableDims()) {}
/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType,
unsigned numScalableDims = 0, ArrayRef<bool> scalableDims = {})
- : shape(shape), elementType(elementType),
- numScalableDims(numScalableDims) {
+ : shape(shape), elementType(elementType) {
if (scalableDims.empty())
scalableDims = SmallVector<bool>(shape.size(), false);
else
this->scalableDims = scalableDims;
}
- Builder &setShape(ArrayRef<int64_t> newShape, unsigned newNumScalableDims = 0,
+ Builder &setShape(ArrayRef<int64_t> newShape,
ArrayRef<bool> newIsScalableDim = {}) {
- numScalableDims = newNumScalableDims;
if (newIsScalableDim.empty())
scalableDims = SmallVector<bool>(shape.size(), false);
else
/// Erase a dim from shape @pos.
Builder &dropDim(unsigned pos) {
assert(pos < shape.size() && "overflow");
- if (pos >= shape.size() - numScalableDims)
- numScalableDims--;
if (storage.empty())
storage.append(shape.begin(), shape.end());
if (storageScalableDims.empty())
operator Type() {
if (shape.empty())
return elementType;
- return VectorType::get(shape, elementType, numScalableDims, scalableDims);
+ return VectorType::get(shape, elementType, scalableDims);
}
private:
// Owning shape data for copy-on-write operations.
SmallVector<int64_t> storage;
Type elementType;
- unsigned numScalableDims;
ArrayRef<bool> scalableDims;
// Owning scalableDims data for copy-on-write operations.
SmallVector<bool> storageScalableDims;
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
- "unsigned":$numScalableDims,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType,
- CArg<"unsigned", "0">:$numScalableDims,
CArg<"ArrayRef<bool>", "{}">:$scalableDims
), [{
// While `scalableDims` is optional, its default value should be
isScalableVec.resize(shape.size(), false);
scalableDims = isScalableVec;
}
- return $_get(elementType.getContext(), shape, elementType,
- numScalableDims, scalableDims);
+ return $_get(elementType.getContext(), shape, elementType, scalableDims);
}]>
];
let extraClassDeclaration = [{
/// Returns true if the vector contains scalable dimensions.
bool isScalable() const {
- return getNumScalableDims() > 0;
+ return llvm::is_contained(getScalableDims(), true);
+ }
+ bool allDimsScalable() const {
+ // Treat 0-d vectors as fixed size.
+ if (getRank() == 0)
+ return false;
+ return !llvm::is_contained(getScalableDims(), false);
}
/// Get or create a new VectorType with the same shape as `this` and an
/// Parse a vector type.
VectorType parseVectorType();
ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
- unsigned &numScalableDims,
SmallVectorImpl<bool> &scalableDims);
ParseResult parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
bool allowDynamic = true,
SmallVector<int64_t, 4> dimensions;
SmallVector<bool, 4> scalableDims;
- unsigned numScalableDims;
- if (parseVectorDimensionList(dimensions, numScalableDims, scalableDims))
+ if (parseVectorDimensionList(dimensions, scalableDims))
return nullptr;
if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
return emitError(getToken().getLoc(),
return emitError(typeLoc, "vector elements must be int/index/float type"),
nullptr;
- return VectorType::get(dimensions, elementType, numScalableDims,
- scalableDims);
+ return VectorType::get(dimensions, elementType, scalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
/// For i-th dimension, `scalableDims[i]` contains either:
/// * `false` for a non-scalable dimension (e.g. `4`),
/// * `true` for a scalable dimension (e.g. `[4]`).
-/// This method also returns the number of scalable dimensions in
-/// `numScalableDims`.
///
/// vector-dim-list := (static-dim-list `x`)?
/// static-dim-list ::= static-dim (`x` static-dim)*
///
ParseResult
Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
- unsigned &numScalableDims,
SmallVectorImpl<bool> &scalableDims) {
- numScalableDims = 0;
// If there is a set of fixed-length dimensions, consume it
while (getToken().is(Token::integer) || getToken().is(Token::l_square)) {
int64_t value;
if (scalable) {
if (!consumeIf(Token::r_square))
return emitWrongTokenError("missing ']' closing scalable dimension");
- numScalableDims++;
}
scalableDims.push_back(scalable);
// Make sure we have an 'x' or something like 'xbf32'.
return {};
if (type.getShape().empty())
return VectorType::get({1}, elementType);
- Type vectorType =
- VectorType::get(type.getShape().back(), elementType,
- type.getNumScalableDims(), type.getScalableDims().back());
+ Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
+ assert((type.isScalable() == type.allDimsScalable()) &&
+ "expected scalable vector with all dims scalable");
auto shape = type.getShape();
for (int i = shape.size() - 2; i >= 0; --i)
vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]);
// Helper to reduce vector type by one rank at front.
static VectorType reducedVectorTypeFront(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
- unsigned numScalableDims = tp.getNumScalableDims();
- if (tp.getShape().size() == numScalableDims)
- --numScalableDims;
return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
- numScalableDims);
+ tp.getScalableDims().drop_front());
}
// Helper to reduce vector type by *all* but one rank at back.
static VectorType reducedVectorTypeBack(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
- unsigned numScalableDims = tp.getNumScalableDims();
- if (numScalableDims > 0)
- --numScalableDims;
return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
- numScalableDims);
+ tp.getScalableDims().take_back());
}
// Helper that picks the proper sequence for inserting.
return UnrankedTensorType::get(i1Type);
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return VectorType::get(vectorType.getShape(), i1Type,
- vectorType.getNumScalableDims(),
vectorType.getScalableDims());
return i1Type;
}
auto i1Type = IntegerType::get(type.getContext(), 1);
if (auto sVectorType = llvm::dyn_cast<VectorType>(type))
return VectorType::get(sVectorType.getShape(), i1Type,
- sVectorType.getNumScalableDims(),
sVectorType.getScalableDims());
return nullptr;
}
// LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
// scalable/non-scalable.
- SmallVector<bool> scalableDims(1, isScalable);
-
- return VectorType::get(numElements, elementType,
- static_cast<unsigned>(isScalable), scalableDims);
+ return VectorType::get(numElements, elementType, {isScalable});
}
Type mlir::LLVM::getVectorType(Type elementType,
"type");
if (useLLVM)
return LLVMScalableVectorType::get(elementType, numElements);
- return VectorType::get(numElements, elementType, /*numScalableDims=*/1);
+
+ // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
+ // scalable/non-scalable.
+ return VectorType::get(numElements, elementType, /*scalableDims=*/true);
}
llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
assert(areValidScalableVecDims(scalableDims) &&
"Permuted scalable vector dimensions are not supported");
- // TODO: Extend scalable vector type to support a bit map.
- bool numScalableDims = !scalableVecDims.empty() && scalableVecDims.back();
- return VectorType::get(vectorShape, elementType, numScalableDims,
- scalableDims);
+ return VectorType::get(vectorShape, elementType, scalableDims);
}
/// Masks an operation with the canonical vector mask if the operation needs
if (firstMaxRankedType) {
auto vecType = VectorType::get(firstMaxRankedType.getShape(),
getElementTypeOrSelf(vecOperand.getType()),
- firstMaxRankedType.getNumScalableDims(),
firstMaxRankedType.getScalableDims());
vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
} else {
resultTypes.push_back(
firstMaxRankedType
? VectorType::get(firstMaxRankedType.getShape(), resultType,
- firstMaxRankedType.getNumScalableDims(),
firstMaxRankedType.getScalableDims())
: resultType);
}
/// Constructs vector type for element type.
static VectorType vectorType(VL vl, Type etp) {
- unsigned numScalableDims = vl.enableVLAVectorization;
- return VectorType::get(vl.vectorLength, etp, numScalableDims,
- vl.enableVLAVectorization);
+ return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
}
/// Constructs vector type from a memref value.
// Inspect source type. For vector types, apply the same
// vectorization to the destination type.
if (auto vtp = dyn_cast<VectorType>(src.getType()))
- return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
+ return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
return dtp;
}
/// Returns the mask type expected by this operation.
Type MultiDimReductionOp::getExpectedMaskType() {
auto vecType = getSourceVectorType();
- return VectorType::get(
- vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
- vecType.getNumScalableDims(), vecType.getScalableDims());
+ return VectorType::get(vecType.getShape(),
+ IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getScalableDims());
}
namespace {
/// Returns the mask type expected by this operation.
Type ReductionOp::getExpectedMaskType() {
auto vecType = getSourceVectorType();
- return VectorType::get(
- vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
- vecType.getNumScalableDims(), vecType.getScalableDims());
+ return VectorType::get(vecType.getShape(),
+ IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getScalableDims());
}
Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
assert(!ShapedType::isDynamicShape(maskShape) &&
"Mask shape couldn't be computed");
// TODO: Extend the scalable vector type representation with a bit map.
- assert(lhsType.getNumScalableDims() == 0 &&
- rhsType.getNumScalableDims() == 0 &&
+ assert(!lhsType.isScalable() && !rhsType.isScalable() &&
"Scalable vectors are not supported yet");
return VectorType::get(maskShape,
if (vRHS) {
SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0],
vRHS.getScalableDims()[0]};
- auto numScalableDims =
- count_if(scalableDimsRes, [](bool isScalable) { return isScalable; });
resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
- vLHS.getElementType(), numScalableDims,
- scalableDimsRes);
+ vLHS.getElementType(), scalableDimsRes);
} else {
// Scalar RHS operand
SmallVector<bool> scalableDimsRes{vLHS.getScalableDims()[0]};
- auto numScalableDims =
- count_if(scalableDimsRes, [](bool isScalable) { return isScalable; });
resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(),
- numScalableDims, scalableDimsRes);
+ scalableDimsRes);
}
if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
/// verification purposes. It requires the operation to be vectorized."
Type OuterProductOp::getExpectedMaskType() {
auto vecType = this->getResultVectorType();
- return VectorType::get(
- vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
- vecType.getNumScalableDims(), vecType.getScalableDims());
+ return VectorType::get(vecType.getShape(),
+ IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getScalableDims());
}
//===----------------------------------------------------------------------===//
SmallVector<bool> scalableDims =
applyPermutationMap(invPermMap, vecType.getScalableDims());
- return VectorType::get(maskShape, i1Type, vecType.getNumScalableDims(),
- scalableDims);
+ return VectorType::get(maskShape, i1Type, scalableDims);
}
ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
/// verification purposes. It requires the operation to be vectorized."
Type GatherOp::getExpectedMaskType() {
auto vecType = this->getIndexVectorType();
- return VectorType::get(
- vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1),
- vecType.getNumScalableDims(), vecType.getScalableDims());
+ return VectorType::get(vecType.getShape(),
+ IntegerType::get(vecType.getContext(), /*width=*/1),
+ vecType.getScalableDims());
}
std::optional<SmallVector<int64_t, 4>> GatherOp::getShapeForUnroll() {
Value mask = rewriter.create<vector::CreateMaskOp>(
loc,
VectorType::get(vtp.getShape(), rewriter.getI1Type(),
- vtp.getNumScalableDims()),
+ vtp.getScalableDims()),
b);
if (xferOp.getMask()) {
// Intersect the in-bounds with the mask specified as an op parameter.
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
- unsigned numScalableDims,
ArrayRef<bool> scalableDims) {
if (!isValidElementType(elementType))
return emitError()
<< "vector types must have positive constant sizes but got "
<< shape;
- if (numScalableDims > shape.size())
- return emitError()
- << "number of scalable dims cannot exceed the number of dims"
- << " (" << numScalableDims << " vs " << shape.size() << ")";
-
if (scalableDims.size() != shape.size())
return emitError() << "number of dims must match, got "
<< scalableDims.size() << " and " << shape.size();
- auto numScale =
- count_if(scalableDims, [](bool isScalable) { return isScalable; });
- if (numScale != numScalableDims)
- return emitError() << "number of scalable dims must match, explicit: "
- << numScalableDims << ", and bools:" << numScale;
-
return success();
}
return VectorType();
if (auto et = llvm::dyn_cast<IntegerType>(getElementType()))
if (auto scaledEt = et.scaleElementBitwidth(scale))
- return VectorType::get(getShape(), scaledEt, getNumScalableDims());
+ return VectorType::get(getShape(), scaledEt, getScalableDims());
if (auto et = llvm::dyn_cast<FloatType>(getElementType()))
if (auto scaledEt = et.scaleElementBitwidth(scale))
- return VectorType::get(getShape(), scaledEt, getNumScalableDims());
+ return VectorType::get(getShape(), scaledEt, getScalableDims());
return VectorType();
}
VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return VectorType::get(shape.value_or(getShape()), elementType,
- getNumScalableDims());
+ getScalableDims());
}
//===----------------------------------------------------------------------===//