//===----------------------------------------------------------------------===//
namespace {
/// This class provides a simple interface for converting the types of block
-/// arguments. This is done by inserting fake cast operations that map from the
-/// illegal type to the original type to allow for undoing pending rewrites in
-/// the case of failure.
+/// arguments. This is done by creating a new block that contains the new legal
+/// types and extracting the block that contains the old illegal types to allow
+/// for undoing pending rewrites in the case of failure.
struct ArgConverter {
ArgConverter(TypeConverter *typeConverter, PatternRewriter &rewriter)
- : castOpName(kCastName, rewriter.getContext()),
- loc(rewriter.getUnknownLoc()), typeConverter(typeConverter),
+ : loc(rewriter.getUnknownLoc()), typeConverter(typeConverter),
rewriter(rewriter) {}
- /// Erase any rewrites registered for arguments to blocks within the given
- /// region. This function is called when the given region is to be destroyed.
- void cancelPendingRewrites(Block *block);
+ /// This structure contains the information pertaining to an argument that has
+ /// been converted.
+ struct ConvertedArgInfo {
+ ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
+ Value *castValue = nullptr)
+ : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
- /// Cleanup and undo any generated conversions for the arguments of block.
- /// This method differs from 'cancelPendingRewrites' in that it returns the
- /// block signature to its original state.
- void discardPendingRewrites(Block *block);
+ /// The start index of in the new argument list that contains arguments that
+ /// replace the original.
+ unsigned newArgIdx;
- /// Replace usages of the cast operations with the argument directly.
- void applyRewrites();
+ /// The number of arguments that replaced the original argument.
+ unsigned newArgSize;
+
+ /// The cast value that was created to cast from the new arguments to the
+ /// old. This only used if 'newArgSize' > 1.
+ Value *castValue;
+ };
+
+ /// This structure contains information pertaining to a block that has had its
+ /// signature converted.
+ struct ConvertedBlockInfo {
+ ConvertedBlockInfo(Block *origBlock) : origBlock(origBlock) {}
+
+ /// The original block that was requested to have its signature converted.
+ Block *origBlock;
+
+ /// The conversion information for each of the arguments. The information is
+ /// None if the argument was dropped during conversion.
+ SmallVector<Optional<ConvertedArgInfo>, 1> argInfo;
+ };
/// Return if the signature of the given block has already been converted.
- bool hasBeenConverted(Block *block) const { return argMapping.count(block); }
+ bool hasBeenConverted(Block *block) const {
+ return conversionInfo.count(block);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Rewrite Application
+ //===--------------------------------------------------------------------===//
- /// Attempt to convert the signature of the given block.
- LogicalResult convertSignature(Block *block, ConversionValueMapping &mapping);
+ /// Erase any rewrites registered for the current block that is about to be
+ /// removed. This merely drops the rewrites without undoing them.
+ void notifyBlockRemoved(Block *block);
- /// Apply the given signature conversion on the given block.
- void applySignatureConversion(
+ /// Cleanup and undo any generated conversions for the arguments of block.
+ /// This method replaces the new block with the original, reverting the IR to
+ /// its original state.
+ void discardRewrites(Block *block);
+
+ /// Fully replace uses of the old arguments with the new, materializing cast
+ /// operations as necessary.
+ // FIXME(riverriddle) The 'mapping' parameter is only necessary because the
+ // implementation of replaceUsesOfBlockArgument is buggy.
+ void applyRewrites(ConversionValueMapping &mapping);
+
+ //===--------------------------------------------------------------------===//
+ // Conversion
+ //===--------------------------------------------------------------------===//
+
+ /// Attempt to convert the signature of the given block, if successful a new
+ /// block is returned containing the new arguments. On failure, nullptr is
+ /// returned.
+ Block *convertSignature(Block *block, ConversionValueMapping &mapping);
+
+ /// Apply the given signature conversion on the given block. The new block
+ /// containing the updated signature is returned.
+ Block *applySignatureConversion(
Block *block, TypeConverter::SignatureConversion &signatureConversion,
ConversionValueMapping &mapping);
- /// Convert the given block argument given the provided set of new argument
- /// values that are to replace it. This function returns the operation used
- /// to perform the conversion.
- Operation *convertArgument(BlockArgument *origArg,
- ArrayRef<Value *> newValues,
- ConversionValueMapping &mapping);
-
- /// A utility function used to create a conversion cast operation with the
- /// given input and result types.
- Operation *createCast(ArrayRef<Value *> inputs, Type outputType);
-
- /// This is an operation name for a fake operation that is inserted during the
- /// conversion process. Operations of this type are guaranteed to never escape
- /// the converter.
- static constexpr StringLiteral kCastName = "__mlir_conversion.cast";
- OperationName castOpName;
-
- /// This is a collection of cast operations that were generated during the
- /// conversion process when converting the types of block arguments.
- llvm::MapVector<Block *, SmallVector<Operation *, 4>> argMapping;
-
- /// An instance of the unknown location that is used when generating
- /// producers.
+ /// A collection of blocks that have had their arguments converted.
+ llvm::MapVector<Block *, ConvertedBlockInfo> conversionInfo;
+
+ /// An instance of the unknown location that is used when materializing
+ /// conversions.
Location loc;
/// The type converter to use when changing types.
};
} // end anonymous namespace
-constexpr StringLiteral ArgConverter::kCastName;
+//===----------------------------------------------------------------------===//
+// Rewrite Application
-/// Erase any rewrites registered for arguments to the given block.
-void ArgConverter::cancelPendingRewrites(Block *block) {
- auto it = argMapping.find(block);
- if (it == argMapping.end())
+void ArgConverter::notifyBlockRemoved(Block *block) {
+ auto it = conversionInfo.find(block);
+ if (it == conversionInfo.end())
return;
- for (auto *op : it->second) {
- op->dropAllDefinedValueUses();
- op->erase();
- }
- argMapping.erase(it);
+
+ // Drop all uses of the original arguments and delete the original block.
+ Block *origBlock = it->second.origBlock;
+ for (BlockArgument *arg : origBlock->getArguments())
+ arg->dropAllUses();
+ delete origBlock;
+
+ conversionInfo.erase(it);
}
-/// Cleanup and undo any generated conversions for the arguments of block.
-/// This method differs from 'cancelPendingRewrites' in that it returns the
-/// block signature to its original state.
-void ArgConverter::discardPendingRewrites(Block *block) {
- auto it = argMapping.find(block);
- if (it == argMapping.end())
+void ArgConverter::discardRewrites(Block *block) {
+ auto it = conversionInfo.find(block);
+ if (it == conversionInfo.end())
return;
+ Block *origBlock = it->second.origBlock;
- // Erase all of the new arguments.
- for (int i = block->getNumArguments() - 1; i >= 0; --i) {
+ // Drop all uses of the new block arguments and replace uses of the new block.
+ for (int i = block->getNumArguments() - 1; i >= 0; --i)
block->getArgument(i)->dropAllUses();
- block->eraseArgument(i, /*updatePredTerms=*/false);
- }
+ block->replaceAllUsesWith(origBlock);
- // Re-instate the old arguments.
- auto &mapping = it->second;
- for (unsigned i = 0, e = mapping.size(); i != e; ++i) {
- auto *op = mapping[i];
- auto *arg = block->addArgument(op->getResult(0)->getType());
- op->getResult(0)->replaceAllUsesWith(arg);
+ // Move the operations back the original block and the delete the new block.
+ origBlock->getOperations().splice(origBlock->end(), block->getOperations());
+ origBlock->insertBefore(block);
+ block->erase();
- // If this operation is within a block, it will be cleaned up automatically.
- if (!op->getBlock())
- op->erase();
- }
- argMapping.erase(it);
+ conversionInfo.erase(it);
}
-/// Replace usages of the cast operations with the argument directly.
-void ArgConverter::applyRewrites() {
- Block *block;
- ArrayRef<Operation *> argOps;
- for (auto &mapping : argMapping) {
- std::tie(block, argOps) = mapping;
+void ArgConverter::applyRewrites(ConversionValueMapping &mapping) {
+ for (auto &info : conversionInfo) {
+ Block *newBlock = info.first;
+ ConvertedBlockInfo &blockInfo = info.second;
+ Block *origBlock = blockInfo.origBlock;
// Process the remapping for each of the original arguments.
- for (unsigned i = 0, e = argOps.size(); i != e; ++i) {
- auto *op = argOps[i];
-
- // Handle the case of a 1->N value mapping.
- if (op->getNumOperands() > 1) {
- // If all of the uses were removed, we can drop this op. Otherwise,
- // keep the operation alive and let the user handle any remaining
- // usages.
- if (op->use_empty())
- op->erase();
+ for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
+ Optional<ConvertedArgInfo> &argInfo = blockInfo.argInfo[i];
+ BlockArgument *origArg = origBlock->getArgument(i);
+
+ // Handle the case of a 1->0 value mapping.
+ if (!argInfo) {
+ // If there are any dangling uses then replace the argument with one
+ // generated by the type converter. This is necessary as the cast must
+ // persist in the IR after conversion.
+ if (!origArg->use_empty()) {
+ rewriter.setInsertionPointToStart(newBlock);
+ auto *newOp = typeConverter->materializeConversion(
+ rewriter, origArg->getType(), llvm::None, loc);
+ origArg->replaceAllUsesWith(newOp->getResult(0));
+ }
continue;
}
// operation.
// FIXME(riverriddle) This should check that the result type and operand
// type are the same, otherwise it should force a conversion to be
- // materialized. This works around a current limitation with regards to
- // region entry argument type conversion.
- if (op->getNumOperands() == 1) {
- op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
- op->destroy();
+ // materialized.
+ if (argInfo->newArgSize == 1) {
+ origArg->replaceAllUsesWith(
+ mapping.lookupOrDefault(newBlock->getArgument(argInfo->newArgIdx)));
continue;
}
- // Otherwise, if there are any dangling uses then replace the fake
- // conversion operation with one generated by the type converter. This
- // is necessary as the cast must persist in the IR after conversion.
- auto *opResult = op->getResult(0);
- if (!opResult->use_empty()) {
- rewriter.setInsertionPointToStart(block);
- SmallVector<Value *, 1> operands(op->getOperands());
- auto *newOp = typeConverter->materializeConversion(
- rewriter, opResult->getType(), operands, op->getLoc());
- opResult->replaceAllUsesWith(newOp->getResult(0));
- }
- op->destroy();
+ // Otherwise this is a 1->N value mapping.
+ Value *castValue = argInfo->castValue;
+ assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping");
+
+ // If the argument is still used, replace it with the generated cast.
+ if (!origArg->use_empty())
+ origArg->replaceAllUsesWith(mapping.lookupOrDefault(castValue));
+
+ // If all users of the cast were removed, we can drop it. Otherwise, keep
+ // the operation alive and let the user handle any remaining usages.
+ if (castValue->use_empty())
+ castValue->getDefiningOp()->erase();
}
+
+ // Drop the original block now the rewrites were applied.
+ delete origBlock;
}
}
-/// Converts the signature of the given entry block.
-LogicalResult ArgConverter::convertSignature(Block *block,
- ConversionValueMapping &mapping) {
+//===----------------------------------------------------------------------===//
+// Conversion
+
+Block *ArgConverter::convertSignature(Block *block,
+ ConversionValueMapping &mapping) {
if (auto conversion = typeConverter->convertBlockSignature(block))
- return applySignatureConversion(block, *conversion, mapping), success();
- return failure();
+ return applySignatureConversion(block, *conversion, mapping);
+ return nullptr;
}
-/// Apply the given signature conversion on the given block.
-void ArgConverter::applySignatureConversion(
+Block *ArgConverter::applySignatureConversion(
Block *block, TypeConverter::SignatureConversion &signatureConversion,
ConversionValueMapping &mapping) {
+ // If no arguments are being changed or added, there is nothing to do.
unsigned origArgCount = block->getNumArguments();
auto convertedTypes = signatureConversion.getConvertedTypes();
if (origArgCount == 0 && convertedTypes.empty())
- return;
+ return block;
- SmallVector<Value *, 4> newArgRange(block->addArguments(convertedTypes));
- ArrayRef<Value *> newArgRef(newArgRange);
+ // Split the block at the beginning to get a new block to use for the updated
+ // signature.
+ Block *newBlock = block->splitBlock(block->begin());
+ block->replaceAllUsesWith(newBlock);
+
+ SmallVector<Value *, 4> newArgRange(newBlock->addArguments(convertedTypes));
+ ArrayRef<Value *> newArgs(newArgRange);
// Remap each of the original arguments as determined by the signature
// conversion.
- auto &newArgMapping = argMapping[block];
+ ConvertedBlockInfo info(block);
+ info.argInfo.resize(origArgCount);
+
OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(block);
+ rewriter.setInsertionPointToStart(newBlock);
for (unsigned i = 0; i != origArgCount; ++i) {
- ArrayRef<Value *> remappedValues;
- if (auto &inputMap = signatureConversion.getInputMapping(i)) {
- // If inputMap->replacementValue is not nullptr, then the argument is
- // dropped and a replacement value is provided to be the remappedValue.
- if (inputMap->replacementValue) {
- assert(inputMap->size == 0 &&
- "invalid to provide a replacement value when the argument isn't "
- "dropped");
- remappedValues = inputMap->replacementValue;
- } else
- remappedValues = newArgRef.slice(inputMap->inputNo, inputMap->size);
+ auto inputMap = signatureConversion.getInputMapping(i);
+ if (!inputMap)
+ continue;
+ BlockArgument *origArg = block->getArgument(i);
+
+ // If inputMap->replacementValue is not nullptr, then the argument is
+ // dropped and a replacement value is provided to be the remappedValue.
+ if (inputMap->replacementValue) {
+ assert(inputMap->size == 0 &&
+ "invalid to provide a replacement value when the argument isn't "
+ "dropped");
+ mapping.map(origArg, inputMap->replacementValue);
+ continue;
}
- BlockArgument *arg = block->getArgument(i);
- newArgMapping.push_back(convertArgument(arg, remappedValues, mapping));
- }
+ // If this is a 1->1 mapping, then map the argument directly.
+ if (inputMap->size == 1) {
+ mapping.map(origArg, newArgs[inputMap->inputNo]);
+ info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size);
+ continue;
+ }
- // Erase all of the original arguments.
- for (unsigned i = 0; i != origArgCount; ++i)
- block->eraseArgument(0, /*updatePredTerms=*/false);
-}
-
-/// Convert the given block argument given the provided set of new argument
-/// values that are to replace it. This function returns the operation used
-/// to perform the conversion.
-Operation *ArgConverter::convertArgument(BlockArgument *origArg,
- ArrayRef<Value *> newValues,
- ConversionValueMapping &mapping) {
- // Handle the cases of 1->0 or 1->1 mappings.
- if (newValues.size() < 2) {
- // Create a temporary producer for the argument during the conversion
- // process.
- auto *cast = createCast(newValues, origArg->getType());
- origArg->replaceAllUsesWith(cast->getResult(0));
-
- // Insert a mapping between this argument and the one that is replacing
- // it.
- if (!newValues.empty())
- mapping.map(cast->getResult(0), newValues[0]);
- return cast;
+ // Otherwise, this is a 1->N mapping. Call into the provided type converter
+ // to pack the new values.
+ auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
+ Operation *cast = typeConverter->materializeConversion(
+ rewriter, origArg->getType(), replArgs, loc);
+ assert(cast->getNumResults() == 1 &&
+ cast->getNumOperands() == replArgs.size());
+ mapping.map(origArg, cast->getResult(0));
+ info.argInfo[i] =
+ ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0));
}
- // Otherwise, this is a 1->N mapping. Call into the provided type converter
- // to pack the new values.
- auto *cast = typeConverter->materializeConversion(
- rewriter, origArg->getType(), newValues, loc);
- assert(cast->getNumResults() == 1 &&
- cast->getNumOperands() == newValues.size());
- origArg->replaceAllUsesWith(cast->getResult(0));
- return cast;
-}
-
-/// A utility function used to create a conversion cast operation with the
-/// given input and result types.
-Operation *ArgConverter::createCast(ArrayRef<Value *> inputs, Type outputType) {
- return Operation::create(loc, castOpName, outputType, inputs, llvm::None,
- llvm::None, 0, false);
+ // Remove the original block from the region and return the new one.
+ newBlock->getParent()->getBlocks().remove(block);
+ conversionInfo.insert({newBlock, std::move(info)});
+ return newBlock;
}
//===----------------------------------------------------------------------===//
LogicalResult convertBlockSignature(Block *block);
/// Apply a signature conversion on the given region.
- void applySignatureConversion(Region *region,
- TypeConverter::SignatureConversion &conversion);
+ Block *
+ applySignatureConversion(Region *region,
+ TypeConverter::SignatureConversion &conversion);
/// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ArrayRef<Value *> newValues,
}
// Undo the type conversion.
case BlockActionKind::TypeConversion: {
- argConverter.discardPendingRewrites(action.block);
+ argConverter.discardRewrites(action.block);
break;
}
}
if (argConverter.typeConverter && repl.op->getNumRegions()) {
for (auto ®ion : repl.op->getRegions())
for (auto &block : region)
- argConverter.cancelPendingRewrites(&block);
+ argConverter.notifyBlockRemoved(&block);
}
}
for (auto &repl : llvm::reverse(replacements))
repl.op->erase();
- argConverter.applyRewrites();
+ argConverter.applyRewrites(mapping);
}
LogicalResult
// * The block has already been converted.
// * This is an entry block, these are converted explicitly via patterns.
if (!argConverter.typeConverter || argConverter.hasBeenConverted(block) ||
- block->isEntryBlock())
+ !block->getParent() || block->isEntryBlock())
return success();
// Otherwise, try to convert the block signature.
- if (failed(argConverter.convertSignature(block, mapping)))
- return failure();
- blockActions.push_back(BlockAction::getTypeConversion(block));
- return success();
+ Block *newBlock = argConverter.convertSignature(block, mapping);
+ if (newBlock)
+ blockActions.push_back(BlockAction::getTypeConversion(newBlock));
+ return success(newBlock);
}
-void ConversionPatternRewriterImpl::applySignatureConversion(
+Block *ConversionPatternRewriterImpl::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion) {
if (!region->empty()) {
- argConverter.applySignatureConversion(®ion->front(), conversion,
- mapping);
- blockActions.push_back(BlockAction::getTypeConversion(®ion->front()));
+ Block *newEntry = argConverter.applySignatureConversion(
+ ®ion->front(), conversion, mapping);
+ blockActions.push_back(BlockAction::getTypeConversion(newEntry));
+ return newEntry;
}
+ return nullptr;
}
void ConversionPatternRewriterImpl::replaceOp(
}
/// Apply a signature conversion to the entry block of the given region.
-void ConversionPatternRewriter::applySignatureConversion(
+Block *ConversionPatternRewriter::applySignatureConversion(
Region *region, TypeConverter::SignatureConversion &conversion) {
- impl->applySignatureConversion(region, conversion);
+ return impl->applySignatureConversion(region, conversion);
}
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument *from,
return success();
for (auto ®ion : op->getRegions()) {
- for (auto &block : region)
+ for (auto &block : llvm::make_early_inc_range(region))
if (failed(rewriter.getImpl().convertBlockSignature(&block)))
return failure();
}