From ac3e5c4d93fbe7fb2db3c745c721aff41cc1b851 Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Fri, 19 Jun 2020 15:09:36 +0000 Subject: [PATCH] [MLIR][Shape] Lower `shape.shape_of` to standard dialect Lower `shape.shape_of` to standard dialect. This lowering supports statically and dynamically shaped tensors. Support for unranked tensors will be added as part of the lowering to `scf`. Differential Revision: https://reviews.llvm.org/D82098 --- .../Conversion/ShapeToStandard/ShapeToStandard.cpp | 42 +++++++++++++++++++++- .../ShapeToStandard/shape-to-standard.mlir | 29 +++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index d02f5e3..6a02bdc 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -38,6 +38,45 @@ public: } }; +class ShapeOfOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ShapeOfOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + ShapeOfOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + auto tensorVal = transformed.arg(); + auto tensorTy = tensorVal.getType(); + + // For unranked tensors `shape_of` lowers to `scf` and the pattern can be + // found in the corresponding pass. + if (tensorTy.isa()) + return failure(); + + // Build values for individual dimensions. + SmallVector dimValues; + auto rankedTensorTy = tensorTy.cast(); + int64_t rank = rankedTensorTy.getRank(); + for (int64_t i = 0; i < rank; i++) { + if (rankedTensorTy.isDynamicDim(i)) { + auto dimVal = rewriter.create(loc, tensorVal, i); + dimValues.push_back(dimVal); + } else { + int64_t dim = rankedTensorTy.getDimSize(i); + auto dimVal = rewriter.create(loc, dim); + dimValues.push_back(dimVal); + } + } + + // Materialize shape as ranked tensor. + rewriter.replaceOpWithNewOp(op.getOperation(), + dimValues); + return success(); + } +}; + class ConstSizeOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -107,7 +146,8 @@ void mlir::populateShapeToStandardConversionPatterns( patterns.insert< BinaryOpConversion, BinaryOpConversion, - ConstSizeOpConverter>(ctx); + ConstSizeOpConverter, + ShapeOfOpConversion>(ctx); // clang-format on } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir index 1caf005..bfe3c2b 100644 --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -86,3 +86,32 @@ func @size_const() -> !shape.size { } // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: return %[[C1]] : index + +// ----- + +// Lower `shape_of` for statically shaped tensor. +// CHECK-LABEL: @shape_of_stat +// CHECK-SAME: (%[[ARG:.*]]: tensor<1x2x3xf32>) +func @shape_of_stat(%arg : tensor<1x2x3xf32>) { + // CHECK-DAG: %[[C1:.*]] = constant 1 : index + // CHECK-DAG: %[[C2:.*]] = constant 2 : index + // CHECK-DAG: %[[C3:.*]] = constant 3 : index + // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C2]], %[[C3]]) : tensor<3xindex> + %shape = shape.shape_of %arg : tensor<1x2x3xf32> + return +} + +// ----- + +// Lower `shape_of` for dynamically shaped tensor. +// CHECK-LABEL: @shape_of_dyn +// CHECK-SAME: (%[[ARG:.*]]: tensor<1x5x?xf32>) +func @shape_of_dyn(%arg : tensor<1x5x?xf32>) { + // CHECK-DAG: %[[C1:.*]] = constant 1 : index + // CHECK-DAG: %[[C5:.*]] = constant 5 : index + // CHECK-DAG: %[[C2:.*]] = constant 2 : index + // CHECK-DAG: %[[DYN_DIM:.*]] = dim %[[ARG]], %[[C2]] : tensor<1x5x?xf32> + // CHECK-DAG: %[[SHAPE:.*]] = tensor_from_elements(%[[C1]], %[[C5]], %[[DYN_DIM]]) : tensor<3xindex> + %shape = shape.shape_of %arg : tensor<1x5x?xf32> + return +} -- 2.7.4