From 140307379075ddd5aa6593d74c89e519baea7238 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 5 Dec 2022 09:16:05 +0100 Subject: [PATCH] [mlir][tensor] Fold rank-reducing insert_slice with inverse collapse_shape Differential Revision: https://reviews.llvm.org/D139221 --- .../Dialect/Tensor/Transforms/ReshapePatterns.cpp | 34 +++++++++++++++++++++- .../Tensor/fold-reassociative-reshapes.mlir | 16 ++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index c1166c5..b655df3 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -49,9 +49,41 @@ struct FoldExpandOfRankReducingExtract return success(); } }; + +/// Fold insert_slice(collapse_shape) ops that cancel itself out. +struct FoldInsertOfRankReducingInsert : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, + PatternRewriter &rewriter) const override { + auto collapseShapeOp = + insertSliceOp.getSource().getDefiningOp(); + if (!collapseShapeOp) + return failure(); + RankedTensorType srcType = collapseShapeOp.getSrcType(); + + // Only cases where the CollapseShapeOp can be folded away entirely are + // supported. Moreover, only simple cases where the resulting InsertSliceOp + // has no rank-reduction anymore are supported at the moment. + RankedTensorType nonReducingInsertType = + RankedTensorType::get(insertSliceOp.getStaticSizes(), + insertSliceOp.getType().getElementType()); + if (nonReducingInsertType != srcType) + return failure(); + + SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); + SmallVector mixedSizes = insertSliceOp.getMixedSizes(); + SmallVector mixedStrides = insertSliceOp.getMixedStrides(); + rewriter.replaceOpWithNewOp( + insertSliceOp, collapseShapeOp.getSrc(), insertSliceOp.getDest(), + mixedOffsets, mixedSizes, mixedStrides); + return success(); + } +}; } // namespace void mlir::tensor::populateReassociativeReshapeFoldingPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); } diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir index c81e531..15a00a5 100644 --- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir +++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir @@ -17,3 +17,19 @@ func.func @expand_shape_of_rank_reducing_extract( : tensor into tensor return %1, %2 : tensor, tensor } + +// ----- + +// CHECK-LABEL: func @rank_reducing_insert_of_collapse_shape( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor into tensor +// CHECK: return %[[insert]] +func.func @rank_reducing_insert_of_collapse_shape( + %t: tensor, %d: tensor, %sz: index) + -> tensor { + %0 = tensor.collapse_shape %t [[0, 1], [2], [3]] + : tensor into tensor + %1 = tensor.insert_slice %0 into %d[0, 0, 0, 0][%sz, 1, 1, 5][1, 1, 1, 1] + : tensor into tensor + return %1 : tensor +} -- 2.7.4