From bcfbf8cc41e0712533b44286635a09a0f9b84afe Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 12 Oct 2022 04:14:33 -0700 Subject: [PATCH] [mlir][Linalg] NFC - Drop filter from LinalgGeneralizationPattern Differential Revision: https://reviews.llvm.org/D135761 --- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 28 +++++++--------------- .../Dialect/Linalg/Transforms/Generalization.cpp | 4 ++-- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 24 ------------------- 3 files changed, 10 insertions(+), 46 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 044ce8d..e77d4e8 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -772,35 +772,25 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final /// Linalg generalization pattern. /// /// Apply the `generalization` transformation as a pattern. -/// `filter` controls LinalgTransformMarker matching and update when specified. /// See `generalization` for more details. +// +// TODO: Automatic default pattern class that just unwraps a function returning +// FailureOr. struct LinalgGeneralizationPattern : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `filter`. - LinalgGeneralizationPattern( - MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - /// Construct a pattern specifically applied to `opName`. - LinalgGeneralizationPattern( - StringRef opName, MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; /// `matchAndRewrite` implementation that returns the significant transformed /// pieces of IR. FailureOr - returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const; + returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const { + return generalizeNamedOp(rewriter, op); + } LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { return returningMatchAndRewrite(op, rewriter); } - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; }; /// @@ -917,9 +907,7 @@ private: /// Populates `patterns` with patterns to convert spec-generated named ops to /// linalg.generic ops. -void populateLinalgNamedOpsGeneralizationPatterns( - RewritePatternSet &patterns, - const LinalgTransformationFilter &filter = LinalgTransformationFilter()); +void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns); /// Linalg decompose convolutions patterns diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 97eee8b..d656e92 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -86,8 +86,8 @@ void LinalgGeneralizationPass::runOnOperation() { } void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( - RewritePatternSet &patterns, const LinalgTransformationFilter &marker) { - patterns.add(patterns.getContext(), marker); + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); } std::unique_ptr> diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 58923bc..63f74a3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -427,30 +427,6 @@ mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite( return paddedOp; } -/// Linalg generalization pattern. -mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( - MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(std::move(f)) {} - -mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( - StringRef opName, MLIRContext *context, LinalgTransformationFilter f, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(f.addOpNameFilter(opName)) {} - -FailureOr -mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite( - LinalgOp linalgOp, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, linalgOp))) - return failure(); - FailureOr genericOp = generalizeNamedOp(rewriter, linalgOp); - if (failed(genericOp)) - return failure(); - filter.replaceLinalgTransformationFilter(rewriter, *genericOp); - return genericOp; -} - LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( memref::CopyOp copyOp, PatternRewriter &rewriter) const { return vectorizeCopy(rewriter, copyOp); -- 2.7.4