[mlir][Linalg] Add pattern to tile and fuse Linalg operations on buffers.
authorMaheshRavishankar <ravishankarm@google.com>
Wed, 30 Sep 2020 21:55:59 +0000 (14:55 -0700)
committerMaheshRavishankar <ravishankarm@google.com>
Wed, 30 Sep 2020 21:56:58 +0000 (14:56 -0700)
The pattern is structured similar to other patterns like
LinalgTilingPattern. The fusion patterns takes options that allows you
to fuse with producers of multiple operands at once.
- The pattern fuses only at the level that is known to be legal, i.e
  if a reduction loop in the consumer is tiled, then fusion should
  happen "before" this loop. Some refactoring of the fusion code is
  needed to fuse only where it is legal.
- Since the fusion on buffers uses the LinalgDependenceGraph that is
  not mutable in place the fusion pattern keeps the original
  operations in the IR, but are tagged with a marker that can be later
  used to find the original operations.

This change also fixes an issue with tiling and
distribution/interchange where if the tile size of a loop were 0 it
wasnt account for in these.

Differential Revision: https://reviews.llvm.org/D88435

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/fusion-pattern.mlir [new file with mode: 0644]
mlir/test/lib/Transforms/CMakeLists.txt
mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp [new file with mode: 0644]
mlir/tools/mlir-opt/mlir-opt.cpp

index 23d296c..f51f7b9 100644 (file)
@@ -459,6 +459,24 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
             }));
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the position of buffer in inputs + outputs list
+      }],
+      /*retTy=*/"Optional<unsigned>",
+      /*methodName=*/"getIndexOfInputAndOutputBuffer",
+      /*args=*/(ins "Value":$value),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        Optional<unsigned> inputIndex = getIndexOfInput(value);
+        if (inputIndex.hasValue()) return inputIndex.getValue();
+        Optional<unsigned> outputIndex = getIndexOfOutputBuffer(value);
+        if (outputIndex.hasValue()) {
+          return $_op.getNumInputs() + outputIndex.getValue();
+        }
+        return llvm::None;
+      }]
+    >,
 
     //===------------------------------------------------------------------===//
     // Other interface methods.
index 00a094d..a7f8c31 100644 (file)
@@ -18,6 +18,7 @@
 namespace mlir {
 namespace linalg {
 
+struct LinalgFusionOptions;
 struct LinalgTilingOptions;
 
 //===----------------------------------------------------------------------===//
@@ -30,6 +31,14 @@ struct TiledLinalgOp {
   SmallVector<Operation *, 8> loops;
 };
 
+struct TiledAndFusedLinalgOps {
+  LinalgOp op;
+  SmallVector<LinalgOp, 1> fusedProducers;
+  SmallVector<LinalgOp, 1> originalProducers;
+  SmallVector<Operation *, 4> fusedLoops;
+  SmallVector<Operation *, 4> unfusedLoops;
+};
+
 /// Populates patterns for vectorization of all ConvN-D ops.
 void populateConvVectorizationPatterns(
     MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
@@ -53,6 +62,71 @@ void populateConvVectorizationPatterns(
 Optional<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
                                      const LinalgTilingOptions &options);
 
+/// Tile and fuse the `op` with its producers. The tile and fuse proceeds in
+/// three steps
+/// - Find tile loops that are fusable with its producer tile loops (a.k.a. tile
+///   + fuse loops).
+/// - Tile just these loops of the consumer (root operation) and fuse with
+///   the producer.
+/// - Tile again the tiled consumer operation produced above to do rest of
+///   the tiling specified by the `tilingOptions`.
+///
+/// For example, consider the sequence of matmul below
+///
+///   linalg.matmul ins(%arg0, %arg1 : memref<256x32xf32>, memref<32x32xf32>)
+///                 outs(%arg2 : memref<256x32xf32>)
+///   linalg.matmul ins(%arg2, %arg3 : memref<256x32xf32>, memref<32x32xf32>)
+///                 outs(%arg4 : memref<256x32xf32>)
+///
+/// It is legal to fuse the RAW dependence (through %arg2) by only fusing the
+/// matmuls row-wise. For example, the fused computation for the above is shown
+/// below. The outer `scf.parallel` loop is the "fused" loop obtained by tiling
+/// along the rows of the matrix. The entire rows of the first matmul operation
+/// need to be computed before they can be used for the second matmul. The
+/// second matmul is further tiled (similar to normal tiling).
+///
+/// #map0 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)>
+/// #map1 = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+/// scf.parallel (%arg5) = (%c0) to (%c256) step (%c16) {
+///   %0 = subview %arg2[%arg5, 0] [16, 32] [1, 1]
+///     : memref<256x32xf32> to memref<16x32xf32, #map0>
+///   %1 = subview %arg4[%arg5, 0] [16, 32] [1, 1]
+///     : memref<256x32xf32> to memref<16x32xf32, #map0>
+///   %2 = subview %arg0[%arg5, 0] [16, 32] [1, 1]
+///     : memref<256x32xf32> to memref<16x32xf32, #map0>
+///   %3 = subview %arg1[0, 0] [32, 32] [1, 1]
+///     : memref<32x32xf32> to memref<32x32xf32, #map1>
+///   linalg.matmul
+///     ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>)
+///     outs(%0 : memref<16x32xf32, #map0>)
+///   scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) {
+///   scf.for %arg7 = %c0 to %c32 step %c4 {
+///     %4 = subview %0[0, %arg7] [16, 4] [1, 1]
+///       : memref<16x32xf32, #map0> to memref<16x4xf32, #map0>
+///     %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1]
+///       : memref<32x32xf32> to memref<4x8xf32, #map0>
+///     %6 = subview %1[0, %arg6] [16, 8] [1, 1]
+///       : memref<16x32xf32, #map0> to memref<16x8xf32, #map0>
+///     linalg.matmul
+///       ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>)
+///       outs(%6 : memref<16x8xf32, #map0>)
+///     }
+///     scf.yield
+///   }
+///   scf.yield
+/// }
+///
+/// The following tiling options are handled differently in tile+fuse (compared
+/// to tile only)
+/// - Interchange of the tiling loops is not supported right now.
+/// - Distribution is only done for the tile+fuse loops. The tiled loops
+///   generated by the second tiling is not distributed.
+Optional<TiledAndFusedLinalgOps>
+tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
+                     const LinalgDependenceGraph &dependenceGraph,
+                     const LinalgTilingOptions &tilingOptions,
+                     const LinalgFusionOptions &fusionOptions);
+
 /// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`.
 /// This is an in-place transformation controlled by `interchangeVector`.
 /// An empty vector is interpreted as the identity permutation and the
@@ -323,6 +397,63 @@ struct LinalgTilingPattern : public LinalgBaseTilingPattern {
   }
 };
 
+struct LinalgFusionOptions {
+  /// Optional list of operands indices to use for fusion. When unspecified,
+  /// only one fusion is done, i.e., the pattern returns after the first fusion.
+  Optional<DenseSet<unsigned>> indicesToFuse = None;
+  LinalgFusionOptions &setIndicesToFuse(ArrayRef<int64_t> operands) {
+    indicesToFuse = DenseSet<unsigned>();
+    indicesToFuse->insert(operands.begin(), operands.end());
+    return *this;
+  }
+};
+
+struct LinalgBaseTileAndFusePattern : public RewritePattern {
+  LinalgBaseTileAndFusePattern(StringRef opName, MLIRContext *context,
+                               const LinalgDependenceGraph &dependenceGraph,
+                               LinalgTilingOptions tilingOptions,
+                               LinalgFusionOptions fusionOptions,
+                               LinalgMarker marker = LinalgMarker(),
+                               LinalgMarker fusedOpMarker = LinalgMarker(),
+                               LinalgMarker originalOpMarker = LinalgMarker(),
+                               PatternBenefit benefit = 1);
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override;
+
+private:
+  /// Dependence graph needed for fusion.
+  const LinalgDependenceGraph &dependenceGraph;
+  /// Options to control tiling.
+  LinalgTilingOptions tilingOptions;
+  /// Options to control fusion.
+  LinalgFusionOptions fusionOptions;
+  /// Marker to control application of the pattern.
+  LinalgMarker marker;
+  /// Marker set on the fused op after tile and fuse.
+  LinalgMarker fusedOpMarker;
+  /// The dependenceGraph is not modifiable, i.e. if the Linalg operations used
+  /// to build the dependence graph changes then the dependenceGraph needs to be
+  /// recomputed right now. To not invalidate the dependenceGraph as
+  /// transformation happens, the original producer can be tagged with a marker
+  /// that can be later used to delete the original operations.
+  LinalgMarker originalOpMarker;
+};
+
+template <typename OpTy>
+struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
+  LinalgTileAndFusePattern(MLIRContext *context,
+                           const LinalgDependenceGraph &dependenceGraph,
+                           LinalgTilingOptions tilingOptions,
+                           LinalgFusionOptions fusionOptions,
+                           LinalgMarker marker = LinalgMarker(),
+                           LinalgMarker fusedOpMarker = LinalgMarker(),
+                           LinalgMarker originalOpMarker = LinalgMarker(),
+                           PatternBenefit benefit = 1)
+      : LinalgBaseTileAndFusePattern(
+            OpTy::getOperationName(), context, dependenceGraph, tilingOptions,
+            fusionOptions, marker, fusedOpMarker, originalOpMarker, benefit) {}
+};
+
 ///
 /// Linalg interchange patterns.
 ///
