return success();
}
+ BaseMemRefType getBufferType(Operation *op, BlockArgument bbArg,
+ const BufferizationOptions &options) const {
+ auto forOp = cast<scf::ForOp>(op);
+ return bufferization::getBufferType(
+ forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
+ }
+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto forOp = cast<scf::ForOp>(op);
getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices);
iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar());
- // Erase terminator if present.
- if (iterArgs.size() == 1)
- rewriter.eraseOp(loopBody->getTerminator());
-
// Move loop body to new loop.
rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs);
- // Update scf.yield of new loop.
- auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
- rewriter.setInsertionPoint(yieldOp);
- SmallVector<Value> yieldValues = getYieldedValues(
- rewriter, yieldOp.getResults(), initArgsTypes, indices, options);
- yieldOp.getResultsMutable().assign(yieldValues);
-
// Replace loop results.
replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults());
yieldOp->getParentOp()))
return yieldOp->emitError("unsupported scf::YieldOp parent");
- // TODO: Bufferize scf.yield inside scf.while/scf.for here.
- // (Currently bufferized together with scf.while/scf.for.)
- if (isa<scf::ForOp, scf::WhileOp>(yieldOp->getParentOp()))
+ // TODO: Bufferize scf.yield inside scf.while here. (Currently bufferized
+ // together with scf.while.)
+ if (isa<scf::WhileOp>(yieldOp->getParentOp()))
return success();
SmallVector<Value> newResults;
Value value = it.value();
if (value.getType().isa<TensorType>()) {
Value buffer = getBuffer(rewriter, value, options);
+ if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
+ BaseMemRefType resultType =
+ cast<BufferizableOpInterface>(forOp.getOperation())
+ .getBufferType(forOp.getRegionIterArgs()[it.index()],
+ options);
+ buffer = castBuffer(rewriter, buffer, resultType);
+ }
newResults.push_back(buffer);
} else {
newResults.push_back(value);