From ad1efb51914780937c704eb6e1ca9554feca08a7 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 7 Oct 2022 08:27:17 -0700 Subject: [PATCH] [mlir][Linalg] Retire LinalgStrategyDecomposePass and filter-based pattern. Context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785 Differential Revision: https://reviews.llvm.org/D135450 --- mlir/include/mlir/Dialect/Linalg/Passes.h | 6 ---- mlir/include/mlir/Dialect/Linalg/Passes.td | 12 -------- .../Dialect/Linalg/Transforms/CodegenStrategy.h | 22 --------------- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 29 ++++--------------- .../Linalg/Transforms/LinalgStrategyPasses.cpp | 33 ---------------------- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 13 ++------- 6 files changed, 9 insertions(+), 106 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index 342ef1c..2738838 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -97,12 +97,6 @@ std::unique_ptr> createLinalgStrategyPadPass( const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); -/// Create a LinalgStrategyDecomposePass. -// TODO: if/when we need finer control add an `opName` parameter. -std::unique_ptr> createLinalgStrategyDecomposePass( - const linalg::LinalgTransformationFilter &filter = - linalg::LinalgTransformationFilter()); - /// Create a LinalgStrategyRemoveMarkersPass. std::unique_ptr> createLinalgStrategyRemoveMarkersPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index 43a6cad..9c0b414 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -200,18 +200,6 @@ def LinalgStrategyPadPass ]; } -// TODO: if/when we need finer control add an anchorOp option. -def LinalgStrategyDecomposePass - : Pass<"linalg-strategy-decompose-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply pattern-based generalization."; - let constructor = "mlir::createLinalgStrategyDecomposePass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - ]; -} - def LinalgStrategyRemoveMarkersPass : Pass<"linalg-strategy-remove-markers-pass", "func::FuncOp"> { let summary = "Cleanup pass that drops markers."; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h index 6f80f419..b9b26bf 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -81,17 +81,6 @@ private: linalg::LinalgPaddingOptions options; }; -/// Represent one application of createLinalgStrategyDecomposePass. -struct Decompose : public Transformation { - explicit Decompose(LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)) {} - - void addToPassPipeline(OpPassManager &pm, - LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyDecomposePass(m)); - } -}; - /// Codegen strategy controls how a Linalg op is progressively lowered. struct CodegenStrategy { /// Append a pattern to tile the Op `opName` and fuse its producers with @@ -142,17 +131,6 @@ struct CodegenStrategy { LinalgTransformationFilter::FilterFunction f = nullptr) { return b ? pad(opName, std::move(options), std::move(f)) : *this; } - /// Append patterns to decompose convolutions. - CodegenStrategy & - decompose(const LinalgTransformationFilter::FilterFunction &f = nullptr) { - transformationSequence.emplace_back(std::make_unique(f)); - return *this; - } - /// Conditionally append patterns to decompose convolutions. - CodegenStrategy & - decomposeIf(bool b, LinalgTransformationFilter::FilterFunction f = nullptr) { - return b ? decompose(std::move(f)) : *this; - } /// Configure the post staged-patterns global enabling passes options. CodegenStrategy & setVectorTransferToSCFOptions(LinalgEnablingOptions options) { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 6d044ed4..735ce0a 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -764,11 +764,7 @@ private: template struct DownscaleSizeOneWindowed2DConvolution final : public OpRewritePattern { - DownscaleSizeOneWindowed2DConvolution( - MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), filter(std::move(f)) {} + using OpRewritePattern::OpRewritePattern; FailureOr returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const; @@ -777,10 +773,6 @@ struct DownscaleSizeOneWindowed2DConvolution final PatternRewriter &rewriter) const override { return returningMatchAndRewrite(convOp, rewriter); } - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; }; extern template struct DownscaleSizeOneWindowed2DConvolution { - DownscaleDepthwiseConv2DNhwcHwcOp( - MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - filter(std::move(f)) {} + DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} FailureOr returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, @@ -807,10 +796,6 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final PatternRewriter &rewriter) const override { return returningMatchAndRewrite(convOp, rewriter); } - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; }; /// @@ -1007,10 +992,8 @@ void populateLinalgNamedOpsGeneralizationPatterns( /// Populates patterns to decompose high-D convolution ops into low-D ones. This /// is a step in progressive lowering for convolution ops, afterwards we can /// vectorize the low-D convolution ops. -void populateDecomposeConvolutionPatterns( - RewritePatternSet &patterns, - const LinalgTransformationFilter &filter = LinalgTransformationFilter(), - PatternBenefit benefit = 1); +void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); //===----------------------------------------------------------------------===// // Op-specific patterns. diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp index fd91b0a..ddf5d55 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -155,31 +155,6 @@ struct LinalgStrategyPadPass LinalgTransformationFilter filter; }; -/// Configurable pass to apply lowering of coarser-grained named linalg ops into -/// finer-grained named versions. -struct LinalgStrategyDecomposePass - : public impl::LinalgStrategyDecomposePassBase< - LinalgStrategyDecomposePass> { - - LinalgStrategyDecomposePass() = default; - - LinalgStrategyDecomposePass(LinalgTransformationFilter filter) - : filter(std::move(filter)) {} - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - RewritePatternSet decompositionPattern(funcOp.getContext()); - populateDecomposeConvolutionPatterns(decompositionPattern, filter); - if (failed(applyPatternsAndFoldGreedily(funcOp, - std::move(decompositionPattern)))) - signalPassFailure(); - } - - LinalgTransformationFilter filter; -}; - /// Configurable pass to lower vector operations. struct LinalgStrategyRemoveMarkersPass : public impl::LinalgStrategyRemoveMarkersPassBase< @@ -221,14 +196,6 @@ mlir::createLinalgStrategyPadPass(StringRef opName, return std::make_unique(opName, opt, filter); } -/// Create a LinalgStrategyDecomposePass. -// TODO: if/when we need finer control add an `opName` parameter. -std::unique_ptr> -mlir::createLinalgStrategyDecomposePass( - const LinalgTransformationFilter &filter) { - return std::make_unique(filter); -} - /// Create a LinalgStrategyRemoveMarkersPass. std::unique_ptr> mlir::createLinalgStrategyRemoveMarkersPass() { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index f6a68e4..9a4f9a8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -782,8 +782,6 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( template FailureOr DownscaleSizeOneWindowed2DConvolution:: returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, convOp))) - return failure(); if (convOp.hasBufferSemantics()) return failure(); // To be implemented. @@ -867,7 +865,6 @@ FailureOr DownscaleSizeOneWindowed2DConvolution:: rewriter, loc, conv1DOp.getResult(0), output); rewriter.replaceOp(convOp, inserted); - filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); return conv1DOp; } @@ -879,8 +876,6 @@ template struct linalg::DownscaleSizeOneWindowed2DConvolution DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, convOp))) - return failure(); if (convOp.hasBufferSemantics()) return failure(); // To be implemented. @@ -943,17 +938,15 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( rewriter, loc, conv1DOp.getResult(0), output); rewriter.replaceOp(convOp, inserted); - filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); return conv1DOp; } -void linalg::populateDecomposeConvolutionPatterns( - RewritePatternSet &patterns, const LinalgTransformationFilter &filter, - PatternBenefit benefit) { +void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, + PatternBenefit benefit) { patterns.add, DownscaleSizeOneWindowed2DConvolution, - DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter, + DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), benefit); } -- 2.7.4