[mlir][vector] Split populateVectorContractLoweringPatterns
authorLei Zhang <antiagainst@google.com>
Thu, 7 Oct 2021 13:33:51 +0000 (09:33 -0400)
committerLei Zhang <antiagainst@google.com>
Thu, 7 Oct 2021 13:39:26 +0000 (09:39 -0400)
It was bundling quite a lot of patterns that convert high-D
vector ops into low-D elementary ops. It might not be good
for all of the patterns to happen for a particular downstream
user. For example, `ShapeCastOpRewritePattern` rewrites
`vector.shape_cast` into data movement extract/insert ops.

Instead, split the entry point into multiple ones so users
can pull in patterns on demand.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D111225

mlir/include/mlir/Dialect/Vector/VectorOps.h
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

index 9bc2cd4..a98ca36 100644 (file)
@@ -159,23 +159,29 @@ struct VectorTransformsOptions {
   }
 };
 
-/// Collect a set of transformation patterns that are related to contracting
-/// or expanding vector operations:
-///   ContractionOpLowering,
-///   ShapeCastOp2DDownCastRewritePattern,
-///   ShapeCastOp2DUpCastRewritePattern
-///   BroadcastOpLowering,
-///   OuterproductOpLowering
-/// These transformation express higher level vector ops in terms of more
-/// elementary extraction, insertion, reduction, product, and broadcast ops.
+/// Collects patterns to progressively lower vector.broadcast ops on high-D
+/// vectors to low-D vector ops.
+void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns);
+
+/// Collects patterns to progressively lower vector contraction ops on high-D
+/// into low-D reduction and product ops.
 void populateVectorContractLoweringPatterns(
     RewritePatternSet &patterns,
-    VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
+    VectorTransformsOptions options = VectorTransformsOptions());
+
+/// Collects patterns to progressively lower vector mask ops into elementary
+/// selection and insertion ops.
+void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns);
+
+/// Collects patterns to progressively lower vector.shape_cast ops on high-D
+/// vectors into 1-D/2-D vector ops by generating data movement extract/insert
+/// ops.
+void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns);
 
 /// Insert TransposeLowering patterns into extraction/insertion.
 void populateVectorTransposeLoweringPatterns(
     RewritePatternSet &patterns,
-    VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
+    VectorTransformsOptions options = VectorTransformsOptions());
 
 /// Returns the integer type required for subscripts in the vector dialect.
 IntegerType getVectorSubscriptType(Builder &builder);
index 1a708dc..d920bb7 100644 (file)
@@ -62,7 +62,10 @@ void LowerVectorToLLVMPass::runOnOperation() {
   {
     RewritePatternSet patterns(&getContext());
     populateVectorToVectorCanonicalizationPatterns(patterns);
+    populateVectorBroadcastLoweringPatterns(patterns);
     populateVectorContractLoweringPatterns(patterns);
+    populateVectorMaskOpLoweringPatterns(patterns);
+    populateVectorShapeCastLoweringPatterns(patterns);
     populateVectorTransposeLoweringPatterns(patterns);
     // Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
     populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
index 8f1d9fc..999f37f 100644 (file)
@@ -3847,27 +3847,35 @@ void mlir::vector::populateBubbleVectorBitCastOpPatterns(
                BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
 }
 
+void mlir::vector::populateVectorBroadcastLoweringPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<BroadcastOpLowering>(patterns.getContext());
+}
+
+void mlir::vector::populateVectorMaskOpLoweringPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<CreateMaskOpLowering, ConstantMaskOpLowering>(
+      patterns.getContext());
+}
+
+void mlir::vector::populateVectorShapeCastLoweringPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<ShapeCastOp2DDownCastRewritePattern,
+               ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern>(
+      patterns.getContext());
+}
+
 void mlir::vector::populateVectorContractLoweringPatterns(
-    RewritePatternSet &patterns, VectorTransformsOptions parameters) {
-  // clang-format off
-  patterns.add<BroadcastOpLowering,
-                  CreateMaskOpLowering,
-                  ConstantMaskOpLowering,
-                  OuterProductOpLowering,
-                  ShapeCastOp2DDownCastRewritePattern,
-                  ShapeCastOp2DUpCastRewritePattern,
-                  ShapeCastOpRewritePattern>(patterns.getContext());
-  patterns.add<ContractionOpLowering,
-                  ContractionOpToMatmulOpLowering,
-                  ContractionOpToOuterProductOpLowering>(parameters, patterns.getContext());
-  // clang-format on
+    RewritePatternSet &patterns, VectorTransformsOptions options) {
+  patterns.add<OuterProductOpLowering>(patterns.getContext());
+  patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
+               ContractionOpToOuterProductOpLowering>(options,
+                                                      patterns.getContext());
 }
 
 void mlir::vector::populateVectorTransposeLoweringPatterns(
-    RewritePatternSet &patterns,
-    VectorTransformsOptions vectorTransformOptions) {
-  patterns.add<TransposeOpLowering>(vectorTransformOptions,
-                                    patterns.getContext());
+    RewritePatternSet &patterns, VectorTransformsOptions options) {
+  patterns.add<TransposeOpLowering>(options, patterns.getContext());
 }
 
 void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
index bf5cf5a..45c985b 100644 (file)
@@ -112,8 +112,11 @@ void TestConvVectorization::runOnOperation() {
 
   // Programmatic controlled lowering of vector.contract only.
   RewritePatternSet vectorContractLoweringPatterns(context);
+  populateVectorBroadcastLoweringPatterns(vectorContractLoweringPatterns);
   populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
                                          vectorTransformOptions);
+  populateVectorMaskOpLoweringPatterns(vectorContractLoweringPatterns);
+  populateVectorShapeCastLoweringPatterns(vectorContractLoweringPatterns);
   populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns,
                                           vectorTransformOptions);
   (void)applyPatternsAndFoldGreedily(module,
index 907f9ae..d4182f5 100644 (file)
@@ -164,7 +164,10 @@ struct TestVectorContractionConversion
     if (lowerToFlatTranspose)
       transposeLowering = VectorTransposeLowering::Flat;
     VectorTransformsOptions options{contractLowering, transposeLowering};
+    populateVectorBroadcastLoweringPatterns(patterns);
     populateVectorContractLoweringPatterns(patterns, options);
+    populateVectorMaskOpLoweringPatterns(patterns);
+    populateVectorShapeCastLoweringPatterns(patterns);
     populateVectorTransposeLoweringPatterns(patterns, options);
     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
   }