From 6176d6a93e71e6d2bf89bd50e41c30e936ed05a9 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 6 Jan 2023 11:59:30 +0100 Subject: [PATCH] [mlir][tensor] Support parallel_insert_slice in MergeConsecutiveInsertExtractSlicePatterns.cpp Differential Revision: https://reviews.llvm.org/D141116 --- .../MergeConsecutiveInsertExtractSlicePatterns.cpp | 15 +++++++++------ .../Tensor/fold-consecutive-insert-extract-slice.mlir | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp index 262ef48..4169882 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp @@ -41,12 +41,13 @@ struct MergeConsecutiveExtractSlice : public OpRewritePattern { }; /// Merges consecutive tensor.insert_slice ops into one. -struct MergeConsecutiveInsertSlice : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +template +struct MergeConsecutiveInsertSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InsertSliceOp nextOp, + LogicalResult matchAndRewrite(OpTy nextOp, PatternRewriter &rewriter) const override { - auto prevOp = nextOp.getSource().getDefiningOp(); + auto prevOp = nextOp.getSource().template getDefiningOp(); if (!prevOp) return failure(); @@ -67,7 +68,7 @@ struct MergeConsecutiveInsertSlice : public OpRewritePattern { !prevOp.getDestType().hasStaticShape()) return failure(); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(), nextOp.getMixedSizes(), nextOp.getMixedStrides()); return success(); @@ -77,6 +78,8 @@ struct MergeConsecutiveInsertSlice : public OpRewritePattern { void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns( RewritePatternSet &patterns) { - patterns.add( + patterns.add, + MergeConsecutiveInsertSlice>( patterns.getContext()); } diff --git a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir index f5d77f6..a120b0f 100644 --- a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir +++ b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir @@ -81,3 +81,20 @@ func.func @insert_slice_rank_reducing_dynamic_shape( // CHECK-LABEL: func.func @insert_slice_rank_reducing_dynamic_shape // CHECK-COUNT-2: tensor.insert_slice + +// ----- + +// CHECK-LABEL: func.func @parallel_insert_slice +// CHECK-NOT: tensor.insert_slice +// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor into tensor<1x2xf32> +func.func @parallel_insert_slice(%t0: tensor<1x2xf32>, %t1: tensor, %t2: tensor<1x1xf32>) -> tensor<1x2xf32> { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %r = scf.foreach_thread (%arg2, %arg3) in (%c1, %c2) shared_outs(%arg4 = %t0) -> (tensor<1x2xf32>) { + %inserted_slice = tensor.insert_slice %t1 into %t2[0, 0] [1, 1] [1, 1] : tensor into tensor<1x1xf32> + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %inserted_slice into %arg4[%arg2, %arg3] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x2xf32> + } + } + return %r : tensor<1x2xf32> +} -- 2.7.4