Enable multi-level Linalg fusion
authorNicolas Vasilache <ntv@google.com>
Wed, 24 Jul 2019 12:10:26 +0000 (05:10 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 24 Jul 2019 12:10:54 +0000 (05:10 -0700)
This CL adds support for SubViewOp in the alias analysis to permit multiple Linalg fusion passes to compose. The debugging messages are also improved for better readability. The readability benefits came in handy when tracking this issue.

A 2-level fusion test is added to capture the new behavior.

PiperOrigin-RevId: 259720246

mlir/lib/Linalg/Analysis/DependenceAnalysis.cpp
mlir/lib/Linalg/Transforms/Fusion.cpp
mlir/test/Linalg/fusion-2-level.mlir [new file with mode: 0644]

index 10b5284..f44bea3 100644 (file)
@@ -44,15 +44,25 @@ Value *Aliases::find(Value *v) {
            "Buffer or block argument expected");
     return it->getSecond();
   }
-  if (auto slice = dyn_cast_or_null<SliceOp>(v->getDefiningOp())) {
-    auto it = aliases.insert(std::make_pair(v, find(slice.getBaseView())));
-    return it.first->second;
-  }
-  if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) {
-    auto it = aliases.insert(std::make_pair(v, view.getSupportingBuffer()));
-    return it.first->second;
+
+  while (true) {
+    if (isa<BlockArgument>(v))
+      return v;
+    if (auto slice = dyn_cast_or_null<SliceOp>(v->getDefiningOp())) {
+      auto it = aliases.insert(std::make_pair(v, find(slice.getBaseView())));
+      return it.first->second;
+    }
+    if (auto view = dyn_cast_or_null<ViewOp>(v->getDefiningOp())) {
+      auto it = aliases.insert(std::make_pair(v, view.getSupportingBuffer()));
+      return it.first->second;
+    }
+    if (auto view = dyn_cast_or_null<SubViewOp>(v->getDefiningOp())) {
+      v = view.getView();
+      continue;
+    }
+    llvm::errs() << "View alias analysis reduces to: " << *v << "\n";
+    llvm_unreachable("unsupported view alias case");
   }
-  llvm_unreachable("unsupported view alias case");
 }
 
 LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases,
index 480d19f..4864f39 100644 (file)
@@ -98,7 +98,8 @@ static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
       // loopToOperandRangesMaps are permutations-only.
       unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition();
       viewRanges[d] = loopRanges[loopPos];
-      LLVM_DEBUG(dbgs() << "i,j: " << en.index() << ", " << en2.index() << "\t"
+      LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index()
+                        << "\t"
                         << "loopPos: " << loopPos << "\t" << viewRanges[d]);
     }
     // TODO(ntv) opportunities for folding/CSE here rather than build new IR.
@@ -124,12 +125,18 @@ static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) {
   for (auto en : llvm::enumerate(ios)) {
     unsigned idx = en.index();
     auto map = maps[idx];
-    LLVM_DEBUG(dbgs() << "map: " << map << "\n");
+    LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n");
+    LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n");
     Value *view = en.value();
     SmallVector<Value *, 8> viewRanges(map.getNumResults(), nullptr);
     for (auto en2 : llvm::enumerate(map.getResults())) {
-      if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition())
+      if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
+        LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth
+                          << "\n");
+        LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << *view
+                          << "\n");
         return ViewDimension{view, static_cast<unsigned>(en2.index())};
+      }
     }
   }
   llvm_unreachable("Expect to be able to extract a view defining loop range");
@@ -148,44 +155,57 @@ static Optional<LinalgOp> fuse(Value *producedView, LinalgOp producer,
     return llvm::None;
   unsigned producerIdx = maybeProducerIdx.getValue();
 
-  auto sliceOp = dyn_cast_or_null<SubViewOp>(
+  // If the view is the same between consumer and tiledConsumer, this means we
+  // don't have loops and the producer cannot be fused at this level.
+  if (consumer.getInput(consumerIdx) == tiledConsumer.getInput(consumerIdx))
+    return llvm::None;
+
+  auto tiledConsumerSubView = dyn_cast_or_null<SubViewOp>(
       tiledConsumer.getInput(consumerIdx)->getDefiningOp());
+
   // If we don't have a slice, this also means we don't have loops and the
   // producer cannot be fused at this level.
-  if (!sliceOp)
+  if (!tiledConsumerSubView)
     return llvm::None;
 
+  // loopToOperandRangesMaps are permutations-only by construction:
+  //   we can always identify a data dimension with a (at least one) loop
+  //   dimension.
   AffineMap producerMap =
       loopToOperandRangesMaps(producer)[producer.getNumInputs() + producerIdx];
-  LLVM_DEBUG(dbgs() << "Consumer Idx: " << consumerIdx << "\tmap: "
+  LLVM_DEBUG(dbgs() << "Consumer Idx: " << consumerIdx << ", consumer map: "
                     << loopToOperandRangesMaps(consumer)[consumerIdx] << "\n");
   LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx
-                    << "\tmap: " << producerMap << "\n");
+                    << ", producer map: " << producerMap << "\n");
 
   unsigned nPar = producer.getNumParallelLoops();
   unsigned nRed = producer.getNumReductionLoops();
   unsigned nWin = producer.getNumWindowLoops();
   SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin);
