This is useful because the result type of an op can sometimes be inferred from its body (e.g., `scf.if`). This will be utilized in subsequent changes.
Also introduces a new `getBufferType` interface method on BufferizableOpInterface. This method is useful for computing a bufferized block argument type with respect to OpOperand types of the parent op.
Differential Revision: https://reviews.llvm.org/D128420
/*defaultImplementation=*/[{
return success();
}]
- >
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return the bufferized type of the given tensor block argument. The
+ block argument is guaranteed to belong to a block of this op.
+ }],
+ /*retType=*/"BaseMemRefType",
+ /*methodName=*/"getBufferType",
+ /*args=*/(ins "BlockArgument":$bbArg,
+ "const BufferizationOptions &":$options),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(bbArg.getOwner()->getParentOp() == $_op &&
+ "bbArg must belong to this op");
+ auto tensorType = bbArg.getType().cast<TensorType>();
+ return bufferization::getMemRefType(tensorType, options);
+ }]
+ >,
];
let extraClassDeclaration = [{
Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options) {
+#ifndef NDEBUG
auto tensorType = value.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
+#endif // NDEBUG
// Replace "%t = to_tensor %m" with %m.
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
// Insert to_memref op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
- Type memrefType = getMemRefType(tensorType, options);
+ Type memrefType = getBufferType(value, options);
ensureToMemrefOpIsValid(value, memrefType);
return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType,
value);
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
return toTensorOp.getMemref().getType().cast<BaseMemRefType>();
+ if (auto bbArg = value.dyn_cast<BlockArgument>())
+ if (auto bufferizableOp =
+ options.dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
+ return bufferizableOp.getBufferType(bbArg, options);
+
return getMemRefType(tensorType, options);
}
// Otherwise, we have to use a memref type with a fully dynamic layout map to
// avoid copies. We are currently missing patterns for layout maps to
// canonicalize away (or canonicalize to more precise layouts).
+ //
+ // FuncOps must be bufferized before their bodies, so add them to the worklist
+ // first.
SmallVector<Operation *> worklist;
- op->walk<WalkOrder::PreOrder>([&](Operation *op) {
- if (hasTensorSemantics(op))
+ op->walk([&](func::FuncOp funcOp) {
+ if (hasTensorSemantics(funcOp))
+ worklist.push_back(funcOp);
+ });
+ op->walk<WalkOrder::PostOrder>([&](Operation *op) {
+ if (hasTensorSemantics(op) && !isa<func::FuncOp>(op))
worklist.push_back(op);
});
// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
- return getBufferType(bbArg, options).cast<Type>();
+ return bufferization::getBufferType(bbArg, options).cast<Type>();
}));
// Construct a new scf.while op with memref instead of tensor values.
LogicalResult bufferize(Operation *op, RewriterBase &b,
const BufferizationOptions &options) const {
// Will be bufferized as part of ForeachThreadOp.
- return failure();
+ return success();
}
// TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
// Op is bufferized as part of AssumingOp.
- return failure();
+ return success();
}
};
// CHECK: %[[alloc2:.*]] = memref.alloc(%{{.*}})
// CHECK: memref.copy %[[iter2]], %[[alloc2]]
// CHECK: memref.dealloc %[[iter2]]
-// CHECK: %[[casted2:.*]] = memref.cast %[[alloc2]]
// CHECK: %[[alloc1:.*]] = memref.alloc(%{{.*}})
// CHECK: memref.copy %[[iter1]], %[[alloc1]]
// CHECK: memref.dealloc %[[iter1]]
+// CHECK: %[[casted2:.*]] = memref.cast %[[alloc2]]
// CHECK: %[[casted1:.*]] = memref.cast %[[alloc1]]
// CHECK: %[[cloned1:.*]] = bufferization.clone %[[casted1]]
// CHECK: memref.dealloc %[[alloc1]]
// CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[w1]], %[[a1]]
// CHECK: memref.dealloc %[[w1]]
- // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
// CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[w0]], %[[a0]]
// CHECK: memref.dealloc %[[w0]]
+ // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
// CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
// CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
// CHECK: memref.dealloc %[[a0]]
// CHECK: %[[a1:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[w1]], %[[a1]]
// CHECK: memref.dealloc %[[w1]]
- // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
// CHECK: %[[a0:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[w0]], %[[a0]]
// CHECK: memref.dealloc %[[w0]]
+ // CHECK: %[[casted1:.*]] = memref.cast %[[a1]]
// CHECK: %[[casted0:.*]] = memref.cast %[[a0]]
// CHECK: %[[cloned0:.*]] = bufferization.clone %[[casted0]]
// CHECK: memref.dealloc %[[a0]]
// CHECK: %[[a3:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[b1]], %[[a3]]
// CHECK: memref.dealloc %[[b1]]
- // CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
// CHECK: %[[a2:.*]] = memref.alloc() {{.*}} : memref<5xi1>
// CHECK: memref.copy %[[b0]], %[[a2]]
+ // CHECK: %[[casted3:.*]] = memref.cast %[[a3]]
// CHECK: %[[casted2:.*]] = memref.cast %[[a2]]
// CHECK: %[[cloned2:.*]] = bufferization.clone %[[casted2]]
// CHECK: memref.dealloc %[[a2]]