Create a LoopUtil function to return perfectly nested loop set
authorMLIR Team <no-reply@google.com>
Thu, 4 Apr 2019 22:19:17 +0000 (15:19 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 5 Apr 2019 14:42:01 +0000 (07:42 -0700)
--

PiperOrigin-RevId: 242019230

mlir/include/mlir/Transforms/LoopUtils.h
mlir/lib/Transforms/LoopFusion.cpp
mlir/lib/Transforms/LoopTiling.cpp
mlir/lib/Transforms/Utils/LoopUtils.cpp

index f1e7b50..2aecdce 100644 (file)
@@ -37,14 +37,23 @@ class Value;
 /// Unrolls this for operation completely if the trip count is known to be
 /// constant. Returns failure otherwise.
 LogicalResult loopUnrollFull(AffineForOp forOp);
+
 /// Unrolls this for operation by the specified unroll factor. Returns failure
 /// if the loop cannot be unrolled either due to restrictions or due to invalid
 /// unroll factors.
 LogicalResult loopUnrollByFactor(AffineForOp forOp, uint64_t unrollFactor);
+
 /// Unrolls this loop by the specified unroll factor or its trip count,
 /// whichever is lower.
 LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor);
 
+/// Get perfectly nested sequence of loops starting at root of loop nest
+/// (the first op being another AffineFor, and the second op - a terminator).
+/// A loop is perfectly nested iff: the first op in the loop's body is another
+/// AffineForOp, and the second op is a terminator).
+void getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops,
+                             AffineForOp root);
+
 /// Unrolls and jams this loop by the specified factor. Returns success if the
 /// loop is successfully unroll-jammed.
 LogicalResult loopUnrollJamByFactor(AffineForOp forOp,
index 2ed159c..39ed5a1 100644 (file)
@@ -1062,18 +1062,9 @@ computeLoopInterchangePermutation(ArrayRef<Operation *> ops,
 // pushing loop carried dependence to a greater depth in the loop nest.
 static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
   assert(node->op->isa<AffineForOp>());
-  // Get perfectly nested sequence of loops starting at root of loop nest
-  // (the first op being another AffineFor, and the second op - a terminator).
-  // TODO(andydavis,bondhugula) Share this with similar code in loop tiling.
   SmallVector<AffineForOp, 4> loops;
   AffineForOp curr = node->op->cast<AffineForOp>();
-  loops.push_back(curr);
-  auto *currBody = curr.getBody();
-  while (currBody->begin() == std::prev(currBody->end(), 2) &&
-         (curr = curr.getBody()->front().dyn_cast<AffineForOp>())) {
-    loops.push_back(curr);
-    currBody = curr.getBody();
-  }
+  getPerfectlyNestedLoops(loops, curr);
   if (loops.size() < 2)
     return;
 
index 956d50e..c215fa3 100644 (file)
@@ -270,11 +270,7 @@ static void getTileableBands(Function &f,
   // (inclusive).
   auto getMaximalPerfectLoopNest = [&](AffineForOp root) {
     SmallVector<AffineForOp, 6> band;
-    AffineForOp currInst = root;
-    do {
-      band.push_back(currInst);
-    } while (currInst.getBody()->getOperations().size() == 2 &&
-             (currInst = currInst.getBody()->front().dyn_cast<AffineForOp>()));
+    getPerfectlyNestedLoops(band, root);
     bands->push_back(band);
   };
 
index 2b17f4b..1e9697a 100644 (file)
@@ -353,6 +353,22 @@ LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef<uint64_t> shifts,
   return success();
 }
 
+/// Get perfectly nested sequence of loops starting at root of loop nest
+/// (the first op being another AffineFor, and the second op - a terminator).
+/// A loop is perfectly nested iff: the first op in the loop's body is another
+/// AffineForOp, and the second op is a terminator).
+void mlir::getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops,
+                                   AffineForOp root) {
+  AffineForOp curr = root;
+  nestedLoops.push_back(curr);
+  auto *currBody = curr.getBody();
+  while (currBody->begin() == std::prev(currBody->end(), 2) &&
+         (curr = curr.getBody()->front().dyn_cast<AffineForOp>())) {
+    nestedLoops.push_back(curr);
+    currBody = curr.getBody();
+  }
+}
+
 /// Unrolls this loop completely.
 LogicalResult mlir::loopUnrollFull(AffineForOp forOp) {
   Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);