}
};
+/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
+class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
+public:
+ using OpRewritePattern<InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertOp op,
+ PatternRewriter &rewriter) const override {
+ auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
+ auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
+
+ if (!srcSplat || !dstSplat)
+ return failure();
+
+ if (srcSplat.getInput() != dstSplat.getInput())
+ return failure();
+
+ rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
+ return success();
+ }
+};
+
} // namespace
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<InsertToBroadcast, BroadcastFolder>(context);
+ results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
}
// Eliminates insert operations that produce values identical to their source
return %shuffle : vector<4xi32>
}
+
+// -----
+
+// CHECK-LABEL: func @insert_splat
+// CHECK-SAME: (%[[ARG:.*]]: i32)
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32>
+// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32>
+func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
+ %v0 = vector.splat %x : vector<4x3xi32>
+ %v1 = vector.splat %x : vector<2x4x3xi32>
+ %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
+ return %insert : vector<2x4x3xi32>
+}