From 9a1b6fec79c3ae2bdeab06c887adc896daf95eb0 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 4 Jul 2019 04:41:02 -0700 Subject: [PATCH] Make ConvertStandardToLLVMPass extendable with other patterns Extend the LLVM lowering pass to accept callbacks that construct an instance of (a subclass of) LLVMTypeConverter and populate a list of conversion patterns. These callbacks will be called when the pass processes a module and their results will be used to set up the dialect conversion infrastructure. Clients can now provide additional conversion patterns to avoid the need of materializing type conversions between LLVM and other types. PiperOrigin-RevId: 256532415 --- .../StandardToLLVM/ConvertStandardToLLVMPass.h | 38 +++++++++++++++- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 50 +++++++++++++++++++--- 2 files changed, 79 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index 572cd56..8a33b75 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -18,6 +18,7 @@ #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ +#include "llvm/ADT/STLExtras.h" #include #include @@ -28,21 +29,54 @@ class Module; namespace mlir { class DialectConversion; class LLVMTypeConverter; +class MLIRContext; class ModuleOp; using Module = ModuleOp; class ModulePassBase; class RewritePattern; class Type; +// Owning list of rewriting patterns. using OwningRewritePatternList = std::vector>; -/// Creates a pass to convert Standard dialects into the LLVMIR dialect. -ModulePassBase *createConvertToLLVMIRPass(); +/// Type for a callback constructing the owning list of patterns for the +/// conversion to the LLVMIR dialect. The callback is expected to append +/// patterns to the owning list provided as the second argument. +using LLVMPatternListFiller = + std::function; + +/// Type for a callback constructing the type converter for the conversion to +/// the LLVMIR dialect. The callback is expected to return an instance of the +/// converter. +using LLVMTypeConverterMaker = + std::function(MLIRContext *)>; /// Collect a set of patterns to convert from the Standard dialect to LLVM. void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); +/// Creates a pass to convert the Standard dialect into the LLVMIR dialect. +ModulePassBase *createConvertToLLVMIRPass(); + +/// Creates a pass to convert operations to the LLVMIR dialect. The conversion +/// is defined by a list of patterns and a type converter that will be obtained +/// during the pass using the provided callbacks. +ModulePassBase * +createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller, + LLVMTypeConverterMaker typeConverterMaker); + +/// Creates a pass to convert operations to the LLVMIR dialect. The conversion +/// is defined by a list of patterns obtained during the pass using the provided +/// callback and an optional type conversion class, an instance is created +/// during the pass. +template +ModulePassBase * +createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller) { + return createConvertToLLVMIRPass(patternListFiller, [](MLIRContext *context) { + return llvm::make_unique(context); + }); +} + namespace LLVM { /// Make argument-taking successors of each block distinct. PHI nodes in LLVM /// IR use the predecessor ID to identify which value to take. They do not diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index a0b911e..ca59c0b 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1005,29 +1005,65 @@ LLVMTypeConverter::convertSignature(FunctionType type, return failure(); } +/// Create an instance of LLVMTypeConverter in the given context. +static std::unique_ptr +makeStandardToLLVMTypeConverter(MLIRContext *context) { + return llvm::make_unique(context); +} + namespace { -/// A pass converting MLIR Standard operations into the LLVM IR dialect. +/// A pass converting MLIR operations into the LLVM IR dialect. struct LLVMLoweringPass : public ModulePass { + // By default, the patterns are those converting Standard operations to the + // LLVMIR dialect. + explicit LLVMLoweringPass( + LLVMPatternListFiller patternListFiller = + populateStdToLLVMConversionPatterns, + LLVMTypeConverterMaker converterBuilder = makeStandardToLLVMTypeConverter) + : patternListFiller(patternListFiller), + typeConverterMaker(converterBuilder) {} + // Run the dialect converter on the module. void runOnModule() override { + if (!typeConverterMaker || !patternListFiller) + return signalPassFailure(); + Module m = getModule(); LLVM::ensureDistinctSuccessors(m); - LLVMTypeConverter converter(&getContext()); + std::unique_ptr typeConverter = + typeConverterMaker(&getContext()); + if (!typeConverter) + return signalPassFailure(); + OwningRewritePatternList patterns; - populateStdToLLVMConversionPatterns(converter, patterns); + patternListFiller(*typeConverter, patterns); ConversionTarget target(getContext()); target.addLegalDialect(); - if (failed( - applyConversionPatterns(m, target, converter, std::move(patterns)))) + if (failed(applyConversionPatterns(m, target, *typeConverter, + std::move(patterns)))) signalPassFailure(); } + + // Callback for creating a list of patterns. It is called every time in + // runOnModule since applyConversionPatterns consumes the list. + LLVMPatternListFiller patternListFiller; + + // Callback for creating an instance of type converter. The converter + // constructor needs an MLIRContext, which is not available until runOnModule. + LLVMTypeConverterMaker typeConverterMaker; }; -} // end anonymous namespace +} // end namespace ModulePassBase *mlir::createConvertToLLVMIRPass() { - return new LLVMLoweringPass(); + return new LLVMLoweringPass; +} + +ModulePassBase * +createConvertToLLVMIRPass(LLVMPatternListFiller patternListFiller, + LLVMTypeConverterMaker typeConverterMaker) { + return new LLVMLoweringPass(patternListFiller, typeConverterMaker); } static PassRegistration -- 2.7.4