[mlir][Vector] Fold InsertStridedSliceOp of two splat with the same input to splat.
authorjacquesguan <Jianjian.Guan@streamcomputing.com>
Thu, 30 Jun 2022 08:30:59 +0000 (16:30 +0800)
committerjacquesguan <Jianjian.Guan@streamcomputing.com>
Fri, 1 Jul 2022 02:46:47 +0000 (10:46 +0800)
This patch folds InsertStridedSliceOp(SplatOp(X):src_type, SplatOp(X):dst_type) to SplatOp(X):dst_type.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D128891

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index 57c02c9..6f68f83 100644 (file)
@@ -886,6 +886,7 @@ def Vector_InsertStridedSliceOp :
 
   let hasFolder = 1;
   let hasVerifier = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Vector_OuterProductOp :
index c3b4d1e..3edd23f 100644 (file)
@@ -2180,6 +2180,38 @@ LogicalResult InsertStridedSliceOp::verify() {
   return success();
 }
 
+namespace {
+/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
+/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
+class FoldInsertStridedSliceSplat final
+    : public OpRewritePattern<InsertStridedSliceOp> {
+public:
+  using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcSplatOp =
+        insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
+    auto destSplatOp =
+        insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
+
+    if (!srcSplatOp || !destSplatOp)
+      return failure();
+
+    if (srcSplatOp.getInput() != destSplatOp.getInput())
+      return failure();
+
+    rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
+    return success();
+  }
+};
+} // namespace
+
+void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
+    RewritePatternSet &results, MLIRContext *context) {
+  results.add<FoldInsertStridedSliceSplat>(context);
+}
+
 OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
   if (getSourceVectorType() == getDestVectorType())
     return getSource();
index d16a6fa..7f50d90 100644 (file)
@@ -1627,3 +1627,17 @@ func.func @bitcast(%a: vector<4x8xf32>) -> vector<4x16xi16> {
   %1 = vector.bitcast %0 : vector<4x8xi32> to vector<4x16xi16>
   return %1 : vector<4x16xi16>
 }
+
+// -----
+
+// CHECK-LABEL: @insert_strided_slice_splat
+//  CHECK-SAME: (%[[ARG:.*]]: f32)
+//  CHECK-NEXT:   %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32>
+//  CHECK-NEXT:   return %[[SPLAT]] : vector<8x16xf32>
+func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
+  %splat0 = vector.splat %x : vector<4x4xf32>
+  %splat1 = vector.splat %x : vector<8x16xf32>
+  %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
+    : vector<4x4xf32> into vector<8x16xf32>
+  return %0 : vector<8x16xf32>
+}