Support `getShape`, `hasStaticShape` and `getDimSize` methods for all the Vector...
authorFeng Liu <fengliuai@google.com>
Tue, 9 Oct 2018 23:49:39 +0000 (16:49 -0700)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 20:26:38 +0000 (13:26 -0700)
PiperOrigin-RevId: 216447553

mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/Types.h
mlir/lib/IR/Builders.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/StandardOps.cpp
mlir/lib/IR/Types.cpp
mlir/lib/Parser/Parser.cpp

index 0d56db4b9e75ec43ce2ec88b32a42f8fcbe966d4..2c1b6ddc726dde7d56e2883460f2571b4c28150b 100644 (file)
@@ -86,7 +86,7 @@ public:
   MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
                             ArrayRef<AffineMap> affineMapComposition = {},
                             unsigned memorySpace = 0);
-  VectorType *getVectorType(ArrayRef<unsigned> shape, Type *elementType);
+  VectorType *getVectorType(ArrayRef<int> shape, Type *elementType);
   RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
   UnrankedTensorType *getTensorType(Type *elementType);
 
index 65c355a9042916480781e642ec68739f04f08af6..ee34e203d1c43d7d91cebab1a8d41b14f46c0374 100644 (file)
@@ -293,9 +293,25 @@ class VectorOrTensorType : public Type {
 public:
   Type *getElementType() const { return elementType; }
 
-  /// If this is ranked tensor or vector type, return the rank.  If it is an
+  /// If this is ranked tensor or vector type, return the rank. If it is an
   /// unranked tensor, return -1.
-  int getRankIfPresent() const;
+  int getRank() const;
+
+  /// If this is ranked tensor or vector type, return the shape. If it is an
+  /// unranked tensor, return an empty array.
+  ArrayRef<int> getShape() const;
+
+  /// If any dimension has unknown size (<0), it doesn't have static shape.
+  /// If all dimensions has known size (>= 0), it has static shape.
+  bool hasStaticShape() const {
+    auto dims = getShape();
+    return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
+  }
+
+  /// If this is ranked tensor or vector type, return the size of the specified
+  /// dimension. It aborts if the tensor is unranked (this can be checked by
+  /// the getRank call method).
+  int getDimSize(unsigned i) const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const Type *type) {
@@ -315,27 +331,22 @@ public:
 /// known constant shape with one or more dimension.
 class VectorType : public VectorOrTensorType {
 public:
-  static VectorType *get(ArrayRef<unsigned> shape, Type *elementType);
+  static VectorType *get(ArrayRef<int> shape, Type *elementType);
 
-  unsigned getRank() const { return getSubclassData(); }
-
-  ArrayRef<unsigned> getShape() const {
-    return ArrayRef<unsigned>(shapeElements, getSubclassData());
+  ArrayRef<int> getShape() const {
+    return ArrayRef<int>(shapeElements, getSubclassData());
   }
 
-  /// Return the size of the specified dimension.
-  unsigned getDimSize(unsigned i) const { return getShape()[i]; }
-
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const Type *type) {
     return type->getKind() == Kind::Vector;
   }
 
 private:
-  const unsigned *shapeElements;
+  const int *shapeElements;
   Type *elementType;
 
-  VectorType(ArrayRef<unsigned> shape, Type *elementType, MLIRContext *context);
+  VectorType(ArrayRef<int> shape, Type *elementType, MLIRContext *context);
   ~VectorType() = delete;
 };
 
@@ -363,15 +374,10 @@ public:
   static RankedTensorType *get(ArrayRef<int> shape,
                                Type *elementType);
 
-  unsigned getRank() const { return getSubclassData(); }
-
   ArrayRef<int> getShape() const {
     return ArrayRef<int>(shapeElements, getSubclassData());
   }
 
-  /// Return the size of the specified dimension, or -1 if unspecified.
-  int getDimSize(unsigned i) const { return getShape()[i]; }
-
   static bool classof(const Type *type) {
     return type->getKind() == Kind::RankedTensor;
   }
@@ -390,6 +396,8 @@ class UnrankedTensorType : public TensorType {
 public:
   static UnrankedTensorType *get(Type *elementType);
 
+  ArrayRef<int> getShape() const { return ArrayRef<int>(); }
+
   static bool classof(const Type *type) {
     return type->getKind() == Kind::UnrankedTensor;
   }
index 449c4dc822f12b89b051ce30675ada191ed66bc3..59299eec2d58113b740142d526d3e1173cd38b00 100644 (file)
@@ -95,8 +95,7 @@ MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType,
   return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
 }
 
-VectorType *Builder::getVectorType(ArrayRef<unsigned> shape,
-                                   Type *elementType) {
+VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) {
   return VectorType::get(shape, elementType);
 }
 
index fceb0760c22c04bdb9863d7174d9379e0b5af109..7501c7d6b57cc6d2d07e66d517c555637a41c7cf 100644 (file)
@@ -85,7 +85,7 @@ struct AffineMapKeyInfo : DenseMapInfo<AffineMap> {
 
 struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> {
   // Vectors are uniqued based on their element type and shape.
-  using KeyTy = std::pair<Type *, ArrayRef<unsigned>>;
+  using KeyTy = std::pair<Type *, ArrayRef<int>>;
   using DenseMapInfo<VectorType *>::getHashValue;
   using DenseMapInfo<VectorType *>::isEqual;
 
@@ -484,10 +484,13 @@ FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
   return *existing.first = result;
 }
 
-VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
+VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
   assert(!shape.empty() && "vector types must have at least one dimension");
   assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) &&
          "vectors elements must be primitives");
