// DialectConversionRewriter
//===----------------------------------------------------------------------===//
+/// This class contains a snapshot of the current conversion rewriter state.
+/// This is useful when saving and undoing a set of rewrites.
+struct RewriterState {
+ RewriterState(unsigned numCreatedOperations, unsigned numReplacements)
+ : numCreatedOperations(numCreatedOperations),
+ numReplacements(numReplacements) {}
+
+ /// The current number of created operations.
+ unsigned numCreatedOperations;
+
+ /// The current number of replacements queued.
+ unsigned numReplacements;
+};
+
/// This class implements a pattern rewriter for ConversionPattern
/// patterns. It automatically performs remapping of replaced operation values.
struct DialectConversionRewriter final : public PatternRewriter {
: PatternRewriter(region), argConverter(region.getContext()) {}
~DialectConversionRewriter() = default;
+ /// Return the current state of the rewriter.
+ RewriterState getCurrentState() {
+ return RewriterState(createdOps.size(), replacements.size());
+ }
+
+ /// Reset the state of the rewriter to a previously saved point.
+ void resetState(RewriterState state) {
+ // Reset any replaced operations and undo any saved mappings.
+ for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
+ for (auto *result : repl.op->getResults())
+ mapping.erase(result);
+ replacements.resize(state.numReplacements);
+
+ // Pop all of the newly created operations.
+ while (createdOps.size() != state.numCreatedOperations)
+ createdOps.pop_back_val()->erase();
+ }
+
/// Cleanup and destroy any generated rewrite operations. This method is
/// invoked when the conversion process fails.
void discardRewrites() {
return failure();
}
- auto curOpCount = rewriter.createdOps.size();
- auto curReplCount = rewriter.replacements.size();
+ RewriterState curState = rewriter.getCurrentState();
auto cleanupFailure = [&] {
- // Pop all of the newly created operations and replacements.
- while (rewriter.createdOps.size() != curOpCount)
- rewriter.createdOps.pop_back_val()->erase();
- rewriter.replacements.resize(curReplCount);
+ // Reset the rewriter state and pop this pattern.
+ rewriter.resetState(curState);
appliedPatterns.erase(pattern);
return failure();
};
}
// Recursively legalize each of the new operations.
- for (unsigned i = curOpCount, e = rewriter.createdOps.size(); i != e; ++i) {
- if (succeeded(legalize(rewriter.createdOps[i], rewriter)))
- continue;
- LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
- return cleanupFailure();
+ for (unsigned i = curState.numCreatedOperations,
+ e = rewriter.createdOps.size();
+ i != e; ++i) {
+ if (failed(legalize(rewriter.createdOps[i], rewriter))) {
+ LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n");
+ return cleanupFailure();
+ }
}
appliedPatterns.erase(pattern);