From 58c296a418cb2a3dbd542a39b5077eb32e1f1895 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 6 Apr 2023 13:21:00 +0900 Subject: [PATCH] [mlir] Use BaseMemRefType for ranked/unranked memrefs This makes `RankedOrUnrankedMemRefOf` consistent with `TensorOf`. Differential Revision: https://reviews.llvm.org/D147160 --- mlir/include/mlir/IR/OpBase.td | 19 ++++++++++++++----- mlir/test/Dialect/MemRef/invalid.mlir | 2 +- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index f7f009c..98866c8 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -801,7 +801,7 @@ def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>; //===----------------------------------------------------------------------===// // Memref type. -// Unranked Memref type +// Any unranked memref whose element type is from the given `allowedTypes` list. class UnrankedMemRefOf allowedTypes> : ShapedContainerType allowedTypes> : def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>; -// Memrefs are blocks of data with fixed type and rank. +// Any ranked memref whose element type is from the given `allowedTypes` list. class MemRefOf allowedTypes> : ShapedContainerType; + class Non0RankedMemRefOf allowedTypes> : ConfinedType, [HasRankGreaterOrEqualPred<1>], "non-0-ranked." # MemRefOf.summary, @@ -821,10 +822,18 @@ class Non0RankedMemRefOf allowedTypes> : def AnyMemRef : MemRefOf<[AnyType]>; def AnyNon0RankedMemRef : Non0RankedMemRefOf<[AnyType]>; -class RankedOrUnrankedMemRefOf allowedTypes>: - AnyTypeOf<[UnrankedMemRefOf, MemRefOf]>; +// Any memref (ranked or unranked) whose element type is from the given +// `allowedTypes` list, and which additionally satisfies an optional list of +// predicates. +class RankedOrUnrankedMemRefOf< + list allowedTypes, + list preds = [], + string summary = "ranked or unranked memref"> + : ShapedContainerType, + summary, "::mlir::BaseMemRefType">; -def AnyRankedOrUnrankedMemRef: AnyTypeOf<[AnyUnrankedMemRef, AnyMemRef]>; +def AnyRankedOrUnrankedMemRef : RankedOrUnrankedMemRefOf<[AnyType]>; def AnyNon0RankedOrUnrankedMemRef: AnyTypeOf<[AnyUnrankedMemRef, AnyNon0RankedMemRef]>; diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 19874f0..37ac1ca 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -910,7 +910,7 @@ func.func @test_alloc_memref_map_rank_mismatch() { // ----- func.func @rank(%0: f32) { - // expected-error@+1 {{'memref.rank' op operand #0 must be unranked.memref of any type values or memref of any type values}} + // expected-error@+1 {{'memref.rank' op operand #0 must be ranked or unranked memref of any type values}} "memref.rank"(%0): (f32)->index return } -- 2.7.4