From dfcc09890a91b1085139fee175936b0e67824e47 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Tue, 28 Jul 2020 15:39:49 +0000 Subject: [PATCH] [MLIR][Shape] Lower `shape.const_shape` to `tensor_from_elements` Differential Revision: https://reviews.llvm.org/D82848 --- .../Conversion/ShapeToStandard/ShapeToStandard.cpp | 34 ++++++++++++++++++++++ .../ShapeToStandard/shape-to-standard.mlir | 16 ++++++++++ 2 files changed, 50 insertions(+) diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index f239d1c..b84b6ba 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -104,6 +104,39 @@ LogicalResult ShapeOfOpConversion::matchAndRewrite( } namespace { +class ConstShapeOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ConstShapeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +LogicalResult ConstShapeOpConverter::matchAndRewrite( + ConstShapeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + + // For now, this lowering supports only extent tensors, not `shape.shape` + // types. + if (op.getType().isa()) + return failure(); + + auto loc = op.getLoc(); + SmallVector extentOperands; + for (auto extent : op.shape()) { + extentOperands.push_back( + rewriter.create(loc, extent.getLimitedValue())); + } + Value tensor = rewriter.create(loc, extentOperands); + Type indexTy = rewriter.getIndexType(); + Type resultTy = RankedTensorType::get({ShapedType::kDynamicSize}, indexTy); + rewriter.replaceOpWithNewOp(op, tensor, resultTy); + return success(); +} + +namespace { class GetExtentOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -209,6 +242,7 @@ void mlir::populateShapeToStandardConversionPatterns( patterns.insert< AnyOpConversion, BinaryOpConversion, + ConstShapeOpConverter, BinaryOpConversion, GetExtentOpConverter, RankOpConverter, diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir index 9336402..7f875f3 100644 --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -111,6 +111,22 @@ func @get_extent_from_extent_tensor(%extents : tensor, %idx : index) // ----- +// Lower `const_shape` to `tensor_from_elements`. +// CHECK-LABEL: @const_shape +// CHECK-SAME: () -> tensor +func @const_shape() -> tensor { + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: %[[C2:.*]] = constant 2 : index + // CHECK: %[[C3:.*]] = constant 3 : index + // CHECK: %[[TENSOR3:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) + // CHECK: %[[RESULT:.*]] = tensor_cast %[[TENSOR3]] : tensor<3xindex> to tensor + // CHECK: return %[[RESULT]] : tensor + %shape = shape.const_shape [1, 2, 3] : tensor + return %shape : tensor +} + +// ----- + // Lower `any` to its first operand. // CHECK-LABEL: @any_of_three // CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor, %[[C:.*]]: tensor) -> tensor -- 2.7.4