Limit the number of places where shaped type has to explicitly reference its...
authorGeoffrey Martin-Noble <gcmn@google.com>
Sat, 25 May 2019 00:44:21 +0000 (17:44 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:01:22 +0000 (20:01 -0700)
    Introduces a hasRank() method to make checking for rank a bit easier.

    This is partially to make it easier to make MemRef subclass ShapedType

--

PiperOrigin-RevId: 249927442

mlir/include/mlir/IR/StandardTypes.h
mlir/lib/IR/StandardTypes.cpp

index de11f87..867fb8d 100644 (file)
@@ -192,38 +192,38 @@ public:
   /// Return the element type.
   Type getElementType() const;
 
-  /// If an element type is an integer or a float, return its width.  Abort
-  /// otherwise.
+  /// If an element type is an integer or a float, return its width. Otherwise,
+  /// abort.
   unsigned getElementTypeBitWidth() const;
 
-  /// If this is ranked tensor or vector type, return the number of elements. If
-  /// it is an unranked tensor, abort.
+  /// If this is a ranked type, return the number of elements. Otherwise, abort.
   unsigned getNumElements() const;
 
-  /// If this is ranked tensor or vector type, return the rank. If it is an
-  /// unranked tensor, return -1.
+  /// If this is a ranked type, return the rank. Otherwise, return -1.
   int64_t getRank() const;
 
-  /// If this is ranked tensor or vector type, return the shape. If it is an
-  /// unranked tensor, abort.
+  /// Whether or not this is a ranked type. Vector and ranked tensors have a
+  /// rank, while unranked tensors do not.
+  bool hasRank() const;
+
+  /// If this is a ranked type, return the shape. Otherwise, abort.
   ArrayRef<int64_t> getShape() const;
 
-  /// If this is unranked tensor or any dimension has unknown size (<0),
-  /// it doesn't have static shape. If all dimensions have known size (>= 0),
-  /// it has static shape.
+  /// If this is unranked type or any dimension has unknown size (<0), it
+  /// doesn't have static shape. If all dimensions have known size (>= 0), it
+  /// has static shape.
   bool hasStaticShape() const;
 
-  /// If this is ranked tensor or vector type, return the size of the specified
-  /// dimension. It aborts if the tensor is unranked (this can be checked by
-  /// the getRank call method).
+  /// If this is ranked type, return the size of the specified dimension.
+  /// Otherwise, abort.
   int64_t getDimSize(unsigned i) const;
 
   /// Get the total amount of bits occupied by a value of this type.  This does
   /// not take into account any memory layout or widening constraints, e.g. a
   /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
   /// it will likely be stored as in a 4xi64 vector register.  Fail an assertion
-  /// if the size cannot be computed statically, i.e. if the tensor has a
-  /// dynamic shape or if its elemental type does not have a known bit width.
+  /// if the size cannot be computed statically, i.e. if the type has a dynamic
+  /// shape or if its elemental type does not have a known bit width.
   int64_t getSizeInBits() const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
index 1958a43..03bce65 100644 (file)
@@ -116,48 +116,29 @@ unsigned ShapedType::getElementTypeBitWidth() const {
 }
 
 unsigned ShapedType::getNumElements() const {
-  switch (getKind()) {
-  case StandardTypes::Vector:
-  case StandardTypes::RankedTensor: {
-    assert(hasStaticShape() && "expected type to have static shape");
-    auto shape = getShape();
-    unsigned num = 1;
-    for (auto dim : shape)
-      num *= dim;
-    return num;
-  }
-  default:
-    llvm_unreachable("not a ShapedType or not ranked");
-  }
+  assert(hasStaticShape() && "expected type to have static shape");
+  auto shape = getShape();
+  unsigned num = 1;
+  for (auto dim : shape)
+    num *= dim;
+  return num;
 }
 
-/// If this is ranked tensor or vector type, return the rank. If it is an
-/// unranked tensor, return -1.
 int64_t ShapedType::getRank() const {
-  switch (getKind()) {
-  case StandardTypes::Vector:
-  case StandardTypes::RankedTensor:
-    return getShape().size();
-  case StandardTypes::UnrankedTensor:
-    return -1;
-  default:
-    llvm_unreachable("not a ShapedType");
-  }
+  return hasRank() ? getShape().size() : -1;
 }
 
+bool ShapedType::hasRank() const { return !isa<UnrankedTensorType>(); }
+
 int64_t ShapedType::getDimSize(unsigned i) const {
-  switch (getKind()) {
-  case StandardTypes::Vector:
-  case StandardTypes::RankedTensor:
+  if (hasRank())
     return getShape()[i];
-  default:
-    llvm_unreachable("not a ShapedType or not ranked");
-  }
+  llvm_unreachable("not a ShapedType or not ranked");
 }
 
-// Get the number of number of bits require to store a value of the given vector
-// or tensor types.  Compute the value recursively since tensors are allowed to
-// have vectors as elements.
+/// Get the number of bits require to store a value of the given shaped type.
+/// Compute the value recursively since tensors are allowed to have vectors as
+/// elements.
 int64_t ShapedType::getSizeInBits() const {
   assert(hasStaticShape() &&
          "cannot get the bit size of an aggregate with a dynamic shape");
@@ -185,7 +166,7 @@ ArrayRef<int64_t> ShapedType::getShape() const {
 }
 
 bool ShapedType::hasStaticShape() const {
-  if (isa<UnrankedTensorType>())
+  if (!hasRank())
     return false;
   return llvm::none_of(getShape(), [](int64_t i) { return i < 0; });
 }