[mlir][Linalg] Retire LinalgStrategyDecomposePass and filter-based pattern.
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 7 Oct 2022 15:27:17 +0000 (08:27 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 7 Oct 2022 16:01:35 +0000 (09:01 -0700)
Context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785

Differential Revision: https://reviews.llvm.org/D135450

mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

index 342ef1c..2738838 100644 (file)
@@ -97,12 +97,6 @@ std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyPadPass(
     const linalg::LinalgTransformationFilter &filter =
         linalg::LinalgTransformationFilter());
 
-/// Create a LinalgStrategyDecomposePass.
-// TODO: if/when we need finer control add an `opName` parameter.
-std::unique_ptr<OperationPass<func::FuncOp>> createLinalgStrategyDecomposePass(
-    const linalg::LinalgTransformationFilter &filter =
-        linalg::LinalgTransformationFilter());
-
 /// Create a LinalgStrategyRemoveMarkersPass.
 std::unique_ptr<OperationPass<func::FuncOp>>
 createLinalgStrategyRemoveMarkersPass();
index 43a6cad..9c0b414 100644 (file)
@@ -200,18 +200,6 @@ def LinalgStrategyPadPass
   ];
 }
 
-// TODO: if/when we need finer control add an anchorOp option.
-def LinalgStrategyDecomposePass
-    : Pass<"linalg-strategy-decompose-pass", "func::FuncOp"> {
-  let summary = "Configurable pass to apply pattern-based generalization.";
-  let constructor = "mlir::createLinalgStrategyDecomposePass()";
-  let dependentDialects = ["linalg::LinalgDialect"];
-  let options = [
-    Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"",
-      "Which func op is the anchor to latch on.">,
-  ];
-}
-
 def LinalgStrategyRemoveMarkersPass
     : Pass<"linalg-strategy-remove-markers-pass", "func::FuncOp"> {
   let summary = "Cleanup pass that drops markers.";
index 6f80f41..b9b26bf 100644 (file)
@@ -81,17 +81,6 @@ private:
   linalg::LinalgPaddingOptions options;
 };
 
-/// Represent one application of createLinalgStrategyDecomposePass.
-struct Decompose : public Transformation {
-  explicit Decompose(LinalgTransformationFilter::FilterFunction f = nullptr)
-      : Transformation(std::move(f)) {}
-
-  void addToPassPipeline(OpPassManager &pm,
-                         LinalgTransformationFilter m) const override {
-    pm.addPass(createLinalgStrategyDecomposePass(m));
-  }
-};
-
 /// Codegen strategy controls how a Linalg op is progressively lowered.
 struct CodegenStrategy {
   /// Append a pattern to tile the Op `opName` and fuse its producers with
@@ -142,17 +131,6 @@ struct CodegenStrategy {
         LinalgTransformationFilter::FilterFunction f = nullptr) {
     return b ? pad(opName, std::move(options), std::move(f)) : *this;
   }
-  /// Append patterns to decompose convolutions.
-  CodegenStrategy &
-  decompose(const LinalgTransformationFilter::FilterFunction &f = nullptr) {
-    transformationSequence.emplace_back(std::make_unique<Decompose>(f));
-    return *this;
-  }
-  /// Conditionally append patterns to decompose convolutions.
-  CodegenStrategy &
-  decomposeIf(bool b, LinalgTransformationFilter::FilterFunction f = nullptr) {
-    return b ? decompose(std::move(f)) : *this;
-  }
   /// Configure the post staged-patterns global enabling passes options.
   CodegenStrategy &
   setVectorTransferToSCFOptions(LinalgEnablingOptions options) {
index 6d044ed..735ce0a 100644 (file)
@@ -764,11 +764,7 @@ private:
 template <typename Conv2DOp, typename Conv1DOp>
 struct DownscaleSizeOneWindowed2DConvolution final
     : public OpRewritePattern<Conv2DOp> {
-  DownscaleSizeOneWindowed2DConvolution(
-      MLIRContext *context,
-      LinalgTransformationFilter f = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1)
-      : OpRewritePattern<Conv2DOp>(context, benefit), filter(std::move(f)) {}
+  using OpRewritePattern<Conv2DOp>::OpRewritePattern;
 
   FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp,
                                                PatternRewriter &rewriter) const;
@@ -777,10 +773,6 @@ struct DownscaleSizeOneWindowed2DConvolution final
                                 PatternRewriter &rewriter) const override {
     return returningMatchAndRewrite(convOp, rewriter);
   }
-
-private:
-  /// LinalgTransformMarker handles special attribute manipulations.
-  LinalgTransformationFilter filter;
 };
 
 extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp,
@@ -792,12 +784,9 @@ extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
 /// dimensions into 1-D depthwise convolution ops.
 struct DownscaleDepthwiseConv2DNhwcHwcOp final
     : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
