[mlir][scf] Simplify the logic for `replaceLoopWithNewYields` for perfectly nested...
authorMahesh Ravishankar <ravishankarm@google.com>
Mon, 26 Sep 2022 22:09:31 +0000 (22:09 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Thu, 29 Sep 2022 16:52:02 +0000 (16:52 +0000)
Based on discussion in https://reviews.llvm.org/D134411, instead of
first modifying the inner most loop first followed by modifying the
outer loops from inside out, this patch restructures the logic to
start the modification from the outer most loop.

Differential Revision: https://reviews.llvm.org/D134832

mlir/lib/Dialect/SCF/Utils/Utils.cpp

index c6e416b..e99510e 100644 (file)
@@ -111,56 +111,66 @@ SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
     bool replaceIterOperandsUsesInLoop) {
   if (loopNest.empty())
     return {};
-  SmallVector<scf::ForOp> newLoopNest(loopNest.size());
-
-  newLoopNest.back() = replaceLoopWithNewYields(
-      builder, loopNest.back(), newIterOperands, newYieldValueFn);
-
-  for (unsigned loopDepth :
-       llvm::reverse(llvm::seq<unsigned>(0, loopNest.size() - 1))) {
-    NewYieldValueFn fn = [&](OpBuilder &innerBuilder, Location loc,
-                             ArrayRef<BlockArgument> innerNewBBArgs) {
-      SmallVector<Value> newYields(
-          newLoopNest[loopDepth + 1]->getResults().take_back(
-              newIterOperands.size()));
-      return newYields;
-    };
-    newLoopNest[loopDepth] =
-        replaceLoopWithNewYields(builder, loopNest[loopDepth], newIterOperands,
-                                 fn, replaceIterOperandsUsesInLoop);
-    if (!replaceIterOperandsUsesInLoop) {
-      /// The yield is expected to producer the following structure
-      /// ```
-      /// %0 = scf.for ... iter_args(%arg0 = %init) {
-      ///   %1 = scf.for ... iter_args(%arg1 = %arg0) {
-      ///     scf.yield %yield
-      ///   }
-      /// }
-      /// ```
-      ///
-      /// since the yield is propagated from inside out, after the inner
-      /// loop is processed the IR is in this form
-      ///
-      /// ```
-      /// scf.for ... iter_args {
-      ///   %1 = scf.for ... iter_args(%arg1 = %init) {
-      ///     scf.yield %yield
-      ///   }
-      /// ```
-      ///
-      /// If `replaceIterOperandUsesInLoops` is true, there is nothing to do.
-      /// `%init` will be replaced with `%arg0` when it is created for the
-      /// outer loop. But without that this has to be done explicitly.
-      unsigned subLen = newIterOperands.size();
-      unsigned subStart =
-          newLoopNest[loopDepth + 1].getNumIterOperands() - subLen;
-      auto resetOperands =
-          newLoopNest[loopDepth + 1].getInitArgsMutable().slice(subStart,
-                                                                subLen);
-      resetOperands.assign(
-          newLoopNest[loopDepth].getRegionIterArgs().take_back(subLen));
-    }
+  // This method is recursive (to make it more readable). Adding an
+  // assertion here to limit the recursion. (See
+  // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
+  assert(loopNest.size() <= 6 &&
+         "exceeded recursion limit when yielding value from loop nest");
+
+  // To yield a value from a perfectly nested loop nest, the following
+  // pattern needs to be created, i.e. starting with
+  //
+  // ```mlir
+  //  scf.for .. {
+  //    scf.for .. {
+  //      scf.for .. {
+  //        %value = ...
+  //      }
+  //    }
+  //  }
+  // ```
+  //
+  // needs to be modified to
+  //
+  // ```mlir
+  // %0 = scf.for .. iter_args(%arg0 = %init) {
+  //   %1 = scf.for .. iter_args(%arg1 = %arg0) {
+  //     %2 = scf.for .. iter_args(%arg2 = %arg1) {
+  //       %value = ...
+  //       scf.yield %value
+  //     }
+  //     scf.yield %2
+  //   }
+  //   scf.yield %1
+  // }
+  // ```
+  //
+  // The inner most loop is handled using the `replaceLoopWithNewYields`
+  // that works on a single loop.
+  if (loopNest.size() == 1) {
+    auto innerMostLoop = replaceLoopWithNewYields(
+        builder, loopNest.back(), newIterOperands, newYieldValueFn,
+        replaceIterOperandsUsesInLoop);
+    return {innerMostLoop};
   }
+  // The outer loops are modified by calling this method recursively
+  // - The return value of the inner loop is the value yielded by this loop.
+  // - The region iter args of this loop are the init_args for the inner loop.
+  SmallVector<scf::ForOp> newLoopNest;
+  NewYieldValueFn fn =
+      [&](OpBuilder &innerBuilder, Location loc,
+          ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
+    newLoopNest = replaceLoopNestWithNewYields(builder, loopNest.drop_front(),
+                                               innerNewBBArgs, newYieldValueFn,
+                                               replaceIterOperandsUsesInLoop);
+    return llvm::to_vector(llvm::map_range(
+        newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
+        [](OpResult r) -> Value { return r; }));
+  };
+  scf::ForOp outerMostLoop =
+      replaceLoopWithNewYields(builder, loopNest.front(), newIterOperands, fn,
+                               replaceIterOperandsUsesInLoop);
+  newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
   return newLoopNest;
 }