Decouple tiling from fusion in Linalg.
authorNicolas Vasilache <ntv@google.com>
Thu, 26 Sep 2019 15:43:58 +0000 (08:43 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 26 Sep 2019 15:44:31 +0000 (08:44 -0700)
This CL modifies the linalg-fusion pass such that it does not tile anymore as part of the pass. Tiling is a separate concern that enables linalg fusion but should happen before.
This makes fusion more composable with other decisions.
In particular the fusion pass now becomes greedy and only applies the transformation on a best-effort basis.

This should also let fusion work in a multi-hop fashion with chains of producer/consumers.

Since the fusion pass does not perform tiling anymore, tests are rewritten to be in pretiled form and make the intent of the test clearer (albeit more verbose).

PiperOrigin-RevId: 271357741

mlir/include/mlir/Dialect/Linalg/Passes.h
mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp
mlir/lib/Dialect/Linalg/CMakeLists.txt
mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
mlir/test/Dialect/Linalg/fusion-2-level.mlir
mlir/test/Dialect/Linalg/fusion.mlir

index 2b58df7..890e98d 100644 (file)
@@ -31,8 +31,7 @@ class ModuleOp;
 template <typename T> class OpPassBase;
 
 namespace linalg {
-std::unique_ptr<OpPassBase<FuncOp>>
-createLinalgFusionPass(ArrayRef<int64_t> tileSizes = {});
+std::unique_ptr<OpPassBase<FuncOp>> createLinalgFusionPass();
 
 std::unique_ptr<OpPassBase<FuncOp>>
 createLinalgTilingPass(ArrayRef<int64_t> tileSizes = {},
index 3fab843..ed904e0 100644 (file)
@@ -32,6 +32,22 @@ using namespace mlir::linalg;
 
 using llvm::dbgs;
 
+static StringRef toStringRef(LinalgDependenceGraph::DependenceType dt) {
+  switch (dt) {
+  case LinalgDependenceGraph::DependenceType::RAW:
+    return "RAW";
+  case LinalgDependenceGraph::DependenceType::RAR:
+    return "RAR";
+  case LinalgDependenceGraph::DependenceType::WAR:
+    return "WAR";
+  case LinalgDependenceGraph::DependenceType::WAW:
+    return "WAW";
+  default:
+    break;
+  }
+  llvm_unreachable("Unexpected DependenceType");
+}
+
 Value *Aliases::find(Value *v) {
   if (isa<BlockArgument>(v))
     return v;
@@ -82,8 +98,8 @@ LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
 void LinalgDependenceGraph::addDependenceElem(DependenceType dt,
                                               LinalgOpView indexingOpView,
                                               LinalgOpView dependentOpView) {
-  LLVM_DEBUG(dbgs() << "\nAdd dep type " << dt << ":\t" << *indexingOpView.op
-                    << " -> " << *dependentOpView.op);
+  LLVM_DEBUG(dbgs() << "\nAdd dep type " << toStringRef(dt) << ":\t"
+                    << *indexingOpView.op << " -> " << *dependentOpView.op);
   dependencesFromGraphs[dt][indexingOpView.op].push_back(
       LinalgDependenceGraphElem{dependentOpView, indexingOpView.view});
   dependencesIntoGraphs[dt][dependentOpView.op].push_back(
@@ -202,9 +218,9 @@ LinalgDependenceGraph::findOperationsWithCoveringDependences(
       if (view && !aliases.alias(view, dependence.indexingView))
         continue;
       auto *op = dependence.dependentOpView.op;
-      LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type " << dt
-                        << ": " << *src << " -> " << *op << " on "
-                        << *dependence.indexingView);
+      LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type "
+                        << toStringRef(dt) << ": " << *src << " -> " << *op
+                        << " on " << *dependence.indexingView);
       res.push_back(op);
     }
   }
index 8eea5dc..bd86893 100644 (file)
@@ -18,6 +18,7 @@ add_llvm_library(MLIRLinalg
 add_dependencies(MLIRLinalg
 
   MLIRAffineOps
+  MLIRAnalysis
   MLIRLinalgOpsIncGen
   MLIRLinalgLibraryOpsIncGen
   MLIRStandardToLLVM
index 731d3b8..42ff5ce 100644 (file)
@@ -19,6 +19,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/Dominance.h"
 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
@@ -51,14 +52,12 @@ using llvm::dbgs;
 /// Implements a simple high-level fusion pass of linalg library operations.
 ///
 /// In each block, linalg ops are processed in reverse textual order.
-/// Given a linalg op, fusion occurs by:
-///   1. tiling the op by a given multi-dimensional tile size;
-///   2. inspecting the linalg ops that write into the views read by the op in
-///      step 1. This uses the SSA value of the views to determine producer-
-///      consumer dependences: only identical SSA views are considered for
-///      fusion at this point;
-///   3. greedily fuse the producing linalg ops into the consuming loop tiles;
-///   4. inspect the fused ops and determine whether they have other remaining
+/// Given a linalg op `O`, fusion occurs by:
+///   1. inspecting the linalg ops that write into the views read by `O`. This
+///      uses the SSA value of the views and a simple subview/slice analysis to
+///      determine producer-consumer dependences;
+///   2. greedily fuse the linalg ops that produce subview
+///   3. inspect the fused ops and determine whether they have other remaining
 ///      LinalgOp uses. If not, then erase the original producing linalg op.
 ///
 /// More advanced use cases, analyses as well as profitability heuristics are
@@ -102,7 +101,7 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
                         << "\t"
                         << "loopPos: " << loopPos << "\t" << viewRanges[d]);
     }
-    // TODO(ntv) opportunities for folding/CSE here rather than build new IR.
+    // TODO(ntv): opportunities for folding/CSE here rather than build new IR.
     clonedViews.push_back(b.create<SubViewOp>(loc, view, viewRanges));
   }
   auto operands = getAssumedNonViewOperands(op);
@@ -115,10 +114,13 @@ struct ViewDimension {
   unsigned dimension;
 };
 
+// Given an `op`, returns the first (`view`, `dimension`) pair that identifies
+// the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
+// guarantees at least one such dimension is found. If multiple candidates exist
+// they must agree by construction (i.e. have the same size) and we just return
+// the first one.
 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
   auto maps = loopToOperandRangesMaps(op);
-  SmallVector<Value *, 8> clonedViews;
-  clonedViews.reserve(op.getNumInputsAndOutputs());
   // Iterate over the inputs and outputs in order.
   // Extract the subranges from the linearized ranges.
   SmallVector<Value *, 8> ios(op.getInputsAndOutputs());
@@ -142,39 +144,22 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
   llvm_unreachable("Expect to be able to extract a view defining loop range");
 }
 
-static Optional<LinalgOp> fuse(Value *producedView, LinalgOp producer,
-                               LinalgOp consumer, LinalgOp tiledConsumer,
-                               OperationFolder &state) {
-  auto maybeConsumerIdx = consumer.getIndexOfInput(producedView);
-  if (!maybeConsumerIdx.hasValue())
-    return llvm::None;
-  unsigned consumerIdx = maybeConsumerIdx.getValue();
-
-  auto maybeProducerIdx = producer.getIndexOfOutput(producedView);
-  if (!maybeProducerIdx.hasValue())
-    return llvm::None;
-  unsigned producerIdx = maybeProducerIdx.getValue();
-
-  // If the view is the same between consumer and tiledConsumer, this means we
-  // don't have loops and the producer cannot be fused at this level.
-  if (consumer.getInput(consumerIdx) == tiledConsumer.getInput(consumerIdx))
-    return llvm::None;
-
-  auto tiledConsumerSubView = dyn_cast_or_null<SubViewOp>(
-      tiledConsumer.getInput(consumerIdx)->getDefiningOp());
-
-  // If we don't have a slice, this also means we don't have loops and the
-  // producer cannot be fused at this level.
-  if (!tiledConsumerSubView)
-    return llvm::None;
+static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer,
+                     unsigned consumerIdx, unsigned producerIdx,
+                     OperationFolder &state) {
+  auto subView = dyn_cast_or_null<SubViewOp>(
+      consumer.getInput(consumerIdx)->getDefiningOp());
+  auto slice = dyn_cast_or_null<SliceOp>(
+      consumer.getInput(consumerIdx)->getDefiningOp());
+  assert(subView || slice);
+  (void)subView;
+  (void)slice;
 
   // loopToOperandRangesMaps are permutations-only by construction:
   //   we can always identify a data dimension with a (at least one) loop
   //   dimension.
   AffineMap producerMap =
       loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx];
-  LLVM_DEBUG(dbgs() << "Consumer Idx: " << consumerIdx << ", consumer map: "
-                    << loopToOperandRangesMaps(consumer)[consumerIdx] << "\n");
   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
                     << ", producer map: " << producerMap << "\n");
 
