From c837a94754f9e00235ab64eecc700234e7e501e8 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 2 Dec 2022 21:22:04 +0100 Subject: [PATCH] Revert "[mlir][tensor] Fold rank-reducing insert_slice with inverse collapse_shape" This reverts commit 1522a3b7b34b41cf0b17678e4a8687797f44a3f0. --- .../Dialect/Tensor/Transforms/ReshapePatterns.cpp | 34 +--------------------- .../Tensor/fold-reassociative-reshapes.mlir | 16 ---------- 2 files changed, 1 insertion(+), 49 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index b655df3..c1166c5 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -49,41 +49,9 @@ struct FoldExpandOfRankReducingExtract return success(); } }; - -/// Fold insert_slice(collapse_shape) ops that cancel itself out. -struct FoldInsertOfRankReducingInsert : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, - PatternRewriter &rewriter) const override { - auto collapseShapeOp = - insertSliceOp.getSource().getDefiningOp(); - if (!collapseShapeOp) - return failure(); - RankedTensorType srcType = collapseShapeOp.getSrcType(); - - // Only cases where the CollapseShapeOp can be folded away entirely are - // supported. Moreover, only simple cases where the resulting InsertSliceOp - // has no rank-reduction anymore are supported at the moment. - RankedTensorType nonReducingInsertType = - RankedTensorType::get(insertSliceOp.getStaticSizes(), - insertSliceOp.getType().getElementType()); - if (nonReducingInsertType != srcType) - return failure(); - - SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); - SmallVector mixedSizes = insertSliceOp.getMixedSizes(); - SmallVector mixedStrides = insertSliceOp.getMixedStrides(); - rewriter.replaceOpWithNewOp( - insertSliceOp, collapseShapeOp.getSrc(), insertSliceOp.getDest(), - mixedOffsets, mixedSizes, mixedStrides); - return success(); - } -}; } // namespace void mlir::tensor::populateReassociativeReshapeFoldingPatterns( RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns.add(patterns.getContext()); } diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir index 15a00a5..c81e531 100644 --- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir +++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir @@ -17,19 +17,3 @@ func.func @expand_shape_of_rank_reducing_extract( : tensor into tensor return %1, %2 : tensor, tensor } - -// ----- - -// CHECK-LABEL: func @rank_reducing_insert_of_collapse_shape( -// CHECK-SAME: %[[t:.*]]: tensor -// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor into tensor -// CHECK: return %[[insert]] -func.func @rank_reducing_insert_of_collapse_shape( - %t: tensor, %d: tensor, %sz: index) - -> tensor { - %0 = tensor.collapse_shape %t [[0, 1], [2], [3]] - : tensor into tensor - %1 = tensor.insert_slice %0 into %d[0, 0, 0, 0][%sz, 1, 1, 5][1, 1, 1, 1] - : tensor into tensor - return %1 : tensor -} -- 2.7.4