[mlir][Standard] Allow unranked memrefs as operands to dim and rank
authorStephan Herhut <herhut@google.com>
Wed, 29 Jul 2020 10:50:05 +0000 (12:50 +0200)
committerStephan Herhut <herhut@google.com>
Wed, 29 Jul 2020 12:42:58 +0000 (14:42 +0200)
`std.dim` currently only accepts ranked memrefs and `std.rank` is limited to
tensors.

Differential Revision: https://reviews.llvm.org/D84790

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/IR/invalid-ops.mlir
mlir/test/Transforms/constant-fold.mlir

index 78307b8..d9634fa 100644 (file)
@@ -1409,7 +1409,7 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
     ```
   }];
 
-  let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor],
+  let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor],
                                  "any tensor or memref type">:$memrefOrTensor,
                        Index:$index);
   let results = (outs Index:$result);
@@ -2024,16 +2024,18 @@ def PrefetchOp : Std_Op<"prefetch"> {
 def RankOp : Std_Op<"rank", [NoSideEffect]> {
   let summary = "rank operation";
   let description = [{
-    The `rank` operation takes a tensor operand and returns its rank.
+    The `rank` operation takes a memref/tensor operand and returns its rank.
 
     Example:
 
     ```mlir
-    %1 = rank %0 : tensor<*xf32>
+    %1 = rank %arg0 : tensor<*xf32>
+    %2 = rank %arg1 : memref<*xf32>
     ```
   }];
 
-  let arguments = (ins AnyTensor);
+  let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor],
+                                 "any tensor or memref type">:$memrefOrTensor);
   let results = (outs Index);
   let verifier = ?;
 
@@ -2044,7 +2046,7 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> {
     }]>];
 
   let hasFolder = 1;
-  let assemblyFormat = "operands attr-dict `:` type(operands)";
+  let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)";
 }
 
 //===----------------------------------------------------------------------===//
index 84c35c9..a67e79a 100644 (file)
@@ -2039,10 +2039,12 @@ LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
 //===----------------------------------------------------------------------===//
 
 OpFoldResult RankOp::fold(ArrayRef<Attribute> operands) {
-  // Constant fold rank when the rank of the tensor is known.
+  // Constant fold rank when the rank of the operand is known.
   auto type = getOperand().getType();
-  if (auto tensorType = type.dyn_cast<RankedTensorType>())
-    return IntegerAttr::get(IndexType::get(getContext()), tensorType.getRank());
+  if (auto shapedType = type.dyn_cast<ShapedType>())
+    if (shapedType.hasRank())
+      return IntegerAttr::get(IndexType::get(getContext()),
+                              shapedType.getRank());
   return IntegerAttr();
 }
 
index 3668c25..6302a8a 100644 (file)
@@ -10,7 +10,7 @@ func @dim(%arg : tensor<1x?xf32>) {
 
 func @rank(f32) {
 ^bb(%0: f32):
-  "std.rank"(%0): (f32)->index // expected-error {{'std.rank' op operand #0 must be tensor of any type values}}
+  "std.rank"(%0): (f32)->index // expected-error {{'std.rank' op operand #0 must be any tensor or memref type}}
   return
 }
 
index 0677b95..36fa234 100644 (file)
@@ -686,6 +686,18 @@ func @fold_rank() -> (index) {
 
 // -----
 
+// CHECK-LABEL: func @fold_rank_memref
+func @fold_rank_memref(%arg0 : memref<?x?xf32>) -> (index) {
+  // Fold a rank into a constant
+  // CHECK-NEXT: [[C2:%.+]] = constant 2 : index
+  %rank_0 = rank %arg0 : memref<?x?xf32>
+
+  // CHECK-NEXT: return [[C2]]
+  return %rank_0 : index
+}
+
+// -----
+
 // CHECK-LABEL: func @nested_isolated_region
 func @nested_isolated_region() {
   // CHECK-NEXT: func @isolated_op