`RankedTensorOf` and `TensorRankOf` (in Tablegen files) now generate code that uses `RankedTensorType` instead of `TensorType`. This gives us more accurate type information (e.g., when calling `op.getType()`).
Also use restrict tensor.expand_shape/tensor.collapse_shape/tensor.pad to ranked tensors. Only cast ops should deal with unranked tensors.
Also improves a few places in the code base (e.g., Toy tutorial) where a ranked tensor is assumed (e.g., because `getRank` is called) but a `TensorType` is currently used: cast to `RankedTensorType` directly, so that the assertion is triggered directly at the cast.
Differential Revision: https://reviews.llvm.org/D147149
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::TensorType>();
+ auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::TensorType>();
+ auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::TensorType>();
+ auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::TensorType>();
+ auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
// ToyToAffine RewritePatterns
//===----------------------------------------------------------------------===//
-/// Convert the given TensorType into the corresponding MemRefType.
-static MemRefType convertTensorToMemRef(TensorType type) {
- assert(type.hasRank() && "expected only ranked shapes");
+/// Convert the given RankedTensorType into the corresponding MemRefType.
+static MemRefType convertTensorToMemRef(RankedTensorType type) {
return MemRefType::get(type.getShape(), type.getElementType());
}
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
- auto tensorType = (*op->result_type_begin()).cast<TensorType>();
+ auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
- auto tensorType = op.getType().cast<TensorType>();
+ auto tensorType = op.getType().cast<RankedTensorType>();
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
// Check that the rank of the attribute type matches the rank of the constant
// result type.
- auto attrType = getValue().getType().cast<mlir::TensorType>();
+ auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
if (attrType.getRank() != resultType.getRank()) {
return emitOpError("return type must match the one of the attached value "
"attribute: ")
// ToyToAffine RewritePatterns
//===----------------------------------------------------------------------===//
-/// Convert the given TensorType into the corresponding MemRefType.
-static MemRefType convertTensorToMemRef(TensorType type) {
- assert(type.hasRank() && "expected only ranked shapes");
+/// Convert the given RankedTensorType into the corresponding MemRefType.
+static MemRefType convertTensorToMemRef(RankedTensorType type) {
return MemRefType::get(type.getShape(), type.getElementType());
}
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
- auto tensorType = (*op->result_type_begin()).cast<TensorType>();
+ auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
- auto tensorType = op.getType().cast<TensorType>();
+ auto tensorType = op.getType().cast<RankedTensorType>();
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
// Check that the rank of the attribute type matches the rank of the
// constant result type.
- auto attrType = attrValue.getType().cast<mlir::TensorType>();
+ auto attrType = attrValue.getType().cast<mlir::RankedTensorType>();
if (attrType.getRank() != resultType.getRank()) {
return op->emitOpError("return type must match the one of the attached "
"value attribute: ")
// ToyToAffine RewritePatterns
//===----------------------------------------------------------------------===//
-/// Convert the given TensorType into the corresponding MemRefType.
-static MemRefType convertTensorToMemRef(TensorType type) {
- assert(type.hasRank() && "expected only ranked shapes");
+/// Convert the given RankedTensorType into the corresponding MemRefType.
+static MemRefType convertTensorToMemRef(RankedTensorType type) {
return MemRefType::get(type.getShape(), type.getElementType());
}
static void lowerOpToLoops(Operation *op, ValueRange operands,
PatternRewriter &rewriter,
LoopIterationFn processIteration) {
- auto tensorType = (*op->result_type_begin()).cast<TensorType>();
+ auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
auto loc = op->getLoc();
// Insert an allocation and deallocation for the result of this operation.
// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
- auto tensorType = op.getType().cast<TensorType>();
+ auto tensorType = op.getType().cast<RankedTensorType>();
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
Tensor_Op<mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
Pure])>,
- Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>,
- Results<(outs AnyTensor:$result)> {
+ Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
+ Results<(outs AnyRankedTensor:$result)> {
code commonExtraClassDeclaration = [{
static StringRef getReassociationAttrStrName() { return "reassociation"; }
}];
let arguments = (ins
- AnyTensor:$source,
+ AnyRankedTensor:$source,
Variadic<Index>:$low,
Variadic<Index>:$high,
DenseI64ArrayAttr:$static_low,
let regions = (region SizedRegion<1>:$region);
- let results = (outs AnyTensor:$result);
+ let results = (outs AnyRankedTensor:$result);
// TODO: Remove custom<InferType> when AllTypesMatch supports opt. operands.
let assemblyFormat = [{
"$_self">])> {
code commonExtraClassDeclaration = [{
- size_t getSourceRank() { return getSource().getType().getRank(); };
- size_t getDestRank() { return getDest().getType().getRank(); };
+ size_t getSourceRank() { return getSourceType().getRank(); };
+ size_t getDestRank() { return getDestType().getRank(); };
RankedTensorType getSourceType() {
return getSource().getType().cast<RankedTensorType>(); };
RankedTensorType getDestType() {
def IsUnrankedTensorTypePred
: CPred<"$_self.isa<::mlir::UnrankedTensorType>()">;
+// Whether a type is a RankedTensorType
+def IsRankedTensorTypePred
+ : CPred<"$_self.isa<::mlir::RankedTensorType>()">;
+
// Whether a type is a BaseMemRefType
def IsBaseMemRefTypePred
: CPred<"$_self.isa<::mlir::BaseMemRefType>()">;
//===----------------------------------------------------------------------===//
// Tensor types.
-// Unranked tensor type whose element type is from the given
-// `allowedTypes` list.
-class UnrankedTensorOf<list<Type> allowedTypes>
- : ShapedContainerType<allowedTypes, IsUnrankedTensorTypePred,
- "unranked.tensor", "::mlir::UnrankedTensorType">;
+// Unranked tensor type whose element type is from the given `allowedTypes`
+// list, and which additionally satisfies an optional list of predicates.
+class UnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [],
+ string summary = "unranked tensor">
+ : ShapedContainerType<
+ allowedTypes, And<!listconcat([IsUnrankedTensorTypePred], preds)>,
+ summary, "::mlir::UnrankedTensorType">;
+
+// Ranked tensor type whose element type is from the given `allowedTypes` list,
+// and which additionally satisfies an optional list of predicates.
+class RankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [],
+ string summary = "ranked tensor">
+ : ShapedContainerType<
+ allowedTypes, And<!listconcat([IsRankedTensorTypePred], preds)>,
+ summary, "::mlir::RankedTensorType">;
// Any tensor type whose element type is from the given `allowedTypes`
// list, and which additionally satisfies an optional list of predicates.
def F32Tensor : TensorOf<[F32]>;
def F64Tensor : TensorOf<[F64]>;
-class RankedTensorOf<
- list<Type> allowedTypes,
- list<Pred> preds = [],
- string summary = "ranked tensor">
- : TensorOf<allowedTypes, !listconcat([HasRankPred], preds), summary>;
-
class Non0RankedTensorOf<list<Type> allowedTypes>
: TensorOf<allowedTypes, [HasRankGreaterOrEqualPred<1>],
"non-0-ranked.tensor">;
def AnyNon0RankedTensor : Non0RankedTensorOf<[AnyType]>;
def AnyUnrankedTensor : UnrankedTensorOf<[AnyType]>;
-def AnyNon0RankedOrUnrankedTensor:
- AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor]>;
+def AnyNon0RankedOrUnrankedTensor
+ : AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor],
+ "non-0-ranked or unranked tensor", "::mlir::TensorType">;
// Ranked tensor type with one of the specified types and ranks.
class TensorRankOf<list<Type> allowedTypes, list<int> ranks>
- : TensorOf<allowedTypes,
+ : RankedTensorOf<allowedTypes,
[HasAnyRankOfPred<ranks>],
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
class StaticShapeTensorOf<list<Type> allowedTypes>
- : TensorOf<allowedTypes, [HasStaticShapePred], "statically shaped tensor">;
+ : RankedTensorOf<allowedTypes, [HasStaticShapePred],
+ "statically shaped tensor">;
def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
LogicalResult
matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- TensorType tensorType = extractOp.getTensor().getType().cast<TensorType>();
+ auto tensorType = extractOp.getTensor().getType().cast<RankedTensorType>();
if (!tensorType.hasStaticShape())
return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
auto extractOperand =
tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
+ // Cannot fold cast to unranked tensor.
+ auto rankedResultType = tensorCast.getType().dyn_cast<RankedTensorType>();
+ if (!rankedResultType)
+ return failure();
+
if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
- tensorCast.getType().getShape() == tensorCast.getSource()
- .getType()
- .cast<RankedTensorType>()
- .getShape())
+ rankedResultType.getShape() == tensorCast.getSource()
+ .getType()
+ .cast<RankedTensorType>()
+ .getShape())
return failure();
SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
for (size_t i = 0, e = sizes.size(); i < e; i++) {
if (dimMask && dimMask->count(i))
continue;
- int64_t dim = tensorCast.getType().getShape()[dimIndex++];
+ int64_t dim = rankedResultType.getShape()[dimIndex++];
if (ShapedType::isDynamic(dim))
continue;
sizes[i] = rewriter.getIndexAttr(dim);
}
rewriter.replaceOpWithNewOp<ExtractSliceOp>(
- tensorCast, tensorCast.getType().cast<RankedTensorType>(),
- extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
+ tensorCast, rankedResultType, extractOperand.getSource(),
+ extractOperand.getMixedOffsets(), sizes,
extractOperand.getMixedStrides());
return success();
}
return failure();
// Skip static dims. These are folded to constant ops.
- TensorType resultType = expandShapeOp.getResultType();
+ RankedTensorType resultType = expandShapeOp.getResultType();
if (!resultType.isDynamicDim(*dim))
return failure();
return failure();
// Skip static dims. These are folded to constant ops.
- TensorType resultType = collapseShapeOp.getResultType();
+ RankedTensorType resultType = collapseShapeOp.getResultType();
if (!resultType.isDynamicDim(*dim))
return failure();
// Asking the dimension of a 0-D shape doesn't make sense.
func.func @dim_0_ranked(%arg : tensor<f32>, %arg1 : index) {
- tensor.dim %arg, %arg1 : tensor<f32> // expected-error {{'tensor.dim' op operand #0 must be unranked.tensor of any type values or non-0-ranked.tensor of any type values, but got 'tensor<f32>'}}
+ tensor.dim %arg, %arg1 : tensor<f32> // expected-error {{'tensor.dim' op operand #0 must be non-0-ranked or unranked tensor, but got 'tensor<f32>'}}
return
}
// -----
func.func @tensor.from_elements_wrong_result_type() {
- // expected-error@+2 {{'result' must be statically shaped tensor of any type values, but got 'tensor<*xi32>'}}
+ // expected-error@+2 {{'tensor.from_elements' invalid kind of type specified}}
%c0 = arith.constant 0 : i32
%0 = tensor.from_elements %c0 : tensor<*xi32>
return