[mlir][IR] Improve `clone` function return type of shaped types
authorMatthias Springer <me@m-sp.org>
Thu, 25 May 2023 07:22:19 +0000 (09:22 +0200)
committerMatthias Springer <me@m-sp.org>
Thu, 25 May 2023 07:27:33 +0000 (09:27 +0200)
There are `clone` overloads that take a shape as a parameter. These overloads are guaranteed to return a ranked shaped type.

`TensorType::clone`/`BaseMemRefType::clone` used to always return a `TensorType`/`BaseMemRefType`. The variants that take a shape parameter now return a `RankedTensorType`/`MemRefType`. Better static type information can make extra casts at the call site obsolete.

E.g.:
```
{TensorType/RankedTensorType} t;
t.clone({1, 2})  // now returns RankedTensorType instead of TensorType
```

Also improve documentation for `clone`.

Differential Revision: https://reviews.llvm.org/D150865

mlir/include/mlir/IR/BuiltinTypeInterfaces.td
mlir/include/mlir/IR/BuiltinTypes.h
mlir/include/mlir/IR/BuiltinTypes.td
mlir/lib/IR/BuiltinTypes.cpp

index bb38985..db38e2e 100644 (file)
@@ -59,8 +59,13 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
   }];
   let methods = [
     InterfaceMethod<[{
-      Returns a clone of this type with the given shape and element
-      type. If a shape is not provided, the current shape of the type is used.
+      Returns a clone of this type with the given shape and element type.
+
+      If no shape is provided, the shape of this type is used. In that case, if
+      this type is unranked, so is the resulting type.
+
+      If a shape is provided, the resulting type is always ranked, even if this
+      type is unranked.
     }],
     "::mlir::ShapedType", "cloneWith", (ins
       "::std::optional<::llvm::ArrayRef<int64_t>>":$shape,
@@ -89,7 +94,7 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
 
     /// Whether the given dimension size indicates a dynamic dimension.
     static constexpr bool isDynamic(int64_t dValue) {
-       return dValue == kDynamic;
+      return dValue == kDynamic;
     }
 
     /// Whether the given shape has any size that indicates a dynamic dimension.
@@ -99,18 +104,24 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
 
     /// Return the number of elements present in the given shape.
     static int64_t getNumElements(ArrayRef<int64_t> shape);
-  }];
 
-  let extraSharedClassDeclaration = [{
     /// Return a clone of this type with the given new shape and element type.
+    /// The returned type is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
-      return $_type.cloneWith(shape, elementType);
+      return cloneWith(shape, elementType);
     }
-    /// Return a clone of this type with the given new shape.
+
+    /// Return a clone of this type with the given new shape. The returned type
+    /// is ranked, even if this type is unranked.
     auto clone(::llvm::ArrayRef<int64_t> shape) {
-      return $_type.cloneWith(shape, $_type.getElementType());
+      return cloneWith(shape, getElementType());
     }
-    /// Return a clone of this type with the given new element type.
+  }];
+
+  let extraSharedClassDeclaration = [{
+    /// Return a clone of this type with the given new element type. The
+    /// returned type is ranked if and only if this type is ranked. In that
+    /// case, the returned type has the same shape as this type.
     auto clone(::mlir::Type elementType) {
       return $_type.cloneWith(/*shape=*/std::nullopt, elementType);
     }
index 4fc82dd..79313b6 100644 (file)
@@ -27,6 +27,8 @@ class AffineMap;
 class FloatType;
 class IndexType;
 class IntegerType;
+class MemRefType;
+class RankedTensorType;
 class StringAttr;
 class TypeRange;
 
@@ -95,6 +97,17 @@ public:
   TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
                        Type elementType) const;
 
+  // Make sure that base class overloads are visible.
+  using ShapedType::Trait<TensorType>::clone;
+
+  /// Return a clone of this type with the given new shape and element type.
+  /// The returned type is ranked, even if this type is unranked.
+  RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
+
+  /// Return a clone of this type with the given new shape. The returned type
+  /// is ranked, even if this type is unranked.
+  RankedTensorType clone(ArrayRef<int64_t> shape) const;
+
   /// Return true if the specified element type is ok in a tensor.
   static bool isValidElementType(Type type);
 
@@ -131,6 +144,17 @@ public:
   BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
                            Type elementType) const;
 
+  // Make sure that base class overloads are visible.
+  using ShapedType::Trait<BaseMemRefType>::clone;
+
+  /// Return a clone of this type with the given new shape and element type.
+  /// The returned type is ranked, even if this type is unranked.
+  MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
+
+  /// Return a clone of this type with the given new shape. The returned type
+  /// is ranked, even if this type is unranked.
+  MemRefType clone(ArrayRef<int64_t> shape) const;
+
   /// Return true if the specified element type is ok in a memref.
   static bool isValidElementType(Type type);
 
index 218c240..58a0156 100644 (file)
@@ -629,7 +629,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", [
       "unsigned":$memorySpaceInd)>
   ];
   let extraClassDeclaration = [{
-    using ShapedType::Trait<MemRefType>::clone;
+    using BaseMemRefType::clone;
     using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
     using ShapedType::Trait<MemRefType>::getRank;
     using ShapedType::Trait<MemRefType>::getNumElements;
@@ -794,7 +794,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using ShapedType::Trait<RankedTensorType>::clone;
+    using TensorType::clone;
     using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
     using ShapedType::Trait<RankedTensorType>::getRank;
     using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -807,6 +807,12 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
     /// This is a builder type that keeps local references to arguments.
     /// Arguments that are passed into the builder must outlive the builder.
     class Builder;
+
+    /// Return a clone of this type with the given new element type and the same
+    /// shape as this type.
+    RankedTensorType clone(::mlir::Type elementType) {
+      return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
+    }
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
@@ -931,7 +937,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using ShapedType::Trait<UnrankedMemRefType>::clone;
+    using BaseMemRefType::clone;
     using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
     using ShapedType::Trait<UnrankedMemRefType>::getRank;
     using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -946,6 +952,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
     /// [deprecated] Returns the memory space in old raw integer representation.
     /// New `Attribute getMemorySpace()` method should be used instead.
     unsigned getMemorySpaceAsInt() const;
+
+    /// Return a clone of this type with the given new element type and the same
+    /// shape as this type.
+    MemRefType clone(::mlir::Type elementType) {
+      return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
+    }
   }];
   let skipDefaultBuilders = 1;
   let genVerifyDecl = 1;
@@ -984,7 +996,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
     }]>
   ];
   let extraClassDeclaration = [{
-    using ShapedType::Trait<UnrankedTensorType>::clone;
+    using TensorType::clone;
     using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
     using ShapedType::Trait<UnrankedTensorType>::getRank;
     using ShapedType::Trait<UnrankedTensorType>::getNumElements;
index b46ea8a..c816e4a 100644 (file)
@@ -291,6 +291,15 @@ TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
                                rankedTy.getEncoding());
 }
 
+RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
+                                   Type elementType) const {
+  return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
+}
+
+RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
+  return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
+}
+
 // Check if "elementType" can be an element type of a tensor.
 static LogicalResult
 checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
@@ -370,6 +379,15 @@ BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
   return builder;
 }
 
+MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
+                                 Type elementType) const {
+  return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
+}
+
+MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
+  return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
+}
+
 Attribute BaseMemRefType::getMemorySpace() const {
   if (auto rankedMemRefTy = dyn_cast<MemRefType>())
     return rankedMemRefTy.getMemorySpace();