[mlir][TilingInterface] Add an option to tile and fuse to yield replacement for the...
authorMahesh Ravishankar <ravishankarm@google.com>
Thu, 5 Jan 2023 00:57:50 +0000 (00:57 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Mon, 16 Jan 2023 18:30:13 +0000 (18:30 +0000)
This patch adds an option to the method that fuses a producer with a
tiled consumer, to also yield from the tiled loops a value that can be
used to replace the original producer. This is only valid if it can be
assertained that the slice of the producer computed within each
iteration of the tiled loop nest does not compute slices of the
producer redundantly. The analysis to derive this is very involved. So
this is left to the caller to assertain.  A test is added that mimics
the `scf::tileConsumerAndFuseProducersGreedilyUsingSCFForOp`, but also
yields the values of all fused producers. This can be used as a
reference for how a caller could use this functionality.

Differential Revision: https://reviews.llvm.org/D141028

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir [new file with mode: 0644]
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

index 151993c..63d3da7 100644 (file)
@@ -89,6 +89,72 @@ struct SCFTileAndFuseOptions {
   }
 };
 
+/// Fuse the producer of the source of `candidateSliceOp` by computing the
+/// required slice of the producer in-place.
+struct SCFFuseProducerOfSliceResult {
+  OpResult origProducer;       // Original untiled producer.
+  Value tiledAndFusedProducer; // Tile and fused producer value.
+};
+std::optional<SCFFuseProducerOfSliceResult>
+tileAndFuseProducerOfSlice(RewriterBase &rewriter,
+                           tensor::ExtractSliceOp candidateSliceOp,
+                           MutableArrayRef<scf::ForOp> loops);
+
+/// Reconstruct the fused producer from within the tiled-and-fused code. Based
+/// on the slice of the producer computed in place it is possible that within
+/// the loop nest same slice of the producer is computed multiple times. It is
+/// in general not possible to recompute the value of the fused producer from
+/// the tiled loop code in such cases. For the cases where no slice of the
+/// producer is computed in a redundant fashion it is possible to reconstruct
+/// the value of the original producer from within the tiled loop. It is upto
+/// the caller to ensure that the producer is not computed redundantly within
+/// the tiled loop nest. For example, consider
+///
+/// ```mlir
+/// %0 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
+/// %1 = linalg.matmul ins(%0, ..) outs(...) -> tensor<?x?x?f32>
+/// ```
+///
+/// If `%1` is tiled in a 2D fashion and `%0` is fused with it, the resulting IR
+/// is,
+///
+/// ```mlir
+/// %t1_0 = scf.for .... iter_args(%arg0 = ...) {
+///   %t1_1 = scf.for ... iter_args(%arg1 = %arg0) {
+///     ...
+///     %t1_2 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
+///     %t1_3 = linalg.matmul ins(%t1_2, ...)
+///     %t1_4 = tensor.insert_slice %t1_3 into %arg1 ...
+///     scf.yield %t1_4
+///   }
+///   scf.yield %t1_1
+/// }
+/// ```
+///
+/// Here `%t1_2` is the same for all iterations of the inner `scf.for`. Instead
+/// if `%1` were tiled only along the rows, the resultant code would be
+///
+/// ```mlir
+/// %t2_0 = scf.for .... iter_args(%arg0 = ...) {
+///   ...
+///   %t2_1 = linalg.matmul ins(...) outs(...) -> tensor<?x?xf32>
+///   %t2_2 = linalg.matmul ins(%t2_1, ...)
+///   %t2_3 = tensor.insert_slice %t2_2 into %arg0 ...
+///   scf.yield %t2_3
+/// }
+/// ```
+///
+/// Here there is no intersection in the different slices of `%t2_1` computed
+/// across iterations of the `scf.for`. In such cases, the value of the original
+/// `%0` can be reconstructed from within the loop body. This is useful in cases
+/// where `%0` had other uses as well. If not reconstructed from within the loop
+/// body, uses of `%0` could not be replaced, making it still live and the
+/// fusion immaterial.
+void yieldReplacementForFusedProducer(
+    RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
+    scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
+    MutableArrayRef<scf::ForOp> loops);
+
 /// Transformation information returned after tile and fuse.
 struct SCFTileAndFuseResult {
   /// List of untiled operations that were fused with the tiled consumer.
index dd0ed44..0c3da3b 100644 (file)
@@ -505,10 +505,12 @@ getUntiledProducerFromSliceSource(OpOperand *source,
   return {source->get().dyn_cast<OpResult>(), destinationIterArg};
 }
 
-static std::optional<Operation *>
-tileAndFuseProducerOfSlice(RewriterBase &rewriter,
-                           tensor::ExtractSliceOp candidateSliceOp,
-                           MutableArrayRef<scf::ForOp> loops) {
+/// Implementation of fusing producer of a single slice by computing the
+/// slice of the producer in-place.
+std::optional<scf::SCFFuseProducerOfSliceResult>
+mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
+                                      tensor::ExtractSliceOp candidateSliceOp,
+                                      MutableArrayRef<scf::ForOp> loops) {
   // 1. Get the producer of the source (potentially walking through
   // `iter_args` of nested `scf.for`)
   auto [fusableProducer, destinationIterArg] =
@@ -597,7 +599,34 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
           innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
     }
   }
