If `index` is negative, it is treated as indexing from the back of the
shape. This negative-handling behavior is important when handling unranked
shapes, where the positive index is not necessarily knowable due to a
- dynamic number of leading dimensions.
+ dynamic number of leading dimensions. If the result is in extent tensor form
+ out of bounds indices result in undefined behavior.
Examples:
- split_at([4,5,6], index=0) -> [], [4,5,6]
let arguments = (ins Shape_ShapeOrExtentTensorType:$operand,
Shape_SizeOrIndexType:$index);
- let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail);
+ let results = (outs Shape_ShapeOrExtentTensorType:$head,
+ Shape_ShapeOrExtentTensorType:$tail);
let hasFolder = 1;
}
}
namespace {
+class SplitAtOpConversion : public OpConversionPattern<SplitAtOp> {
+public:
+ using OpConversionPattern<SplitAtOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(SplitAtOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+
+LogicalResult SplitAtOpConversion::matchAndRewrite(
+ SplitAtOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ // Error conditions are not implemented, only lower if all operands and
+ // results are extent tensors.
+ if (llvm::any_of(ValueRange{op.operand(), op.head(), op.tail()},
+ [](Value v) { return v.getType().isa<ShapeType>(); }))
+ return failure();
+
+ SplitAtOp::Adaptor transformed(op);
+ ImplicitLocOpBuilder b(op.getLoc(), rewriter);
+ Value zero = b.create<ConstantIndexOp>(0);
+ Value rank = b.create<DimOp>(transformed.operand(), zero);
+
+ // index < 0 ? index + rank : index
+ Value originalIndex = transformed.index();
+ Value add = b.create<AddIOp>(originalIndex, rank);
+ Value indexIsNegative =
+ b.create<CmpIOp>(CmpIPredicate::slt, originalIndex, zero);
+ Value index = b.create<SelectOp>(indexIsNegative, add, originalIndex);
+
+ Value one = b.create<ConstantIndexOp>(1);
+ Value head = b.create<SubTensorOp>(transformed.operand(), zero, index, one);
+ Value tailSize = b.create<SubIOp>(rank, index);
+ Value tail =
+ b.create<SubTensorOp>(transformed.operand(), index, tailSize, one);
+ rewriter.replaceOp(op, {head, tail});
+ return success();
+}
+
+namespace {
class ToExtentTensorOpConversion
: public OpConversionPattern<ToExtentTensorOp> {
public:
ReduceOpConverter,
ShapeEqOpConverter,
ShapeOfOpConversion,
+ SplitAtOpConversion,
ToExtentTensorOpConversion>(ctx);
// clang-format on
}
: tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> tensor<?xindex>
return
}
+
+// -----
+
+// Lower `split_at`
+// CHECK-LABEL: @split_at
+// CHECK-SAME: %[[SHAPE:.*]]: tensor<?xindex>, %[[INDEX:.*]]: index
+func @split_at(%shape: tensor<?xindex>, %index: index) -> (tensor<?xindex>, tensor<?xindex>) {
+ // CHECK-NEXT: %[[C0:.*]] = constant 0 : index
+ // CHECK-NEXT: %[[RANK:.*]] = dim %[[SHAPE]], %[[C0]] : tensor<?xindex>
+ // CHECK-NEXT: %[[POSINDEX:.*]] = addi %[[INDEX]], %[[RANK]] : index
+ // CHECK-NEXT: %[[ISNEG:.*]] = cmpi slt, %[[INDEX]], %[[C0]] : index
+ // CHECK-NEXT: %[[SELECT:.*]] = select %[[ISNEG]], %[[POSINDEX]], %[[INDEX]] : index
+ // CHECK-NEXT: %[[C1:.*]] = constant 1 : index
+ // CHECK-NEXT: %[[HEAD:.*]] = subtensor %[[SHAPE]][%[[C0]]] [%[[SELECT]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex>
+ // CHECK-NEXT: %[[TAIL_SIZE:.*]] = subi %[[RANK]], %[[SELECT]] : index
+ // CHECK-NEXT: %[[TAIL:.*]] = subtensor %[[SHAPE]][%[[SELECT]]] [%[[TAIL_SIZE]]] [%[[C1]]] : tensor<?xindex> to tensor<?xindex>
+ // CHECK-NEXT: return %[[HEAD]], %[[TAIL]] : tensor<?xindex>, tensor<?xindex>
+ %head, %tail = "shape.split_at"(%shape, %index) : (tensor<?xindex>, index) -> (tensor<?xindex>, tensor<?xindex>)
+ return %head, %tail : tensor<?xindex>, tensor<?xindex>
+}