void populateMergeConsecutiveInsertExtractSlicePatterns(
RewritePatternSet &patterns);
+/// Populates `patterns` with patterns that fold `tensor.expand_shape` and
+/// `tensor.collapse_shape` into other ops.
+void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
+
} // namespace tensor
} // namespace mlir
Bufferize.cpp
ExtractSliceFromReshapeUtils.cpp
MergeConsecutiveInsertExtractSlicePatterns.cpp
+ ReshapePatterns.cpp
SplitPaddingPatterns.cpp
SwapExtractSliceWithProducerPatterns.cpp
MLIRTensorDialect
MLIRTilingInterface
MLIRTransforms
- )
+)
--- /dev/null
+//===- RankReductionPatterns.cpp - Patterns related to rank reductions ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
+
+#define DEBUG_TYPE "mlir-tensor-split-padding"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+/// Fold expand_shape(extract_slice) ops that cancel itself out.
+struct FoldExpandOfRankReducingExtract
+ : public OpRewritePattern<ExpandShapeOp> {
+ using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp,
+ PatternRewriter &rewriter) const override {
+ RankedTensorType resultType = expandShapeOp.getResultType();
+ auto extractSliceOp =
+ expandShapeOp.getSrc().getDefiningOp<ExtractSliceOp>();
+ if (!extractSliceOp)
+ return failure();
+ RankedTensorType srcType = extractSliceOp.getSourceType();
+
+ // Only cases where the ExpandShapeOp can be folded away entirely are
+ // supported. Moreover, only simple cases where the resulting ExtractSliceOp
+ // has no rank-reduction anymore are supported at the moment.
+ RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType(
+ srcType, extractSliceOp.getStaticOffsets(),
+ extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides());
+ if (nonReducingExtractType != resultType)
+ return failure();
+
+ SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
+ SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
+ rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
+ expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes,
+ mixedStrides);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<FoldExpandOfRankReducingExtract>(patterns.getContext());
+}
--- /dev/null
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-reassociative-reshape-folding %s | FileCheck %s
+
+// CHECK-LABEL: func @expand_shape_of_rank_reducing_extract(
+// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x?xf32>
+// CHECK-DAG: %[[extract1:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
+// CHECK-DAG: %[[extract2:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
+// CHECK: return %[[extract1]], %[[extract2]]
+func.func @expand_shape_of_rank_reducing_extract(
+ %t: tensor<?x?x?x?xf32>, %idx: index)
+ -> (tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>)
+{
+ %0 = tensor.extract_slice %t[0, 0, 0, 0][%idx, 1, 1, 5][1, 1, 1, 1]
+ : tensor<?x?x?x?xf32> to tensor<?x1x5xf32>
+ %1 = tensor.expand_shape %0 [[0], [1, 2], [3]]
+ : tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
+ %2 = tensor.expand_shape %0 [[0, 1], [2], [3]]
+ : tensor<?x1x5xf32> into tensor<?x1x1x5xf32>
+ return %1, %2 : tensor<?x1x1x5xf32>, tensor<?x1x1x5xf32>
+}
"with loop nest"),
llvm::cl::init(false)};
+ Option<bool> testReassociativeReshapeFolding{
+ *this, "test-reassociative-reshape-folding",
+ llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
+ llvm::cl::init(false)};
+
Option<bool> useForeach{
*this, "use-foreach",
llvm::cl::desc(
};
} // namespace
+static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ tensor::populateReassociativeReshapeFoldingPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
static void applySplitPaddingPatterns(Operation *rootOp) {
RewritePatternSet patterns(rootOp->getContext());
tensor::populateSplitPaddingPatterns(patterns);
applyFoldConstantExtractSlicePatterns(rootOp);
if (testFoldConsecutiveInsertExtractSlice)
applyFoldConsecutiveInsertExtractSlicePatterns(rootOp);
+ if (testReassociativeReshapeFolding)
+ applyReassociativeReshapeFoldingPatterns(rootOp);
if (testRewriteExtractSliceWithTiledCollapseShape) {
if (failed(
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))