-  return fusedProducerValue->getDefiningOp();
+  return scf::SCFFuseProducerOfSliceResult{fusableProducer,
+                                           fusedProducerValue.value()};
+}
+
+/// Reconstruct the fused producer from within the tiled-and-fused code.
+void mlir::scf::yieldReplacementForFusedProducer(
+    RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
+    scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
+    MutableArrayRef<scf::ForOp> loops) {
+  auto [fusableProducer, fusedProducerValue] = fusedProducerInfo;
+  SmallVector<Value> initValues;
+  FailureOr<Value> initValue = tensor::getOrCreateDestination(
+      rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
+  if (succeeded(initValue)) {
+    SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes();
+    SmallVector<Value> yieldedVals =
+        yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
+                         resultOffsets, resultSizes, loops);
+  }
+  if (auto dstStyleProducer =
+          fusedProducerValue.getDefiningOp<DestinationStyleOpInterface>()) {
+    Value dstValue =
+        dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
+            ->get();
+    updateDestinationOperandsForTiledOp(
+        rewriter, dstValue, loops.back().getRegionIterArgs().back());
+  }
 }
 
 /// Implementation of tile consumer and fuse producer greedily.
@@ -661,13 +690,17 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
     // The operands of the fused producer might themselved be slices of
     // values produced by operations that implement the `TilingInterface`.
     // Add these operations to the worklist.
-    Optional<Operation *> fusedProducer = tileAndFuseProducerOfSlice(
-        rewriter, candidateSliceOp, tileAndFuseResult.loops);
+    std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
+        tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
+                                   tileAndFuseResult.loops);
     if (!fusedProducer)
       continue;
 
-    tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer.value());
-    addCandidateSlices(fusedProducer.value(), candidates);
+    if (Operation *tiledAndFusedOp =
+            fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
+      tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
+      addCandidateSlices(tiledAndFusedOp, candidates);
+    }
   }
   return tileAndFuseResult;
 }
diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
new file mode 100644 (file)
index 0000000..f47850a
--- /dev/null
@@ -0,0 +1,47 @@
+// RUN: mlir-opt -test-tiling-interface=tile-consumer-fuse-and-yield-producer-using-scf-for -cse -split-input-file %s | FileCheck %s
+
+func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %rhs1 : tensor<?x?xf32>,
+    %init0 : tensor<?x?xf32>, %init1 : tensor<?x?xf32>)
+    -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %cst = arith.constant 0.0 : f32
+  %d0 = tensor.dim %lhs0, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %rhs0, %c1 : tensor<?x?xf32>
+  %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %gemm0 = linalg.matmul
+      ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
+  %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %gemm1 = linalg.matmul  {__internal_linalg_transform__ = "gemm_sequence_fusion_and_yield"}
+      ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %gemm0, %gemm1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+//      CHECK: func.func @gemm_gemm_fusion_yield_both(
+// CHECK-SAME:     %[[LHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME:     %[[RHS0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
+// CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
+// CHECK-SAME:     %[[INIT0:[a-zA-Z0-9]+]]: tensor<?x?xf32>,
+// CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
+//  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//      CHECK:   %[[RESULT:.+]]:2 = scf.for %[[IV:[a-zA-Z0-9]+]] =
+// CHECK-SAME:       iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]])
+//  CHECK-DAG:     %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
+//  CHECK-DAG:     %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0]
+//  CHECK-DAG:     %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
+//      CHECK:     %[[FILL0_TILE:.+]] = linalg.fill
+// CHECK-SAME:         outs(%[[INIT0_TILE]] :
+//      CHECK:     %[[GEMM0_TILE:.+]] = linalg.matmul
+// CHECK-SAME:         ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
+// CHECK-SAME:         outs(%[[FILL0_TILE]] :
+//  CHECK-DAG:     %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
+//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
+//      CHECK:     %[[FILL1_TILE:.+]] = linalg.fill
+// CHECK-SAME:         outs(%[[INIT1_TILE]] :
+//      CHECK:     %[[GEMM1_TILE:.+]] = linalg.matmul
+// CHECK-SAME:         ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] :
+// CHECK-SAME:         outs(%[[FILL1_TILE]] :
+//      CHECK:     %[[INSERT0:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG0]][%[[IV]], 0]
+//      CHECK:     %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
+//      CHECK:     scf.yield %[[INSERT0]], %[[INSERT1]]
index d862d0b..97fd634 100644 (file)
@@ -11,8 +11,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <utility>
 #include <optional>
