}
};
-/// Collect a set of transformation patterns that are related to contracting
-/// or expanding vector operations:
-/// ContractionOpLowering,
-/// ShapeCastOp2DDownCastRewritePattern,
-/// ShapeCastOp2DUpCastRewritePattern
-/// BroadcastOpLowering,
-/// OuterproductOpLowering
-/// These transformation express higher level vector ops in terms of more
-/// elementary extraction, insertion, reduction, product, and broadcast ops.
+/// Collects patterns to progressively lower vector.broadcast ops on high-D
+/// vectors to low-D vector ops.
+void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
+
+/// Collects patterns to progressively lower vector contraction ops on high-D
+/// into low-D reduction and product ops.
void populateVectorContractLoweringPatterns(
RewritePatternSet &patterns,
- VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
+ VectorTransformsOptions options = VectorTransformsOptions());
+
+/// Collects patterns to progressively lower vector mask ops into elementary
+/// selection and insertion ops.
+void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
+
+/// Collects patterns to progressively lower vector.shape_cast ops on high-D
+/// vectors into 1-D/2-D vector ops by generating data movement extract/insert
+/// ops.
+void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns);
/// Insert TransposeLowering patterns into extraction/insertion.
void populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
- VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
+ VectorTransformsOptions options = VectorTransformsOptions());
/// Returns the integer type required for subscripts in the vector dialect.
IntegerType getVectorSubscriptType(Builder &builder);
{
RewritePatternSet patterns(&getContext());
populateVectorToVectorCanonicalizationPatterns(patterns);
+ populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns);
+ populateVectorMaskOpLoweringPatterns(patterns);
+ populateVectorShapeCastLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns);
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
}
+void mlir::vector::populateVectorBroadcastLoweringPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<BroadcastOpLowering>(patterns.getContext());
+}
+
+void mlir::vector::populateVectorMaskOpLoweringPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
+ patterns.getContext());
+}
+
+void mlir::vector::populateVectorShapeCastLoweringPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<ShapeCastOp2DDownCastRewritePattern,
+ ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
+ patterns.getContext());
+}
+
void mlir::vector::populateVectorContractLoweringPatterns(
- RewritePatternSet &patterns, VectorTransformsOptions parameters) {
- // clang-format off
- patterns.add<BroadcastOpLowering,
- CreateMaskOpLowering,
- ConstantMaskOpLowering,
- OuterProductOpLowering,
- ShapeCastOp2DDownCastRewritePattern,
- ShapeCastOp2DUpCastRewritePattern,
- ShapeCastOpRewritePattern>(patterns.getContext());
- patterns.add<ContractionOpLowering,
- ContractionOpToMatmulOpLowering,
- ContractionOpToOuterProductOpLowering>(parameters, patterns.getContext());
- // clang-format on
+ RewritePatternSet &patterns, VectorTransformsOptions options) {
+ patterns.add<OuterProductOpLowering>(patterns.getContext());
+ patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
+ ContractionOpToOuterProductOpLowering>(options,
+ patterns.getContext());
}
void mlir::vector::populateVectorTransposeLoweringPatterns(
- RewritePatternSet &patterns,
- VectorTransformsOptions vectorTransformOptions) {
- patterns.add<TransposeOpLowering>(vectorTransformOptions,
- patterns.getContext());
+ RewritePatternSet &patterns, VectorTransformsOptions options) {
+ patterns.add<TransposeOpLowering>(options, patterns.getContext());
}
void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
// Programmatic controlled lowering of vector.contract only.
RewritePatternSet vectorContractLoweringPatterns(context);
+ populateVectorBroadcastLoweringPatterns(vectorContractLoweringPatterns);
populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
vectorTransformOptions);
+ populateVectorMaskOpLoweringPatterns(vectorContractLoweringPatterns);
+ populateVectorShapeCastLoweringPatterns(vectorContractLoweringPatterns);
populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns,
vectorTransformOptions);
(void)applyPatternsAndFoldGreedily(module,
if (lowerToFlatTranspose)
transposeLowering = VectorTransposeLowering::Flat;
VectorTransformsOptions options{contractLowering, transposeLowering};
+ populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, options);
+ populateVectorMaskOpLoweringPatterns(patterns);
+ populateVectorShapeCastLoweringPatterns(patterns);
populateVectorTransposeLoweringPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}