From 0d5cb90f6c44730a74f49feb6f5b624b3414e459 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Mon, 26 Sep 2022 22:09:31 +0000 Subject: [PATCH] [mlir][scf] Simplify the logic for `replaceLoopWithNewYields` for perfectly nested loops. Based on discussion in https://reviews.llvm.org/D134411, instead of first modifying the inner most loop first followed by modifying the outer loops from inside out, this patch restructures the logic to start the modification from the outer most loop. Differential Revision: https://reviews.llvm.org/D134832 --- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 108 +++++++++++++++++++---------------- 1 file changed, 59 insertions(+), 49 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index c6e416b..e99510e 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -111,56 +111,66 @@ SmallVector mlir::replaceLoopNestWithNewYields( bool replaceIterOperandsUsesInLoop) { if (loopNest.empty()) return {}; - SmallVector newLoopNest(loopNest.size()); - - newLoopNest.back() = replaceLoopWithNewYields( - builder, loopNest.back(), newIterOperands, newYieldValueFn); - - for (unsigned loopDepth : - llvm::reverse(llvm::seq(0, loopNest.size() - 1))) { - NewYieldValueFn fn = [&](OpBuilder &innerBuilder, Location loc, - ArrayRef innerNewBBArgs) { - SmallVector newYields( - newLoopNest[loopDepth + 1]->getResults().take_back( - newIterOperands.size())); - return newYields; - }; - newLoopNest[loopDepth] = - replaceLoopWithNewYields(builder, loopNest[loopDepth], newIterOperands, - fn, replaceIterOperandsUsesInLoop); - if (!replaceIterOperandsUsesInLoop) { - /// The yield is expected to producer the following structure - /// ``` - /// %0 = scf.for ... iter_args(%arg0 = %init) { - /// %1 = scf.for ... iter_args(%arg1 = %arg0) { - /// scf.yield %yield - /// } - /// } - /// ``` - /// - /// since the yield is propagated from inside out, after the inner - /// loop is processed the IR is in this form - /// - /// ``` - /// scf.for ... iter_args { - /// %1 = scf.for ... iter_args(%arg1 = %init) { - /// scf.yield %yield - /// } - /// ``` - /// - /// If `replaceIterOperandUsesInLoops` is true, there is nothing to do. - /// `%init` will be replaced with `%arg0` when it is created for the - /// outer loop. But without that this has to be done explicitly. - unsigned subLen = newIterOperands.size(); - unsigned subStart = - newLoopNest[loopDepth + 1].getNumIterOperands() - subLen; - auto resetOperands = - newLoopNest[loopDepth + 1].getInitArgsMutable().slice(subStart, - subLen); - resetOperands.assign( - newLoopNest[loopDepth].getRegionIterArgs().take_back(subLen)); - } + // This method is recursive (to make it more readable). Adding an + // assertion here to limit the recursion. (See + // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235) + assert(loopNest.size() <= 6 && + "exceeded recursion limit when yielding value from loop nest"); + + // To yield a value from a perfectly nested loop nest, the following + // pattern needs to be created, i.e. starting with + // + // ```mlir + // scf.for .. { + // scf.for .. { + // scf.for .. { + // %value = ... + // } + // } + // } + // ``` + // + // needs to be modified to + // + // ```mlir + // %0 = scf.for .. iter_args(%arg0 = %init) { + // %1 = scf.for .. iter_args(%arg1 = %arg0) { + // %2 = scf.for .. iter_args(%arg2 = %arg1) { + // %value = ... + // scf.yield %value + // } + // scf.yield %2 + // } + // scf.yield %1 + // } + // ``` + // + // The inner most loop is handled using the `replaceLoopWithNewYields` + // that works on a single loop. + if (loopNest.size() == 1) { + auto innerMostLoop = replaceLoopWithNewYields( + builder, loopNest.back(), newIterOperands, newYieldValueFn, + replaceIterOperandsUsesInLoop); + return {innerMostLoop}; } + // The outer loops are modified by calling this method recursively + // - The return value of the inner loop is the value yielded by this loop. + // - The region iter args of this loop are the init_args for the inner loop. + SmallVector newLoopNest; + NewYieldValueFn fn = + [&](OpBuilder &innerBuilder, Location loc, + ArrayRef innerNewBBArgs) -> SmallVector { + newLoopNest = replaceLoopNestWithNewYields(builder, loopNest.drop_front(), + innerNewBBArgs, newYieldValueFn, + replaceIterOperandsUsesInLoop); + return llvm::to_vector(llvm::map_range( + newLoopNest.front().getResults().take_back(innerNewBBArgs.size()), + [](OpResult r) -> Value { return r; })); + }; + scf::ForOp outerMostLoop = + replaceLoopWithNewYields(builder, loopNest.front(), newIterOperands, fn, + replaceIterOperandsUsesInLoop); + newLoopNest.insert(newLoopNest.begin(), outerMostLoop); return newLoopNest; } -- 2.7.4