From 567fd523bf538523f58779e5af9d20c3e48838a2 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Fri, 6 May 2022 21:44:26 +0000 Subject: [PATCH] [mlir][SCF] Add utility method to add new yield values to a loop. The current implementation of `cloneWithNewYields` has a few issues - It clones the loop body of the original loop to create a new loop. This is very expensive. - It performs `erase` operations which are incompatible when this method is called from within a pattern rewrite. All erases need to go through `PatternRewriter`. To address these a new utility method `replaceLoopWithNewYields` is added which - moves the operations from the original loop into the new loop. - replaces all uses of the original loop with the corresponding results of the new loop - use a call back to allow caller to generate the new yield values. - the original loop is modified to just yield the basic block arguments corresponding to the iter_args of the loop. This represents a no-op loop. The loop itself is dead (since all its uses are replaced), but is not removed. The caller is expected to erase the op. Consequently, this method can be called from within a `matchAndRewrite` method of a `PatternRewriter`. The `cloneWithNewYields` could be replaces with `replaceLoopWithNewYields`, but that seems to trigger a failure during walks, potentially due to the operations being moved. That is left as a TODO. Differential Revision: https://reviews.llvm.org/D125147 --- mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 22 +++++++ mlir/lib/Dialect/SCF/Utils/Utils.cpp | 64 ++++++++++++++++++ mlir/test/Transforms/scf-loop-utils.mlir | 2 +- .../Transforms/scf-replace-with-new-yields.mlir | 21 ++++++ mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp | 77 ++++++++++++++++------ 5 files changed, 166 insertions(+), 20 deletions(-) create mode 100644 mlir/test/Transforms/scf-replace-with-new-yields.mlir diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index bb5a484..70fc084 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -61,6 +61,28 @@ scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop, ValueRange newYieldedValues, bool replaceLoopResults = true); +/// Replace the `loop` with `newIterOperands` added as new initialization +/// values. `newYieldValuesFn` is a callback that can be used to specify +/// the additional values to be yielded by the loop. The number of +/// values returned by the callback should match the number of new +/// initialization values. This function +/// - Moves (i.e. doesnt clone) operations from the `loop` to the newly created +/// loop +/// - Replaces the uses of `loop` with the new loop. +/// - `loop` isnt erased, but is left in a "no-op" state where the body of the +/// loop just yields the basic block arguments that correspond to the +/// initialization values of a loop. The loop is dead after this method. +/// - All uses of the `newIterOperands` within the generated new loop +/// are replaced with the corresponding `BlockArgument` in the loop body. +/// TODO: This method could be used instead of `cloneWithNewYields`. Making +/// this change though hits assertions in the walk mechanism that is unrelated +/// to this method itself. +using NewYieldValueFn = std::function( + OpBuilder &b, Location loc, ArrayRef newBBArgs)>; +scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, + ValueRange newIterOperands, + NewYieldValueFn newYieldValuesFn); + /// Outline a region with a single block into a new FuncOp. /// Assumes the FuncOp result types is the type of the yielded operands of the /// single block. This constraint makes it easy to determine the result. diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 5d69a3f..a2a751a 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -91,6 +91,70 @@ scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop, return newLoop; } +scf::ForOp mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, + ValueRange newIterOperands, + NewYieldValueFn newYieldValuesFn) { + // Create a new loop before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(loop); + auto operands = llvm::to_vector(loop.getIterOperands()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + scf::ForOp newLoop = builder.create( + loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), + operands, [](OpBuilder &, Location, Value, ValueRange) {}); + + Block *loopBody = loop.getBody(); + Block *newLoopBody = newLoop.getBody(); + + // Move the body of the original loop to the new loop. + newLoopBody->getOperations().splice(newLoopBody->end(), + loopBody->getOperations()); + + // Generate the new yield values to use by using the callback and ppend the + // yield values to the scf.yield operation. + auto yield = cast(newLoopBody->getTerminator()); + ArrayRef newBBArgs = + newLoopBody->getArguments().take_back(newIterOperands.size()); + { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(yield); + SmallVector newYieldedValues = + newYieldValuesFn(builder, loop.getLoc(), newBBArgs); + assert(newIterOperands.size() == newYieldedValues.size() && + "expected as many new yield values as new iter operands"); + yield.getResultsMutable().append(newYieldedValues); + } + + // Remap the BlockArguments from the original loop to the new loop + // BlockArguments. + ArrayRef bbArgs = loopBody->getArguments(); + for (auto it : + llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + + // Replace all uses of `newIterOperands` with the corresponding basic block + // arguments. + for (auto it : llvm::zip(newIterOperands, newBBArgs)) { + std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoop->isProperAncestor(user); + }); + } + + // Replace all uses of the original loop with corresponding values from the + // new loop. + loop.replaceAllUsesWith( + newLoop.getResults().take_front(loop.getNumResults())); + + // Add a fake yield to the original loop body that just returns the + // BlockArguments corresponding to the iter_args. This makes it a no-op loop. + // The loop is dead. The caller is expected to erase it. + builder.setInsertionPointToEnd(loopBody); + builder.create(loop->getLoc(), loop.getRegionIterArgs()); + + return newLoop; +} + /// Outline a region with a single block into a new FuncOp. /// Assumes the FuncOp result types is the type of the yielded operands of the /// single block. This constraint makes it easy to determine the result. diff --git a/mlir/test/Transforms/scf-loop-utils.mlir b/mlir/test/Transforms/scf-loop-utils.mlir index 32e42b9..3e03d3a 100644 --- a/mlir/test/Transforms/scf-loop-utils.mlir +++ b/mlir/test/Transforms/scf-loop-utils.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils -mlir-disable-threading %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils=test-clone-with-new-yields -mlir-disable-threading %s | FileCheck %s // CHECK-LABEL: @hoist // CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index, diff --git a/mlir/test/Transforms/scf-replace-with-new-yields.mlir b/mlir/test/Transforms/scf-replace-with-new-yields.mlir new file mode 100644 index 0000000..802f86e --- /dev/null +++ b/mlir/test/Transforms/scf-replace-with-new-yields.mlir @@ -0,0 +1,21 @@ + +// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils=test-replace-with-new-yields -mlir-disable-threading %s | FileCheck %s + +func.func @doubleup(%lb: index, %ub: index, %step: index, %extra_arg: f32) -> f32 { + %0 = scf.for %i = %lb to %ub step %step iter_args(%iter = %extra_arg) -> (f32) { + %1 = arith.addf %iter, %iter : f32 + scf.yield %1: f32 + } + return %0: f32 +} +// CHECK-LABEL: func @doubleup +// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]: f32 +// CHECK: %[[NEWLOOP:.+]]:2 = scf.for +// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[ARG]], %[[INIT2:.+]] = %[[ARG]] +// CHECK: %[[DOUBLE:.+]] = arith.addf %[[INIT1]], %[[INIT1]] +// CHECK: %[[DOUBLE2:.+]] = arith.addf %[[DOUBLE]], %[[DOUBLE]] +// CHECK: scf.yield %[[DOUBLE]], %[[DOUBLE2]] +// CHECK: %[[OLDLOOP:.+]] = scf.for +// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[ARG]]) +// CHECK: scf.yield %[[INIT]] +// CHECK: return %[[NEWLOOP]]#0 diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp index f354a30..858b0b1 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -32,29 +32,67 @@ struct TestSCFForUtilsPass StringRef getArgument() const final { return "test-scf-for-utils"; } StringRef getDescription() const final { return "test scf.for utils"; } explicit TestSCFForUtilsPass() = default; + TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {} + + Option testCloneWithNewYields{ + *this, "test-clone-with-new-yields", + llvm::cl::desc( + "Test cloning of a loop while returning additional yield values"), + llvm::cl::init(false)}; + + Option testReplaceWithNewYields{ + *this, "test-replace-with-new-yields", + llvm::cl::desc("Test replacing a loop with a new loop that returns new " + "additional yeild values"), + llvm::cl::init(false)}; void runOnOperation() override { func::FuncOp func = getOperation(); SmallVector toErase; - func.walk([&](Operation *fakeRead) { - if (fakeRead->getName().getStringRef() != "fake_read") - return; - auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner(); - auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner(); - auto loop = fakeRead->getParentOfType(); - - OpBuilder b(loop); - loop.moveOutOfLoop(fakeRead); - fakeWrite->moveAfter(loop); - auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0), - fakeCompute->getResult(0)); - fakeCompute->getResult(0).replaceAllUsesWith( - newLoop.getResults().take_back()[0]); - toErase.push_back(loop); - }); - for (auto loop : llvm::reverse(toErase)) - loop.erase(); + if (testCloneWithNewYields) { + func.walk([&](Operation *fakeRead) { + if (fakeRead->getName().getStringRef() != "fake_read") + return; + auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner(); + auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner(); + auto loop = fakeRead->getParentOfType(); + + OpBuilder b(loop); + loop.moveOutOfLoop(fakeRead); + fakeWrite->moveAfter(loop); + auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0), + fakeCompute->getResult(0)); + fakeCompute->getResult(0).replaceAllUsesWith( + newLoop.getResults().take_back()[0]); + toErase.push_back(loop); + }); + for (auto loop : llvm::reverse(toErase)) + loop.erase(); + } + + if (testReplaceWithNewYields) { + func.walk([&](scf::ForOp forOp) { + if (forOp.getNumResults() == 0) + return; + auto newInitValues = forOp.getInitArgs(); + if (newInitValues.empty()) + return; + NewYieldValueFn fn = [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) { + Block *block = newBBArgs.front().getOwner(); + SmallVector newYieldValues; + for (auto yieldVal : + cast(block->getTerminator()).getResults()) { + newYieldValues.push_back( + b.create(loc, yieldVal, yieldVal)); + } + return newYieldValues; + }; + OpBuilder b(forOp); + replaceLoopWithNewYields(b, forOp, newInitValues, fn); + }); + } } }; @@ -88,7 +126,8 @@ static const StringLiteral kTestPipeliningLoopMarker = "__test_pipelining_loop__"; static const StringLiteral kTestPipeliningStageMarker = "__test_pipelining_stage__"; -/// Marker to express the order in which operations should be after pipelining. +/// Marker to express the order in which operations should be after +/// pipelining. static const StringLiteral kTestPipeliningOpOrderMarker = "__test_pipelining_op_order__"; -- 2.7.4