From 59d3a9e0877b2b12fc98eea0f9bbbc93f3c7a094 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 5 Apr 2022 14:17:22 -0400 Subject: [PATCH] [mlir][vector] Separate high-D insert/extract strided slice rewrite Right now `populateVectorInsertExtractStridedSliceTransforms` contains two categories of patterns, one for decomposing high-D insert/extract strided slices, the other for lowering them to shuffle ops. They are at different levels---the former is in the middle, while the latter is a step of final lowering. Split them to give users more control of which pattern to pick. This means break down the previous `VectorExtractStridedSliceOpRewritePattern`, which is doing two things together. Also renamed those patterns to be clearer. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D123137 --- .../Vector/Transforms/VectorRewritePatterns.h | 28 +++++--- ...torInsertExtractStridedSliceRewritePatterns.cpp | 84 +++++++++++++++------- 2 files changed, 78 insertions(+), 34 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 0a298ea..0522ef5 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -241,8 +241,8 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( /// Populate `patterns` with the following patterns. /// -/// [VectorInsertStridedSliceOpDifferentRankRewritePattern] -/// ======================================================= +/// [DecomposeDifferentRankInsertStridedSlice] +/// ========================================== /// RewritePattern for InsertStridedSliceOp where source and destination vectors /// have different ranks. /// @@ -257,8 +257,19 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( /// 2. k-D -> (n-1)-D InsertStridedSlice op /// 3. InsertOp that is the reverse of 1. /// -/// [VectorInsertStridedSliceOpSameRankRewritePattern] -/// ================================================== +/// [DecomposeNDExtractStridedSlice] +/// ================================ +/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower +/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case. +void populateVectorInsertExtractStridedSliceDecompositionPatterns( + RewritePatternSet &patterns); + +/// Populate `patterns` with the following patterns. +/// +/// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns(); +/// +/// [ConvertSameRankInsertStridedSliceIntoShuffle] +/// ============================================== /// RewritePattern for InsertStridedSliceOp where source and destination vectors /// have the same rank. For each outermost index in the slice: /// begin end stride @@ -268,12 +279,9 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( /// 3. the destination subvector is inserted back in the proper place /// 3. InsertOp that is the reverse of 1. /// -/// [VectorExtractStridedSliceOpRewritePattern] -/// =========================================== -/// Progressive lowering of ExtractStridedSliceOp to either: -/// 1. single offset extract as a direct vector::ShuffleOp. -/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp + -/// InsertOp/InsertElementOp for the n-D case. +/// [Convert1DExtractStridedSliceIntoShuffle] +/// ========================================= +/// For such cases, we can lower it to a ShuffleOp. void populateVectorInsertExtractStridedSliceTransforms( RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index 2a384c3..a1e80e1 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -45,14 +45,14 @@ static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, /// When ranks are different, InsertStridedSlice needs to extract a properly /// ranked vector from the destination vector into which to insert. This pattern /// only takes care of this extraction part and forwards the rest to -/// [VectorInsertStridedSliceOpSameRankRewritePattern]. +/// [ConvertSameRankInsertStridedSliceIntoShuffle]. /// /// For a k-D source and n-D destination vector (k < n), we emit: /// 1. ExtractOp to extract the (unique) (n-1)-D subvector into which to /// insert the k-D source. /// 2. k-D -> (n-1)-D InsertStridedSlice op /// 3. InsertOp that is the reverse of 1. -class VectorInsertStridedSliceOpDifferentRankRewritePattern +class DecomposeDifferentRankInsertStridedSlice : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -102,7 +102,7 @@ public: /// 2. InsertStridedSlice (k-1)-D into (n-1)-D /// 3. the destination subvector is inserted back in the proper place /// 3. InsertOp that is the reverse of 1. -class VectorInsertStridedSliceOpSameRankRewritePattern +class ConvertSameRankInsertStridedSliceIntoShuffle : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -193,11 +193,50 @@ public: } }; -/// Progressive lowering of ExtractStridedSliceOp to either: -/// 1. single offset extract as a direct vector::ShuffleOp. -/// 2. ExtractOp/ExtractElementOp + lower rank ExtractStridedSliceOp + -/// InsertOp/InsertElementOp for the n-D case. -class VectorExtractStridedSliceOpRewritePattern +/// RewritePattern for ExtractStridedSliceOp where source and destination +/// vectors are 1-D. For such cases, we can lower it to a ShuffleOp. +class Convert1DExtractStridedSliceIntoShuffle + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractStridedSliceOp op, + PatternRewriter &rewriter) const override { + auto dstType = op.getType(); + + assert(!op.getOffsets().getValue().empty() && "Unexpected empty offsets"); + + int64_t offset = + op.getOffsets().getValue().front().cast().getInt(); + int64_t size = + op.getSizes().getValue().front().cast().getInt(); + int64_t stride = + op.getStrides().getValue().front().cast().getInt(); + + auto loc = op.getLoc(); + auto elemType = dstType.getElementType(); + assert(elemType.isSignlessIntOrIndexOrFloat()); + + // Single offset can be more efficiently shuffled. + if (op.getOffsets().getValue().size() != 1) + return failure(); + + SmallVector offsets; + offsets.reserve(size); + for (int64_t off = offset, e = offset + size * stride; off < e; + off += stride) + offsets.push_back(off); + rewriter.replaceOpWithNewOp(op, dstType, op.getVector(), + op.getVector(), + rewriter.getI64ArrayAttr(offsets)); + return success(); + } +}; + +/// RewritePattern for ExtractStridedSliceOp where the source vector is n-D. +/// For such cases, we can rewrite it to ExtractOp/ExtractElementOp + lower +/// rank ExtractStridedSliceOp + InsertOp/InsertElementOp for the n-D case. +class DecomposeNDExtractStridedSlice : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -225,18 +264,10 @@ public: auto elemType = dstType.getElementType(); assert(elemType.isSignlessIntOrIndexOrFloat()); - // Single offset can be more efficiently shuffled. - if (op.getOffsets().getValue().size() == 1) { - SmallVector offsets; - offsets.reserve(size); - for (int64_t off = offset, e = offset + size * stride; off < e; - off += stride) - offsets.push_back(off); - rewriter.replaceOpWithNewOp(op, dstType, op.getVector(), - op.getVector(), - rewriter.getI64ArrayAttr(offsets)); - return success(); - } + // Single offset can be more efficiently shuffled. It's handled in + // Convert1DExtractStridedSliceIntoShuffle. + if (op.getOffsets().getValue().size() == 1) + return failure(); // Extract/insert on a lower ranked extract strided slice op. Value zero = rewriter.create( @@ -256,11 +287,16 @@ public: } }; +void mlir::vector::populateVectorInsertExtractStridedSliceDecompositionPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::vector::populateVectorInsertExtractStridedSliceTransforms( RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + populateVectorInsertExtractStridedSliceDecompositionPatterns(patterns); + patterns.add(patterns.getContext()); } -- 2.7.4