[MLIR][Shape] Allow `shape.rank` to accept extent tensors `tensor?xindex>`
authorFrederik Gossen <frgossen@google.com>
Mon, 20 Jul 2020 14:46:18 +0000 (14:46 +0000)
committerFrederik Gossen <frgossen@google.com>
Mon, 20 Jul 2020 14:47:19 +0000 (14:47 +0000)
Differential Revision: https://reviews.llvm.org/D84156

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/ops.mlir

index 46400c8..703353c 100644 (file)
@@ -195,13 +195,13 @@ def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", [NoSideEffect]> {
 def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
   let summary = "Gets the rank of a shape";
   let description = [{
-    Returns the rank of the shape, i.e. the number of extents.
+    Returns the rank of the shape or extent tensor, i.e. the number of extents.
   }];
 
-  let arguments = (ins Shape_ShapeType:$shape);
+  let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
   let results = (outs Shape_SizeType:$rank);
 
-  let assemblyFormat = "attr-dict $shape";
+  let assemblyFormat = "$shape `:` type($shape) attr-dict";
 
   let hasFolder = 1;
   let hasCanonicalizer = 1;
index 2220663..0619a73 100644 (file)
@@ -124,7 +124,7 @@ func @rank(%shape : !shape.shape) -> !shape.size {
   // CHECK-DAG: %[[C0:.*]] = constant 0 : index
   // CHECK-DAG: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]]
   // CHECK-DAG: return %[[RESULT]] : index
-  %rank = shape.rank %shape
+  %rank = shape.rank %shape : !shape.shape
   return %rank : !shape.size
 }
 
index 156063e..80b7cb9 100644 (file)
@@ -499,7 +499,7 @@ func @fold_rank() -> !shape.size {
   // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5
   // CHECK-DAG: return %[[RESULT]] : !shape.size
   %shape = shape.const_shape [3, 4, 5, 6, 7]
-  %rank = shape.rank %shape
+  %rank = shape.rank %shape : !shape.shape
   return %rank : !shape.size
 }
 
@@ -511,7 +511,7 @@ func @fold_rank() -> !shape.size {
 func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
   // CHECK-DAG: %[[RESULT:.*]] = shape.rank %[[SHAPE]]
   // CHECK-DAG: return %[[RESULT]] : !shape.size
-  %rank = shape.rank %shape
+  %rank = shape.rank %shape : !shape.shape
   return %rank : !shape.size
 }
 
@@ -520,11 +520,11 @@ func @dont_fold_rank(%shape : !shape.shape) -> !shape.size {
 // Canonicalize `rank` when shape is derived from ranked tensor.
 // CHECK-LABEL: @canonicalize_rank
 func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size {
-// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3
-// CHECK-DAG: return %[[RESULT]] : !shape.size
-%shape = shape.shape_of %arg : tensor<1x2x?xf32>
-%rank = shape.rank %shape
-return %rank : !shape.size
+  // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3
+  // CHECK-DAG: return %[[RESULT]] : !shape.size
+  %shape = shape.shape_of %arg : tensor<1x2x?xf32>
+  %rank = shape.rank %shape : !shape.shape
+  return %rank : !shape.size
 }
 
 // -----
@@ -533,12 +533,12 @@ return %rank : !shape.size
 // CHECK-LABEL: @dont_canonicalize_rank
 // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size
 func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size {
-// CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32>
-// CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
-// CHECK-DAG: return %[[SIZE]] : !shape.size
-%shape = shape.shape_of %arg : tensor<*xf32>
-%rank = shape.rank %shape
-return %rank : !shape.size
+  // CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32>
+  // CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]]
+  // CHECK-DAG: return %[[SIZE]] : !shape.size
+  %shape = shape.shape_of %arg : tensor<*xf32>
+  %rank = shape.rank %shape : !shape.shape
+  return %rank : !shape.size
 }
 
 // Canonicalize redundant conversion from `index` to `size` and back.
index 30cf29a..1187d7a 100644 (file)
@@ -130,10 +130,16 @@ func @test_from_extent_tensor(%arg: tensor<?xindex>) -> !shape.shape {
 }
 
 func @rank(%shape : !shape.shape) -> !shape.size {
-  %rank = shape.rank %shape
+  %rank = shape.rank %shape : !shape.shape
   return %rank : !shape.size
 }
 
+func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> !shape.size {
+  %rank = shape.rank %shape : tensor<?xindex>
+  return %rank : !shape.size
+}
+
+
 func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 {
   %result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
   return %result : i1