using namespace mlir::impl;
//===----------------------------------------------------------------------===//
-// ProducerGenerator
+// ArgConverter
//===----------------------------------------------------------------------===//
namespace {
-/// This class provides a simple interface for generating fake producers during
-/// the conversion process. These fake producers are used when replacing the
-/// results of an operation with values of a new, legal, type. The producer
-/// provides a definition for the remaining uses of the old value while they
-/// await conversion.
-struct ProducerGenerator {
- ProducerGenerator(MLIRContext *ctx)
- : producerOpName(kProducerName, ctx), loc(UnknownLoc::get(ctx)) {}
-
- /// Cleanup any generated conversion values. Returns failure if there are any
- /// dangling references to a producer operation, success otherwise.
- LogicalResult cleanupGeneratedOps() {
- for (auto *op : producerOps) {
- if (!op->use_empty()) {
- auto diag = op->getContext()->emitError(loc)
- << "Converter did not convert all uses of replaced value "
- "with illegal type";
- for (auto *user : op->getResult(0)->getUsers())
- diag.attachNote(user->getLoc())
- << "user was not converted : " << *user;
- return diag;
- }
+/// This class provides a simple interface for converting the types of block
+/// arguments. This is done by inserting fake cast operations for the illegal
+/// type that allow for updating the real type to return the correct type.
+struct ArgConverter {
+ ArgConverter(MLIRContext *ctx)
+ : castOpName(kCastName, ctx), loc(UnknownLoc::get(ctx)) {}
+
+ /// Cleanup and undo any generated conversion values.
+ void discardRewrites() {
+ // On failure drop all uses of the cast operation and destroy it.
+ for (auto *op : castOps) {
+ op->getResult(0)->dropAllUses();
+ op->destroy();
+ }
+ castOps.clear();
+ }
+
+ /// Replace usages of the cast operations with the argument directly.
+ void applyRewrites() {
+ // On success, we update the type of the block argument and replace uses of
+ // the cast.
+ for (auto *op : castOps) {
+ op->getOperand(0)->setType(op->getResult(0)->getType());
+ op->getResult(0)->replaceAllUsesWith(op->getOperand(0));
op->destroy();
}
- return success();
}
- /// Generate a producer value for 'oldValue'. These new producers replace all
- /// of the current uses of the original value, and record a mapping between
- /// for replacement with the 'newValue'.
- void generateAndReplace(Value *oldValue, Value *newValue,
- BlockAndValueMapping &mapping) {
- if (oldValue->use_empty())
- return;
-
- // Otherwise, generate a new producer operation for the given value type.
- auto *producer = Operation::create(
- loc, producerOpName, llvm::None, oldValue->getType(), llvm::None,
- llvm::None, 0, false, oldValue->getContext());
-
- // Replace the uses of the old value and record the mapping.
- oldValue->replaceAllUsesWith(producer->getResult(0));
- mapping.map(producer->getResult(0), newValue);
- producerOps.push_back(producer);
+ /// Generate a cast operation for 'arg' that produces the new, legal, type.
+ void castArgument(BlockArgument *arg, Type newType,
+ BlockAndValueMapping &mapping) {
+ // Otherwise, generate a new cast operation for the given value type.
+ auto *cast = Operation::create(loc, castOpName, arg, newType, llvm::None,
+ llvm::None, 0, false, arg->getContext());
+
+ // Replace the uses of the argument and record the mapping.
+ mapping.map(arg, cast->getResult(0));
+ castOps.push_back(cast);
}
/// This is an operation name for a fake operation that is inserted during the
/// conversion process. Operations of this type are guaranteed to never escape
/// the converter.
- static constexpr StringLiteral kProducerName = "__mlir_conversion.producer";
- OperationName producerOpName;
+ static constexpr StringLiteral kCastName = "__mlir_conversion.cast";
+ OperationName castOpName;
- /// This is a collection of producer values that were generated during the
+ /// This is a collection of cast values that were generated during the
/// conversion process.
- std::vector<Operation *> producerOps;
+ std::vector<Operation *> castOps;
/// An instance of the unknown location that is used when generating
/// producers.
UnknownLoc loc;
};
-constexpr StringLiteral ProducerGenerator::kProducerName;
+constexpr StringLiteral ArgConverter::kCastName;
//===----------------------------------------------------------------------===//
// DialectConversionRewriter
/// This class implements a pattern rewriter for DialectConversionPattern
/// patterns. It automatically performs remapping of replaced operation values.
struct DialectConversionRewriter final : public PatternRewriter {
+ /// This class represents one requested operation replacement via 'replaceOp'.
+ struct OpReplacement {
+ OpReplacement() = default;
+ OpReplacement(Operation *op, ArrayRef<Value *> newValues)
+ : op(op), newValues(newValues.begin(), newValues.end()) {}
+
+ Operation *op;
+ SmallVector<Value *, 2> newValues;
+ };
+
DialectConversionRewriter(Function *fn)
- : PatternRewriter(fn), tempGenerator(fn->getContext()) {}
+ : PatternRewriter(fn), argConverter(fn->getContext()) {}
~DialectConversionRewriter() = default;
- // Implement the hook for replacing an operation with new values.
+ /// Cleanup and destroy any generated rewrite operations. This method is
+ /// invoked when the conversion process fails.
+ void discardRewrites() {
+ argConverter.discardRewrites();
+ for (auto *op : createdOps) {
+ op->dropAllDefinedValueUses();
+ op->erase();
+ }
+ }
+
+ /// Apply all requested operation rewrites. This method is invoked when the
+ /// conversion process succeeds.
+ void applyRewrites() {
+ argConverter.applyRewrites();
+
+ // Apply all of the rewrites replacements requested during conversion.
+ for (auto &repl : replacements) {
+ for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i)
+ repl.op->getResult(i)->replaceAllUsesWith(repl.newValues[i]);
+ repl.op->erase();
+ }
+ }
+
+ /// PatternRewriter hook for replacing the results of an operation.
void replaceOp(Operation *op, ArrayRef<Value *> newValues,
ArrayRef<Value *> valuesToRemoveIfDead) override {
assert(newValues.size() == op->getNumResults());
- for (unsigned i = 0, e = newValues.size(); i < e; ++i) {
- Value *result = op->getResult(i);
- if (result->getType() != newValues[i]->getType())
- tempGenerator.generateAndReplace(result, newValues[i], mapping);
- else
- result->replaceAllUsesWith(newValues[i]);
- }
- op->erase();
+ // Create mappings for any type changes.
+ for (unsigned i = 0, e = newValues.size(); i < e; ++i)
+ if (op->getResult(i)->getType() != newValues[i]->getType())
+ mapping.map(op->getResult(i), newValues[i]);
+
+ // Record the requested operation replacement.
+ replacements.emplace_back(op, newValues);
}
- // Implement the hook for creating operations, and make sure that newly
- // created ops are added to the worklist for processing.
+ /// PatternRewriter hook for creating a new operation.
Operation *createOperation(const OperationState &state) override {
- return FuncBuilder::createOperation(state);
+ auto *result = FuncBuilder::createOperation(state);
+ createdOps.push_back(result);
+ return result;
+ }
+
+ /// PatternRewriter hook for updating the root operation in-place.
+ void notifyRootUpdated(Operation *op) override {
+ // The rewriter caches changes to the IR to allow for operating in-place and
+ // backtracking. The rewrite is currently not capable of backtracking
+ // in-place modifications.
+ llvm_unreachable("in-place operation updates are not supported");
}
- void lookupValues(Operation::operand_range operands,
- SmallVectorImpl<Value *> &remapped) {
+ /// Remap the given operands to those with potentially different types.
+ void remapValues(Operation::operand_range operands,
+ SmallVectorImpl<Value *> &remapped) {
remapped.reserve(llvm::size(operands));
for (Value *operand : operands)
remapped.push_back(mapping.lookupOrDefault(operand));
}
- // Mapping between values(blocks) in the original function and in the new
- // function.
+ // Mapping between replaced values that differ in type. This happens when
+ // replacing a value with one of a different type.
BlockAndValueMapping mapping;
- /// Utility used to create temporary producers operations.
- ProducerGenerator tempGenerator;
+ /// Utility used to convert block arguments.
+ ArgConverter argConverter;
+
+ /// Ordered vector of all of the newly created operations during conversion.
+ SmallVector<Operation *, 4> createdOps;
+
+ /// Ordered vector of any requested operation replacements.
+ SmallVector<OpReplacement, 4> replacements;
};
} // end anonymous namespace
// DialectConversionPattern
//===----------------------------------------------------------------------===//
-/// Rewrite the IR rooted at the specified operation with the result of
-/// this pattern, generating any new operations with the specified
-/// builder. If an unexpected error is encountered (an internal
-/// compiler error), it is emitted through the normal MLIR diagnostic
-/// hooks and the IR is left in a valid state.
+/// Rewrite the IR rooted at the specified operation with the result of this
+/// pattern. If an unexpected error is encountered (an internal compiler
+/// error), it is emitted through the normal MLIR diagnostic hooks and the IR is
+/// left in a valid state.
void DialectConversionPattern::rewrite(Operation *op,
PatternRewriter &rewriter) const {
SmallVector<Value *, 4> operands;
auto &dialectRewriter = static_cast<DialectConversionRewriter &>(rewriter);
- dialectRewriter.lookupValues(op->getOperands(), operands);
+ dialectRewriter.remapValues(op->getOperands(), operands);
// If this operation has no successors, invoke the rewrite directly.
if (op->getNumSuccessors() == 0)
// FunctionConverter
//===----------------------------------------------------------------------===//
namespace {
-// Implementation detail class of the DialectConversion utility. Performs
-// function-by-function conversions by creating new functions, filling them in
-// with converted blocks, updating the function attributes, and replacing the
-// old functions with the new ones in the module.
+// This class converts a single function using a given DialectConversion
+// structure.
class FunctionConverter {
public:
// Constructs a FunctionConverter.
RewritePatternMatcher &matcher)
: dialectConversion(conversion), matcher(matcher) {}
- // Converts the given function to the dialect using hooks defined in
- // `dialectConversion`. Returns the converted function or `nullptr` on error.
- Function *convertFunction(Function *f);
+ /// Converts the given function to the dialect using hooks defined in
+ /// `dialectConversion`. Returns failure on error, success otherwise.
+ LogicalResult convertFunction(Function *f);
- // Converts the given region starting from the entry block and following the
- // block successors. Returns failure on error, success otherwise.
+ /// Converts the given region starting from the entry block and following the
+ /// block successors. Returns failure on error, success otherwise.
template <typename RegionParent>
LogicalResult convertRegion(DialectConversionRewriter &rewriter,
Region ®ion, RegionParent *parent);
- // Converts a block by traversing its operations sequentially, attempting to
- // match a pattern. If there is no match, recurses the operations regions if
- // it has any.
+ /// Converts a block by traversing its operations sequentially, attempting to
+ /// match a pattern. If there is no match, recurses the operations regions if
+ /// it has any.
//
- // After converting operations, traverses the successor blocks unless they
- // have been visited already as indicated in `visitedBlocks`.
+ /// After converting operations, traverses the successor blocks unless they
+ /// have been visited already as indicated in `visitedBlocks`.
LogicalResult convertBlock(DialectConversionRewriter &rewriter, Block *block,
DenseSet<Block *> &visitedBlocks);
- // Converts the type of the given block argument. Returns success if the
- // argument type could be successfully converted, failure otherwise.
+ /// Converts the type of the given block argument. Returns success if the
+ /// argument type could be successfully converted, failure otherwise.
LogicalResult convertArgument(DialectConversionRewriter &rewriter,
BlockArgument *arg, Location loc);
- // Pointer to a specific dialect conversion info.
+ /// Pointer to a specific dialect conversion info.
DialectConversion *dialectConversion;
/// The matcher to use when converting operations.
<< "could not convert block argument of type : " << arg->getType();
// Generate a replacement value, with the new type, for this argument.
- if (convertedType != arg->getType()) {
- rewriter.tempGenerator.generateAndReplace(arg, arg, rewriter.mapping);
- arg->setType(convertedType);
- }
+ if (convertedType != arg->getType())
+ rewriter.argConverter.castArgument(arg, convertedType, rewriter.mapping);
return success();
}
if (matcher.matchAndRewrite(&op, rewriter))
continue;
- // If a rewrite wasn't matched, update any mapped operands in place.
- for (auto &operand : op.getOpOperands())
- if (auto *newOperand = rewriter.mapping.lookupOrNull(operand.get()))
- operand.set(newOperand);
-
// Traverse any held regions.
for (auto ®ion : op.getRegions())
if (!region.empty() && failed(convertRegion(rewriter, region, &op)))
return success();
}
-Function *FunctionConverter::convertFunction(Function *f) {
- // Convert the function type using the dialect converter.
- SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
- Type newFunctionType = dialectConversion->convertFunctionSignatureType(
- f->getType(), f->getAllArgAttrs(), newFunctionArgAttrs);
- if (!newFunctionType)
- return f->emitError("could not convert function type"), nullptr;
-
- // Create a new function using the mapped function type and arg attributes.
- auto *newFunc = new Function(f->getLoc(), f->getName().strref(),
- newFunctionType.cast<FunctionType>(),
- f->getAttrs(), newFunctionArgAttrs);
- f->getModule()->getFunctions().push_back(newFunc);
-
- // If this is not an external function, we need to convert the body.
- if (!f->isExternal()) {
- DialectConversionRewriter rewriter(f);
- f->getBody().cloneInto(&newFunc->getBody(), rewriter.mapping,
- f->getContext());
- rewriter.mapping.clear();
- if (failed(convertRegion(rewriter, newFunc->getBody(), &*newFunc))) {
- f->getModule()->getFunctions().pop_back();
- return nullptr;
- }
+LogicalResult FunctionConverter::convertFunction(Function *f) {
+ // If this is an external function, there is nothing else to do.
+ if (f->isExternal())
+ return success();
- // Cleanup any temp producer operations that were generated by the rewriter.
- if (failed(rewriter.tempGenerator.cleanupGeneratedOps())) {
- f->getModule()->getFunctions().pop_back();
- return nullptr;
- }
+ // Rewrite the function body.
+ DialectConversionRewriter rewriter(f);
+ if (failed(convertRegion(rewriter, f->getBody(), f))) {
+ // Reset any of the converted arguments.
+ rewriter.argConverter.discardRewrites();
+ return failure();
}
- return newFunc;
+
+ // Otherwise the conversion succeeded, so apply all rewrites.
+ rewriter.applyRewrites();
+ return success();
}
//===----------------------------------------------------------------------===//
// DialectConversion
//===----------------------------------------------------------------------===//
+namespace {
+/// This class represents a function to be converted. It allows for converting
+/// the body of functions and the signature in two phases.
+struct ConvertedFunction {
+ ConvertedFunction(Function *fn, FunctionType newType,
+ ArrayRef<NamedAttributeList> newFunctionArgAttrs)
+ : fn(fn), newType(newType),
+ newFunctionArgAttrs(newFunctionArgAttrs.begin(),
+ newFunctionArgAttrs.end()) {}
+
+ /// The function to convert.
+ Function *fn;
+ /// The new type and argument attributes for the function.
+ FunctionType newType;
+ SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
+};
+} // end anonymous namespace
+
// Create a function type with arguments and results converted, and argument
// attributes passed through.
FunctionType DialectConversion::convertFunctionSignatureType(
initConverters(patterns, context);
RewritePatternMatcher matcher(std::move(patterns));
- SmallVector<Function *, 0> originalFuncs, convertedFuncs;
- DenseMap<Attribute, FunctionAttr> functionAttrRemapping;
- originalFuncs.reserve(module->getFunctions().size());
- for (auto &func : *module)
- originalFuncs.push_back(&func);
- convertedFuncs.reserve(originalFuncs.size());
-
- // Convert each function.
- FunctionConverter converter(context, this, matcher);
- for (auto *func : originalFuncs) {
- Function *converted = converter.convertFunction(func);
- if (!converted) {
- // Make sure to erase any previously converted functions.
- while (!convertedFuncs.empty())
- convertedFuncs.pop_back_val()->erase();
+ // Try to convert each of the functions within the module. Defer updating the
+ // signatures of the functions until after all of the bodies have been
+ // converted. This allows for the conversion patterns to still rely on the
+ // public signatures of the functions within the module before they are
+ // updated.
+ std::vector<ConvertedFunction> toConvert;
+ toConvert.reserve(module->getFunctions().size());
+ for (auto &func : *module) {
+ // Convert the function type using the dialect converter.
+ SmallVector<NamedAttributeList, 4> newFunctionArgAttrs;
+ FunctionType newType = convertFunctionSignatureType(
+ func.getType(), func.getAllArgAttrs(), newFunctionArgAttrs);
+ if (!newType || !newType.isa<FunctionType>())
+ return func.emitError("could not convert function type");
+
+ // Convert the body of this function.
+ FunctionConverter converter(context, this, matcher);
+ if (failed(converter.convertFunction(&func)))
return failure();
- }
- convertedFuncs.push_back(converted);
- auto origFuncAttr = FunctionAttr::get(func);
- auto convertedFuncAttr = FunctionAttr::get(converted);
- functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr});
+ // Add function signature to be updated.
+ toConvert.emplace_back(&func, newType.cast<FunctionType>(),
+ newFunctionArgAttrs);
}
- // Remap function attributes in the converted functions. Original functions
- // will disappear anyway so there is no need to remap attributes in them.
- for (const auto &funcPair : functionAttrRemapping)
- remapFunctionAttrs(*funcPair.getSecond().getValue(), functionAttrRemapping);
-
- // Remove the original functions from the module and update the names of the
- // converted functions.
- for (unsigned i = 0, e = originalFuncs.size(); i != e; ++i) {
- convertedFuncs[i]->takeName(*originalFuncs[i]);
- originalFuncs[i]->erase();
+ // Finally, update the signatures of all of the converted functions.
+ for (auto &it : toConvert) {
+ it.fn->setType(it.newType);
+ it.fn->setAllArgAttrs(it.newFunctionArgAttrs);
}
return success();