MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap> affineMapComposition = {},
unsigned memorySpace = 0);
- VectorType *getVectorType(ArrayRef<unsigned> shape, Type *elementType);
+ VectorType *getVectorType(ArrayRef<int> shape, Type *elementType);
RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
UnrankedTensorType *getTensorType(Type *elementType);
public:
Type *getElementType() const { return elementType; }
- /// If this is ranked tensor or vector type, return the rank. If it is an
+ /// If this is ranked tensor or vector type, return the rank. If it is an
/// unranked tensor, return -1.
- int getRankIfPresent() const;
+ int getRank() const;
+
+ /// If this is ranked tensor or vector type, return the shape. If it is an
+ /// unranked tensor, return an empty array.
+ ArrayRef<int> getShape() const;
+
+ /// If any dimension has unknown size (<0), it doesn't have static shape.
+ /// If all dimensions has known size (>= 0), it has static shape.
+ bool hasStaticShape() const {
+ auto dims = getShape();
+ return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
+ }
+
+ /// If this is ranked tensor or vector type, return the size of the specified
+ /// dimension. It aborts if the tensor is unranked (this can be checked by
+ /// the getRank call method).
+ int getDimSize(unsigned i) const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
/// known constant shape with one or more dimension.
class VectorType : public VectorOrTensorType {
public:
- static VectorType *get(ArrayRef<unsigned> shape, Type *elementType);
+ static VectorType *get(ArrayRef<int> shape, Type *elementType);
- unsigned getRank() const { return getSubclassData(); }
-
- ArrayRef<unsigned> getShape() const {
- return ArrayRef<unsigned>(shapeElements, getSubclassData());
+ ArrayRef<int> getShape() const {
+ return ArrayRef<int>(shapeElements, getSubclassData());
}
- /// Return the size of the specified dimension.
- unsigned getDimSize(unsigned i) const { return getShape()[i]; }
-
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() == Kind::Vector;
}
private:
- const unsigned *shapeElements;
+ const int *shapeElements;
Type *elementType;
- VectorType(ArrayRef<unsigned> shape, Type *elementType, MLIRContext *context);
+ VectorType(ArrayRef<int> shape, Type *elementType, MLIRContext *context);
~VectorType() = delete;
};
static RankedTensorType *get(ArrayRef<int> shape,
Type *elementType);
- unsigned getRank() const { return getSubclassData(); }
-
ArrayRef<int> getShape() const {
return ArrayRef<int>(shapeElements, getSubclassData());
}
- /// Return the size of the specified dimension, or -1 if unspecified.
- int getDimSize(unsigned i) const { return getShape()[i]; }
-
static bool classof(const Type *type) {
return type->getKind() == Kind::RankedTensor;
}
public:
static UnrankedTensorType *get(Type *elementType);
+ ArrayRef<int> getShape() const { return ArrayRef<int>(); }
+
static bool classof(const Type *type) {
return type->getKind() == Kind::UnrankedTensor;
}
return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
}
-VectorType *Builder::getVectorType(ArrayRef<unsigned> shape,
- Type *elementType) {
+VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) {
return VectorType::get(shape, elementType);
}
struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> {
// Vectors are uniqued based on their element type and shape.
- using KeyTy = std::pair<Type *, ArrayRef<unsigned>>;
+ using KeyTy = std::pair<Type *, ArrayRef<int>>;
using DenseMapInfo<VectorType *>::getHashValue;
using DenseMapInfo<VectorType *>::isEqual;
return *existing.first = result;
}
-VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
+VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
assert(!shape.empty() && "vector types must have at least one dimension");
assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) &&
"vectors elements must be primitives");
+ assert(!std::any_of(shape.begin(), shape.end(), [](int i) {
+ return i < 0;
+ }) && "vector types must have static shape");
auto *context = elementType->getContext();
auto &impl = context->getImpl();
return emitOpError("index to extract_element must have 'index' type");
// Verify the # indices match if we have a ranked type.
- auto aggregateRank = aggregateType->getRankIfPresent();
+ auto aggregateRank = aggregateType->getRank();
if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
return emitOpError("incorrect number of indices for extract_element");
Type *elementType, unsigned subClassData)
: Type(kind, context, subClassData), elementType(elementType) {}
-/// If this is ranked tensor or vector type, return the rank. If it is an
+/// If this is ranked tensor or vector type, return the rank. If it is an
/// unranked tensor, return -1.
-int VectorOrTensorType::getRankIfPresent() const {
+int VectorOrTensorType::getRank() const {
switch (getKind()) {
default:
llvm_unreachable("not a VectorOrTensorType");
case Kind::Vector:
- return cast<VectorType>(this)->getRank();
+ return cast<VectorType>(this)->getShape().size();
case Kind::RankedTensor:
- return cast<RankedTensorType>(this)->getRank();
+ return cast<RankedTensorType>(this)->getShape().size();
case Kind::UnrankedTensor:
return -1;
}
}
-VectorType::VectorType(ArrayRef<unsigned> shape, Type *elementType,
+int VectorOrTensorType::getDimSize(unsigned i) const {
+ switch (getKind()) {
+ case Kind::Vector:
+ return cast<VectorType>(this)->getShape()[i];
+ case Kind::RankedTensor:
+ return cast<RankedTensorType>(this)->getShape()[i];
+ default:
+ llvm_unreachable("not a VectorOrTensorType");
+ }
+}
+
+VectorType::VectorType(ArrayRef<int> shape, Type *elementType,
MLIRContext *context)
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
shapeElements(shape.data()) {}
if (getToken().isNot(Token::integer))
return (emitError("expected dimension size in vector type"), nullptr);
- SmallVector<unsigned, 4> dimensions;
+ SmallVector<int, 4> dimensions;
while (getToken().is(Token::integer)) {
// Make sure this integer value is in bound and valid.
auto dimension = getToken().getUnsignedIntegerValue();
if (!dimension.hasValue())
return (emitError("invalid dimension in vector type"), nullptr);
- dimensions.push_back(dimension.getValue());
+ dimensions.push_back((int)dimension.getValue());
consumeToken(Token::integer);