ValueRange newYieldedValues,
bool replaceLoopResults = true);
+/// Replace the `loop` with `newIterOperands` added as new initialization
+/// values. `newYieldValuesFn` is a callback that can be used to specify
+/// the additional values to be yielded by the loop. The number of
+/// values returned by the callback should match the number of new
+/// initialization values. This function
+/// - Moves (i.e. doesnt clone) operations from the `loop` to the newly created
+/// loop
+/// - Replaces the uses of `loop` with the new loop.
+/// - `loop` isnt erased, but is left in a "no-op" state where the body of the
+/// loop just yields the basic block arguments that correspond to the
+/// initialization values of a loop. The loop is dead after this method.
+/// - All uses of the `newIterOperands` within the generated new loop
+/// are replaced with the corresponding `BlockArgument` in the loop body.
+/// TODO: This method could be used instead of `cloneWithNewYields`. Making
+/// this change though hits assertions in the walk mechanism that is unrelated
+/// to this method itself.
+using NewYieldValueFn = std::function<SmallVector<Value>(
+ OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs)>;
+scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
+ ValueRange newIterOperands,
+ NewYieldValueFn newYieldValuesFn);
+
/// Outline a region with a single block into a new FuncOp.
/// Assumes the FuncOp result types is the type of the yielded operands of the
/// single block. This constraint makes it easy to determine the result.
return newLoop;
}
+scf::ForOp mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
+ ValueRange newIterOperands,
+ NewYieldValueFn newYieldValuesFn) {
+ // Create a new loop before the existing one, with the extra operands.
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPoint(loop);
+ auto operands = llvm::to_vector(loop.getIterOperands());
+ operands.append(newIterOperands.begin(), newIterOperands.end());
+ scf::ForOp newLoop = builder.create<scf::ForOp>(
+ loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
+ operands, [](OpBuilder &, Location, Value, ValueRange) {});
+
+ Block *loopBody = loop.getBody();
+ Block *newLoopBody = newLoop.getBody();
+
+ // Move the body of the original loop to the new loop.
+ newLoopBody->getOperations().splice(newLoopBody->end(),
+ loopBody->getOperations());
+
+ // Generate the new yield values to use by using the callback and ppend the
+ // yield values to the scf.yield operation.
+ auto yield = cast<scf::YieldOp>(newLoopBody->getTerminator());
+ ArrayRef<BlockArgument> newBBArgs =
+ newLoopBody->getArguments().take_back(newIterOperands.size());
+ {
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPoint(yield);
+ SmallVector<Value> newYieldedValues =
+ newYieldValuesFn(builder, loop.getLoc(), newBBArgs);
+ assert(newIterOperands.size() == newYieldedValues.size() &&
+ "expected as many new yield values as new iter operands");
+ yield.getResultsMutable().append(newYieldedValues);
+ }
+
+ // Remap the BlockArguments from the original loop to the new loop
+ // BlockArguments.
+ ArrayRef<BlockArgument> bbArgs = loopBody->getArguments();
+ for (auto it :
+ llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size())))
+ std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
+
+ // Replace all uses of `newIterOperands` with the corresponding basic block
+ // arguments.
+ for (auto it : llvm::zip(newIterOperands, newBBArgs)) {
+ std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) {
+ Operation *user = use.getOwner();
+ return newLoop->isProperAncestor(user);
+ });
+ }
+
+ // Replace all uses of the original loop with corresponding values from the
+ // new loop.
+ loop.replaceAllUsesWith(
+ newLoop.getResults().take_front(loop.getNumResults()));
+
+ // Add a fake yield to the original loop body that just returns the
+ // BlockArguments corresponding to the iter_args. This makes it a no-op loop.
+ // The loop is dead. The caller is expected to erase it.
+ builder.setInsertionPointToEnd(loopBody);
+ builder.create<scf::YieldOp>(loop->getLoc(), loop.getRegionIterArgs());
+
+ return newLoop;
+}
+
/// Outline a region with a single block into a new FuncOp.
/// Assumes the FuncOp result types is the type of the yielded operands of the
/// single block. This constraint makes it easy to determine the result.
-// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils -mlir-disable-threading %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils=test-clone-with-new-yields -mlir-disable-threading %s | FileCheck %s
// CHECK-LABEL: @hoist
// CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index,
--- /dev/null
+
+// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils=test-replace-with-new-yields -mlir-disable-threading %s | FileCheck %s
+
+func.func @doubleup(%lb: index, %ub: index, %step: index, %extra_arg: f32) -> f32 {
+ %0 = scf.for %i = %lb to %ub step %step iter_args(%iter = %extra_arg) -> (f32) {
+ %1 = arith.addf %iter, %iter : f32
+ scf.yield %1: f32
+ }
+ return %0: f32
+}
+// CHECK-LABEL: func @doubleup
+// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]: f32
+// CHECK: %[[NEWLOOP:.+]]:2 = scf.for
+// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[ARG]], %[[INIT2:.+]] = %[[ARG]]
+// CHECK: %[[DOUBLE:.+]] = arith.addf %[[INIT1]], %[[INIT1]]
+// CHECK: %[[DOUBLE2:.+]] = arith.addf %[[DOUBLE]], %[[DOUBLE]]
+// CHECK: scf.yield %[[DOUBLE]], %[[DOUBLE2]]
+// CHECK: %[[OLDLOOP:.+]] = scf.for
+// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[ARG]])
+// CHECK: scf.yield %[[INIT]]
+// CHECK: return %[[NEWLOOP]]#0
StringRef getArgument() const final { return "test-scf-for-utils"; }
StringRef getDescription() const final { return "test scf.for utils"; }
explicit TestSCFForUtilsPass() = default;
+ TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {}
+
+ Option<bool> testCloneWithNewYields{
+ *this, "test-clone-with-new-yields",
+ llvm::cl::desc(
+ "Test cloning of a loop while returning additional yield values"),
+ llvm::cl::init(false)};
+
+ Option<bool> testReplaceWithNewYields{
+ *this, "test-replace-with-new-yields",
+ llvm::cl::desc("Test replacing a loop with a new loop that returns new "
+ "additional yeild values"),
+ llvm::cl::init(false)};
void runOnOperation() override {
func::FuncOp func = getOperation();
SmallVector<scf::ForOp, 4> toErase;
- func.walk([&](Operation *fakeRead) {
- if (fakeRead->getName().getStringRef() != "fake_read")
- return;
- auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
- auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
- auto loop = fakeRead->getParentOfType<scf::ForOp>();
-
- OpBuilder b(loop);
- loop.moveOutOfLoop(fakeRead);
- fakeWrite->moveAfter(loop);
- auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
- fakeCompute->getResult(0));
- fakeCompute->getResult(0).replaceAllUsesWith(
- newLoop.getResults().take_back()[0]);
- toErase.push_back(loop);
- });
- for (auto loop : llvm::reverse(toErase))
- loop.erase();
+ if (testCloneWithNewYields) {
+ func.walk([&](Operation *fakeRead) {
+ if (fakeRead->getName().getStringRef() != "fake_read")
+ return;
+ auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
+ auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
+ auto loop = fakeRead->getParentOfType<scf::ForOp>();
+
+ OpBuilder b(loop);
+ loop.moveOutOfLoop(fakeRead);
+ fakeWrite->moveAfter(loop);
+ auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
+ fakeCompute->getResult(0));
+ fakeCompute->getResult(0).replaceAllUsesWith(
+ newLoop.getResults().take_back()[0]);
+ toErase.push_back(loop);
+ });
+ for (auto loop : llvm::reverse(toErase))
+ loop.erase();
+ }
+
+ if (testReplaceWithNewYields) {
+ func.walk([&](scf::ForOp forOp) {
+ if (forOp.getNumResults() == 0)
+ return;
+ auto newInitValues = forOp.getInitArgs();
+ if (newInitValues.empty())
+ return;
+ NewYieldValueFn fn = [&](OpBuilder &b, Location loc,
+ ArrayRef<BlockArgument> newBBArgs) {
+ Block *block = newBBArgs.front().getOwner();
+ SmallVector<Value> newYieldValues;
+ for (auto yieldVal :
+ cast<scf::YieldOp>(block->getTerminator()).getResults()) {
+ newYieldValues.push_back(
+ b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
+ }
+ return newYieldValues;
+ };
+ OpBuilder b(forOp);
+ replaceLoopWithNewYields(b, forOp, newInitValues, fn);
+ });
+ }
}
};
"__test_pipelining_loop__";
static const StringLiteral kTestPipeliningStageMarker =
"__test_pipelining_stage__";
-/// Marker to express the order in which operations should be after pipelining.
+/// Marker to express the order in which operations should be after
+/// pipelining.
static const StringLiteral kTestPipeliningOpOrderMarker =
"__test_pipelining_op_order__";