NFC. Refactor/update some affine fusion pass code for readability
authorUday Bondhugula <uday@polymagelabs.com>
Sun, 11 Dec 2022 19:12:48 +0000 (00:42 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Sun, 11 Dec 2022 19:48:17 +0000 (01:18 +0530)
NFC. Refactor some affine fusion pass code for readability. Some of its
methods are too long. This is the first among some NFC changes before new
features/related updates are posted. Add missing code comments at a couple of
places.

Differential Revision: https://reviews.llvm.org/D139255

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp

index 63b7ad7..df6b30e 100644 (file)
@@ -12,7 +12,6 @@
 
 #include "mlir/Dialect/Affine/Passes.h"
 
-#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
@@ -224,7 +223,7 @@ public:
   // Returns the graph node for 'forOp'.
   Node *getForOpNode(AffineForOp forOp) {
     for (auto &idAndNode : nodes)
-      if (idAndNode.second.op == forOp.getOperation())
+      if (idAndNode.second.op == forOp)
         return &idAndNode.second;
     return nullptr;
   }
@@ -711,27 +710,35 @@ gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
                                 producerConsumerMemrefs);
 }
 
+/// A memref escapes the function if either:
+///   1. it is a function argument, or
+///   2. it is used by a non-affine op (e.g., std load/store, std
+///   call, etc.)
+/// FIXME: Support alias creating ops like memref view ops.
+static bool isEscapingMemref(Value memref) {
+  // Check if 'memref' escapes because it's a block argument.
+  if (memref.isa<BlockArgument>())
+    return true;
+
+  // Check if 'memref' escapes through a non-affine op (e.g., std load/store,
+  // call op, etc.). This already covers aliases created from this.
+  for (Operation *user : memref.getUsers())
+    if (!isa<AffineMapAccessInterface>(*user))
+      return true;
+  return false;
+}
+
 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
-/// that escape the function. A memref escapes the function if either:
-///   1. It's a function argument, or
-///   2. It's used by a non-affine op (e.g., std load/store, std call, etc.)
+/// that escape the function.
 void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
                            DenseSet<Value> &escapingMemRefs) {
   auto *node = mdg->getNode(id);
-  for (auto *storeOpInst : node->stores) {
-    auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+  for (Operation *storeOp : node->stores) {
+    auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
     if (escapingMemRefs.count(memref))
       continue;
-    // Check if 'memref' escapes because it's a block argument.
-    if (memref.isa<BlockArgument>()) {
+    if (isEscapingMemref(memref))
       escapingMemRefs.insert(memref);
-      continue;
-    }
-    // Check if 'memref' escapes through a non-affine op (e.g., std load/store,
-    // call op, etc.).
-    for (Operation *user : memref.getUsers())
-      if (!isa<AffineMapAccessInterface>(*user))
-        escapingMemRefs.insert(memref);
   }
 }
 
@@ -743,6 +750,8 @@ void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
 // dependence graph at a different depth.
 bool MemRefDependenceGraph::init(func::FuncOp f) {
   LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n");
+  // Map from a memref to the set of ids of the nodes that have ops accessing
+  // the memref.
   DenseMap<Value, SetVector<unsigned>> memrefAccesses;
 
   // TODO: support multi-block functions.
@@ -832,8 +841,8 @@ bool MemRefDependenceGraph::init(func::FuncOp f) {
         getLoopIVs(*user, &loops);
         if (loops.empty())
           continue;
-        assert(forToNodeMap.count(loops[0].getOperation()) > 0);
-        unsigned userLoopNestId = forToNodeMap[loops[0].getOperation()];
+        assert(forToNodeMap.count(loops[0]) > 0);
+        unsigned userLoopNestId = forToNodeMap[loops[0]];
         addEdge(node.id, userLoopNestId, value);
       }
     }
@@ -866,7 +875,7 @@ bool MemRefDependenceGraph::init(func::FuncOp f) {
 static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
   assert(isa<AffineForOp>(node->op));
   AffineForOp newRootForOp = sinkSequentialLoops(cast<AffineForOp>(node->op));
-  node->op = newRootForOp.getOperation();
+  node->op = newRootForOp;
 }
 
 //  TODO: improve/complete this when we have target data.
@@ -893,7 +902,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
                                  unsigned dstLoopDepth,
                                  Optional<unsigned> fastMemorySpace,
                                  uint64_t localBufSizeThreshold) {
-  auto *forInst = forOp.getOperation();
+  Operation *forInst = forOp.getOperation();
 
   // Create builder to insert alloc op just before 'forOp'.
   OpBuilder b(forInst);
@@ -1418,6 +1427,10 @@ public:
     eraseUnusedMemRefAllocations();
   }
 
