[mlir] Add support for multiple uses in transform.structured.fuse_into_containing_op
authorHarsh Menon <harsh@nod-labs.com>
Wed, 24 May 2023 01:03:59 +0000 (18:03 -0700)
committerHarsh Menon <harsh@nod-labs.com>
Wed, 24 May 2023 15:36:41 +0000 (08:36 -0700)
In the tile and fuse of the first extract use, we add support
for scenarios where the results of the tiled op have uses
that are dominated by the scf.for_all. Specifically, we replace
the scf.for_all with a new scf.for_all that has an additional
shared_out and add the appropriate parallel insert slice op.

Differential Revision: https://reviews.llvm.org/D151275

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

index c1a24a4..4f476d1 100644 (file)
@@ -347,10 +347,82 @@ void transform::FuseIntoContainingOp::build(OpBuilder &builder,
   result.addTypes(transform::AnyOpType::get(builder.getContext()));
 }
 
+/// Add new operands to the forall op for users of the producerOp
+/// that are dominated by the containing scf.forall op.
+static Operation *replaceForAllWithNewSignature(
+    RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
+    Operation *containingOp, TilingResult &tileAndFuseResult,
+    int64_t resultNumber, SmallVector<OpFoldResult> &offsets,
+    SmallVector<OpFoldResult> &sizes) {
+
+  // Count number of users not including the containing op
+  SetVector<Operation *> dominatedUsers;
+  DominanceInfo domInfo(containingOp);
+  for (Operation *user : producerOp->getResult(resultNumber).getUsers()) {
+    if ((user != containingOp) && (domInfo.dominates(containingOp, user))) {
+      dominatedUsers.insert(user);
+    }
+  }
+  if (dominatedUsers.size() == 0)
+    return nullptr;
+
+  // Create new scf.forall op
+  auto forallOp = cast<scf::ForallOp>(containingOp);
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(forallOp);
+
+  // Get new output
+  Location loc = forallOp.getLoc();
+  auto genericOp = dyn_cast<linalg::GenericOp>(producerOp);
+  if (!genericOp)
+    return nullptr;
+  SmallVector<Value> outputs = genericOp.getOutputs();
+  SmallVector<Value> newOuts(forallOp.getOutputs());
+  newOuts.push_back(outputs[resultNumber]);
+
+  // Create new scf.forall op
+  auto newforallOp = rewriter.create<scf::ForallOp>(
+      loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+      forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+  rewriter.eraseBlock(newforallOp.getBody());
+  newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+  // Add additional block argument for new value being returned
+  newforallOp.getBody()->addArgument(newOuts.back().getType(),
+                                     newOuts.back().getLoc());
+
+  // Fix terminator
+  scf::InParallelOp terminatorOp = newforallOp.getTerminator();
+  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
+      terminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
+  Operation *firstYieldOp = yieldingOps.front();
+  rewriter.setInsertionPoint(firstYieldOp);
+  Value src = tileAndFuseResult.tiledValues[0];
+  Value dst = newforallOp.getOutputBlockArguments().back();
+  SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
+  rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
+                                                 dst, offsets, sizes, strides);
+
+  for (auto result : llvm::enumerate(forallOp.getResults())) {
+    rewriter.replaceAllUsesWith(result.value(),
+                                newforallOp->getResult(result.index()));
+  }
+  rewriter.replaceUsesWithIf(producerOp->getResult(resultNumber),
+                             newforallOp->getResults().back(),
+                             [&](OpOperand &use) {
+                               Operation *user = use.getOwner();
+                               return dominatedUsers.contains(user);
+                             });
+  return newforallOp;
+}
+
 /// Find the first "extract" user of `producerOp` and tile it right before its
 /// use. The tiled op is fused under the `containingOp`.
 /// Return this fused op on success or nullptr if anything fails.
