[mlir] StandardToLLVM: option to disable AllocOp lowering
authorButygin <ivan.butygin@intel.com>
Wed, 19 May 2021 19:04:29 +0000 (22:04 +0300)
committerButygin <ivan.butygin@intel.com>
Sun, 30 May 2021 14:50:25 +0000 (17:50 +0300)
Differential Revision: https://reviews.llvm.org/D103237

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

index a7a68ef..25be9f9 100644 (file)
@@ -37,8 +37,19 @@ public:
   bool useBarePtrCallConv = false;
   bool emitCWrappers = false;
 
-  /// Use aligned_alloc for heap allocations.
-  bool useAlignedAlloc = false;
+  enum class AllocLowering {
+    /// Use malloc for for heap allocations.
+    Malloc,
+
+    /// Use aligned_alloc for heap allocations.
+    AlignedAlloc,
+
+    /// Do not lower heap allocations. Users must provide their own patterns for
+    /// AllocOp and DeallocOp lowering.
+    None
+  };
+
+  AllocLowering allocLowering = AllocLowering::Malloc;
 
   /// The data layout of the module to produce. This must be consistent with the
   /// data layout used in the upper levels of the lowering pipeline.
index a9560cf..2fbfa11 100644 (file)
@@ -3922,7 +3922,6 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
   // clang-format off
   patterns.add<
       AssumeAlignmentOpLowering,
-      DeallocOpLowering,
       DimOpLowering,
       GlobalMemrefOpLowering,
       GetGlobalMemrefOpLowering,
@@ -3936,10 +3935,11 @@ void mlir::populateStdToLLVMMemoryConversionPatterns(
       TransposeOpLowering,
       ViewOpLowering>(converter);
   // clang-format on
-  if (converter.getOptions().useAlignedAlloc)
-    patterns.add<AlignedAllocOpLowering>(converter);
-  else
-    patterns.add<AllocOpLowering>(converter);
+  auto allocLowering = converter.getOptions().allocLowering;
+  if (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc)
+    patterns.add<AlignedAllocOpLowering, DeallocOpLowering>(converter);
+  else if (allocLowering == LowerToLLVMOptions::AllocLowering::Malloc)
+    patterns.add<AllocOpLowering, DeallocOpLowering>(converter);
 }
 
 void mlir::populateStdToLLVMFuncOpConversionPattern(
@@ -4071,7 +4071,9 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
     options.emitCWrappers = emitCWrappers;
     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
       options.overrideIndexBitwidth(indexBitwidth);
-    options.useAlignedAlloc = useAlignedAlloc;
+    options.allocLowering =
+        (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc
+                         : LowerToLLVMOptions::AllocLowering::Malloc);
     options.dataLayout = llvm::DataLayout(this->dataLayout);
     LLVMTypeConverter typeConverter(&getContext(), options);
 
@@ -4139,9 +4141,16 @@ std::unique_ptr<OperationPass<ModuleOp>> mlir::createLowerToLLVMPass() {
 
 std::unique_ptr<OperationPass<ModuleOp>>
 mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) {
+  auto allocLowering = options.allocLowering;
+  // There is no way to provide additional patterns for pass, so
+  // AllocLowering::None will always fail.
+  assert(allocLowering != LowerToLLVMOptions::AllocLowering::None &&
+         "LLVMLoweringPass doesn't support AllocLowering::None");
+  bool useAlignedAlloc =
+      (allocLowering == LowerToLLVMOptions::AllocLowering::AlignedAlloc);
   return std::make_unique<LLVMLoweringPass>(
       options.useBarePtrCallConv, options.emitCWrappers,
-      options.getIndexBitwidth(), options.useAlignedAlloc, options.dataLayout);
+      options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout);
 }
 
 mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx)