From e625aea33a653d23d83aab8ea30e6bf7dd0b6b51 Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Tue, 23 Aug 2022 09:28:16 +0200 Subject: [PATCH] [mlir][Linalg] Retire Linalg generic interchange pattern and pass This revision removes the Linalg generic interchange pattern and pass. It also changes transform-patterns test to make use of transform dialect. Context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785 Differential Revision: https://reviews.llvm.org/D132368 --- mlir/include/mlir/Dialect/Linalg/Passes.h | 7 ---- mlir/include/mlir/Dialect/Linalg/Passes.td | 11 ------ .../Dialect/Linalg/Transforms/CodegenStrategy.h | 31 ----------------- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 32 ----------------- .../Linalg/Transforms/LinalgStrategyPasses.cpp | 40 ---------------------- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 23 ------------- mlir/test/Dialect/Linalg/transform-patterns.mlir | 10 +++++- .../lib/Dialect/Linalg/TestLinalgTransforms.cpp | 9 ----- 8 files changed, 9 insertions(+), 154 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h index ec7db84..ecf684c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -108,13 +108,6 @@ std::unique_ptr> createLinalgStrategyDecomposePass( const linalg::LinalgTransformationFilter &filter = linalg::LinalgTransformationFilter()); -/// Create a LinalgStrategyInterchangePass. -std::unique_ptr> -createLinalgStrategyInterchangePass( - ArrayRef iteratorInterchange = {}, - const linalg::LinalgTransformationFilter &filter = - linalg::LinalgTransformationFilter()); - /// Create a LinalgStrategyPeelPass. std::unique_ptr> createLinalgStrategyPeelPass( StringRef opName = "", diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index c497135..85cbdd8 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -235,17 +235,6 @@ def LinalgStrategyDecomposePass ]; } -def LinalgStrategyInterchangePass - : Pass<"linalg-strategy-interchange-pass", "func::FuncOp"> { - let summary = "Configurable pass to apply pattern-based iterator interchange."; - let constructor = "mlir::createLinalgStrategyInterchangePass()"; - let dependentDialects = ["linalg::LinalgDialect"]; - let options = [ - Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", - "Which func op is the anchor to latch on.">, - ]; -} - def LinalgStrategyPeelPass : Pass<"linalg-strategy-peel-pass", "func::FuncOp"> { let summary = "Configurable pass to apply pattern-based linalg peeling."; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h index 82a4948..d28f1cc 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -96,23 +96,6 @@ private: std::string opName; }; -/// Represent one application of createLinalgStrategyInterchangePass. -struct Interchange : public Transformation { - explicit Interchange(ArrayRef iteratorInterchange, - LinalgTransformationFilter::FilterFunction f = nullptr) - : Transformation(std::move(f)), - iteratorInterchange(iteratorInterchange.begin(), - iteratorInterchange.end()) {} - - void addToPassPipeline(OpPassManager &pm, - LinalgTransformationFilter m) const override { - pm.addPass(createLinalgStrategyInterchangePass(iteratorInterchange, m)); - } - -private: - SmallVector iteratorInterchange; -}; - /// Represent one application of createLinalgStrategyDecomposePass. struct Decompose : public Transformation { explicit Decompose(LinalgTransformationFilter::FilterFunction f = nullptr) @@ -250,20 +233,6 @@ struct CodegenStrategy { LinalgTransformationFilter::FilterFunction f = nullptr) { return b ? generalize(opName, std::move(f)) : *this; } - /// Append a pattern to interchange iterators. - CodegenStrategy & - interchange(ArrayRef iteratorInterchange, - const LinalgTransformationFilter::FilterFunction &f = nullptr) { - transformationSequence.emplace_back( - std::make_unique(iteratorInterchange, f)); - return *this; - } - /// Conditionally append a pattern to interchange iterators. - CodegenStrategy & - interchangeIf(bool b, ArrayRef iteratorInterchange, - LinalgTransformationFilter::FilterFunction f = nullptr) { - return b ? interchange(iteratorInterchange, std::move(f)) : *this; - } /// Append patterns to decompose convolutions. CodegenStrategy & decompose(const LinalgTransformationFilter::FilterFunction &f = nullptr) { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 5a4c41a..8f53d5a 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -829,38 +829,6 @@ private: }; /// -/// Linalg generic interchange pattern. -/// -/// Apply the `interchange` transformation on a RewriterBase. -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `interchange` for more details. -struct GenericOpInterchangePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - /// GenericOp-specific constructor with an optional `filter`. - GenericOpInterchangePattern( - MLIRContext *context, ArrayRef interchangeVector, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - /// `matchAndRewrite` implementation that returns the significant transformed - /// pieces of IR. - FailureOr - returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(GenericOp op, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(op, rewriter); - } - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; - /// The interchange vector to reorder the iterators and indexing_maps dims. - SmallVector interchangeVector; -}; - -/// /// Linalg generalization pattern. /// /// Apply the `generalization` transformation as a pattern. diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp index ee0846f..22b97ff 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -199,37 +199,6 @@ struct LinalgStrategyDecomposePass LinalgTransformationFilter filter; }; -/// Configurable pass to apply pattern-based linalg generalization. -struct LinalgStrategyInterchangePass - : public LinalgStrategyInterchangePassBase { - - LinalgStrategyInterchangePass() = default; - - LinalgStrategyInterchangePass(ArrayRef iteratorInterchange, - LinalgTransformationFilter filter) - : iteratorInterchange(iteratorInterchange.begin(), - iteratorInterchange.end()), - filter(std::move(filter)) {} - - void runOnOperation() override { - auto funcOp = getOperation(); - if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) - return; - - SmallVector interchangeVector(iteratorInterchange.begin(), - iteratorInterchange.end()); - RewritePatternSet interchangePattern(funcOp.getContext()); - interchangePattern.add( - funcOp.getContext(), interchangeVector, filter); - if (failed(applyPatternsAndFoldGreedily(funcOp, - std::move(interchangePattern)))) - signalPassFailure(); - } - - SmallVector iteratorInterchange; - LinalgTransformationFilter filter; -}; - /// Configurable pass to apply pattern-based linalg peeling. struct LinalgStrategyPeelPass : public LinalgStrategyPeelPassBase { @@ -491,15 +460,6 @@ mlir::createLinalgStrategyDecomposePass( return std::make_unique(filter); } -/// Create a LinalgStrategyInterchangePass. -std::unique_ptr> -mlir::createLinalgStrategyInterchangePass( - ArrayRef iteratorInterchange, - const LinalgTransformationFilter &filter) { - return std::make_unique(iteratorInterchange, - filter); -} - /// Create a LinalgStrategyPeelPass. std::unique_ptr> mlir::createLinalgStrategyPeelPass(StringRef opName, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 11152f1..2fcbe68 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -537,29 +537,6 @@ mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite( return tileLoopNest; } -/// Linalg generic interchange pattern. -mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( - MLIRContext *context, ArrayRef interchangeVector, - LinalgTransformationFilter f, PatternBenefit benefit) - : OpRewritePattern(context, benefit), filter(std::move(f)), - interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} - -FailureOr -mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite( - GenericOp genericOp, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, genericOp))) - return failure(); - - FailureOr transformedOp = - interchangeGenericOp(rewriter, genericOp, interchangeVector); - if (failed(transformedOp)) - return failure(); - - // New filter if specified. - filter.replaceLinalgTransformationFilter(rewriter, genericOp); - return transformedOp; -} - /// Linalg generalization pattern. mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir index 3a704e4..e7053c3 100644 --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns -split-input-file -test-transform-dialect-interpreter | FileCheck %s // CHECK-DAG: #[[$STRIDED_1D:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> // Map corresponding to a 2D memory access where the stride along the last dim is known to be 1. @@ -114,6 +114,14 @@ func.func @permute_generic(%A: memref, } return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + transform.structured.interchange %0 { iterator_interchange = [1, 2, 0]} + } +} // CHECK-LABEL: func @permute_generic // CHECK: linalg.generic { // CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]], diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index fc988ea..576082a 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -230,15 +230,6 @@ static void applyPatterns(func::FuncOp funcOp) { .addOpFilter()); patterns.add(ctx); - //===--------------------------------------------------------------------===// - // Linalg generic interchange pattern. - //===--------------------------------------------------------------------===// - patterns.add( - ctx, - /*interchangeVector=*/ArrayRef{1, 2, 0}, - LinalgTransformationFilter(ArrayRef{}, - StringAttr::get(ctx, "PERMUTED"))); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); // Drop the marker. -- 2.7.4