/// Transformation information returned after tiling.
struct SCFTilingResult {
- /// The tiled operation generated.
- Operation *tiledOp;
+ /// Tiled operations that are generated during tiling. The order does not
+ /// matter except the last op. The replacements are expected to be the results
+ /// of the last op.
+ SmallVector<Operation *> tiledOps;
/// The `scf.for` operations that iterate over the tiles.
SmallVector<scf::ForOp> loops;
/// Values to use as replacements for the untiled op. Is the same size as the
if (failed(maybeTilingResult))
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
- results.push_back(maybeTilingResult->tiledOp);
+ results.append(maybeTilingResult->tiledOps);
return DiagnosedSilenceableFailure(success());
}
rewriter.replaceOp(linalgOp,
maybeTilingResult->loops.front()->getResults());
- tiled.push_back(maybeTilingResult->tiledOp);
+ tiled.append(maybeTilingResult->tiledOps);
for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
loops[en2.index()].push_back(en2.value());
}
rewriter.replaceOp(tilingInterfaceOp, tilingResult->replacements);
- tiled.push_back(tilingResult->tiledOp);
+ tiled.append(tilingResult->tiledOps);
for (const auto &en2 : llvm::enumerate(tilingResult->loops))
loops[en2.index()].push_back(en2.value());
}
tilingResult.loops.back().getBody()->getTerminator());
SmallVector<Operation *> tiledImplementation =
op.getTiledImplementation(rewriter, offsets, sizes);
- if (tiledImplementation.size() != 1) {
- return rewriter.notifyMatchFailure(
- op, "expected tiled implementation to return a single op");
- }
- tilingResult.tiledOp = tiledImplementation[0];
+ tilingResult.tiledOps.append(tiledImplementation);
if (op->getNumResults() == 0) {
// nothing more to do.
return tilingResult;
}
FailureOr<SmallVector<Value>> replacementOr = yieldTiledValues(
- rewriter, destinationTensors, tilingResult.tiledOp->getResults(),
+ rewriter, destinationTensors, tilingResult.tiledOps.back()->getResults(),
resultOffsetsList, resultSizesList, tilingResult.loops);
if (failed(replacementOr))
return rewriter.notifyMatchFailure(op, "failed to yield replacement");
if (auto dstOp =
- dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOp)) {
+ dyn_cast<DestinationStyleOpInterface>(tilingResult.tiledOps.back())) {
auto innerMostLoop = tilingResult.loops.back();
SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands();
assert(destinationTensors.size() ==
tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
if (failed(tilingResult))
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
- tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp);
+ for (auto tiledOp : tilingResult->tiledOps)
+ tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
tileAndFuseResult.loops = std::move(tilingResult->loops);
for (const auto &result : llvm::enumerate(
llvm::zip(consumer->getResults(), tilingResult->replacements))) {
tileAndFuseResult.replacements[std::get<0>(result.value())] =
std::get<1>(result.value());
- yieldedValueToResultNumber[tilingResult->tiledOp->getResult(
+ yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
result.index())] = result.index();
}
}