From e0ea706a59b9032b7f3590478080adf4f3e1486a Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 3 Feb 2020 13:49:21 +0100 Subject: [PATCH] [mlir] ConvertStandardToLLVM: do not rely on command line options internally The patterns for converting `std.alloc` and `std.dealoc` can be configured to use `llvm.alloca` instead of calling `malloc` and `free`. This configuration has been only possible through a command-line flag, despite the presence of a (misleading) parameter in the pass constructor. Use the parameter instead and only initalize it from the command line flags if the pass is constructed from the mlir-opt registration. --- .../StandardToLLVM/ConvertStandardToLLVMPass.h | 15 ++++++++----- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 26 ++++++++++++++-------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h index 0822454..4179fff 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -21,7 +21,8 @@ class OwningRewritePatternList; /// Standard dialect to the LLVM dialect, excluding non-memory-related /// operations and FuncOp. void populateStdToLLVMMemoryConversionPatters( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool useAlloca); /// Collect a set of patterns to convert from the Standard dialect to the LLVM /// dialect, excluding the memory-related operations. @@ -33,15 +34,19 @@ void populateStdToLLVMDefaultFuncOpConversionPattern( LLVMTypeConverter &converter, OwningRewritePatternList &patterns); /// Collect a set of default patterns to convert from the Standard dialect to -/// LLVM. +/// LLVM. If `useAlloca` is set, the patterns for AllocOp and DeallocOp will +/// generate `llvm.alloca` instead of calls to "malloc". void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, - OwningRewritePatternList &patterns); + OwningRewritePatternList &patterns, + bool useAlloca = false); /// Collect a set of patterns to convert from the Standard dialect to /// LLVM using the bare pointer calling convention for MemRef function -/// arguments. +/// arguments. If `useAlloca` is set, the patterns for AllocOp and DeallocOp +/// will generate `llvm.alloca` instead of calls to "malloc". void populateStdToLLVMBarePtrConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool useAlloca = false); /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. /// By default stdlib malloc/free are used for allocating MemRef payloads. diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 336481f..d0668ca 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -2255,7 +2255,8 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns( } void mlir::populateStdToLLVMMemoryConversionPatters( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool useAlloca) { // clang-format off patterns.insert< DimOpLowering, @@ -2267,7 +2268,7 @@ void mlir::populateStdToLLVMMemoryConversionPatters( patterns.insert< AllocOpLowering, DeallocOpLowering>( - *converter.getDialect(), converter, clUseAlloca.getValue()); + *converter.getDialect(), converter, useAlloca); // clang-format on } @@ -2277,10 +2278,11 @@ void mlir::populateStdToLLVMDefaultFuncOpConversionPattern( } void mlir::populateStdToLLVMConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool useAlloca) { populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); - populateStdToLLVMMemoryConversionPatters(converter, patterns); + populateStdToLLVMMemoryConversionPatters(converter, patterns, useAlloca); } static void populateStdToLLVMBarePtrFuncOpConversionPattern( @@ -2289,10 +2291,11 @@ static void populateStdToLLVMBarePtrFuncOpConversionPattern( } void mlir::populateStdToLLVMBarePtrConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + bool useAlloca) { populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); - populateStdToLLVMMemoryConversionPatters(converter, patterns); + populateStdToLLVMMemoryConversionPatters(converter, patterns, useAlloca); } // Convert types using the stored LLVM IR module. @@ -2360,7 +2363,7 @@ struct LLVMLoweringPass : public ModulePass { /// Creates an LLVM lowering pass. explicit LLVMLoweringPass(bool useAlloca = false, bool useBarePtrCallConv = false) - : useBarePtrCallConv(useBarePtrCallConv) {} + : useAlloca(useAlloca), useBarePtrCallConv(useBarePtrCallConv) {} /// Run the dialect converter on the module. void runOnModule() override { @@ -2374,9 +2377,10 @@ struct LLVMLoweringPass : public ModulePass { OwningRewritePatternList patterns; if (useBarePtrCallConv) - populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns); + populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns, + useAlloca); else - populateStdToLLVMConversionPatterns(typeConverter, patterns); + populateStdToLLVMConversionPatterns(typeConverter, patterns, useAlloca); ConversionTarget target(getContext()); target.addLegalDialect(); @@ -2384,6 +2388,10 @@ struct LLVMLoweringPass : public ModulePass { signalPassFailure(); } + /// Use `alloca` instead of `call @malloc` for converting std.alloc. + bool useAlloca; + + /// Convert memrefs to bare pointers in function signatures. bool useBarePtrCallConv; }; } // end namespace -- 2.7.4