[mlir][Linalg] Retire LinalgStrategyLowerVectorsPass and filter-based patterns
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Tue, 4 Oct 2022 22:42:41 +0000 (15:42 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 5 Oct 2022 07:55:27 +0000 (00:55 -0700)
Context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785

Depends on D135200

Differential Revision: https://reviews.llvm.org/D135222

mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

index 13db77e..342ef1c 100644 (file)
@@ -103,14 +103,6 @@ std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyDecomposePass(
     const linalg::LinalgTransformationFilter &filter =
         linalg::LinalgTransformationFilter());
 
-/// Create a LinalgStrategyLowerVectorsPass.
-std::unique_ptr<OperationPass<func::FuncOp>>
-createLinalgStrategyLowerVectorsPass(
-    linalg::LinalgVectorLoweringOptions opt =
-        linalg::LinalgVectorLoweringOptions(),
-    const linalg::LinalgTransformationFilter &filter =
-        linalg::LinalgTransformationFilter());
-
 /// Create a LinalgStrategyRemoveMarkersPass.
 std::unique_ptr<OperationPass<func::FuncOp>>
 createLinalgStrategyRemoveMarkersPass();
index 4533110..43a6cad 100644 (file)
@@ -212,17 +212,6 @@ def LinalgStrategyDecomposePass
   ];
 }
 
-def LinalgStrategyLowerVectorsPass
-    : Pass<"linalg-strategy-lower-vectors-pass", "func::FuncOp"> {
-  let summary = "Configurable pass to lower vector operations.";
-  let constructor = "mlir::createLinalgStrategyLowerVectorsPass()";
-  let dependentDialects = ["linalg::LinalgDialect"];
-  let options = [
-    Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
-      "Which func op is the anchor to latch on.">,
-  ];
-}
-
 def LinalgStrategyRemoveMarkersPass
     : Pass<"linalg-strategy-remove-markers-pass", "func::FuncOp"> {
   let summary = "Cleanup pass that drops markers.";
index f8f89ef..6f80f41 100644 (file)
@@ -92,22 +92,6 @@ struct Decompose : public Transformation {
   }
 };
 
