[MLIR][Shape] Allow `shape.get_extent` to operate on extent tensors
authorFrederik Gossen <frgossen@google.com>
Fri, 24 Jul 2020 08:34:00 +0000 (08:34 +0000)
committerFrederik Gossen <frgossen@google.com>
Fri, 24 Jul 2020 08:34:37 +0000 (08:34 +0000)
`shape.get_extent` now accepts extent tensors `tensor<?xindex>` as an argument.

Differential Revision: https://reviews.llvm.org/D84158

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir
mlir/test/Dialect/Shape/canonicalize.mlir
mlir/test/Dialect/Shape/ops.mlir

index 1bdfd9a..2302c51 100644 (file)
@@ -228,17 +228,15 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
 }
 
 def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
-  let summary = "Gets the specified extent from a shape";
+  let summary = "Gets the specified extent from a shape or extent tensor";
   let description = [{
-    Gets the extent indexed by `dim` from `shape`.
-    If the shape is an error, it returns an error size.
+    Gets the extent indexed by `dim` from the `shape` operand. If the shape is
+    an error then it returns an error size.
   }];
-  let arguments = (ins
-    Shape_ShapeType:$shape,
-    Shape_SizeType:$dim
-  );
+  let arguments = (ins Shape_ShapeOrExtentTensorType:$shape,
+                       Shape_SizeType:$dim);
   let results = (outs Shape_SizeType:$extent);
-  let assemblyFormat = "$shape `,` $dim attr-dict";
+  let assemblyFormat = "$shape `,` $dim `:` type($shape)  attr-dict";
 
   let builders = [
     // Builder that allows passing a constant dimension as a simple integer.
index 0619a73..f50b653 100644 (file)
@@ -139,7 +139,7 @@ func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size)
   // CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32>
   // CHECK: return %[[RESULT]] : index
   %shape = shape.shape_of %arg : tensor<2x3xf32>
-  %result = shape.get_extent %shape, %idx
+  %result = shape.get_extent %shape, %idx : !shape.shape
   return %result : !shape.size
 }
 
@@ -154,7 +154,7 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
   // CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor<?xindex>
   // CHECK: return %[[RESULT]] : index
   %shape = shape.from_extent_tensor %extents : tensor<?xindex>
-  %result = shape.get_extent %shape, %idx
+  %result = shape.get_extent %shape, %idx : !shape.shape
   return %result : !shape.size
 }
 
index 20f21bb..9e691b8 100644 (file)
@@ -239,9 +239,9 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size {
 // CHECK-LABEL: func @basic
 func @basic() -> !shape.size {
   // CHECK: shape.const_size 2
-  %0 = shape.const_shape [0, 1, 2] : !shape.shape
+  %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
   %c2 = shape.const_size 2
-  %1 = shape.get_extent %0, %c2
+  %1 = shape.get_extent %0, %c2 : tensor<?xindex>
   return %1 : !shape.size
 }
 
@@ -252,9 +252,9 @@ func @basic() -> !shape.size {
 func @out_of_bounds() -> !shape.size {
   // CHECK: shape.const_shape
   // CHECK: shape.get_extent
-  %0 = shape.const_shape [0, 1, 2] : !shape.shape
+  %0 = shape.const_shape [0, 1, 2] : tensor<?xindex>
   %c3 = shape.const_size 3
-  %1 = shape.get_extent %0, %c3
+  %1 = shape.get_extent %0, %c3 : tensor<?xindex>
   return %1 : !shape.size
 }
 
@@ -262,10 +262,10 @@ func @out_of_bounds() -> !shape.size {
 
 // Should not fold.
 // CHECK-LABEL: func @not_const
-func @not_const(%arg0: !shape.shape) -> !shape.size {
+func @not_const(%arg0: tensor<?xindex>) -> !shape.size {
   // CHECK: shape.get_extent
   %c3 = shape.const_size 3
-  %0 = shape.get_extent %arg0, %c3
+  %0 = shape.get_extent %arg0, %c3 : tensor<?xindex>
   return %0 : !shape.size
 }
 
index aace26d..66b5834 100644 (file)
@@ -161,3 +161,15 @@ func @shape_eq_on_mixed(%a : tensor<?xindex>, %b : !shape.shape) -> i1 {
   %result = shape.shape_eq %a, %b : tensor<?xindex>, !shape.shape
   return %result : i1
 }
+
+func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size {
+  %c0 = shape.const_size 0
+  %result = shape.get_extent %arg, %c0 : !shape.shape
+  return %result : !shape.size
+}
+
+func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> !shape.size {
+  %c0 = shape.const_size 0
+  %result = shape.get_extent %arg, %c0 : tensor<?xindex>
+  return %result : !shape.size
+}