[mlir][Vector] Fold InsertStridedSliceOp of ExtractStridedSliceOp.
authorjacquesguan <Jianjian.Guan@streamcomputing.com>
Thu, 30 Jun 2022 11:24:31 +0000 (19:24 +0800)
committerjacquesguan <Jianjian.Guan@streamcomputing.com>
Fri, 1 Jul 2022 03:43:35 +0000 (11:43 +0800)
This patch supports to fold InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst) to dst.

Differential Revision: https://reviews.llvm.org/D128903

mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index 3edd23f..38f38f8 100644 (file)
@@ -2205,11 +2205,43 @@ public:
     return success();
   }
 };
+
+/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
+/// to dst.
+class FoldInsertStridedSliceOfExtract final
+    : public OpRewritePattern<InsertStridedSliceOp> {
+public:
+  using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto extractStridedSliceOp =
+        insertStridedSliceOp.getSource()
+            .getDefiningOp<vector::ExtractStridedSliceOp>();
+
+    if (!extractStridedSliceOp)
+      return failure();
+
+    if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
+      return failure();
+
+    // Check if have the same strides and offsets.
+    if (extractStridedSliceOp.getStrides() !=
+            insertStridedSliceOp.getStrides() ||
+        extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
+      return failure();
+
+    rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
+    return success();
+  }
+};
+
 } // namespace
 
 void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
-  results.add<FoldInsertStridedSliceSplat>(context);
+  results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract>(
+      context);
 }
 
 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
index 7f50d90..515a2d1 100644 (file)
@@ -1641,3 +1641,17 @@ func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
     : vector<4x4xf32> into vector<8x16xf32>
   return %0 : vector<8x16xf32>
 }
+
+
+// -----
+
+// CHECK-LABEL: @insert_extract_strided_slice
+//  CHECK-SAME: (%[[ARG:.*]]: vector<8x16xf32>)
+//  CHECK-NEXT:   return %[[ARG]] : vector<8x16xf32>
+func.func @insert_extract_strided_slice(%x: vector<8x16xf32>) -> (vector<8x16xf32>) {
+  %0 = vector.extract_strided_slice %x {offsets = [0, 8], sizes = [2, 4], strides = [1, 1]}
+        : vector<8x16xf32> to vector<2x4xf32>
+  %1 = vector.insert_strided_slice %0, %x {offsets = [0, 8], strides = [1, 1]}
+        : vector<2x4xf32> into vector<8x16xf32>
+  return %1 : vector<8x16xf32>
+}