namespace {
/// Conversion patterns.
-class SizeToIndexOpConversion
- : public OpConversionPattern<shape::SizeToIndexOp> {
+class FromExtentTensorOpConversion
+ : public OpConversionPattern<shape::FromExtentTensorOp> {
public:
- using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
+ using OpConversionPattern<shape::FromExtentTensorOp>::OpConversionPattern;
LogicalResult
- matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
+ matchAndRewrite(shape::FromExtentTensorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
- shape::SizeToIndexOpOperandAdaptor transformed(operands);
- rewriter.replaceOp(op.getOperation(), transformed.arg());
+ shape::FromExtentTensorOpOperandAdaptor transformed(operands);
+ rewriter.replaceOp(op.getOperation(), transformed.input());
return success();
}
};
}
};
+class SizeToIndexOpConversion
+ : public OpConversionPattern<shape::SizeToIndexOp> {
+public:
+ using OpConversionPattern<shape::SizeToIndexOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(shape::SizeToIndexOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ shape::SizeToIndexOpOperandAdaptor transformed(operands);
+ rewriter.replaceOp(op.getOperation(), transformed.arg());
+ return success();
+ }
+};
+
+class ToExtentTensorOpConversion
+ : public OpConversionPattern<shape::ToExtentTensorOp> {
+public:
+ using OpConversionPattern<shape::ToExtentTensorOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const override {
+ shape::ToExtentTensorOpOperandAdaptor transformed(operands);
+ rewriter.replaceOp(op.getOperation(), transformed.input());
+ return success();
+ }
+};
+
/// Type conversions.
class ShapeTypeConverter : public TypeConverter {
public:
ShapeTypeConverter(MLIRContext *ctx) {
// Add default pass-through conversion.
addConversion([&](Type type) { return type; });
+
addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); });
addConversion([ctx](shape::ShapeType type) {
return RankedTensorType::get({ShapedType::kDynamicSize},
OwningRewritePatternList &patterns, MLIRContext *ctx) {
// clang-format off
patterns.insert<
+ FromExtentTensorOpConversion,
IndexToSizeOpConversion,
- SizeToIndexOpConversion>(ctx);
+ SizeToIndexOpConversion,
+ ToExtentTensorOpConversion>(ctx);
// clang-format on
}
// CHECK: return %[[SHAPE]] : tensor<?xindex>
return %shape : !shape.shape
}
+
+// -----
+
+// Lower `to_extent_tensor` operation to no-op.
+// CHECK-LABEL: @to_extent_tensor
+// CHECK-SAME: (%[[SHAPE:.*]]: tensor<?xindex>) -> tensor<?xindex>
+func @to_extent_tensor(%shape : !shape.shape) -> tensor<?xindex> {
+ // CHECK-NEXT: return %[[SHAPE]] : tensor<?xindex>
+ %tensor = "shape.to_extent_tensor"(%shape) : (!shape.shape) -> tensor<?xindex>
+ return %tensor : tensor<?xindex>
+}
+
+// -----
+
+// Lower `from_extent_tensor` operation to no-op.
+// CHECK-LABEL: @from_extent_tensor
+// CHECK-SAME: (%[[TENSOR:.*]]: tensor<?xindex>) -> tensor<?xindex>
+func @from_extent_tensor(%tensor : tensor<?xindex>) -> !shape.shape {
+ // CHECK-NEXT: return %[[TENSOR]] : tensor<?xindex>
+ %shape = "shape.from_extent_tensor"(%tensor)
+ : (tensor<?xindex>) -> !shape.shape
+ return %shape : !shape.shape
+}