[MLIR] Remove unnecessary Block argument on MemRefDependenceGraph::init
authorUday Bondhugula <uday@polymagelabs.com>
Tue, 31 Jan 2023 01:22:45 +0000 (06:52 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Tue, 31 Jan 2023 01:22:53 +0000 (06:52 +0530)
Remove unnecessary Block argument on MemRefDependenceGraph::init.
`block` is already a field on MDG.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D142597

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

index db39a83..79e8949 100644 (file)
@@ -210,9 +210,9 @@ public:
 
   MemRefDependenceGraph(Block &block) : block(block) {}
 
-  // Initializes the dependence graph based on operations in 'f'.
+  // Initializes the dependence graph based on operations in `block'.
   // Returns true on success, false otherwise.
-  bool init(Block *block);
+  bool init();
 
   // Returns the graph node for 'id'.
   Node *getNode(unsigned id) {
@@ -771,16 +771,14 @@ void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
 
 // Initializes the data dependence graph by walking operations in `block`.
 // Assigns each node in the graph a node id based on program order in 'f'.
-// TODO: Add support for taking a Block arg to construct the
-// dependence graph at a different depth.
-bool MemRefDependenceGraph::init(Block *block) {
+bool MemRefDependenceGraph::init() {
   LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
   // Map from a memref to the set of ids of the nodes that have ops accessing
   // the memref.
   DenseMap<Value, SetVector<unsigned>> memrefAccesses;
 
   DenseMap<Operation *, unsigned> forToNodeMap;
-  for (Operation &op : *block) {
+  for (Operation &op : block) {
     if (auto forOp = dyn_cast<AffineForOp>(op)) {
       // Create graph node 'id' to represent top-level 'forOp' and record
       // all loads and store accesses it contains.
@@ -859,8 +857,8 @@ bool MemRefDependenceGraph::init(Block *block) {
     for (Value value : opInst->getResults()) {
       for (Operation *user : value.getUsers()) {
         // Ignore users outside of the block.
-        if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() !=
-            block)
+        if (block.getParent()->findAncestorOpInRegion(*user)->getBlock() !=
+            &block)
           continue;
         SmallVector<AffineForOp, 4> loops;
         getAffineForIVs(*user, &loops);
@@ -1132,7 +1130,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
   });
 
   if (maxLegalFusionDepth == 0) {
-    LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth == 0 .\n");
+    LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n");
     return false;
   }
 
@@ -1170,7 +1168,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
   MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
   if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
     LLVM_DEBUG(llvm::dbgs()
-               << "Unable to compute MemRefRegion for source operation\n.");
+               << "Unable to compute MemRefRegion for source operation\n");
     return false;
   }
 
@@ -1195,7 +1193,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
                               dstLoopNestStats, slice,
                               &fusedLoopNestComputeCost)) {
-      LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n.");
+      LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
       continue;
     }
 
@@ -1565,7 +1563,7 @@ public:
         if (!srcEscapingMemRefs.empty() &&
             hasNonAffineUsersOnThePath(srcId, dstId, mdg)) {
           LLVM_DEBUG(llvm::dbgs()
-                     << "Can't fuse: non-affine users in between the loops\n.");
+                     << "Can't fuse: non-affine users in between the loops\n");
           continue;
         }
 
@@ -2032,8 +2030,10 @@ public:
 /// Run fusion on `block`.
 void LoopFusion::runOnBlock(Block *block) {
   MemRefDependenceGraph g(*block);
-  if (!g.init(block))
+  if (!g.init()) {
+    LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n");
     return;
+  }
 
   std::optional<unsigned> fastMemorySpaceOpt;
   if (fastMemorySpace.hasValue())