From d9f645fe5081fccbe59560989cdf8ea4535946fc Mon Sep 17 00:00:00 2001 From: Chenguang Wang Date: Thu, 22 Dec 2022 09:10:15 -0800 Subject: [PATCH] [mlir] Allow specifying benefit for C func ptr style patterns. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D139234 --- mlir/include/mlir/IR/PatternMatch.h | 14 +++++++++----- mlir/unittests/IR/PatternMatchTest.cpp | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 0743e37..3ee533c 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -1638,12 +1638,15 @@ public: // Add a matchAndRewrite style pattern represented as a C function pointer. template - RewritePatternSet &add(LogicalResult (*implFn)(OpType, - PatternRewriter &rewriter)) { + RewritePatternSet & + add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), + PatternBenefit benefit = 1, ArrayRef generatedNames = {}) { struct FnPattern final : public OpRewritePattern { FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), - MLIRContext *context) - : OpRewritePattern(context), implFn(implFn) {} + MLIRContext *context, PatternBenefit benefit, + ArrayRef generatedNames) + : OpRewritePattern(context, benefit, generatedNames), + implFn(implFn) {} LogicalResult matchAndRewrite(OpType op, PatternRewriter &rewriter) const override { @@ -1653,7 +1656,8 @@ public: private: LogicalResult (*implFn)(OpType, PatternRewriter &rewriter); }; - add(std::make_unique(std::move(implFn), getContext())); + add(std::make_unique(std::move(implFn), getContext(), benefit, + generatedNames)); return *this; } diff --git a/mlir/unittests/IR/PatternMatchTest.cpp b/mlir/unittests/IR/PatternMatchTest.cpp index 6454f05..3a58d5c 100644 --- a/mlir/unittests/IR/PatternMatchTest.cpp +++ b/mlir/unittests/IR/PatternMatchTest.cpp @@ -28,3 +28,22 @@ TEST(OpRewritePatternTest, GetGeneratedNames) { ASSERT_EQ(ops.front().getStringRef(), test::OpB::getOperationName()); } } // end anonymous namespace + +namespace { +LogicalResult anOpRewritePatternFunc(test::OpA op, PatternRewriter &rewriter) { + return failure(); +} +TEST(AnOpRewritePatternTest, PatternFuncAttributes) { + MLIRContext context; + RewritePatternSet patterns(&context); + + patterns.add(anOpRewritePatternFunc, /*benefit=*/3, + /*generatedNames=*/{test::OpB::getOperationName()}); + ASSERT_EQ(patterns.getNativePatterns().size(), 1); + auto &pattern = patterns.getNativePatterns().front(); + ASSERT_EQ(pattern->getBenefit(), 3); + ASSERT_EQ(pattern->getGeneratedOps().size(), 1); + ASSERT_EQ(pattern->getGeneratedOps().front().getStringRef(), + test::OpB::getOperationName()); +} +} // end anonymous namespace -- 2.7.4