}
};
+/// Fuse the producer of the source of `candidateSliceOp` by computing the
+/// required slice of the producer in-place.
+struct SCFFuseProducerOfSliceResult {
+ OpResult origProducer; // Original untiled producer.
+ Value tiledAndFusedProducer; // Tile and fused producer value.
+};
+std::optional<SCFFuseProducerOfSliceResult>
+tileAndFuseProducerOfSlice(RewriterBase &rewriter,
+ tensor::ExtractSliceOp candidateSliceOp,
+ MutableArrayRef<scf::ForOp> loops);
+
+/// Reconstruct the fused producer from within the tiled-and-fused code. Based
+/// on the slice of the producer computed in place it is possible that within
+/// the loop nest same slice of the producer is computed multiple times. It is
+/// in general not possible to recompute the value of the fused producer from
+/// the tiled loop code in such cases. For the cases where no slice of the
+/// producer is computed in a redundant fashion it is possible to reconstruct
+/// the value of the original producer from within the tiled loop. It is upto
+/// the caller to ensure that the producer is not computed redundantly within
+/// the tiled loop nest. For example, consider
+///
+/// ```mlir
+/// %0 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
+/// %1 = linalg.matmul ins(%0, ..) outs(...) -> tensor<?x?x?f32>
+/// ```
+///
+/// If `%1` is tiled in a 2D fashion and `%0` is fused with it, the resulting IR
+/// is,
+///
+/// ```mlir
+/// %t1_0 = scf.for .... iter_args(%arg0 = ...) {
+/// %t1_1 = scf.for ... iter_args(%arg1 = %arg0) {
+/// ...
+/// %t1_2 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
+/// %t1_3 = linalg.matmul ins(%t1_2, ...)
+/// %t1_4 = tensor.insert_slice %t1_3 into %arg1 ...
+/// scf.yield %t1_4
+/// }
+/// scf.yield %t1_1
+/// }
+/// ```
+///
+/// Here `%t1_2` is the same for all iterations of the inner `scf.for`. Instead
+/// if `%1` were tiled only along the rows, the resultant code would be
+///
+/// ```mlir
+/// %t2_0 = scf.for .... iter_args(%arg0 = ...) {
+/// ...
+/// %t2_1 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
+/// %t2_2 = linalg.matmul ins(%t2_1, ...)
+/// %t2_3 = tensor.insert_slice %t2_2 into %arg0 ...
+/// scf.yield %t2_3
+/// }
+/// ```
+///
+/// Here there is no intersection in the different slices of `%t2_1` computed
+/// across iterations of the `scf.for`. In such cases, the value of the original
+/// `%0` can be reconstructed from within the loop body. This is useful in cases
+/// where `%0` had other uses as well. If not reconstructed from within the loop
+/// body, uses of `%0` could not be replaced, making it still live and the
+/// fusion immaterial.
+void yieldReplacementForFusedProducer(
+ RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
+ scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
+ MutableArrayRef<scf::ForOp> loops);
+
/// Transformation information returned after tile and fuse.
struct SCFTileAndFuseResult {
/// List of untiled operations that were fused with the tiled consumer.
return {source->get().dyn_cast<OpResult>(), destinationIterArg};
}
-static std::optional<Operation *>
-tileAndFuseProducerOfSlice(RewriterBase &rewriter,
- tensor::ExtractSliceOp candidateSliceOp,
- MutableArrayRef<scf::ForOp> loops) {
+/// Implementation of fusing producer of a single slice by computing the
+/// slice of the producer in-place.
+std::optional<scf::SCFFuseProducerOfSliceResult>
+mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
+ tensor::ExtractSliceOp candidateSliceOp,
+ MutableArrayRef<scf::ForOp> loops) {
// 1. Get the producer of the source (potentially walking through
// `iter_args` of nested `scf.for`)
auto [fusableProducer, destinationIterArg] =
innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
}
}
- return fusedProducerValue->getDefiningOp();
+ return scf::SCFFuseProducerOfSliceResult{fusableProducer,
+ fusedProducerValue.value()};
+}
+
+/// Reconstruct the fused producer from within the tiled-and-fused code.
+void mlir::scf::yieldReplacementForFusedProducer(
+ RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
+ scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
+ MutableArrayRef<scf::ForOp> loops) {
+ auto [fusableProducer, fusedProducerValue] = fusedProducerInfo;
+ SmallVector<Value> initValues;
+ FailureOr<Value> initValue = tensor::getOrCreateDestination(
+ rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
+ if (succeeded(initValue)) {
+ SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes();
+ SmallVector<Value> yieldedVals =
+ yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
+ resultOffsets, resultSizes, loops);
+ }
+ if (auto dstStyleProducer =
+ fusedProducerValue.getDefiningOp<DestinationStyleOpInterface>()) {
+ Value dstValue =
+ dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
+ ->get();
+ updateDestinationOperandsForTiledOp(
+ rewriter, dstValue, loops.back().getRegionIterArgs().back());
+ }
}
/// Implementation of tile consumer and fuse producer greedily.
// The operands of the fused producer might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
- Optional<Operation *> fusedProducer = tileAndFuseProducerOfSlice(
- rewriter, candidateSliceOp, tileAndFuseResult.loops);
+ std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
+ tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
+ tileAndFuseResult.loops);
if (!fusedProducer)
continue;
- tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer.value());
- addCandidateSlices(fusedProducer.value(), candidates);
+ if (Operation *tiledAndFusedOp =
+ fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
+ tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
+ addCandidateSlices(tiledAndFusedOp, candidates);
+ }
}
return tileAndFuseResult;
}
--- /dev/null
+// RUN: mlir-opt -test-tiling-interface=tile-consumer-fuse-and-yield-producer-using-scf-for -cse -split-input-file %s | FileCheck %s
+
+func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>,
+ %init0 : tensor<?x?xf32>, %init1 : tensor<?x?xf32>)
+ -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 0.0 : f32
+ %d0 = tensor.dim %lhs0, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %rhs0, %c1 : tensor<?x?xf32>
+ %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %gemm0 = linalg.matmul
+ ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
+ %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_sequence_fusion_and_yield"}
+ ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %gemm0, %gemm1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+// CHECK: func.func @gemm_gemm_fusion_yield_both(
+// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:[a-zA-Z0-9]+]] =
+// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]])
+// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
+// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0]
+// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
+// CHECK: %[[FILL0_TILE:.+]] = linalg.fill
+// CHECK-SAME: outs(%[[INIT0_TILE]] :
+// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
+// CHECK-SAME: outs(%[[FILL0_TILE]] :
+// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
+// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
+// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
+// CHECK-SAME: outs(%[[INIT1_TILE]] :
+// CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul
+// CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] :
+// CHECK-SAME: outs(%[[FILL1_TILE]] :
+// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG0]][%[[IV]], 0]
+// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
+// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]]
//
//===----------------------------------------------------------------------===//
-#include <utility>
#include <optional>
+#include <utility>
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
LinalgTransformationFilter filter;
};
+/// Pattern to tile a consumer and fuse producer with it
+/// while reconstructing the value of the fused producer
+/// from within the loop nest to replace any external
+/// uses of the producer. In general yielding the producer
+/// this way requires a guarantee that the slice of the producer
+/// is not computed redundantly within the tiled loops. An analysis that
+/// figures it out has shown to be very complex. So this is left as a caller
+/// side determination. In this test pattern it is assumed that the tile sizes
+/// are selected such that all producers when fused into the tiled loops do no
+/// have redundant computation.
+struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
+ : public OpInterfaceRewritePattern<TilingInterface> {
+
+ TestTileConsumerFuseAndYieldProducerUsingSCFForOp(
+ MLIRContext *context, scf::SCFTilingOptions options,
+ LinalgTransformationFilter filter = LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ options(std::move(options)), filter(std::move(filter)) {}
+
+ LogicalResult matchAndRewrite(TilingInterface rootOp,
+ PatternRewriter &rewriter) const override {
+ if (failed(filter.checkAndNotify(rewriter, rootOp)))
+ return failure();
+
+ // Collect list of operations that can be tiled and fused.
+ llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
+ collectTiledAndFusedOps(rootOp);
+ auto isIgnoredUser = [&](Operation *user, scf::ForOp outerMostTiledLoop) {
+ return tiledAndFusedOps.count(user) || isa<tensor::DimOp>(user) ||
+ outerMostTiledLoop->isAncestor(user);
+ };
+
+ // The rest of this method is similar to
+ // scf::tileAndFuseGreedilyUsingSCFForOp, except that also yields
+ // replacements for values of the fused producer.
+
+ // 1. Tile the consumer.
+ SmallVector<OpResult> yieldedValuesToOrigValues;
+ FailureOr<scf::SCFTilingResult> tilingResult =
+ scf::tileUsingSCFForOp(rewriter, rootOp, options);
+ if (failed(tilingResult)) {
+ return rewriter.notifyMatchFailure(rootOp,
+ "failed to tile base operation");
+ }
+ yieldedValuesToOrigValues.append(rootOp->result_begin(),
+ rootOp->result_end());
+
+ // 2. Tiling each operation results in generation of slices. The source of
+ // these slices could be producers that can be fused into the tiled loops by
+ // computing the slices of these producers in-place. This results in more
+ // slices created for operands of the "fused producer". This open up more
+ // opportunities for fusion. Use a worklist to fuse greedily.
+ auto addCandidateSlices =
+ [](Operation *fusedOp, std::deque<tensor::ExtractSliceOp> &candidates) {
+ for (Value operand : fusedOp->getOperands())
+ if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
+ candidates.push_back(sliceOp);
+ };
+
+ std::deque<tensor::ExtractSliceOp> candidates;
+ addCandidateSlices(tilingResult->tiledOps.back(), candidates);
+ OpBuilder::InsertionGuard g(rewriter);
+ while (!candidates.empty()) {
+ // Traverse the slices in BFS fashion.
+ tensor::ExtractSliceOp candidateSliceOp = candidates.front();
+ candidates.pop_front();
+
+ // Materialize the slice of the producer in place.
+ std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
+ tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
+ tilingResult->loops);
+ if (!fusedProducer)
+ continue;
+
+ // Check if the fused producer has other uses that require the value
+ // to be yielded from within the tiled loop.
+ OpResult untiledProducer = fusedProducer->origProducer;
+ if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) {
+ return !isIgnoredUser(user, tilingResult->loops.front());
+ })) {
+ yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
+ fusedProducer.value(),
+ tilingResult->loops);
+ yieldedValuesToOrigValues.push_back(untiledProducer);
+ }
+
+ // Add more fusion candidates to the worklist.
+ if (auto fusedProducerOp =
+ fusedProducer->tiledAndFusedProducer.getDefiningOp())
+ addCandidateSlices(fusedProducerOp, candidates);
+ }
+
+ scf::ForOp outermostLoop = tilingResult->loops.front();
+ for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) {
+ Value replacement = outermostLoop.getResult(index);
+ rewriter.replaceUseIf(origVal, replacement, [&](OpOperand &use) {
+ return !isIgnoredUser(use.getOwner(), outermostLoop);
+ });
+ }
+ rewriter.eraseOp(rootOp);
+ filter.replaceLinalgTransformationFilter(rewriter,
+ tilingResult->tiledOps.back());
+ return success();
+ }
+
+private:
+ /// Starting from `op` walk all operands backwards to find all
+ /// potentially fusable operations, i.e. operations that implement
+ /// the `TilingInterface`.
+ llvm::SmallDenseSet<Operation *>
+ collectTiledAndFusedOps(Operation *op) const {
+ SmallVector<Operation *> worklist;
+ llvm::SmallDenseSet<Operation *> producers;
+ worklist.push_back(op);
+ producers.insert(op);
+ while (!worklist.empty()) {
+ Operation *current = worklist.pop_back_val();
+ for (OpOperand &operand : current->getOpOperands()) {
+ Operation *producer = operand.get().getDefiningOp();
+ if (!producer || !isa<TilingInterface>(producer) ||
+ producers.count(producer))
+ continue;
+ worklist.push_back(producer);
+ producers.insert(producer);
+ }
+ }
+ return producers;
+ }
+
+ scf::SCFTilingOptions options;
+ LinalgTransformationFilter filter;
+};
+
/// Pattern to lower operations that implement the `TilingInterface` to
/// loops/scalar IR using `scf.for`.
struct LowerToLoopsUsingSCFForOp
"Test tiling using TilingInterface with scf.for operations"),
llvm::cl::init(false)};
+ Option<bool> testTileConsumerFuseAndYieldProducer{
+ *this, "tile-consumer-fuse-and-yield-producer-using-scf-for",
+ llvm::cl::desc(
+ "Test tile and fuse transformation while yielding fused producer "
+ "replacements using TilingInterface with scf.for operations"),
+ llvm::cl::init(false)};
+
Option<bool> testTileConsumerAndFuseProducer{
*this, "tile-consumer-and-fuse-producer-using-scf-for",
llvm::cl::desc("Test tile and fuse transformation using TilingInterface "
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
}
+static void addPatternForTileFuseAndYield(MLIRContext *context,
+ RewritePatternSet &patterns,
+ StringRef filterName,
+ ArrayRef<int64_t> tileSizes,
+ ArrayRef<int64_t> interchange = {}) {
+ scf::SCFTilingOptions tilingOptions;
+ tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
+ LinalgTransformationFilter filter(StringAttr::get(context, filterName),
+ StringAttr::get(context, "tiled"));
+ patterns.add<TestTileConsumerFuseAndYieldProducerUsingSCFForOp>(
+ context, tilingOptions, filter);
+}
+
static void addPatternForTileAndFuse(MLIRContext *context,
RewritePatternSet &patterns,
StringRef filterName,
{10});
return;
}
+ if (testTileConsumerFuseAndYieldProducer) {
+ // 1. Fusion of back-to-back-reduction ops
+ addPatternForTileFuseAndYield(context, patterns,
+ "gemm_sequence_fusion_and_yield", {10});
+ return;
+ }
if (testLoweringToScalar) {
patterns.add<LowerToLoopsUsingSCFForOp>(context);
}