/// Fully replace uses of the old arguments with the new, materializing cast
/// operations as necessary.
- // FIXME(riverriddle) The 'mapping' parameter is only necessary because the
- // implementation of replaceUsesOfBlockArgument is buggy.
void applyRewrites(ConversionValueMapping &mapping);
//===--------------------------------------------------------------------===//
/// This is useful when saving and undoing a set of rewrites.
struct RewriterState {
RewriterState(unsigned numCreatedOps, unsigned numReplacements,
- unsigned numBlockActions, unsigned numIgnoredOperations,
- unsigned numRootUpdates)
+ unsigned numArgReplacements, unsigned numBlockActions,
+ unsigned numIgnoredOperations, unsigned numRootUpdates)
: numCreatedOps(numCreatedOps), numReplacements(numReplacements),
+ numArgReplacements(numArgReplacements),
numBlockActions(numBlockActions),
numIgnoredOperations(numIgnoredOperations),
numRootUpdates(numRootUpdates) {}
/// The current number of replacements queued.
unsigned numReplacements;
+ /// The current number of argument replacements queued.
+ unsigned numArgReplacements;
+
/// The current number of block actions performed.
unsigned numBlockActions;
/// Ordered vector of any requested operation replacements.
SmallVector<OpReplacement, 4> replacements;
+ /// Ordered vector of any requested block argument replacements.
+ SmallVector<BlockArgument, 4> argReplacements;
+
/// Ordered list of block operations (creations, splits, motions).
SmallVector<BlockAction, 4> blockActions;
RewriterState ConversionPatternRewriterImpl::getCurrentState() {
return RewriterState(createdOps.size(), replacements.size(),
- blockActions.size(), ignoredOps.size(),
- rootUpdates.size());
+ argReplacements.size(), blockActions.size(),
+ ignoredOps.size(), rootUpdates.size());
}
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
rootUpdates[i].resetOperation();
rootUpdates.resize(state.numRootUpdates);
+ // Reset any replaced arguments.
+ for (BlockArgument replacedArg :
+ llvm::drop_begin(argReplacements, state.numArgReplacements))
+ mapping.erase(replacedArg);
+ argReplacements.resize(state.numArgReplacements);
+
// Undo any block actions.
undoBlockActions(state.numBlockActions);
argConverter.notifyOpRemoved(repl.op);
}
+ // Apply all of the requested argument replacements.
+ for (BlockArgument arg : argReplacements) {
+ Value repl = mapping.lookupOrDefault(arg);
+ if (repl.isa<BlockArgument>()) {
+ arg.replaceAllUsesWith(repl);
+ continue;
+ }
+
+ // If the replacement value is an operation, we check to make sure that we
+ // don't replace uses that are within the parent operation of the
+ // replacement value.
+ Operation *replOp = repl.cast<OpResult>().getOwner();
+ Block *replBlock = replOp->getBlock();
+ arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+ Operation *user = operand.getOwner();
+ return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
+ });
+ }
+
// In a second pass, erase all of the replaced operations in reverse. This
// allows processing nested operations before their parent region is
// destroyed.
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value to) {
- for (auto &u : from.getUses()) {
- if (u.getOwner() == to.getDefiningOp())
- continue;
- u.getOwner()->replaceUsesOfWith(from, to);
- }
+ LLVM_DEBUG({
+ Operation *parentOp = from.getOwner()->getParentOp();
+ impl->logger.startLine() << "** Replace Argument : '" << from
+ << "'(in region of '" << parentOp->getName()
+ << "'(" << from.getOwner()->getParentOp() << ")\n";
+ });
+ impl->argReplacements.push_back(from);
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}
}
};
+/// A simple pattern that tests the undo mechanism when replacing the uses of a
+/// block argument.
+struct TestUndoBlockArgReplace : public ConversionPattern {
+ TestUndoBlockArgReplace(MLIRContext *ctx)
+ : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {}
+
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ auto illegalOp =
+ rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
+ rewriter.replaceUsesOfBlockArgument(op->getRegion(0).front().getArgument(0),
+ illegalOp);
+ rewriter.updateRootInPlace(op, [] {});
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Type-Conversion Rewrite Testing
TestTypeConverter converter;
mlir::OwningRewritePatternList patterns;
populateWithGenerated(&getContext(), &patterns);
- patterns.insert<
- TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
- TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType,
- TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
- TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
- TestNonRootReplacement, TestBoundedRecursiveRewrite>(&getContext());
+ patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
+ TestCreateBlock, TestCreateIllegalBlock,
+ TestUndoBlockArgReplace, TestPassthroughInvalidOp,
+ TestSplitReturnType, TestChangeProducerTypeI32ToF32,
+ TestChangeProducerTypeF32ToF64,
+ TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+ TestNonRootReplacement, TestBoundedRecursiveRewrite>(
+ &getContext());
patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
converter);