-static SmallVector<Operation *>
+/// If tiled op has uses that are dominated by `containingOp`, return
+/// a new `containingOp` with results of the fused op appended to
+/// results of the `containingOp` or nullptr if there are no dominated uses.
+static std::tuple<SmallVector<Operation *>, Operation *>
 tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
                            Operation *producerOp, Operation *containingOp) {
   LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n");
@@ -386,10 +458,13 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
       cast<OpResult>(sliceOpToTile.getSource()).getResultNumber();
   LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
 
+  SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes();
+
   FailureOr<TilingResult> tileAndFuseResult =
-      tileableProducer.generateResultTileValue(rewriter, resultNumber,
-                                               sliceOpToTile.getMixedOffsets(),
-                                               sliceOpToTile.getMixedSizes());
+      tileableProducer.generateResultTileValue(rewriter, resultNumber, offsets,
+                                               sizes);
+
   if (failed(tileAndFuseResult)) {
     diag.attachNote(tileableProducer->getLoc())
         << "failed to tile producer op: " << *tileableProducer;
@@ -408,7 +483,13 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
       cast<RankedTensorType>(sliceOpToTile->getResult(0).getType()).getShape());
   assert(succeeded(maybeRankReduced) && "unexpected shape");
   rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
-  return tileAndFuseResult->tiledOps;
+
+  // Add new outputs to containing op, if required
+  Operation *newContainingOp = replaceForAllWithNewSignature(
+      rewriter, diag, producerOp, containingOp, *tileAndFuseResult,
+      resultNumber, offsets, sizes);
+
+  return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
 }
 
 /// First, find the first "scf::ForallOp" user of `producerOp` and ensure
@@ -635,11 +716,15 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
     // cases, we can tile/clone once and reuse the value for each use.
     // Futhermore, producers should then be traversed according to a
     // topological sorting.
-    SmallVector<Operation *> tiledOps =
+    auto [tiledOps, newContainingOp] =
         tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
     if (!tiledOps.empty()) {
       LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp);
       fusedOps.append(tiledOps);
+      if (newContainingOp) {
+        rewriter.eraseOp(containingOp);
+        containingOp = newContainingOp;
+      }
       continue;
     }
 