@@ -187,11 +172,11 @@ static Optional<LinalgOp> fuse(Value *producedView, LinalgOp producer,
   // This defines a subset of the loop ranges that we need to complete later.
   for (auto en : llvm::enumerate(producerMap.getResults())) {
     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
-    loopRanges[posInProducerLoop] = tiledConsumerSubView.getRange(en.index());
+    loopRanges[posInProducerLoop] = subView.getRange(en.index());
   }
 
-  OpBuilder b(tiledConsumer.getOperation());
-  auto loc = tiledConsumer.getLoc();
+  OpBuilder b(consumer.getOperation());
+  auto loc = consumer.getLoc();
   // Iterate over all dimensions. For the dimensions not identified by the
   // producer map for `producerIdx`, we need to explicitly compute the view that
   // defines the loop ranges using the `producer`.
@@ -216,147 +201,123 @@ static Optional<LinalgOp> fuse(Value *producedView, LinalgOp producer,
 // Some of these will be lifted in the future with better analysis.
 static bool isStructurallyFusableProducer(LinalgOp producer, Value *readView,
                                           LinalgOp consumer) {
-  // If a producer has multiple outputs, the analysis needs to take the tiling
-  // of other outputs into account.
-  if (producer.getNumOutputs() != 1)
+  if (producer.getNumOutputs() != 1) {
+    LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)");
     return false;
-  // Until subview analysis is available, same SSA value is required for fusion.
-  if (producer.getOutput(0) != readView)
+  }
+  // Must be a subview or a slice to guarantee there are loops we can fuse into.
+  auto subView = dyn_cast_or_null<SubViewOp>(readView->getDefiningOp());
+  auto slice = dyn_cast_or_null<SliceOp>(readView->getDefiningOp());
+  if (!subView && !slice) {
+    LLVM_DEBUG(dbgs() << "\nNot structurally fusable (not a subview or slice)");
     return false;
-  // No control-flow divergence supported. Only straightline op fusion allowed.
-  // TODO(ntv) allow fusion when a dominance relation exists.
-  if (producer.getOperation()->getBlock() !=
-      consumer.getOperation()->getBlock())
+  }
+  // Only fuse when the producer block dominates.
+  DominanceInfo dom(producer.getOperation());
+  if (!dom.dominates(producer.getOperation()->getBlock(),
+                     consumer.getOperation()->getBlock())) {
+    LLVM_DEBUG(
+        dbgs()
+        << "\nNot structurally fusable (producer block does not dominate)");
     return false;
+  }
   return true;
 }
 
-static void fuseLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
-  OperationFolder state(f.getContext());
-  DenseSet<Operation *> eraseSet;
-
-  LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
-
-  // 1. Record the linalg ops so we can traverse them in reverse order.
-  SmallVector<Operation *, 8> linalgOps;
-  f.walk([&](LinalgOp op) { linalgOps.push_back(op.getOperation()); });
-
-  // 2. Setup the dependences graph, aliases are populated lazily.
-  Aliases aliases;
-  LinalgDependenceGraph G(aliases, linalgOps);
+// Only consider RAW atm.
+struct FusionInfo {
+  LinalgOp originalProducer;
+  LinalgOp fusedProducer;
+};
+static Optional<FusionInfo> fuseProducerOf(LinalgOp consumer,
+                                           unsigned consumerIdx,
+                                           LinalgDependenceGraph &G,
+                                           OperationFolder &state) {
+  LLVM_DEBUG(dbgs() << "\nStart examining consumer: "
+                    << *consumer.getOperation());
+  for (auto dependence : G.getDependencesInto(
+           consumer, LinalgDependenceGraph::DependenceType::RAW)) {
+    LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
+                      << *dependence.dependentOpView.op << "\n");
+    auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
+
+    // Check that the dependence is indeed on the input `consumerIdx` view.
+    auto *readView = dependence.indexingView;
+    if (consumer.getInput(consumerIdx) != readView)
+      continue;
 
-  // 2. For each original linalg op (in reverse order to allow chained
-  // fusions).
-  for (auto *op : llvm::reverse(linalgOps)) {
-    auto consumer = cast<LinalgOp>(op);
-    LLVM_DEBUG(dbgs() << "\n******\nStart processing:\t" << *op);
-    // 3. If marked for erasure, it has already been fused. Skip fusing op.
-    if (eraseSet.count(op) > 0) {
-      LLVM_DEBUG(dbgs() << "\nAlready fused and marked for erasure, skip.");
+    // Consumer consumes this view, `isStructurallyFusableProducer` also checks
+    // whether it is a strict subview of the producer view.
+    auto *producedView = dependence.dependentOpView.view;
+    auto producerIdx = producer.getIndexOfOutput(producedView).getValue();
+    // `consumerIdx` and `producerIdx` exist by construction.
+    LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation()
+                      << " view: " << *producedView
+                      << " output index: " << producerIdx);
+
+    // Make some simple structural checks that alleviate the need for more
+    // complex analyses.
+    if (!isStructurallyFusableProducer(producer, readView, consumer)) {
+      LLVM_DEBUG(dbgs() << "\n***Not fusable:\t" << *producer.getOperation());
       continue;
     }
 
-    // 4. Apply loop tiling to enable fusion. If unsuccessful, skip fusing op.
-    auto tiledOp = tileLinalgOp(op, tileSizes, state);
-    if (!tiledOp) {
-      LLVM_DEBUG(dbgs() << "\nTile sizes did not produce loops, skip.");
+    // Check for fusion-preventing write that would violate dependences.
+    // `view` is a producer write that cannot bypass any other write or read.
+    if (!G.findCoveringDependences(producer, consumer).empty())
       continue;
-    }
 
-    // 5. For now, we only fuse RAW dependences.
-    SmallVector<Operation *, 8> fusedProducers;
-    SmallVector<Value *, 8> fusedViews;
-    for (auto dependence : G.getDependencesInto(
-             consumer, LinalgDependenceGraph::DependenceType::RAW)) {
-      auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
-      LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
-                        << *producer.getOperation() << "\n");
-
-      // a. For now we require fusion on identical SSA values, this allows us to
-      // not worry about partial writes etc.
-      // TODO(ntv) support more elaborate fusion with non identical SSA values.
-      auto *view = dependence.indexingView;
-      if (view != dependence.dependentOpView.view) {
-        LLVM_DEBUG(dbgs() << "\nviews are different SSA values, skip.");
-        continue;
-      }
-      // b. Make some simple structural checks that alleviate the need for more
-      // complex analyses.
-      if (!isStructurallyFusableProducer(producer, view, op)) {
-        LLVM_DEBUG(dbgs() << "\n***Not fusable:\t" << *producer.getOperation());
-        continue;
-      }
-      // c. Check for fusion-preventing write that would violate dependences.
-      // `view` is a producer write that cannot bypass any other write or read.
-      bool preventFusion = false;
-      for (auto *op : G.findCoveringDependences(producer, consumer))
-        if (eraseSet.count(op) == 0) {
-          preventFusion = true;
-          LLVM_DEBUG(dbgs() << "\n***Found fusion preventing dep via: " << *op);
-          break;
-        }
-      if (preventFusion)
-        continue;
-
-      // 6. Try to fuse `producer` just before `tiledOp`.
-      LLVM_DEBUG(f.print(dbgs() << "\nBefore tiledOp-fusion: \n"));
-
-      auto tOp = tiledOp->op;
-      OpBuilder builder(tOp.getOperation());
-      ScopedContext scope(builder, tOp.getLoc());
-      LLVM_DEBUG(dbgs() << "Try fuse into tiled consumer: " << *tOp << "\n");
-      auto maybeFusedProducer = fuse(view, producer, op, tOp, state);
-      if (!maybeFusedProducer) {
-        LLVM_DEBUG(dbgs() << "\nFusion did not do anything, skip.");
-        continue;
-      }
+    // Fuse `producer` just before `consumer`.
+    OpBuilder builder(consumer.getOperation());
+    ScopedContext scope(builder, consumer.getLoc());
+    LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n");
+    auto fusedProducer =
+        fuse(producedView, producer, consumer, consumerIdx, producerIdx, state);
 
-      fusedProducers.push_back(producer.getOperation());
-      fusedViews.push_back(view);
-    }
+    return FusionInfo{producer, fusedProducer};
+  }
+  return llvm::None;
+}
 