+  /// Visit each node in the graph, and for each node, attempt to fuse it with
+  /// producer-consumer candidates. No fusion is performed when producers with a
+  /// user count greater than `maxSrcUserCount` for any of the memrefs involved
+  /// are encountered.
   void fuseProducerConsumerNodes(unsigned maxSrcUserCount) {
     LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n");
     init();
@@ -1628,8 +1641,8 @@ public:
                      << dstAffineForOp << "\n");
 
           // Move 'dstAffineForOp' before 'insertPointInst' if needed.
-          if (fusedLoopInsPoint != dstAffineForOp.getOperation())
-            dstAffineForOp.getOperation()->moveBefore(fusedLoopInsPoint);
+          if (fusedLoopInsPoint != dstAffineForOp)
+            dstAffineForOp->moveBefore(fusedLoopInsPoint);
 
           // Update edges between 'srcNode' and 'dstNode'.
           mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs,
@@ -1642,8 +1655,7 @@ public:
             dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
               Value storeMemRef = storeOp.getMemRef();
               if (privateMemrefs.count(storeMemRef) > 0)
-                privateMemRefToStores[storeMemRef].push_back(
-                    storeOp.getOperation());
+                privateMemRefToStores[storeMemRef].push_back(storeOp);
             });
 
             // Replace original memrefs with private memrefs. Note that all the
@@ -1672,7 +1684,7 @@ public:
 
           // Collect dst loop stats after memref privatization transformation.
           LoopNestStateCollector dstLoopCollector;
-          dstLoopCollector.collect(dstAffineForOp.getOperation());
+          dstLoopCollector.collect(dstAffineForOp);
 
           // Clear and add back loads and stores.
           mdg->clearNodeLoadAndStores(dstNode->id);
@@ -1798,7 +1810,7 @@ public:
 
       auto dstForInst = cast<AffineForOp>(dstNode->op);
       // Update operation position of fused loop nest (if needed).
-      if (insertPointInst != dstForInst.getOperation()) {
+      if (insertPointInst != dstForInst) {
         dstForInst->moveBefore(insertPointInst);
       }
       // Update data dependence graph state post fusion.
@@ -1939,7 +1951,7 @@ public:
     // Collect dst loop stats after memref privatization transformation.
     auto dstForInst = cast<AffineForOp>(dstNode->op);
     LoopNestStateCollector dstLoopCollector;
-    dstLoopCollector.collect(dstForInst.getOperation());
+    dstLoopCollector.collect(dstForInst);
     // Clear and add back loads and stores
     mdg->clearNodeLoadAndStores(dstNode->id);
     mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
index 7378418..889c907 100644 (file)
 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
-#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BlockAndValueMapping.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Operation.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -113,7 +106,7 @@ static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
         }
         return WalkResult::advance();
       }