index aca5a98..76ce4eb 100644 (file)
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_LINALG_UTILS_H_
 
 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/EDSC/Builders.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/SCF/SCF.h"
index dfc977d..8dadfe6 100644 (file)
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/IR/AffineExpr.h"
@@ -154,9 +155,9 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
   llvm_unreachable("Expect to be able to extract a view defining loop range");
 }
 
-static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
-                     unsigned consumerIdx, unsigned producerIdx,
-                     OperationFolder *folder) {
+static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
+                     LinalgOp consumer, unsigned consumerIdx,
+                     OperationFolder *folder = nullptr) {
   assert(producer.hasBufferSemantics() &&
          "expected linalg op with buffer semantics");
   assert(consumer.hasBufferSemantics() &&
@@ -174,9 +175,7 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
   //   we can always identify a data dimension with a (at least one) loop
   //   dimension.
   AffineMap producerMap =
-      producer.indexing_maps()[producer.getNumInputs() + producerIdx]
-          .cast<AffineMapAttr>()
-          .getValue();
+      producer.indexing_maps()[producerIdx].cast<AffineMapAttr>().getValue();
   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
                     << ", producer map: " << producerMap << "\n");
 
@@ -185,10 +184,9 @@ static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer,
   unsigned nWin = producer.getNumWindowLoops();
   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
 
-  OpBuilder b(consumer.getOperation());
-  auto loc = consumer.getLoc();
   // Iterate over dimensions identified by the producer map for `producerIdx`.
   // This defines a subset of the loop ranges that we need to complete later.
+  auto loc = consumer.getLoc();
   for (auto en : llvm::enumerate(producerMap.getResults())) {
     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
     loopRanges[posInProducerLoop] =
@@ -319,71 +317,380 @@ static bool isSameSubView(Value a, Value b) {
   return true;
 }
 
-static Optional<FusionInfo>
-fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
-                  const LinalgDependenceGraph &graph, OperationFolder *folder,
-                  LinalgDependenceGraph::DependenceType depType) {
-  assert(consumer.hasBufferSemantics() &&
-         "expected linalg op with buffer semantics");
-  LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
-                    << *consumer.getOperation());
-  for (auto dependence : graph.getDependencesInto(consumer, depType)) {
-    LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
-                      << *dependence.dependentOpView.op << "\n");
-    auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
-
-    // Check that the dependence is indeed on the input `consumerIdx` view.
-    auto consumedView = dependence.indexingView;
-    if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
-      continue;
-
-    // Consumer consumes this view, `isStructurallyFusableProducer` also checks
-    // whether it is a strict subview of the producer view.
-    auto producedView = dependence.dependentOpView.view;
-    auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue();
-    // `consumerIdx` and `producerIdx` exist by construction.
-    LLVM_DEBUG(dbgs() << "\n"
-                      << LinalgDependenceGraph::getDependenceTypeStr(depType)
-                      << "producer: " << *producer.getOperation() << " view: "
-                      << producedView << " output index: " << producerIdx);
-
-    // Must be a subview or a slice to guarantee there are loops we can fuse
-    // into.
-    auto subView = consumedView.getDefiningOp<SubViewOp>();
-    auto slice = consumedView.getDefiningOp<SliceOp>();
-    if (!subView && !slice) {
-      LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
-      continue;
-    }
+static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
+findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
+                    const LinalgDependenceGraph &dependenceGraph) {
+  // Only consider RAW and WAW atm.
+  for (auto depType : {
+           LinalgDependenceGraph::DependenceType::RAW,
+           LinalgDependenceGraph::DependenceType::WAW,
+       }) {
+    for (auto dependence :
+         dependenceGraph.getDependencesInto(consumer, depType)) {
+      auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
 
-    // Simple fusability checks.
-    if (!isFusableInto(graph, consumer, consumedView, producer))
-      continue;
+      // Check that the dependence is indeed on the input `consumerIdx` view.
+      auto consumedView = dependence.indexingView;
+      if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
+        continue;
 
-    // Fuse `producer` just before `consumer`.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(consumer.getOperation());
-    ScopedContext scope(b, consumer.getLoc());
-    LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
-    auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx,
-                              producerIdx, folder);
+      // Consumer consumes this view, `isStructurallyFusableProducer` also
+      // checks whether it is a strict subview of the producer view.
+      auto producedView = dependence.dependentOpView.view;
+      auto producerIdx =
+          producer.getIndexOfOutputBuffer(producedView).getValue();
+      // `consumerIdx` and `producerIdx` exist by construction.
+      LLVM_DEBUG(dbgs() << "\n"
+                        << LinalgDependenceGraph::getDependenceTypeStr(depType)
+                        << "producer: " << *producer.getOperation() << " view: "
+                        << producedView << " output index: " << producerIdx);
+      (void)producerIdx;
+
+      // Simple fusability checks.
+      if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
+        continue;
 
-    return FusionInfo{producer, fusedProducer};
+      return dependence;
+    }
   }
-  return llvm::None;
+  return {};
 }
 
