[mlir] Use RankedTensorType when rank is required
authorMatthias Springer <springerm@google.com>
Thu, 6 Apr 2023 04:20:41 +0000 (13:20 +0900)
committerMatthias Springer <springerm@google.com>
Thu, 6 Apr 2023 04:22:10 +0000 (13:22 +0900)
`RankedTensorOf` and `TensorRankOf` (in Tablegen files) now generate code that uses `RankedTensorType` instead of `TensorType`. This gives us more accurate type information (e.g., when calling `op.getType()`).

Also use restrict tensor.expand_shape/tensor.collapse_shape/tensor.pad to ranked tensors. Only cast ops should deal with unranked tensors.

Also improves a few places in the code base (e.g., Toy tutorial) where a ranked tensor is assumed (e.g., because `getRank` is called) but a `TensorType` is currently used: cast to `RankedTensorType` directly, so that the assertion is triggered directly at the cast.

Differential Revision: https://reviews.llvm.org/D147149

14 files changed:
mlir/examples/toy/Ch2/mlir/Dialect.cpp
mlir/examples/toy/Ch3/mlir/Dialect.cpp
mlir/examples/toy/Ch4/mlir/Dialect.cpp
mlir/examples/toy/Ch5/mlir/Dialect.cpp
mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch6/mlir/Dialect.cpp
mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch7/mlir/Dialect.cpp
mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
mlir/include/mlir/IR/OpBase.td
mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRV.cpp
mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
mlir/test/Dialect/Tensor/invalid.mlir

