There are `clone` overloads that take a shape as a parameter. These overloads are guaranteed to return a ranked shaped type.
`TensorType::clone`/`BaseMemRefType::clone` used to always return a `TensorType`/`BaseMemRefType`. The variants that take a shape parameter now return a `RankedTensorType`/`MemRefType`. Better static type information can make extra casts at the call site obsolete.
E.g.:
```
{TensorType/RankedTensorType} t;
t.clone({1, 2}) // now returns RankedTensorType instead of TensorType
```
Also improve documentation for `clone`.
Differential Revision: https://reviews.llvm.org/D150865
}];
let methods = [
InterfaceMethod<[{
- Returns a clone of this type with the given shape and element
- type. If a shape is not provided, the current shape of the type is used.
+ Returns a clone of this type with the given shape and element type.
+
+ If no shape is provided, the shape of this type is used. In that case, if
+ this type is unranked, so is the resulting type.
+
+ If a shape is provided, the resulting type is always ranked, even if this
+ type is unranked.
}],
"::mlir::ShapedType", "cloneWith", (ins
"::std::optional<::llvm::ArrayRef<int64_t>>":$shape,
/// Whether the given dimension size indicates a dynamic dimension.
static constexpr bool isDynamic(int64_t dValue) {
- return dValue == kDynamic;
+ return dValue == kDynamic;
}
/// Whether the given shape has any size that indicates a dynamic dimension.
/// Return the number of elements present in the given shape.
static int64_t getNumElements(ArrayRef<int64_t> shape);
- }];
- let extraSharedClassDeclaration = [{
/// Return a clone of this type with the given new shape and element type.
+ /// The returned type is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
- return $_type.cloneWith(shape, elementType);
+ return cloneWith(shape, elementType);
}
- /// Return a clone of this type with the given new shape.
+
+ /// Return a clone of this type with the given new shape. The returned type
+ /// is ranked, even if this type is unranked.
auto clone(::llvm::ArrayRef<int64_t> shape) {
- return $_type.cloneWith(shape, $_type.getElementType());
+ return cloneWith(shape, getElementType());
}
- /// Return a clone of this type with the given new element type.
+ }];
+
+ let extraSharedClassDeclaration = [{
+ /// Return a clone of this type with the given new element type. The
+ /// returned type is ranked if and only if this type is ranked. In that
+ /// case, the returned type has the same shape as this type.
auto clone(::mlir::Type elementType) {
return $_type.cloneWith(/*shape=*/std::nullopt, elementType);
}
class FloatType;
class IndexType;
class IntegerType;
+class MemRefType;
+class RankedTensorType;
class StringAttr;
class TypeRange;
TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;
+ // Make sure that base class overloads are visible.
+ using ShapedType::Trait<TensorType>::clone;
+
+ /// Return a clone of this type with the given new shape and element type.
+ /// The returned type is ranked, even if this type is unranked.
+ RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
+
+ /// Return a clone of this type with the given new shape. The returned type
+ /// is ranked, even if this type is unranked.
+ RankedTensorType clone(ArrayRef<int64_t> shape) const;
+
/// Return true if the specified element type is ok in a tensor.
static bool isValidElementType(Type type);
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const;
+ // Make sure that base class overloads are visible.
+ using ShapedType::Trait<BaseMemRefType>::clone;
+
+ /// Return a clone of this type with the given new shape and element type.
+ /// The returned type is ranked, even if this type is unranked.
+ MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
+
+ /// Return a clone of this type with the given new shape. The returned type
+ /// is ranked, even if this type is unranked.
+ MemRefType clone(ArrayRef<int64_t> shape) const;
+
/// Return true if the specified element type is ok in a memref.
static bool isValidElementType(Type type);
"unsigned":$memorySpaceInd)>
];
let extraClassDeclaration = [{
- using ShapedType::Trait<MemRefType>::clone;
+ using BaseMemRefType::clone;
using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<MemRefType>::getRank;
using ShapedType::Trait<MemRefType>::getNumElements;
}]>
];
let extraClassDeclaration = [{
- using ShapedType::Trait<RankedTensorType>::clone;
+ using TensorType::clone;
using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<RankedTensorType>::getRank;
using ShapedType::Trait<RankedTensorType>::getNumElements;
/// This is a builder type that keeps local references to arguments.
/// Arguments that are passed into the builder must outlive the builder.
class Builder;
+
+ /// Return a clone of this type with the given new element type and the same
+ /// shape as this type.
+ RankedTensorType clone(::mlir::Type elementType) {
+ return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
+ }
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
}]>
];
let extraClassDeclaration = [{
- using ShapedType::Trait<UnrankedMemRefType>::clone;
+ using BaseMemRefType::clone;
using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedMemRefType>::getRank;
using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
/// [deprecated] Returns the memory space in old raw integer representation.
/// New `Attribute getMemorySpace()` method should be used instead.
unsigned getMemorySpaceAsInt() const;
+
+ /// Return a clone of this type with the given new element type and the same
+ /// shape as this type.
+ MemRefType clone(::mlir::Type elementType) {
+ return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
+ }
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
}]>
];
let extraClassDeclaration = [{
- using ShapedType::Trait<UnrankedTensorType>::clone;
+ using TensorType::clone;
using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
using ShapedType::Trait<UnrankedTensorType>::getRank;
using ShapedType::Trait<UnrankedTensorType>::getNumElements;
rankedTy.getEncoding());
}
+RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape,
+ Type elementType) const {
+ return ::llvm::cast<RankedTensorType>(cloneWith(shape, elementType));
+}
+
+RankedTensorType TensorType::clone(::llvm::ArrayRef<int64_t> shape) const {
+ return ::llvm::cast<RankedTensorType>(cloneWith(shape, getElementType()));
+}
+
// Check if "elementType" can be an element type of a tensor.
static LogicalResult
checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
return builder;
}
+MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape,
+ Type elementType) const {
+ return ::llvm::cast<MemRefType>(cloneWith(shape, elementType));
+}
+
+MemRefType BaseMemRefType::clone(::llvm::ArrayRef<int64_t> shape) const {
+ return ::llvm::cast<MemRefType>(cloneWith(shape, getElementType()));
+}
+
Attribute BaseMemRefType::getMemorySpace() const {
if (auto rankedMemRefTy = dyn_cast<MemRefType>())
return rankedMemRefTy.getMemorySpace();