[mlir][Shape] Allow shape.split_at to return extent tensors and lower it to std.subtensor
authorBenjamin Kramer <benny.kra@googlemail.com>
Mon, 8 Mar 2021 14:23:28 +0000 (15:23 +0100)
committerBenjamin Kramer <benny.kra@googlemail.com>
Mon, 8 Mar 2021 15:48:05 +0000 (16:48 +0100)
split_at can return an error if the split index is out of bounds. If the
user knows that the index can never be out of bounds it's safe to use
extent tensors. This has a straight-forward lowering to std.subtensor.

Differential Revision: https://reviews.llvm.org/D98177

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

index 27e219d..a176e6d 100644 (file)
@@ -604,7 +604,8 @@ def Shape_SplitAtOp : Shape_Op<"split_at", [NoSideEffect]> {
     If `index` is negative, it is treated as indexing from the back of the
     shape. This negative-handling behavior is important when handling unranked
     shapes, where the positive index is not necessarily knowable due to a
-    dynamic number of leading dimensions.
+    dynamic number of leading dimensions. If the result is in extent tensor form
+    out of bounds indices result in undefined behavior.
 
     Examples:
     - split_at([4,5,6], index=0) -> [], [4,5,6]
@@ -623,7 +624,8 @@ def Shape_SplitAtOp : Shape_Op<"split_at", [NoSideEffect]> {
 
   let arguments = (ins Shape_ShapeOrExtentTensorType:$operand,
                        Shape_SizeOrIndexType:$index);
-  let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail);
+  let results = (outs Shape_ShapeOrExtentTensorType:$head,
+                      Shape_ShapeOrExtentTensorType:$tail);
   let hasFolder = 1;
 }
 
index 2b5d619..49c44ad 100644 (file)
@@ -591,6 +591,47 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite(
 }
 
 namespace {
+class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
+public:
+  using OpConversionPattern<SplitAtOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(SplitAtOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult SplitAtOpConversion::matchAndRewrite(
+    SplitAtOp op, ArrayRef<Value> operands,
+    ConversionPatternRewriter &rewriter) const {
+  // Error conditions are not implemented, only lower if all operands and
+  // results are extent tensors.
+  if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()},
+                   [](Value v) { return v.getType().isa<ShapeType>(); }))
+    return failure();
+
+  SplitAtOp::Adaptor transformed(op);
+  ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+  Value zero = b.create<ConstantIndexOp>(0);
+  Value rank = b.create<DimOp>(transformed.operand(), zero);
+
+  // index < 0 ? index + rank : index
+  Value originalIndex = transformed.index();
+  Value add = b.create<AddIOp>(originalIndex, rank);
+  Value indexIsNegative =
+      b.create<CmpIOp>(CmpIPredicate::slt, originalIndex, zero);
+  Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex);
+
+  Value one = b.create<ConstantIndexOp>(1);
+  Value head = b.create<SubTensorOp>(transformed.operand(), zero, index, one);
+  Value tailSize = b.create<SubIOp>(rank, index);
+  Value tail =
+      b.create<SubTensorOp>(transformed.operand(), index, tailSize, one);
+  rewriter.replaceOp(op, {head, tail});
+  return success();
+}
+
+namespace {
 class ToExtentTensorOpConversion
     : public OpConversionPattern<ToExtentTensorOp> {
 public:
@@ -660,6 +701,7 @@ void mlir::populateShapeToStandardConversionPatterns(
       ReduceOpConverter,
       ShapeEqOpConverter,
       ShapeOfOpConversion,
+      SplitAtOpConversion,
       ToExtentTensorOpConversion>(ctx);
   // clang-format on
 }
index d8aec02..a4a0f7e 100644 (file)
@@ -592,3 +592,23 @@ func @broadcast_3_shapes_different_extents(%a : tensor<2xindex>,
       : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> tensor<?xindex>
   return
 }
+
+// -----
+
+// Lower `split_at`
+// CHECK-LABEL: @split_at
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>, %[[INDEX:.*]]: index
+func @split_at(%shape: tensor<?xindex>, %index: index) -> (tensor<?xindex>, tensor<?xindex>) {
+  // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+  // CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
+  // CHECK-NEXT: %[[POSINDEX:.*]] = addi %[[INDEX]], %[[RANK]] : index
+  // CHECK-NEXT: %[[ISNEG:.*]] = cmpi slt, %[[INDEX]], %[[C0]] : index
+  // CHECK-NEXT: %[[SELECT:.*]] = select %[[ISNEG]], %[[POSINDEX]], %[[INDEX]] : index
+  // CHECK-NEXT: %[[C1:.*]] = constant 1 : index
+  // CHECK-NEXT: %[[HEAD:.*]] = subtensor %[[SHAPE]][%[[C0]]] [%[[SELECT]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex>
+  // CHECK-NEXT: %[[TAIL_SIZE:.*]] = subi %[[RANK]], %[[SELECT]] : index
+  // CHECK-NEXT: %[[TAIL:.*]] = subtensor %[[SHAPE]][%[[SELECT]]] [%[[TAIL_SIZE]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex>
+  // CHECK-NEXT: return %[[HEAD]], %[[TAIL]] : tensor<?xindex>, tensor<?xindex>
+  %head, %tail = "shape.split_at"(%shape, %index) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+  return %head, %tail : tensor<?xindex>, tensor<?xindex>
+}