return success();
}
+namespace {
+/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
+/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
+class FoldInsertStridedSliceSplat final
+ : public OpRewritePattern<InsertStridedSliceOp> {
+public:
+ using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
+ PatternRewriter &rewriter) const override {
+ auto srcSplatOp =
+ insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
+ auto destSplatOp =
+ insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
+
+ if (!srcSplatOp || !destSplatOp)
+ return failure();
+
+ if (srcSplatOp.getInput() != destSplatOp.getInput())
+ return failure();
+
+ rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
+ return success();
+ }
+};
+} // namespace
+
+void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
+ RewritePatternSet &results, MLIRContext *context) {
+ results.add<FoldInsertStridedSliceSplat>(context);
+}
+
OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
if (getSourceVectorType() == getDestVectorType())
return getSource();
%1 = vector.bitcast %0 : vector<4x8xi32> to vector<4x16xi16>
return %1 : vector<4x16xi16>
}
+
+// -----
+
+// CHECK-LABEL: @insert_strided_slice_splat
+// CHECK-SAME: (%[[ARG:.*]]: f32)
+// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32>
+// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
+func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
+ %splat0 = vector.splat %x : vector<4x4xf32>
+ %splat1 = vector.splat %x : vector<8x16xf32>
+ %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
+ : vector<4x4xf32> into vector<8x16xf32>
+ return %0 : vector<8x16xf32>
+}