+  assert(!std::any_of(shape.begin(), shape.end(), [](int i) {
+    return i < 0;
+  }) && "vector types must have static shape");
 
   auto *context = elementType->getContext();
   auto &impl = context->getImpl();
index 9524c3056a707b53e235fea10b436ce711cdad7b..1099dc45ab741a5f0d3472e89a659317d3d05dc1 100644 (file)
@@ -792,7 +792,7 @@ bool ExtractElementOp::verify() const {
       return emitOpError("index to extract_element must have 'index' type");
 
   // Verify the # indices match if we have a ranked type.
-  auto aggregateRank = aggregateType->getRankIfPresent();
+  auto aggregateRank = aggregateType->getRank();
   if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
     return emitOpError("incorrect number of indices for extract_element");
 
index 1ff5de40b864a0a93f608332c8145cb09907a2d8..12880a3b1aef1ae666d060386e1546cd012eb307 100644 (file)
@@ -40,22 +40,33 @@ VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
                                        Type *elementType, unsigned subClassData)
     : Type(kind, context, subClassData), elementType(elementType) {}
 
-/// If this is ranked tensor or vector type, return the rank.  If it is an
+/// If this is ranked tensor or vector type, return the rank. If it is an
 /// unranked tensor, return -1.
-int VectorOrTensorType::getRankIfPresent() const {
+int VectorOrTensorType::getRank() const {
   switch (getKind()) {
   default:
     llvm_unreachable("not a VectorOrTensorType");
   case Kind::Vector:
-    return cast<VectorType>(this)->getRank();
+    return cast<VectorType>(this)->getShape().size();
   case Kind::RankedTensor:
-    return cast<RankedTensorType>(this)->getRank();
+    return cast<RankedTensorType>(this)->getShape().size();
   case Kind::UnrankedTensor:
     return -1;
   }
 }
 
-VectorType::VectorType(ArrayRef<unsigned> shape, Type *elementType,
+int VectorOrTensorType::getDimSize(unsigned i) const {
+  switch (getKind()) {
+  case Kind::Vector:
+    return cast<VectorType>(this)->getShape()[i];
+  case Kind::RankedTensor:
+    return cast<RankedTensorType>(this)->getShape()[i];
+  default:
+    llvm_unreachable("not a VectorOrTensorType");
+  }
+}
+
+VectorType::VectorType(ArrayRef<int> shape, Type *elementType,
                        MLIRContext *context)
     : VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
       shapeElements(shape.data()) {}
index 5e0f5f31313e68f50f9533623ecbb2577fb0da6b..727c86e1caa717a47ecd6c0b33b9ba407eebf05d 100644 (file)
@@ -372,13 +372,13 @@ VectorType *Parser::parseVectorType() {
   if (getToken().isNot(Token::integer))
     return (emitError("expected dimension size in vector type"), nullptr);
 
-  SmallVector<unsigned, 4> dimensions;
+  SmallVector<int, 4> dimensions;
   while (getToken().is(Token::integer)) {
     // Make sure this integer value is in bound and valid.
     auto dimension = getToken().getUnsignedIntegerValue();
     if (!dimension.hasValue())
       return (emitError("invalid dimension in vector type"), nullptr);
-    dimensions.push_back(dimension.getValue());
+    dimensions.push_back((int)dimension.getValue());
 
     consumeToken(Token::integer);