-// Only consider RAW and WAW atm.
 Optional<FusionInfo> mlir::linalg::fuseProducerOf(
     OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
     const LinalgDependenceGraph &graph, OperationFolder *folder) {
-  for (auto dep : {
-           LinalgDependenceGraph::DependenceType::RAW,
-           LinalgDependenceGraph::DependenceType::WAW,
-       }) {
-    if (auto res =
-            fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep))
-      return res;
+  Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
+      findFusableProducer(consumer, consumerIdx, graph);
+  if (!fusableDependence)
+    return {};
+
+  LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
+  Value producerView = fusableDependence->dependentOpView.view;
+  Value consumerView = fusableDependence->indexingView;
+
+  // Must be a subview or a slice to guarantee there are loops we can fuse
+  // into.
+  auto subView = consumerView.getDefiningOp<SubViewOp>();
+  auto slice = consumerView.getDefiningOp<SliceOp>();
+  if (!subView && !slice) {
+    LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
+    return {};
+  }
+
+  // Fuse `producer` just before `consumer`.
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(consumer.getOperation());
+  ScopedContext scope(b, consumer.getLoc());
+  LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
+  Optional<unsigned> producerIdxOpt =
+      producerOp.getIndexOfInputAndOutputBuffer(producerView);
+  assert(producerIdxOpt.hasValue() && "incorrect operand index");
+  unsigned producerIdx = producerIdxOpt.getValue();
+
+  auto fusedProducer =
+      fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder);
+  return FusionInfo{producerOp, fusedProducer};
+}
+
+/// Returns the positions of the loop in `op` that can be tiled based on the
+/// operations that are to be fused with it. For example, in a
+///
+///   linalg. matmul ins(%a, %b : ...) outs(%c : ...)
+///
+/// if the producer of %a needs to be fused with this op, only the `i` loop of
+/// the matmul can be tiled while fusing. If producer of %a, and %b are to be
+/// fused, then no loops can be tiled while fusing.
+static DenseSet<unsigned> collectTileAndFuseLoops(
+    LinalgOp op, ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem>
+                     fusableDependences) {
+  // 1. Only parallel loops can be used for tile + fuse. Find the number of
+  // common outer parallel loops between the op and its producers being fused.
+  auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
+    return linalgOp.iterator_types()
+        .getValue()
+        .take_while([](Attribute attr) -> bool {
+          return attr.cast<StringAttr>().getValue() ==
+                 getParallelIteratorTypeName();
+        })
+        .size();
+  };
+
+  size_t numOuterParallelLoops = getNumOuterParallelLoops(op);
+  for (auto dependence : fusableDependences) {
+    numOuterParallelLoops =
+        std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast<LinalgOp>(
+                                            dependence.dependentOpView.op)));
+  }
+
+  // Need to compute what tiled loops can be "fused". Given the precondition
+  // that all indexing map for the producer view is a projected permutation, we
+  // can assert that the producer iterates over the dimensions of the "fused
+  // view" only once. To be used a fused loop the producer should use this loop
+  // to access the fused view. For example, consider
+  //
+  // ```
+  //   linalg.add ins(%a, %b) outs(%c)
+  //   linalg.matmul ins(%d, %c) outs(%e)
+  // ```
+  //
+  // if `linalg.add` has the semantics of `c = a + b`, then the following
+  // tile+fuse code is correct.
+  //
+  // ```
+  // for j ... += TSj
+  //   %sa = subview %a[0, %j][...]
+  //   %sb = subview %b[0, %j][...]
+  //   %sc = subview %c[0, %j][...]
+  //   %sd = subview %d[0, 0][...]
+  //   %se = subview %e[0, %j][...]
+  //   linalg.add ins(%sa, %sb) outs(%sc)
+  //   linalg.matmul ins(%sd, %sc) outs(%se)
+  // ```
+  //
+  // On the other hand tiling along i would be incorrect
+  //
+  // ```
+  // for %i .. += TSi
+  //   %sa = subview %a[%i, 0][...]
+  //   %sb = subview %b[%i, 0][...]
+  //   %sc = subview %c[%i, 0][...]
+  //   %sc2 = subview %c[0, 0][...]
+  //   %sd = subview %d[%i, 0][...]
+  //   %se = subview %e[%i, 0][...]
+  //   linalg.add ins(%sa, %sb) outs(%sc)
+  //   linalg.matmul ins(%sd, %sc2) outs(%se)
+  // ```
+  //
+  // The write to the subview `%sc` in `linalg.add` is performed after the read
+  // from it using `%sc2` violating the RAW dependence of the original code. To
+  // find such loops indexing map of the fused view in the consumer op is
+  // used. For the above example, this indexing map is
+  //
+  //   affine_map<(d0, d1, d2) -> (d2, d1)>
+  //
+  // Since d0 is not in the result expressions of this map, it is not treated as
+  // tile + fuse loop, (but d1 is).
+  //
+  // TODO: The above is probably restrictive and there might be a generalization
+  // of these that might allow for more fusion opportunities. Explore based on
+  // needs.
+  SmallVector<DenseSet<unsigned>, 1> commonTilableLoops;
+  for (auto dependence : fusableDependences) {
+    unsigned consumerIdx =
+        op.getIndexOfInputAndOutputBuffer(dependence.indexingView).getValue();
+    AffineMap consumerAccess = op.getIndexingMap(consumerIdx);
+    // Previously asserted that the consumerAccess map is a projected
+    // permutation, so all results are known to be AffineDimExprs. To remove
+    // this restriction walk the expression to find which dimensions of the
+    // consumer loop appear in the `consumerAccess`.
+    DenseSet<unsigned> positions;
+    for (auto expr : consumerAccess.getResults())
+      positions.insert(expr.cast<AffineDimExpr>().getPosition());
+    commonTilableLoops.emplace_back(std::move(positions));
+  }
+
+  // 2. Of the outer parallel loops, only those loops can be tiled + fused as
+  // computed above for all the fused dependences can be used to tile and fuse.
+  DenseSet<unsigned> tilableParallelLoops;
+  for (auto index : llvm::seq<unsigned>(0, numOuterParallelLoops)) {
+    if (llvm::all_of(commonTilableLoops,
+                     [&](const DenseSet<unsigned> &tilableLoops) {
+                       return tilableLoops.count(index);
+                     }))
+      tilableParallelLoops.insert(index);
+  }
+  return tilableParallelLoops;
+}
+
+/// Find all dependences that are to be fusable.
+static Optional<
+    SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
+findAllFusableDependences(LinalgOp op,
+                          const LinalgDependenceGraph &dependenceGraph,
+                          const LinalgFusionOptions &fusionOptions) {
+  SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>
+      fusableDependences;
+  for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) {
+    if (fusionOptions.indicesToFuse &&
+        !fusionOptions.indicesToFuse->count(operand.index()))
+      continue;
+    Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
+        fusableDependence =
+            findFusableProducer(op, operand.index(), dependenceGraph);
+    if (!fusableDependence)
+      continue;
+    // Make sure that the indexing map of the view used for fusion in the
+    // producer is a projected permutation.
+    LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
+    Value producerView = fusableDependence->dependentOpView.view;
+    unsigned producerIdx =
+        producerOp.getIndexOfInputAndOutputBuffer(producerView).getValue();
+    AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
+    if (!producerMap.isProjectedPermutation()) {
+      op.emitError("unhandled non permutation indexing map for fused view in "
+                   "producer for operand at index ")
+          << operand.index();
+      return llvm::None;
+    }
+    Value consumerView = fusableDependence->indexingView;
+    unsigned consumerIdx =
+        op.getIndexOfInputAndOutputBuffer(consumerView).getValue();
+    if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) {
+      op.emitError(
+          "unhandled case where indexing map for fused view in the consumer is "
+          "not a projected permuration while fusing at index ")
+          << operand.index();
+      return llvm::None;
+    }
+    fusableDependences.push_back(*fusableDependence);
+    if (!fusionOptions.indicesToFuse)
+      break;
+  }
+  return fusableDependences;
+}
+
+static bool isZero(Value v) {
+  if (auto cst = v.getDefiningOp<ConstantIndexOp>())
+    return cst.getValue() == 0;
+  return false;
+}
+
+template <typename LoopType>
+static Optional<TiledAndFusedLinalgOps>
+tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op,
+                         const LinalgDependenceGraph &dependenceGraph,
+                         const LinalgTilingOptions &tilingOptions,
+                         const LinalgFusionOptions &fusionOptions) {
+  assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
+  // Some of the tiling options might not be supportable with tile and fuse.
+  // TODO: Support interchange with tile + fuse.
+  if (!tilingOptions.interchangeVector.empty()) {
+    op.emitError("unable to handle tile and fuse with interchange");
+    return llvm::None;
+  }
+
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(op);
+  ScopedContext scope(rewriter, op.getLoc());
+
+  // Find all the producers.
+  Optional<SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>
+      fusableDependencesOpt =
+          findAllFusableDependences(op, dependenceGraph, fusionOptions);
+  if (!fusableDependencesOpt)
+    return llvm::None;
+  ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependences(
+      *fusableDependencesOpt);
+
+  // Enforce the convention that "tiling by zero" skips tiling a particular
+  // dimension. This convention is significantly simpler to handle instead of
+  // adjusting affine maps to account for missing dimensions.
+  auto nLoops = op.getNumLoops();
+  SmallVector<Value, 4> tileSizeVector =
+      tilingOptions.tileSizeComputationFunction(rewriter, op);
+  if (tileSizeVector.size() < nLoops) {
+    auto zero = std_constant_index(0);
+    tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
+  }
+
+  TiledAndFusedLinalgOps ret;
+
+  // Find the loops that can be tiled and fused.
+  DenseSet<unsigned> tileFuseLoops =
+      collectTileAndFuseLoops(op, fusableDependences);
+
+  // If there are no fusable dependences or there are no tile+fusable loops,
+  // just return.
+  if (fusableDependences.empty() || tileFuseLoops.empty()) {
+    return llvm::None;
+  }
+
+  // Get the tile sizes for the first and second tiling steps. For the first
+  // step the tile size are set to zero for the loops that arent
+  // fused. Similarly for the second step, the tile sizes are set to zero for
+  // the loops that are fused. For example, if for the following input
+  //
+  // ```
+  //   linalg.add ins(%a, %b) outs(%c)
+  //   linalg.matmul ins(%d, %c) outs(%e)
+  // ```
+  //
+  // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}`
+  // respectively, and since only `j` can be tiled and fused. The tile sizes
+  // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable
+  // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile
+  // the tiled matmul generated by the first tiling step.
+  SmallVector<Value, 4> tileAndFuseSizes, tileSizes;
+  for (auto tileSize : enumerate(tileSizeVector)) {
+    auto zero = std_constant_index(0);
+    if (tileFuseLoops.count(tileSize.index())) {
+      tileAndFuseSizes.push_back(tileSize.value());
+      tileSizes.push_back(zero);
+    } else {
+      tileSizes.push_back(tileSize.value());
+      tileAndFuseSizes.push_back(zero);
+    }
+  }
+
+  // Tile for the loops that can be fused.
+  LinalgTilingOptions firstTilingOptions = tilingOptions;
+  firstTilingOptions.setTileSizes(tileAndFuseSizes);
+  Optional<TiledLinalgOp> firstTiledOp =
+      tileLinalgOp(rewriter, op, firstTilingOptions);
+  if (!firstTiledOp)
+    return llvm::None;
+  ret.op = firstTiledOp->op;
+  ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end());
+
+  rewriter.setInsertionPoint(ret.op);
+  // Fuse the operands.
+  for (auto producer : enumerate(fusableDependences)) {
+    LinalgOp producerOp = cast<LinalgOp>(producer.value().dependentOpView.op);
+    unsigned producerIdx = producerOp
+                               .getIndexOfInputAndOutputBuffer(
+                                   producer.value().dependentOpView.view)
+                               .getValue();
+    unsigned consumerIdx =
+        op.getIndexOfInputAndOutputBuffer(producer.value().indexingView)
+            .getValue();
+    LinalgOp fusedOp =
+        fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx);
+    ret.fusedProducers.push_back(fusedOp);
+    ret.originalProducers.push_back(producerOp);
+  }
+
+  if (!llvm::all_of(tileSizes, isZero)) {
+    // Tile the remaining loops of the root operation.
+    LinalgTilingOptions secondTilingOptions = tilingOptions;
+    // The distribution is done only for the tile+fused loops.
+    secondTilingOptions.distribution = llvm::None;
+    secondTilingOptions.setTileSizes(tileSizes);
+    Optional<TiledLinalgOp> secondTiledOp =
+        tileLinalgOp(rewriter, ret.op, secondTilingOptions);
+    if (!secondTiledOp)
+      return llvm::None;
+    ret.unfusedLoops.assign(secondTiledOp->loops.begin(),
+                            secondTiledOp->loops.end());
+    rewriter.eraseOp(ret.op);
+    ret.op = secondTiledOp->op;
+  }
+
+  return ret;
+}
+
+Optional<TiledAndFusedLinalgOps>
+mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op,
+                                   const LinalgDependenceGraph &dependenceGraph,
+                                   const LinalgTilingOptions &tilingOptions,
+                                   const LinalgFusionOptions &fusionOptions) {
+  switch (tilingOptions.loopType) {
+  case LinalgTilingLoopType::Loops:
+    return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph,
+                                                tilingOptions, fusionOptions);
+  case LinalgTilingLoopType::ParallelLoops:
+    return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
+        rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
+  default:;
   }
   return llvm::None;
 }
