From 2c4a56c4183f4f01c0b0959acec6972fddd79b7d Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 6 Jan 2022 08:30:49 -0500 Subject: [PATCH] [mlir][Linalg] NFC - Modernize padding pattern Differential Revision: https://reviews.llvm.org/D116739 --- .../mlir/Dialect/Linalg/Transforms/Transforms.h | 4 ++-- mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 19 ++++++++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index c1185c0..7592094 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -688,7 +688,7 @@ struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern { /// Apply the `padding` transformation as a pattern. /// `filter` controls LinalgTransformMarker matching and update when specified. /// See `padding` for more details. -struct LinalgPaddingPattern : public RewritePattern { +struct LinalgPaddingPattern : public OpInterfaceRewritePattern { // Entry point to match any LinalgOp OpInterface. LinalgPaddingPattern( MLIRContext *context, @@ -701,7 +701,7 @@ struct LinalgPaddingPattern : public RewritePattern { LinalgPaddingOptions options = LinalgPaddingOptions(), LinalgTransformationFilter filter = LinalgTransformationFilter(), PatternBenefit benefit = 1); - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(LinalgOp, PatternRewriter &rewriter) const override; private: diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 8b9c7bd..177a2ab 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -489,23 +489,24 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( MLIRContext *context, LinalgPaddingOptions options, LinalgTransformationFilter filter, PatternBenefit benefit) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context), + : OpInterfaceRewritePattern(context, benefit), filter(std::move(filter)), options(std::move(options)) {} mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( StringRef opName, MLIRContext *context, LinalgPaddingOptions options, LinalgTransformationFilter filter, PatternBenefit benefit) - : RewritePattern(opName, benefit, context, {}), filter(std::move(filter)), - options(std::move(options)) {} + : OpInterfaceRewritePattern(context, benefit), + filter(std::move(filter)), options(std::move(options)) { + this->filter.addFilter([opName](Operation *op) { + return success(op->getName().getStringRef() == opName); + }); +} LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite( - Operation *op, PatternRewriter &rewriter) const { - LinalgOp linalgOp = dyn_cast(op); - if (!linalgOp) - return failure(); + LinalgOp linalgOp, PatternRewriter &rewriter) const { if (!linalgOp.hasTensorSemantics()) return failure(); - if (failed(filter.checkAndNotify(rewriter, op))) + if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); // Pad the operation. @@ -538,7 +539,7 @@ LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite( } // Replace the original operation to pad. - rewriter.replaceOp(op, newResults.getValue()); + rewriter.replaceOp(linalgOp, newResults.getValue()); filter.replaceLinalgTransformationFilter(rewriter, paddedOp); return success(); } -- 2.7.4