[mlir][linalg][bufferize] Add `createDealloc` options
authorMatthias Springer <springerm@google.com>
Thu, 6 Jan 2022 21:10:55 +0000 (06:10 +0900)
committerMatthias Springer <springerm@google.com>
Thu, 6 Jan 2022 21:13:57 +0000 (06:13 +0900)
If `createDealloc` is deactivated (enabled by default), newly allocated buffers are not deallocated anymore. In such a case, the missing deallocations can be inserted by the existing "BufferDeallocation" pass.

This change is needed for unifying core bufferization and Comprehensive Bufferize. Core bufferization has a separate pass for generating deallocations.

Note: In the future, this will evolve towards generating deallocation ops only for buffer allocations that do not escape block boundaries (i.e., that are in destination passing style).

Differential Revision: https://reviews.llvm.org/D116450

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp

index ccc939f..0bd42e3 100644 (file)
@@ -133,6 +133,10 @@ struct BufferizationOptions {
   /// the boundaries.
   bool allowUnknownOps = false;
 
+  /// Specifies whether dealloc ops should be generated along with alloc ops. If
+  /// not, new memory allocations will leak.
+  bool createDeallocs = true;
+
   /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
   /// Should be used only with `testAnalysisOnly = true`.
   unsigned analysisFuzzerSeed = 0;
@@ -368,10 +372,12 @@ public:
   Optional<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
                               ArrayRef<Value> dynShape) const;
 
-  /// Creates an alloc-dealloc pair. This function may perform additional
-  /// optimizations such as buffer allocation hoisting.
-  Value createAllocDeallocPair(OpBuilder &builder, Location loc,
-                               Value shapedValue) const;
+  /// Creates a memref allocation for the given shaped value. This function may
+  /// perform additional optimizations such as buffer allocation hoisting. If
+  /// `createDealloc`, a deallocation op is inserted at the point where the
+  /// allocation goes out of scope.
+  Value createAlloc(OpBuilder &b, Location loc, Value shapedValue,
+                    bool deallocMemref) const;
 
   /// Creates a memref deallocation. The given memref buffer must have been
   /// allocated using `createAlloc`.
index 785c49e..84dba99 100644 (file)
@@ -392,7 +392,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
     // allocation should be inserted (in the absence of allocation hoisting).
     setInsertionPointAfter(rewriter, operandBuffer);
     // Allocate the result buffer.
-    Value resultBuffer = createAllocDeallocPair(rewriter, loc, operandBuffer);
+    Value resultBuffer =
+        createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
     bool skipCopy = false;
     // Do not copy if the last preceding write of `operand` is an op that does
     // not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -534,12 +535,11 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
   return allocMemRefType;
 }
 
-/// Create an Allocop/DeAllocOp pair, where the AllocOp is after
+/// Create an AllocOp/DeallocOp pair, where the AllocOp is after
 /// `shapedValue.getDefiningOp` (or at the top of the block in case of a
 /// bbArg) and the DeallocOp is at the end of the block.
-Value mlir::linalg::comprehensive_bufferize::BufferizationState::
-    createAllocDeallocPair(OpBuilder &b, Location loc,
-                           Value shapedValue) const {
+Value mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
+    OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref) const {
   // Take a guard before anything else.
   OpBuilder::InsertionGuard g(b);
 
@@ -559,9 +559,12 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
     casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
   }
 
-  // 2. Create memory deallocation.
-  b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
-  createDealloc(b, loc, allocated.getValue());
+  if (deallocMemref) {
+    // 2. Create memory deallocation.
+    b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
+    createDealloc(b, loc, allocated.getValue());
+  }
+
   return casted;
 }
 
index 9922130..5c54625 100644 (file)
@@ -199,8 +199,9 @@ struct InitTensorOpInterface
     if (initTensorOp->getUses().empty())
       return success();
 
-    Value alloc = state.createAllocDeallocPair(rewriter, initTensorOp->getLoc(),
-                                               initTensorOp.result());
+    Value alloc = state.createAlloc(rewriter, initTensorOp->getLoc(),
+                                    initTensorOp.result(),
+                                    state.getOptions().createDeallocs);
     replaceOpWithBufferizedValues(rewriter, op, alloc);
     return success();
   }
index 894b5c6..e1cd933 100644 (file)
@@ -142,8 +142,8 @@ struct ExtractSliceOpInterface
     bool inplace = state.isInPlace(extractSliceOp->getResult(0));
     Value alloc;
     if (!inplace)
-      alloc =
-          state.createAllocDeallocPair(rewriter, loc, extractSliceOp.result());
+      alloc = state.createAlloc(rewriter, loc, extractSliceOp.result(),
+                                state.getOptions().createDeallocs);
 
     // Bufferize to subview.
     auto subviewMemRefType =