From: River Riddle Date: Wed, 22 May 2019 18:49:04 +0000 (-0700) Subject: Refactor DialectConversion to operate on functions in-place *without* any cloning... X-Git-Tag: llvmorg-11-init~1466^2~1636 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d15d107da104fb8e41feeb37a86228263e79e91b;p=platform%2Fupstream%2Fllvm.git Refactor DialectConversion to operate on functions in-place *without* any cloning. This works by caching all of the requested pattern rewrite operations, e.g. replace operation, and only applying them on a completely successful conversion. -- PiperOrigin-RevId: 249490306 --- diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h index 1a7332d..4a9de2b 100644 --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -178,6 +178,15 @@ public: assert(index < getNumArguments() && "invalid argument number"); argAttrs[index].setAttrs(attributes); } + void setArgAttrs(unsigned index, NamedAttributeList attributes) { + assert(index < getNumArguments() && "invalid argument number"); + argAttrs[index] = attributes; + } + void setAllArgAttrs(ArrayRef attributes) { + assert(attributes.size() == getNumArguments()); + for (unsigned i = 0, e = attributes.size(); i != e; ++i) + argAttrs[i] = attributes[i]; + } /// Return all argument attributes of this function. MutableArrayRef getAllArgAttrs() { return argAttrs; } diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 8bbce1d..3facb7a 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -86,10 +86,9 @@ public: } /// Rewrite the IR rooted at the specified operation with the result of - /// this pattern, generating any new operations with the specified - /// builder. If an unexpected error is encountered (an internal - /// compiler error), it is emitted through the normal MLIR diagnostic - /// hooks and the IR is left in a valid state. + /// this pattern. If an unexpected error is encountered (an internal compiler + /// error), it is emitted through the normal MLIR diagnostic hooks and the IR + /// is left in a valid state. void rewrite(Operation *op, PatternRewriter &rewriter) const final; private: @@ -118,7 +117,9 @@ private: /// If any error happened during the conversion, the pass fails as soon as /// possible. /// -/// If the conversion fails, the module is not modified. +/// If conversion fails for a specific function, that functions remains +/// unmodified. Otherwise, successfully converted functions will remain +/// converted. class DialectConversion { public: virtual ~DialectConversion() = default; diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index ff53880..de66d01 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -26,71 +26,65 @@ using namespace mlir; using namespace mlir::impl; //===----------------------------------------------------------------------===// -// ProducerGenerator +// ArgConverter //===----------------------------------------------------------------------===// namespace { -/// This class provides a simple interface for generating fake producers during -/// the conversion process. These fake producers are used when replacing the -/// results of an operation with values of a new, legal, type. The producer -/// provides a definition for the remaining uses of the old value while they -/// await conversion. -struct ProducerGenerator { - ProducerGenerator(MLIRContext *ctx) - : producerOpName(kProducerName, ctx), loc(UnknownLoc::get(ctx)) {} - - /// Cleanup any generated conversion values. Returns failure if there are any - /// dangling references to a producer operation, success otherwise. - LogicalResult cleanupGeneratedOps() { - for (auto *op : producerOps) { - if (!op->use_empty()) { - auto diag = op->getContext()->emitError(loc) - << "Converter did not convert all uses of replaced value " - "with illegal type"; - for (auto *user : op->getResult(0)->getUsers()) - diag.attachNote(user->getLoc()) - << "user was not converted : " << *user; - return diag; - } +/// This class provides a simple interface for converting the types of block +/// arguments. This is done by inserting fake cast operations for the illegal +/// type that allow for updating the real type to return the correct type. +struct ArgConverter { + ArgConverter(MLIRContext *ctx) + : castOpName(kCastName, ctx), loc(UnknownLoc::get(ctx)) {} + + /// Cleanup and undo any generated conversion values. + void discardRewrites() { + // On failure drop all uses of the cast operation and destroy it. + for (auto *op : castOps) { + op->getResult(0)->dropAllUses(); + op->destroy(); + } + castOps.clear(); + } + + /// Replace usages of the cast operations with the argument directly. + void applyRewrites() { + // On success, we update the type of the block argument and replace uses of + // the cast. + for (auto *op : castOps) { + op->getOperand(0)->setType(op->getResult(0)->getType()); + op->getResult(0)->replaceAllUsesWith(op->getOperand(0)); op->destroy(); } - return success(); } - /// Generate a producer value for 'oldValue'. These new producers replace all - /// of the current uses of the original value, and record a mapping between - /// for replacement with the 'newValue'. - void generateAndReplace(Value *oldValue, Value *newValue, - BlockAndValueMapping &mapping) { - if (oldValue->use_empty()) - return; - - // Otherwise, generate a new producer operation for the given value type. - auto *producer = Operation::create( - loc, producerOpName, llvm::None, oldValue->getType(), llvm::None, - llvm::None, 0, false, oldValue->getContext()); - - // Replace the uses of the old value and record the mapping. - oldValue->replaceAllUsesWith(producer->getResult(0)); - mapping.map(producer->getResult(0), newValue); - producerOps.push_back(producer); + /// Generate a cast operation for 'arg' that produces the new, legal, type. + void castArgument(BlockArgument *arg, Type newType, + BlockAndValueMapping &mapping) { + // Otherwise, generate a new cast operation for the given value type. + auto *cast = Operation::create(loc, castOpName, arg, newType, llvm::None, + llvm::None, 0, false, arg->getContext()); + + // Replace the uses of the argument and record the mapping. + mapping.map(arg, cast->getResult(0)); + castOps.push_back(cast); } /// This is an operation name for a fake operation that is inserted during the /// conversion process. Operations of this type are guaranteed to never escape /// the converter. - static constexpr StringLiteral kProducerName = "__mlir_conversion.producer"; - OperationName producerOpName; + static constexpr StringLiteral kCastName = "__mlir_conversion.cast"; + OperationName castOpName; - /// This is a collection of producer values that were generated during the + /// This is a collection of cast values that were generated during the /// conversion process. - std::vector producerOps; + std::vector castOps; /// An instance of the unknown location that is used when generating /// producers. UnknownLoc loc; }; -constexpr StringLiteral ProducerGenerator::kProducerName; +constexpr StringLiteral ArgConverter::kCastName; //===----------------------------------------------------------------------===// // DialectConversionRewriter @@ -99,43 +93,91 @@ constexpr StringLiteral ProducerGenerator::kProducerName; /// This class implements a pattern rewriter for DialectConversionPattern /// patterns. It automatically performs remapping of replaced operation values. struct DialectConversionRewriter final : public PatternRewriter { + /// This class represents one requested operation replacement via 'replaceOp'. + struct OpReplacement { + OpReplacement() = default; + OpReplacement(Operation *op, ArrayRef newValues) + : op(op), newValues(newValues.begin(), newValues.end()) {} + + Operation *op; + SmallVector newValues; + }; + DialectConversionRewriter(Function *fn) - : PatternRewriter(fn), tempGenerator(fn->getContext()) {} + : PatternRewriter(fn), argConverter(fn->getContext()) {} ~DialectConversionRewriter() = default; - // Implement the hook for replacing an operation with new values. + /// Cleanup and destroy any generated rewrite operations. This method is + /// invoked when the conversion process fails. + void discardRewrites() { + argConverter.discardRewrites(); + for (auto *op : createdOps) { + op->dropAllDefinedValueUses(); + op->erase(); + } + } + + /// Apply all requested operation rewrites. This method is invoked when the + /// conversion process succeeds. + void applyRewrites() { + argConverter.applyRewrites(); + + // Apply all of the rewrites replacements requested during conversion. + for (auto &repl : replacements) { + for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) + repl.op->getResult(i)->replaceAllUsesWith(repl.newValues[i]); + repl.op->erase(); + } + } + + /// PatternRewriter hook for replacing the results of an operation. void replaceOp(Operation *op, ArrayRef newValues, ArrayRef valuesToRemoveIfDead) override { assert(newValues.size() == op->getNumResults()); - for (unsigned i = 0, e = newValues.size(); i < e; ++i) { - Value *result = op->getResult(i); - if (result->getType() != newValues[i]->getType()) - tempGenerator.generateAndReplace(result, newValues[i], mapping); - else - result->replaceAllUsesWith(newValues[i]); - } - op->erase(); + // Create mappings for any type changes. + for (unsigned i = 0, e = newValues.size(); i < e; ++i) + if (op->getResult(i)->getType() != newValues[i]->getType()) + mapping.map(op->getResult(i), newValues[i]); + + // Record the requested operation replacement. + replacements.emplace_back(op, newValues); } - // Implement the hook for creating operations, and make sure that newly - // created ops are added to the worklist for processing. + /// PatternRewriter hook for creating a new operation. Operation *createOperation(const OperationState &state) override { - return FuncBuilder::createOperation(state); + auto *result = FuncBuilder::createOperation(state); + createdOps.push_back(result); + return result; + } + + /// PatternRewriter hook for updating the root operation in-place. + void notifyRootUpdated(Operation *op) override { + // The rewriter caches changes to the IR to allow for operating in-place and + // backtracking. The rewrite is currently not capable of backtracking + // in-place modifications. + llvm_unreachable("in-place operation updates are not supported"); } - void lookupValues(Operation::operand_range operands, - SmallVectorImpl &remapped) { + /// Remap the given operands to those with potentially different types. + void remapValues(Operation::operand_range operands, + SmallVectorImpl &remapped) { remapped.reserve(llvm::size(operands)); for (Value *operand : operands) remapped.push_back(mapping.lookupOrDefault(operand)); } - // Mapping between values(blocks) in the original function and in the new - // function. + // Mapping between replaced values that differ in type. This happens when + // replacing a value with one of a different type. BlockAndValueMapping mapping; - /// Utility used to create temporary producers operations. - ProducerGenerator tempGenerator; + /// Utility used to convert block arguments. + ArgConverter argConverter; + + /// Ordered vector of all of the newly created operations during conversion. + SmallVector createdOps; + + /// Ordered vector of any requested operation replacements. + SmallVector replacements; }; } // end anonymous namespace @@ -143,16 +185,15 @@ struct DialectConversionRewriter final : public PatternRewriter { // DialectConversionPattern //===----------------------------------------------------------------------===// -/// Rewrite the IR rooted at the specified operation with the result of -/// this pattern, generating any new operations with the specified -/// builder. If an unexpected error is encountered (an internal -/// compiler error), it is emitted through the normal MLIR diagnostic -/// hooks and the IR is left in a valid state. +/// Rewrite the IR rooted at the specified operation with the result of this +/// pattern. If an unexpected error is encountered (an internal compiler +/// error), it is emitted through the normal MLIR diagnostic hooks and the IR is +/// left in a valid state. void DialectConversionPattern::rewrite(Operation *op, PatternRewriter &rewriter) const { SmallVector operands; auto &dialectRewriter = static_cast(rewriter); - dialectRewriter.lookupValues(op->getOperands(), operands); + dialectRewriter.remapValues(op->getOperands(), operands); // If this operation has no successors, invoke the rewrite directly. if (op->getNumSuccessors() == 0) @@ -185,10 +226,8 @@ void DialectConversionPattern::rewrite(Operation *op, // FunctionConverter //===----------------------------------------------------------------------===// namespace { -// Implementation detail class of the DialectConversion utility. Performs -// function-by-function conversions by creating new functions, filling them in -// with converted blocks, updating the function attributes, and replacing the -// old functions with the new ones in the module. +// This class converts a single function using a given DialectConversion +// structure. class FunctionConverter { public: // Constructs a FunctionConverter. @@ -196,31 +235,31 @@ public: RewritePatternMatcher &matcher) : dialectConversion(conversion), matcher(matcher) {} - // Converts the given function to the dialect using hooks defined in - // `dialectConversion`. Returns the converted function or `nullptr` on error. - Function *convertFunction(Function *f); + /// Converts the given function to the dialect using hooks defined in + /// `dialectConversion`. Returns failure on error, success otherwise. + LogicalResult convertFunction(Function *f); - // Converts the given region starting from the entry block and following the - // block successors. Returns failure on error, success otherwise. + /// Converts the given region starting from the entry block and following the + /// block successors. Returns failure on error, success otherwise. template LogicalResult convertRegion(DialectConversionRewriter &rewriter, Region ®ion, RegionParent *parent); - // Converts a block by traversing its operations sequentially, attempting to - // match a pattern. If there is no match, recurses the operations regions if - // it has any. + /// Converts a block by traversing its operations sequentially, attempting to + /// match a pattern. If there is no match, recurses the operations regions if + /// it has any. // - // After converting operations, traverses the successor blocks unless they - // have been visited already as indicated in `visitedBlocks`. + /// After converting operations, traverses the successor blocks unless they + /// have been visited already as indicated in `visitedBlocks`. LogicalResult convertBlock(DialectConversionRewriter &rewriter, Block *block, DenseSet &visitedBlocks); - // Converts the type of the given block argument. Returns success if the - // argument type could be successfully converted, failure otherwise. + /// Converts the type of the given block argument. Returns success if the + /// argument type could be successfully converted, failure otherwise. LogicalResult convertArgument(DialectConversionRewriter &rewriter, BlockArgument *arg, Location loc); - // Pointer to a specific dialect conversion info. + /// Pointer to a specific dialect conversion info. DialectConversion *dialectConversion; /// The matcher to use when converting operations. @@ -237,10 +276,8 @@ FunctionConverter::convertArgument(DialectConversionRewriter &rewriter, << "could not convert block argument of type : " << arg->getType(); // Generate a replacement value, with the new type, for this argument. - if (convertedType != arg->getType()) { - rewriter.tempGenerator.generateAndReplace(arg, arg, rewriter.mapping); - arg->setType(convertedType); - } + if (convertedType != arg->getType()) + rewriter.argConverter.castArgument(arg, convertedType, rewriter.mapping); return success(); } @@ -260,11 +297,6 @@ FunctionConverter::convertBlock(DialectConversionRewriter &rewriter, if (matcher.matchAndRewrite(&op, rewriter)) continue; - // If a rewrite wasn't matched, update any mapped operands in place. - for (auto &operand : op.getOpOperands()) - if (auto *newOperand = rewriter.mapping.lookupOrNull(operand.get())) - operand.set(newOperand); - // Traverse any held regions. for (auto ®ion : op.getRegions()) if (!region.empty() && failed(convertRegion(rewriter, region, &op))) @@ -306,44 +338,46 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter, return success(); } -Function *FunctionConverter::convertFunction(Function *f) { - // Convert the function type using the dialect converter. - SmallVector newFunctionArgAttrs; - Type newFunctionType = dialectConversion->convertFunctionSignatureType( - f->getType(), f->getAllArgAttrs(), newFunctionArgAttrs); - if (!newFunctionType) - return f->emitError("could not convert function type"), nullptr; - - // Create a new function using the mapped function type and arg attributes. - auto *newFunc = new Function(f->getLoc(), f->getName().strref(), - newFunctionType.cast(), - f->getAttrs(), newFunctionArgAttrs); - f->getModule()->getFunctions().push_back(newFunc); - - // If this is not an external function, we need to convert the body. - if (!f->isExternal()) { - DialectConversionRewriter rewriter(f); - f->getBody().cloneInto(&newFunc->getBody(), rewriter.mapping, - f->getContext()); - rewriter.mapping.clear(); - if (failed(convertRegion(rewriter, newFunc->getBody(), &*newFunc))) { - f->getModule()->getFunctions().pop_back(); - return nullptr; - } +LogicalResult FunctionConverter::convertFunction(Function *f) { + // If this is an external function, there is nothing else to do. + if (f->isExternal()) + return success(); - // Cleanup any temp producer operations that were generated by the rewriter. - if (failed(rewriter.tempGenerator.cleanupGeneratedOps())) { - f->getModule()->getFunctions().pop_back(); - return nullptr; - } + // Rewrite the function body. + DialectConversionRewriter rewriter(f); + if (failed(convertRegion(rewriter, f->getBody(), f))) { + // Reset any of the converted arguments. + rewriter.argConverter.discardRewrites(); + return failure(); } - return newFunc; + + // Otherwise the conversion succeeded, so apply all rewrites. + rewriter.applyRewrites(); + return success(); } //===----------------------------------------------------------------------===// // DialectConversion //===----------------------------------------------------------------------===// +namespace { +/// This class represents a function to be converted. It allows for converting +/// the body of functions and the signature in two phases. +struct ConvertedFunction { + ConvertedFunction(Function *fn, FunctionType newType, + ArrayRef newFunctionArgAttrs) + : fn(fn), newType(newType), + newFunctionArgAttrs(newFunctionArgAttrs.begin(), + newFunctionArgAttrs.end()) {} + + /// The function to convert. + Function *fn; + /// The new type and argument attributes for the function. + FunctionType newType; + SmallVector newFunctionArgAttrs; +}; +} // end anonymous namespace + // Create a function type with arguments and results converted, and argument // attributes passed through. FunctionType DialectConversion::convertFunctionSignatureType( @@ -386,40 +420,35 @@ LogicalResult DialectConversion::convert(Module *module) { initConverters(patterns, context); RewritePatternMatcher matcher(std::move(patterns)); - SmallVector originalFuncs, convertedFuncs; - DenseMap functionAttrRemapping; - originalFuncs.reserve(module->getFunctions().size()); - for (auto &func : *module) - originalFuncs.push_back(&func); - convertedFuncs.reserve(originalFuncs.size()); - - // Convert each function. - FunctionConverter converter(context, this, matcher); - for (auto *func : originalFuncs) { - Function *converted = converter.convertFunction(func); - if (!converted) { - // Make sure to erase any previously converted functions. - while (!convertedFuncs.empty()) - convertedFuncs.pop_back_val()->erase(); + // Try to convert each of the functions within the module. Defer updating the + // signatures of the functions until after all of the bodies have been + // converted. This allows for the conversion patterns to still rely on the + // public signatures of the functions within the module before they are + // updated. + std::vector toConvert; + toConvert.reserve(module->getFunctions().size()); + for (auto &func : *module) { + // Convert the function type using the dialect converter. + SmallVector newFunctionArgAttrs; + FunctionType newType = convertFunctionSignatureType( + func.getType(), func.getAllArgAttrs(), newFunctionArgAttrs); + if (!newType || !newType.isa()) + return func.emitError("could not convert function type"); + + // Convert the body of this function. + FunctionConverter converter(context, this, matcher); + if (failed(converter.convertFunction(&func))) return failure(); - } - convertedFuncs.push_back(converted); - auto origFuncAttr = FunctionAttr::get(func); - auto convertedFuncAttr = FunctionAttr::get(converted); - functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr}); + // Add function signature to be updated. + toConvert.emplace_back(&func, newType.cast(), + newFunctionArgAttrs); } - // Remap function attributes in the converted functions. Original functions - // will disappear anyway so there is no need to remap attributes in them. - for (const auto &funcPair : functionAttrRemapping) - remapFunctionAttrs(*funcPair.getSecond().getValue(), functionAttrRemapping); - - // Remove the original functions from the module and update the names of the - // converted functions. - for (unsigned i = 0, e = originalFuncs.size(); i != e; ++i) { - convertedFuncs[i]->takeName(*originalFuncs[i]); - originalFuncs[i]->erase(); + // Finally, update the signatures of all of the converted functions. + for (auto &it : toConvert) { + it.fn->setType(it.newType); + it.fn->setAllArgAttrs(it.newFunctionArgAttrs); } return success();