[mlir][SCF] Add utility method to add new yield values to a loop.
authorMahesh Ravishankar <ravishankarm@google.com>
Fri, 6 May 2022 21:44:26 +0000 (21:44 +0000)
committerMahesh Ravishankar <ravishankarm@google.com>
Tue, 10 May 2022 18:44:11 +0000 (18:44 +0000)
The current implementation of `cloneWithNewYields` has a few issues
- It clones the loop body of the original loop to create a new
  loop. This is very expensive.
- It performs `erase` operations which are incompatible when this
  method is called from within a pattern rewrite. All erases need to
  go through `PatternRewriter`.

To address these a new utility method `replaceLoopWithNewYields` is added
which
- moves the operations from the original loop into the new loop.
- replaces all uses of the original loop with the corresponding
  results of the new loop
- use a call back to allow caller to generate the new yield values.
- the original loop is modified to just yield the basic block
  arguments corresponding to the iter_args of the loop. This
  represents a no-op loop. The loop itself is dead (since all its uses
  are replaced), but is not removed. The caller is expected to erase
  the op. Consequently, this method can be called from within a
  `matchAndRewrite` method of a `PatternRewriter`.

The `cloneWithNewYields` could be replaces with
`replaceLoopWithNewYields`, but that seems to trigger a failure during
walks, potentially due to the operations being moved. That is left as
a TODO.

Differential Revision: https://reviews.llvm.org/D125147

mlir/include/mlir/Dialect/SCF/Utils/Utils.h
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/test/Transforms/scf-loop-utils.mlir
mlir/test/Transforms/scf-replace-with-new-yields.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

index bb5a484..70fc084 100644 (file)
@@ -61,6 +61,28 @@ scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
                               ValueRange newYieldedValues,
                               bool replaceLoopResults = true);
 
+/// Replace the `loop` with `newIterOperands` added as new initialization
+/// values. `newYieldValuesFn` is a callback that can be used to specify
+/// the additional values to be yielded by the loop. The number of
+/// values returned by the callback should match the number of new
+/// initialization values. This function
+/// - Moves (i.e. doesnt clone) operations from the `loop` to the newly created
+///   loop
+/// - Replaces the uses of `loop` with the new loop.
+/// - `loop` isnt erased, but is left in a "no-op" state where the body of the
+///   loop just yields the basic block arguments that correspond to the
+///   initialization values of a loop. The loop is dead after this method.
+/// - All uses of the `newIterOperands` within the generated new loop
+///   are replaced with the corresponding `BlockArgument` in the loop body.
+/// TODO: This method could be used instead of `cloneWithNewYields`. Making
+/// this change though hits assertions in the walk mechanism that is unrelated
+/// to this method itself.
+using NewYieldValueFn = std::function<SmallVector<Value>(
+    OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs)>;
+scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
+                                    ValueRange newIterOperands,
+                                    NewYieldValueFn newYieldValuesFn);
+
 /// Outline a region with a single block into a new FuncOp.
 /// Assumes the FuncOp result types is the type of the yielded operands of the
 /// single block. This constraint makes it easy to determine the result.
index 5d69a3f..a2a751a 100644 (file)
@@ -91,6 +91,70 @@ scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
   return newLoop;
 }
 