-/// Represent one application of createLinalgStrategyLowerVectorsPass.
-struct VectorLowering : public Transformation {
-  explicit VectorLowering(
-      linalg::LinalgVectorLoweringOptions options,
-      LinalgTransformationFilter::FilterFunction f = nullptr)
-      : Transformation(std::move(f)), options(options) {}
-
-  void addToPassPipeline(OpPassManager &pm,
-                         LinalgTransformationFilter m) const override {
-    pm.addPass(createLinalgStrategyLowerVectorsPass(options, m));
-  }
-
-private:
-  linalg::LinalgVectorLoweringOptions options;
-};
-
 /// Codegen strategy controls how a Linalg op is progressively lowered.
 struct CodegenStrategy {
   /// Append a pattern to tile the Op `opName` and fuse its producers with
@@ -169,12 +153,6 @@ struct CodegenStrategy {
   decomposeIf(bool b, LinalgTransformationFilter::FilterFunction f = nullptr) {
     return b ? decompose(std::move(f)) : *this;
   }
-  /// Append a pattern to lower all vector operations.
-  CodegenStrategy &vectorLowering(LinalgVectorLoweringOptions options) {
-    transformationSequence.emplace_back(
-        std::make_unique<VectorLowering>(options));
-    return *this;
-  }
   /// Configure the post staged-patterns global enabling passes options.
   CodegenStrategy &
   setVectorTransferToSCFOptions(LinalgEnablingOptions options) {
index 8b441b2..727c8d4 100644 (file)
@@ -927,96 +927,6 @@ struct LinalgEnablingOptions {
   }
 };
 
-/// Vector lowering options control how ops are lowered down to 1-D and scf.for
-/// form.
-struct LinalgVectorLoweringOptions {
-  /// Enable lowering of vector.contract.
-  /// In a progressive lowering of vectors, this would be the 1st step.
-  bool contractionLowering = false;
-  LinalgVectorLoweringOptions &enableContractionLowering(bool val = true) {
-    contractionLowering = val;
-    return *this;
-  }
-  /// Enable lowering of vector.multi_reduce.
-  /// In a progressive lowering of vectors, this would be the 2nd step.
-  bool multiReductionLowering = false;
-  LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) {
-    multiReductionLowering = val;
-    return *this;
-  }
-  /// Trigger full / partial vector.transfer splits.
-  /// In a progressive lowering of vectors, this would be the 3rd step.
-  bool transferPartialRewrite = false;
-  LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) {
-    transferPartialRewrite = val;
-    return *this;
-  }
-  /// Enable lowering of vector.transfer to scf.
-  /// In a progressive lowering of vectors, this would be the 4th step.
-  bool transferToSCFConversion = false;
-  LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) {
-    transferToSCFConversion = val;
-    return *this;
-  }
-  /// Maximal transfer rank under which we do not lower further.
-  int64_t maxTransferRank = 1;
-  LinalgVectorLoweringOptions &setMaxTransferRank(int64_t val) {
-    maxTransferRank = val;
-    return *this;
-  }
-  /// Vector lowering operations may result in surprising behavior when
-  /// composing multiple codegen strategies and must be enabled explicitly.
-  /// In a progressive lowering of vectors, this would be the 5th step.
-  bool transferLowering = true;
-  LinalgVectorLoweringOptions &enableTransferLowering(bool val = true) {
-    transferLowering = val;
-    return *this;
-  }
-  /// Enable lowering of vector.shape_cast to insert/extract.
-  /// In a progressive lowering of vectors, this would be the 6th step.
-  bool shapeCastLowering = true;
-  LinalgVectorLoweringOptions &enableShapeCastLowering(bool val = true) {
-    shapeCastLowering = val;
-    return *this;
-  }
-  /// Enable lowering of vector.transpose.
-  /// In a progressive lowering of vectors, this would be the 7th step.
-  bool transposeLowering = false;
-  LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) {
-    transposeLowering = val;
-    return *this;
-  }
-  /// Enable AVX2-specific lowerings.
-  bool avx2Lowering = false;
-  LinalgVectorLoweringOptions &enableAVX2Lowering(bool val = true) {
-    avx2Lowering = val;
-    return *this;
-  }
-
-  /// Configure the post staged-patterns late vector.transfer to scf
-  /// conversion.
-  VectorTransferToSCFOptions vectorTransferToSCFOptions;
-  LinalgVectorLoweringOptions &
-  setVectorTransferToSCFOptions(VectorTransferToSCFOptions options) {
-    vectorTransferToSCFOptions = options;
-    return *this;
-  }
-  /// Configure late vector transformations.
-  vector::VectorTransformsOptions vectorTransformOptions;
-  LinalgVectorLoweringOptions &
-  setVectorTransformsOptions(vector::VectorTransformsOptions options) {
-    vectorTransformOptions = options;
-    return *this;
-  }
-  /// Configure specialized vector lowerings.
-  x86vector::avx2::LoweringOptions avx2LoweringOptions;
-  LinalgVectorLoweringOptions &
-  setAVX2LoweringOptions(x86vector::avx2::LoweringOptions options) {
-    avx2LoweringOptions = options;
-    return *this;
-  }
-};
-
 //===----------------------------------------------------------------------===//
 // Transformations exposed as rewrite patterns.
 //===----------------------------------------------------------------------===//
index 30609e5..fd91b0a 100644 (file)
@@ -181,71 +181,6 @@ struct LinalgStrategyDecomposePass
 };
 
 /// Configurable pass to lower vector operations.