+#include <utility>
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -239,6 +239,140 @@ private:
   LinalgTransformationFilter filter;
 };
 
+/// Pattern to tile a consumer and fuse producer with it
+/// while reconstructing the value of the fused producer
+/// from within the loop nest to replace any external
+/// uses of the producer. In general yielding the producer
+/// this way requires a guarantee that the slice of the producer
+/// is not computed redundantly within the tiled loops. An analysis that
+/// figures it out has shown to be very complex. So this is left as a caller
+/// side determination. In this test pattern it is assumed that the tile sizes
+/// are selected such that all producers when fused into the tiled loops do no
+/// have redundant computation.
+struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
+    : public OpInterfaceRewritePattern<TilingInterface> {
+
+  TestTileConsumerFuseAndYieldProducerUsingSCFForOp(
+      MLIRContext *context, scf::SCFTilingOptions options,
+      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      PatternBenefit benefit = 1)
+      : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+        options(std::move(options)), filter(std::move(filter)) {}
+
+  LogicalResult matchAndRewrite(TilingInterface rootOp,
+                                PatternRewriter &rewriter) const override {
+    if (failed(filter.checkAndNotify(rewriter, rootOp)))
+      return failure();
+
+    // Collect list of operations that can be tiled and fused.
+    llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
+        collectTiledAndFusedOps(rootOp);
+    auto isIgnoredUser = [&](Operation *user, scf::ForOp outerMostTiledLoop) {
+      return tiledAndFusedOps.count(user) || isa<tensor::DimOp>(user) ||
+             outerMostTiledLoop->isAncestor(user);
+    };
+
+    // The rest of this method is similar to
+    // scf::tileAndFuseGreedilyUsingSCFForOp, except that also yields
+    // replacements for values of the fused producer.
+
+    // 1. Tile the consumer.
+    SmallVector<OpResult> yieldedValuesToOrigValues;
+    FailureOr<scf::SCFTilingResult> tilingResult =
+        scf::tileUsingSCFForOp(rewriter, rootOp, options);
+    if (failed(tilingResult)) {
+      return rewriter.notifyMatchFailure(rootOp,
+                                         "failed to tile base operation");
+    }
+    yieldedValuesToOrigValues.append(rootOp->result_begin(),
+                                     rootOp->result_end());
+
+    // 2. Tiling each operation results in generation of slices. The source of
+    // these slices could be producers that can be fused into the tiled loops by
+    // computing the slices of these producers in-place. This results in more
+    // slices created for operands of the "fused producer". This open up more
+    // opportunities for fusion. Use a worklist to fuse greedily.
+    auto addCandidateSlices =
+        [](Operation *fusedOp, std::deque<tensor::ExtractSliceOp> &candidates) {
+          for (Value operand : fusedOp->getOperands())
+            if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
+              candidates.push_back(sliceOp);
+        };
+
+    std::deque<tensor::ExtractSliceOp> candidates;
+    addCandidateSlices(tilingResult->tiledOps.back(), candidates);
+    OpBuilder::InsertionGuard g(rewriter);
+    while (!candidates.empty()) {
+      // Traverse the slices in BFS fashion.
+      tensor::ExtractSliceOp candidateSliceOp = candidates.front();
+      candidates.pop_front();
+
+      // Materialize the slice of the producer in place.
+      std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
+          tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
+                                     tilingResult->loops);
+      if (!fusedProducer)
+        continue;
+
+      // Check if the fused producer has other uses that require the value
+      // to be yielded from within the tiled loop.
+      OpResult untiledProducer = fusedProducer->origProducer;
+      if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) {
+            return !isIgnoredUser(user, tilingResult->loops.front());
+          })) {
+        yieldReplacementForFusedProducer(rewriter, candidateSliceOp,
+                                         fusedProducer.value(),
+                                         tilingResult->loops);
+        yieldedValuesToOrigValues.push_back(untiledProducer);
+      }
+
+      // Add more fusion candidates to the worklist.
+      if (auto fusedProducerOp =
+              fusedProducer->tiledAndFusedProducer.getDefiningOp())
+        addCandidateSlices(fusedProducerOp, candidates);
+    }
+
+    scf::ForOp outermostLoop = tilingResult->loops.front();
+    for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) {
+      Value replacement = outermostLoop.getResult(index);
+      rewriter.replaceUseIf(origVal, replacement, [&](OpOperand &use) {
+        return !isIgnoredUser(use.getOwner(), outermostLoop);
+      });
+    }
+    rewriter.eraseOp(rootOp);
+    filter.replaceLinalgTransformationFilter(rewriter,
+                                             tilingResult->tiledOps.back());
+    return success();
+  }
+
+private:
+  /// Starting from `op` walk all operands backwards to find all
+  /// potentially fusable operations, i.e. operations that implement
+  /// the `TilingInterface`.
+  llvm::SmallDenseSet<Operation *>
+  collectTiledAndFusedOps(Operation *op) const {
+    SmallVector<Operation *> worklist;
+    llvm::SmallDenseSet<Operation *> producers;
+    worklist.push_back(op);
+    producers.insert(op);
+    while (!worklist.empty()) {
+      Operation *current = worklist.pop_back_val();
+      for (OpOperand &operand : current->getOpOperands()) {
+        Operation *producer = operand.get().getDefiningOp();
+        if (!producer || !isa<TilingInterface>(producer) ||
+            producers.count(producer))
+          continue;
+        worklist.push_back(producer);
+        producers.insert(producer);
+      }
+    }
+    return producers;
+  }
+
+  scf::SCFTilingOptions options;
+  LinalgTransformationFilter filter;
+};
+
 /// Pattern to lower operations that implement the `TilingInterface` to
 /// loops/scalar IR using `scf.for`.
 struct LowerToLoopsUsingSCFForOp
