From 9cdf6b641da1a7ba0145b224460c64efd65017e0 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 7 Dec 2022 16:22:07 +0100 Subject: [PATCH] [mlir][tensor] Support parallel_insert_slice in reassociative reshape folder Differential Revision: https://reviews.llvm.org/D139540 --- .../Dialect/Tensor/Transforms/ReshapePatterns.cpp | 23 +++++++++++----------- .../Tensor/fold-reassociative-reshapes.mlir | 19 ++++++++++++++++++ 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index b655df3..d40e5f3 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -11,8 +11,6 @@ #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" -#define DEBUG_TYPE "mlir-tensor-split-padding" - using namespace mlir; using namespace mlir::tensor; @@ -51,13 +49,14 @@ struct FoldExpandOfRankReducingExtract }; /// Fold insert_slice(collapse_shape) ops that cancel itself out. -struct FoldInsertOfRankReducingInsert : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +template +struct FoldInsertOfRankReducingInsert : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, + LogicalResult matchAndRewrite(OpTy insertSliceOp, PatternRewriter &rewriter) const override { auto collapseShapeOp = - insertSliceOp.getSource().getDefiningOp(); + insertSliceOp.getSource().template getDefiningOp(); if (!collapseShapeOp) return failure(); RankedTensorType srcType = collapseShapeOp.getSrcType(); @@ -67,16 +66,16 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern { // has no rank-reduction anymore are supported at the moment. RankedTensorType nonReducingInsertType = RankedTensorType::get(insertSliceOp.getStaticSizes(), - insertSliceOp.getType().getElementType()); + insertSliceOp.getDestType().getElementType()); if (nonReducingInsertType != srcType) return failure(); SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); SmallVector mixedSizes = insertSliceOp.getMixedSizes(); SmallVector mixedStrides = insertSliceOp.getMixedStrides(); - rewriter.replaceOpWithNewOp( - insertSliceOp, collapseShapeOp.getSrc(), insertSliceOp.getDest(), - mixedOffsets, mixedSizes, mixedStrides); + rewriter.replaceOpWithNewOp(insertSliceOp, collapseShapeOp.getSrc(), + insertSliceOp.getDest(), mixedOffsets, + mixedSizes, mixedStrides); return success(); } }; @@ -84,6 +83,8 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern { void mlir::tensor::populateReassociativeReshapeFoldingPatterns( RewritePatternSet &patterns) { - patterns.add( + patterns.add, + FoldInsertOfRankReducingInsert>( patterns.getContext()); } diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir index 15a00a5..e6256a9 100644 --- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir +++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir @@ -33,3 +33,22 @@ func.func @rank_reducing_insert_of_collapse_shape( : tensor into tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: func @rank_reducing_parallel_insert_of_collapse_shape( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK: tensor.parallel_insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor into tensor +func.func @rank_reducing_parallel_insert_of_collapse_shape( + %t: tensor, %d: tensor, %sz: index, %thr: index) + -> tensor { + %0 = tensor.collapse_shape %t [[0, 1], [2], [3]] + : tensor into tensor + %1 = scf.foreach_thread (%iv) in (%thr) shared_outs(%o = %d) -> (tensor) { + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %0 into %o[0, 0, 0, 0][%sz, 1, 1, 5][1, 1, 1, 1] + : tensor into tensor + } + } + return %1 : tensor +} -- 2.7.4