From: River Riddle Date: Thu, 23 May 2019 16:23:33 +0000 (-0700) Subject: Decouple running a conversion from the DialectConversion class. The DialectConver... X-Git-Tag: llvmorg-11-init~1466^2~1626 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=14d1cfbccbad35318b4e5075612e6142ebc479bf;p=platform%2Fupstream%2Fllvm.git Decouple running a conversion from the DialectConversion class. The DialectConversion class is only necessary for type signature changes(block arguments or function arguments). This isn't always desired when performing a dialect conversion. This allows for those conversions without this need to run per function instead of per module. -- PiperOrigin-RevId: 249657549 --- diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 7a0e1a5..b81fc25 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -436,7 +436,8 @@ void linalg::convertToLLVM(mlir::Module &module) { // Convert Linalg ops to the LLVM IR dialect using the converter defined // above. - auto r = Lowering(getDescriptorConverters).convert(&module); + Lowering lowering(getDescriptorConverters); + auto r = applyConverter(module, lowering); (void)r; assert(succeeded(r) && "conversion failed"); } diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 9f2d46c..4a258f8 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -143,7 +143,7 @@ void linalg::convertLinalg3ToLLVM(Module &module) { assert(succeeded(rr) && "affine loop lowering failed"); auto lowering = makeLinalgToLLVMLowering(getConversions); - auto r = lowering->convert(&module); + auto r = applyConverter(module, *lowering); (void)r; assert(succeeded(r) && "conversion failed"); } diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index b6e0703..72ef800 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -134,7 +134,8 @@ protected: /// dialect. struct EarlyLoweringPass : public ModulePass { void runOnModule() override { - if (failed(EarlyLowering().convert(&getModule()))) { + EarlyLowering lowering; + if (failed(applyConverter(getModule(), lowering))) { getModule().getContext()->emitError( mlir::UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n"); diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index 5a0a901..2837807 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -343,8 +343,9 @@ protected: /// and is targeting LLVM otherwise. struct LateLoweringPass : public ModulePass { void runOnModule() override { - // Perform Toy specific lowering - if (failed(LateLowering().convert(&getModule()))) { + // Perform Toy specific lowering. + LateLowering lowering; + if (failed(applyConverter(getModule(), lowering))) { getModule().getContext()->emitError( UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n"); signalPassFailure(); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 3facb7a..d1b3318 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -39,8 +39,8 @@ class Value; /// Base class for the dialect conversion patterns that require type changes. /// Specific conversions must derive this class and implement least one /// `rewrite` method. -/// NOTE: These conversion patterns can only be used with the DialectConversion -/// class. +/// NOTE: These conversion patterns can only be used with the 'apply*' methods +/// below. class DialectConversionPattern : public RewritePattern { public: /// Construct an DialectConversionPattern. `rootName` must correspond to the @@ -112,22 +112,10 @@ private: // match against the list of conversions. On the first match, call // `rewrite` for the operations, and advance to the next iteration. If no // match is found, replicate the operation as is. -/// 3. Update all attributes of function type to point to the new functions. -/// 4. Replace old functions with new functions in the module. -/// If any error happened during the conversion, the pass fails as soon as -/// possible. -/// -/// If conversion fails for a specific function, that functions remains -/// unmodified. Otherwise, successfully converted functions will remain -/// converted. class DialectConversion { public: virtual ~DialectConversion() = default; - /// Run the converter on the provided module. - LLVM_NODISCARD - LogicalResult convert(Module *m); - /// Derived classes must implement this hook to produce a set of conversion /// patterns to apply. They may use `mlirContext` to obtain registered /// dialects or operations. This will be called in the beginning of the @@ -170,6 +158,19 @@ public: SmallVectorImpl &convertedArgAttrs); }; +/// Convert the given module with the provided dialect conversion object. +/// If conversion fails for a specific function, those functions remains +/// unmodified. +LLVM_NODISCARD +LogicalResult applyConverter(Module &module, DialectConversion &converter); + +/// Convert the given function with the provided conversion patterns. This will +/// convert as many of the operations within 'fn' as possible given the set of +/// patterns. +LLVM_NODISCARD +LogicalResult applyConversionPatterns(Function &fn, + OwningRewritePatternList &&patterns); + } // end namespace mlir #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_ diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index a2476dc..347280d 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -1006,9 +1006,9 @@ class LLVMLoweringPass : public ModulePass { public: // Run the dialect converter on the module. void runOnModule() override { - Module *m = &getModule(); - LLVM::ensureDistinctSuccessors(m); - if (failed(impl.convert(m))) + Module &m = getModule(); + LLVM::ensureDistinctSuccessors(&m); + if (failed(applyConverter(m, impl))) signalPassFailure(); } diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 8c2bdb7..cf8a2cc 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -608,7 +608,8 @@ void LowerLinalgToLLVMPass::runOnModule() { signalPassFailure(); // Convert to the LLVM IR dialect using the converter defined above. - if (failed(Lowering().convert(&module))) + Lowering lowering; + if (failed(applyConverter(module, lowering))) signalPassFailure(); } diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index de66d01..389b5ad 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -226,13 +226,13 @@ void DialectConversionPattern::rewrite(Operation *op, // FunctionConverter //===----------------------------------------------------------------------===// namespace { -// This class converts a single function using a given DialectConversion -// structure. +// This class converts a single function using the given pattern matcher. If a +// DialectConversion object is also provided, then the types of block arguments +// will be converted using the appropriate 'convertType' calls. class FunctionConverter { public: - // Constructs a FunctionConverter. - explicit FunctionConverter(MLIRContext *ctx, DialectConversion *conversion, - RewritePatternMatcher &matcher) + explicit FunctionConverter(MLIRContext *ctx, RewritePatternMatcher &matcher, + DialectConversion *conversion = nullptr) : dialectConversion(conversion), matcher(matcher) {} /// Converts the given function to the dialect using hooks defined in @@ -319,11 +319,15 @@ FunctionConverter::convertRegion(DialectConversionRewriter &rewriter, Region ®ion, RegionParent *parent) { assert(!region.empty() && "expected non-empty region"); - // Create the arguments of each of the blocks in the region. - for (Block &block : region) - for (auto *arg : block.getArguments()) - if (failed(convertArgument(rewriter, arg, parent->getLoc()))) - return failure(); + // Create the arguments of each of the blocks in the region. If a type + // converter was not provided, then we don't need to change any of the block + // types. + if (dialectConversion) { + for (Block &block : region) + for (auto *arg : block.getArguments()) + if (failed(convertArgument(rewriter, arg, parent->getLoc()))) + return failure(); + } // Start a DFS-order traversal of the CFG to make sure defs are converted // before uses in dominated blocks. @@ -346,8 +350,8 @@ LogicalResult FunctionConverter::convertFunction(Function *f) { // Rewrite the function body. DialectConversionRewriter rewriter(f); if (failed(convertRegion(rewriter, f->getBody(), f))) { - // Reset any of the converted arguments. - rewriter.argConverter.discardRewrites(); + // Reset any of the generated rewrites. + rewriter.discardRewrites(); return failure(); } @@ -360,24 +364,6 @@ LogicalResult FunctionConverter::convertFunction(Function *f) { // DialectConversion //===----------------------------------------------------------------------===// -namespace { -/// This class represents a function to be converted. It allows for converting -/// the body of functions and the signature in two phases. -struct ConvertedFunction { - ConvertedFunction(Function *fn, FunctionType newType, - ArrayRef newFunctionArgAttrs) - : fn(fn), newType(newType), - newFunctionArgAttrs(newFunctionArgAttrs.begin(), - newFunctionArgAttrs.end()) {} - - /// The function to convert. - Function *fn; - /// The new type and argument attributes for the function. - FunctionType newType; - SmallVector newFunctionArgAttrs; -}; -} // end anonymous namespace - // Create a function type with arguments and results converted, and argument // attributes passed through. FunctionType DialectConversion::convertFunctionSignatureType( @@ -403,21 +389,38 @@ FunctionType DialectConversion::convertFunctionSignatureType( return FunctionType::get(arguments, results, type.getContext()); } -// Converts the module as follows. -// 1. Call `convertFunction` on each function of the module and collect the -// mapping between old and new functions. -// 2. Remap all function attributes in the new functions to point to the new -// functions instead of the old ones. -// 3. Replace old functions with the new in the module. -LogicalResult DialectConversion::convert(Module *module) { - if (!module) - return failure(); +//===----------------------------------------------------------------------===// +// applyConversionPatterns +//===----------------------------------------------------------------------===// +namespace { +/// This class represents a function to be converted. It allows for converting +/// the body of functions and the signature in two phases. +struct ConvertedFunction { + ConvertedFunction(Function *fn, FunctionType newType, + ArrayRef newFunctionArgAttrs) + : fn(fn), newType(newType), + newFunctionArgAttrs(newFunctionArgAttrs.begin(), + newFunctionArgAttrs.end()) {} + + /// The function to convert. + Function *fn; + /// The new type and argument attributes for the function. + FunctionType newType; + SmallVector newFunctionArgAttrs; +}; +} // end anonymous namespace + +/// Convert the given module with the provided dialect conversion object. +/// If conversion fails for a specific function, those functions remains +/// unmodified. +LogicalResult mlir::applyConverter(Module &module, + DialectConversion &converter) { // Grab the conversion patterns from the converter and create the pattern // matcher. - MLIRContext *context = module->getContext(); + MLIRContext *context = module.getContext(); OwningRewritePatternList patterns; - initConverters(patterns, context); + converter.initConverters(patterns, context); RewritePatternMatcher matcher(std::move(patterns)); // Try to convert each of the functions within the module. Defer updating the @@ -426,18 +429,18 @@ LogicalResult DialectConversion::convert(Module *module) { // public signatures of the functions within the module before they are // updated. std::vector toConvert; - toConvert.reserve(module->getFunctions().size()); - for (auto &func : *module) { + toConvert.reserve(module.getFunctions().size()); + for (auto &func : module) { // Convert the function type using the dialect converter. SmallVector newFunctionArgAttrs; - FunctionType newType = convertFunctionSignatureType( + FunctionType newType = converter.convertFunctionSignatureType( func.getType(), func.getAllArgAttrs(), newFunctionArgAttrs); if (!newType || !newType.isa()) return func.emitError("could not convert function type"); // Convert the body of this function. - FunctionConverter converter(context, this, matcher); - if (failed(converter.convertFunction(&func))) + FunctionConverter funcConverter(context, matcher, &converter); + if (failed(funcConverter.convertFunction(&func))) return failure(); // Add function signature to be updated. @@ -453,3 +456,15 @@ LogicalResult DialectConversion::convert(Module *module) { return success(); } + +/// Convert the given function with the provided conversion patterns. This will +/// convert as many of the operations within 'fn' as possible given the set of +/// patterns. +LogicalResult +mlir::applyConversionPatterns(Function &fn, + OwningRewritePatternList &&patterns) { + // Convert the body of this function. + RewritePatternMatcher matcher(std::move(patterns)); + FunctionConverter converter(fn.getContext(), matcher); + return converter.convertFunction(&fn); +}