+scf::ForOp mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
+                                          ValueRange newIterOperands,
+                                          NewYieldValueFn newYieldValuesFn) {
+  // Create a new loop before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(builder);
+  builder.setInsertionPoint(loop);
+  auto operands = llvm::to_vector(loop.getIterOperands());
+  operands.append(newIterOperands.begin(), newIterOperands.end());
+  scf::ForOp newLoop = builder.create<scf::ForOp>(
+      loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
+      operands, [](OpBuilder &, Location, Value, ValueRange) {});
+
+  Block *loopBody = loop.getBody();
+  Block *newLoopBody = newLoop.getBody();
+
+  // Move the body of the original loop to the new loop.
+  newLoopBody->getOperations().splice(newLoopBody->end(),
+                                      loopBody->getOperations());
+
+  // Generate the new yield values to use by using the callback and ppend the
+  // yield values to the scf.yield operation.
+  auto yield = cast<scf::YieldOp>(newLoopBody->getTerminator());
+  ArrayRef<BlockArgument> newBBArgs =
+      newLoopBody->getArguments().take_back(newIterOperands.size());
+  {
+    OpBuilder::InsertionGuard g(builder);
+    builder.setInsertionPoint(yield);
+    SmallVector<Value> newYieldedValues =
+        newYieldValuesFn(builder, loop.getLoc(), newBBArgs);
+    assert(newIterOperands.size() == newYieldedValues.size() &&
+           "expected as many new yield values as new iter operands");
+    yield.getResultsMutable().append(newYieldedValues);
+  }
+
+  // Remap the BlockArguments from the original loop to the new loop
+  // BlockArguments.
+  ArrayRef<BlockArgument> bbArgs = loopBody->getArguments();
+  for (auto it :
+       llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size())))
+    std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
+
+  // Replace all uses of `newIterOperands` with the corresponding basic block
+  // arguments.
+  for (auto it : llvm::zip(newIterOperands, newBBArgs)) {
+    std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) {
+      Operation *user = use.getOwner();
+      return newLoop->isProperAncestor(user);
+    });
+  }
+
+  // Replace all uses of the original loop with corresponding values from the
+  // new loop.
+  loop.replaceAllUsesWith(
+      newLoop.getResults().take_front(loop.getNumResults()));
+
+  // Add a fake yield to the original loop body that just returns the
+  // BlockArguments corresponding to the iter_args. This makes it a no-op loop.
+  // The loop is dead. The caller is expected to erase it.
+  builder.setInsertionPointToEnd(loopBody);
+  builder.create<scf::YieldOp>(loop->getLoc(), loop.getRegionIterArgs());
+
+  return newLoop;
+}
+
 /// Outline a region with a single block into a new FuncOp.
 /// Assumes the FuncOp result types is the type of the yielded operands of the
 /// single block. This constraint makes it easy to determine the result.
index 32e42b9..3e03d3a 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils -mlir-disable-threading %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils=test-clone-with-new-yields -mlir-disable-threading %s | FileCheck %s
 
 // CHECK-LABEL: @hoist
 //  CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index,
diff --git a/mlir/test/Transforms/scf-replace-with-new-yields.mlir b/mlir/test/Transforms/scf-replace-with-new-yields.mlir
new file mode 100644 (file)
index 0000000..802f86e
--- /dev/null
@@ -0,0 +1,21 @@
+
+// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils=test-replace-with-new-yields -mlir-disable-threading %s | FileCheck %s
+
+func.func @doubleup(%lb: index, %ub: index, %step: index, %extra_arg: f32) -> f32 {
+  %0 = scf.for %i = %lb to %ub step %step iter_args(%iter = %extra_arg) -> (f32) {
+    %1 = arith.addf %iter, %iter : f32
+    scf.yield %1: f32
+  }
+  return %0: f32
+}
+// CHECK-LABEL: func @doubleup
+//  CHECK-SAME:   %[[ARG:[a-zA-Z0-9]+]]: f32
+//       CHECK:   %[[NEWLOOP:.+]]:2 = scf.for
+//  CHECK-SAME:       iter_args(%[[INIT1:.+]] = %[[ARG]], %[[INIT2:.+]] = %[[ARG]]
+//       CHECK:     %[[DOUBLE:.+]] = arith.addf %[[INIT1]], %[[INIT1]]
+//       CHECK:     %[[DOUBLE2:.+]] = arith.addf %[[DOUBLE]], %[[DOUBLE]]
+//       CHECK:     scf.yield %[[DOUBLE]], %[[DOUBLE2]]
+//       CHECK:   %[[OLDLOOP:.+]] = scf.for
+//  CHECK-SAME:       iter_args(%[[INIT:.+]] = %[[ARG]])
+//       CHECK:     scf.yield %[[INIT]]
+//       CHECK:   return %[[NEWLOOP]]#0
index f354a30..858b0b1 100644 (file)
@@ -32,29 +32,67 @@ struct TestSCFForUtilsPass
   StringRef getArgument() const final { return "test-scf-for-utils"; }
   StringRef getDescription() const final { return "test scf.for utils"; }
   explicit TestSCFForUtilsPass() = default;
