From 7bca97d960ab9451185a997208057a89355b406a Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 25 Jun 2020 08:37:18 +0000 Subject: [PATCH] [MLIR][Shape] Add canonicalization pattern for `shape.rank` Replace any `rank(shape_of(tensor))` that relies on a ranked tensor with the corresponding constant `const_size`. Differential Revision: https://reviews.llvm.org/D82077 --- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td | 5 +++ mlir/lib/Dialect/Shape/IR/Shape.cpp | 44 ++++++++++++++++++++++++++ mlir/test/Dialect/Shape/canonicalize.mlir | 26 +++++++++++++++ 3 files changed, 75 insertions(+) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 379f861..2430fe6 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -130,6 +130,10 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [ let arguments = (ins IndexAttr:$value); let results = (outs Shape_SizeType:$result); + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &result, int64_t value"> + ]; + let assemblyFormat = "$value attr-dict"; let hasFolder = 1; } @@ -181,6 +185,7 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> { let assemblyFormat = "attr-dict $shape"; let hasFolder = 1; + let hasCanonicalizer = 1; } def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index cdbc892..2d952183 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -364,6 +364,11 @@ OpFoldResult CstrEqOp::fold(ArrayRef operands) { // ConstSizeOp //===----------------------------------------------------------------------===// +void ConstSizeOp::build(OpBuilder &builder, OperationState &result, + int64_t value) { + build(builder, result, builder.getIndexAttr(value)); +} + OpFoldResult ConstSizeOp::fold(ArrayRef) { return valueAttr(); } void ConstSizeOp::getAsmResultNames( @@ -450,6 +455,45 @@ OpFoldResult RankOp::fold(ArrayRef operands) { return builder.getIndexAttr(rank); } +/// Evaluate the `rank` operation for shapes of ranked tensors at compile time. +/// Constant folding fails in cases where only the rank is constant, not the +/// shape itself. +/// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`. +/// +/// Example: +/// +/// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32> +/// %rank = shape.rank %shape +/// +/// becomes +/// +/// %rank = shape.const_size 3 + +namespace { +struct RankShapeOfCanonicalizationPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(RankOp op, + PatternRewriter &rewriter) const override { + auto shapeOfOp = op.shape().getDefiningOp(); + if (!shapeOfOp) + return failure(); + auto rankedTensorType = + shapeOfOp.arg().getType().dyn_cast(); + if (!rankedTensorType) + return failure(); + int64_t rank = rankedTensorType.getRank(); + rewriter.replaceOpWithNewOp(op.getOperation(), rank); + return success(); + } +}; +} // namespace + +void RankOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, + MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 00f6b36..9fb48e6 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -466,3 +466,29 @@ func @dont_fold_rank(%shape : !shape.shape) -> !shape.size { %rank = shape.rank %shape return %rank : !shape.size } + +// ----- + +// Canonicalize `rank` when shape is derived from ranked tensor. +// CHECK-LABEL: @canonicalize_rank +func @canonicalize_rank(%arg : tensor<1x2x?xf32>) -> !shape.size { +// CHECK-DAG: %[[RESULT:.*]] = shape.const_size 3 +// CHECK-DAG: return %[[RESULT]] : !shape.size +%shape = shape.shape_of %arg : tensor<1x2x?xf32> +%rank = shape.rank %shape +return %rank : !shape.size +} + +// ----- + +// Do not canonicalize `rank` when shape is derived from unranked tensor. +// CHECK-LABEL: @dont_canonicalize_rank +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) -> !shape.size +func @dont_canonicalize_rank(%arg : tensor<*xf32>) -> !shape.size { +// CHECK-DAG: %[[SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<*xf32> +// CHECK-DAG: %[[SIZE:.*]] = shape.rank %[[SHAPE]] +// CHECK-DAG: return %[[SIZE]] : !shape.size +%shape = shape.shape_of %arg : tensor<*xf32> +%rank = shape.rank %shape +return %rank : !shape.size +} -- 2.7.4