-    // 7. If no fusion occurred, or a drop the outer tiled loop which undoes
-    // everything we did.
-    if (fusedProducers.empty()) {
-      tiledOp->loops[0].erase();
-      continue;
-    }
+static void fuseLinalgOpsGreedily(FuncOp f) {
+  LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
 
-    eraseSet.insert(op);
-    eraseSet.insert(fusedProducers.begin(), fusedProducers.end());
-  }
+  OperationFolder state(f.getContext());
+  DenseSet<Operation *> eraseSet;
 
-  for (auto *op : eraseSet)
-    op->erase();
+  // Save original Linalg ops, we only want to make a pass over those.
+  SmallVector<Operation *, 8> linalgOps;
+  f.walk([&](LinalgOp op) { linalgOps.push_back(op); });
 
+  Aliases aliases;
+  LinalgDependenceGraph G(aliases, linalgOps);
+  for (auto *op : llvm::reverse(linalgOps)) {
+    for (unsigned consumerIdx = 0, e = LinalgOp(op).getNumInputs();
+         consumerIdx < e; ++consumerIdx) {
+      if (auto fusionInfo = fuseProducerOf(op, consumerIdx, G, state))
+        eraseSet.insert(fusionInfo->originalProducer.getOperation());
+    }
+  }
+
+  // The `fuseProducerOf` function performs structural checks and in particular
+  // that no covering read or write exist between the consumer and the producer.
+  // As a consequence, the only fusions that may occur preserve subsequent
+  // dependences and are guaranteed by construction to produce the whole view.
+  // We may thus erase the producer once it is fused.
+  for (auto *e : eraseSet)
+    e->erase();
   LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n"));
 }
 
 namespace {
 struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> {
-  LinalgFusionPass() = default;
-  LinalgFusionPass(ArrayRef<int64_t> sizes);
-
-  void runOnFunction() override { fuseLinalgOps(getFunction(), tileSizes); }
-
-  SmallVector<int64_t, 8> tileSizes;
+  void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); }
 };
 } // namespace
 
