[MLIR][NFC] loop tiling - improve comments / naming
authorUday Bondhugula <uday@polymagelabs.com>
Mon, 23 Mar 2020 15:02:58 +0000 (20:32 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Tue, 24 Mar 2020 02:07:19 +0000 (07:37 +0530)
Improve comments, naming, and other cleanup

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>
Differential Revision: https://reviews.llvm.org/D76616

mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp

index de818d7..3f08315 100644 (file)
@@ -177,13 +177,12 @@ constructTiledIndexSetHyperRect(MutableArrayRef<AffineForOp> origLoops,
 //  TODO(bondhugula): handle non hyper-rectangular spaces.
 LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
                                 ArrayRef<unsigned> tileSizes) {
-  assert(!band.empty());
-  assert(band.size() == tileSizes.size() && "Incorrect number of tile sizes");
+  assert(!band.empty() && "no loops in band");
+  assert(band.size() == tileSizes.size() && "Too few/many tile sizes");
 
   // Check if the supplied for op's are all successively nested.
-  for (unsigned i = 1, e = band.size(); i < e; i++) {
-    assert(band[i].getParentOp() == band[i - 1].getOperation());
-  }
+  for (unsigned i = 1, e = band.size(); i < e; i++)
+    assert(band[i].getParentOp() == band[i - 1] && "not a perfect nest / band");
 
   auto origLoops = band;
 
@@ -192,11 +191,11 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
   // Note that width is at least one since band isn't empty.
   unsigned width = band.size();
 
-  SmallVector<AffineForOp, 12> newLoops(2 * width);
-  AffineForOp innermostPointLoop;
+  SmallVector<AffineForOp, 6> tiledLoops(2 * width);
 
   // The outermost among the loops as we add more..
   auto *topLoop = rootAffineForOp.getOperation();
+  AffineForOp innermostPointLoop;
 
   // Add intra-tile (or point) loops.
   for (unsigned i = 0; i < width; i++) {
@@ -206,7 +205,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
     pointLoop.getBody()->getOperations().splice(
         pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
         topLoop);
-    newLoops[2 * width - 1 - i] = pointLoop;
+    tiledLoops[2 * width - 1 - i] = pointLoop;
     topLoop = pointLoop.getOperation();
     if (i == 0)
       innermostPointLoop = pointLoop;
@@ -220,7 +219,7 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
     tileSpaceLoop.getBody()->getOperations().splice(
         tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
         topLoop);
-    newLoops[2 * width - i - 1] = tileSpaceLoop;
+    tiledLoops[2 * width - i - 1] = tileSpaceLoop;
     topLoop = tileSpaceLoop.getOperation();
   }
 
@@ -234,16 +233,17 @@ LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
   getIndexSet(band, &cst);
 
   if (!cst.isHyperRectangular(0, width)) {
-    rootAffineForOp.emitError("tiled code generation unimplemented for the "
-                              "non-hyperrectangular case");
+    llvm::dbgs() << "tiled code generation unimplemented for the "
+                    "non-hyperrectangular case, op:"
+                 << *rootAffineForOp << "\n";
     return failure();
   }
 
-  constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes);
-  // In this case, the point loop IVs just replace the original ones.
-  for (unsigned i = 0; i < width; i++) {
-    origLoopIVs[i].replaceAllUsesWith(newLoops[i + width].getInductionVar());
-  }
+  constructTiledIndexSetHyperRect(origLoops, tiledLoops, tileSizes);
+
+  // Replace original IVs with intra-tile loop IVs.
+  for (unsigned i = 0; i < width; i++)
+    origLoopIVs[i].replaceAllUsesWith(tiledLoops[i + width].getInductionVar());
 
   // Erase the old loop nest.
   rootAffineForOp.erase();
@@ -381,6 +381,7 @@ void LoopTiling::runOnFunction() {
   std::vector<SmallVector<AffineForOp, 6>> bands;
   getTileableBands(getFunction(), &bands);
 
+  // Tile each band.
   for (auto &band : bands) {
     // Set up tile sizes; fill missing tile sizes at the end with default tile
     // size or clTileSize if one was provided.
@@ -389,7 +390,7 @@ void LoopTiling::runOnFunction() {
     if (llvm::DebugFlag) {
       auto diag = band[0].emitRemark("using tile sizes [");
       for (auto tSize : tileSizes)
-        diag << tSize << " ";
+        diag << tSize << ' ';
       diag << "]\n";
     }
     if (failed(tileCodeGen(band, tileSizes)))