From 81469527ec99a3452c867d997d259ea70f81dda5 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Thu, 25 Jun 2020 08:31:49 +0000 Subject: [PATCH] [MLIR][Shape] Add constant folding to `shape.rank` Add constant folding for the `shape.rank` operation of the shape dialect. Differential Revision: https://reviews.llvm.org/D82076 --- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td | 2 ++ mlir/lib/Dialect/Shape/IR/Shape.cpp | 13 +++++++++++++ mlir/test/Dialect/Shape/canonicalize.mlir | 24 ++++++++++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 0785a40..379f861 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -179,6 +179,8 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> { let results = (outs Shape_SizeType:$rank); let assemblyFormat = "attr-dict $shape"; + + let hasFolder = 1; } def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 664c0cb..cdbc892 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -438,6 +438,19 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, } //===----------------------------------------------------------------------===// +// RankOp +//===----------------------------------------------------------------------===// + +OpFoldResult RankOp::fold(ArrayRef operands) { + auto shape = operands[0].dyn_cast_or_null(); + if (!shape) + return {}; + int64_t rank = shape.getNumElements(); + Builder builder(getContext()); + return builder.getIndexAttr(rank); +} + +//===----------------------------------------------------------------------===// // NumElementsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index a56a8f9..00f6b36 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -442,3 +442,27 @@ func @f(%arg0 : !shape.shape) { "consume.witness"(%0) : (!shape.witness) -> () return } + +// ----- + +// Fold `rank` based on constant shape. +// CHECK-LABEL: @fold_rank +func @fold_rank() -> !shape.size { + // CHECK-DAG: %[[RESULT:.*]] = shape.const_size 5 + // CHECK-DAG: return %[[RESULT]] : !shape.size + %shape = shape.const_shape [3, 4, 5, 6, 7] + %rank = shape.rank %shape + return %rank : !shape.size +} + +// ----- + +// Do not fold `rank` if shape is dynamic. +// CHECK-LABEL: @dont_fold_rank +// CHECK-SAME: (%[[SHAPE:.*]]: !shape.shape) -> !shape.size +func @dont_fold_rank(%shape : !shape.shape) -> !shape.size { + // CHECK-DAG: %[[RESULT:.*]] = shape.rank %[[SHAPE]] + // CHECK-DAG: return %[[RESULT]] : !shape.size + %rank = shape.rank %shape + return %rank : !shape.size +} -- 2.7.4