From 52ffc728181bc2d3c889f7f80c252c3433b9e7b6 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Fri, 4 Nov 2022 13:58:59 -0700 Subject: [PATCH] [mlir][tiling] Relax tiling to accept generating multiple operations. Some operations need to generate multiple operations when implementing the tiling interface. Here is a sound example in IREE, see https://github.com/iree-org/iree/pull/10905 for more details. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D137300 --- .../mlir/Dialect/SCF/Transforms/TileUsingInterface.h | 6 ++++-- .../Dialect/Linalg/TransformOps/LinalgTransformOps.cpp | 6 +++--- mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 15 ++++++--------- .../Interfaces/TilingInterface/TestTilingInterface.cpp | 3 ++- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 9fa4114..151993c 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -62,8 +62,10 @@ struct SCFTilingOptions { /// Transformation information returned after tiling. struct SCFTilingResult { - /// The tiled operation generated. - Operation *tiledOp; + /// Tiled operations that are generated during tiling. The order does not + /// matter except the last op. The replacements are expected to be the results + /// of the last op. + SmallVector tiledOps; /// The `scf.for` operations that iterate over the tiles. SmallVector loops; /// Values to use as replacements for the untiled op. Is the same size as the diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index a35dd14..6b8ca91 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -931,7 +931,7 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, if (failed(maybeTilingResult)) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); - results.push_back(maybeTilingResult->tiledOp); + results.append(maybeTilingResult->tiledOps); return DiagnosedSilenceableFailure(success()); } @@ -1251,7 +1251,7 @@ transform::TileOp::apply(TransformResults &transformResults, rewriter.replaceOp(linalgOp, maybeTilingResult->loops.front()->getResults()); - tiled.push_back(maybeTilingResult->tiledOp); + tiled.append(maybeTilingResult->tiledOps); for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops)) loops[en2.index()].push_back(en2.value()); } @@ -1609,7 +1609,7 @@ transform::TileToScfForOp::apply(TransformResults &transformResults, rewriter.replaceOp(tilingInterfaceOp, tilingResult->replacements); - tiled.push_back(tilingResult->tiledOp); + tiled.append(tilingResult->tiledOps); for (const auto &en2 : llvm::enumerate(tilingResult->loops)) loops[en2.index()].push_back(en2.value()); } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 0c86bd4..6e59bdb 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -360,11 +360,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, tilingResult.loops.back().getBody()->getTerminator()); SmallVector tiledImplementation = op.getTiledImplementation(rewriter, offsets, sizes); - if (tiledImplementation.size() != 1) { - return rewriter.notifyMatchFailure( - op, "expected tiled implementation to return a single op"); - } - tilingResult.tiledOp = tiledImplementation[0]; + tilingResult.tiledOps.append(tiledImplementation); if (op->getNumResults() == 0) { // nothing more to do. return tilingResult; @@ -396,13 +392,13 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, } FailureOr> replacementOr = yieldTiledValues( - rewriter, destinationTensors, tilingResult.tiledOp->getResults(), + rewriter, destinationTensors, tilingResult.tiledOps.back()->getResults(), resultOffsetsList, resultSizesList, tilingResult.loops); if (failed(replacementOr)) return rewriter.notifyMatchFailure(op, "failed to yield replacement"); if (auto dstOp = - dyn_cast(tilingResult.tiledOp)) { + dyn_cast(tilingResult.tiledOps.back())) { auto innerMostLoop = tilingResult.loops.back(); SmallVector destinationTensors = dstOp.getDpsInitOperands(); assert(destinationTensors.size() == @@ -554,13 +550,14 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); if (failed(tilingResult)) return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); - tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp); + for (auto tiledOp : tilingResult->tiledOps) + tileAndFuseResult.tiledAndFusedOps.insert(tiledOp); tileAndFuseResult.loops = std::move(tilingResult->loops); for (const auto &result : llvm::enumerate( llvm::zip(consumer->getResults(), tilingResult->replacements))) { tileAndFuseResult.replacements[std::get<0>(result.value())] = std::get<1>(result.value()); - yieldedValueToResultNumber[tilingResult->tiledOp->getResult( + yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult( result.index())] = result.index(); } } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp index 31e3c1a..1644179 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -193,7 +193,8 @@ struct TestTileUsingSCFForOp rewriter.eraseOp(op); } - filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp); + for (auto tiledOp : tilingResult->tiledOps) + filter.replaceLinalgTransformationFilter(rewriter, tiledOp); return success(); } -- 2.7.4