namespace llvm {
struct CoroSplitPass : PassInfoMixin<CoroSplitPass> {
- CoroSplitPass(bool OptimizeFrame = false) : OptimizeFrame(OptimizeFrame) {}
+ const std::function<bool(Instruction &)> MaterializableCallback;
+
+ CoroSplitPass(bool OptimizeFrame = false);
+ CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
+ bool OptimizeFrame = false)
+ : MaterializableCallback(MaterializableCallback),
+ OptimizeFrame(OptimizeFrame) {}
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
LazyCallGraph &CG, CGSCCUpdateResult &UR);
LLVM_DEBUG(dump());
}
-static bool materializable(Instruction &V);
-
namespace {
// RematGraph is used to construct a DAG for rematerializable instructions
using RematNodeMap =
SmallMapVector<Instruction *, std::unique_ptr<RematNode>, 8>;
RematNodeMap Remats;
+ const std::function<bool(Instruction &)> &MaterializableCallback;
SuspendCrossingInfo &Checker;
- RematGraph(Instruction *I, SuspendCrossingInfo &Checker) : Checker(Checker) {
+ RematGraph(const std::function<bool(Instruction &)> &MaterializableCallback,
+ Instruction *I, SuspendCrossingInfo &Checker)
+ : MaterializableCallback(MaterializableCallback), Checker(Checker) {
std::unique_ptr<RematNode> FirstNode = std::make_unique<RematNode>(I);
EntryNode = FirstNode.get();
std::deque<std::unique_ptr<RematNode>> WorkList;
Remats[N->Node] = std::move(NUPtr);
for (auto &Def : N->Node->operands()) {
Instruction *D = dyn_cast<Instruction>(Def.get());
- if (!D || !materializable(*D) ||
+ if (!D || !MaterializableCallback(*D) ||
!Checker.isDefinitionAcrossSuspend(*D, FirstUse))
continue;
rewritePHIs(*BB);
}
+/// Default materializable callback
// Check for instructions that we can recreate on resume as opposed to spill
// the result into a coroutine frame.
-static bool materializable(Instruction &V) {
- return isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
- isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V);
+bool coro::defaultMaterializable(Instruction &V) {
+ return (isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
+ isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V));
}
// Check for structural coroutine intrinsics that should not be spilled into
}
}
-static void doRematerializations(Function &F, SuspendCrossingInfo &Checker) {
+static void doRematerializations(
+ Function &F, SuspendCrossingInfo &Checker,
+ const std::function<bool(Instruction &)> &MaterializableCallback) {
SpillInfo Spills;
// See if there are materializable instructions across suspend points
// We record these as the starting point to also identify materializable
// defs of uses in these operations
for (Instruction &I : instructions(F)) {
- if (!materializable(I))
+ if (!MaterializableCallback(I))
continue;
for (User *U : I.users())
if (Checker.isDefinitionAcrossSuspend(I, U))
continue;
// Constructor creates the whole RematGraph for the given Use
- auto RematUPtr = std::make_unique<RematGraph>(U, Checker);
+ auto RematUPtr =
+ std::make_unique<RematGraph>(MaterializableCallback, U, Checker);
LLVM_DEBUG(dbgs() << "***** Next remat group *****\n";
ReversePostOrderTraversal<RematGraph *> RPOT(RematUPtr.get());
rewriteMaterializableInstructions(AllRemats);
}
-void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
+void coro::buildCoroutineFrame(
+ Function &F, Shape &Shape,
+ const std::function<bool(Instruction &)> &MaterializableCallback) {
// Don't eliminate swifterror in async functions that won't be split.
if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty())
eliminateSwiftError(F, Shape);
// Build suspend crossing info.
SuspendCrossingInfo Checker(F, Shape);
- doRematerializations(F, Checker);
+ doRematerializations(F, Checker, MaterializableCallback);
FrameDataInfo FrameData;
SmallVector<CoroAllocaAllocInst*, 4> LocalAllocas;
void buildFrom(Function &F);
};
-void buildCoroutineFrame(Function &F, Shape &Shape);
+bool defaultMaterializable(Instruction &V);
+void buildCoroutineFrame(
+ Function &F, Shape &Shape,
+ const std::function<bool(Instruction &)> &MaterializableCallback);
CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
ArrayRef<Value *> Arguments, IRBuilder<> &);
} // End namespace coro.
};
}
-static coro::Shape splitCoroutine(Function &F,
- SmallVectorImpl<Function *> &Clones,
- TargetTransformInfo &TTI,
- bool OptimizeFrame) {
+static coro::Shape
+splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
+ TargetTransformInfo &TTI, bool OptimizeFrame,
+ std::function<bool(Instruction &)> MaterializableCallback) {
PrettyStackTraceFunction prettyStackTrace(F);
// The suspend-crossing algorithm in buildCoroutineFrame get tripped
return Shape;
simplifySuspendPoints(Shape);
- buildCoroutineFrame(F, Shape);
+ buildCoroutineFrame(F, Shape, MaterializableCallback);
replaceFrameSizeAndAlignment(Shape);
// If there are no suspend points, no split required, just remove
Fns.push_back(PrepareFn);
}
+CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
+ : MaterializableCallback(coro::defaultMaterializable),
+ OptimizeFrame(OptimizeFrame) {}
+
PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
CGSCCAnalysisManager &AM,
LazyCallGraph &CG, CGSCCUpdateResult &UR) {
F.setSplittedCoroutine();
SmallVector<Function *, 4> Clones;
- const coro::Shape Shape = splitCoroutine(
- F, Clones, FAM.getResult<TargetIRAnalysis>(F), OptimizeFrame);
+ const coro::Shape Shape =
+ splitCoroutine(F, Clones, FAM.getResult<TargetIRAnalysis>(F),
+ OptimizeFrame, MaterializableCallback);
updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM);
if (!Shape.CoroSuspends.empty()) {
+add_subdirectory(Coroutines)
add_subdirectory(IPO)
add_subdirectory(Scalar)
add_subdirectory(Utils)
--- /dev/null
+set(LLVM_LINK_COMPONENTS
+ Analysis
+ AsmParser
+ Core
+ Coroutines
+ Passes
+ Support
+ TargetParser
+ TransformUtils
+ )
+
+add_llvm_unittest(CoroTests
+ ExtraRematTest.cpp
+ )
+
+target_link_libraries(CoroTests PRIVATE LLVMTestingSupport)
+
+set_property(TARGET CoroTests PROPERTY FOLDER "Tests/UnitTests/TransformTests")
--- /dev/null
+//===- ExtraRematTest.cpp - Coroutines unit tests -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Testing/Support/Error.h"
+#include "llvm/Transforms/Coroutines/CoroSplit.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+struct ExtraRematTest : public testing::Test {
+ LLVMContext Ctx;
+ ModulePassManager MPM;
+ PassBuilder PB;
+ LoopAnalysisManager LAM;
+ FunctionAnalysisManager FAM;
+ CGSCCAnalysisManager CGAM;
+ ModuleAnalysisManager MAM;
+ LLVMContext Context;
+ std::unique_ptr<Module> M;
+
+ ExtraRematTest() {
+ PB.registerModuleAnalyses(MAM);
+ PB.registerCGSCCAnalyses(CGAM);
+ PB.registerFunctionAnalyses(FAM);
+ PB.registerLoopAnalyses(LAM);
+ PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
+ }
+
+ BasicBlock *getBasicBlockByName(Function *F, StringRef Name) const {
+ for (BasicBlock &BB : *F) {
+ if (BB.getName() == Name)
+ return &BB;
+ }
+ return nullptr;
+ }
+
+ CallInst *getCallByName(BasicBlock *BB, StringRef Name) const {
+ for (Instruction &I : *BB) {
+ if (CallInst *CI = dyn_cast<CallInst>(&I))
+ if (CI->getCalledFunction()->getName() == Name)
+ return CI;
+ }
+ return nullptr;
+ }
+
+ void ParseAssembly(const StringRef IR) {
+ SMDiagnostic Error;
+ M = parseAssemblyString(IR, Error, Context);
+ std::string errMsg;
+ raw_string_ostream os(errMsg);
+ Error.print("", os);
+
+ // A failure here means that the test itself is buggy.
+ if (!M)
+ report_fatal_error(os.str().c_str());
+ }
+};
+
+StringRef Text = R"(
+ define ptr @f(i32 %n) presplitcoroutine {
+ entry:
+ %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
+ %size = call i32 @llvm.coro.size.i32()
+ %alloc = call ptr @malloc(i32 %size)
+ %hdl = call ptr @llvm.coro.begin(token %id, ptr %alloc)
+
+ %inc1 = add i32 %n, 1
+ %val2 = call i32 @should.remat(i32 %inc1)
+ %sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
+ switch i8 %sp1, label %suspend [i8 0, label %resume1
+ i8 1, label %cleanup]
+ resume1:
+ %inc2 = add i32 %val2, 1
+ %sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
+ switch i8 %sp1, label %suspend [i8 0, label %resume2
+ i8 1, label %cleanup]
+
+ resume2:
+ call void @print(i32 %val2)
+ call void @print(i32 %inc2)
+ br label %cleanup
+
+ cleanup:
+ %mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
+ call void @free(ptr %mem)
+ br label %suspend
+ suspend:
+ call i1 @llvm.coro.end(ptr %hdl, i1 0)
+ ret ptr %hdl
+ }
+
+ declare ptr @llvm.coro.free(token, ptr)
+ declare i32 @llvm.coro.size.i32()
+ declare i8 @llvm.coro.suspend(token, i1)
+ declare void @llvm.coro.resume(ptr)
+ declare void @llvm.coro.destroy(ptr)
+
+ declare token @llvm.coro.id(i32, ptr, ptr, ptr)
+ declare i1 @llvm.coro.alloc(token)
+ declare ptr @llvm.coro.begin(token, ptr)
+ declare i1 @llvm.coro.end(ptr, i1)
+
+ declare i32 @should.remat(i32)
+
+ declare noalias ptr @malloc(i32)
+ declare void @print(i32)
+ declare void @free(ptr)
+ )";
+
+// Materializable callback with extra rematerialization
+bool ExtraMaterializable(Instruction &I) {
+ if (isa<CastInst>(&I) || isa<GetElementPtrInst>(&I) ||
+ isa<BinaryOperator>(&I) || isa<CmpInst>(&I) || isa<SelectInst>(&I))
+ return true;
+
+ if (auto *CI = dyn_cast<CallInst>(&I)) {
+ auto *CalledFunc = CI->getCalledFunction();
+ if (CalledFunc && CalledFunc->getName().startswith("should.remat"))
+ return true;
+ }
+
+ return false;
+}
+
+TEST_F(ExtraRematTest, TestCoroRematDefault) {
+ ParseAssembly(Text);
+
+ ASSERT_TRUE(M);
+
+ CGSCCPassManager CGPM;
+ CGPM.addPass(CoroSplitPass());
+ MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+ MPM.run(*M, MAM);
+
+ // Verify that extra rematerializable instruction has been rematerialized
+ Function *F = M->getFunction("f.resume");
+ ASSERT_TRUE(F) << "could not find split function f.resume";
+
+ BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
+ ASSERT_TRUE(Resume1)
+ << "could not find expected BB resume1 in split function";
+
+ // With default materialization the intrinsic should not have been
+ // rematerialized
+ CallInst *CI = getCallByName(Resume1, "should.remat");
+ ASSERT_FALSE(CI);
+}
+
+TEST_F(ExtraRematTest, TestCoroRematWithCallback) {
+ ParseAssembly(Text);
+
+ ASSERT_TRUE(M);
+
+ CGSCCPassManager CGPM;
+ CGPM.addPass(
+ CoroSplitPass(std::function<bool(Instruction &)>(ExtraMaterializable)));
+ MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+ MPM.run(*M, MAM);
+
+ // Verify that extra rematerializable instruction has been rematerialized
+ Function *F = M->getFunction("f.resume");
+ ASSERT_TRUE(F) << "could not find split function f.resume";
+
+ BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
+ ASSERT_TRUE(Resume1)
+ << "could not find expected BB resume1 in split function";
+
+ // With callback the extra rematerialization of the function should have
+ // happened
+ CallInst *CI = getCallByName(Resume1, "should.remat");
+ ASSERT_TRUE(CI);
+}
+} // namespace