Make MemRefType subclass ShapedType
authorGeoffrey Martin-Noble <gcmn@google.com>
Fri, 31 May 2019 20:28:01 +0000 (13:28 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:12:40 +0000 (20:12 -0700)
    MemRefs have the same notion of shape, rank, and fixed element type. This allows us to reuse utilities based on shape for memref.

    All dyn_cast and isa calls for ShapedType have been checked and either modified to explicitly check for vector or tensor, or confirmed to not depend on the result being a vector or tensor.

    Discussion in https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/cHLoyfGu8y8

--

PiperOrigin-RevId: 250945184

mlir/include/mlir/IR/StandardTypes.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/StandardTypes.cpp
mlir/lib/IR/TypeDetail.h

index ea88768..7b794c9 100644 (file)
@@ -180,10 +180,10 @@ public:
   const llvm::fltSemantics &getFloatSemantics();
 };
 
-// TODO(b/132735995) Add support for MemRef
-/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
-/// types because they share behavior and semantics around shape, rank, and
-/// fixed element type.
+/// This is a common base class between Vector, UnrankedTensor, RankedTensor,
+/// and MemRef types because they share behavior and semantics around shape,
+/// rank, and fixed element type. Any type with these semantics should inherit
+/// from ShapedType.
 class ShapedType : public Type {
 public:
   using ImplType = detail::ShapedTypeStorage;
@@ -202,8 +202,8 @@ public:
   /// If this is a ranked type, return the rank. Otherwise, abort.
   int64_t getRank() const;
 
-  /// Whether or not this is a ranked type. Vector and ranked tensors have a
-  /// rank, while unranked tensors do not.
+  /// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors
+  /// have a rank, while unranked tensors do not.
   bool hasRank() const;
 
   /// If this is a ranked type, return the shape. Otherwise, abort.
@@ -230,7 +230,8 @@ public:
   static bool classof(Type type) {
     return type.getKind() == StandardTypes::Vector ||
            type.getKind() == StandardTypes::RankedTensor ||
-           type.getKind() == StandardTypes::UnrankedTensor;
+           type.getKind() == StandardTypes::UnrankedTensor ||
+           type.getKind() == StandardTypes::MemRef;
   }
 };
 
@@ -359,7 +360,7 @@ public:
 /// unknown (represented by any negative integer). MemRef types also have an
 /// affine map composition, represented as an array AffineMap pointers.
 class MemRefType
-    : public Type::TypeBase<MemRefType, Type, detail::MemRefTypeStorage> {
+    : public Type::TypeBase<MemRefType, ShapedType, detail::MemRefTypeStorage> {
 public:
   using Base::Base;
 
@@ -389,17 +390,11 @@ public:
                    location);
   }
 
+  // TODO(b/132735995) Get rid of this unsigned override.
   unsigned getRank() const { return getShape().size(); }
 
-  /// Returns an array of memref shape dimension sizes.
   ArrayRef<int64_t> getShape() const;
 
-  /// Return the size of the specified dimension, or -1 if unspecified.
-  int64_t getDimSize(unsigned i) const { return getShape()[i]; }
-
-  /// Returns the elemental type for this memref shape.
-  Type getElementType() const;
-
   /// Returns an array of affine map pointers representing the memref affine
   /// map composition.
   ArrayRef<AffineMap> getAffineMaps() const;
@@ -407,15 +402,18 @@ public:
   /// Returns the memory space in which data referred to by this memref resides.
   unsigned getMemorySpace() const;
 
+  // TODO(b/132735995) Extract into shaped type.
   /// Returns the number of dimensions with dynamic size.
   unsigned getNumDynamicDims() const;
 
+  // TODO(b/132735995) Extract into shaped type.
   /// If any dimension of the shape has unknown size (<0), it doesn't have
   /// static shape.
   bool hasStaticShape() const { return getNumDynamicDims() == 0; }
 
   static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }
 
+  // TODO(b/132735995) Extract into shaped type.
   /// Integer value indicating that the size in a dimension is dynamic.
   static constexpr int64_t kDynamicDimSize = -1;
 
index b3e2649..239d58d 100644 (file)
@@ -203,8 +203,6 @@ void ModuleState::visitType(Type type) {
     // Visit affine maps in memref type.
     for (auto map : memref.getAffineMaps())
       recordAttributeReference(AffineMapAttr::get(map));
-    // TODO(b/132735995) Remove this when MemRef is a subclass of ShapedType.
-    visitType(memref.getElementType());
   }
   if (auto shapedType = type.dyn_cast<ShapedType>()) {
     visitType(shapedType.getElementType());
index 78a39de..80170a1 100644 (file)
@@ -162,6 +162,8 @@ ArrayRef<int64_t> ShapedType::getShape() const {
     return cast<VectorType>().getShape();
   case StandardTypes::RankedTensor:
     return cast<RankedTensorType>().getShape();
+  case StandardTypes::MemRef:
+    return cast<MemRefType>().getShape();
   default:
     llvm_unreachable("not a ShapedType or not ranked");
   }
@@ -347,8 +349,6 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
 
 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
 
-Type MemRefType::getElementType() const { return getImpl()->elementType; }
-
 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
   return getImpl()->getAffineMaps();
 }
index 541aacd..0e7edf0 100644 (file)
@@ -200,13 +200,13 @@ struct UnrankedTensorTypeStorage : public ShapedTypeStorage {
   }
 };
 
-struct MemRefTypeStorage : public TypeStorage {
+struct MemRefTypeStorage : public ShapedTypeStorage {
   MemRefTypeStorage(unsigned shapeSize, Type elementType,
                     const int64_t *shapeElements, const unsigned numAffineMaps,
                     AffineMap const *affineMapList, const unsigned memorySpace)
-      : TypeStorage(shapeSize), elementType(elementType),
-        shapeElements(shapeElements), numAffineMaps(numAffineMaps),
-        affineMapList(affineMapList), memorySpace(memorySpace) {}
+      : ShapedTypeStorage(elementType, shapeSize), shapeElements(shapeElements),
+        numAffineMaps(numAffineMaps), affineMapList(affineMapList),
+        memorySpace(memorySpace) {}
 
   /// The hash key used for uniquing.
   // MemRefs are uniqued based on their shape, element type, affine map
@@ -242,8 +242,6 @@ struct MemRefTypeStorage : public TypeStorage {
     return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
   }
 
-  /// The type of each scalar element of the memref.
-  Type elementType;
   /// An array of integers which stores the shape dimension sizes.
   const int64_t *shapeElements;
   /// The number of affine maps in the 'affineMapList' array.