+  TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {}
+
+  Option<bool> testCloneWithNewYields{
+      *this, "test-clone-with-new-yields",
+      llvm::cl::desc(
+          "Test cloning of a loop while returning additional yield values"),
+      llvm::cl::init(false)};
+
+  Option<bool> testReplaceWithNewYields{
+      *this, "test-replace-with-new-yields",
+      llvm::cl::desc("Test replacing a loop with a new loop that returns new "
+                     "additional yeild values"),
+      llvm::cl::init(false)};
 
   void runOnOperation() override {
     func::FuncOp func = getOperation();
     SmallVector<scf::ForOp, 4> toErase;
 
-    func.walk([&](Operation *fakeRead) {
-      if (fakeRead->getName().getStringRef() != "fake_read")
-        return;
-      auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
-      auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
-      auto loop = fakeRead->getParentOfType<scf::ForOp>();
-
-      OpBuilder b(loop);
-      loop.moveOutOfLoop(fakeRead);
-      fakeWrite->moveAfter(loop);
-      auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
-                                        fakeCompute->getResult(0));
-      fakeCompute->getResult(0).replaceAllUsesWith(
-          newLoop.getResults().take_back()[0]);
-      toErase.push_back(loop);
-    });
-    for (auto loop : llvm::reverse(toErase))
-      loop.erase();
+    if (testCloneWithNewYields) {
+      func.walk([&](Operation *fakeRead) {
+        if (fakeRead->getName().getStringRef() != "fake_read")
+          return;
+        auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
+        auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
+        auto loop = fakeRead->getParentOfType<scf::ForOp>();
+
+        OpBuilder b(loop);
+        loop.moveOutOfLoop(fakeRead);
+        fakeWrite->moveAfter(loop);
+        auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
+                                          fakeCompute->getResult(0));
+        fakeCompute->getResult(0).replaceAllUsesWith(
+            newLoop.getResults().take_back()[0]);
+        toErase.push_back(loop);
+      });
+      for (auto loop : llvm::reverse(toErase))
+        loop.erase();
+    }
+
+    if (testReplaceWithNewYields) {
+      func.walk([&](scf::ForOp forOp) {
+        if (forOp.getNumResults() == 0)
+          return;
+        auto newInitValues = forOp.getInitArgs();
+        if (newInitValues.empty())
+          return;
+        NewYieldValueFn fn = [&](OpBuilder &b, Location loc,
+                                 ArrayRef<BlockArgument> newBBArgs) {
+          Block *block = newBBArgs.front().getOwner();
+          SmallVector<Value> newYieldValues;
+          for (auto yieldVal :
+               cast<scf::YieldOp>(block->getTerminator()).getResults()) {
+            newYieldValues.push_back(
+                b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
+          }
+          return newYieldValues;
+        };
+        OpBuilder b(forOp);
+        replaceLoopWithNewYields(b, forOp, newInitValues, fn);
+      });
+    }
   }
 };
 
@@ -88,7 +126,8 @@ static const StringLiteral kTestPipeliningLoopMarker =
     "__test_pipelining_loop__";
 static const StringLiteral kTestPipeliningStageMarker =
     "__test_pipelining_stage__";
-/// Marker to express the order in which operations should be after pipelining.
+/// Marker to express the order in which operations should be after
+/// pipelining.
 static const StringLiteral kTestPipeliningOpOrderMarker =
     "__test_pipelining_op_order__";