result.addTypes(transform::AnyOpType::get(builder.getContext()));
}
+/// Add new operands to the forall op for users of the producerOp
+/// that are dominated by the containing scf.forall op.
+static Operation *replaceForAllWithNewSignature(
+ RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
+ Operation *containingOp, TilingResult &tileAndFuseResult,
+ int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
+ SmallVector<OpFoldResult> &sizes) {
+
+ // Count number of users not including the containing op
+ SetVector<Operation *> dominatedUsers;
+ DominanceInfo domInfo(containingOp);
+ for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
+ if ((user != containingOp) && (domInfo.dominates(containingOp, user))) {
+ dominatedUsers.insert(user);
+ }
+ }
+ if (dominatedUsers.size() == 0)
+ return nullptr;
+
+ // Create new scf.forall op
+ auto forallOp = cast<scf::ForallOp>(containingOp);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(forallOp);
+
+ // Get new output
+ Location loc = forallOp.getLoc();
+ auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
+ if (!genericOp)
+ return nullptr;
+ SmallVector<Value> outputs = genericOp.getOutputs();
+ SmallVector<Value> newOuts(forallOp.getOutputs());
+ newOuts.push_back(outputs[resultNumber]);
+
+ // Create new scf.forall op
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+ rewriter.eraseBlock(newforallOp.getBody());
+ newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+ // Add additional block argument for new value being returned
+ newforallOp.getBody()->addArgument(newOuts.back().getType(),
+ newOuts.back().getLoc());
+
+ // Fix terminator
+ scf::InParallelOp terminatorOp = newforallOp.getTerminator();
+ SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
+ terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
+ Operation *firstYieldOp = yieldingOps.front();
+ rewriter.setInsertionPoint(firstYieldOp);
+ Value src = tileAndFuseResult.tiledValues[0];
+ Value dst = newforallOp.getOutputBlockArguments().back();
+ SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
+ rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
+ dst, offsets, sizes, strides);
+
+ for (auto result : llvm::enumerate(forallOp.getResults())) {
+ rewriter.replaceAllUsesWith(result.value(),
+ newforallOp->getResult(result.index()));
+ }
+ rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
+ newforallOp->getResults().back(),
+ [&](OpOperand &use) {
+ Operation *user = use.getOwner();
+ return dominatedUsers.contains(user);
+ });
+ return newforallOp;
+}
+
/// Find the first "extract" user of `producerOp` and tile it right before its
/// use. The tiled op is fused under the `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
-static SmallVector<Operation *>
+/// If tiled op has uses that are dominated by `containingOp`, return
+/// a new `containingOp` with results of the fused op appended to
+/// results of the `containingOp` or nullptr if there are no dominated uses.
+static std::tuple<SmallVector<Operation *>, Operation *>
tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
Operation *producerOp, Operation *containingOp) {
LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
+ SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
+
FailureOr<TilingResult> tileAndFuseResult =
- tileableProducer.generateResultTileValue(rewriter, resultNumber,
- sliceOpToTile.getMixedOffsets(),
- sliceOpToTile.getMixedSizes());
+ tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
+ sizes);
+
if (failed(tileAndFuseResult)) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to tile producer op: " << *tileableProducer;
cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
assert(succeeded(maybeRankReduced) && "unexpected shape");
rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
- return tileAndFuseResult->tiledOps;
+
+ // Add new outputs to containing op, if required
+ Operation *newContainingOp = replaceForAllWithNewSignature(
+ rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
+ resultNumber, offsets, sizes);
+
+ return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
}
/// First, find the first "scf::ForallOp" user of `producerOp` and ensure
// cases, we can tile/clone once and reuse the value for each use.
// Futhermore, producers should then be traversed according to a
// topological sorting.
- SmallVector<Operation *> tiledOps =
+ auto [tiledOps, newContainingOp] =
tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
if (!tiledOps.empty()) {
LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
fusedOps.append(tiledOps);
+ if (newContainingOp) {
+ rewriter.eraseOp(containingOp);
+ containingOp = newContainingOp;
+ }
continue;
}
// CHECK: scf.forall {{.*}} -> (tensor<?xf32>) {
%2 = scf.forall (%arg3) in (%d0) shared_outs(%o = %0) -> (tensor<?xf32>) {
%5 = tensor.extract_slice %o[%arg3] [1] [1] : tensor<?xf32> to tensor<f32>
-
+
// CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [1] [1] : tensor<?xf32> to tensor<1xf32>
// CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<1xf32>) -> tensor<1xf32>
// CHECK: tensor.extract_slice %{{.*}}[0] [1] [1] : tensor<1xf32> to tensor<f32>
transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
}
}
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_multi_output_op_multi_use
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
+ func.func @fuse_tileable_multi_output_op_multi_use(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
+ -> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+
+ // CHECK: %[[G0:.*]]:2 = linalg.generic
+ %0:2 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ } ins(%in : tensor<?xf32>) outs(%out_1, %out_3 : tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ %d = arith.addf %a, %b : f32
+ %e = arith.addf %d, %c : f32
+ linalg.yield %d, %e : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+ %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
+
+ %1 = affine.apply #map0()[%d0, %idx]
+
+ // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
+ // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
+ %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
+ // CHECK: %[[I0:.*]] = affine.apply {{.*}}
+ %3 = affine.apply #map1(%i)[%idx]
+ // CHECK: %[[I1:.*]] = affine.min {{.*}}
+ %4 = affine.min #map2(%i)[%d0, %idx]
+ %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}}
+ %6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice %[[T1]]#0 into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ // CHECK: return %[[R0]]#0, %[[R0]]#1, %[[G0]]#1
+ func.return %2, %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>
+ // CHECK: }
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
+
+ // linalg.generic is tileable. The op is tiled and fused.
+ transform.structured.fuse_into_containing_op %0 into %1
+ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+ }
+}
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_mixed_dominating_uses
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
+ func.func @fuse_tileable_mixed_dominating_uses(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
+ -> (tensor<?xf32>, tensor<?xf32>) {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+
+ // CHECK: %[[G0:.*]] = linalg.generic
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ } ins(%in : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) {
+ ^bb0(%a: f32, %b: f32):
+ %d = arith.addf %a, %b : f32
+ linalg.yield %d : f32
+ } -> tensor<?xf32>
+ // CHECK: %[[D0:.*]] = tensor.dim %[[G0]]
+ %d0 = tensor.dim %0, %c0 : tensor<?xf32>
+
+ %1 = affine.apply #map0()[%d0, %idx]
+
+ // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
+ // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
+ %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
+ // CHECK: %[[I0:.*]] = affine.apply {{.*}}
+ %3 = affine.apply #map1(%i)[%idx]
+ // CHECK: %[[I1:.*]] = affine.min {{.*}}
+ %4 = affine.min #map2(%i)[%d0, %idx]
+ %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
+ %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ // CHECK: return %[[R0]]#0, %[[R0]]#1
+ func.return %2, %0 : tensor<?xf32>, tensor<?xf32>
+ // CHECK: }
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
+
+ // linalg.generic is tileable. The op is tiled and fused.
+ transform.structured.fuse_into_containing_op %0 into %1
+ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+ }
+}
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+#map3 = affine_map<(d0, d1) -> (d0, d1)>
+#map4 = affine_map<(d0, d1) -> (d0)>
+
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_reductions
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?x?xf32>
+ // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
+ func.func @fuse_tileable_reductions(%idx: index, %in: tensor<?x?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
+ -> (tensor<?xf32>, tensor<?xf32>) {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+
+ %0 = linalg.generic {
+ indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]
+ } ins(%in : tensor<?x?xf32>) outs(%out_1 : tensor<?xf32>) {
+ ^bb0(%a: f32, %b: f32):
+ %d = arith.maxf %a, %b : f32
+ linalg.yield %d : f32
+ } -> tensor<?xf32>
+ %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
+
+ %1 = affine.apply #map0()[%d0, %idx]
+
+ // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
+ // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
+ %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
+ // CHECK: %[[I0:.*]] = affine.apply {{.*}}
+ %3 = affine.apply #map1(%i)[%idx]
+ // CHECK: %[[I1:.*]] = affine.min {{.*}}
+ %4 = affine.min #map2(%i)[%d0, %idx]
+ %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
+ %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
+ tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ // CHECK: return %[[R0]]#0, %[[R0]]#1
+ func.return %2, %0 : tensor<?xf32>, tensor<?xf32>
+ // CHECK: }
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
+
+ // linalg.generic is tileable. The op is tiled and fused.
+ transform.structured.fuse_into_containing_op %0 into %1
+ : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+ }
+}