#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Debug.h"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::transform;
+#define DEBUG_TYPE "linalg-transforms"
+
/// Extracts a vector of unsigned from an array attribute. Asserts if the
/// attribute contains values other than intergers. May truncate.
static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
Diagnostic &diag,
Operation *producerOp,
Operation *containingOp) {
+ LLVM_DEBUG(llvm::dbgs() << "Try to fuse a direct extract use\n");
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer) {
diag.attachNote(producerOp->getLoc())
rewriter.setInsertionPoint(sliceOpToTile);
// Tile the producer.
+ int64_t resultNumber =
+ sliceOpToTile.getSource().cast<OpResult>().getResultNumber();
+ LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+
FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
- rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
+ rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
if (failed(tiledProducer)) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to tile producer op: " << *tileableProducer;
return nullptr;
}
+ LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n");
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
- rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
+ rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
return fusedOp;
}
static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
Operation *containingOp) {
+ LLVM_DEBUG(
+ llvm::dbgs() << "Try to fuse an extract use through block argument\n");
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer) {
return nullptr;
}
- // Ensure `tileableProducer` has exactly one destination operand that we can
- // replace the ForeachThreadOp bbArg with.
- auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
- if (destinationOperands.size() != 1) {
- diag.attachNote(tileableProducer->getLoc())
- << "tileableProducer must have exactly one destination operand: "
- << *tileableProducer;
- return nullptr;
- }
-
// Search the first use by a "scf::ForeachThreadOp" user.
scf::ForeachThreadOp foreachThreadOp;
auto itProducerUses =
// Replace the use in the tileableProducer before tiling: clone, replace and
// then tile.
+ int64_t resultNumber = pUse->get().cast<OpResult>().getResultNumber();
+ LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+
+ auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
+
BlockAndValueMapping bvm;
- bvm.map(destinationOperands.front(), bbArg);
+ bvm.map(destinationOperands[resultNumber], bbArg);
auto tileableProducerClone =
cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
auto scopeGuard =
// Tile the producer.
FailureOr<Value> tiledProducer =
tileableProducerClone.generateResultTileValue(
- rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
+ rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
if (failed(tiledProducer)) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to tile producer op: " << *tileableProducer;
return nullptr;
}
+ LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n");
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
- rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
+ rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
// Replace the use in containingOp.
rewriter.updateRootInPlace(containingOp, [&]() {
static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
Operation *producerOp,
Operation *containingOp) {
+ LLVM_DEBUG(llvm::dbgs() << "Try to fuse an use by cloning\n");
+
// Gather all uses inside the containing op.
SmallVector<OpOperand *> uses;
for (OpResult result : producerOp->getOpResults()) {
assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
"Parallel insert slice is not a valid clone destination");
unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
+ LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(use->getOwner());
fusedOp = rewriter.clone(*producerOp);
ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
// If nothing to fuse, propagate success.
if (producerOps.empty()) {
- results.set(getResult().cast<OpResult>(), SmallVector<mlir::Operation *>{});
+ results.set(getFusedOp().cast<OpResult>(),
+ SmallVector<mlir::Operation *>{});
return DiagnosedSilenceableFailure::success();
}
- for (Operation *producerOp : producerOps) {
- if (producerOp->getNumResults() != 1) {
- Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
- diag << "op with != 1 results not supported";
- return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
- }
- }
ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
- if (containingOps.size() != 1)
+ if (containingOps.size() != 1) {
+ // Definite failure.
return DiagnosedSilenceableFailure(
this->emitOpError("requires exactly one containing_op handle (got ")
<< containingOps.size() << ")");
+ }
Operation *containingOp = containingOps.front();
// Helper function to find the next producer that should be fused. Take any
while (!remainingProducers.empty()) {
auto nextProducer = getNextProducer();
if (failed(nextProducer)) {
+ results.set(getFusedOp().cast<OpResult>(), ArrayRef<Operation *>());
Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
diag << "could not find next producer to fuse into container";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
Operation *producerOp = *nextProducer;
- // Detaul diagnostic, to be complemented with more failure information.
+ // Default diagnostic, to be complemented with more failure information.
Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
diag << "could not fuse " << *producerOp << " into " << *containingOp;
Operation *tiled =
tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
if (tiled) {
+ LLVM_DEBUG(llvm::dbgs() << "\nFused a direct extract use\n"
+ << *containingOp);
fusedOps.push_back(tiled);
continue;
}
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
rewriter, diag, producerOp, containingOp);
if (tiledContainingOpOperand) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "\nFused an extract use through block argument\n"
+ << *containingOp);
fusedOps.push_back(tiledContainingOpOperand);
continue;
}
Operation *cloned =
cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
if (cloned) {
+ LLVM_DEBUG(llvm::dbgs() << "\nFused an use by cloning\n"
+ << *containingOp);
fusedOps.push_back(cloned);
continue;
}
-
+ results.set(getFusedOp().cast<OpResult>(), ArrayRef<Operation *>());
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
transform.structured.fuse_into_containing_op %0 into %1
}
}
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_multi_output_op
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
+ func.func @fuse_tileable_multi_output_op(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>) -> tensor<?xf32> {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+
+ %0:2 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ } ins(%in : tensor<?xf32>) outs(%out_1, %out_3 : tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ %d = arith.addf %a, %b : f32
+ %e = arith.addf %d, %c : f32
+ linalg.yield %d, %e : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+ %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
+
+ %1 = affine.apply #map0()[%d0, %idx]
+
+ // CHECK: scf.foreach_thread {{.*}} {
+ %2 = scf.foreach_thread (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
+ %3 = affine.apply #map1(%i)[%idx]
+ %4 = affine.min #map2(%i)[%d0, %idx]
+ %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}]
+ // CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}} ins(%[[T0]]
+ %6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]#0
+ %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ scf.foreach_thread.perform_concurrently {
+ tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ // CHECK: }
+ func.return %2 : tensor<?xf32>
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1
+
+ // linalg.generic is tileable. The op is tiled and fused.
+ transform.structured.fuse_into_containing_op %0 into %1
+ }
+}