Fix canonicalizer to copy the entire GreedyRewriteConfig instead of selected fields
authorMehdi Amini <joker.eph@gmail.com>
Wed, 9 Aug 2023 01:58:12 +0000 (18:58 -0700)
committerTobias Hieta <tobias@hieta.se>
Fri, 25 Aug 2023 07:42:01 +0000 (09:42 +0200)
It is surprising for the user that only some fields were honored.

Also make the FrozenRewritePatternSet a shared_ptr<const T>.

Fixes #64543

Differential Revision: https://reviews.llvm.org/D157469

mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Transforms/Canonicalizer.cpp

index 0eeb8bb..2131fe3 100644 (file)
@@ -715,18 +715,20 @@ public:
   //===--------------------------------------------------------------------===//
 
   /// This class represents a StringSwitch like class that is useful for parsing
-  /// expected keywords. On construction, it invokes `parseKeyword` and
-  /// processes each of the provided cases statements until a match is hit. The
-  /// provided `ResultT` must be assignable from `failure()`.
+  /// expected keywords. On construction, unless a non-empty keyword is
+  /// provided, it invokes `parseKeyword` and processes each of the provided
+  /// cases statements until a match is hit. The provided `ResultT` must be
+  /// assignable from `failure()`.
   template <typename ResultT = ParseResult>
   class KeywordSwitch {
   public:
-    KeywordSwitch(AsmParser &parser)
+    KeywordSwitch(AsmParser &parser, StringRef *keyword = nullptr)
         : parser(parser), loc(parser.getCurrentLocation()) {
-      if (failed(parser.parseKeywordOrCompletion(&keyword)))
+      if (keyword && !keyword->empty())
+        this->keyword = *keyword;
+      else if (failed(parser.parseKeywordOrCompletion(&this->keyword)))
         result = failure();
     }
-
     /// Case that uses the provided value when true.
     KeywordSwitch &Case(StringLiteral str, ResultT value) {
       return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
index b4ad85c..d50019b 100644 (file)
@@ -29,7 +29,8 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
   Canonicalizer() = default;
   Canonicalizer(const GreedyRewriteConfig &config,
                 ArrayRef<std::string> disabledPatterns,
-                ArrayRef<std::string> enabledPatterns) {
+                ArrayRef<std::string> enabledPatterns)
+      : config(config) {
     this->topDownProcessingEnabled = config.useTopDownTraversal;
     this->enableRegionSimplification = config.enableRegionSimplification;
     this->maxIterations = config.maxIterations;
@@ -41,30 +42,31 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
   /// Initialize the canonicalizer by building the set of patterns used during
   /// execution.
   LogicalResult initialize(MLIRContext *context) override {
+    // Set the config from possible pass options set in the meantime.
+    config.useTopDownTraversal = topDownProcessingEnabled;
+    config.enableRegionSimplification = enableRegionSimplification;
+    config.maxIterations = maxIterations;
+    config.maxNumRewrites = maxNumRewrites;
+
     RewritePatternSet owningPatterns(context);
     for (auto *dialect : context->getLoadedDialects())
       dialect->getCanonicalizationPatterns(owningPatterns);
     for (RegisteredOperationName op : context->getRegisteredOperations())
       op.getCanonicalizationPatterns(owningPatterns, context);
 
-    patterns = FrozenRewritePatternSet(std::move(owningPatterns),
-                                       disabledPatterns, enabledPatterns);
+    patterns = std::make_shared<FrozenRewritePatternSet>(
+        std::move(owningPatterns), disabledPatterns, enabledPatterns);
     return success();
   }
   void runOnOperation() override {
-    GreedyRewriteConfig config;
-    config.useTopDownTraversal = topDownProcessingEnabled;
-    config.enableRegionSimplification = enableRegionSimplification;
-    config.maxIterations = maxIterations;
-    config.maxNumRewrites = maxNumRewrites;
     LogicalResult converged =
-        applyPatternsAndFoldGreedily(getOperation(), patterns, config);
+        applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
     // Canonicalization is best-effort. Non-convergence is not a pass failure.
     if (testConvergence && failed(converged))
       signalPassFailure();
   }
-
-  FrozenRewritePatternSet patterns;
+  GreedyRewriteConfig config;
+  std::shared_ptr<const FrozenRewritePatternSet> patterns;
 };
 } // namespace