-  DenseSet<unsigned> fromSlice;
+
+  // Iterate over dimensions identified by the producer map for `producerIdx`.
+  // This defines a subset of the loop ranges that we need to complete later.
   for (auto en : llvm::enumerate(producerMap.getResults())) {
-    // loopToOperandRangesMaps are permutations-only.
     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
-    loopRanges[posInProducerLoop] = sliceOp.getRange(en.index());
-    fromSlice.insert(posInProducerLoop);
+    loopRanges[posInProducerLoop] = tiledConsumerSubView.getRange(en.index());
   }
 
   OpBuilder b(tiledConsumer.getOperation());
   auto loc = tiledConsumer.getLoc();
-  for (unsigned i = 0; i < loopRanges.size(); ++i) {
-    if (fromSlice.count(i))
-      LLVM_DEBUG(llvm::dbgs() << "LR: " << loopRanges[i] << "\n");
+  // Iterate over all dimensions. For the dimensions not identified by the
+  // producer map for `producerIdx`, we need to explicitly compute the view that
+  // defines the loop ranges using the `producer`.
+  for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
+    if (loopRanges[i].min)
+      LLVM_DEBUG(llvm::dbgs()
+                 << "existing LoopRange: " << loopRanges[i] << "\n");
     else {
       auto viewDim = getViewDefiningLoopRange(producer, i);
       loopRanges[i] = SubViewOp::Range{
           state.create<ConstantIndexOp>(b, loc, 0),
           linalg::intrinsics::dim(viewDim.view, viewDim.dimension),
           state.create<ConstantIndexOp>(b, loc, 1)};
-      LLVM_DEBUG(llvm::dbgs() << "new LR: " << loopRanges[i] << "\n");
+      LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
     }
   }
 
@@ -215,6 +235,8 @@ static void fuseLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
   OperationFolder state;
   DenseSet<Operation *> eraseSet;
 
+  LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n"));
+
   // 1. Record the linalg ops so we can traverse them in reverse order.
   SmallVector<Operation *, 8> linalgOps;
   f.walk<LinalgOp>(
@@ -249,7 +271,7 @@ static void fuseLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
              consumer, LinalgDependenceGraph::DependenceType::RAW)) {
       auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
       LLVM_DEBUG(dbgs() << "\n***Consider producer:\t"
-                        << *producer.getOperation());
+                        << *producer.getOperation() << "\n");
 
       // a. For now we require fusion on identical SSA values, this allows us to
       // not worry about partial writes etc.
@@ -278,9 +300,12 @@ static void fuseLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
         continue;
 
       // 6. Try to fuse `producer` just before `tiledOp`.
+      LLVM_DEBUG(f.print(dbgs() << "\nBefore tiledOp-fusion: \n"));
+
       auto tOp = tiledOp->op;
       OpBuilder builder(tOp.getOperation());
       ScopedContext scope(builder, tOp.getLoc());
+      LLVM_DEBUG(dbgs() << "Try fuse into tiled consumer: " << *tOp << "\n");
       auto maybeFusedProducer = fuse(view, producer, op, tOp, state);
       if (!maybeFusedProducer) {
         LLVM_DEBUG(dbgs() << "\nFusion did not do anything, skip.");
@@ -310,7 +335,7 @@ static void fuseLinalgOps(FuncOp f, ArrayRef<int64_t> tileSizes) {
 
 namespace {
 struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> {
-  LinalgFusionPass();
+  LinalgFusionPass() = default;
   LinalgFusionPass(ArrayRef<int64_t> sizes);
 
   void runOnFunction() { fuseLinalgOps(getFunction(), tileSizes); }
@@ -319,9 +344,6 @@ struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> {
 };
 } // namespace
 
-LinalgFusionPass::LinalgFusionPass()
-    : tileSizes(clTileSizes.begin(), clTileSizes.end()) {}
-
 LinalgFusionPass::LinalgFusionPass(ArrayRef<int64_t> sizes)
     : LinalgFusionPass() {
   if (!sizes.empty())
@@ -334,4 +356,8 @@ mlir::linalg::createLinalgFusionPass(ArrayRef<int64_t> tileSizes) {
 }
 
 static PassRegistration<LinalgFusionPass>
-    pass("linalg-fusion", "Fuse operations in the linalg dialect");
+    pass("linalg-fusion", "Fuse operations in the linalg dialect", [] {
+      auto *pass = new LinalgFusionPass();
+      pass->tileSizes.assign(clTileSizes.begin(), clTileSizes.end());
+      return pass;
+    });
diff --git a/mlir/test/Linalg/fusion-2-level.mlir b/mlir/test/Linalg/fusion-2-level.mlir
new file mode 100644 (file)
index 0000000..29c87c7
--- /dev/null
@@ -0,0 +1,14 @@
+// RUN: mlir-opt %s -linalg-fusion -linalg-fusion-tile-sizes=16 -cse | mlir-opt -linalg-fusion -linalg-fusion-tile-sizes=8 | FileCheck %s
+
+func @f1(%A: !linalg.view<?x?xf32>, %B: !linalg.view<?x?xf32>, %C: !linalg.view<?x?xf32>, %D: !linalg.view<?x?xf32>, %E: !linalg.view<?x?xf32>) -> !linalg.view<?x?xf32> {
+  linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  linalg.matmul(%C, %D, %E) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+  return %E : !linalg.view<?x?xf32>
+}
+// CHECK-LABEL: func @f1
+//   CHECK-DAG: %[[c8:.*]] = constant 8
+//   CHECK-DAG: %[[c16:.*]] = constant 16
+//       CHECK:   loop.for %{{.*}} step %[[c16]] {
+//       CHECK:     loop.for %{{.*}} %[[c8]] {
+//       CHECK:       linalg.matmul
+//       CHECK:       linalg.matmul
\ No newline at end of file