index a6ccbbf..ef07af2 100644 (file)
@@ -139,7 +139,7 @@ mlir::LogicalResult ConstantOp::verify() {
 
   // Check that the rank of the attribute type matches the rank of the constant
   // result type.
-  auto attrType = getValue().getType().cast<mlir::TensorType>();
+  auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
   if (attrType.getRank() != resultType.getRank()) {
     return emitOpError("return type must match the one of the attached value "
                        "attribute: ")
index 913979a..43f8d5b 100644 (file)
@@ -139,7 +139,7 @@ mlir::LogicalResult ConstantOp::verify() {
 
   // Check that the rank of the attribute type matches the rank of the constant
   // result type.
-  auto attrType = getValue().getType().cast<mlir::TensorType>();
+  auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
   if (attrType.getRank() != resultType.getRank()) {
     return emitOpError("return type must match the one of the attached value "
                        "attribute: ")
index f5258eb..75a5171 100644 (file)
@@ -199,7 +199,7 @@ mlir::LogicalResult ConstantOp::verify() {
 
   // Check that the rank of the attribute type matches the rank of the constant
   // result type.
-  auto attrType = getValue().getType().cast<mlir::TensorType>();
+  auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
   if (attrType.getRank() != resultType.getRank()) {
     return emitOpError("return type must match the one of the attached value "
                        "attribute: ")
index a959969..98c8eb5 100644 (file)
@@ -199,7 +199,7 @@ mlir::LogicalResult ConstantOp::verify() {
 
   // Check that the rank of the attribute type matches the rank of the constant
   // result type.
-  auto attrType = getValue().getType().cast<mlir::TensorType>();
+  auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
   if (attrType.getRank() != resultType.getRank()) {
     return emitOpError("return type must match the one of the attached value "
                        "attribute: ")
index c52f5bd..a40353e 100644 (file)
@@ -30,9 +30,8 @@ using namespace mlir;
 // ToyToAffine RewritePatterns
 //===----------------------------------------------------------------------===//
 
-/// Convert the given TensorType into the corresponding MemRefType.
-static MemRefType convertTensorToMemRef(TensorType type) {
-  assert(type.hasRank() && "expected only ranked shapes");
+/// Convert the given RankedTensorType into the corresponding MemRefType.
+static MemRefType convertTensorToMemRef(RankedTensorType type) {
   return MemRefType::get(type.getShape(), type.getElementType());
 }
 
@@ -63,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
 static void lowerOpToLoops(Operation *op, ValueRange operands,
                            PatternRewriter &rewriter,
                            LoopIterationFn processIteration) {
-  auto tensorType = (*op->result_type_begin()).cast<TensorType>();
+  auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
   auto loc = op->getLoc();
 
   // Insert an allocation and deallocation for the result of this operation.
@@ -144,7 +143,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
 
     // When lowering the constant operation, we allocate and assign the constant
     // values to a corresponding memref allocation.
-    auto tensorType = op.getType().cast<TensorType>();
+    auto tensorType = op.getType().cast<RankedTensorType>();
     auto memRefType = convertTensorToMemRef(tensorType);
     auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
 
index a959969..98c8eb5 100644 (file)
@@ -199,7 +199,7 @@ mlir::LogicalResult ConstantOp::verify() {
 
   // Check that the rank of the attribute type matches the rank of the constant
   // result type.
-  auto attrType = getValue().getType().cast<mlir::TensorType>();
+  auto attrType = getValue().getType().cast<mlir::RankedTensorType>();
   if (attrType.getRank() != resultType.getRank()) {
     return emitOpError("return type must match the one of the attached value "
                        "attribute: ")
index c52f5bd..a40353e 100644 (file)
@@ -30,9 +30,8 @@ using namespace mlir;
 // ToyToAffine RewritePatterns
 //===----------------------------------------------------------------------===//
 
-/// Convert the given TensorType into the corresponding MemRefType.
-static MemRefType convertTensorToMemRef(TensorType type) {
-  assert(type.hasRank() && "expected only ranked shapes");
+/// Convert the given RankedTensorType into the corresponding MemRefType.
+static MemRefType convertTensorToMemRef(RankedTensorType type) {
   return MemRefType::get(type.getShape(), type.getElementType());
 }
 
@@ -63,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
 static void lowerOpToLoops(Operation *op, ValueRange operands,
                            PatternRewriter &rewriter,
                            LoopIterationFn processIteration) {
-  auto tensorType = (*op->result_type_begin()).cast<TensorType>();
+  auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
   auto loc = op->getLoc();
 
   // Insert an allocation and deallocation for the result of this operation.
@@ -144,7 +143,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
 
     // When lowering the constant operation, we allocate and assign the constant
     // values to a corresponding memref allocation.
-    auto tensorType = op.getType().cast<TensorType>();
+    auto tensorType = op.getType().cast<RankedTensorType>();
     auto memRefType = convertTensorToMemRef(tensorType);
     auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
 
index d332411..5fcb0be 100644 (file)
@@ -195,7 +195,7 @@ static mlir::LogicalResult verifyConstantForType(mlir::Type type,
 
     // Check that the rank of the attribute type matches the rank of the
     // constant result type.
-    auto attrType = attrValue.getType().cast<mlir::TensorType>();
+    auto attrType = attrValue.getType().cast<mlir::RankedTensorType>();
     if (attrType.getRank() != resultType.getRank()) {
       return op->emitOpError("return type must match the one of the attached "
                              "value attribute: ")
index c52f5bd..a40353e 100644 (file)
@@ -30,9 +30,8 @@ using namespace mlir;
 // ToyToAffine RewritePatterns
 //===----------------------------------------------------------------------===//
 
-/// Convert the given TensorType into the corresponding MemRefType.
-static MemRefType convertTensorToMemRef(TensorType type) {
-  assert(type.hasRank() && "expected only ranked shapes");
+/// Convert the given RankedTensorType into the corresponding MemRefType.
+static MemRefType convertTensorToMemRef(RankedTensorType type) {
   return MemRefType::get(type.getShape(), type.getElementType());
 }
 
@@ -63,7 +62,7 @@ using LoopIterationFn = function_ref<Value(
 static void lowerOpToLoops(Operation *op, ValueRange operands,
                            PatternRewriter &rewriter,
                            LoopIterationFn processIteration) {
-  auto tensorType = (*op->result_type_begin()).cast<TensorType>();
+  auto tensorType = (*op->result_type_begin()).cast<RankedTensorType>();
   auto loc = op->getLoc();
 
   // Insert an allocation and deallocation for the result of this operation.
@@ -144,7 +143,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> {
 
     // When lowering the constant operation, we allocate and assign the constant
     // values to a corresponding memref allocation.
-    auto tensorType = op.getType().cast<TensorType>();
+    auto tensorType = op.getType().cast<RankedTensorType>();
     auto memRefType = convertTensorToMemRef(tensorType);
     auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
 
index 65b2c12..d106f12 100644 (file)
@@ -974,8 +974,8 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
     Tensor_Op<mnemonic, !listconcat(traits, [
       DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
       Pure])>,
-    Arguments<(ins AnyTensor:$src, IndexListArrayAttr:$reassociation)>,
-    Results<(outs AnyTensor:$result)> {
+    Arguments<(ins AnyRankedTensor:$src, IndexListArrayAttr:$reassociation)>,
+    Results<(outs AnyRankedTensor:$result)> {
 
   code commonExtraClassDeclaration = [{
     static StringRef getReassociationAttrStrName() { return "reassociation"; }
@@ -1210,7 +1210,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [
   }];
 
   let arguments = (ins
-    AnyTensor:$source,
+    AnyRankedTensor:$source,
     Variadic<Index>:$low,
     Variadic<Index>:$high,
     DenseI64ArrayAttr:$static_low,
@@ -1219,7 +1219,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [
 
   let regions = (region SizedRegion<1>:$region);
 
-  let results = (outs AnyTensor:$result);
+  let results = (outs AnyRankedTensor:$result);
 
   // TODO: Remove custom<InferType> when AllTypesMatch supports opt. operands.
   let assemblyFormat = [{
@@ -1678,8 +1678,8 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
                    "$_self">])> {
 
   code commonExtraClassDeclaration = [{
-    size_t getSourceRank() { return getSource().getType().getRank(); };
-    size_t getDestRank() { return getDest().getType().getRank(); };
+    size_t getSourceRank() { return getSourceType().getRank(); };
+    size_t getDestRank() { return getDestType().getRank(); };
     RankedTensorType getSourceType() {
       return getSource().getType().cast<RankedTensorType>(); };
     RankedTensorType getDestType() {
index 554f026..f7f009c 100644 (file)
@@ -240,6 +240,10 @@ def IsUnrankedMemRefTypePred
 def IsUnrankedTensorTypePred
         : CPred<"$_self.isa<::mlir::UnrankedTensorType>()">;
 
+// Whether a type is a RankedTensorType
+def IsRankedTensorTypePred
+        : CPred<"$_self.isa<::mlir::RankedTensorType>()">;
+
 // Whether a type is a BaseMemRefType
 def IsBaseMemRefTypePred
         : CPred<"$_self.isa<::mlir::BaseMemRefType>()">;
@@ -721,11 +725,21 @@ def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
 //===----------------------------------------------------------------------===//
 // Tensor types.
 
-// Unranked tensor type whose element type is from the given
-// `allowedTypes` list.
-class UnrankedTensorOf<list<Type> allowedTypes>
-  : ShapedContainerType<allowedTypes, IsUnrankedTensorTypePred,
-      "unranked.tensor", "::mlir::UnrankedTensorType">;
+// Unranked tensor type whose element type is from the given `allowedTypes`
+// list, and which additionally satisfies an optional list of predicates.
+class UnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [],
+                       string summary = "unranked tensor">
+  : ShapedContainerType<
+      allowedTypes, And<!listconcat([IsUnrankedTensorTypePred], preds)>,
+      summary, "::mlir::UnrankedTensorType">;
+
+// Ranked tensor type whose element type is from the given `allowedTypes` list,
+// and which additionally satisfies an optional list of predicates.
+class RankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [],
+                     string summary = "ranked tensor">
+  : ShapedContainerType<
+      allowedTypes, And<!listconcat([IsRankedTensorTypePred], preds)>,
+      summary, "::mlir::RankedTensorType">;
 
 // Any tensor type whose element type is from the given `allowedTypes`
 // list, and which additionally satisfies an optional list of predicates.
@@ -754,12 +768,6 @@ def F16Tensor  : TensorOf<[F16]>;
 def F32Tensor  : TensorOf<[F32]>;
 def F64Tensor  : TensorOf<[F64]>;
 
-class RankedTensorOf<
-    list<Type> allowedTypes,
-    list<Pred> preds = [],
-    string summary = "ranked tensor">
-  : TensorOf<allowedTypes, !listconcat([HasRankPred], preds), summary>;
-
 class Non0RankedTensorOf<list<Type> allowedTypes>
   : TensorOf<allowedTypes, [HasRankGreaterOrEqualPred<1>],
       "non-0-ranked.tensor">;
@@ -768,12 +776,13 @@ def AnyRankedTensor : RankedTensorOf<[AnyType]>;
 def AnyNon0RankedTensor  : Non0RankedTensorOf<[AnyType]>;
 def AnyUnrankedTensor  : UnrankedTensorOf<[AnyType]>;
 
-def AnyNon0RankedOrUnrankedTensor:
-    AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor]>;
+def AnyNon0RankedOrUnrankedTensor
+  : AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor],
+              "non-0-ranked or unranked tensor", "::mlir::TensorType">;
 
 // Ranked tensor type with one of the specified types and ranks.
 class TensorRankOf<list<Type> allowedTypes, list<int> ranks>
-  : TensorOf<allowedTypes,
+  : RankedTensorOf<allowedTypes,
       [HasAnyRankOfPred<ranks>],
       !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
 
@@ -784,7 +793,8 @@ class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>;
 class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
 
 class StaticShapeTensorOf<list<Type> allowedTypes>
-  : TensorOf<allowedTypes, [HasStaticShapePred], "statically shaped tensor">;
+  : RankedTensorOf<allowedTypes, [HasStaticShapePred],
+                   "statically shaped tensor">;
 
 def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
 
index 776d6c7..ed13ab3 100644 (file)
@@ -44,7 +44,7 @@ public:
   LogicalResult
   matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    TensorType tensorType = extractOp.getTensor().getType().cast<TensorType>();
+    auto tensorType = extractOp.getTensor().getType().cast<RankedTensorType>();
 
     if (!tensorType.hasStaticShape())
       return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
index 7ee9325..ccea6dd 100644 (file)
@@ -369,11 +369,16 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
     auto extractOperand =
         tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
 
+    // Cannot fold cast to unranked tensor.
+    auto rankedResultType = tensorCast.getType().dyn_cast<RankedTensorType>();
+    if (!rankedResultType)
+      return failure();
+
     if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
-        tensorCast.getType().getShape() == tensorCast.getSource()
-                                               .getType()
-                                               .cast<RankedTensorType>()
-                                               .getShape())
+        rankedResultType.getShape() == tensorCast.getSource()
+                                           .getType()
+                                           .cast<RankedTensorType>()
+                                           .getShape())
       return failure();
 
     SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
@@ -383,15 +388,15 @@ struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
     for (size_t i = 0, e = sizes.size(); i < e; i++) {
       if (dimMask && dimMask->count(i))
         continue;
-      int64_t dim = tensorCast.getType().getShape()[dimIndex++];
+      int64_t dim = rankedResultType.getShape()[dimIndex++];
       if (ShapedType::isDynamic(dim))
         continue;
       sizes[i] = rewriter.getIndexAttr(dim);
     }
 
     rewriter.replaceOpWithNewOp<ExtractSliceOp>(
-        tensorCast, tensorCast.getType().cast<RankedTensorType>(),
-        extractOperand.getSource(), extractOperand.getMixedOffsets(), sizes,
+        tensorCast, rankedResultType, extractOperand.getSource(),
+        extractOperand.getMixedOffsets(), sizes,
         extractOperand.getMixedStrides());
     return success();
   }
@@ -1500,7 +1505,7 @@ struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
       return failure();
 
     // Skip static dims. These are folded to constant ops.
-    TensorType resultType = expandShapeOp.getResultType();
+    RankedTensorType resultType = expandShapeOp.getResultType();
     if (!resultType.isDynamicDim(*dim))
       return failure();
 
@@ -1544,7 +1549,7 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
       return failure();
 
     // Skip static dims. These are folded to constant ops.
-    TensorType resultType = collapseShapeOp.getResultType();
+    RankedTensorType resultType = collapseShapeOp.getResultType();
     if (!resultType.isDynamicDim(*dim))
       return failure();
 
index f74bd94..61f03f1 100644 (file)
@@ -2,7 +2,7 @@
 
 // Asking the dimension of a 0-D shape doesn't make sense.
 func.func @dim_0_ranked(%arg : tensor<f32>, %arg1 : index) {
-  tensor.dim %arg, %arg1 : tensor<f32> // expected-error {{'tensor.dim' op operand #0 must be unranked.tensor of any type values or non-0-ranked.tensor of any type values, but got 'tensor<f32>'}}
+  tensor.dim %arg, %arg1 : tensor<f32> // expected-error {{'tensor.dim' op operand #0 must be non-0-ranked or unranked tensor, but got 'tensor<f32>'}}
   return
 }
 
@@ -33,7 +33,7 @@ func.func @insert_too_many_indices(%arg0: f32, %arg1: tensor<?xf32>) {
 // -----
 
 func.func @tensor.from_elements_wrong_result_type() {
-  // expected-error@+2 {{'result' must be statically shaped tensor of any type values, but got 'tensor<*xi32>'}}
+  // expected-error@+2 {{'tensor.from_elements' invalid kind of type specified}}
   %c0 = arith.constant 0 : i32
   %0 = tensor.from_elements %c0 : tensor<*xi32>
   return