From 92ea624a1345fc9f0512bab2bd5d0d1ebeb8cf21 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 29 Sep 2021 09:36:32 +0000 Subject: [PATCH] [mlir][Linalg] Rewrite CodegenStrategy to populate a pass pipeline. This revision retires a good portion of the complexity of the codegen strategy and puts the logic behind pass logic. Differential revision: https://reviews.llvm.org/D110678 --- mlir/include/mlir/Dialect/Linalg/Passes.h | 38 +++ mlir/include/mlir/Dialect/Linalg/Passes.td | 62 +++++ .../Dialect/Linalg/Transforms/CodegenStrategy.h | 227 ++++-------------- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 74 +++++- .../include/mlir/Dialect/Vector/VectorTransforms.h | 25 +- mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt | 1 + .../Dialect/Linalg/Transforms/CodegenStrategy.cpp | 103 +++------ .../Linalg/Transforms/LinalgStrategyPasses.cpp | 256 +++++++++++++++++++++ mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 6 + mlir/lib/Dialect/Vector/VectorTransforms.cpp | 28 ++- mlir/test/Dialect/Linalg/codegen-strategy.mlir | 10 +- .../Dialect/Linalg/CPU/benchmark_matmul.mlir | 4 +- .../lib/Dialect/Linalg/TestConvVectorization.cpp | 8 +- .../Dialect/Linalg/TestLinalgCodegenStrategy.cpp | 70 ++---- 14 files changed, 558 insertions(+), 354 deletions(-) create mode 100644 mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index 56c709b..867921c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_LINALG_PASSES_H_ #define MLIR_DIALECT_LINALG_PASSES_H_ +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Pass/Pass.h" @@ -78,6 +79,43 @@ std::unique_ptr createLinalgDetensorizePass(); std::unique_ptr> createLinalgTileAndFuseTensorOpsPass(); //===----------------------------------------------------------------------===// +/// Linalg strategy passes. +//===----------------------------------------------------------------------===// +/// Create a LinalgStrategyTilePass. +std::unique_ptr> createLinalgStrategyTilePass( + StringRef opName = "", + linalg::LinalgTilingOptions opt = linalg::LinalgTilingOptions(), + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyPromotePass. +std::unique_ptr> createLinalgStrategyPromotePass( + StringRef opName = "", + linalg::LinalgPromotionOptions opt = linalg::LinalgPromotionOptions(), + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyVectorizePass. +std::unique_ptr> +createLinalgStrategyVectorizePass(StringRef opName = "", + linalg::LinalgVectorizationOptions opt = + linalg::LinalgVectorizationOptions(), + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyEnablePass. +std::unique_ptr> createLinalgStrategyEnablePass( + linalg::LinalgEnablingOptions opt = linalg::LinalgEnablingOptions(), + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); + +/// Create a LinalgStrategyLowerVectorsPass. +std::unique_ptr> +createLinalgStrategyLowerVectorsPass(linalg::LinalgVectorLoweringOptions opt = + linalg::LinalgVectorLoweringOptions(), + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); +//===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index 3f331b1..32327cd 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -229,4 +229,66 @@ def LinalgTileAndFuseTensorOps let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect"]; } +def LinalgStrategyTilePass + : FunctionPass<"linalg-strategy-tile-pass"> { + let summary = "Configurable pass to apply pattern-based linalg tiling."; + let constructor = "mlir::createLinalgStrategyTilePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + ]; +} + +def LinalgStrategyPromotePass + : FunctionPass<"linalg-strategy-promote-pass"> { + let summary = "Configurable pass to apply pattern-based linalg promotion."; + let constructor = "mlir::createLinalgStrategyPromotePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + ]; +} + +def LinalgStrategyVectorizePass + : FunctionPass<"linalg-strategy-vectorize-pass"> { + let summary = "Configurable pass to apply pattern-based linalg vectorization."; + let constructor = "mlir::createLinalgStrategyVectorizePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + ]; +} + +def LinalgStrategyEnablePass + : FunctionPass<"linalg-strategy-enable-pass"> { + let summary = "Configurable pass to enable the application of other " + "pattern-based linalg passes."; + let constructor = "mlir::createLinalgStrategyEnablePass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + ]; +} + +def LinalgStrategyLowerVectorsPass + : FunctionPass<"linalg-strategy-lower-vectors-pass"> { + let summary = "Configurable pass to lower vector operations."; + let constructor = "mlir::createLinalgStrategyLowerVectorsPass()"; + let dependentDialects = ["linalg::LinalgDialect"]; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + ]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h index d33dd81..ff37207 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -10,7 +10,8 @@ #define MLIR_DIALECT_LINALG_TRANSFORMS_CODEGENSTRATEGY_H_ #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Pass/PassManager.h" namespace mlir { @@ -21,69 +22,23 @@ namespace linalg { /// Abstract Transformation class applied in a sequence that also handles state /// through markers. struct Transformation { - explicit Transformation(linalg::LinalgTransformationFilter::FilterFunction f) + explicit Transformation(LinalgTransformationFilter::FilterFunction f) : filter(f) {} virtual ~Transformation() = default; - virtual RewritePatternSet - buildRewritePatterns(MLIRContext *context, - linalg::LinalgTransformationFilter m) = 0; - linalg::LinalgTransformationFilter::FilterFunction filter = nullptr; + virtual void addToPassPipeline(OpPassManager &pm, + LinalgTransformationFilter m) const = 0; + LinalgTransformationFilter::FilterFunction filter = nullptr; }; -/// SFINAE: Enqueue helper for ConcreteOpType that have a `getOperationName`. -template