From f67d57c95f50fabdfa0bbd454faa564f5059d5f4 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Tue, 12 Oct 2021 18:01:54 +0200 Subject: [PATCH] [mlir][Shape] Add a pattern to turn extract from shape_of into tensor.dim If I remember correctly this wasn't done previously because dim used to be in the memref dialect. Differential Revision: https://reviews.llvm.org/D111651 --- mlir/lib/Dialect/Shape/IR/Shape.cpp | 3 ++- mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td | 9 +++++++++ mlir/test/Dialect/Shape/canonicalize.mlir | 14 ++++++++++++++ 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 62b4e02..0a5da5d 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1473,7 +1473,8 @@ struct ShapeOfCastExtentTensor : public OpRewritePattern { void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns.add(context); } LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes( diff --git a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td index 7460dc5..0825f0f 100644 --- a/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td +++ b/mlir/lib/Dialect/Shape/IR/ShapeCanonicalization.td @@ -16,6 +16,9 @@ def HasStaticShape : Constraint().hasStaticShape() }]>>; +// Helper that takes the first element of a range. +def TakeFront : NativeCodeCall<"$0.front()">; + // Canonicalization patterns. def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args), @@ -43,3 +46,9 @@ def SizeToIndexToSizeCanonicalization : Pat< def TensorCastConstShape : Pat < (Tensor_CastOp:$res (Shape_ConstShapeOp $arg)), (Shape_ConstShapeOp $arg), [(HasStaticShape $res)]>; + +// tensor.extract from shape_of -> tensor.dim. We can take the first index +// because shape_of always returns a 1D tensor. +def ExtractFromShapeOfExtentTensor : Pat< + (Tensor_ExtractOp (Shape_ShapeOfOp $arg), $indices), + (Tensor_DimOp $arg, (TakeFront $indices))>; diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index b0c2181..a6b93e8 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1380,3 +1380,17 @@ func @concretize_broadcast_result_type(%arg0 : tensor<2xindex>, -> tensor return %0 : tensor } + +// ----- + +// CHECK-LABEL: func @extract_shapeof +// CHECK-SAME: %[[ARG0:.*]]: tensor +func @extract_shapeof(%arg0 : tensor) -> index { + %c1 = constant 1 : index +// CHECK: %[[C1:.*]] = constant 1 + %shape = shape.shape_of %arg0 : tensor -> tensor<2xindex> +// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]] + %result = tensor.extract %shape[%c1] : tensor<2xindex> +// CHECK: return %[[DIM]] + return %result : index +} -- 2.7.4