/// `tensor.collapse_shape` into other ops.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
+/// Populates `patterns` with patterns that fold tensor.empty with
+/// tensor.[extract_slice|cast|expand_shape|collapse_shape].
+void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns);
+
} // namespace tensor
} // namespace mlir
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
+ tensor::populateFoldTensorEmptyPatterns(patterns);
memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
}
}
};
-/// `tensor.empty` does not define any tensor contents, so a slice of a
-/// `tensor.empty` can be canonicalized to a smaller `tensor.empty`.
-struct FoldEmptyTensorWithExtractSliceOp
- : public OpRewritePattern<ExtractSliceOp> {
- using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
- PatternRewriter &rewriter) const override {
- if (!sliceOp.getSource().getDefiningOp<EmptyOp>())
- return failure();
-
- // ExtractSliceOp may be rank-reducing; its dynamic sizes must be
- // preserved as well as its result type.
- auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
- sliceOp.getType().getElementType(),
- sliceOp.getType().getEncoding());
- rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
- sliceOp.getSizes());
- return success();
- }
-};
-
-template <typename ReshapeOp>
-struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
- using OpRewritePattern<ReshapeOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
- PatternRewriter &rewriter) const override {
- if (!reshapeOp.getSrc().template getDefiningOp<EmptyOp>())
- return failure();
- Location loc = reshapeOp.getLoc();
- ReifiedRankedShapedTypeDims resultShapes;
- ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
- cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation());
- if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
- resultShapes)) ||
- !llvm::hasSingleElement(resultShapes))
- return failure();
- // TODO: Do not drop tensor type encoding.
- Value emptyTensor =
- rewriter.create<EmptyOp>(loc, getAsOpFoldResult(resultShapes[0]),
- reshapeOp.getResultType().getElementType());
- if (emptyTensor.getType() != reshapeOp.getResultType()) {
- rewriter.replaceOpWithNewOp<tensor::CastOp>(
- reshapeOp, reshapeOp.getResultType(), emptyTensor);
- } else {
- rewriter.replaceOp(reshapeOp, emptyTensor);
- }
- return success();
- }
-};
-
struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;
void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
- FoldEmptyTensorWithExtractSliceOp,
- FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
- FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>,
ReplaceEmptyTensorStaticShapeDims>(context);
}
add_mlir_dialect_library(MLIRTensorTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
+ EmptyOpPatterns.cpp
ExtractSliceFromReshapeUtils.cpp
MergeConsecutiveInsertExtractSlicePatterns.cpp
ReshapePatterns.cpp
--- /dev/null
+//===- EmptyOpPatterns.cpp - Patterns related to tensor.empty folding ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+template <typename ReshapeOp>
+struct FoldEmptyTensorWithReshapeOp : public OpRewritePattern<ReshapeOp> {
+ using OpRewritePattern<ReshapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ReshapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ if (!reshapeOp.getSrc().template getDefiningOp<EmptyOp>())
+ return failure();
+ Location loc = reshapeOp.getLoc();
+ ReifiedRankedShapedTypeDims resultShapes;
+ ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
+ cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation());
+ if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
+ resultShapes)) ||
+ !llvm::hasSingleElement(resultShapes))
+ return failure();
+ // TODO: Do not drop tensor type encoding.
+ Value emptyTensor =
+ rewriter.create<EmptyOp>(loc, getAsOpFoldResult(resultShapes[0]),
+ reshapeOp.getResultType().getElementType());
+ if (emptyTensor.getType() != reshapeOp.getResultType()) {
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(
+ reshapeOp, reshapeOp.getResultType(), emptyTensor);
+ } else {
+ rewriter.replaceOp(reshapeOp, emptyTensor);
+ }
+ return success();
+ }
+};
+
+/// `tensor.empty` does not define any tensor contents, so a slice of a
+/// `tensor.empty` can be canonicalized to a smaller `tensor.empty`.
+struct FoldEmptyTensorWithExtractSliceOp
+ : public OpRewritePattern<ExtractSliceOp> {
+ using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ if (!sliceOp.getSource().getDefiningOp<EmptyOp>())
+ return failure();
+
+ // ExtractSliceOp may be rank-reducing; its dynamic sizes must be
+ // preserved as well as its result type.
+ auto tensorType = RankedTensorType::get(sliceOp.getType().getShape(),
+ sliceOp.getType().getElementType(),
+ sliceOp.getType().getEncoding());
+ rewriter.replaceOpWithNewOp<EmptyOp>(sliceOp, tensorType,
+ sliceOp.getSizes());
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::tensor::populateFoldTensorEmptyPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<FoldEmptyTensorWithExtractSliceOp,
+ FoldEmptyTensorWithReshapeOp<tensor::ExpandShapeOp>,
+ FoldEmptyTensorWithReshapeOp<tensor::CollapseShapeOp>>(
+ patterns.getContext());
+}
// CHECK-LABEL: func @fold_fill_reshape()
func.func @fold_fill_reshape() -> tensor<6x4xf32> {
%zero = arith.constant 0.0 : f32
- // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<6x4xf32>
%empty = tensor.empty() : tensor<1x2x3x4xf32>
- // CHECK: %[[FILL:.+]] = linalg.fill ins(%cst : f32) outs(%[[INIT]] : tensor<6x4xf32>) -> tensor<6x4xf32>
+ // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
+ // CHECK-NEXT: %[[FILL:.+]] = linalg.fill ins(%cst : f32)
+ // CHECK-SAME: outs(%[[COLLAPSE]] : tensor<6x4xf32>)
%fill = linalg.fill ins(%zero : f32) outs(%empty : tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
%reshape = tensor.collapse_shape %fill [[0, 1, 2], [3]]
: tensor<1x2x3x4xf32> into tensor<6x4xf32>
// -----
-func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
- %0 = tensor.empty(%arg0) : tensor<6x5x?xf32>
- %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]]
- : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
- return %1 : tensor<2x3x5x4x?x7xf32>
-}
-// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
-// CHECK: func @empty_reshape_expansion
-// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]])
-// CHECK-NEXT: return %[[INIT]]
-
-// -----
-
-func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
- %0 = tensor.empty(%arg0) : tensor<2x3x5x4x?x7xf32>
- %1 = tensor.collapse_shape %0 [[0, 1], [2], [3, 4, 5]]
- : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
- return %1 : tensor<6x5x?xf32>
-}
-// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
-// CHECK: func @empty_reshape_collapse
-// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
-// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]])
-// CHECK-NEXT: return %[[INIT]]
-
-// -----
-
-func.func @fold_empty_tensor_with_slice
- (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32>
-{
- %0 = tensor.empty(%arg0) : tensor<?x10x40xf32>
- %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1]
- : tensor<?x10x40xf32> to tensor<5x?x20xf32>
- return %1 : tensor<5x?x20xf32>
-}
-// CHECK: func @fold_empty_tensor_with_slice
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
-// CHECK: %[[T0:.+]] = tensor.empty(%[[ARG1]])
-// CHECK: return %[[T0]]
-
-// -----
-
func.func @fold_empty_tensor_with_cast(%arg0 : index) -> tensor<1x12xf32> {
%0 = tensor.empty(%arg0) : tensor<?x12xf32>
%1 = tensor.cast %0 : tensor<?x12xf32> to tensor<1x12xf32>
// -----
-// CHECK-LABEL: func @rank_reducing_empty_tensor_extract
-func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tensor<2xf32> {
- // CHECK: tensor.empty() : tensor<2xf32>
- %a = tensor.empty(%sz) : tensor<?x2xf32>
-
- // CHECK-NOT: extract
- %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
- return %r: tensor<2xf32>
-}
-
-// -----
-
// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
// CHECK-LABEL: func @dim_of_expand_shape(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
--- /dev/null
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-empty-op-folding %s | FileCheck %s
+
+func.func @empty_reshape_expansion(%arg0 : index) -> tensor<2x3x5x4x?x7xf32> {
+ %0 = tensor.empty(%arg0) : tensor<6x5x?xf32>
+ %1 = tensor.expand_shape %0 [[0, 1], [2], [3, 4, 5]]
+ : tensor<6x5x?xf32> into tensor<2x3x5x4x?x7xf32>
+ return %1 : tensor<2x3x5x4x?x7xf32>
+}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
+// CHECK: func @empty_reshape_expansion
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
+// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
+// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[DIM]]]
+// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]])
+// CHECK-NEXT: return %[[INIT]]
+
+// -----
+
+func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
+ %0 = tensor.empty(%arg0) : tensor<2x3x5x4x?x7xf32>
+ %1 = tensor.collapse_shape %0 [[0, 1], [2], [3, 4, 5]]
+ : tensor<2x3x5x4x?x7xf32> into tensor<6x5x?xf32>
+ return %1 : tensor<6x5x?xf32>
+}
+// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 * 28)>
+// CHECK: func @empty_reshape_collapse
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<2x3x5x4x?x7xf32>
+// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
+// CHECK-NEXT: %[[D:.+]] = affine.apply #[[MAP]]()[%[[DIM]]]
+// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]])
+// CHECK-NEXT: return %[[INIT]]
+
+// -----
+
+func.func @fold_empty_tensor_with_slice
+ (%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32>
+{
+ %0 = tensor.empty(%arg0) : tensor<?x10x40xf32>
+ %1 = tensor.extract_slice %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1]
+ : tensor<?x10x40xf32> to tensor<5x?x20xf32>
+ return %1 : tensor<5x?x20xf32>
+}
+// CHECK: func @fold_empty_tensor_with_slice
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[T0:.+]] = tensor.empty(%[[ARG1]])
+// CHECK: return %[[T0]]
+
+// -----
+
+// CHECK-LABEL: func @rank_reducing_empty_tensor_extract
+func.func @rank_reducing_empty_tensor_extract(%sz : index, %idx : index) -> tensor<2xf32> {
+ // CHECK: tensor.empty() : tensor<2xf32>
+ %a = tensor.empty(%sz) : tensor<?x2xf32>
+
+ // CHECK-NOT: extract
+ %r = tensor.extract_slice %a[%idx, 0] [1, 2] [1, 1] : tensor<?x2xf32> to tensor<2xf32>
+ return %r: tensor<2xf32>
+}
llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
llvm::cl::init(false)};
+ Option<bool> testEmptyOpFolding{
+ *this, "test-empty-op-folding",
+ llvm::cl::desc("Test folding of tensor.empty"), llvm::cl::init(false)};
+
Option<bool> useForeach{
*this, "use-foreach",
llvm::cl::desc(
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}
+static void applyEmptyOpFoldingPatterns(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ tensor::populateFoldTensorEmptyPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
static void applySplitPaddingPatterns(Operation *rootOp) {
RewritePatternSet patterns(rootOp->getContext());
tensor::populateSplitPaddingPatterns(patterns);
applyFoldConsecutiveInsertExtractSlicePatterns(rootOp);
if (testReassociativeReshapeFolding)
applyReassociativeReshapeFoldingPatterns(rootOp);
+ if (testEmptyOpFolding)
+ applyEmptyOpFoldingPatterns(rootOp);
if (testRewriteExtractSliceWithTiledCollapseShape) {
if (failed(
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))