[MLIR][Affine] Make fusion helper check method significantly more efficient
authorUday Bondhugula <uday@polymagelabs.com>
Wed, 28 Dec 2022 20:06:21 +0000 (01:36 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Wed, 28 Dec 2022 20:20:29 +0000 (01:50 +0530)
The `hasDependencePath` method in affine fusion is quite inefficient as
it does a DFS on the complete graph for what is a small part of the
checks before fusion can be performed. Make this efficient by using the
fact that the nodes involved are all at the top-level of the same block.
With this change, for large graphs with about 10,000 nodes, the check
runs in a few seconds instead of not terminating even in a few hours.

This is NFC from a functionality standpoint; it only leads to an
improvement in pass running time on large IR.

Differential Revision: https://reviews.llvm.org/D140522

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

index 90e2f75..5628b82 100644 (file)
@@ -328,11 +328,13 @@ public:
   }
 
   // Returns true if there is a path in the dependence graph from node 'srcId'
-  // to node 'dstId'. Returns false otherwise.
+  // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
+  // operations that the edges connected are expected to be from the same block.
   bool hasDependencePath(unsigned srcId, unsigned dstId) {
     // Worklist state is: <node-id, next-output-edge-index-to-visit>
     SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
     worklist.push_back({srcId, 0});
+    Operation *dstOp = getNode(dstId)->op;
     // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
     while (!worklist.empty()) {
       auto &idAndIndex = worklist.back();
@@ -350,8 +352,12 @@ public:
       Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
       // Increment next output edge index for 'idAndIndex'.
       ++idAndIndex.second;
-      // Add node at 'edge.id' to worklist.
-      worklist.push_back({edge.id, 0});
+      // Add node at 'edge.id' to the worklist. We don't need to consider
+      // nodes that are "after" dstId in the containing block; one can't have a
+      // path to `dstId` from any of those nodes.
+      bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op);
+      if (!afterDst && edge.id != idAndIndex.first)
+        worklist.push_back({edge.id, 0});
     }
     return false;
   }