//===--------------------------------------------------------------------===//
/// This class represents a StringSwitch like class that is useful for parsing
- /// expected keywords. On construction, it invokes `parseKeyword` and
- /// processes each of the provided cases statements until a match is hit. The
- /// provided `ResultT` must be assignable from `failure()`.
+ /// expected keywords. On construction, unless a non-empty keyword is
+ /// provided, it invokes `parseKeyword` and processes each of the provided
+ /// cases statements until a match is hit. The provided `ResultT` must be
+ /// assignable from `failure()`.
template <typename ResultT = ParseResult>
class KeywordSwitch {
public:
- KeywordSwitch(AsmParser &parser)
+ KeywordSwitch(AsmParser &parser, StringRef *keyword = nullptr)
: parser(parser), loc(parser.getCurrentLocation()) {
- if (failed(parser.parseKeywordOrCompletion(&keyword)))
+ if (keyword && !keyword->empty())
+ this->keyword = *keyword;
+ else if (failed(parser.parseKeywordOrCompletion(&this->keyword)))
result = failure();
}
-
/// Case that uses the provided value when true.
KeywordSwitch &Case(StringLiteral str, ResultT value) {
return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
Canonicalizer() = default;
Canonicalizer(const GreedyRewriteConfig &config,
ArrayRef<std::string> disabledPatterns,
- ArrayRef<std::string> enabledPatterns) {
+ ArrayRef<std::string> enabledPatterns)
+ : config(config) {
this->topDownProcessingEnabled = config.useTopDownTraversal;
this->enableRegionSimplification = config.enableRegionSimplification;
this->maxIterations = config.maxIterations;
/// Initialize the canonicalizer by building the set of patterns used during
/// execution.
LogicalResult initialize(MLIRContext *context) override {
+ // Set the config from possible pass options set in the meantime.
+ config.useTopDownTraversal = topDownProcessingEnabled;
+ config.enableRegionSimplification = enableRegionSimplification;
+ config.maxIterations = maxIterations;
+ config.maxNumRewrites = maxNumRewrites;
+
RewritePatternSet owningPatterns(context);
for (auto *dialect : context->getLoadedDialects())
dialect->getCanonicalizationPatterns(owningPatterns);
for (RegisteredOperationName op : context->getRegisteredOperations())
op.getCanonicalizationPatterns(owningPatterns, context);
- patterns = FrozenRewritePatternSet(std::move(owningPatterns),
- disabledPatterns, enabledPatterns);
+ patterns = std::make_shared<FrozenRewritePatternSet>(
+ std::move(owningPatterns), disabledPatterns, enabledPatterns);
return success();
}
void runOnOperation() override {
- GreedyRewriteConfig config;
- config.useTopDownTraversal = topDownProcessingEnabled;
- config.enableRegionSimplification = enableRegionSimplification;
- config.maxIterations = maxIterations;
- config.maxNumRewrites = maxNumRewrites;
LogicalResult converged =
- applyPatternsAndFoldGreedily(getOperation(), patterns, config);
+ applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
// Canonicalization is best-effort. Non-convergence is not a pass failure.
if (testConvergence && failed(converged))
signalPassFailure();
}
-
- FrozenRewritePatternSet patterns;
+ GreedyRewriteConfig config;
+ std::shared_ptr<const FrozenRewritePatternSet> patterns;
};
} // namespace