From: River Riddle Date: Tue, 27 Oct 2020 00:23:41 +0000 (-0700) Subject: [mlir][Pattern] Refactor the Pattern class into a "metadata only" class X-Git-Tag: llvmorg-13-init~8085 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b99bd771626fbbf8b9b29ce312d4151968796826;p=platform%2Fupstream%2Fllvm.git [mlir][Pattern] Refactor the Pattern class into a "metadata only" class The Pattern class was originally intended to be used for solely matching operations, but that use never materialized. All of the pattern infrastructure uses RewritePattern, and the infrastructure for pure matching(Matchers.h) is implemented inline. This means that this class isn't a useful abstraction at the moment, so this revision refactors it to solely encapsulate the "metadata" of a pattern. The metadata includes the various state describing a pattern; benefit, root operation, etc. The API on PatternApplicator is updated to now operate on `Pattern`s as nothing special from `RewritePattern` is necessary. This refactoring is also necessary for the upcoming use of PDL patterns alongside C++ rewrite patterns. Differential Revision: https://reviews.llvm.org/D86258 --- diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md index 2a2c30d..ab93245 100644 --- a/mlir/docs/PatternRewriter.md +++ b/mlir/docs/PatternRewriter.md @@ -174,10 +174,10 @@ Each driver is responsible for defining its own operation visitation order as well as pattern cost model, but the final application is performed via a `PatternApplicator` class. This class takes as input the `OwningRewritePatternList` and transforms the patterns based upon a provided -cost model. This cost model computes a final benefit for a given rewrite -pattern, using whatever driver specific information necessary. After a cost -model has been computed, the driver may begin to match patterns against -operations using `PatternApplicator::matchAndRewrite`. +cost model. This cost model computes a final benefit for a given pattern, using +whatever driver specific information necessary. After a cost model has been +computed, the driver may begin to match patterns against operations using +`PatternApplicator::matchAndRewrite`. An example is shown below: @@ -209,7 +209,7 @@ void applyMyPatternDriver(Operation *op, // Create the applicator and apply our cost model. PatternApplicator applicator(patterns); - applicator.applyCostModel([](const RewritePattern &pattern) { + applicator.applyCostModel([](const Pattern &pattern) { // Apply a default cost model. // Note: This is just for demonstration, if the default cost model is truly // desired `applicator.applyDefaultCostModel()` should be used diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index ea8f410..ef6e3bd 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -58,15 +58,23 @@ private: }; //===----------------------------------------------------------------------===// -// Pattern class +// Pattern //===----------------------------------------------------------------------===// -/// Instances of Pattern can be matched against SSA IR. These matches get used -/// in ways dependent on their subclasses and the driver doing the matching. -/// For example, RewritePatterns implement a rewrite from one matched pattern -/// to a replacement DAG tile. +/// This class contains all of the data related to a pattern, but does not +/// contain any methods or logic for the actual matching. This class is solely +/// used to interface with the metadata of a pattern, such as the benefit or +/// root operation. class Pattern { public: + /// Return a list of operations that may be generated when rewriting an + /// operation instance with this pattern. + ArrayRef getGeneratedOps() const { return generatedOps; } + + /// Return the root node that this pattern matches. Patterns that can match + /// multiple root types return None. + Optional getRootKind() const { return rootKind; } + /// Return the benefit (the inverse of "cost") of matching this pattern. The /// benefit of a Pattern is always static - rewrites that may have dynamic /// benefit can be instantiated multiple times (different Pattern instances) @@ -74,19 +82,11 @@ public: /// condition predicates. PatternBenefit getBenefit() const { return benefit; } - /// Return the root node that this pattern matches. Patterns that can match - /// multiple root types return None. - Optional getRootKind() const { return rootKind; } - - //===--------------------------------------------------------------------===// - // Implementation hooks for patterns to implement. - //===--------------------------------------------------------------------===// - - /// Attempt to match against code rooted at the specified operation, - /// which is the same operation code as getRootKind(). - virtual LogicalResult match(Operation *op) const = 0; - - virtual ~Pattern() {} + /// Returns true if this pattern is known to result in recursive application, + /// i.e. this pattern may generate IR that also matches this pattern, but is + /// known to bound the recursion. This signals to a rewrite driver that it is + /// safe to apply this pattern recursively to generated IR. + bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; } protected: /// This class acts as a special tag that makes the desire to match "any" @@ -94,19 +94,38 @@ protected: /// feature, and ensures that the user is making a conscious decision. struct MatchAnyOpTypeTag {}; - /// This constructor is used for patterns that match against a specific - /// operation type. The `benefit` is the expected benefit of matching this - /// pattern. + /// Construct a pattern with a certain benefit that matches the operation + /// with the given root name. Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context); - - /// This constructor is used when a pattern may match against multiple - /// different types of operations. The `benefit` is the expected benefit of - /// matching this pattern. `MatchAnyOpTypeTag` is just a tag to ensure that - /// the "match any" behavior is what the user actually desired, - /// `MatchAnyOpTypeTag()` should always be supplied here. - Pattern(PatternBenefit benefit, MatchAnyOpTypeTag); + /// Construct a pattern with a certain benefit that matches any operation + /// type. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" + /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should + /// always be supplied here. + Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag); + /// Construct a pattern with a certain benefit that matches the operation with + /// the given root name. `generatedNames` contains the names of operations + /// that may be generated during a successful rewrite. + Pattern(StringRef rootName, ArrayRef generatedNames, + PatternBenefit benefit, MLIRContext *context); + /// Construct a pattern that may match any operation type. `generatedNames` + /// contains the names of operations that may be generated during a successful + /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" + /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should + /// always be supplied here. + Pattern(ArrayRef generatedNames, PatternBenefit benefit, + MLIRContext *context, MatchAnyOpTypeTag tag); + + /// Set the flag detailing if this pattern has bounded rewrite recursion or + /// not. + void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) { + hasBoundedRecursion = hasBoundedRecursionArg; + } private: + /// A list of the potential operations that may be generated when rewriting + /// an op with this pattern. + SmallVector generatedOps; + /// The root operation of the pattern. If the pattern matches a specific /// operation, this contains the name of that operation. Contains None /// otherwise. @@ -115,9 +134,14 @@ private: /// The expected benefit of matching this pattern. const PatternBenefit benefit; - virtual void anchor(); + /// A boolean flag of whether this pattern has bounded recursion or not. + bool hasBoundedRecursion = false; }; +//===----------------------------------------------------------------------===// +// RewritePattern +//===----------------------------------------------------------------------===// + /// RewritePattern is the common base class for all DAG to DAG replacements. /// There are two possible usages of this class: /// * Multi-step RewritePattern with "match" and "rewrite" @@ -129,6 +153,8 @@ private: /// class RewritePattern : public Pattern { public: + virtual ~RewritePattern() {} + /// Rewrite the IR rooted at the specified operation with the result of /// this pattern, generating any new operations with the specified /// builder. If an unexpected error is encountered (an internal @@ -138,7 +164,7 @@ public: /// Attempt to match against code rooted at the specified operation, /// which is the same operation code as getRootKind(). - LogicalResult match(Operation *op) const override; + virtual LogicalResult match(Operation *op) const; /// Attempt to match against code rooted at the specified operation, /// which is the same operation code as getRootKind(). If successful, this @@ -152,44 +178,12 @@ public: return failure(); } - /// Returns true if this pattern is known to result in recursive application, - /// i.e. this pattern may generate IR that also matches this pattern, but is - /// known to bound the recursion. This signals to a rewriter that it is safe - /// to apply this pattern recursively to generated IR. - virtual bool hasBoundedRewriteRecursion() const { return false; } - - /// Return a list of operations that may be generated when rewriting an - /// operation instance with this pattern. - ArrayRef getGeneratedOps() const { return generatedOps; } - protected: - /// Construct a rewrite pattern with a certain benefit that matches the - /// operation with the given root name. - RewritePattern(StringRef rootName, PatternBenefit benefit, - MLIRContext *context) - : Pattern(rootName, benefit, context) {} - /// Construct a rewrite pattern with a certain benefit that matches any - /// operation type. `MatchAnyOpTypeTag` is just a tag to ensure that the - /// "match any" behavior is what the user actually desired, - /// `MatchAnyOpTypeTag()` should always be supplied here. - RewritePattern(PatternBenefit benefit, MatchAnyOpTypeTag tag) - : Pattern(benefit, tag) {} - /// Construct a rewrite pattern with a certain benefit that matches the - /// operation with the given root name. `generatedNames` contains the names of - /// operations that may be generated during a successful rewrite. - RewritePattern(StringRef rootName, ArrayRef generatedNames, - PatternBenefit benefit, MLIRContext *context); - /// Construct a rewrite pattern that may match any operation type. - /// `generatedNames` contains the names of operations that may be generated - /// during a successful rewrite. `MatchAnyOpTypeTag` is just a tag to ensure - /// that the "match any" behavior is what the user actually desired, - /// `MatchAnyOpTypeTag()` should always be supplied here. - RewritePattern(ArrayRef generatedNames, PatternBenefit benefit, - MLIRContext *context, MatchAnyOpTypeTag tag); + /// Inherit the base constructors from `Pattern`. + using Pattern::Pattern; - /// A list of the potential operations that may be generated when rewriting - /// an op with this pattern. - SmallVector generatedOps; + /// An anchor for the virtual table. + virtual void anchor(); }; /// OpRewritePattern is a wrapper around RewritePattern that allows for @@ -232,7 +226,7 @@ template struct OpRewritePattern : public RewritePattern { }; //===----------------------------------------------------------------------===// -// PatternRewriter class +// PatternRewriter //===----------------------------------------------------------------------===// /// This class coordinates the application of a pattern to the current function, @@ -498,7 +492,7 @@ public: /// pattern. Users can query contained patterns and pass analysis results to /// applyCostModel. Patterns to be discarded should have a benefit of /// `impossibleToMatch`. - using CostModel = function_ref; + using CostModel = function_ref; explicit PatternApplicator(const OwningRewritePatternList &owningPatternList) : owningPatternList(owningPatternList) {} @@ -512,11 +506,11 @@ public: /// onFailure: called when a pattern fails to match to perform cleanup. /// onSuccess: called when a pattern match succeeds; return failure() to /// invalidate the match and try another pattern. - LogicalResult matchAndRewrite( - Operation *op, PatternRewriter &rewriter, - function_ref canApply = {}, - function_ref onFailure = {}, - function_ref onSuccess = {}); + LogicalResult + matchAndRewrite(Operation *op, PatternRewriter &rewriter, + function_ref canApply = {}, + function_ref onFailure = {}, + function_ref onSuccess = {}); /// Apply a cost model to the patterns within this applicator. void applyCostModel(CostModel model); @@ -524,22 +518,22 @@ public: /// Apply the default cost model that solely uses the pattern's static /// benefit. void applyDefaultCostModel() { - applyCostModel( - [](const RewritePattern &pattern) { return pattern.getBenefit(); }); + applyCostModel([](const Pattern &pattern) { return pattern.getBenefit(); }); } - /// Walk all of the rewrite patterns within the applicator. - void walkAllPatterns(function_ref walk); + /// Walk all of the patterns within the applicator. + void walkAllPatterns(function_ref walk); private: /// Attempt to match and rewrite the given op with the given pattern, allowing /// a predicate to decide if a pattern can be applied or not, and hooks for if /// the pattern match was a success or failure. - LogicalResult matchAndRewrite( - Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter, - function_ref canApply, - function_ref onFailure, - function_ref onSuccess); + LogicalResult + matchAndRewrite(Operation *op, const RewritePattern &pattern, + PatternRewriter &rewriter, + function_ref canApply, + function_ref onFailure, + function_ref onSuccess); /// The list that owns the patterns used within this applicator. const OwningRewritePatternList &owningPatternList; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index a95e100..71eaf0d 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1042,7 +1042,12 @@ public: class VectorInsertStridedSliceOpSameRankRewritePattern : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx) + : OpRewritePattern(ctx) { + // This pattern creates recursive InsertStridedSliceOp, but the recursion is + // bounded as the rank is strictly decreasing. + setHasBoundedRewriteRecursion(); + } LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -1093,9 +1098,6 @@ public: rewriter.replaceOp(op, res); return success(); } - /// This pattern creates recursive InsertStridedSliceOp, but the recursion is - /// bounded as the rank is strictly decreasing. - bool hasBoundedRewriteRecursion() const final { return true; } }; /// Returns the strides if the memory underlying `memRefType` has a contiguous @@ -1505,7 +1507,12 @@ private: class VectorExtractStridedSliceOpConversion : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + VectorExtractStridedSliceOpConversion(MLIRContext *ctx) + : OpRewritePattern(ctx) { + // This pattern creates recursive ExtractStridedSliceOp, but the recursion + // is bounded as the rank is strictly decreasing. + setHasBoundedRewriteRecursion(); + } LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -1552,9 +1559,6 @@ public: rewriter.replaceOp(op, res); return success(); } - /// This pattern creates recursive ExtractStridedSliceOp, but the recursion is - /// bounded as the rank is strictly decreasing. - bool hasBoundedRewriteRecursion() const final { return true; } }; } // namespace diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index d1da8d1..136d019 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -16,6 +16,10 @@ using namespace mlir; #define DEBUG_TYPE "pattern-match" +//===----------------------------------------------------------------------===// +// PatternBenefit +//===----------------------------------------------------------------------===// + PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) { assert(representation == benefit && benefit != ImpossibleToMatchSentinel && "This pattern match benefit is too large to represent"); @@ -27,34 +31,16 @@ unsigned short PatternBenefit::getBenefit() const { } //===----------------------------------------------------------------------===// -// Pattern implementation +// Pattern //===----------------------------------------------------------------------===// Pattern::Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context) : rootKind(OperationName(rootName, context)), benefit(benefit) {} -Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag) +Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag) : benefit(benefit) {} - -// Out-of-line vtable anchor. -void Pattern::anchor() {} - -//===----------------------------------------------------------------------===// -// RewritePattern and PatternRewriter implementation -//===----------------------------------------------------------------------===// - -void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const { - llvm_unreachable("need to implement either matchAndRewrite or one of the " - "rewrite functions!"); -} - -LogicalResult RewritePattern::match(Operation *op) const { - llvm_unreachable("need to implement either match or matchAndRewrite!"); -} - -RewritePattern::RewritePattern(StringRef rootName, - ArrayRef generatedNames, - PatternBenefit benefit, MLIRContext *context) +Pattern::Pattern(StringRef rootName, ArrayRef generatedNames, + PatternBenefit benefit, MLIRContext *context) : Pattern(rootName, benefit, context) { generatedOps.reserve(generatedNames.size()); std::transform(generatedNames.begin(), generatedNames.end(), @@ -62,9 +48,8 @@ RewritePattern::RewritePattern(StringRef rootName, return OperationName(name, context); }); } -RewritePattern::RewritePattern(ArrayRef generatedNames, - PatternBenefit benefit, MLIRContext *context, - MatchAnyOpTypeTag tag) +Pattern::Pattern(ArrayRef generatedNames, PatternBenefit benefit, + MLIRContext *context, MatchAnyOpTypeTag tag) : Pattern(benefit, tag) { generatedOps.reserve(generatedNames.size()); std::transform(generatedNames.begin(), generatedNames.end(), @@ -73,6 +58,26 @@ RewritePattern::RewritePattern(ArrayRef generatedNames, }); } +//===----------------------------------------------------------------------===// +// RewritePattern +//===----------------------------------------------------------------------===// + +void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const { + llvm_unreachable("need to implement either matchAndRewrite or one of the " + "rewrite functions!"); +} + +LogicalResult RewritePattern::match(Operation *op) const { + llvm_unreachable("need to implement either match or matchAndRewrite!"); +} + +/// Out-of-line vtable anchor. +void RewritePattern::anchor() {} + +//===----------------------------------------------------------------------===// +// PatternRewriter +//===----------------------------------------------------------------------===// + PatternRewriter::~PatternRewriter() { // Out of line to provide a vtable anchor for the class. } @@ -201,7 +206,7 @@ void PatternRewriter::cloneRegionBefore(Region ®ion, Block *before) { } //===----------------------------------------------------------------------===// -// PatternMatcher implementation +// PatternApplicator //===----------------------------------------------------------------------===// void PatternApplicator::applyCostModel(CostModel model) { @@ -266,16 +271,16 @@ void PatternApplicator::applyCostModel(CostModel model) { } void PatternApplicator::walkAllPatterns( - function_ref walk) { + function_ref walk) { for (auto &it : owningPatternList) walk(*it); } LogicalResult PatternApplicator::matchAndRewrite( Operation *op, PatternRewriter &rewriter, - function_ref canApply, - function_ref onFailure, - function_ref onSuccess) { + function_ref canApply, + function_ref onFailure, + function_ref onSuccess) { // Check to see if there are patterns matching this specific operation type. MutableArrayRef opPatterns; auto patternIt = patterns.find(op->getName()); @@ -315,9 +320,9 @@ LogicalResult PatternApplicator::matchAndRewrite( LogicalResult PatternApplicator::matchAndRewrite( Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter, - function_ref canApply, - function_ref onFailure, - function_ref onSuccess) { + function_ref canApply, + function_ref onFailure, + function_ref onSuccess) { // Check that the pattern can be applied. if (canApply && !canApply(pattern)) return failure(); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 5f6c972..692cd49 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1452,7 +1452,7 @@ ConversionPattern::matchAndRewrite(Operation *op, namespace { /// A set of rewrite patterns that can be used to legalize a given operation. -using LegalizationPatterns = SmallVector; +using LegalizationPatterns = SmallVector; /// This class defines a recursive operation legalizer. class OperationLegalizer { @@ -1484,12 +1484,11 @@ private: /// Return true if the given pattern may be applied to the given operation, /// false otherwise. - bool canApplyPattern(Operation *op, const RewritePattern &pattern, + bool canApplyPattern(Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter); /// Legalize the resultant IR after successfully applying the given pattern. - LogicalResult legalizePatternResult(Operation *op, - const RewritePattern &pattern, + LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, RewriterState &curState); @@ -1546,7 +1545,7 @@ private: DenseMap &legalizerPatterns); /// The current set of patterns that have been applied. - SmallPtrSet appliedPatterns; + SmallPtrSet appliedPatterns; /// The legalization information provided by the target. ConversionTarget ⌖ @@ -1697,13 +1696,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op, auto &rewriterImpl = rewriter.getImpl(); // Functor that returns if the given pattern may be applied. - auto canApply = [&](const RewritePattern &pattern) { + auto canApply = [&](const Pattern &pattern) { return canApplyPattern(op, pattern, rewriter); }; // Functor that cleans up the rewriter state after a pattern failed to match. RewriterState curState = rewriterImpl.getCurrentState(); - auto onFailure = [&](const RewritePattern &pattern) { + auto onFailure = [&](const Pattern &pattern) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match")); rewriterImpl.resetState(curState); appliedPatterns.erase(&pattern); @@ -1711,7 +1710,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // Functor that performs additional legalization when a pattern is // successfully applied. - auto onSuccess = [&](const RewritePattern &pattern) { + auto onSuccess = [&](const Pattern &pattern) { auto result = legalizePatternResult(op, pattern, rewriter, curState); appliedPatterns.erase(&pattern); if (failed(result)) @@ -1724,8 +1723,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op, onSuccess); } -bool OperationLegalizer::canApplyPattern(Operation *op, - const RewritePattern &pattern, +bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter) { LLVM_DEBUG({ auto &os = rewriter.getImpl().logger; @@ -1747,9 +1745,10 @@ bool OperationLegalizer::canApplyPattern(Operation *op, return true; } -LogicalResult OperationLegalizer::legalizePatternResult( - Operation *op, const RewritePattern &pattern, - ConversionPatternRewriter &rewriter, RewriterState &curState) { +LogicalResult +OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern, + ConversionPatternRewriter &rewriter, + RewriterState &curState) { auto &impl = rewriter.getImpl(); #ifndef NDEBUG @@ -1877,13 +1876,12 @@ void OperationLegalizer::buildLegalizationGraph( // generate it. DenseMap> parentOps; // A mapping between an operation and any currently invalid patterns it has. - DenseMap> - invalidPatterns; + DenseMap> invalidPatterns; // A worklist of patterns to consider for legality. - llvm::SetVector patternWorklist; + llvm::SetVector patternWorklist; // Build the mapping from operations to the parent ops that may generate them. - applicator.walkAllPatterns([&](const RewritePattern &pattern) { + applicator.walkAllPatterns([&](const Pattern &pattern) { Optional root = pattern.getRootKind(); // If the pattern has no specific root, we can't analyze the relationship @@ -1914,7 +1912,7 @@ void OperationLegalizer::buildLegalizationGraph( // recurse into itself. It would be better to perform this kind of filtering // at a higher level than here anyways. if (!anyOpLegalizerPatterns.empty()) { - for (const RewritePattern *pattern : patternWorklist) + for (const Pattern *pattern : patternWorklist) legalizerPatterns[*pattern->getRootKind()].push_back(pattern); return; } @@ -1964,15 +1962,15 @@ void OperationLegalizer::computeLegalizationGraphBenefit( // Apply a cost model to the pattern applicator. We order patterns first by // depth then benefit. `legalizerPatterns` contains per-op patterns by // decreasing benefit. - applicator.applyCostModel([&](const RewritePattern &p) { - ArrayRef orderedPatternList; - if (Optional rootName = p.getRootKind()) + applicator.applyCostModel([&](const Pattern &pattern) { + ArrayRef orderedPatternList; + if (Optional rootName = pattern.getRootKind()) orderedPatternList = legalizerPatterns[*rootName]; else orderedPatternList = anyOpLegalizerPatterns; // If the pattern is not found, then it was removed and cannot be matched. - auto it = llvm::find(orderedPatternList, &p); + auto it = llvm::find(orderedPatternList, &pattern); if (it == orderedPatternList.end()) return PatternBenefit::impossibleToMatch(); @@ -2014,9 +2012,9 @@ unsigned OperationLegalizer::applyCostModelToPatterns( unsigned minDepth = std::numeric_limits::max(); // Compute the depth for each pattern within the set. - SmallVector, 4> patternsByDepth; + SmallVector, 4> patternsByDepth; patternsByDepth.reserve(patterns.size()); - for (const RewritePattern *pattern : patterns) { + for (const Pattern *pattern : patterns) { unsigned depth = 0; for (auto generatedOp : pattern->getGeneratedOps()) { unsigned generatedOpDepth = computeOpLegalizationDepth( @@ -2037,8 +2035,8 @@ unsigned OperationLegalizer::applyCostModelToPatterns( // Sort the patterns by those likely to be the most beneficial. llvm::array_pod_sort( patternsByDepth.begin(), patternsByDepth.end(), - [](const std::pair *lhs, - const std::pair *rhs) { + [](const std::pair *lhs, + const std::pair *rhs) { // First sort by the smaller pattern legalization depth. if (lhs->second != rhs->second) return llvm::array_pod_sort_comparator(&lhs->second, diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 282d310..04a21ec 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -452,7 +452,11 @@ struct TestNonRootReplacement : public RewritePattern { /// bounded recursion. struct TestBoundedRecursiveRewrite : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + TestBoundedRecursiveRewrite(MLIRContext *ctx) + : OpRewritePattern(ctx) { + // The conversion target handles bounding the recursion of this pattern. + setHasBoundedRewriteRecursion(); + } LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, PatternRewriter &rewriter) const final { @@ -462,9 +466,6 @@ struct TestBoundedRecursiveRewrite }); return success(); } - - /// The conversion target handles bounding the recursion of this pattern. - bool hasBoundedRewriteRecursion() const final { return true; } }; struct TestNestedOpCreationUndoRewrite