return nullptr;
}
+/// Compute the type of the `memref` to use for allocating the buffer for
+/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
+/// dynamic dimensions in the returned `memref` type. The function also sets the
+/// insertion point of the builder `b` to the position where the allocation is
+/// to be inserted.
+static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
+ Value shapedValue,
+ SmallVectorImpl<Value> &dynShape) {
+ MemRefType allocMemRefType =
+ getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
+ if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
+ b.setInsertionPointToStart(bbArg.getOwner());
+ loc = bbArg.getOwner()->getParentOp()->getLoc();
+ } else {
+ b.setInsertionPoint(shapedValue.getDefiningOp());
+ loc = shapedValue.getDefiningOp()->getLoc();
+ }
+
+ // Compute the dynamic part of the shape.
+ bool foundDynamicShapes = false;
+ if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
+ shapedValue.getDefiningOp())) {
+ ReifiedRankedShapedTypeDims resultDims;
+ if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
+ foundDynamicShapes = true;
+ OpResult resultValue = shapedValue.dyn_cast<OpResult>();
+ auto &shape = resultDims[resultValue.getResultNumber()];
+ for (auto dim : enumerate(allocMemRefType.getShape()))
+ if (dim.value() == ShapedType::kDynamicSize)
+ dynShape.push_back(shape[dim.index()]);
+ }
+ }
+ if (!foundDynamicShapes) {
+ for (auto dim : enumerate(allocMemRefType.getShape()))
+ if (dim.value() == ShapedType::kDynamicSize)
+ dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
+ }
+
+ // If the buffer is statically shaped, try to hoist it to the first enclosing
+ // parallel region.
+ // TODO: this concept of parallel region and threadlocal needs interfaces.
+ // TODO: also hoist in the dynamic case. For now this relies on subsequent
+ // calls to LICM and buffer hoisting which will most likely not succeed.
+ // TODO: when packing, allocate a static bounding box which will enable more
+ // hoisting.
+ if (dynShape.empty()) {
+ Operation *parent =
+ getFirstParentOfType<FuncOp, TiledLoopOp, scf::ParallelOp,
+ AffineParallelOp>(shapedValue);
+ if (parent)
+ b.setInsertionPointToStart(&(parent->getRegion(0).front()));
+ }
+ return allocMemRefType;
+}
+
/// Create an Allocop/DeAllocOp pair, where the AllocOp is after
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);
+ // 1. Create memory allocation.
assert(shapedValue.getType().isa<ShapedType>());
MemRefType memRefType = shapedValue.getType().dyn_cast<MemRefType>();
-
- Optional<Value> allocated = allocationFns.allocationFn(b, loc, shapedValue);
+ SmallVector<Value> dynShape;
+ // Note: getAllocationTypeAndShape also sets the insertion point.
+ MemRefType allocMemRefType =
+ getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
+ Optional<Value> allocated =
+ allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
// TODO: For now just assert the value is returned. Eventually need to
// error-propagate.
assert(allocated && "allocation failed");
Value casted = allocated.getValue();
- MemRefType allocMemRefType = allocated->getType().cast<MemRefType>();
if (memRefType && memRefType != allocMemRefType) {
casted = b.create<memref::CastOp>(loc, memRefType, allocated.getValue());
aliasInfo.insertNewBufferEquivalence(casted, allocated.getValue());
}
+ // 2. Create memory deallocation.
+ b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
allocationFns.deallocationFn(b, loc, allocated.getValue());
return casted;
}
// Bufferization entry-point for functions.
//===----------------------------------------------------------------------===//
-/// Compute the type of the `memref` to use for allocating the buffer for
-/// `shapedValue`. Also returns (by reference in `dynShape`), the value for the
-/// dynamic dimensions in the returned `memref` type. The function also sets the
-/// insertion point of the builder `b` to the position where the allocation is
-/// to be inserted.
-static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
- Value shapedValue,
- SmallVectorImpl<Value> &dynShape) {
- MemRefType allocMemRefType =
- getContiguousMemRefType(shapedValue.getType().cast<ShapedType>());
- if (auto bbArg = shapedValue.dyn_cast<BlockArgument>()) {
- b.setInsertionPointToStart(bbArg.getOwner());
- loc = bbArg.getOwner()->getParentOp()->getLoc();
- } else {
- b.setInsertionPoint(shapedValue.getDefiningOp());
- loc = shapedValue.getDefiningOp()->getLoc();
- }
-
- // Compute the dynamic part of the shape.
- bool foundDynamicShapes = false;
- if (auto rankedOp = dyn_cast_or_null<ReifyRankedShapedTypeOpInterface>(
- shapedValue.getDefiningOp())) {
- ReifiedRankedShapedTypeDims resultDims;
- if (succeeded(rankedOp.reifyResultShapes(b, resultDims))) {
- foundDynamicShapes = true;
- OpResult resultValue = shapedValue.dyn_cast<OpResult>();
- auto &shape = resultDims[resultValue.getResultNumber()];
- for (auto dim : enumerate(allocMemRefType.getShape()))
- if (dim.value() == ShapedType::kDynamicSize)
- dynShape.push_back(shape[dim.index()]);
- }
- }
- if (!foundDynamicShapes) {
- for (auto dim : enumerate(allocMemRefType.getShape()))
- if (dim.value() == ShapedType::kDynamicSize)
- dynShape.push_back(createOrFoldDimOp(b, loc, shapedValue, dim.index()));
- }
-
- // If the buffer is statically shaped, try to hoist it to the first enclosing
- // parallel region.
- // TODO: this concept of parallel region and threadlocal needs interfaces.
- // TODO: also hoist in the dynamic case. For now this relies on subsequent
- // calls to LICM and buffer hoisting which will most likely not succeed.
- // TODO: when packing, allocate a static bounding box which will enable more
- // hoisting.
- if (dynShape.empty()) {
- Operation *parent =
- getFirstParentOfType<FuncOp, TiledLoopOp, scf::ParallelOp,
- AffineParallelOp>(shapedValue);
- if (parent)
- b.setInsertionPointToStart(&(parent->getRegion(0).front()));
- }
- return allocMemRefType;
-}
-
-Optional<Value> mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc,
- Value shapedValue) {
- // Take a guard before anything else.
- OpBuilder::InsertionGuard g(b);
- SmallVector<Value> dynShape;
- MemRefType allocMemRefType =
- getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
+Optional<Value>
+mlir::linalg::defaultAllocationFn(OpBuilder &b, Location loc, MemRefType type,
+ const SmallVector<Value> &dynShape) {
Value allocated = b.create<memref::AllocOp>(
- loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
+ loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated;
}
-static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
- Value shapedValue) {
- OpBuilder::InsertionGuard g(b);
- SmallVector<Value> dynShape;
- MemRefType allocMemRefType =
- getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
+static Optional<Value>
+allocationFnUsingAlloca(OpBuilder &b, Location loc, MemRefType type,
+ const SmallVector<Value> &dynShape) {
Value allocated = b.create<memref::AllocaOp>(
- loc, allocMemRefType, dynShape, b.getI64IntegerAttr(kBufferAlignments));
+ loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated;
}
void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc,
Value allocatedBuffer) {
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(allocatedBuffer.getParentBlock()->getTerminator());
b.create<memref::DeallocOp>(loc, allocatedBuffer);
}