[mlir][Linalg] Support multi-output fusion in FuseIntoContainingOp
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 30 Sep 2022 11:09:37 +0000 (04:09 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Fri, 14 Oct 2022 10:54:54 +0000 (03:54 -0700)
This revision adds the ability to fuse tileable ops with multiple results to
the transform.fuse_into_containing_op.

Differential Revision: https://reviews.llvm.org/D135955

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

index be4efaa..5c304f5 100644 (file)
@@ -71,9 +71,8 @@ def FuseIntoContainingOp :
   let description = [{Fuse a producer into a containing operation.}];
 
   let summary = [{
-    Fuses the `producer_op` into the `containing_op`. Only producers with a
-    single result are supported at the moment. Returns a handle to the fused
-    ops.
+    Fuses the `producer_op` into the `containing_op`.
+    Returns a handle to the fused ops.
 
     The producer is typically a slice of a tileable op (i.e., implements
     TilingInterface). In that case, this transform computes the accessed
@@ -98,8 +97,10 @@ def FuseIntoContainingOp :
     This is the case when tiling fails or when no producer op could be found
     among the remaining producers that has at least one use within the
     containing op. I.e., "producers" that are not consumed within the containing
-    op are rejected by this operation. This operation reads and frees the
-    producer handle. It reads the containing op handle.
+    op are rejected by this operation.
+
+    This operation reads and frees the producer handle.
+    This operation reads the containing op handle.
   }];
 
   let arguments = (ins Arg<PDL_Operation, "",
index ed74de7..e47e8e5 100644 (file)
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Debug.h"
 
 using namespace mlir;
 using namespace mlir::linalg;
 using namespace mlir::transform;
 
+#define DEBUG_TYPE "linalg-transforms"
+
 /// Extracts a vector of unsigned from an array attribute. Asserts if the
 /// attribute contains values other than intergers. May truncate.
 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
@@ -258,6 +261,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
                                              Diagnostic &diag,
                                              Operation *producerOp,
                                              Operation *containingOp) {
+  LLVM_DEBUG(llvm::dbgs() << "Try to fuse a direct extract use\n");
   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
   if (!tileableProducer) {
     diag.attachNote(producerOp->getLoc())
@@ -286,18 +290,23 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
   rewriter.setInsertionPoint(sliceOpToTile);
 
   // Tile the producer.
+  int64_t resultNumber =
+      sliceOpToTile.getSource().cast<OpResult>().getResultNumber();
+  LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+
   FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
-      rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
+      rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
       sliceOpToTile.getMixedSizes());
   if (failed(tiledProducer)) {
     diag.attachNote(tileableProducer->getLoc())
         << "failed to tile producer op: " << *tileableProducer;
     return nullptr;
   }
+  LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n");
 
   // Replace the extract op.
   Operation *fusedOp = tiledProducer->getDefiningOp();
-  rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
+  rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
   return fusedOp;
 }
 
@@ -310,6 +319,8 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
 static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
     RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
     Operation *containingOp) {
+  LLVM_DEBUG(
+      llvm::dbgs() << "Try to fuse an extract use through block argument\n");
 
   auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
   if (!tileableProducer) {
@@ -318,16 +329,6 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
     return nullptr;
   }
 
-  // Ensure `tileableProducer` has exactly one destination operand that we can
-  // replace the ForeachThreadOp bbArg with.
-  auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
-  if (destinationOperands.size() != 1) {
-    diag.attachNote(tileableProducer->getLoc())
-        << "tileableProducer must have exactly one destination operand: "
-        << *tileableProducer;
-    return nullptr;
-  }
-
   // Search the first use by a "scf::ForeachThreadOp" user.
   scf::ForeachThreadOp foreachThreadOp;
   auto itProducerUses =
@@ -371,8 +372,13 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
 
   // Replace the use in the tileableProducer before tiling: clone, replace and
   // then tile.
+  int64_t resultNumber = pUse->get().cast<OpResult>().getResultNumber();
+  LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+
+  auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
+
   BlockAndValueMapping bvm;
-  bvm.map(destinationOperands.front(), bbArg);
+  bvm.map(destinationOperands[resultNumber], bbArg);
   auto tileableProducerClone =
       cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
   auto scopeGuard =
@@ -381,17 +387,18 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
   // Tile the producer.
   FailureOr<Value> tiledProducer =
       tileableProducerClone.generateResultTileValue(
-          rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
+          rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
           sliceOpToTile.getMixedSizes());
   if (failed(tiledProducer)) {
     diag.attachNote(tileableProducer->getLoc())
         << "failed to tile producer op: " << *tileableProducer;
     return nullptr;
   }
+  LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n");
 
   // Replace the extract op.
   Operation *fusedOp = tiledProducer->getDefiningOp();
-  rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
+  rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
 
   // Replace the use in containingOp.
   rewriter.updateRootInPlace(containingOp, [&]() {
@@ -405,6 +412,8 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
 static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
                                        Operation *producerOp,
                                        Operation *containingOp) {
+  LLVM_DEBUG(llvm::dbgs() << "Try to fuse an use by cloning\n");
+
   // Gather all uses inside the containing op.
   SmallVector<OpOperand *> uses;
   for (OpResult result : producerOp->getOpResults()) {
@@ -437,6 +446,8 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
   assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
          "Parallel insert slice is not a valid clone destination");
   unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
+  LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+
   OpBuilder::InsertionGuard guard(rewriter);
   rewriter.setInsertionPoint(use->getOwner());
   fusedOp = rewriter.clone(*producerOp);
@@ -453,21 +464,17 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
   ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
   // If nothing to fuse, propagate success.
   if (producerOps.empty()) {
-    results.set(getResult().cast<OpResult>(), SmallVector<mlir::Operation *>{});
+    results.set(getFusedOp().cast<OpResult>(),
+                SmallVector<mlir::Operation *>{});
     return DiagnosedSilenceableFailure::success();
   }
-  for (Operation *producerOp : producerOps) {
-    if (producerOp->getNumResults() != 1) {
-      Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
-      diag << "op with != 1 results not supported";
-      return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
-    }
-  }
   ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
-  if (containingOps.size() != 1)
+  if (containingOps.size() != 1) {
+    // Definite failure.
     return DiagnosedSilenceableFailure(
         this->emitOpError("requires exactly one containing_op handle (got ")
         << containingOps.size() << ")");
+  }
   Operation *containingOp = containingOps.front();
 
   // Helper function to find the next producer that should be fused. Take any
@@ -498,6 +505,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
   while (!remainingProducers.empty()) {
     auto nextProducer = getNextProducer();
     if (failed(nextProducer)) {
+      results.set(getFusedOp().cast<OpResult>(), ArrayRef<Operation *>());
       Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
       diag << "could not find next producer to fuse into container";
       return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
@@ -505,7 +513,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
 
     Operation *producerOp = *nextProducer;
 
-    // Detaul diagnostic, to be complemented with more failure information.
+    // Default diagnostic, to be complemented with more failure information.
     Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
     diag << "could not fuse " << *producerOp << " into " << *containingOp;
 
@@ -517,6 +525,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
     Operation *tiled =
         tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
     if (tiled) {
+      LLVM_DEBUG(llvm::dbgs() << "\nFused a direct extract use\n"
+                              << *containingOp);
       fusedOps.push_back(tiled);
       continue;
     }
@@ -525,6 +535,9 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
         tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
             rewriter, diag, producerOp, containingOp);
     if (tiledContainingOpOperand) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "\nFused an extract use through block argument\n"
+                 << *containingOp);
       fusedOps.push_back(tiledContainingOpOperand);
       continue;
     }
@@ -532,10 +545,12 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
     Operation *cloned =
         cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
     if (cloned) {
+      LLVM_DEBUG(llvm::dbgs() << "\nFused an use by cloning\n"
+                              << *containingOp);
       fusedOps.push_back(cloned);
       continue;
     }
-
+    results.set(getFusedOp().cast<OpResult>(), ArrayRef<Operation *>());
     return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
   }
 
index b1af4ef..141e8f5 100644 (file)
@@ -141,3 +141,63 @@ module {
     transform.structured.fuse_into_containing_op %0 into %1
   }
 }
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+
+module {
+  // CHECK-LABEL: func.func @fuse_tileable_multi_output_op
+  //  CHECK-SAME:   %[[CHUNK_SIZE:[0-9a-z]+]]: index
+  //  CHECK-SAME:   %[[IN:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
+  //  CHECK-SAME:   %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
+  func.func @fuse_tileable_multi_output_op(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>) -> tensor<?xf32> {
+    %cst = arith.constant 4.200000e+01 : f32
+    %c0 = arith.constant 0 : index
+
+    %0:2 = linalg.generic {
+      indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+      iterator_types = ["parallel"]
+    } ins(%in : tensor<?xf32>) outs(%out_1, %out_3 : tensor<?xf32>, tensor<?xf32>) {
+      ^bb0(%a: f32, %b: f32, %c: f32):
+        %d = arith.addf %a, %b : f32
+        %e = arith.addf %d, %c : f32
+        linalg.yield %d, %e : f32, f32
+    } -> (tensor<?xf32>, tensor<?xf32>)
+    %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
+
+    %1 = affine.apply #map0()[%d0, %idx]
+
+    // CHECK: scf.foreach_thread {{.*}} {
+    %2 = scf.foreach_thread (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
+      %3 = affine.apply #map1(%i)[%idx]
+      %4 = affine.min #map2(%i)[%d0, %idx]
+      %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+      // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}]
+      // CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}} ins(%[[T0]]
+      %6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+      // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]#0
+      %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+      scf.foreach_thread.perform_concurrently {
+        tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+      }
+    }
+    // CHECK: }
+    func.return %2 : tensor<?xf32>
+  }
+
+  transform.sequence failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1
+
+    // linalg.generic is tileable. The op is tiled and fused.
+    transform.structured.fuse_into_containing_op %0 into %1
+  }
+}