-      for (auto value : op->getResults()) {
+      for (Value value : op->getResults()) {
         for (Operation *user : value.getUsers()) {
           SmallVector<AffineForOp, 4> loops;
           // Check if any loop in loop nest surrounding 'user' is 'opB'.
@@ -137,15 +130,12 @@ static Operation *getLastDependentOpInRange(Operation *opA, Operation *opB) {
 // dependences. Returns nullptr if no such insertion point is found.
 static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
                                                  AffineForOp dstForOp) {
-  bool isSrcForOpBeforeDstForOp =
-      srcForOp->isBeforeInBlock(dstForOp.getOperation());
+  bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
 
-  auto *firstDepOpA =
-      getFirstDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
-  auto *lastDepOpB =
-      getLastDependentOpInRange(forOpA.getOperation(), forOpB.getOperation());
+  Operation *firstDepOpA = getFirstDependentOpInRange(forOpA, forOpB);
+  Operation *lastDepOpB = getLastDependentOpInRange(forOpA, forOpB);
   // Block:
   //      ...
   //  |-- opA
@@ -170,7 +160,7 @@ static Operation *getFusedLoopNestInsertionPoint(AffineForOp srcForOp,
   }
   // No dependences from 'opA' to operation in range ('opA', 'opB'), return
   // 'opB' insertion point.
-  return forOpB.getOperation();
+  return forOpB;
 }
 
 // Gathers all load and store ops in loop nest rooted at 'forOp' into
@@ -281,8 +271,7 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
   }
 
   // Check if 'srcForOp' precedes 'dstForOp' in 'block'.
-  bool isSrcForOpBeforeDstForOp =
-      srcForOp->isBeforeInBlock(dstForOp.getOperation());
+  bool isSrcForOpBeforeDstForOp = srcForOp->isBeforeInBlock(dstForOp);
   // 'forOpA' executes before 'forOpB' in 'block'.
   auto forOpA = isSrcForOpBeforeDstForOp ? srcForOp : dstForOp;
   auto forOpB = isSrcForOpBeforeDstForOp ? dstForOp : srcForOp;
@@ -315,8 +304,8 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
   }
 
   // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'.
-  unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops(
-      *srcForOp.getOperation(), *dstForOp.getOperation());
+  unsigned numCommonLoops =
+      mlir::getNumCommonSurroundingLoops(*srcForOp, *dstForOp);
 
   // Filter out ops in 'opsA' to compute the slice union based on the
   // assumptions made by the fusion strategy.
@@ -539,8 +528,8 @@ static int64_t getComputeCostHelper(
   int64_t opCount = stats.opCountMap[forOp] - 1;
   if (stats.loopMap.count(forOp) > 0) {
     for (auto childForOp : stats.loopMap[forOp]) {
-      opCount += getComputeCostHelper(childForOp.getOperation(), stats,
-                                      tripCountOverrideMap, computeCostMap);
+      opCount += getComputeCostHelper(childForOp, stats, tripCountOverrideMap,
+                                      computeCostMap);
     }
   }
   // Add in additional op instances from slice (if specified in map).
@@ -567,7 +556,7 @@ static int64_t getComputeCostHelper(
 /// instance count (i.e. total number of operations in the loop body * loop
 /// trip count) for the entire loop nest.
 int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) {
-  return getComputeCostHelper(forOp.getOperation(), stats,
+  return getComputeCostHelper(forOp, stats,
                               /*tripCountOverrideMap=*/nullptr,
                               /*computeCostMap=*/nullptr);
 }
@@ -611,8 +600,8 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
       computeCostMap[insertPointParent] = -storeCount;
     // Subtract out any load users of 'storeMemrefs' nested below
     // 'insertPointParent'.
-    for (auto value : storeMemrefs) {
-      for (auto *user : value.getUsers()) {
+    for (Value memref : storeMemrefs) {
+      for (auto *user : memref.getUsers()) {
         if (auto loadOp = dyn_cast<AffineReadOpInterface>(user)) {
           SmallVector<AffineForOp, 4> loops;
           // Check if any loop in loop nest surrounding 'user' is
@@ -633,13 +622,13 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats,
 
   // Compute op instance count for the src loop nest with iteration slicing.
   int64_t sliceComputeCost = getComputeCostHelper(
-      srcForOp.getOperation(), srcStats, &sliceTripCountMap, &computeCostMap);
+      srcForOp, srcStats, &sliceTripCountMap, &computeCostMap);
 
   // Compute cost of fusion for this depth.
   computeCostMap[insertPointParent] = sliceComputeCost;
 
   *computeCost =
-      getComputeCostHelper(dstForOp.getOperation(), dstStats,
+      getComputeCostHelper(dstForOp, dstStats,
                            /*tripCountOverrideMap=*/nullptr, &computeCostMap);
   return true;
 }