From f6fb0a4f35d152d154aeb8a8e3d47ff1392c1bad Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 7 Dec 2022 18:30:36 +0100 Subject: [PATCH] [mlir] Make patterns for folding tensor.empty optional. At the moment, they are a part of EmptyOp::getCanonicalizationPatterns. When extract_slice(tensor.empty) is rewritten as a new tensor.empty, it could happen that we end up with two tensor.empty ops, since the original tensor.empty can have two users. After bufferization such cases result in two allocations. Differential Revision: https://reviews.llvm.org/D139308 --- .../mlir/Dialect/Tensor/Transforms/Transforms.h | 4 ++ .../lib/Dialect/Linalg/Transforms/DropUnitDims.cpp | 2 + mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 55 --------------- mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt | 1 + .../Dialect/Tensor/Transforms/EmptyOpPatterns.cpp | 79 ++++++++++++++++++++++ mlir/test/Dialect/Linalg/canonicalize.mlir | 5 +- mlir/test/Dialect/Tensor/canonicalize.mlir | 58 ---------------- mlir/test/Dialect/Tensor/fold-empty-op.mlir | 61 +++++++++++++++++ .../lib/Dialect/Tensor/TestTensorTransforms.cpp | 12 ++++ 9 files changed, 162 insertions(+), 115 deletions(-) create mode 100644 mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp create mode 100644 mlir/test/Dialect/Tensor/fold-empty-op.mlir diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h index 267972e..430842b 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -40,6 +40,10 @@ void populateMergeConsecutiveInsertExtractSlicePatterns( /// `tensor.collapse_shape` into other ops. void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns); +/// Populates `patterns` with patterns that fold tensor.empty with +/// tensor.[extract_slice|cast|expand_shape|collapse_shape]. +void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index f0cd9ca..c0bf57c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -667,6 +668,7 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns( tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); tensor::EmptyOp::getCanonicalizationPatterns(patterns, context); tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); + tensor::populateFoldTensorEmptyPatterns(patterns); memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index e0f45b7..f5cb9fe 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -620,58 +620,6 @@ struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern { } }; -/// `tensor.empty` does not define any tensor contents, so a slice of a -/// `tensor.empty` can be canonicalized to a smaller `tensor.empty`. -struct FoldEmptyTensorWithExtractSliceOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, - PatternRewriter &rewriter) const override { - if (!sliceOp.getSource().getDefiningOp()) - return failure(); - - // ExtractSliceOp may be rank-reducing; its dynamic sizes must be - // preserved as well as its result type. - auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(), - sliceOp.getType().getElementType(), - sliceOp.getType().getEncoding()); - rewriter.replaceOpWithNewOp(sliceOp, tensorType, - sliceOp.getSizes()); - return success(); - } -}; - -template -struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ReshapeOp reshapeOp, - PatternRewriter &rewriter) const override { - if (!reshapeOp.getSrc().template getDefiningOp()) - return failure(); - Location loc = reshapeOp.getLoc(); - ReifiedRankedShapedTypeDims resultShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - cast(reshapeOp.getOperation()); - if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, - resultShapes)) || - !llvm::hasSingleElement(resultShapes)) - return failure(); - // TODO: Do not drop tensor type encoding. - Value emptyTensor = - rewriter.create(loc, getAsOpFoldResult(resultShapes[0]), - reshapeOp.getResultType().getElementType()); - if (emptyTensor.getType() != reshapeOp.getResultType()) { - rewriter.replaceOpWithNewOp( - reshapeOp, reshapeOp.getResultType(), emptyTensor); - } else { - rewriter.replaceOp(reshapeOp, emptyTensor); - } - return success(); - } -}; - struct FoldEmptyTensorWithDimOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -765,9 +713,6 @@ struct FoldEmptyTensorWithCastOp : public OpRewritePattern { void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add, - FoldEmptyTensorWithReshapeOp, ReplaceEmptyTensorStaticShapeDims>(context); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt index 08a0d5a..216fc8e 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTensorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + EmptyOpPatterns.cpp ExtractSliceFromReshapeUtils.cpp MergeConsecutiveInsertExtractSlicePatterns.cpp ReshapePatterns.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp new file mode 100644 index 0000000..66cbd64 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/EmptyOpPatterns.cpp @@ -0,0 +1,79 @@ +//===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; +using namespace mlir::tensor; + +namespace { + +template +struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + if (!reshapeOp.getSrc().template getDefiningOp()) + return failure(); + Location loc = reshapeOp.getLoc(); + ReifiedRankedShapedTypeDims resultShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + cast(reshapeOp.getOperation()); + if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter, + resultShapes)) || + !llvm::hasSingleElement(resultShapes)) + return failure(); + // TODO: Do not drop tensor type encoding. + Value emptyTensor = + rewriter.create(loc, getAsOpFoldResult(resultShapes[0]), + reshapeOp.getResultType().getElementType()); + if (emptyTensor.getType() != reshapeOp.getResultType()) { + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeOp.getResultType(), emptyTensor); + } else { + rewriter.replaceOp(reshapeOp, emptyTensor); + } + return success(); + } +}; + +/// `tensor.empty` does not define any tensor contents, so a slice of a +/// `tensor.empty` can be canonicalized to a smaller `tensor.empty`. +struct FoldEmptyTensorWithExtractSliceOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractSliceOp sliceOp, + PatternRewriter &rewriter) const override { + if (!sliceOp.getSource().getDefiningOp()) + return failure(); + + // ExtractSliceOp may be rank-reducing; its dynamic sizes must be + // preserved as well as its result type. + auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(), + sliceOp.getType().getElementType(), + sliceOp.getType().getEncoding()); + rewriter.replaceOpWithNewOp(sliceOp, tensorType, + sliceOp.getSizes()); + return success(); + } +}; + +} // namespace + +void mlir::tensor::populateFoldTensorEmptyPatterns( + RewritePatternSet &patterns) { + patterns.add, + FoldEmptyTensorWithReshapeOp>( + patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index c9f1726..4510d20 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -299,9 +299,10 @@ func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) { // CHECK-LABEL: func @fold_fill_reshape() func.func @fold_fill_reshape() -> tensor<6x4xf32> { %zero = arith.constant 0.0 : f32 - // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<6x4xf32> %empty = tensor.empty() : tensor<1x2x3x4xf32> - // CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor<6x4xf32>) -> tensor<6x4xf32> + // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape + // CHECK-NEXT: %[[FILL:.+]] = linalg.fill ins(%cst : f32) + // CHECK-SAME: outs(%[[COLLAPSE]] : tensor<6x4xf32>) %fill = linalg.fill ins(%zero : f32) outs(%empty : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> %reshape = tensor.collapse_shape %fill [[0, 1, 2], [3]] : tensor<1x2x3x4xf32> into tensor<6x4xf32> diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 04e7207..fed2ca7 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1538,52 +1538,6 @@ func.func @empty_canonicalize() -> (tensor<4x5x?xf32>) { // ----- -func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { - %0 = tensor.empty(%arg0) : tensor<6x5x?xf32> - %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] - : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> - return %1 : tensor<2x3x5x4x?x7xf32> -} -// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> -// CHECK: func @empty_reshape_expansion -// CHECK-SAME: %[[ARG0:.+]]: index -// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] -// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) -// CHECK-NEXT: return %[[INIT]] - -// ----- - -func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { - %0 = tensor.empty(%arg0) : tensor<2x3x5x4x?x7xf32> - %1 = tensor.collapse_shape %0 [[0, 1], [2], [3, 4, 5]] - : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> - return %1 : tensor<6x5x?xf32> -} -// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> -// CHECK: func @empty_reshape_collapse -// CHECK-SAME: %[[ARG0:.+]]: index -// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]] -// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) -// CHECK-NEXT: return %[[INIT]] - -// ----- - -func.func @fold_empty_tensor_with_slice - (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32> -{ - %0 = tensor.empty(%arg0) : tensor - %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1] - : tensor to tensor<5x?x20xf32> - return %1 : tensor<5x?x20xf32> -} -// CHECK: func @fold_empty_tensor_with_slice -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index -// CHECK: %[[T0:.+]] = tensor.empty(%[[ARG1]]) -// CHECK: return %[[T0]] - -// ----- - func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> { %0 = tensor.empty(%arg0) : tensor %1 = tensor.cast %0 : tensor to tensor<1x12xf32> @@ -1619,18 +1573,6 @@ func.func @empty_tensor_canonicalize(%i : index) { // ----- -// CHECK-LABEL: func @rank_reducing_empty_tensor_extract -func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tensor<2xf32> { - // CHECK: tensor.empty() : tensor<2xf32> - %a = tensor.empty(%sz) : tensor - - // CHECK-NOT: extract - %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> - return %r: tensor<2xf32> -} - -// ----- - // CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)> // CHECK-LABEL: func @dim_of_expand_shape( // CHECK-SAME: %[[t:.*]]: tensor diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir new file mode 100644 index 0000000..799c691 --- /dev/null +++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-empty-op-folding %s | FileCheck %s + +func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> { + %0 = tensor.empty(%arg0) : tensor<6x5x?xf32> + %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]] + : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32> + return %1 : tensor<2x3x5x4x?x7xf32> +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)> +// CHECK: func @empty_reshape_expansion +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32> +// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]] +// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[DIM]]] +// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) +// CHECK-NEXT: return %[[INIT]] + +// ----- + +func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> { + %0 = tensor.empty(%arg0) : tensor<2x3x5x4x?x7xf32> + %1 = tensor.collapse_shape %0 [[0, 1], [2], [3, 4, 5]] + : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32> + return %1 : tensor<6x5x?xf32> +} +// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)> +// CHECK: func @empty_reshape_collapse +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<2x3x5x4x?x7xf32> +// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]] +// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[DIM]]] +// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]]) +// CHECK-NEXT: return %[[INIT]] + +// ----- + +func.func @fold_empty_tensor_with_slice + (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32> +{ + %0 = tensor.empty(%arg0) : tensor + %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1] + : tensor to tensor<5x?x20xf32> + return %1 : tensor<5x?x20xf32> +} +// CHECK: func @fold_empty_tensor_with_slice +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: %[[T0:.+]] = tensor.empty(%[[ARG1]]) +// CHECK: return %[[T0]] + +// ----- + +// CHECK-LABEL: func @rank_reducing_empty_tensor_extract +func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tensor<2xf32> { + // CHECK: tensor.empty() : tensor<2xf32> + %a = tensor.empty(%sz) : tensor + + // CHECK-NOT: extract + %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor to tensor<2xf32> + return %r: tensor<2xf32> +} diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp index 1802387..fed6aec 100644 --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -70,6 +70,10 @@ struct TestTensorTransforms llvm::cl::desc("Test folding of expand_shape/collapse_shape"), llvm::cl::init(false)}; + Option testEmptyOpFolding{ + *this, "test-empty-op-folding", + llvm::cl::desc("Test folding of tensor.empty"), llvm::cl::init(false)}; + Option useForeach{ *this, "use-foreach", llvm::cl::desc( @@ -85,6 +89,12 @@ static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) { (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void applyEmptyOpFoldingPatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + tensor::populateFoldTensorEmptyPatterns(patterns); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + static void applySplitPaddingPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateSplitPaddingPatterns(patterns); @@ -264,6 +274,8 @@ void TestTensorTransforms::runOnOperation() { applyFoldConsecutiveInsertExtractSlicePatterns(rootOp); if (testReassociativeReshapeFolding) applyReassociativeReshapeFoldingPatterns(rootOp); + if (testEmptyOpFolding) + applyEmptyOpFoldingPatterns(rootOp); if (testRewriteExtractSliceWithTiledCollapseShape) { if (failed( applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach))) -- 2.7.4