index 3db801b..68d6954 100644 (file)
@@ -318,25 +318,10 @@ static SmallVector<Value, 4> makeTiledViews(OpBuilder &b, Location loc,
 }
 
 template <typename LoopTy>
-Optional<TiledLinalgOp> static tileLinalgOpImpl(
-    OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) {
-  OpBuilder::InsertionGuard g(b);
-  b.setInsertionPoint(op);
-  ScopedContext scope(b, op.getLoc());
-
-  assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
-  // 1. Enforce the convention that "tiling by zero" skips tiling a particular
-  // dimension. This convention is significantly simpler to handle instead of
-  // adjusting affine maps to account for missing dimensions.
+static Optional<TiledLinalgOp>
+tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizes,
+                 const LinalgTilingOptions &options) {
   auto nLoops = op.getNumLoops();
-  SmallVector<Value, 4> tileSizeVector =
-      options.tileSizeComputationFunction(b, op);
-  if (tileSizeVector.size() < nLoops) {
-    auto zero = std_constant_index(0);
-    tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
-  }
-
-  ArrayRef<Value> tileSizes = tileSizeVector;
   // Initial tile sizes may be too big, only take the first nLoops.
   tileSizes = tileSizes.take_front(nLoops);
 
@@ -350,17 +335,7 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
       return llvm::None;
   }
 