index 29cacb4..f3f4802 100644 (file)
@@ -116,7 +116,7 @@ module {
     // CHECK: scf.forall {{.*}} -> (tensor<?xf32>) {
     %2 = scf.forall (%arg3) in (%d0) shared_outs(%o = %0) -> (tensor<?xf32>) {
       %5 = tensor.extract_slice %o[%arg3] [1] [1] : tensor<?xf32> to tensor<f32>
-      
+
       // CHECK: tensor.extract_slice %{{.*}}[%{{.*}}] [1] [1] : tensor<?xf32> to tensor<1xf32>
       // CHECK: linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<1xf32>) -> tensor<1xf32>
       // CHECK: tensor.extract_slice %{{.*}}[0] [1] [1] : tensor<1xf32> to tensor<f32>
@@ -288,3 +288,200 @@ module {
     transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
   }
 }
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+
+module {
+  // CHECK-LABEL: func.func @fuse_tileable_multi_output_op_multi_use
+  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
+  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
+  func.func @fuse_tileable_multi_output_op_multi_use(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
+    -> (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) {
+    %cst = arith.constant 4.200000e+01 : f32
+    %c0 = arith.constant 0 : index
+
+    // CHECK: %[[G0:.*]]:2 = linalg.generic
+    %0:2 = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]
+    } ins(%in : tensor<?xf32>) outs(%out_1, %out_3 : tensor<?xf32>, tensor<?xf32>) {
+      ^bb0(%a: f32, %b: f32, %c: f32):
+        %d = arith.addf %a, %b : f32
+        %e = arith.addf %d, %c : f32
+        linalg.yield %d, %e : f32, f32
+    } -> (tensor<?xf32>, tensor<?xf32>)
+    %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
+
+    %1 = affine.apply #map0()[%d0, %idx]
+
+    // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
+    // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
+    %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
+      // CHECK: %[[I0:.*]] = affine.apply {{.*}}
+      %3 = affine.apply #map1(%i)[%idx]
+      // CHECK: %[[I1:.*]] = affine.min {{.*}}
+      %4 = affine.min #map2(%i)[%d0, %idx]
+      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+      // CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}}
+      %6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+      scf.forall.in_parallel {
+        // CHECK: tensor.parallel_insert_slice %[[T1]]#0 into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
+        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+      }
+    }
+    // CHECK: return %[[R0]]#0, %[[R0]]#1, %[[G0]]#1
+    func.return %2, %0#0, %0#1 : tensor<?xf32>, tensor<?xf32>, tensor<?xf32>
+    // CHECK: }
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb1(%arg1: !transform.any_op):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
+
+    // linalg.generic is tileable. The op is tiled and fused.
+    transform.structured.fuse_into_containing_op %0 into %1
+      : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+  }
+}
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+
+module {
+  // CHECK-LABEL: func.func @fuse_tileable_mixed_dominating_uses
+  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
+  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
+  func.func @fuse_tileable_mixed_dominating_uses(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
+    -> (tensor<?xf32>, tensor<?xf32>) {
+    %cst = arith.constant 4.200000e+01 : f32
+    %c0 = arith.constant 0 : index
+
+    // CHECK: %[[G0:.*]] = linalg.generic
+    %0 = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]
+    } ins(%in : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) {
+      ^bb0(%a: f32, %b: f32):
+        %d = arith.addf %a, %b : f32
+        linalg.yield %d : f32
+    } -> tensor<?xf32>
+    // CHECK: %[[D0:.*]] = tensor.dim %[[G0]]
+    %d0 = tensor.dim %0, %c0 : tensor<?xf32>
+
+    %1 = affine.apply #map0()[%d0, %idx]
+
+    // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
+    // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
+    %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
+      // CHECK: %[[I0:.*]] = affine.apply {{.*}}
+      %3 = affine.apply #map1(%i)[%idx]
+      // CHECK: %[[I1:.*]] = affine.min {{.*}}
+      %4 = affine.min #map2(%i)[%d0, %idx]
+      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+      // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
+      %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+      scf.forall.in_parallel {
+        // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
+        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+      }
+    }
+    // CHECK: return %[[R0]]#0, %[[R0]]#1
+    func.return %2, %0 : tensor<?xf32>, tensor<?xf32>
+    // CHECK: }
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb1(%arg1: !transform.any_op):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
+
+    // linalg.generic is tileable. The op is tiled and fused.
+    transform.structured.fuse_into_containing_op %0 into %1
+      : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+  }
+}
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+#map3 = affine_map<(d0, d1) -> (d0, d1)>
+#map4 = affine_map<(d0, d1) -> (d0)>
+
+module {
+  // CHECK-LABEL: func.func @fuse_tileable_reductions
+  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
+  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?x?xf32>
+  //  CHECK-SAME:   %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
+  func.func @fuse_tileable_reductions(%idx: index, %in: tensor<?x?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
+    -> (tensor<?xf32>, tensor<?xf32>) {
+    %cst = arith.constant 4.200000e+01 : f32
+    %c0 = arith.constant 0 : index
+
+    %0 = linalg.generic {
+      indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]
+      } ins(%in : tensor<?x?xf32>) outs(%out_1 : tensor<?xf32>) {
+        ^bb0(%a: f32, %b: f32):
+          %d = arith.maxf %a, %b : f32
+          linalg.yield %d : f32
+        } -> tensor<?xf32>
+    %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
+
+    %1 = affine.apply #map0()[%d0, %idx]
+
+    // CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
+    // CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
+    %2 = scf.forall (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
+      // CHECK: %[[I0:.*]] = affine.apply {{.*}}
+      %3 = affine.apply #map1(%i)[%idx]
+      // CHECK: %[[I1:.*]] = affine.min {{.*}}
+      %4 = affine.min #map2(%i)[%d0, %idx]
+      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+      // CHECK: %[[T1:.*]] = linalg.generic {{.*}}
+      %6 = tensor.extract_slice %0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+      scf.forall.in_parallel {
+        // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
+        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+      }
+    }
+    // CHECK: return %[[R0]]#0, %[[R0]]#1
+    func.return %2, %0 : tensor<?xf32>, tensor<?xf32>
+    // CHECK: }
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb1(%arg1: !transform.any_op):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
+    %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
+
+    // linalg.generic is tileable. The op is tiled and fused.
+    transform.structured.fuse_into_containing_op %0 into %1
+      : (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
+  }
+}