BufferizationState(const AnalysisState &analysisState)
: analysisState(analysisState) {}
+ /// Creates a memref allocation with the given type and dynamic extents.
+ FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+ ValueRange dynShape);
+
+ /// Creates a memref allocation for the given shaped value. This function may
+ /// perform additional optimizations such as buffer allocation hoisting.
+ // TODO: Allocation hoisting should be a cleanup pass.
+ FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue);
+
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization was decided.
FailureOr<Value>
getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
bool forceInPlace = false,
- Optional<Operation *> customCopyInsertionPoint = None) const;
+ Optional<Operation *> customCopyInsertionPoint = None);
/// Return a reference to the BufferizationOptions.
const BufferizationOptions &getOptions() const {
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = {});
-/// Creates a memref allocation with the given type and dynamic extents.
-FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ValueRange dynShape,
- const BufferizationOptions &options);
-
-/// Creates a memref allocation with the given type and dynamic extents. If
-/// `createDealloc`, a deallocation op is inserted at the point where the
-/// allocation goes out of scope.
-FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ValueRange dynShape, bool deallocMemref,
- const BufferizationOptions &options);
-
-/// Creates a memref allocation for the given shaped value. This function may
-/// perform additional optimizations such as buffer allocation hoisting. If
-/// `createDealloc`, a deallocation op is inserted at the point where the
-/// allocation goes out of scope.
-// TODO: Allocation hoisting should be a cleanup pass.
-FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
- bool deallocMemref,
- const BufferizationOptions &options);
-
/// Creates a memref deallocation. The given memref buffer must have been
/// allocated using `createAlloc`.
LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to,
const BufferizationOptions &options);
+/// Finalize all buffer allocations, i.e., create alloc ops as specified in the
+/// bufferization options and deallocate all buffers.
+LogicalResult finalizeBuffers(Operation *op,
+ const BufferizationOptions &options);
} // namespace bufferization
} // namespace mlir
constexpr const ::llvm::StringLiteral
bufferization::BufferizableOpInterface::kInplaceableAttrName;
+static const char *kBufferAllocationAttr = "bufferization.allocation";
+
//===----------------------------------------------------------------------===//
// BufferizationOptions
//===----------------------------------------------------------------------===//
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
-FailureOr<Value> BufferizationState::getBuffer(
- RewriterBase &rewriter, OpOperand &opOperand, bool forceInPlace,
- Optional<Operation *> customCopyInsertionPoint) const {
+FailureOr<Value>
+BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
+ bool forceInPlace,
+ Optional<Operation *> customCopyInsertionPoint) {
const BufferizationOptions &options = analysisState.getOptions();
OpBuilder::InsertionGuard guard(rewriter);
Operation *op = opOperand.getOwner();
// allocation should be inserted (in the absence of allocation hoisting).
setInsertionPointAfter(rewriter, operandBuffer);
// Allocate the result buffer.
- FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer,
- options.createDeallocs, options);
+ FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer);
if (failed(resultBuffer))
return failure();
// Do not copy if the last preceding writes of `operand` are ops that do
// Bufferization-specific scoped alloc/dealloc insertion support.
//===----------------------------------------------------------------------===//
+/// Create a memref allocation with the given type and dynamic extents.
+static FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
+ ValueRange dynShape,
+ const BufferizationOptions &options) {
+ if (options.allocationFn)
+ return (*options.allocationFn)(b, loc, type, dynShape,
+ options.bufferAlignment);
+
+ // Default bufferallocation via AllocOp.
+ Value allocated = b.create<memref::AllocOp>(
+ loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment));
+ return allocated;
+}
+
+/// Creates a memref deallocation. The given memref buffer must have been
+/// allocated using `createAlloc`.
+LogicalResult
+bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
+ const BufferizationOptions &options) {
+ if (options.deallocationFn)
+ return (*options.deallocationFn)(b, loc, allocatedBuffer);
+
+ // Default buffer deallocation via DeallocOp.
+ b.create<memref::DeallocOp>(loc, allocatedBuffer);
+ return success();
+}
+
/// Move the insertion point of the given builder to the beginning of a
/// surrounding block as much as possible, while not crossing any allocation
/// hoisting barriers.
return allocMemRefType;
}
-/// Create an AllocOp/DeallocOp pair, where the AllocOp is after
-/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
-/// bbArg) and the DeallocOp is at the end of the block.
-FailureOr<Value>
-bufferization::createAlloc(OpBuilder &b, Location loc, Value shapedValue,
- bool deallocMemref,
- const BufferizationOptions &options) {
+static Value createBufferAllocation(OpBuilder &b, Location loc, MemRefType type,
+ ValueRange dynShape) {
+ auto allocaOp = b.create<memref::AllocaOp>(loc, type, dynShape);
+ allocaOp->setAttr(kBufferAllocationAttr, b.getUnitAttr());
+ return allocaOp.getResult();
+}
+
+/// Create an allocation after `shapedValue.getDefiningOp` (or at the top of the
+/// block in case of a bbArg).
+FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
+ Value shapedValue) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
-
- // 1. Create memory allocation.
assert(shapedValue.getType().isa<ShapedType>());
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
SmallVector<Value> dynShape;
// Note: getAllocationTypeAndShape also sets the insertion point.
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
- FailureOr<Value> allocated =
- createAlloc(b, loc, allocMemRefType, dynShape, options);
- if (failed(allocated))
- return failure();
- Value casted = allocated.getValue();
+ Value alloc = createBufferAllocation(b, loc, allocMemRefType, dynShape);
if (memRefType && memRefType != allocMemRefType) {
- assert(memref::CastOp::areCastCompatible(allocated.getValue().getType(),
- memRefType) &&
+ assert(memref::CastOp::areCastCompatible(alloc.getType(), memRefType) &&
"createAlloc: cast incompatible");
- casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
- }
-
- if (deallocMemref) {
- // 2. Create memory deallocation.
- b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
- if (failed(createDealloc(b, loc, allocated.getValue(), options)))
- return failure();
- }
-
- return casted;
-}
-
-/// Create a memref allocation with the given type and dynamic extents.
-FailureOr<Value>
-bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ValueRange dynShape,
- const BufferizationOptions &options) {
- if (options.allocationFn)
- return (*options.allocationFn)(b, loc, type, dynShape,
- options.bufferAlignment);
-
- // Default bufferallocation via AllocOp.
- Value allocated = b.create<memref::AllocOp>(
- loc, type, dynShape, b.getI64IntegerAttr(options.bufferAlignment));
- return allocated;
-}
-
-/// Create a memref allocation with the given type and dynamic extents. May also
-/// deallocate the memref again.
-FailureOr<Value>
-bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type,
- ValueRange dynShape, bool deallocMemref,
- const BufferizationOptions &options) {
- OpBuilder::InsertionGuard g(b);
-
- FailureOr<Value> alloc = createAlloc(b, loc, type, dynShape, options);
- if (failed(alloc))
- return failure();
-
- if (deallocMemref) {
- // Dealloc at the end of the block.
- b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator());
- if (failed(createDealloc(b, loc, *alloc, options)))
- return failure();
+ alloc = b.create<memref::CastOp>(loc, memRefType, alloc);
}
-
return alloc;
}
-/// Create a memref deallocation.
-LogicalResult
-bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
- const BufferizationOptions &options) {
- if (options.deallocationFn)
- return (*options.deallocationFn)(b, loc, allocatedBuffer);
-
- // Default buffer deallocation via DeallocOp.
- b.create<memref::DeallocOp>(loc, allocatedBuffer);
- return success();
+/// Create a memref allocation with the given type and dynamic extents.
+FailureOr<Value> BufferizationState::createAlloc(OpBuilder &b, Location loc,
+ MemRefType type,
+ ValueRange dynShape) {
+ return createBufferAllocation(b, loc, type, dynShape);
}
/// Create a memory copy between two memref buffers.
return success();
}
+LogicalResult
+bufferization::finalizeBuffers(Operation *op,
+ const BufferizationOptions &options) {
+ IRRewriter rewriter(op->getContext());
+
+ // Bufferization creates memref.alloca ops. After bufferization, these must be
+ // rewritten to alloc/dealloc ops as specified in the bufferization options.
+ WalkResult status = op->walk([&](memref::AllocaOp allocaOp) {
+ // Ignore memref.alloca ops that were not created by the bufferization.
+ if (!allocaOp->hasAttr(kBufferAllocationAttr))
+ return WalkResult::skip();
+
+ Block *block = allocaOp->getBlock();
+ rewriter.setInsertionPoint(allocaOp);
+ FailureOr<Value> alloc =
+ createAlloc(rewriter, allocaOp->getLoc(), allocaOp.getType(),
+ allocaOp.dynamicSizes(), options);
+ if (failed(alloc))
+ return WalkResult::interrupt();
+ rewriter.replaceOp(allocaOp, *alloc);
+
+ // Stop here if deallocations are deactivated.
+ if (!options.createDeallocs)
+ return WalkResult::advance();
+
+ rewriter.setInsertionPoint(block->getTerminator());
+ if (failed(createDealloc(rewriter, alloc->getLoc(), *alloc, options)))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+
+ return success(!status.wasInterrupted());
+}
+
//===----------------------------------------------------------------------===//
// Bufferization-specific BlockAndValueMapping support with debugging.
//===----------------------------------------------------------------------===//