// TODO(bondhugula): handle non hyper-rectangular spaces.
LogicalResult mlir::tileCodeGen(MutableArrayRef<AffineForOp> band,
ArrayRef<unsigned> tileSizes) {
- assert(!band.empty());
- assert(band.size() == tileSizes.size() && "Incorrect number of tile sizes");
+ assert(!band.empty() && "no loops in band");
+ assert(band.size() == tileSizes.size() && "Too few/many tile sizes");
// Check if the supplied for op's are all successively nested.
- for (unsigned i = 1, e = band.size(); i < e; i++) {
- assert(band[i].getParentOp() == band[i - 1].getOperation());
- }
+ for (unsigned i = 1, e = band.size(); i < e; i++)
+ assert(band[i].getParentOp() == band[i - 1] && "not a perfect nest / band");
auto origLoops = band;
// Note that width is at least one since band isn't empty.
unsigned width = band.size();
- SmallVector<AffineForOp, 12> newLoops(2 * width);
- AffineForOp innermostPointLoop;
+ SmallVector<AffineForOp, 6> tiledLoops(2 * width);
// The outermost among the loops as we add more..
auto *topLoop = rootAffineForOp.getOperation();
+ AffineForOp innermostPointLoop;
// Add intra-tile (or point) loops.
for (unsigned i = 0; i < width; i++) {
pointLoop.getBody()->getOperations().splice(
pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
topLoop);
- newLoops[2 * width - 1 - i] = pointLoop;
+ tiledLoops[2 * width - 1 - i] = pointLoop;
topLoop = pointLoop.getOperation();
if (i == 0)
innermostPointLoop = pointLoop;
tileSpaceLoop.getBody()->getOperations().splice(
tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(),
topLoop);
- newLoops[2 * width - i - 1] = tileSpaceLoop;
+ tiledLoops[2 * width - i - 1] = tileSpaceLoop;
topLoop = tileSpaceLoop.getOperation();
}
getIndexSet(band, &cst);
if (!cst.isHyperRectangular(0, width)) {
- rootAffineForOp.emitError("tiled code generation unimplemented for the "
- "non-hyperrectangular case");
+ llvm::dbgs() << "tiled code generation unimplemented for the "
+ "non-hyperrectangular case, op:"
+ << *rootAffineForOp << "\n";
return failure();
}
- constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes);
- // In this case, the point loop IVs just replace the original ones.
- for (unsigned i = 0; i < width; i++) {
- origLoopIVs[i].replaceAllUsesWith(newLoops[i + width].getInductionVar());
- }
+ constructTiledIndexSetHyperRect(origLoops, tiledLoops, tileSizes);
+
+ // Replace original IVs with intra-tile loop IVs.
+ for (unsigned i = 0; i < width; i++)
+ origLoopIVs[i].replaceAllUsesWith(tiledLoops[i + width].getInductionVar());
// Erase the old loop nest.
rootAffineForOp.erase();
std::vector<SmallVector<AffineForOp, 6>> bands;
getTileableBands(getFunction(), &bands);
+ // Tile each band.
for (auto &band : bands) {
// Set up tile sizes; fill missing tile sizes at the end with default tile
// size or clTileSize if one was provided.
if (llvm::DebugFlag) {
auto diag = band[0].emitRemark("using tile sizes [");
for (auto tSize : tileSizes)
- diag << tSize << " ";
+ diag << tSize << ' ';
diag << "]\n";
}
if (failed(tileCodeGen(band, tileSizes)))