-LinalgFusionPass::LinalgFusionPass(ArrayRef<int64_t> sizes)
-    : LinalgFusionPass() {
-  if (!sizes.empty())
-    this->tileSizes.assign(sizes.begin(), sizes.end());
-}
-
-std::unique_ptr<OpPassBase<FuncOp>>
-mlir::linalg::createLinalgFusionPass(ArrayRef<int64_t> tileSizes) {
-  return std::make_unique<LinalgFusionPass>(tileSizes);
+std::unique_ptr<OpPassBase<FuncOp>> mlir::linalg::createLinalgFusionPass() {
+  return std::make_unique<LinalgFusionPass>();
 }
 
 static PassRegistration<LinalgFusionPass>
-    pass("linalg-fusion", "Fuse operations in the linalg dialect", [] {
-      auto pass = std::make_unique<LinalgFusionPass>();
-      pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end());
-      return pass;
-    });
+    pass("linalg-fusion", "Fuse operations in the linalg dialect");
index 29c87c7..1ef4a0b 100644 (file)
@@ -1,14 +1,60 @@
-// RUN: mlir-opt %s -linalg-fusion -linalg-fusion-tile-sizes=16 -cse | mlir-opt -linalg-fusion -linalg-fusion-tile-sizes=8 | FileCheck %s
-
+// RUN: mlir-opt %s -linalg-fusion | FileCheck %s
+#map0 = (d0) -> (d0 + 20)
+#map1 = (d0) -> (d0 + 40)
+#map2 = (d0) -> (d0 + 30)
+#map3 = (d0) -> (d0 + 2)
+#map4 = (d0) -> (d0 + 4)
+#map5 = (d0) -> (d0 + 3)
 func @f1(%A: !linalg.view<?x?xf32>, %B: !linalg.view<?x?xf32>, %C: !linalg.view<?x?xf32>, %D: !linalg.view<?x?xf32>, %E: !linalg.view<?x?xf32>) -> !linalg.view<?x?xf32> {
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %c4 = constant 4 : index
+  %c3 = constant 3 : index
+  %c2 = constant 2 : index
+  %c40 = constant 40 : index
+  %c30 = constant 30 : index
+  %c20 = constant 20 : index
+  %0 = linalg.dim %C, 0 : !linalg.view<?x?xf32>
+  %1 = linalg.dim %C, 1 : !linalg.view<?x?xf32>
+  %2 = linalg.dim %D, 1 : !linalg.view<?x?xf32>
   linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
-  linalg.matmul(%C, %D, %E) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  loop.for %arg5 = %c0 to %0 step %c20 {
+    loop.for %arg6 = %c0 to %2 step %c30 {
+      loop.for %arg7 = %c0 to %1 step %c40 {
+        %3 = affine.apply #map0(%arg5)
+        %4 = affine.apply #map1(%arg7)
+        %5 = linalg.subview %C[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view<?x?xf32>
+        %6 = affine.apply #map2(%arg6)
+        %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        %9 = linalg.dim %5, 0 : !linalg.view<?x?xf32>
+        %10 = linalg.dim %5, 1 : !linalg.view<?x?xf32>
+        %11 = linalg.dim %7, 1 : !linalg.view<?x?xf32>
+        loop.for %arg8 = %c0 to %9 step %c2 {
+          loop.for %arg9 = %c0 to %11 step %c3 {
+            loop.for %B0 = %c0 to %10 step %c4 {
+              %12 = affine.apply #map3(%arg8)
+              %13 = affine.apply #map4(%B0)
+              %14 = linalg.subview %5[%arg8, %12, %c1, %B0, %13, %c1] : !linalg.view<?x?xf32>
+              %15 = affine.apply #map5(%arg9)
+              %16 = linalg.subview %7[%B0, %13, %c1, %arg9, %15, %c1] : !linalg.view<?x?xf32>
+              %17 = linalg.subview %8[%arg8, %12, %c1, %arg9, %15, %c1] : !linalg.view<?x?xf32>
+              linalg.matmul(%14, %16, %17) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+            }
+          }
+        }
+      }
+    }
+  }
   return %E : !linalg.view<?x?xf32>
 }
 // CHECK-LABEL: func @f1
-//   CHECK-DAG: %[[c8:.*]] = constant 8
-//   CHECK-DAG: %[[c16:.*]] = constant 16
-//       CHECK:   loop.for %{{.*}} step %[[c16]] {
-//       CHECK:     loop.for %{{.*}} %[[c8]] {
-//       CHECK:       linalg.matmul
-//       CHECK:       linalg.matmul
\ No newline at end of file
+//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
+//      CHECK: loop.for
+//      CHECK:   loop.for
+//      CHECK:     loop.for
+//      CHECK:      loop.for
+//      CHECK:        loop.for
+//      CHECK:          loop.for
+//      CHECK:            linalg.matmul
+//      CHECK:            linalg.matmul
index c07d0a6..65e8c72 100644 (file)
-// RUN: mlir-opt %s -linalg-fusion -linalg-fusion-tile-sizes=0,0,0 | FileCheck %s -check-prefix=FUSE-0
-// RUN: mlir-opt %s -linalg-fusion -linalg-fusion-tile-sizes=2 | FileCheck %s -check-prefix=FUSE-2
-// RUN: mlir-opt %s -linalg-fusion -linalg-fusion-tile-sizes=2,3 | FileCheck %s -check-prefix=FUSE-23
-// RUN: mlir-opt %s -linalg-fusion -linalg-fusion-tile-sizes=2,3,4 | FileCheck %s -check-prefix=FUSE-234
+// RUN: mlir-opt %s -linalg-fusion | FileCheck %s
+
+#map0 = (d0) -> (d0 + 2)
+#map1 = (d0) -> (d0 + 4)
+#map2 = (d0) -> (d0 + 3)
 
 func @f1(%A: !linalg.view<?x?xf32>, %B: !linalg.view<?x?xf32>, %C: !linalg.view<?x?xf32>, %D: !linalg.view<?x?xf32>, %E: !linalg.view<?x?xf32>) -> !linalg.view<?x?xf32> {
+  %c0 = constant 0 : index
+  %c4 = constant 4 : index
+  %c3 = constant 3 : index
+  %c2 = constant 2 : index
+  %0 = linalg.dim %A, 0 : !linalg.view<?x?xf32>
+  %1 = linalg.dim %A, 1 : !linalg.view<?x?xf32>
+  %2 = linalg.dim %B, 1 : !linalg.view<?x?xf32>
   linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
-  linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  %c1 = constant 1 : index
+  loop.for %arg5 = %c0 to %0 step %c2 {
+    loop.for %arg6 = %c0 to %2 step %c3 {
+      loop.for %arg7 = %c0 to %1 step %c4 {
+        %3 = affine.apply #map0(%arg5)
+        %4 = affine.apply #map1(%arg7)
+        %5 = linalg.subview %A[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view<?x?xf32>
+        %6 = affine.apply #map2(%arg6)
+        %7 = linalg.subview %B[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        %8 = linalg.subview %C[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        linalg.matmul(%5, %7, %8) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+      }
+    }
+  }
   return %E : !linalg.view<?x?xf32>
 }
+// CHECK-LABEL: func @f1
+//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
 // No RAW dependences, the pass does not fuse RAR atm.
-// FUSE-0-LABEL: func @f1
-//   FUSE-0-NOT: loop.for
-// FUSE-2-LABEL: func @f1
-//   FUSE-2-NOT: loop.for
-// FUSE-23-LABEL: func @f1
-//   FUSE-23-NOT: loop.for
-// FUSE-234-LABEL: func @f1
-//   FUSE-234-NOT: loop.for
+//      CHECK: linalg.matmul
+//      CHECK: loop.for
+//      CHECK:   loop.for
+//      CHECK:     loop.for
+//      CHECK:       linalg.matmul
 
 func @f2(%A: !linalg.view<?x?xf32>, %B: !linalg.view<?x?xf32>, %C: !linalg.view<?x?xf32>, %D: !linalg.view<?x?xf32>, %E: !linalg.view<?x?xf32>) -> !linalg.view<?x?xf32> {
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %c4 = constant 4 : index
+  %c3 = constant 3 : index
+  %c2 = constant 2 : index
   linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
-  linalg.matmul(%C, %D, %E) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  %0 = linalg.dim %C, 0 : !linalg.view<?x?xf32>
+  %1 = linalg.dim %C, 1 : !linalg.view<?x?xf32>
+  %2 = linalg.dim %D, 1 : !linalg.view<?x?xf32>
+  loop.for %arg5 = %c0 to %0 step %c2 {
+    loop.for %arg6 = %c0 to %2 step %c3 {
+      loop.for %arg7 = %c0 to %1 step %c4 {
+        %3 = affine.apply #map0(%arg5)
+        %4 = affine.apply #map1(%arg7)
+        %5 = linalg.subview %C[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view<?x?xf32>
+        %6 = affine.apply #map2(%arg6)
+        %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        linalg.matmul(%5, %7, %8) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+      }
+    }
+  }
   return %E : !linalg.view<?x?xf32>
 }
-// No tiling => no fusion
-// FUSE-0-LABEL: func @f2
-//   FUSE-0-NOT: loop.for
-//
-// FUSE-2-LABEL: func @f2
-//       FUSE-2:   %[[C_0:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view<?x?xf32>
-//       FUSE-2:   loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
-//       FUSE-2:     linalg.matmul
-//       FUSE-2:     linalg.matmul
-//
-// FUSE-23-LABEL: func @f2
-//       FUSE-23:   %[[C_0:.*]] = linalg.dim %arg2, 0 : !linalg.view<?x?xf32>
-//       FUSE-23:   %[[D_1:.*]] = linalg.dim %arg3, 1 : !linalg.view<?x?xf32>
-//       FUSE-23:   loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
-//       FUSE-23:     loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
-//       FUSE-23:       linalg.matmul
-//       FUSE-23:       linalg.matmul
-//
-// FUSE-234-LABEL: func @f2
-//       FUSE-234:   %[[C_0:.*]] = linalg.dim %arg2, 0 : !linalg.view<?x?xf32>
-//       FUSE-234:   %[[C_1:.*]] = linalg.dim %arg2, 1 : !linalg.view<?x?xf32>
-//       FUSE-234:   %[[D_1:.*]] = linalg.dim %arg3, 1 : !linalg.view<?x?xf32>
-//       FUSE-234:   loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
-//       FUSE-234:     loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
-//       FUSE-234:       loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
-//       FUSE-234:         linalg.matmul
-//       FUSE-234:         linalg.matmul
+// CHECK-LABEL: func @f2
+//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
+//   CHECK-DAG:   %[[C_0:.*]] = linalg.dim %[[C]], 0 : !linalg.view<?x?xf32>
+//   CHECK-DAG:   %[[C_1:.*]] = linalg.dim %[[C]], 1 : !linalg.view<?x?xf32>
+//   CHECK-DAG:   %[[D_1:.*]] = linalg.dim %[[D]], 1 : !linalg.view<?x?xf32>
+//       CHECK:   loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
+//       CHECK:     loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
+//       CHECK:       loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
+//       CHECK:         linalg.matmul
+//       CHECK:         linalg.matmul
 
 func @f3(%A: !linalg.view<?x?xf32>, %B: !linalg.view<?x?xf32>, %C: !linalg.view<?x?xf32>, %D: !linalg.view<?x?xf32>, %E: !linalg.view<?x?xf32>) -> !linalg.view<?x?xf32> {
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %c4 = constant 4 : index
+  %c3 = constant 3 : index
+  %c2 = constant 2 : index
   linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
-  linalg.matmul(%D, %C, %E) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  %0 = linalg.dim %D, 0 : !linalg.view<?x?xf32>
+  %1 = linalg.dim %D, 1 : !linalg.view<?x?xf32>
+  %2 = linalg.dim %C, 1 : !linalg.view<?x?xf32>
+  loop.for %arg5 = %c0 to %0 step %c2 {
+    loop.for %arg6 = %c0 to %2 step %c3 {
+      loop.for %arg7 = %c0 to %1 step %c4 {
+        %3 = affine.apply #map0(%arg5)
+        %4 = affine.apply #map1(%arg7)
+        %5 = linalg.subview %D[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view<?x?xf32>
+        %6 = affine.apply #map2(%arg6)
+        %7 = linalg.subview %C[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        linalg.matmul(%5, %7, %8) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+      }
+    }
+  }
   return %E : !linalg.view<?x?xf32>
 }
-// No tiling => no fusion
-// FUSE-0-LABEL: func @f3
-//   FUSE-0-NOT: loop.for
-//
-// Read to %C does not get tiled along 1st dimension => no fusion
-// FUSE-2-LABEL: func @f3
-//   FUSE-2-NOT:   loop.for
-//
-// FUSE-23-LABEL: func @f3
-//       FUSE-23:   %[[D_0:.*]] = linalg.dim %arg3, 0 : !linalg.view<?x?xf32>
-//       FUSE-23:   %[[C_1:.*]] = linalg.dim %arg2, 1 : !linalg.view<?x?xf32>
-//       FUSE-23:   loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} {
-//       FUSE-23:     loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
-//       FUSE-23:       linalg.matmul
-//       FUSE-23:       linalg.matmul
-//
-// FUSE-234-LABEL: func @f3
-//       FUSE-234:   %[[D_0:.*]] = linalg.dim %arg3, 0 : !linalg.view<?x?xf32>
-//       FUSE-234:   %[[D_1:.*]] = linalg.dim %arg3, 1 : !linalg.view<?x?xf32>
-//       FUSE-234:   %[[C_1:.*]] = linalg.dim %arg2, 1 : !linalg.view<?x?xf32>
-//       FUSE-234:   loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} {
-//       FUSE-234:     loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
-//       FUSE-234:       loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
-//       FUSE-234:         linalg.matmul
-//       FUSE-234:         linalg.matmul
+// CHECK-LABEL: func @f3
+//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
+//          CHECK:   %[[D_0:.*]] = linalg.dim %[[D]], 0 : !linalg.view<?x?xf32>
+//          CHECK:   %[[D_1:.*]] = linalg.dim %[[D]], 1 : !linalg.view<?x?xf32>
+//          CHECK:   %[[C_1:.*]] = linalg.dim %[[C]], 1 : !linalg.view<?x?xf32>
+//          CHECK:   loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} {
+//          CHECK:     loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
+//          CHECK:       loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
+//          CHECK:         linalg.matmul
+//          CHECK:         linalg.matmul
 
 func @f4(%A: !linalg.view<?x?xf32>, %B: !linalg.view<?x?xf32>, %C: !linalg.view<?x?xf32>, %D: !linalg.view<?x?xf32>, %E: !linalg.view<?x?xf32>) -> !linalg.view<?x?xf32> {
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %c4 = constant 4 : index
+  %c3 = constant 3 : index
+  %c2 = constant 2 : index
   linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
   linalg.matmul(%A, %B, %D) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
-  linalg.matmul(%C, %D, %E) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  %0 = linalg.dim %C, 0 : !linalg.view<?x?xf32>
+  %1 = linalg.dim %C, 1 : !linalg.view<?x?xf32>
+  %2 = linalg.dim %D, 1 : !linalg.view<?x?xf32>
+  loop.for %arg5 = %c0 to %0 step %c2 {
+    loop.for %arg6 = %c0 to %2 step %c3 {
+      loop.for %arg7 = %c0 to %1 step %c4 {
+        %3 = affine.apply #map0(%arg5)
+        %4 = affine.apply #map1(%arg7)
+        %5 = linalg.subview %C[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view<?x?xf32>
+        %6 = affine.apply #map2(%arg6)
+        %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        linalg.matmul(%5, %7, %8) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+      }
+    }
+  }
   return %E : !linalg.view<?x?xf32>
 }
-// No tiling => no fusion
-// FUSE-0-LABEL: func @f4
-//   FUSE-0-NOT: loop.for
-//
-// Read to %D does not get tiled along 1st dimension => no fusion
-// FUSE-2-LABEL: func @f4
-//       FUSE-2:   linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}})
-//       FUSE-2:   %[[C_0:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view<?x?xf32>
-//       FUSE-2:   loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
-//       FUSE-2:     linalg.matmul
-//       FUSE-2:     linalg.matmul
-//
-// FUSE-23-LABEL: func @f4
-//       FUSE-23:   %[[C_0:.*]] = linalg.dim %arg2, 0 : !linalg.view<?x?xf32>
-//       FUSE-23:   %[[D_1:.*]] = linalg.dim %arg3, 1 : !linalg.view<?x?xf32>
-//       FUSE-23:   loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
-//       FUSE-23:     loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
-//       FUSE-23:       linalg.matmul
-//       FUSE-23:       linalg.matmul
-//       FUSE-23:       linalg.matmul
-//
-// FUSE-234-LABEL: func @f4
-//       FUSE-234:   %[[C_0:.*]] = linalg.dim %arg2, 0 : !linalg.view<?x?xf32>
-//       FUSE-234:   %[[C_1:.*]] = linalg.dim %arg2, 1 : !linalg.view<?x?xf32>
-//       FUSE-234:   %[[D_1:.*]] = linalg.dim %arg3, 1 : !linalg.view<?x?xf32>
-//       FUSE-234:   loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
-//       FUSE-234:     loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
-//       FUSE-234:       loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
-//       FUSE-234:         linalg.matmul
-//       FUSE-234:         linalg.matmul
-//       FUSE-234:         linalg.matmul
+// CHECK-LABEL: func @f4
+//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
+//          CHECK:   %[[C_0:.*]] = linalg.dim %[[C]], 0 : !linalg.view<?x?xf32>
+//          CHECK:   %[[C_1:.*]] = linalg.dim %[[C]], 1 : !linalg.view<?x?xf32>
+//          CHECK:   %[[D_1:.*]] = linalg.dim %[[D]], 1 : !linalg.view<?x?xf32>
+//          CHECK:   loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
+//          CHECK:     loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
+//          CHECK:       loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
+// Fuse D then fuse C, no false dependence prevent it.
+//          CHECK:         linalg.matmul
+//          CHECK:         linalg.matmul
+//          CHECK:         linalg.matmul
 
 func @f5(%A: !linalg.view<?x?xf32>, %B: !linalg.view<?x?xf32>, %C: !linalg.view<?x?xf32>, %D: !linalg.view<?x?xf32>, %E: !linalg.view<?x?xf32>) -> !linalg.view<?x?xf32> {
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %c4 = constant 4 : index
+  %c3 = constant 3 : index
+  %c2 = constant 2 : index
+  %0 = linalg.dim %B, 1 : !linalg.view<?x?xf32>
+  %1 = linalg.dim %D, 0 : !linalg.view<?x?xf32>
+  %2 = linalg.dim %D, 1 : !linalg.view<?x?xf32>
   linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
   linalg.matmul(%C, %B, %D) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
-  linalg.matmul(%D, %B, %E) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  loop.for %arg5 = %c0 to %1 step %c2 {
+    loop.for %arg6 = %c0 to %0 step %c3 {
+      loop.for %arg7 = %c0 to %2 step %c4 {
+        %3 = affine.apply #map0(%arg5)
+        %4 = affine.apply #map1(%arg7)
+        %5 = linalg.subview %D[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view<?x?xf32>
+        %6 = affine.apply #map2(%arg6)
+        %7 = linalg.subview %B[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        linalg.matmul(%5, %7, %8) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+      }
+    }
+  }
   return %E : !linalg.view<?x?xf32>
 }
-// No tiling => no fusion
-// FUSE-0-LABEL: func @f5
-//   FUSE-0-NOT: loop.for
-//
-// FUSE-2-LABEL: func @f5
-//       FUSE-2:   linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}})
-//       FUSE-2:   %[[D_0:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view<?x?xf32>
-//       FUSE-2:   loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} {
-//       FUSE-2:     linalg.matmul
-//       FUSE-2:     linalg.matmul
-//
-// FUSE-23-LABEL: func @f5
-//       FUSE-23:   linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}})
-//       FUSE-23:   %[[D_0:.*]] = linalg.dim %arg3, 0 : !linalg.view<?x?xf32>
-//       FUSE-23:   %[[B_1:.*]] = linalg.dim %arg1, 1 : !linalg.view<?x?xf32>
-//       FUSE-23:   loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} {
-//       FUSE-23:     loop.for %{{.*}} = %{{.*}} to %[[B_1]] step %{{.*}} {
-//       FUSE-23:       linalg.matmul
-//       FUSE-23:       linalg.matmul
-//
-// FUSE-234-LABEL: func @f5
-//       FUSE-234:   linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}})
-//       FUSE-234:   %[[D_0:.*]] = linalg.dim %arg3, 0 : !linalg.view<?x?xf32>
-//       FUSE-234:   %[[D_1:.*]] = linalg.dim %arg3, 1 : !linalg.view<?x?xf32>
-//       FUSE-234:   %[[B_1:.*]] = linalg.dim %arg1, 1 : !linalg.view<?x?xf32>
-//       FUSE-234:   loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} {
-//       FUSE-234:     loop.for %{{.*}} = %{{.*}} to %[[B_1]] step %{{.*}} {
-//       FUSE-234:       loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
-//       FUSE-234:         linalg.matmul
-//       FUSE-234:         linalg.matmul
+// CHECK-LABEL: func @f5
+//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
+//      CHECK-DAG:   %[[B_1:.*]] = linalg.dim %[[B]], 1 : !linalg.view<?x?xf32>
+//      CHECK-DAG:   %[[D_0:.*]] = linalg.dim %[[D]], 0 : !linalg.view<?x?xf32>
+//      CHECK-DAG:   %[[D_1:.*]] = linalg.dim %[[D]], 1 : !linalg.view<?x?xf32>
+// Don't fuse C due to false dependence, note that this is too conservative though.
+//          CHECK:   linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}})
+//          CHECK:   loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} {
+//          CHECK:     loop.for %{{.*}} = %{{.*}} to %[[B_1]] step %{{.*}} {
+//          CHECK:       loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
+//          CHECK:         linalg.matmul
+//          CHECK:         linalg.matmul
 
 func @f6(%A: !linalg.view<?x?xf32>, %B: !linalg.view<?x?xf32>, %C: !linalg.view<?x?xf32>, %D: !linalg.view<?x?xf32>, %E: !linalg.view<?x?xf32>) -> !linalg.view<?x?xf32> {
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %c4 = constant 4 : index
+  %c3 = constant 3 : index
+  %c2 = constant 2 : index
+  %0 = linalg.dim %C, 1 : !linalg.view<?x?xf32>
   linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
-  linalg.matmul(%A, %C, %D) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
-  linalg.matmul(%C, %D, %E) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  linalg.matmul(%A, %C, %E) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  %1 = linalg.dim %C, 0 : !linalg.view<?x?xf32>
+  %2 = linalg.dim %D, 1 : !linalg.view<?x?xf32>
+  loop.for %arg5 = %c0 to %1 step %c2 {
+    loop.for %arg6 = %c0 to %2 step %c3 {
+      loop.for %arg7 = %c0 to %0 step %c4 {
+        %3 = affine.apply #map0(%arg5)
+        %4 = affine.apply #map1(%arg7)
+        %5 = linalg.subview %C[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view<?x?xf32>
+        %6 = affine.apply #map2(%arg6)
+        %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        linalg.matmul(%5, %7, %8) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+      }
+    }
+  }
   return %E : !linalg.view<?x?xf32>
 }
-// Write to %C can not be fused because the 2 RAW are not compatible.
-// The current algorithm just bails out on fusion in the case of any write-based
-// interleaved dependence.
-// No tiling => no fusion
-// FUSE-0-LABEL: func @f6
-//   FUSE-0-NOT: loop.for
-//
-// Read to D is not tiled along 1st dimension => no fusion
-// FUSE-2-LABEL: func @f6
-//   FUSE-2-NOT:   loop.for
-//
-// FUSE-23-LABEL: func @f6
-//
-// FUSE-234-LABEL: func @f6
+// CHECK-LABEL: func @f6
+//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
+// Cannot fuse C due to interleaved read of C that would be bypassed.
+// Cannot fuse E (WAW).
+//   CHECK:   linalg.matmul
+//   CHECK:   linalg.matmul
+//   CHECK:   loop.for
+//   CHECK:     loop.for
+//   CHECK:       loop.for
+//   CHECK:         linalg.matmul
+// CHECK-NOT:       linalg.matmul
 
 func @f7(%A: !linalg.view<?x?xf32>, %B: !linalg.view<?x?xf32>, %C: !linalg.view<?x?xf32>, %D: !linalg.view<?x?xf32>, %E: !linalg.view<?x?xf32>) -> !linalg.view<?x?xf32> {
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %c4 = constant 4 : index
+  %c3 = constant 3 : index
+  %c2 = constant 2 : index
+  %0 = linalg.dim %A, 0 : !linalg.view<?x?xf32>
+  %1 = linalg.dim %A, 1 : !linalg.view<?x?xf32>
+  %2 = linalg.dim %C, 1 : !linalg.view<?x?xf32>
+  %3 = linalg.dim %C, 0 : !linalg.view<?x?xf32>
+  %4 = linalg.dim %D, 1 : !linalg.view<?x?xf32>
   linalg.matmul(%A, %C, %E) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
   linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
-  linalg.matmul(%A, %C, %E) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
-  linalg.matmul(%C, %D, %E) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  loop.for %arg5 = %c0 to %0 step %c2 {
+    loop.for %arg6 = %c0 to %2 step %c3 {
+      loop.for %arg7 = %c0 to %1 step %c4 {
+        %5 = affine.apply #map0(%arg5)
+        %6 = affine.apply #map1(%arg7)
+        %7 = linalg.subview %A[%arg5, %5, %c1, %arg7, %6, %c1] : !linalg.view<?x?xf32>
+        %8 = affine.apply #map2(%arg6)
+        %9 = linalg.subview %C[%arg7, %6, %c1, %arg6, %8, %c1] : !linalg.view<?x?xf32>
+        %10 = linalg.subview %E[%arg5, %5, %c1, %arg6, %8, %c1] : !linalg.view<?x?xf32>
+        linalg.matmul(%7, %9, %10) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+      }
+    }
+  }
+  loop.for %arg5 = %c0 to %3 step %c2 {
+    loop.for %arg6 = %c0 to %4 step %c3 {
+      loop.for %arg7 = %c0 to %2 step %c4 {
+        %5 = affine.apply #map0(%arg5)
+        %6 = affine.apply #map1(%arg7)
+        %7 = linalg.subview %C[%arg5, %5, %c1, %arg7, %6, %c1] : !linalg.view<?x?xf32>
+        %8 = affine.apply #map2(%arg6)
+        %9 = linalg.subview %D[%arg7, %6, %c1, %arg6, %8, %c1] : !linalg.view<?x?xf32>
+        %10 = linalg.subview %E[%arg5, %5, %c1, %arg6, %8, %c1] : !linalg.view<?x?xf32>
+        linalg.matmul(%7, %9, %10) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+      }
+    }
+  }
   return %E : !linalg.view<?x?xf32>
 }
-// The only fusion that respects dependences is the write to %C into the
-// immediately following read.
-// No tiling => no fusion
-// FUSE-0-LABEL: func @f7
-//   FUSE-0-NOT: loop.for
-//
-// Read to %C (in 3rd matmul) is not tiled along 1st dimension => no fusion
-// FUSE-2-LABEL: func @f7
-//   FUSE-2-NOT:   loop.for
-//
-// FUSE-23-LABEL: func @f7
-//       FUSE-23:   linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}})
-//       FUSE-23:   %[[A_0:.*]] = linalg.dim %arg0, 0 : !linalg.view<?x?xf32>
-//       FUSE-23:   %[[C_1:.*]] = linalg.dim %arg2, 1 : !linalg.view<?x?xf32>
-//       FUSE-23:   loop.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} {
-//       FUSE-23:     loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
-//       FUSE-23:       linalg.matmul
-//       FUSE-23:       linalg.matmul
-//       FUSE-23:   linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}})
-//
-// FUSE-234-LABEL: func @f7
-//       FUSE-234:   linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}})
-//       FUSE-234:   %[[A_0:.*]] = linalg.dim %arg0, 0 : !linalg.view<?x?xf32>
-//       FUSE-234:   %[[A_1:.*]] = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
-//       FUSE-234:   %[[C_1:.*]] = linalg.dim %arg2, 1 : !linalg.view<?x?xf32>
-//       FUSE-234:   loop.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} {
-//       FUSE-234:     loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
-//       FUSE-234:       loop.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} {
-//       FUSE-234:         linalg.matmul
-//       FUSE-234:         linalg.matmul
-//       FUSE-234:   linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}})
+// CHECK-LABEL: func @f7
+//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
+//       CHECK:   %[[A_0:.*]] = linalg.dim %[[A]], 0 : !linalg.view<?x?xf32>
+//       CHECK:   %[[A_1:.*]] = linalg.dim %[[A]], 1 : !linalg.view<?x?xf32>
+//       CHECK:   %[[C_1:.*]] = linalg.dim %[[C]], 1 : !linalg.view<?x?xf32>
+//       CHECK:   %[[C_0:.*]] = linalg.dim %[[C]], 0 : !linalg.view<?x?xf32>
+//       CHECK:   %[[D_1:.*]] = linalg.dim %[[D]], 1 : !linalg.view<?x?xf32>
+//       CHECK:   linalg.matmul(%[[A]], %[[C]], %[[E]])
+//       CHECK:   loop.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} {
+//       CHECK:     loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
+//       CHECK:       loop.for %{{.*}} = %{{.*}} to %[[A_1]] step %{{.*}} {
+//       CHECK:         linalg.matmul
+//       CHECK:         linalg.matmul
+//       CHECK:   loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} {
+//       CHECK:     loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} {
+//       CHECK:       loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} {
+//       CHECK:         linalg.matmul
+//   CHECK-NOT:         linalg.matmul
 
 func @f8(%A: !linalg.view<?x?xf32>, %B: !linalg.view<?x?xf32>, %C: !linalg.view<?x?xf32>, %D: !linalg.view<?x?xf32>, %E: !linalg.view<?x?xf32>) -> !linalg.view<?x?xf32> {
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %c4 = constant 4 : index
+  %c3 = constant 3 : index
+  %c2 = constant 2 : index
+  %0 = linalg.dim %A, 0 : !linalg.view<?x?xf32>
+  %1 = linalg.dim %A, 1 : !linalg.view<?x?xf32>
   linalg.matmul(%A, %C, %D) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
   linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
-  linalg.matmul(%A, %D, %E) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  %2 = linalg.dim %D, 1 : !linalg.view<?x?xf32>
+  loop.for %arg5 = %c0 to %0 step %c2 {
+    loop.for %arg6 = %c0 to %2 step %c3 {
+      loop.for %arg7 = %c0 to %1 step %c4 {
+        %3 = affine.apply #map0(%arg5)
+        %4 = affine.apply #map1(%arg7)
+        %5 = linalg.subview %A[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view<?x?xf32>
+        %6 = affine.apply #map2(%arg6)
+        %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view<?x?xf32>
+        linalg.matmul(%5, %7, %8) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+      }
+    }
+  }
   return %E : !linalg.view<?x?xf32>
 }
-// In this example, %D can never be fused because the WAR on %C would be violated
-// No tiling => no fusion
-// FUSE-0-LABEL: func @f8
-//   FUSE-0-NOT: loop.for
-//
-// FUSE-2-LABEL: func @f8
-//   FUSE-2-NOT:   loop.for
-//
-// FUSE-23-LABEL: func @f8
-//   FUSE-23-NOT:   loop.for
-//
-// FUSE-234-LABEL: func @f8
-//   FUSE-234-NOT:   loop.for
+// CHECK-LABEL: func @f8
+//       CHECK:   (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
+//   CHECK:   linalg.matmul
+//   CHECK:   linalg.matmul
+//   CHECK:   loop.for
+//   CHECK:     loop.for
+//   CHECK:       loop.for
+//   CHECK:         linalg.matmul
+// CHECK-NOT:       linalg.matmul
 
 #id_2d = (i, j) -> (i, j)
 #pointwise_2d_trait = {
@@ -243,51 +328,39 @@ func @f8(%A: !linalg.view<?x?xf32>, %B: !linalg.view<?x?xf32>, %C: !linalg.view<
   n_loop_types = [2, 0, 0],
   n_views = [2, 1]
 }
-
-func @pointwise(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?x?xf32>,
-                %arg2: !linalg.view<?x?xf32>, %arg3: !linalg.view<?x?xf32>) {
-  linalg.generic #pointwise_2d_trait %arg0, %arg0, %arg1 {
-  ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):   // no predecessors
-    %4 = addf %arg4, %arg5 : f32
-    linalg.yield %4 : f32
-  }: !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
-  linalg.generic #pointwise_2d_trait %arg1, %arg2, %arg3 {
-  ^bb0(%arg4: f32, %arg5: f32, %arg6: f32):   // no predecessors
-    %4 = mulf %arg4, %arg5 : f32
-    linalg.yield %4 : f32
+func @pointwise(%A: !linalg.view<?x?xf32>, %B: !linalg.view<?x?xf32>, %C: !linalg.view<?x?xf32>, %D: !linalg.view<?x?xf32>) {
+  %c1 = constant 1 : index
+  %c0 = constant 0 : index
+  %c3 = constant 3 : index
+  %c2 = constant 2 : index
+  linalg.generic #pointwise_2d_trait %A, %A, %B {
+  ^bb0(%E: f32, %arg5: f32, %arg6: f32):   // no predecessors
+    %2 = addf %E, %arg5 : f32
+    linalg.yield %2 : f32
   }: !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  %0 = linalg.dim %B, 0 : !linalg.view<?x?xf32>
+  %1 = linalg.dim %B, 1 : !linalg.view<?x?xf32>
+  loop.for %E = %c0 to %0 step %c2 {
+    loop.for %arg5 = %c0 to %1 step %c3 {
+      %2 = affine.apply #map0(%E)
+      %3 = affine.apply #map1(%arg5)
+      %4 = linalg.subview %B[%E, %2, %c1, %arg5, %3, %c1] : !linalg.view<?x?xf32>
+      %5 = linalg.subview %C[%E, %2, %c1, %arg5, %3, %c1] : !linalg.view<?x?xf32>
+      %6 = linalg.subview %D[%E, %2, %c1, %arg5, %3, %c1] : !linalg.view<?x?xf32>
+      linalg.generic #pointwise_2d_trait %4, %5, %6 {
+      ^bb0(%arg6: f32, %arg7: f32, %arg8: f32):       // no predecessors
+        %7 = mulf %arg6, %arg7 : f32
+        linalg.yield %7 : f32
+      }: !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+    }
+  }
   return
 }
-// No tiling => no fusion
-// FUSE-0-LABEL: func @pointwise
-//   FUSE-0-NOT: loop.for
-//       FUSE-0: linalg.generic
-//       FUSE-0:   addf
-//       FUSE-0: linalg.generic
-//       FUSE-0:   mulf
-//
-// FUSE-2-LABEL: func @pointwise
-//       FUSE-2:   loop.for
-//   FUSE-2-NOT:   loop.for
-//       FUSE-2:     linalg.generic
-//       FUSE-2:       addf
-//       FUSE-2:     linalg.generic
-//       FUSE-2:       mulf
-//
-// FUSE-23-LABEL: func @pointwise
-//       FUSE-23:   loop.for
-//       FUSE-23:     loop.for
-//   FUSE-23-NOT:   loop.for
-//       FUSE-23:       linalg.generic
-//       FUSE-23:         addf
-//       FUSE-23:       linalg.generic
-//       FUSE-23:         mulf
-//
-// FUSE-234-LABEL: func @pointwise
-//       FUSE-234:   loop.for
-//       FUSE-234:     loop.for
-//   FUSE-234-NOT:   loop.for
-//       FUSE-234:       linalg.generic
-//       FUSE-234:         addf
-//       FUSE-234:       linalg.generic
-//       FUSE-234:         mulf
+// CHECK-LABEL: func @pointwise
+//       CHECK:   loop.for
+//       CHECK:     loop.for
+//   CHECK-NOT:   loop.for
+//       CHECK:       linalg.generic
+//       CHECK:         addf
+//       CHECK:       linalg.generic
+//       CHECK:         mulf