From: Frederik Gossen Date: Fri, 24 Jul 2020 08:34:00 +0000 (+0000) Subject: [MLIR][Shape] Allow `shape.get_extent` to operate on extent tensors X-Git-Tag: llvmorg-13-init~16915 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=0e1a42efd8b8702a6adcb09802b84bdd8e727a19;p=platform%2Fupstream%2Fllvm.git [MLIR][Shape] Allow `shape.get_extent` to operate on extent tensors `shape.get_extent` now accepts extent tensors `tensor` as an argument. Differential Revision: https://reviews.llvm.org/D84158 --- diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 1bdfd9a..2302c51 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -228,17 +228,15 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { } def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> { - let summary = "Gets the specified extent from a shape"; + let summary = "Gets the specified extent from a shape or extent tensor"; let description = [{ - Gets the extent indexed by `dim` from `shape`. - If the shape is an error, it returns an error size. + Gets the extent indexed by `dim` from the `shape` operand. If the shape is + an error then it returns an error size. }]; - let arguments = (ins - Shape_ShapeType:$shape, - Shape_SizeType:$dim - ); + let arguments = (ins Shape_ShapeOrExtentTensorType:$shape, + Shape_SizeType:$dim); let results = (outs Shape_SizeType:$extent); - let assemblyFormat = "$shape `,` $dim attr-dict"; + let assemblyFormat = "$shape `,` $dim `:` type($shape) attr-dict"; let builders = [ // Builder that allows passing a constant dimension as a simple integer. diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir index 0619a73..f50b653 100644 --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -139,7 +139,7 @@ func @get_extent_shape_of(%arg : tensor<2x3xf32>, %idx : !shape.size) // CHECK: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x3xf32> // CHECK: return %[[RESULT]] : index %shape = shape.shape_of %arg : tensor<2x3xf32> - %result = shape.get_extent %shape, %idx + %result = shape.get_extent %shape, %idx : !shape.shape return %result : !shape.size } @@ -154,7 +154,7 @@ func @get_extent_from_extent_tensor(%extents : tensor, // CHECK: %[[RESULT:.*]] = extract_element %[[EXTENTS]][%[[IDX]]] : tensor // CHECK: return %[[RESULT]] : index %shape = shape.from_extent_tensor %extents : tensor - %result = shape.get_extent %shape, %idx + %result = shape.get_extent %shape, %idx : !shape.shape return %result : !shape.size } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 20f21bb..9e691b8 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -239,9 +239,9 @@ func @nonfoldable_num_elements(%shape : !shape.shape) -> !shape.size { // CHECK-LABEL: func @basic func @basic() -> !shape.size { // CHECK: shape.const_size 2 - %0 = shape.const_shape [0, 1, 2] : !shape.shape + %0 = shape.const_shape [0, 1, 2] : tensor %c2 = shape.const_size 2 - %1 = shape.get_extent %0, %c2 + %1 = shape.get_extent %0, %c2 : tensor return %1 : !shape.size } @@ -252,9 +252,9 @@ func @basic() -> !shape.size { func @out_of_bounds() -> !shape.size { // CHECK: shape.const_shape // CHECK: shape.get_extent - %0 = shape.const_shape [0, 1, 2] : !shape.shape + %0 = shape.const_shape [0, 1, 2] : tensor %c3 = shape.const_size 3 - %1 = shape.get_extent %0, %c3 + %1 = shape.get_extent %0, %c3 : tensor return %1 : !shape.size } @@ -262,10 +262,10 @@ func @out_of_bounds() -> !shape.size { // Should not fold. // CHECK-LABEL: func @not_const -func @not_const(%arg0: !shape.shape) -> !shape.size { +func @not_const(%arg0: tensor) -> !shape.size { // CHECK: shape.get_extent %c3 = shape.const_size 3 - %0 = shape.get_extent %arg0, %c3 + %0 = shape.get_extent %arg0, %c3 : tensor return %0 : !shape.size } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir index aace26d..66b5834 100644 --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -161,3 +161,15 @@ func @shape_eq_on_mixed(%a : tensor, %b : !shape.shape) -> i1 { %result = shape.shape_eq %a, %b : tensor, !shape.shape return %result : i1 } + +func @get_extent_on_shape(%arg : !shape.shape) -> !shape.size { + %c0 = shape.const_size 0 + %result = shape.get_extent %arg, %c0 : !shape.shape + return %result : !shape.size +} + +func @get_extent_on_extent_tensor(%arg : tensor) -> !shape.size { + %c0 = shape.const_size 0 + %result = shape.get_extent %arg, %c0 : tensor + return %result : !shape.size +}