From: River Riddle Date: Tue, 16 Jul 2019 18:57:45 +0000 (-0700) Subject: Refactor DialectConversion to support different conversion modes. X-Git-Tag: llvmorg-11-init~1466^2~1169 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=2b9855b5b4e964305510caf4f170107b40fd847e;p=platform%2Fupstream%2Fllvm.git Refactor DialectConversion to support different conversion modes. Users generally want several different modes of conversion. This cl refactors DialectConversion to provide two: * Partial (applyPartialConversion) - This mode allows for illegal operations to exist in the IR, and does not fail if an operation fails to be legalized. * Full (applyFullConversion) - This mode fails if any operation is not properly legalized to the conversion target. This allows for ensuring that the IR after a conversion only contains operations legal for the target. PiperOrigin-RevId: 258412243 --- diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index 392d9be..989915e 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -421,8 +421,7 @@ LogicalResult linalg::convertToLLVM(mlir::ModuleOp module) { ConversionTarget target(*module.getContext()); target.addLegalDialect(); - return applyConversionPatterns(module, target, converter, - std::move(patterns)); + return applyFullConversion(module, target, converter, std::move(patterns)); } namespace { diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index 54d1f55..e4eaca8 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -160,8 +160,8 @@ LogicalResult linalg::convertLinalg3ToLLVM(ModuleOp module) { ConversionTarget target(*module.getContext()); target.addLegalDialect(); - if (failed(applyConversionPatterns(module, target, converter, - std::move(patterns)))) + if (failed( + applyFullConversion(module, target, converter, std::move(patterns)))) return failure(); return success(); diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index 5677f35..e4df917 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -132,8 +132,8 @@ struct EarlyLoweringPass : public FunctionPass { OwningRewritePatternList patterns; RewriteListBuilder::build(patterns, &getContext()); - if (failed(applyConversionPatterns(getFunction(), target, - std::move(patterns)))) { + if (failed(applyPartialConversion(getFunction(), target, + std::move(patterns)))) { emitError(mlir::UnknownLoc::get(&getContext()), "Error lowering Toy\n"); signalPassFailure(); } diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index ca8185c..ebc81ef 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -356,8 +356,8 @@ struct LateLoweringPass : public ModulePass { target.addLegalDialect(); target.addLegalOp(); - if (failed(applyConversionPatterns(getModule(), target, typeConverter, - std::move(toyPatterns)))) { + if (failed(applyPartialConversion(getModule(), target, typeConverter, + std::move(toyPatterns)))) { emitError(UnknownLoc::get(getModule().getContext()), "Error lowering Toy\n"); signalPassFailure(); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 2e8ecfa..79f0d38 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -349,31 +349,61 @@ private: }; //===----------------------------------------------------------------------===// -// Conversion Application +// Op Conversion Entry Points //===----------------------------------------------------------------------===// -/// Convert the given module with the provided conversion patterns and type -/// conversion object. This function returns failure if a type conversion -/// failed. -LLVM_NODISCARD LogicalResult applyConversionPatterns( +/// Apply a partial conversion on the given operations, and all nested +/// operations. This method converts as many operations to the target as +/// possible, ignoring operations that failed to legalize. This method only +/// returns failure if there are unreachable blocks in any of the regions nested +/// within 'ops'. +LLVM_NODISCARD LogicalResult +applyPartialConversion(ArrayRef ops, ConversionTarget &target, + OwningRewritePatternList &&patterns); +LLVM_NODISCARD LogicalResult +applyPartialConversion(Operation *op, ConversionTarget &target, + OwningRewritePatternList &&patterns); + +/// Apply a complete conversion on the given operations, and all nested +/// operations. This method returns failure if the conversion of any operation +/// fails, or if there are unreachable blocks in any of the regions nested +/// within 'ops'. +LLVM_NODISCARD LogicalResult +applyFullConversion(ArrayRef ops, ConversionTarget &target, + OwningRewritePatternList &&patterns); +LLVM_NODISCARD LogicalResult +applyFullConversion(Operation *op, ConversionTarget &target, + OwningRewritePatternList &&patterns); + +//===----------------------------------------------------------------------===// +// Op + Type Conversion Entry Points +//===----------------------------------------------------------------------===// + +/// Apply a partial conversion on the function operations within the given +/// module. This method returns failure if a type conversion was encountered. +LLVM_NODISCARD LogicalResult applyPartialConversion( ModuleOp module, ConversionTarget &target, TypeConverter &converter, OwningRewritePatternList &&patterns); -/// Convert the given functions with the provided conversion patterns. This -/// function returns failure if a type conversion failed. -LLVM_NODISCARD -LogicalResult applyConversionPatterns(MutableArrayRef fns, - ConversionTarget &target, - TypeConverter &converter, - OwningRewritePatternList &&patterns); - -/// Convert the given function with the provided conversion patterns. This will -/// convert as many of the operations within 'fn' as possible given the set of -/// patterns. -LLVM_NODISCARD -LogicalResult applyConversionPatterns(FuncOp fn, ConversionTarget &target, - OwningRewritePatternList &&patterns); +/// Apply a partial conversion on the given function operations. This method +/// returns failure if a type conversion was encountered. +LLVM_NODISCARD LogicalResult applyPartialConversion( + MutableArrayRef fns, ConversionTarget &target, + TypeConverter &converter, OwningRewritePatternList &&patterns); + +/// Apply a full conversion on the function operations within the given +/// module. This method returns failure if a type conversion was encountered, or +/// if the conversion of any operations failed. +LLVM_NODISCARD LogicalResult applyFullConversion( + ModuleOp module, ConversionTarget &target, TypeConverter &converter, + OwningRewritePatternList &&patterns); +/// Apply a partial conversion on the given function operations. This method +/// returns failure if a type conversion was encountered, or if the conversion +/// of any operation failed. +LLVM_NODISCARD LogicalResult applyFullConversion( + MutableArrayRef fns, ConversionTarget &target, + TypeConverter &converter, OwningRewritePatternList &&patterns); } // end namespace mlir #endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_ diff --git a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp index 5064bbaa..9c2053d 100644 --- a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp +++ b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp @@ -282,7 +282,7 @@ void ControlFlowToCFGPass::runOnFunction() { ConversionTarget target(getContext()); target.addLegalDialect(); if (failed( - applyConversionPatterns(getFunction(), target, std::move(patterns)))) + applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 2e8e313..bc5cbfa 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1064,13 +1064,13 @@ struct LLVMLoweringPass : public ModulePass { ConversionTarget target(getContext()); target.addLegalDialect(); - if (failed(applyConversionPatterns(m, target, *typeConverter, - std::move(patterns)))) + if (failed(applyPartialConversion(m, target, *typeConverter, + std::move(patterns)))) signalPassFailure(); } // Callback for creating a list of patterns. It is called every time in - // runOnModule since applyConversionPatterns consumes the list. + // runOnModule since applyPartialConversion consumes the list. LLVMPatternListFiller patternListFiller; // Callback for creating an instance of type converter. The converter diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index 298f978..c7ea50fa5 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -763,8 +763,8 @@ void LowerLinalgToLLVMPass::runOnModule() { ConversionTarget target(getContext()); target.addLegalDialect(); - if (failed(applyConversionPatterns(module, target, converter, - std::move(patterns)))) { + if (failed(applyPartialConversion(module, target, converter, + std::move(patterns)))) { signalPassFailure(); } diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index a0fba2c..fd01ad9 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -841,30 +841,48 @@ void OperationLegalizer::computeLegalizationGraphBenefit() { } //===----------------------------------------------------------------------===// -// FunctionConverter +// OperationConverter //===----------------------------------------------------------------------===// namespace { -// This class converts a single function using the given pattern matcher. If a +enum OpConversionMode { + // In this mode, the conversion will ignore failed conversions to allow + // illegal operations to co-exist in the IR. + Partial, + + // In this mode, all operations must be legal for the given target for the + // conversion to succeeed. + Full, +}; + +// This class converts operations using the given pattern matcher. If a // TypeConverter object is provided, then the types of block arguments will be // converted using the appropriate 'convertType' calls. -struct FunctionConverter { - explicit FunctionConverter(ConversionTarget &target, - OwningRewritePatternList &patterns, - TypeConverter *conversion = nullptr) - : typeConverter(conversion), opLegalizer(target, patterns) {} +struct OperationConverter { + explicit OperationConverter(ConversionTarget &target, + OwningRewritePatternList &patterns, + OpConversionMode mode, + TypeConverter *conversion = nullptr) + : typeConverter(conversion), opLegalizer(target, patterns), mode(mode) {} /// Converts the given function to the conversion target. Returns failure on - /// error, success otherwise. If 'signatureConversion' is provided, the - /// arguments of the entry block are updated accordingly. + /// error, success otherwise. LogicalResult convertFunction(FuncOp f, - TypeConverter::SignatureConversion *signatureConversion); + TypeConverter::SignatureConversion &signatureConversion); + + /// Converts the given operations to the conversion target. + LogicalResult convertOperations(ArrayRef ops); private: /// Converts a block or operation with the given rewriter. LogicalResult convert(DialectConversionRewriter &rewriter, llvm::PointerUnion &ptr); + /// Converts a set of blocks/operations with the given rewriter. + LogicalResult + convert(DialectConversionRewriter &rewriter, + std::vector> &toConvert); + /// Recursively collect all of the blocks, and operations, to convert from /// within 'region'. LogicalResult computeConversionSet( @@ -876,11 +894,14 @@ private: /// The legalizer to use when converting operations. OperationLegalizer opLegalizer; + + /// The conversion mode to use when legalizing operations. + OpConversionMode mode; }; } // end anonymous namespace /// Recursively collect all of the blocks to convert from within 'region'. -LogicalResult FunctionConverter::computeConversionSet( +LogicalResult OperationConverter::computeConversionSet( Region ®ion, std::vector> &toConvert) { if (region.empty()) @@ -919,49 +940,78 @@ LogicalResult FunctionConverter::computeConversionSet( /// Converts a block or operation with the given rewriter. LogicalResult -FunctionConverter::convert(DialectConversionRewriter &rewriter, - llvm::PointerUnion &ptr) { +OperationConverter::convert(DialectConversionRewriter &rewriter, + llvm::PointerUnion &ptr) { // If this is a block, then convert the types of each of the arguments. if (auto *block = ptr.dyn_cast()) { assert(typeConverter && "expected valid type converter"); return rewriter.argConverter.convertArguments(block, rewriter.mapping); } - // Otherwise, this is an operation to legalize. - (void)opLegalizer.legalize(ptr.get(), rewriter); + // Otherwise, legalize the given operation. + auto *op = ptr.get(); + auto result = opLegalizer.legalize(op, rewriter); + + // Failed conversions are only important if this is a full conversion. + if (mode == OpConversionMode::Full && failed(result)) + return op->emitError() << "failed to legalize operation '" << op->getName() + << "'"; + + // In any other case, illegal operations are allowed to remain in the IR. return success(); } -LogicalResult FunctionConverter::convertFunction( - FuncOp f, TypeConverter::SignatureConversion *signatureConversion) { +LogicalResult OperationConverter::convert( + DialectConversionRewriter &rewriter, + std::vector> &toConvert) { + // Convert each operation/block and discard rewrites on failure. + for (auto &it : toConvert) { + if (failed(convert(rewriter, it))) { + rewriter.discardRewrites(); + return failure(); + } + } + + // Otherwise the body conversion succeeded, so apply all rewrites. + rewriter.applyRewrites(); + return success(); +} + +LogicalResult OperationConverter::convertFunction( + FuncOp f, TypeConverter::SignatureConversion &signatureConversion) { // If this is an external function, there is nothing else to do. if (f.isExternal()) return success(); - DialectConversionRewriter rewriter(f.getContext(), typeConverter); - // Update the signature of the entry block. - if (signatureConversion) { - rewriter.argConverter.convertSignature( - &f.getBody().front(), *signatureConversion, rewriter.mapping); - } + DialectConversionRewriter rewriter(f.getContext(), typeConverter); + rewriter.argConverter.convertSignature(&f.getBody().front(), + signatureConversion, rewriter.mapping); - /// Compute the set of operations and blocks to convert. + // Compute the set of operations and blocks to convert. std::vector> toConvert; if (failed(computeConversionSet(f.getBody(), toConvert))) return failure(); + return convert(rewriter, toConvert); +} - // Convert each operation/block and discard rewrites on failure. - for (auto &it : toConvert) { - if (failed(convert(rewriter, it))) { - rewriter.discardRewrites(); - return failure(); - } +/// Converts the given top-level operation to the conversion target. +LogicalResult OperationConverter::convertOperations(ArrayRef ops) { + if (ops.empty()) + return success(); + + /// Compute the set of operations and blocks to convert. + std::vector> toConvert; + for (auto *op : ops) { + toConvert.emplace_back(op); + for (auto ®ion : op->getRegions()) + if (failed(computeConversionSet(region, toConvert))) + return failure(); } - // Otherwise the body conversion succeeded, so apply all rewrites. - rewriter.applyRewrites(); - return success(); + // Rewrite the blocks and operations. + DialectConversionRewriter rewriter(ops.front()->getContext(), typeConverter); + return convert(rewriter, toConvert); } //===----------------------------------------------------------------------===// @@ -1102,34 +1152,59 @@ auto ConversionTarget::getOpAction(OperationName op) const } //===----------------------------------------------------------------------===// -// Conversion Application +// Op Conversion Entry Points //===----------------------------------------------------------------------===// -/// Convert the given module with the provided conversion patterns and type -/// conversion object. If conversion fails for specific functions, those -/// functions remains unmodified. +/// Apply a partial conversion on the given operations, and all nested +/// operations. This method converts as many operations to the target as +/// possible, ignoring operations that failed to legalize. LogicalResult -mlir::applyConversionPatterns(ModuleOp module, ConversionTarget &target, - TypeConverter &converter, - OwningRewritePatternList &&patterns) { - SmallVector allFunctions(module.getOps()); - return applyConversionPatterns(allFunctions, target, converter, - std::move(patterns)); +mlir::applyPartialConversion(ArrayRef ops, + ConversionTarget &target, + OwningRewritePatternList &&patterns) { + OperationConverter converter(target, patterns, OpConversionMode::Partial); + return converter.convertOperations(ops); +} +LogicalResult +mlir::applyPartialConversion(Operation *op, ConversionTarget &target, + OwningRewritePatternList &&patterns) { + return applyPartialConversion(llvm::makeArrayRef(op), target, + std::move(patterns)); +} + +/// Apply a complete conversion on the given operations, and all nested +/// operations. This method will return failure if the conversion of any +/// operation fails. +LogicalResult mlir::applyFullConversion(ArrayRef ops, + ConversionTarget &target, + OwningRewritePatternList &&patterns) { + OperationConverter converter(target, patterns, OpConversionMode::Full); + return converter.convertOperations(ops); } +LogicalResult mlir::applyFullConversion(Operation *op, ConversionTarget &target, + OwningRewritePatternList &&patterns) { + return applyFullConversion(llvm::makeArrayRef(op), target, + std::move(patterns)); +} + +//===----------------------------------------------------------------------===// +// Op + Type Conversion Entry Points +//===----------------------------------------------------------------------===// -/// Convert the given functions with the provided conversion patterns. -LogicalResult mlir::applyConversionPatterns( - MutableArrayRef fns, ConversionTarget &target, - TypeConverter &converter, OwningRewritePatternList &&patterns) { +static LogicalResult applyConversion(MutableArrayRef fns, + ConversionTarget &target, + TypeConverter &converter, + OwningRewritePatternList &&patterns, + OpConversionMode mode) { if (fns.empty()) return success(); // Build the function converter. - auto *ctx = fns.front().getContext(); - FunctionConverter funcConverter(target, patterns, &converter); + OperationConverter funcConverter(target, patterns, mode, &converter); // Try to convert each of the functions within the module. SmallVector argAttrs; + auto *ctx = fns.front().getContext(); for (auto func : fns) { argAttrs.clear(); func.getAllArgAttrs(argAttrs); @@ -1144,20 +1219,53 @@ LogicalResult mlir::applyConversionPatterns( func.setAllArgAttrs(conversion->getConvertedArgAttrs()); // Convert the body of this function. - if (failed(funcConverter.convertFunction(func, &*conversion))) + if (failed(funcConverter.convertFunction(func, *conversion))) return failure(); } return success(); } -/// Convert the given function with the provided conversion patterns. This will -/// convert as many of the operations within 'fn' as possible given the set of -/// patterns. +/// Apply a partial conversion on the function operations within the given +/// module. This method returns failure if a type conversion was encountered. LogicalResult -mlir::applyConversionPatterns(FuncOp fn, ConversionTarget &target, - OwningRewritePatternList &&patterns) { - // Convert the body of this function. - FunctionConverter converter(target, patterns); - return converter.convertFunction(fn, /*signatureConversion=*/nullptr); +mlir::applyPartialConversion(ModuleOp module, ConversionTarget &target, + TypeConverter &converter, + OwningRewritePatternList &&patterns) { + SmallVector allFunctions(module.getOps()); + return applyPartialConversion(allFunctions, target, converter, + std::move(patterns)); +} + +/// Apply a partial conversion on the given function operations. This method +/// returns failure if a type conversion was encountered. +LogicalResult +mlir::applyPartialConversion(MutableArrayRef fns, + ConversionTarget &target, TypeConverter &converter, + OwningRewritePatternList &&patterns) { + return applyConversion(fns, target, converter, std::move(patterns), + OpConversionMode::Partial); +} + +/// Apply a full conversion on the function operations within the given module. +/// This method returns failure if a type conversion was encountered, or if the +/// conversion of any operations failed. +LogicalResult mlir::applyFullConversion(ModuleOp module, + ConversionTarget &target, + TypeConverter &converter, + OwningRewritePatternList &&patterns) { + SmallVector allFunctions(module.getOps()); + return applyFullConversion(allFunctions, target, converter, + std::move(patterns)); +} + +/// Apply a full conversion on the given function operations. This method +/// returns failure if a type conversion was encountered, or if the conversion +/// of any operation failed. +LogicalResult mlir::applyFullConversion(MutableArrayRef fns, + ConversionTarget &target, + TypeConverter &converter, + OwningRewritePatternList &&patterns) { + return applyConversion(fns, target, converter, std::move(patterns), + OpConversionMode::Full); } diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 82b7074..20a9134 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -521,8 +521,8 @@ class LowerAffinePass : public FunctionPass { populateAffineToStdConversionPatterns(patterns, &getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); - if (failed(applyConversionPatterns(getFunction(), target, - std::move(patterns)))) + if (failed( + applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index d4aef38..ed5aea8 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -177,8 +177,8 @@ struct TestLegalizePatternDriver TestTypeConverter converter; TestConversionTarget target(getContext()); - if (failed(applyConversionPatterns(getModule(), target, converter, - std::move(patterns)))) + if (failed(applyPartialConversion(getModule(), target, converter, + std::move(patterns)))) signalPassFailure(); } };