-  // If interchangeVector is empty, use the identity. Build the permutation map
-  // otherwise.
-  auto invPermutationMap =
-      AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
-  if (!options.interchangeVector.empty())
-    invPermutationMap = inversePermutation(AffineMap::getPermutationMap(
-        options.interchangeVector, b.getContext()));
-  if (!invPermutationMap)
-    return llvm::None;
-
-  // 2. Build the tiled loop ranges.
+  // 1. Build the tiled loop ranges.
   auto allViewSizes = getViewSizes(b, op);
   // The flattened loopToOperandRangesMaps is expected to be an invertible
   // permutation map (asserted in the inverse calculation).
@@ -374,17 +349,39 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
   SmallVector<SubViewOp::Range, 4> loopRanges;
   LoopIndexToRangeIndexMap loopIndexToRangeIndex;
   std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges(
-      b, scope.getLocation(), viewSizesToLoopsMap, allViewSizes, tileSizes);
-  if (!options.interchangeVector.empty())
-    applyPermutationToVector(loopRanges, options.interchangeVector);
+      b, op.getLoc(), viewSizesToLoopsMap, allViewSizes, tileSizes);
+  SmallVector<Attribute, 4> iteratorTypes;
+  for (auto attr :
+       enumerate(op.iterator_types().cast<ArrayAttr>().getValue())) {
+    if (loopIndexToRangeIndex.count(attr.index()))
+      iteratorTypes.push_back(attr.value());
+  }
+  // If interchangeVector is empty, use the identity. Build the permutation map
+  // otherwise.
+  auto invPermutationMap =
+      AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext());
+  if (!options.interchangeVector.empty()) {
+    // Based on the pruned iterations (due to zero tile size), recompute the
+    // interchange vector.
+    SmallVector<unsigned, 4> interchangeVector;
+    interchangeVector.reserve(options.interchangeVector.size());
+    for (auto pos : options.interchangeVector) {
+      auto it = loopIndexToRangeIndex.find(pos);
+      if (it == loopIndexToRangeIndex.end())
+        continue;
+      interchangeVector.push_back(it->second);
+    }
+    invPermutationMap = inversePermutation(
+        AffineMap::getPermutationMap(interchangeVector, b.getContext()));
+    if (!invPermutationMap)
+      return llvm::None;
+    applyPermutationToVector(loopRanges, interchangeVector);
+    applyPermutationToVector(iteratorTypes, interchangeVector);
+  }
 
-  // 3. Create the tiled loops.
+  // 2. Create the tiled loops.
   LinalgOp res = op;
   SmallVector<Value, 4> ivs;
-  SmallVector<Attribute, 4> iteratorTypes =
-      llvm::to_vector<4>(op.iterator_types().cast<ArrayAttr>().getValue());
-  if (!options.interchangeVector.empty())
-    applyPermutationToVector(iteratorTypes, options.interchangeVector);
   GenerateLoopNest<LoopTy>::doit(
       loopRanges, /*iterArgInitValues*/ {}, iteratorTypes,
       [&](ValueRange localIvs, ValueRange iterArgs) -> scf::ValueVector {
@@ -410,10 +407,10 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
       },
       options.distribution);
 
-  // 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
+  // 3. Transforms index arguments of `linalg.generic` w.r.t. to the tiling.
   transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex);
 
-  // 5. Gather the newly created loops and return them with the new op.
+  // 4. Gather the newly created loops and return them with the new op.
   SmallVector<Operation *, 8> loops;
   loops.reserve(ivs.size());
   for (auto iv : ivs) {
@@ -429,14 +426,38 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(
   return TiledLinalgOp{res, loops};
 }
 
+template <typename LoopTy>
+Optional<TiledLinalgOp> static tileLinalgOpImpl(
+    OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) {
+  OpBuilder::InsertionGuard g(b);
+  b.setInsertionPoint(op);
+  ScopedContext scope(b, op.getLoc());
+
+  assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
+  // Enforce the convention that "tiling by zero" skips tiling a particular
+  // dimension. This convention is significantly simpler to handle instead of
+  // adjusting affine maps to account for missing dimensions.
+  auto nLoops = op.getNumLoops();
+  SmallVector<Value, 4> tileSizeVector =
+      options.tileSizeComputationFunction(b, op);
+  if (tileSizeVector.size() < nLoops) {
+    auto zero = std_constant_index(0);
+    tileSizeVector.append(nLoops - tileSizeVector.size(), zero);
+  }
+
+  return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
+}
+
 Optional<TiledLinalgOp>
 mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
                            const LinalgTilingOptions &options) {
-  if (options.loopType == LinalgTilingLoopType::Loops)
+  switch (options.loopType) {
+  case LinalgTilingLoopType::Loops:
     return tileLinalgOpImpl<scf::ForOp>(b, op, options);
-  if (options.loopType == LinalgTilingLoopType::ParallelLoops)
+  case LinalgTilingLoopType::ParallelLoops:
     return tileLinalgOpImpl<scf::ParallelOp>(b, op, options);
-  // TODO: Impl tiling to affine loops when it makes sense.
+  default:;
+  }
   return llvm::None;
 }
 
index c1aad62..56652cb 100644 (file)
@@ -129,6 +129,43 @@ LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewrite(
   return success();
 }
 
+mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
+    StringRef opName, MLIRContext *context,
+    const LinalgDependenceGraph &dependenceGraph,
+    LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
+    LinalgMarker marker, LinalgMarker fusedOpMarker,
+    LinalgMarker originalOpMarker, PatternBenefit benefit)
+    : RewritePattern(opName, {}, benefit, context),
+      dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
+      fusionOptions(fusionOptions), marker(marker),
+      fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
+
+LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
+    Operation *op, PatternRewriter &rewriter) const {
+  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
+  if (!linalgOp)
+    return failure();
+  if (failed(marker.checkAndNotify(rewriter, linalgOp)))
+    return failure();
+  if (!linalgOp.hasBufferSemantics())
+    return failure();
+
+  Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
+      rewriter, op, dependenceGraph, tilingOptions, fusionOptions);
+  if (!tiledAndFusedOps)
+    return failure();
+  marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation());
+  for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
+    fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation());
+  }
+  for (auto origProducerOp : tiledAndFusedOps->originalProducers)
+    originalOpMarker.replaceLinalgMarker(rewriter,
+                                         origProducerOp.getOperation());
+  rewriter.updateRootInPlace(
+      op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); });
+  return success();
+}
+
 /// Linalg base interchange pattern.
 mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern(
     StringRef opName, MLIRContext *context,
diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir
new file mode 100644 (file)
index 0000000..61e5b74
--- /dev/null
@@ -0,0 +1,297 @@
+// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file | FileCheck %s
+
+module {
+  func @basic_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+                     %arg2: memref<?x?xf32>) {
+    %cst = constant 0.000000e+00 : f32
+    linalg.fill(%arg2, %cst) : memref<?x?xf32>, f32
+    linalg.matmul {__internal_linalg_transform__ = "basic_fusion"}
+      ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?xf32>)
+    return
+  }
+}
+
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+//      CHECK: func @basic_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:   %[[C32:.+]] = constant 32 : index
+//  CHECK-DAG:   %[[C64:.+]] = constant 64 : index
+//  CHECK-DAG:   %[[C16:.+]] = constant 16 : index
+//  CHECK-DAG:   %[[CST:.+]] = constant 0.0{{.*}} : f32
+//  CHECK-DAG:   linalg.fill(%[[ARG2]], %[[CST]])
+// CHECK-SAME:   __internal_linalg_transform__ = "after_basic_fusion_original"
+//  CHECK-DAG:   %[[M:.+]] = dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:   %[[N:.+]] = dim %[[ARG1]], %[[C1]]
+//      CHECK:   scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) =
+// CHECK-SAME:     to (%[[M]], %[[N]])
+// CHECK-SAME:     step (%[[C32]], %[[C64]]) {
+//      CHECK:     %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+//      CHECK:     %[[K:.+]] = dim %[[ARG0]], %[[C1]]
+//      CHECK:     %[[SV1:.+]] = subview %[[ARG0]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[K]]]
+//      CHECK:     %[[K_2:.+]] = dim %[[ARG1]], %[[C0]]
+//      CHECK:     %[[TILE_N:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N]]]
+//      CHECK:     %[[SV2:.+]] = subview %[[ARG1]][0, %[[IV1]]]
+// CHECK-SAME:       %[[K_2]], %[[TILE_N]]
+//      CHECK:     %[[M_2:.+]] = dim %[[ARG2]], %[[C0]]
+//      CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
+//      CHECK:     %[[N_2:.+]] = dim %[[ARG2]], %[[C1]]
+//      CHECK:     %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]]
+//      CHECK:     %[[SV3:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK-SAME:       [%[[TILE_M_2]], %[[TILE_N_2]]]
+//      CHECK:     linalg.fill(%[[SV3]], %[[CST]])
+// CHECK-SAME:       __internal_linalg_transform__ = "after_basic_fusion_producer"
+//      CHECK:     scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
+//      CHECK:       %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
+//      CHECK:       %[[SV4:.+]] = subview %[[SV1]][0, %[[IV2]]]
+// CHECK-SAME:         [%[[TILE_M]], %[[TILE_K]]]
+//      CHECK:       %[[TILE_K_2:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K_2]]]
+//      CHECK:       %[[SV5:.+]] = subview %[[SV2]][%[[IV2]], 0]
+// CHECK-SAME:         [%[[TILE_K_2]], %[[TILE_N]]]
+//      CHECK:       linalg.matmul
+// CHECK-SAME:         __internal_linalg_transform__ = "after_basic_fusion"
+// CHECK-SAME:         ins(%[[SV4]], %[[SV5]]
+// CHECK-SAME:           : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME:         outs(%[[SV3]] : memref<?x?xf32, #[[MAP1]]>)
+//      CHECK:     }
+//      CHECK:   }
+//      CHECK:   linalg.matmul
+// CHECK-SAME:     __internal_linalg_transform__ = "after_basic_fusion_original"
+
+// -----
+
+module {
+  func @rhs_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+                              %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) {
+    %cst = constant 0.000000e+00 : f32
+    linalg.copy(%arg1, %arg2) : memref<?x?xf32>, memref<?x?xf32>
+    linalg.fill(%arg3, %cst) : memref<?x?xf32>, f32
+    linalg.matmul {__internal_linalg_transform__ = "rhs_fusion"}
+      ins(%arg0, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg3 : memref<?x?xf32>)
+    return
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+//      CHECK: func @rhs_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:   %[[C32:.+]] = constant 32 : index
+//  CHECK-DAG:   %[[C64:.+]] = constant 64 : index
+//  CHECK-DAG:   %[[C16:.+]] = constant 16 : index
+//  CHECK-DAG:   %[[CST:.+]] = constant 0.0{{.*}} : f32
+//  CHECK-DAG:   linalg.copy(%[[ARG1]], %[[ARG2]])
+// CHECK-SAME:   __internal_linalg_transform__ = "after_rhs_fusion_original"
+//  CHECK-DAG:   %[[N:.+]] = dim %[[ARG2]], %[[C1]]
+//      CHECK:   scf.parallel (%[[IV0:.+]]) =
+// CHECK-SAME:     (%[[C0]]) to (%[[N]]) step (%[[C64]]) {
+//      CHECK:     %[[K:.+]] = dim %[[ARG2]], %[[C0]]
+//      CHECK:     %[[TILE_N:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N]]]
+//      CHECK:     %[[SV1:.+]] = subview %[[ARG2]][0, %[[IV0]]]
+// CHECK-SAME:       [%[[K]], %[[TILE_N]]]
+//      CHECK:     %[[M:.+]] = dim %[[ARG3]], %[[C0]]
+//      CHECK:     %[[N_2:.+]] = dim %[[ARG3]], %[[C1]]
+//      CHECK:     %[[TILE_N_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N_2]]]
+//      CHECK:     %[[SV2:.+]] = subview %[[ARG3]][0, %[[IV0]]]
+// CHECK-SAME:       [%[[M]], %[[TILE_N_2]]]
+//      CHECK:     %[[SV3:.+]] = subview %[[ARG1]][0, %[[IV0]]]
+// CHECK-SAME:       [%[[K]], %[[TILE_N]]]
+//      CHECK:     linalg.copy(%[[SV3]], %[[SV1]])
+// CHECK-SAME:       __internal_linalg_transform__ = "after_rhs_fusion_producer"
+//  CHECK-NOT:     linalg.fill
+//  CHECK-DAG:     %[[M_2:.+]] = dim %[[ARG0]], %[[C0]]
+//  CHECK-DAG:     %[[K_2:.+]] = dim %[[ARG0]], %[[C1]]
+//      CHECK:     scf.parallel (%[[IV1:.+]]) =
+// CHECK-SAME:       (%[[C0]]) to (%[[M_2]]) step (%[[C32]]) {
+// CHECK-NEXT:       scf.for %[[IV2:.+]] = %[[C0]] to %[[K_2]] step %[[C16]] {
+//      CHECK:         %[[TILE_M:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[M_2]]]
+//      CHECK:         %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K_2]]]
+//      CHECK:         %[[SV4:.+]] = subview %[[ARG0]][%[[IV1]], %[[IV2]]]
+// CHECK-SAME:           [%[[TILE_M]], %[[TILE_K]]]
+//      CHECK:         %[[TILE_K_2:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]]
+//      CHECK:         %[[SV5:.+]] = subview %[[SV1]][%[[IV2]], 0]
+// CHECK-SAME:           [%[[TILE_K_2]], %[[TILE_N]]]
+//      CHECK:         %[[TILE_M_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[M]]]
+//      CHECK:         %[[SV6:.+]] = subview %[[SV2]][%[[IV1]], 0]
+// CHECK-SAME:           [%[[TILE_M_2]], %[[TILE_N_2]]]
+//      CHECK:         linalg.matmul
+// CHECK-SAME:           __internal_linalg_transform__ = "after_rhs_fusion"
+// CHECK-SAME:           ins(%[[SV4]], %[[SV5]]
+// CHECK-SAME:             : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME:           outs(%[[SV6]] : memref<?x?xf32, #[[MAP1]]>)
+//      CHECK:       }
+//      CHECK:     }
+//      CHECK:   }
+//      CHECK:   linalg.matmul
+// CHECK-SAME:     __internal_linalg_transform__ = "after_rhs_fusion_original"
+
+
+// -----
+
+module {
+  func @two_operand_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+                              %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>) {
+    %cst = constant 0.000000e+00 : f32
+    linalg.copy(%arg0, %arg1) : memref<?x?xf32>, memref<?x?xf32>
+    linalg.fill(%arg3, %cst) : memref<?x?xf32>, f32
+    linalg.matmul {__internal_linalg_transform__ = "two_operand_fusion"}
+      ins(%arg1, %arg2 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg3 : memref<?x?xf32>)
+    return
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+//      CHECK: func @two_operand_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:   %[[C32:.+]] = constant 32 : index
+//  CHECK-DAG:   %[[C64:.+]] = constant 64 : index
+//  CHECK-DAG:   %[[C16:.+]] = constant 16 : index
+//  CHECK-DAG:   %[[CST:.+]] = constant 0.0{{.*}} : f32
+//      CHECK:   linalg.copy(%[[ARG0]], %[[ARG1]])
+// CHECK-SAME:     __internal_linalg_transform__ = "after_two_operand_fusion_original"
+//      CHECK:   linalg.fill(%[[ARG3]], %[[CST]])
+// CHECK-SAME:     __internal_linalg_transform__ = "after_two_operand_fusion_original"
+//  CHECK-DAG:   %[[M:.+]] = dim %[[ARG1]], %[[C0]]
+//      CHECK:   scf.parallel (%[[IV0:.+]]) =
+// CHECK-SAME:     (%[[C0]]) to (%[[M]]) step (%[[C32]]) {
+//      CHECK:     %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+//      CHECK:     %[[K:.+]] = dim %[[ARG1]], %[[C1]]
+//      CHECK:     %[[SV1:.+]] = subview %[[ARG1]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[K]]]
+//      CHECK:     %[[M_2:.+]] = dim %[[ARG3]], %[[C0]]
+//      CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
+//      CHECK:     %[[N:.+]] = dim %[[ARG3]], %[[C1]]
+//      CHECK:     %[[SV2:.+]] = subview %[[ARG3]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M_2]], %[[N]]]
+//      CHECK:     %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[K]]]
+//      CHECK:     linalg.copy(%[[SV3]], %[[SV1]])
+// CHECK-SAME:       __internal_linalg_transform__ = "after_two_operand_fusion_producer"
+//      CHECK:     linalg.fill(%[[SV2]], %[[CST]])
+// CHECK-SAME:       __internal_linalg_transform__ = "after_two_operand_fusion_producer"
+//  CHECK-DAG:     %[[N_2:.+]] = dim %[[ARG2]], %[[C1]]
+//      CHECK:     scf.parallel (%[[IV1:.+]]) =
+// CHECK-SAME:       (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {
+// CHECK-NEXT:       scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
+//      CHECK:         %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K]]]
+//      CHECK:         %[[SV4:.+]] = subview %[[SV1]][0, %[[IV2]]]
+// CHECK-SAME:           [%[[TILE_M]], %[[TILE_K]]]
+//      CHECK:         %[[K_2:.+]] = dim %[[ARG2]], %[[C0]]
+//      CHECK:         %[[TILE_K_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K_2]]]
+//      CHECK:         %[[TILE_N:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N_2]]]
+//      CHECK:         %[[SV5:.+]] = subview %[[ARG2]][%[[IV2]], %[[IV1]]]
+// CHECK-SAME:           [%[[TILE_K_2]], %[[TILE_N]]]
+//      CHECK:         %[[TILE_N_2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N]]]
+//      CHECK:         %[[SV6:.+]] = subview %[[SV2]][0, %[[IV1]]]
+// CHECK-SAME:           [%[[TILE_M_2]], %[[TILE_N_2]]]
+//      CHECK:         linalg.matmul
+// CHECK-SAME:           __internal_linalg_transform__ = "after_two_operand_fusion"
+// CHECK-SAME:           ins(%[[SV4]], %[[SV5]]
+// CHECK-SAME:             : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME:           outs(%[[SV6]] : memref<?x?xf32, #[[MAP1]]>)
+//      CHECK:       }
+//      CHECK:     }
+//      CHECK:   }
+//      CHECK:   linalg.matmul
+// CHECK-SAME:     __internal_linalg_transform__ = "after_two_operand_fusion_original"
+
+// -----
+
+module {
+  func @matmul_fusion(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>,
+                      %arg2: memref<?x?xf32>, %arg3: memref<?x?xf32>,
+                      %arg4: memref<?x?xf32>) {
+    linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg2 : memref<?x?xf32>)
+    linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"}
+      ins(%arg2, %arg3 : memref<?x?xf32>, memref<?x?xf32>)
+      outs(%arg4 : memref<?x?xf32>)
+    return
+  }
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
+//      CHECK: func @matmul_fusion
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK-SAME:   %[[ARG4:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+//  CHECK-DAG:   %[[C0:.+]] = constant 0 : index
+//  CHECK-DAG:   %[[C1:.+]] = constant 1 : index
+//  CHECK-DAG:   %[[C32:.+]] = constant 32 : index
+//  CHECK-DAG:   %[[C64:.+]] = constant 64 : index
+//  CHECK-DAG:   %[[C16:.+]] = constant 16 : index
+//      CHECK:   linalg.matmul
+// CHECK-SAME:     __internal_linalg_transform__ = "after_lhs_fusion_original"
+//  CHECK-DAG:   %[[M:.+]] = dim %[[ARG2]], %[[C0]]
+//      CHECK:   scf.parallel (%[[IV0:.+]]) =
+// CHECK-SAME:     (%[[C0]]) to (%[[M]]) step (%[[C32]]) {
+//      CHECK:     %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
+//      CHECK:     %[[K2:.+]] = dim %[[ARG2]], %[[C1]]
+//      CHECK:     %[[SV1:.+]] = subview %[[ARG2]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[K2]]]
+//      CHECK:     %[[M_2:.+]] = dim %[[ARG4]], %[[C0]]
+//      CHECK:     %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
+//      CHECK:     %[[N:.+]] = dim %[[ARG4]], %[[C1]]
+//      CHECK:     %[[SV2:.+]] = subview %[[ARG4]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M_2]], %[[N]]]
+//      CHECK:     %[[K1:.+]] = dim %[[ARG0]], %[[C1]]
+//      CHECK:     %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0]
+// CHECK-SAME:       [%[[TILE_M]], %[[K1]]]
+//      CHECK:     %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2]]]
+//      CHECK:     linalg.matmul
+// CHECK-SAME:         __internal_linalg_transform__ = "after_lhs_fusion_producer"
+// CHECK-SAME:         ins(%[[SV3]], %[[SV4]]
+// CHECK-SAME:           : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME:         outs(%[[SV1]] : memref<?x?xf32, #[[MAP1]]>)
+//  CHECK-DAG:     %[[N_2:.+]] = dim %[[ARG3]], %[[C1]]
+//      CHECK:     scf.parallel (%[[IV1:.+]]) =
+// CHECK-SAME:       (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {
+// CHECK-NEXT:       scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
+//      CHECK:         %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K]]]
+//      CHECK:         %[[SV6:.+]] = subview %[[SV1]][0, %[[IV2]]]
+// CHECK-SAME:           [%[[TILE_M]], %[[TILE_K]]]
+//      CHECK:         %[[K_2:.+]] = dim %[[ARG3]], %[[C0]]
+//      CHECK:         %[[TILE_K_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K_2]]]
+//      CHECK:         %[[TILE_N:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N_2]]]
+//      CHECK:         %[[SV7:.+]] = subview %[[ARG3]][%[[IV2]], %[[IV1]]]
+// CHECK-SAME:           [%[[TILE_K_2]], %[[TILE_N]]]
+//      CHECK:         %[[TILE_N_2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N]]]
+//      CHECK:         %[[SV8:.+]] = subview %[[SV2]][0, %[[IV1]]]
+// CHECK-SAME:           [%[[TILE_M_2]], %[[TILE_N_2]]]
+//      CHECK:         linalg.matmul
+// CHECK-SAME:           __internal_linalg_transform__ = "after_lhs_fusion"
+// CHECK-SAME:           ins(%[[SV6]], %[[SV7]]
+// CHECK-SAME:             : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
+// CHECK-SAME:           outs(%[[SV8]] : memref<?x?xf32, #[[MAP1]]>)
+//      CHECK:       }
+//      CHECK:     }
+//      CHECK:   }
+//      CHECK:   linalg.matmul
+// CHECK-SAME:     __internal_linalg_transform__ = "after_lhs_fusion_original"
index 3c82554..5bf6062 100644 (file)
@@ -16,6 +16,7 @@ add_mlir_library(MLIRTestTransforms
   TestGpuMemoryPromotion.cpp
   TestGpuParallelLoopMapping.cpp
   TestInlining.cpp
+  TestLinalgFusionTransforms.cpp
   TestLinalgHoisting.cpp
   TestLinalgTransforms.cpp
   TestLiveness.cpp
diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
new file mode 100644 (file)
index 0000000..9a376c5
--- /dev/null
@@ -0,0 +1,112 @@
+//===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements logic for testing Linalg fusion patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+struct TestLinalgFusionTransforms
+    : public PassWrapper<TestLinalgFusionTransforms, FunctionPass> {
+  TestLinalgFusionTransforms() = default;
+  TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect,
+                    StandardOpsDialect>();
+  }
+
+  void runOnFunction() override;
+};
+} // namespace
+
+static void fillFusionPatterns(MLIRContext *context,
+                               const LinalgDependenceGraph &dependenceGraph,
+                               OwningRewritePatternList &patterns) {
+  patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions()
+          .setTileSizes({32, 64, 16})
+          .setLoopType(LinalgTilingLoopType::ParallelLoops),
+      LinalgFusionOptions(),
+      LinalgMarker(Identifier::get("basic_fusion", context),
+                   Identifier::get("after_basic_fusion", context)),
+      LinalgMarker(ArrayRef<Identifier>(),
+                   Identifier::get("after_basic_fusion_producer", context)),
+      LinalgMarker(ArrayRef<Identifier>(),
+                   Identifier::get("after_basic_fusion_original", context)));
+
+  patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions()
+          .setTileSizes({32, 64, 16})
+          .setLoopType(LinalgTilingLoopType::ParallelLoops),
+      LinalgFusionOptions().setIndicesToFuse({0}),
+      LinalgMarker(Identifier::get("lhs_fusion", context),
+                   Identifier::get("after_lhs_fusion", context)),
+      LinalgMarker(ArrayRef<Identifier>(),
+                   Identifier::get("after_lhs_fusion_producer", context)),
+      LinalgMarker(ArrayRef<Identifier>(),
+                   Identifier::get("after_lhs_fusion_original", context)));
+
+  patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions()
+          .setTileSizes({32, 64, 16})
+          .setLoopType(LinalgTilingLoopType::ParallelLoops),
+      LinalgFusionOptions().setIndicesToFuse({1}),
+      LinalgMarker(Identifier::get("rhs_fusion", context),
+                   Identifier::get("after_rhs_fusion", context)),
+      LinalgMarker(ArrayRef<Identifier>(),
+                   Identifier::get("after_rhs_fusion_producer", context)),
+      LinalgMarker(ArrayRef<Identifier>(),
+                   Identifier::get("after_rhs_fusion_original", context)));
+
+  patterns.insert<LinalgTileAndFusePattern<MatmulOp>>(
+      context, dependenceGraph,
+      LinalgTilingOptions()
+          .setTileSizes({32, 64, 16})
+          .setLoopType(LinalgTilingLoopType::ParallelLoops),
+      LinalgFusionOptions().setIndicesToFuse({0, 2}),
+      LinalgMarker(Identifier::get("two_operand_fusion", context),
+                   Identifier::get("after_two_operand_fusion", context)),
+      LinalgMarker(
+          ArrayRef<Identifier>(),
+          Identifier::get("after_two_operand_fusion_producer", context)),
+      LinalgMarker(
+          ArrayRef<Identifier>(),
+          Identifier::get("after_two_operand_fusion_original", context)));
+}
+
+static void applyFusionPatterns(MLIRContext *context, FuncOp funcOp) {
+  OwningRewritePatternList fusionPatterns;
+  Aliases alias;
+  LinalgDependenceGraph dependenceGraph =
+      LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
+  fillFusionPatterns(context, dependenceGraph, fusionPatterns);
+  applyPatternsAndFoldGreedily(funcOp, fusionPatterns);
+}
+
+void TestLinalgFusionTransforms::runOnFunction() {
+  applyFusionPatterns(&getContext(), getFunction());
+}
+
+namespace mlir {
+void registerTestLinalgFusionTransforms() {
+  PassRegistration<TestLinalgFusionTransforms> testFusionTransformsPass(
+      "test-linalg-fusion-transform-patterns",
+      "Test Linalg fusion transformation patterns by applying them greedily.");
+}
+} // namespace mlir
index aed8b0a..0389c70 100644 (file)
@@ -58,6 +58,7 @@ void registerTestFunc();
 void registerTestGpuMemoryPromotionPass();
 void registerTestGpuParallelLoopMappingPass();
 void registerTestInterfaces();
+void registerTestLinalgFusionTransforms();
 void registerTestLinalgHoisting();
 void registerTestLinalgTransforms();
 void registerTestLivenessPass();
@@ -114,6 +115,7 @@ void registerTestPasses() {
   registerTestExpandTanhPass();
   registerTestGpuMemoryPromotionPass();
   registerTestInterfaces();
+  registerTestLinalgFusionTransforms();
   registerTestLinalgHoisting();
   registerTestLinalgTransforms();
   registerTestLivenessPass();