[mlir][Vector] Fold InsertOp(SplatOp(X), SplatOp(X)) to SplatOp(X).
authorjacquesguan <Jianjian.Guan@streamcomputing.com>
Fri, 1 Jul 2022 08:02:28 +0000 (16:02 +0800)
committerjacquesguan <Jianjian.Guan@streamcomputing.com>
Wed, 6 Jul 2022 03:27:23 +0000 (11:27 +0800)
This patch folds InsertOp(SplatOp(X), SplatOp(X)) to SplatOp(X).

Differential Revision: https://reviews.llvm.org/D129058

mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir

index a387479..00db465 100644 (file)
@@ -2031,11 +2031,32 @@ public:
   }
 };
 
+/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
+class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(InsertOp op,
+                                PatternRewriter &rewriter) const override {
+    auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
+    auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
+
+    if (!srcSplat || !dstSplat)
+      return failure();
+
+    if (srcSplat.getInput() != dstSplat.getInput())
+      return failure();
+
+    rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
+    return success();
+  }
+};
+
 } // namespace
 
 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<InsertToBroadcast, BroadcastFolder>(context);
+  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
 }
 
 // Eliminates insert operations that produce values identical to their source
index e7747c7..84b5a45 100644 (file)
@@ -1669,3 +1669,16 @@ func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
   return %shuffle : vector<4xi32>
 }
 
+
+// -----
+
+// CHECK-LABEL: func @insert_splat
+//  CHECK-SAME:   (%[[ARG:.*]]: i32)
+//  CHECK-NEXT:   %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32>
+//  CHECK-NEXT:   return %[[SPLAT]] : vector<2x4x3xi32>
+func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
+  %v0 = vector.splat %x : vector<4x3xi32>
+  %v1 = vector.splat %x : vector<2x4x3xi32>
+  %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
+  return %insert : vector<2x4x3xi32>
+}