From 54a5f606281d05203dca1d81d135e691b10bc513 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 12 Sep 2022 23:01:25 -0700 Subject: [PATCH] [mlir][scf][Transform] Refactor transform.fuse_into_containing_op so it is iterative and supports output fusion. This revision revisits the implementation of `transform.fuse_into_containing_op` so that it iterates on producers one use at a time. Support is added to fuse a producer through a foreach_thread shared tensor argument, in which case we tile and fuse the op inside the containing op and update the shared tensor argument to the unique destination operand. If one cannot find such a unique destination operand the transform fails. --- .../Linalg/TransformOps/LinalgTransformOps.cpp | 260 +++++++++++++++------ 1 file changed, 184 insertions(+), 76 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 29b13e2..49328a6 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -17,9 +17,12 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Parser/Parser.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringSet.h" using namespace mlir; @@ -226,78 +229,168 @@ LogicalResult transform::FuseOp::verify() { // FuseIntoContainingOp //===----------------------------------------------------------------------===// -static FailureOr> tileAndFuse(Operation *producerOp, - Operation *containingOp, - RewriterBase &rewriter) { +/// Find the first "extract" user of `producerOp` and tile it right before its +/// use. The tiled op is now fused under the `containingOp`. +/// Return this fused op on success or nullptr if anything fails. +static Operation *tileAndFuseFirstExtractUse(Operation *producerOp, + Operation *containingOp, + RewriterBase &rewriter) { auto tileableProducer = dyn_cast(producerOp); if (!tileableProducer) - return failure(); + return nullptr; // Search the producer slices accessed within the containing operation. - // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe - // evolve into an interface. - SmallVector sliceOps; - for (Operation *user : tileableProducer->getUsers()) { + // TODO: Generalize to more extract/insert/parallel_insert triples. + // Maybe evolve into an interface. + auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) { auto sliceOp = dyn_cast(user); - if (!sliceOp) - continue; - if (!containingOp->isProperAncestor(sliceOp)) + return sliceOp && containingOp->isProperAncestor(sliceOp); + }); + + // Check for a non-empty fusion opportunity. + if (it == tileableProducer->getUsers().end()) + return nullptr; + auto sliceOpToTile = cast(*it); + + // Try to fuse the producer in-place. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(sliceOpToTile); + + // Tile the producer. + FailureOr tiledProducer = tileableProducer.generateResultTileValue( + rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(), + sliceOpToTile.getMixedSizes()); + if (failed(tiledProducer)) + return nullptr; + + // Replace the extract op. + Operation *fusedOp = tiledProducer->getDefiningOp(); + rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0)); + return fusedOp; +} + +/// Find the first "extract" user of `producerOp` and tile it right before its +/// use. The tiled op is now fused under the `containingOp`. +/// Return this fused op on success or nullptr if anything fails. +static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( + Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) { + + auto foreachThreadOp = dyn_cast(containingOp); + if (!foreachThreadOp) + return nullptr; + + auto tileableProducer = dyn_cast(producerOp); + if (!tileableProducer) + return nullptr; + + // Search the producer slices accessed within the containing + // operation. + // TODO: Generalize to more extract/insert/parallel_insert triples. + // Maybe evolve into an interface. + OpOperand *pUse; + BlockArgument bbArg; + tensor::ExtractSliceOp sliceOpToTile; + // Only consider slices that may come from the containingOp args. + for (OpOperand &use : tileableProducer->getUses()) { + if (use.getOwner() != containingOp) continue; - sliceOps.push_back(sliceOp); + pUse = &use; + bbArg = foreachThreadOp.getTiedBlockArgument(&use); + for (Operation *user : bbArg.getUsers()) { + auto sliceOp = dyn_cast(user); + if (!sliceOp) + continue; + if (!containingOp->isAncestor(sliceOp)) + continue; + sliceOpToTile = sliceOp; + break; + } + if (sliceOpToTile) + break; } // Check for a non-empty list of fusion opportunities. - if (sliceOps.empty()) - return failure(); + if (!sliceOpToTile || !pUse) + return nullptr; - // Try to fuse the producer in-place. - SmallVector fusedOps; - for (tensor::ExtractSliceOp sliceOp : sliceOps) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(sliceOp); + // Ensure there is exactly one destination operand that we can replace the + // ForeachThreadOp bbArg with. + auto destinationOperands = tileableProducer.getDestinationOperands(rewriter); + if (destinationOperands.size() != 1) + return nullptr; - // Tile the producer. - FailureOr tiledProducer = tileableProducer.generateResultTileValue( - rewriter, /*resultNumber=*/0, sliceOp.getMixedOffsets(), - sliceOp.getMixedSizes()); - if (failed(tiledProducer)) - return failure(); - fusedOps.push_back(tiledProducer->getDefiningOp()); - } + // Try to fuse the producer in-place. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(sliceOpToTile); + + // Replace the use in the tileableProducer before tiling, replace and then + // tile. + BlockAndValueMapping bvm; + bvm.map(destinationOperands.front(), bbArg); + auto tileableProducerClone = + cast(rewriter.clone(*tileableProducer, bvm)); + auto scopeGuard = + llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); }); + + // Tile the producer. + FailureOr tiledProducer = + tileableProducerClone.generateResultTileValue( + rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(), + sliceOpToTile.getMixedSizes()); + if (failed(tiledProducer)) + return nullptr; // Replace the extract op. - for (const auto &en : enumerate(sliceOps)) - rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0)); - return fusedOps; + Operation *fusedOp = tiledProducer->getDefiningOp(); + rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0)); + + // Replace the use in containingOp. + rewriter.startRootUpdate(fusedOp); + containingOp->setOperand(pUse->getOperandNumber(), + destinationOperands.front()); + rewriter.finalizeRootUpdate(fusedOp); + + return fusedOp; } -static FailureOr> -cloneAndFuse(Operation *producerOp, Operation *containingOp, - RewriterBase &rewriter) { +static Operation *cloneAndFuseFirstUse(Operation *producerOp, + Operation *containingOp, + RewriterBase &rewriter) { // Gather all uses inside the containing op. SmallVector uses; - for (OpResult result : producerOp->getOpResults()) - for (OpOperand &use : result.getUses()) - if (containingOp->isProperAncestor(use.getOwner())) + for (OpResult result : producerOp->getOpResults()) { + for (OpOperand &use : result.getUses()) { + if (containingOp->isProperAncestor(use.getOwner())) { uses.push_back(&use); + continue; + } + // Cannot clone and fuse if the use is fom the containing op itself: fail. + if (containingOp == use.getOwner()) + return nullptr; + } + } // Check for a non-empty list of fusion opportunities. if (uses.empty()) - return failure(); + return nullptr; // Clone and fuse inside the containing op. - SmallVector fusedOps; + Operation *fusedOp = nullptr; for (OpOperand *use : uses) { + // Parallel insert slice is not a valid clone destination. + // TODO: Generalize to other type of ops. + assert(!isa(use->getOwner()) && + "Parallel insert slice is not a valid clone destination"); unsigned resultNumber = use->get().cast().getResultNumber(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(use->getOwner()); - Operation *cloned = rewriter.clone(*producerOp); + fusedOp = rewriter.clone(*producerOp); rewriter.updateRootInPlace( - use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); }); - fusedOps.push_back(cloned); + use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); }); + break; } - return fusedOps; + return fusedOp; } DiagnosedSilenceableFailure @@ -312,7 +405,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results, } for (Operation *producerOp : producerOps) { if (producerOp->getNumResults() != 1) { - Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note); + Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); diag << "op with != 1 results not supported"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } @@ -331,15 +424,17 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results, auto getNextProducer = [&]() -> FailureOr { for (const auto &it : enumerate(remainingProducers)) { Operation *producerOp = it.value(); - bool hasUseInContainingOp = - any_of(producerOp->getUsers(), [&](Operation *op) { - return containingOp->isProperAncestor(op); + // The containing op may be a user of producerOp: use isAncestor. + int64_t numUsesInContainingOp = + llvm::count_if(producerOp->getUsers(), [&](Operation *op) { + return containingOp->isAncestor(op); }); - // TODO: When resolving the TODO below (no duplicate ops), take an op that - // has no use among the remaining producers. This is a topological + // TODO: When resolving the TODO below (no duplicate ops), take an op + // that has no use among the remaining producers. This is a topological // sorting. - if (hasUseInContainingOp) { - remainingProducers.erase(remainingProducers.begin() + it.index()); + if (numUsesInContainingOp > 0) { + if (numUsesInContainingOp == 1) + remainingProducers.erase(remainingProducers.begin() + it.index()); return producerOp; } } @@ -350,29 +445,42 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results, while (!remainingProducers.empty()) { auto nextProducer = getNextProducer(); if (failed(nextProducer)) { - Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note); + Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark); diag << "could not fuse ops into container"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } Operation *producerOp = *nextProducer; - // TODO: If there are multiple uses of the producer in the containing op, we - // currently tile/clone the op multiple times (once per use). In some cases, - // we can tile/clone once and reuse the value for each use. Futhermore, - // producers should then be traversed according to a topological sorting. - auto tiled = tileAndFuse(producerOp, containingOp, rewriter); - if (succeeded(tiled)) - fusedOps.append(*tiled); - - auto cloned = cloneAndFuse(producerOp, containingOp, rewriter); - if (succeeded(cloned)) - fusedOps.append(*cloned); - - if (failed(tiled) && failed(cloned)) { - Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note); - diag << "could not fuse into containing op"; - return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + // TODO: If there are multiple uses of the producer in the containing op, + // we currently tile/clone the op multiple times (once per use). In some + // cases, we can tile/clone once and reuse the value for each use. + // Futhermore, producers should then be traversed according to a + // topological sorting. + Operation *tiled = + tileAndFuseFirstExtractUse(producerOp, containingOp, rewriter); + if (tiled) { + fusedOps.push_back(tiled); + continue; + } + + Operation *tiledContainingOpOperand = + tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( + producerOp, containingOp, rewriter); + if (tiledContainingOpOperand) { + fusedOps.push_back(tiledContainingOpOperand); + continue; + } + + Operation *cloned = + cloneAndFuseFirstUse(producerOp, containingOp, rewriter); + if (cloned) { + fusedOps.push_back(cloned); + continue; } + + Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); + diag << "could not fuse " << *producerOp << "into " << *containingOp; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } results.set(getFusedOp().cast(), fusedOps); @@ -626,9 +734,9 @@ LogicalResult transform::PadOp::verify() { extractFromI64ArrayAttr(getPaddingDimensions()); if (any_of(paddingDimensions, [](int64_t paddingDimension) { return paddingDimension < 0; })) { - return emitOpError() - << "expects padding_dimensions to contain positive integers, found " - << getPaddingDimensions(); + return emitOpError() << "expects padding_dimensions to contain positive " + "integers, found " + << getPaddingDimensions(); } SmallVector hoistPaddings = @@ -699,8 +807,8 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, transform::TransformState &state) { LinalgTilingOptions tilingOptions; tilingOptions.scalarizeDynamicDims(); - // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile - // sizes and asserts that it is not already set. + // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the + // tile sizes and asserts that it is not already set. SmallVector emptyTileSizes; LinalgTilingPattern pattern(getContext(), tilingOptions); SimpleRewriter rewriter(getContext()); @@ -847,8 +955,8 @@ LogicalResult SplitOp::verify() { if ((static_cast(getStaticSplitPoint()) != ShapedType::kDynamicSize) ^ (getDynamicSplitPoint() == nullptr)) { - return emitOpError() - << "expects either a dynamic or a static split point to be provided"; + return emitOpError() << "expects either a dynamic or a static split " + "point to be provided"; } return success(); } @@ -1202,8 +1310,8 @@ transform::VectorizeOp::applyToOne(Operation *target, //===----------------------------------------------------------------------===// namespace { -/// Registers new ops and declares PDL as dependent dialect since the additional -/// ops are using PDL types for operands and results. +/// Registers new ops and declares PDL as dependent dialect since the +/// additional ops are using PDL types for operands and results. class LinalgTransformDialectExtension : public transform::TransformDialectExtension< LinalgTransformDialectExtension> { -- 2.7.4