-struct LinalgStrategyLowerVectorsPass
-    : public impl::LinalgStrategyLowerVectorsPassBase<
-          LinalgStrategyLowerVectorsPass> {
-
-  LinalgStrategyLowerVectorsPass(LinalgVectorLoweringOptions opt,
-                                 LinalgTransformationFilter filt)
-      : options(opt), filter(std::move(filt)) {}
-
-  void runOnOperation() override {
-    auto funcOp = getOperation();
-    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
-      return;
-
-    MLIRContext *context = funcOp.getContext();
-    RewritePatternSet patterns(context);
-    vector::populateVectorToVectorCanonicalizationPatterns(patterns);
-    // In a progressive lowering of vectors, this would be the 1st step.
-    if (options.contractionLowering) {
-      patterns.add<ContractionOpToOuterProductOpLowering,
-                   ContractionOpToMatmulOpLowering, ContractionOpLowering>(
-          options.vectorTransformOptions, context);
-      vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
-    }
-    // In a progressive lowering of vectors, this would be the 2nd step.
-    if (options.multiReductionLowering) {
-      vector::populateVectorMultiReductionLoweringPatterns(
-          patterns,
-          options.vectorTransformOptions.vectorMultiReductionLowering);
-    }
-    // In a progressive lowering of vectors, this would be the 3rd step.
-    if (options.transferPartialRewrite) {
-      patterns.add<vector::VectorTransferFullPartialRewriter>(
-          context, options.vectorTransformOptions);
-    }
-    // In a progressive lowering of vectors, this would be the 4th step.
-    if (options.transferLowering) {
-      vector::populateVectorTransferLoweringPatterns(patterns,
-                                                     options.maxTransferRank);
-    }
-    // In a progressive lowering of vectors, this would be the 5th step.
-    if (options.transferToSCFConversion) {
-      populateVectorToSCFConversionPatterns(
-          patterns, options.vectorTransferToSCFOptions.setTargetRank(
-                        options.maxTransferRank));
-    }
-    // In a progressive lowering of vectors, this would be the 6th step.
-    if (options.shapeCastLowering) {
-      vector::populateVectorShapeCastLoweringPatterns(patterns);
-    }
-    // In a progressive lowering of vectors, this would be the 7th step.
-    if (options.transposeLowering) {
-      vector::populateVectorTransposeLoweringPatterns(
-          patterns, options.vectorTransformOptions);
-      if (options.avx2Lowering)
-        x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
-            patterns, options.avx2LoweringOptions, /*benefit=*/10);
-    }
-    (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
-  }
-
-  LinalgVectorLoweringOptions options;
-  LinalgTransformationFilter filter;
-};
-
-/// Configurable pass to lower vector operations.
 struct LinalgStrategyRemoveMarkersPass
     : public impl::LinalgStrategyRemoveMarkersPassBase<
           LinalgStrategyRemoveMarkersPass> {
@@ -294,13 +229,6 @@ mlir::createLinalgStrategyDecomposePass(
   return std::make_unique<LinalgStrategyDecomposePass>(filter);
 }
 
-/// Create a LinalgStrategyLowerVectorsPass.
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createLinalgStrategyLowerVectorsPass(
-    LinalgVectorLoweringOptions opt, const LinalgTransformationFilter &filter) {
-  return std::make_unique<LinalgStrategyLowerVectorsPass>(opt, filter);
-}
-
 /// Create a LinalgStrategyRemoveMarkersPass.
 std::unique_ptr<OperationPass<func::FuncOp>>
 mlir::createLinalgStrategyRemoveMarkersPass() {
index 3426bbb..5547a96 100644 (file)
@@ -235,39 +235,40 @@ struct TestVectorTransposeLowering
   }
 
   void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
+    func::FuncOp funcOp = getOperation();
+    MLIRContext *context = funcOp.getContext();
+    RewritePatternSet patterns(context);
 
-    // Test on one pattern in isolation.
-    // Explicitly disable shape_cast lowering.
-    LinalgVectorLoweringOptions options = LinalgVectorLoweringOptions()
-                                              .enableVectorTransposeLowering()
-                                              .enableShapeCastLowering(false);
+    vector::VectorTransformsOptions vectorTransformOptions;
     if (lowerToEltwise) {
-      options = options.setVectorTransformsOptions(
-          VectorTransformsOptions().setVectorTransposeLowering(
-              VectorTransposeLowering::EltWise));
+      vectorTransformOptions =
+          vectorTransformOptions.setVectorTransposeLowering(
+              VectorTransposeLowering::EltWise);
     }
     if (lowerToFlatTranspose) {
-      options = options.setVectorTransformsOptions(
-          VectorTransformsOptions().setVectorTransposeLowering(
-              VectorTransposeLowering::Flat));
+      vectorTransformOptions =
+          vectorTransformOptions.setVectorTransposeLowering(
+              VectorTransposeLowering::Flat);
     }
     if (lowerToShuffleTranspose) {
-      options = options.setVectorTransformsOptions(
-          VectorTransformsOptions().setVectorTransposeLowering(
-              VectorTransposeLowering::Shuffle));
+      vectorTransformOptions =
+          vectorTransformOptions.setVectorTransposeLowering(
+              VectorTransposeLowering::Shuffle);
     }
+    vector::populateVectorTransposeLoweringPatterns(patterns,
+                                                    vectorTransformOptions);
+
     if (lowerToAvx2) {
-      options = options.enableAVX2Lowering().setAVX2LoweringOptions(
+      auto avx2LoweringOptions =
           x86vector::avx2::LoweringOptions().setTransposeOptions(
               x86vector::avx2::TransposeLoweringOptions()
                   .lower4x8xf32()
-                  .lower8x8xf32()));
+                  .lower8x8xf32());
+      x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
+          patterns, avx2LoweringOptions, /*benefit=*/10);
     }
 
-    OpPassManager dynamicPM("func.func");
-    dynamicPM.addPass(createLinalgStrategyLowerVectorsPass(options));
-    if (failed(runPipeline(dynamicPM, getOperation())))
+    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
       return signalPassFailure();
   }
 };