-  DownscaleDepthwiseConv2DNhwcHwcOp(
-      MLIRContext *context,
-      LinalgTransformationFilter f = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1)
-      : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
-        filter(std::move(f)) {}
+  DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context,
+                                    PatternBenefit benefit = 1)
+      : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit) {}
 
   FailureOr<DepthwiseConv1DNwcWcOp>
   returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
@@ -807,10 +796,6 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final
                                 PatternRewriter &rewriter) const override {
     return returningMatchAndRewrite(convOp, rewriter);
   }
-
-private:
-  /// LinalgTransformMarker handles special attribute manipulations.
-  LinalgTransformationFilter filter;
 };
 
 ///
@@ -1007,10 +992,8 @@ void populateLinalgNamedOpsGeneralizationPatterns(
 /// Populates patterns to decompose high-D convolution ops into low-D ones. This
 /// is a step in progressive lowering for convolution ops, afterwards we can
 /// vectorize the low-D convolution ops.
-void populateDecomposeConvolutionPatterns(
-    RewritePatternSet &patterns,
-    const LinalgTransformationFilter &filter = LinalgTransformationFilter(),
-    PatternBenefit benefit = 1);
+void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
 
 //===----------------------------------------------------------------------===//
 // Op-specific patterns.
index fd91b0a..ddf5d55 100644 (file)
@@ -155,31 +155,6 @@ struct LinalgStrategyPadPass
   LinalgTransformationFilter filter;
 };
 
-/// Configurable pass to apply lowering of coarser-grained named linalg ops into
-/// finer-grained named versions.
-struct LinalgStrategyDecomposePass
-    : public impl::LinalgStrategyDecomposePassBase<
-          LinalgStrategyDecomposePass> {
-
-  LinalgStrategyDecomposePass() = default;
-
-  LinalgStrategyDecomposePass(LinalgTransformationFilter filter)
-      : filter(std::move(filter)) {}
-
-  void runOnOperation() override {
-    auto funcOp = getOperation();
-    if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
-      return;
-    RewritePatternSet decompositionPattern(funcOp.getContext());
-    populateDecomposeConvolutionPatterns(decompositionPattern, filter);
-    if (failed(applyPatternsAndFoldGreedily(funcOp,
-                                            std::move(decompositionPattern))))
-      signalPassFailure();
-  }
-
-  LinalgTransformationFilter filter;
-};
-
 /// Configurable pass to lower vector operations.
 struct LinalgStrategyRemoveMarkersPass
     : public impl::LinalgStrategyRemoveMarkersPassBase<
@@ -221,14 +196,6 @@ mlir::createLinalgStrategyPadPass(StringRef opName,
   return std::make_unique<LinalgStrategyPadPass>(opName, opt, filter);
 }
 
-/// Create a LinalgStrategyDecomposePass.
-// TODO: if/when we need finer control add an `opName` parameter.
-std::unique_ptr<OperationPass<func::FuncOp>>
-mlir::createLinalgStrategyDecomposePass(
-    const LinalgTransformationFilter &filter) {
-  return std::make_unique<LinalgStrategyDecomposePass>(filter);
-}
-
 /// Create a LinalgStrategyRemoveMarkersPass.
 std::unique_ptr<OperationPass<func::FuncOp>>
 mlir::createLinalgStrategyRemoveMarkersPass() {
index f6a68e4..9a4f9a8 100644 (file)
@@ -782,8 +782,6 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
 template <typename Conv2DOp, typename Conv1DOp>
 FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
     returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const {
-  if (failed(filter.checkAndNotify(rewriter, convOp)))
-    return failure();
   if (convOp.hasBufferSemantics())
     return failure(); // To be implemented.
 
@@ -867,7 +865,6 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>::
       rewriter, loc, conv1DOp.getResult(0), output);
   rewriter.replaceOp(convOp, inserted);
 
-  filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
   return conv1DOp;
 }
 
@@ -879,8 +876,6 @@ template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp,
 FailureOr<DepthwiseConv1DNwcWcOp>
 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
     DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const {
-  if (failed(filter.checkAndNotify(rewriter, convOp)))
-    return failure();
   if (convOp.hasBufferSemantics())
     return failure(); // To be implemented.
 
@@ -943,17 +938,15 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite(
       rewriter, loc, conv1DOp.getResult(0), output);
   rewriter.replaceOp(convOp, inserted);
 
-  filter.replaceLinalgTransformationFilter(rewriter, conv1DOp);
   return conv1DOp;
 }
 
-void linalg::populateDecomposeConvolutionPatterns(
-    RewritePatternSet &patterns, const LinalgTransformationFilter &filter,
-    PatternBenefit benefit) {
+void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns,
+                                                  PatternBenefit benefit) {
   patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
                                                      Conv1DNwcWcfOp>,
                DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
                                                      Conv1DNcwFcwOp>,
-               DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(), filter,
+               DownscaleDepthwiseConv2DNhwcHwcOp>(patterns.getContext(),
                                                   benefit);
 }