@@ -283,6 +417,13 @@ struct TestTilingInterfacePass
           "Test tiling using TilingInterface with scf.for operations"),
       llvm::cl::init(false)};
 
+  Option<bool> testTileConsumerFuseAndYieldProducer{
+      *this, "tile-consumer-fuse-and-yield-producer-using-scf-for",
+      llvm::cl::desc(
+          "Test tile and fuse transformation while yielding fused producer "
+          "replacements using TilingInterface with scf.for operations"),
+      llvm::cl::init(false)};
+
   Option<bool> testTileConsumerAndFuseProducer{
       *this, "tile-consumer-and-fuse-producer-using-scf-for",
       llvm::cl::desc("Test tile and fuse transformation using TilingInterface "
@@ -314,6 +455,19 @@ static void addPatternForTiling(MLIRContext *context,
   patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
 }
 
+static void addPatternForTileFuseAndYield(MLIRContext *context,
+                                          RewritePatternSet &patterns,
+                                          StringRef filterName,
+                                          ArrayRef<int64_t> tileSizes,
+                                          ArrayRef<int64_t> interchange = {}) {
+  scf::SCFTilingOptions tilingOptions;
+  tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
+  LinalgTransformationFilter filter(StringAttr::get(context, filterName),
+                                    StringAttr::get(context, "tiled"));
+  patterns.add<TestTileConsumerFuseAndYieldProducerUsingSCFForOp>(
+      context, tilingOptions, filter);
+}
+
 static void addPatternForTileAndFuse(MLIRContext *context,
                                      RewritePatternSet &patterns,
                                      StringRef filterName,
@@ -375,6 +529,12 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
                              {10});
     return;
   }
+  if (testTileConsumerFuseAndYieldProducer) {
+    // 1. Fusion of back-to-back-reduction ops
+    addPatternForTileFuseAndYield(context, patterns,
+                                  "gemm_sequence_fusion_and_yield", {10});
+    return;
+  }
   if (testLoweringToScalar) {
     patterns.add<LowerToLoopsUsingSCFForOp>(context);
   }