#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Interfaces/TilingInterface.h"
-#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/StringSet.h"
// FuseIntoContainingOp
//===----------------------------------------------------------------------===//
-static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp,
- Operation *containingOp,
- RewriterBase &rewriter) {
+/// Find the first "extract" user of `producerOp` and tile it right before its
+/// use. The tiled op is fused under the `containingOp`.
+/// Return this fused op on success or nullptr if anything fails.
+static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
+ Operation *containingOp,
+ RewriterBase &rewriter) {
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer)
- return failure();
+ return nullptr;
// Search the producer slices accessed within the containing operation.
- // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe
+ // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
// evolve into an interface.
- SmallVector<tensor::ExtractSliceOp> sliceOps;
- for (Operation *user : tileableProducer->getUsers()) {
+ auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
- if (!sliceOp)
- continue;
- if (!containingOp->isProperAncestor(sliceOp))
- continue;
- sliceOps.push_back(sliceOp);
- }
+ return sliceOp && containingOp->isProperAncestor(sliceOp);
+ });
- // Check for a non-empty list of fusion opportunities.
- if (sliceOps.empty())
- return failure();
+ // Find a fusion opportunity.
+ if (it == tileableProducer->getUsers().end())
+ return nullptr;
+ auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
// Try to fuse the producer in-place.
- SmallVector<Operation *> fusedOps;
- for (tensor::ExtractSliceOp sliceOp : sliceOps) {
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(sliceOp);
-
- // Tile the producer.
- FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
- rewriter, /*resultNumber=*/0, sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes());
- if (failed(tiledProducer))
- return failure();
- fusedOps.push_back(tiledProducer->getDefiningOp());
- }
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(sliceOpToTile);
+
+ // Tile the producer.
+ FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
+ rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
+ sliceOpToTile.getMixedSizes());
+ if (failed(tiledProducer))
+ return nullptr;
+
+ // Replace the extract op.
+ Operation *fusedOp = tiledProducer->getDefiningOp();
+ rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
+ return fusedOp;
+}
+
+/// First, find the first "scf::ForeachThreadOp" user of `producerOp` and ensure
+/// it is exactly the `containingOp`, otherwise bail.
+/// Then, find the first "extract" user of the tied block argument and tile it
+/// right before its "extract" use. The tiled op is fused under the
+/// `containingOp`.
+/// Return this fused op on success or nullptr if anything fails.
+static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
+ Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) {
+
+ auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
+ if (!tileableProducer)
+ return nullptr;
+
+ // Search the first use by a "scf::ForeachThreadOp" user.
+ scf::ForeachThreadOp foreachThreadOp;
+ auto itProducerUses =
+ llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
+ foreachThreadOp = dyn_cast<scf::ForeachThreadOp>(use.getOwner());
+ return foreachThreadOp;
+ });
+ // If it's not from the containing op, return.
+ if (!foreachThreadOp || foreachThreadOp != containingOp)
+ return nullptr;
+
+ // Search the producer slices accessed within the containing
+ // operation.
+ // TODO: Generalize to more extract/insert/parallel_insert triples.
+ // Maybe evolve into an interface.
+ OpOperand *pUse = &(*itProducerUses);
+ BlockArgument bbArg = foreachThreadOp.getTiedBlockArgument(pUse);
+
+ // Search the producer slices accessed within the containing operation.
+ // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
+ // evolve into an interface.
+ auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
+ auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
+ return sliceOp && containingOp->isProperAncestor(sliceOp);
+ });
+
+ // Find a fusion opportunity.
+ if (itBBArgUsers == bbArg.getUsers().end())
+ return nullptr;
+ auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
+
+ // Ensure `tileableProducer` has exactly one destination operand that we can
+ // replace the ForeachThreadOp bbArg with.
+ auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
+ if (destinationOperands.size() != 1)
+ return nullptr;
+
+ // Try to fuse the producer in-place.
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(sliceOpToTile);
+
+ // Replace the use in the tileableProducer before tiling: clone, replace and
+ // then tile.
+ BlockAndValueMapping bvm;
+ bvm.map(destinationOperands.front(), bbArg);
+ auto tileableProducerClone =
+ cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
+ auto scopeGuard =
+ llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
+
+ // Tile the producer.
+ FailureOr<Value> tiledProducer =
+ tileableProducerClone.generateResultTileValue(
+ rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
+ sliceOpToTile.getMixedSizes());
+ if (failed(tiledProducer))
+ return nullptr;
// Replace the extract op.
- for (const auto &en : enumerate(sliceOps))
- rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0));
- return fusedOps;
+ Operation *fusedOp = tiledProducer->getDefiningOp();
+ rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
+
+ // Replace the use in containingOp.
+ rewriter.updateRootInPlace(containingOp, [&]() {
+ containingOp->setOperand(pUse->getOperandNumber(),
+ destinationOperands.front());
+ });
+
+ return fusedOp;
}
-static FailureOr<SmallVector<Operation *>>
-cloneAndFuse(Operation *producerOp, Operation *containingOp,
- RewriterBase &rewriter) {
+static Operation *cloneAndFuseFirstUse(Operation *producerOp,
+ Operation *containingOp,
+ RewriterBase &rewriter) {
// Gather all uses inside the containing op.
SmallVector<OpOperand *> uses;
- for (OpResult result : producerOp->getOpResults())
- for (OpOperand &use : result.getUses())
- if (containingOp->isProperAncestor(use.getOwner()))
+ for (OpResult result : producerOp->getOpResults()) {
+ for (OpOperand &use : result.getUses()) {
+ if (containingOp->isProperAncestor(use.getOwner())) {
uses.push_back(&use);
+ continue;
+ }
+ // Cannot clone and fuse if the use is by the containing op itself: fail
+ // immediately.
+ if (containingOp == use.getOwner())
+ return nullptr;
+ }
+ }
// Check for a non-empty list of fusion opportunities.
if (uses.empty())
- return failure();
+ return nullptr;
// Clone and fuse inside the containing op.
- SmallVector<Operation *> fusedOps;
- for (OpOperand *use : uses) {
- unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(use->getOwner());
- Operation *cloned = rewriter.clone(*producerOp);
- rewriter.updateRootInPlace(
- use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); });
- fusedOps.push_back(cloned);
- }
-
- return fusedOps;
+ Operation *fusedOp = nullptr;
+ OpOperand *use = uses.front();
+ // Parallel insert slice is not a valid clone destination.
+ // TODO: Generalize to other type of ops.
+ assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
+ "Parallel insert slice is not a valid clone destination");
+ unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(use->getOwner());
+ fusedOp = rewriter.clone(*producerOp);
+ rewriter.updateRootInPlace(
+ use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
+
+ return fusedOp;
}
DiagnosedSilenceableFailure
}
for (Operation *producerOp : producerOps) {
if (producerOp->getNumResults() != 1) {
- Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
+ Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
diag << "op with != 1 results not supported";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
auto getNextProducer = [&]() -> FailureOr<Operation *> {
for (const auto &it : enumerate(remainingProducers)) {
Operation *producerOp = it.value();
- bool hasUseInContainingOp =
- any_of(producerOp->getUsers(), [&](Operation *op) {
- return containingOp->isProperAncestor(op);
+ // The containing op may be a user of producerOp: use isAncestor.
+ int64_t numUsesInContainingOp =
+ llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
+ return containingOp->isAncestor(op);
});
- // TODO: When resolving the TODO below (no duplicate ops), take an op that
- // has no use among the remaining producers. This is a topological
+ // TODO: When resolving the TODO below (no duplicate ops), take an op
+ // that has no use among the remaining producers. This is a topological
// sorting.
- if (hasUseInContainingOp) {
- remainingProducers.erase(remainingProducers.begin() + it.index());
+ if (numUsesInContainingOp > 0) {
+ if (numUsesInContainingOp == 1)
+ remainingProducers.erase(remainingProducers.begin() + it.index());
return producerOp;
}
}
while (!remainingProducers.empty()) {
auto nextProducer = getNextProducer();
if (failed(nextProducer)) {
- Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note);
+ Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
diag << "could not fuse ops into container";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
Operation *producerOp = *nextProducer;
- // TODO: If there are multiple uses of the producer in the containing op, we
- // currently tile/clone the op multiple times (once per use). In some cases,
- // we can tile/clone once and reuse the value for each use. Futhermore,
- // producers should then be traversed according to a topological sorting.
- auto tiled = tileAndFuse(producerOp, containingOp, rewriter);
- if (succeeded(tiled))
- fusedOps.append(*tiled);
-
- auto cloned = cloneAndFuse(producerOp, containingOp, rewriter);
- if (succeeded(cloned))
- fusedOps.append(*cloned);
-
- if (failed(tiled) && failed(cloned)) {
- Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
- diag << "could not fuse into containing op";
- return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+ // TODO: If there are multiple uses of the producer in the containing op,
+ // we currently tile/clone the op multiple times (once per use). In some
+ // cases, we can tile/clone once and reuse the value for each use.
+ // Futhermore, producers should then be traversed according to a
+ // topological sorting.
+ Operation *tiled =
+ tileAndFuseFirstExtractUse(producerOp, containingOp, rewriter);
+ if (tiled) {
+ fusedOps.push_back(tiled);
+ continue;
+ }
+
+ Operation *tiledContainingOpOperand =
+ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
+ producerOp, containingOp, rewriter);
+ if (tiledContainingOpOperand) {
+ fusedOps.push_back(tiledContainingOpOperand);
+ continue;
}
+
+ Operation *cloned =
+ cloneAndFuseFirstUse(producerOp, containingOp, rewriter);
+ if (cloned) {
+ fusedOps.push_back(cloned);
+ continue;
+ }
+
+ Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
+ diag << "could not fuse " << *producerOp << "into " << *containingOp;
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
results.set(getFusedOp().cast<OpResult>(), fusedOps);
extractFromI64ArrayAttr(getPaddingDimensions());
if (any_of(paddingDimensions,
[](int64_t paddingDimension) { return paddingDimension < 0; })) {
- return emitOpError()
- << "expects padding_dimensions to contain positive integers, found "
- << getPaddingDimensions();
+ return emitOpError() << "expects padding_dimensions to contain positive "
+ "integers, found "
+ << getPaddingDimensions();
}
SmallVector<int64_t> hoistPaddings =
transform::TransformState &state) {
LinalgTilingOptions tilingOptions;
tilingOptions.scalarizeDynamicDims();
- // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
- // sizes and asserts that it is not already set.
+ // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the
+ // tile sizes and asserts that it is not already set.
SmallVector<int64_t> emptyTileSizes;
LinalgTilingPattern pattern(getContext(), tilingOptions);
SimpleRewriter rewriter(getContext());
if ((static_cast<int64_t>(getStaticSplitPoint()) !=
ShapedType::kDynamicSize) ^
(getDynamicSplitPoint() == nullptr)) {
- return emitOpError()
- << "expects either a dynamic or a static split point to be provided";
+ return emitOpError() << "expects either a dynamic or a static split "
+ "point to be provided";
}
return success();
}
//===----------------------------------------------------------------------===//
namespace {
-/// Registers new ops and declares PDL as dependent dialect since the additional
-/// ops are using PDL types for operands and results.
+/// Registers new ops and declares PDL as dependent dialect since the
+/// additional ops are using PDL types for operands and results.
class LinalgTransformDialectExtension
: public transform::TransformDialectExtension<
LinalgTransformDialectExtension> {