From 42c195f0ec8f2d9236b237c5ad2c6f3ca9b4184c Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 8 Mar 2021 15:23:28 +0100 Subject: [PATCH] [mlir][Shape] Allow shape.split_at to return extent tensors and lower it to std.subtensor split_at can return an error if the split index is out of bounds. If the user knows that the index can never be out of bounds it's safe to use extent tensors. This has a straight-forward lowering to std.subtensor. Differential Revision: https://reviews.llvm.org/D98177 --- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td | 6 ++-- .../Conversion/ShapeToStandard/ShapeToStandard.cpp | 42 ++++++++++++++++++++++ .../ShapeToStandard/shape-to-standard.mlir | 20 +++++++++++ 3 files changed, 66 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 27e219d..a176e6d 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -604,7 +604,8 @@ def Shape_SplitAtOp : Shape_Op<"split_at", [NoSideEffect]> { If `index` is negative, it is treated as indexing from the back of the shape. This negative-handling behavior is important when handling unranked shapes, where the positive index is not necessarily knowable due to a - dynamic number of leading dimensions. + dynamic number of leading dimensions. If the result is in extent tensor form + out of bounds indices result in undefined behavior. Examples: - split_at([4,5,6], index=0) -> [], [4,5,6] @@ -623,7 +624,8 @@ def Shape_SplitAtOp : Shape_Op<"split_at", [NoSideEffect]> { let arguments = (ins Shape_ShapeOrExtentTensorType:$operand, Shape_SizeOrIndexType:$index); - let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail); + let results = (outs Shape_ShapeOrExtentTensorType:$head, + Shape_ShapeOrExtentTensorType:$tail); let hasFolder = 1; } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index 2b5d619..49c44ad 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -591,6 +591,47 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( } namespace { +class SplitAtOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SplitAtOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +LogicalResult SplitAtOpConversion::matchAndRewrite( + SplitAtOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + // Error conditions are not implemented, only lower if all operands and + // results are extent tensors. + if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()}, + [](Value v) { return v.getType().isa(); })) + return failure(); + + SplitAtOp::Adaptor transformed(op); + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value zero = b.create(0); + Value rank = b.create(transformed.operand(), zero); + + // index < 0 ? index + rank : index + Value originalIndex = transformed.index(); + Value add = b.create(originalIndex, rank); + Value indexIsNegative = + b.create(CmpIPredicate::slt, originalIndex, zero); + Value index = b.create(indexIsNegative, add, originalIndex); + + Value one = b.create(1); + Value head = b.create(transformed.operand(), zero, index, one); + Value tailSize = b.create(rank, index); + Value tail = + b.create(transformed.operand(), index, tailSize, one); + rewriter.replaceOp(op, {head, tail}); + return success(); +} + +namespace { class ToExtentTensorOpConversion : public OpConversionPattern { public: @@ -660,6 +701,7 @@ void mlir::populateShapeToStandardConversionPatterns( ReduceOpConverter, ShapeEqOpConverter, ShapeOfOpConversion, + SplitAtOpConversion, ToExtentTensorOpConversion>(ctx); // clang-format on } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir index d8aec02..a4a0f7e 100644 --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -592,3 +592,23 @@ func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>, : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> tensor return } + +// ----- + +// Lower `split_at` +// CHECK-LABEL: @split_at +// CHECK-SAME: %[[SHAPE:.*]]: tensor, %[[INDEX:.*]]: index +func @split_at(%shape: tensor, %index: index) -> (tensor, tensor) { + // CHECK-NEXT: %[[C0:.*]] = constant 0 : index + // CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor + // CHECK-NEXT: %[[POSINDEX:.*]] = addi %[[INDEX]], %[[RANK]] : index + // CHECK-NEXT: %[[ISNEG:.*]] = cmpi slt, %[[INDEX]], %[[C0]] : index + // CHECK-NEXT: %[[SELECT:.*]] = select %[[ISNEG]], %[[POSINDEX]], %[[INDEX]] : index + // CHECK-NEXT: %[[C1:.*]] = constant 1 : index + // CHECK-NEXT: %[[HEAD:.*]] = subtensor %[[SHAPE]][%[[C0]]] [%[[SELECT]]] [%[[C1]]] : tensor to tensor + // CHECK-NEXT: %[[TAIL_SIZE:.*]] = subi %[[RANK]], %[[SELECT]] : index + // CHECK-NEXT: %[[TAIL:.*]] = subtensor %[[SHAPE]][%[[SELECT]]] [%[[TAIL_SIZE]]] [%[[C1]]] : tensor to tensor + // CHECK-NEXT: return %[[HEAD]], %[[TAIL]] : tensor, tensor + %head, %tail = "shape.split_at"(%shape, %index) : (tensor, index) -> (tensor, tensor) + return %head, %tail : tensor, tensor +} -- 2.7.4