When cleaning up after a failed legalization pattern, make sure to remove any newly...
authorRiver Riddle <riverriddle@google.com>
Wed, 5 Jun 2019 16:36:32 +0000 (09:36 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:18:32 +0000 (16:18 -0700)
PiperOrigin-RevId: 251658984

mlir/lib/Transforms/DialectConversion.cpp

index 1deedc1..333214e 100644 (file)
@@ -95,6 +95,20 @@ constexpr StringLiteral ArgConverter::kCastName;
 // DialectConversionRewriter
 //===----------------------------------------------------------------------===//
 
+/// This class contains a snapshot of the current conversion rewriter state.
+/// This is useful when saving and undoing a set of rewrites.
+struct RewriterState {
+  RewriterState(unsigned numCreatedOperations, unsigned numReplacements)
+      : numCreatedOperations(numCreatedOperations),
+        numReplacements(numReplacements) {}
+
+  /// The current number of created operations.
+  unsigned numCreatedOperations;
+
+  /// The current number of replacements queued.
+  unsigned numReplacements;
+};
+
 /// This class implements a pattern rewriter for ConversionPattern
 /// patterns. It automatically performs remapping of replaced operation values.
 struct DialectConversionRewriter final : public PatternRewriter {
@@ -112,6 +126,24 @@ struct DialectConversionRewriter final : public PatternRewriter {
       : PatternRewriter(region), argConverter(region.getContext()) {}
   ~DialectConversionRewriter() = default;
 
+  /// Return the current state of the rewriter.
+  RewriterState getCurrentState() {
+    return RewriterState(createdOps.size(), replacements.size());
+  }
+
+  /// Reset the state of the rewriter to a previously saved point.
+  void resetState(RewriterState state) {
+    // Reset any replaced operations and undo any saved mappings.
+    for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
+      for (auto *result : repl.op->getResults())
+        mapping.erase(result);
+    replacements.resize(state.numReplacements);
+
+    // Pop all of the newly created operations.
+    while (createdOps.size() != state.numCreatedOperations)
+      createdOps.pop_back_val()->erase();
+  }
+
   /// Cleanup and destroy any generated rewrite operations. This method is
   /// invoked when the conversion process fails.
   void discardRewrites() {
@@ -354,13 +386,10 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
     return failure();
   }
 
-  auto curOpCount = rewriter.createdOps.size();
-  auto curReplCount = rewriter.replacements.size();
+  RewriterState curState = rewriter.getCurrentState();
   auto cleanupFailure = [&] {
-    // Pop all of the newly created operations and replacements.
-    while (rewriter.createdOps.size() != curOpCount)
-      rewriter.createdOps.pop_back_val()->erase();
-    rewriter.replacements.resize(curReplCount);
+    // Reset the rewriter state and pop this pattern.
+    rewriter.resetState(curState);
     appliedPatterns.erase(pattern);
     return failure();
   };
@@ -373,11 +402,13 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
   }
 
   // Recursively legalize each of the new operations.
-  for (unsigned i = curOpCount, e = rewriter.createdOps.size(); i != e; ++i) {
-    if (succeeded(legalize(rewriter.createdOps[i], rewriter)))
-      continue;
-    LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
-    return cleanupFailure();
+  for (unsigned i = curState.numCreatedOperations,
+                e = rewriter.createdOps.size();
+       i != e; ++i) {
+    if (failed(legalize(rewriter.createdOps[i], rewriter))) {
+      LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
+      return cleanupFailure();
+    }
   }
 
   appliedPatterns.erase(pattern);