From 42b3fe833502390d44d2df126048c8310dffa9bd Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 17 Dec 2019 12:09:33 -0800 Subject: [PATCH] Make it possible to override the lowering of MemRef to the LLVM dialect. NFC. The lowering of MemRef types to the LLVM dialect is connected to the underlying runtime representation of structured memory buffers. It has changed several times in the past and reached the current state of a LLVM structured-typed descriptor containing two pointers and all sizes. In several reported use cases, a different, often simpler, lowering scheme is required. For example, lowering statically-shaped memrefs to bare LLVM pointers to simplify aliasing annotation. Split the pattern population functions into those include memref-related operations and the remaining ones. Users are expected to extend TypeConverter::convertType to handle the memref types differently. PiperOrigin-RevId: 286030610 --- mlir/g3doc/ConversionToLLVMDialect.md | 8 ++--- .../StandardToLLVM/ConvertStandardToLLVMPass.h | 11 +++++++ .../StandardToLLVM/ConvertStandardToLLVM.cpp | 37 +++++++++++++++------- 3 files changed, 40 insertions(+), 16 deletions(-) diff --git a/mlir/g3doc/ConversionToLLVMDialect.md b/mlir/g3doc/ConversionToLLVMDialect.md index 3881ee0..19403e2 100644 --- a/mlir/g3doc/ConversionToLLVMDialect.md +++ b/mlir/g3doc/ConversionToLLVMDialect.md @@ -302,11 +302,11 @@ llvm.func @bar(%arg0: !llvm.i64) { llvm.call @foo(%16) : (!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) -> () llvm.return } - - - ``` +*This convention may or may not apply if the conversion of MemRef types is +overridden by the user.* + ## Repeated Successor Removal Since the goal of the LLVM IR dialect is to reflect LLVM IR in MLIR, the dialect @@ -349,7 +349,7 @@ before the conversion to the LLVM IR dialect: llvm.br ^bb1(%2 : !llvm.i32) ``` -## Memref Model +## Default Memref Model ### Memref Descriptor diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index c5c17b3..d49c1c2 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -52,6 +52,17 @@ using LLVMPatternListFiller = using LLVMTypeConverterMaker = std::function(MLIRContext *)>; +/// Collect a set of patterns to convert memory-related operations from the +/// Standard dialect to the LLVM dialect, excluding the memory-related +/// operations. +void populateStdToLLVMMemoryConversionPatters( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +/// Collect a set of patterns to convert from the Standard dialect to the LLVM +/// dialect, excluding the memory-related operations. +void populateStdToLLVMNonMemoryConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + /// Collect a set of patterns to convert from the Standard dialect to LLVM. void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 9b2113a..51cdd72 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -2020,7 +2020,7 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) { } /// Collect a set of patterns to convert from the Standard dialect to LLVM. -void mlir::populateStdToLLVMConversionPatterns( +void mlir::populateStdToLLVMNonMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // FIXME: this should be tablegen'ed // clang-format off @@ -2035,7 +2035,6 @@ void mlir::populateStdToLLVMConversionPatterns( CmpIOpLowering, CondBranchOpLowering, ConstLLVMOpLowering, - DimOpLowering, DivFOpLowering, DivISOpLowering, DivIUOpLowering, @@ -2045,10 +2044,7 @@ void mlir::populateStdToLLVMConversionPatterns( Log2OpLowering, FPExtLowering, FPTruncLowering, - FuncOpConversion, IndexCastOpLowering, - LoadOpLowering, - MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, @@ -2061,22 +2057,39 @@ void mlir::populateStdToLLVMConversionPatterns( SignExtendIOpLowering, SplatOpLowering, SplatNdOpLowering, - StoreOpLowering, SubFOpLowering, SubIOpLowering, - SubViewOpLowering, TanhOpLowering, TruncateIOpLowering, - ViewOpLowering, XOrOpLowering, ZeroExtendIOpLowering>(*converter.getDialect(), converter); - patterns.insert< - AllocOpLowering, - DeallocOpLowering>( - *converter.getDialect(), converter, clUseAlloca.getValue()); // clang-format on } +void mlir::populateStdToLLVMMemoryConversionPatters( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + // clang-format off + patterns.insert< + DimOpLowering, + FuncOpConversion, + LoadOpLowering, + MemRefCastOpLowering, + StoreOpLowering, + SubViewOpLowering, + ViewOpLowering>(*converter.getDialect(), converter); + patterns.insert< + AllocOpLowering, + DeallocOpLowering>( + *converter.getDialect(), converter, clUseAlloca.getValue()); + // clang-format on +} + +void mlir::populateStdToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); + populateStdToLLVMMemoryConversionPatters(converter, patterns); +} + // Convert types using the stored LLVM IR module. Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); } -- 2.7.4