Some cleanup of ShapedType now that MemRef subclasses it.
authorGeoffrey Martin-Noble <gcmn@google.com>
Fri, 31 May 2019 20:28:19 +0000 (13:28 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:12:51 +0000 (20:12 -0700)
    Extract common methods into ShapedType.
    Simplify methods.
    Remove some extraneous asserts.
    Replace sentinel value with a helper method to check the same.

--

PiperOrigin-RevId: 250945261

mlir/include/mlir/IR/StandardTypes.h
mlir/lib/Analysis/Utils.cpp
mlir/lib/IR/StandardTypes.cpp

index 7b794c9..713fc6c 100644 (file)
@@ -214,6 +214,10 @@ public:
   /// has static shape.
   bool hasStaticShape() const;
 
+  /// If this is a ranked type, return the number of dimensions with dynamic
+  /// size. Otherwise, abort.
+  unsigned getNumDynamicDims() const;
+
   /// If this is ranked type, return the size of the specified dimension.
   /// Otherwise, abort.
   int64_t getDimSize(unsigned i) const;
@@ -233,6 +237,9 @@ public:
            type.getKind() == StandardTypes::UnrankedTensor ||
            type.getKind() == StandardTypes::MemRef;
   }
+
+  /// Whether the given dimension size indicates a dynamic dimension.
+  static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; }
 };
 
 /// Vector types represent multi-dimensional SIMD vectors, and have a fixed
@@ -402,21 +409,8 @@ public:
   /// Returns the memory space in which data referred to by this memref resides.
   unsigned getMemorySpace() const;
 
-  // TODO(b/132735995) Extract into shaped type.
-  /// Returns the number of dimensions with dynamic size.
-  unsigned getNumDynamicDims() const;
-
-  // TODO(b/132735995) Extract into shaped type.
-  /// If any dimension of the shape has unknown size (<0), it doesn't have
-  /// static shape.
-  bool hasStaticShape() const { return getNumDynamicDims() == 0; }
-
   static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }
 
-  // TODO(b/132735995) Extract into shaped type.
-  /// Integer value indicating that the size in a dimension is dynamic.
-  static constexpr int64_t kDynamicDimSize = -1;
-
 private:
   /// Get or create a new MemRefType defined by the arguments.  If the resulting
   /// type would be ill-formed, return nullptr.  If the location is provided,
index 59ab8f7..476c7c8 100644 (file)
@@ -308,7 +308,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
     for (unsigned r = 0; r < rank; r++) {
       cst.addConstantLowerBound(r, 0);
       int64_t dimSize = memRefType.getDimSize(r);
-      if (dimSize == MemRefType::kDynamicDimSize)
+      if (ShapedType::isDynamic(dimSize))
         continue;
       cst.addConstantUpperBound(r, dimSize - 1);
     }
index 80170a1..c8e7e69 100644 (file)
@@ -116,7 +116,7 @@ unsigned ShapedType::getElementTypeBitWidth() const {
 }
 
 unsigned ShapedType::getNumElements() const {
-  assert(hasStaticShape() && "expected type to have static shape");
+  assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
   auto shape = getShape();
   unsigned num = 1;
   for (auto dim : shape)
@@ -124,18 +124,11 @@ unsigned ShapedType::getNumElements() const {
   return num;
 }
 
-int64_t ShapedType::getRank() const {
-  assert(hasRank());
-  return getShape().size();
-}
+int64_t ShapedType::getRank() const { return getShape().size(); }
 
 bool ShapedType::hasRank() const { return !isa<UnrankedTensorType>(); }
 
-int64_t ShapedType::getDimSize(unsigned i) const {
-  if (hasRank())
-    return getShape()[i];
-  llvm_unreachable("not a ShapedType or not ranked");
-}
+int64_t ShapedType::getDimSize(unsigned i) const { return getShape()[i]; }
 
 /// Get the number of bits require to store a value of the given shaped type.
 /// Compute the value recursively since tensors are allowed to have vectors as
@@ -169,10 +162,12 @@ ArrayRef<int64_t> ShapedType::getShape() const {
   }
 }
 
+unsigned ShapedType::getNumDynamicDims() const {
+  return llvm::count_if(getShape(), isDynamic);
+}
+
 bool ShapedType::hasStaticShape() const {
-  if (!hasRank())
-    return false;
-  return llvm::none_of(getShape(), [](int64_t i) { return i < 0; });
+  return hasRank() && llvm::none_of(getShape(), isDynamic);
 }
 
 //===----------------------------------------------------------------------===//
@@ -291,9 +286,6 @@ LogicalResult UnrankedTensorType::verifyConstructionInvariants(
 // MemRefType
 //===----------------------------------------------------------------------===//
 
-// static constexpr must have a definition (until in C++17 and inline variable).
-constexpr int64_t MemRefType::kDynamicDimSize;
-
 /// Get or create a new MemRefType defined by the arguments.  If the resulting
 /// type would be ill-formed, return nullptr.  If the location is provided,
 /// emit detailed error messages.  To emit errors when the location is unknown,
@@ -355,10 +347,6 @@ ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
 
 unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
 
-unsigned MemRefType::getNumDynamicDims() const {
-  return llvm::count_if(getShape(), [](int64_t i) { return i < 0; });
-}
-
 //===----------------------------------------------------------------------===//
 /// ComplexType
 //===----------------------------------------------------------------------===//