[Coroutines] Modify CoroFrame materializable into a callback
authorDavid Stuttard <david.stuttard@amd.com>
Thu, 19 Jan 2023 13:55:49 +0000 (13:55 +0000)
committerDavid Stuttard <david.stuttard@amd.com>
Mon, 13 Feb 2023 11:02:25 +0000 (11:02 +0000)
This change makes it possible to optionally provide a different callback to
determine if an instruction is materializable.

By default the behaviour is unchanged.

Differential Revision: https://reviews.llvm.org/D142621

llvm/include/llvm/Transforms/Coroutines/CoroSplit.h
llvm/lib/Transforms/Coroutines/CoroFrame.cpp
llvm/lib/Transforms/Coroutines/CoroInternal.h
llvm/lib/Transforms/Coroutines/CoroSplit.cpp
llvm/unittests/Transforms/CMakeLists.txt
llvm/unittests/Transforms/Coroutines/CMakeLists.txt [new file with mode: 0644]
llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp [new file with mode: 0644]

index 7623c9c..a2be109 100644 (file)
 namespace llvm {
 
 struct CoroSplitPass : PassInfoMixin<CoroSplitPass> {
-  CoroSplitPass(bool OptimizeFrame = false) : OptimizeFrame(OptimizeFrame) {}
+  const std::function<bool(Instruction &)> MaterializableCallback;
+
+  CoroSplitPass(bool OptimizeFrame = false);
+  CoroSplitPass(std::function<bool(Instruction &)> MaterializableCallback,
+                bool OptimizeFrame = false)
+      : MaterializableCallback(MaterializableCallback),
+        OptimizeFrame(OptimizeFrame) {}
 
   PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM,
                         LazyCallGraph &CG, CGSCCUpdateResult &UR);
index dc14f68..9f9b45c 100644 (file)
@@ -318,8 +318,6 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
   LLVM_DEBUG(dump());
 }
 
-static bool materializable(Instruction &V);
-
 namespace {
 
 // RematGraph is used to construct a DAG for rematerializable instructions
@@ -342,9 +340,12 @@ struct RematGraph {
   using RematNodeMap =
       SmallMapVector<Instruction *, std::unique_ptr<RematNode>, 8>;
   RematNodeMap Remats;
+  const std::function<bool(Instruction &)> &MaterializableCallback;
   SuspendCrossingInfo &Checker;
 
-  RematGraph(Instruction *I, SuspendCrossingInfo &Checker) : Checker(Checker) {
+  RematGraph(const std::function<bool(Instruction &)> &MaterializableCallback,
+             Instruction *I, SuspendCrossingInfo &Checker)
+      : MaterializableCallback(MaterializableCallback), Checker(Checker) {
     std::unique_ptr<RematNode> FirstNode = std::make_unique<RematNode>(I);
     EntryNode = FirstNode.get();
     std::deque<std::unique_ptr<RematNode>> WorkList;
@@ -367,7 +368,7 @@ struct RematGraph {
     Remats[N->Node] = std::move(NUPtr);
     for (auto &Def : N->Node->operands()) {
       Instruction *D = dyn_cast<Instruction>(Def.get());
-      if (!D || !materializable(*D) ||
+      if (!D || !MaterializableCallback(*D) ||
           !Checker.isDefinitionAcrossSuspend(*D, FirstUse))
         continue;
 
@@ -2211,11 +2212,12 @@ static void rewritePHIs(Function &F) {
     rewritePHIs(*BB);
 }
 
+/// Default materializable callback
 // Check for instructions that we can recreate on resume as opposed to spill
 // the result into a coroutine frame.
-static bool materializable(Instruction &V) {
-  return isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
-         isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V);
+bool coro::defaultMaterializable(Instruction &V) {
+  return (isa<CastInst>(&V) || isa<GetElementPtrInst>(&V) ||
+          isa<BinaryOperator>(&V) || isa<CmpInst>(&V) || isa<SelectInst>(&V));
 }
 
 // Check for structural coroutine intrinsics that should not be spilled into
@@ -2887,14 +2889,16 @@ void coro::salvageDebugInfo(
   }
 }
 
-static void doRematerializations(Function &F, SuspendCrossingInfo &Checker) {
+static void doRematerializations(
+    Function &F, SuspendCrossingInfo &Checker,
+    const std::function<bool(Instruction &)> &MaterializableCallback) {
   SpillInfo Spills;
 
   // See if there are materializable instructions across suspend points
   // We record these as the starting point to also identify materializable
   // defs of uses in these operations
   for (Instruction &I : instructions(F)) {
-    if (!materializable(I))
+    if (!MaterializableCallback(I))
       continue;
     for (User *U : I.users())
       if (Checker.isDefinitionAcrossSuspend(I, U))
@@ -2925,7 +2929,8 @@ static void doRematerializations(Function &F, SuspendCrossingInfo &Checker) {
         continue;
 
       // Constructor creates the whole RematGraph for the given Use
-      auto RematUPtr = std::make_unique<RematGraph>(U, Checker);
+      auto RematUPtr =
+          std::make_unique<RematGraph>(MaterializableCallback, U, Checker);
 
       LLVM_DEBUG(dbgs() << "***** Next remat group *****\n";
                  ReversePostOrderTraversal<RematGraph *> RPOT(RematUPtr.get());
@@ -2943,7 +2948,9 @@ static void doRematerializations(Function &F, SuspendCrossingInfo &Checker) {
   rewriteMaterializableInstructions(AllRemats);
 }
 
-void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
+void coro::buildCoroutineFrame(
+    Function &F, Shape &Shape,
+    const std::function<bool(Instruction &)> &MaterializableCallback) {
   // Don't eliminate swifterror in async functions that won't be split.
   if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty())
     eliminateSwiftError(F, Shape);
@@ -2994,7 +3001,7 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
   // Build suspend crossing info.
   SuspendCrossingInfo Checker(F, Shape);
 
-  doRematerializations(F, Checker);
+  doRematerializations(F, Checker, MaterializableCallback);
 
   FrameDataInfo FrameData;
   SmallVector<CoroAllocaAllocInst*, 4> LocalAllocas;
index 032361c..13bdaaf 100644 (file)
@@ -261,7 +261,10 @@ struct LLVM_LIBRARY_VISIBILITY Shape {
   void buildFrom(Function &F);
 };
 
-void buildCoroutineFrame(Function &F, Shape &Shape);
+bool defaultMaterializable(Instruction &V);
+void buildCoroutineFrame(
+    Function &F, Shape &Shape,
+    const std::function<bool(Instruction &)> &MaterializableCallback);
 CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn,
                              ArrayRef<Value *> Arguments, IRBuilder<> &);
 } // End namespace coro.
index 1171878..fd33a46 100644 (file)
@@ -1929,10 +1929,10 @@ namespace {
   };
 }
 
-static coro::Shape splitCoroutine(Function &F,
-                                  SmallVectorImpl<Function *> &Clones,
-                                  TargetTransformInfo &TTI,
-                                  bool OptimizeFrame) {
+static coro::Shape
+splitCoroutine(Function &F, SmallVectorImpl<Function *> &Clones,
+               TargetTransformInfo &TTI, bool OptimizeFrame,
+               std::function<bool(Instruction &)> MaterializableCallback) {
   PrettyStackTraceFunction prettyStackTrace(F);
 
   // The suspend-crossing algorithm in buildCoroutineFrame get tripped
@@ -1944,7 +1944,7 @@ static coro::Shape splitCoroutine(Function &F,
     return Shape;
 
   simplifySuspendPoints(Shape);
-  buildCoroutineFrame(F, Shape);
+  buildCoroutineFrame(F, Shape, MaterializableCallback);
   replaceFrameSizeAndAlignment(Shape);
 
   // If there are no suspend points, no split required, just remove
@@ -2104,6 +2104,10 @@ static void addPrepareFunction(const Module &M,
     Fns.push_back(PrepareFn);
 }
 
+CoroSplitPass::CoroSplitPass(bool OptimizeFrame)
+    : MaterializableCallback(coro::defaultMaterializable),
+      OptimizeFrame(OptimizeFrame) {}
+
 PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
                                      CGSCCAnalysisManager &AM,
                                      LazyCallGraph &CG, CGSCCUpdateResult &UR) {
@@ -2142,8 +2146,9 @@ PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C,
     F.setSplittedCoroutine();
 
     SmallVector<Function *, 4> Clones;
-    const coro::Shape Shape = splitCoroutine(
-        F, Clones, FAM.getResult<TargetIRAnalysis>(F), OptimizeFrame);
+    const coro::Shape Shape =
+        splitCoroutine(F, Clones, FAM.getResult<TargetIRAnalysis>(F),
+                       OptimizeFrame, MaterializableCallback);
     updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM);
 
     if (!Shape.CoroSuspends.empty()) {
index b7f1817..98c821a 100644 (file)
@@ -1,3 +1,4 @@
+add_subdirectory(Coroutines)
 add_subdirectory(IPO)
 add_subdirectory(Scalar)
 add_subdirectory(Utils)
diff --git a/llvm/unittests/Transforms/Coroutines/CMakeLists.txt b/llvm/unittests/Transforms/Coroutines/CMakeLists.txt
new file mode 100644 (file)
index 0000000..0913e82
--- /dev/null
@@ -0,0 +1,18 @@
+set(LLVM_LINK_COMPONENTS
+  Analysis
+  AsmParser
+  Core
+  Coroutines
+  Passes
+  Support
+  TargetParser
+  TransformUtils
+  )
+
+add_llvm_unittest(CoroTests
+  ExtraRematTest.cpp
+  )
+
+target_link_libraries(CoroTests PRIVATE LLVMTestingSupport)
+
+set_property(TARGET CoroTests PROPERTY FOLDER "Tests/UnitTests/TransformTests")
diff --git a/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp
new file mode 100644 (file)
index 0000000..df4d6c3
--- /dev/null
@@ -0,0 +1,184 @@
+//===- ExtraRematTest.cpp - Coroutines unit tests -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Passes/PassBuilder.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Testing/Support/Error.h"
+#include "llvm/Transforms/Coroutines/CoroSplit.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+struct ExtraRematTest : public testing::Test {
+  LLVMContext Ctx;
+  ModulePassManager MPM;
+  PassBuilder PB;
+  LoopAnalysisManager LAM;
+  FunctionAnalysisManager FAM;
+  CGSCCAnalysisManager CGAM;
+  ModuleAnalysisManager MAM;
+  LLVMContext Context;
+  std::unique_ptr<Module> M;
+
+  ExtraRematTest() {
+    PB.registerModuleAnalyses(MAM);
+    PB.registerCGSCCAnalyses(CGAM);
+    PB.registerFunctionAnalyses(FAM);
+    PB.registerLoopAnalyses(LAM);
+    PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
+  }
+
+  BasicBlock *getBasicBlockByName(Function *F, StringRef Name) const {
+    for (BasicBlock &BB : *F) {
+      if (BB.getName() == Name)
+        return &BB;
+    }
+    return nullptr;
+  }
+
+  CallInst *getCallByName(BasicBlock *BB, StringRef Name) const {
+    for (Instruction &I : *BB) {
+      if (CallInst *CI = dyn_cast<CallInst>(&I))
+        if (CI->getCalledFunction()->getName() == Name)
+          return CI;
+    }
+    return nullptr;
+  }
+
+  void ParseAssembly(const StringRef IR) {
+    SMDiagnostic Error;
+    M = parseAssemblyString(IR, Error, Context);
+    std::string errMsg;
+    raw_string_ostream os(errMsg);
+    Error.print("", os);
+
+    // A failure here means that the test itself is buggy.
+    if (!M)
+      report_fatal_error(os.str().c_str());
+  }
+};
+
+StringRef Text = R"(
+    define ptr @f(i32 %n) presplitcoroutine {
+    entry:
+      %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
+      %size = call i32 @llvm.coro.size.i32()
+      %alloc = call ptr @malloc(i32 %size)
+      %hdl = call ptr @llvm.coro.begin(token %id, ptr %alloc)
+
+      %inc1 = add i32 %n, 1
+      %val2 = call i32 @should.remat(i32 %inc1)
+      %sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
+      switch i8 %sp1, label %suspend [i8 0, label %resume1
+                                      i8 1, label %cleanup]
+    resume1:
+      %inc2 = add i32 %val2, 1
+      %sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
+      switch i8 %sp1, label %suspend [i8 0, label %resume2
+                                      i8 1, label %cleanup]
+
+    resume2:
+      call void @print(i32 %val2)
+      call void @print(i32 %inc2)
+      br label %cleanup
+
+    cleanup:
+      %mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
+      call void @free(ptr %mem)
+      br label %suspend
+    suspend:
+      call i1 @llvm.coro.end(ptr %hdl, i1 0)
+      ret ptr %hdl
+    }
+
+    declare ptr @llvm.coro.free(token, ptr)
+    declare i32 @llvm.coro.size.i32()
+    declare i8  @llvm.coro.suspend(token, i1)
+    declare void @llvm.coro.resume(ptr)
+    declare void @llvm.coro.destroy(ptr)
+
+    declare token @llvm.coro.id(i32, ptr, ptr, ptr)
+    declare i1 @llvm.coro.alloc(token)
+    declare ptr @llvm.coro.begin(token, ptr)
+    declare i1 @llvm.coro.end(ptr, i1)
+
+    declare i32 @should.remat(i32)
+
+    declare noalias ptr @malloc(i32)
+    declare void @print(i32)
+    declare void @free(ptr)
+  )";
+
+// Materializable callback with extra rematerialization
+bool ExtraMaterializable(Instruction &I) {
+  if (isa<CastInst>(&I) || isa<GetElementPtrInst>(&I) ||
+      isa<BinaryOperator>(&I) || isa<CmpInst>(&I) || isa<SelectInst>(&I))
+    return true;
+
+  if (auto *CI = dyn_cast<CallInst>(&I)) {
+    auto *CalledFunc = CI->getCalledFunction();
+    if (CalledFunc && CalledFunc->getName().startswith("should.remat"))
+      return true;
+  }
+
+  return false;
+}
+
+TEST_F(ExtraRematTest, TestCoroRematDefault) {
+  ParseAssembly(Text);
+
+  ASSERT_TRUE(M);
+
+  CGSCCPassManager CGPM;
+  CGPM.addPass(CoroSplitPass());
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+  MPM.run(*M, MAM);
+
+  // Verify that extra rematerializable instruction has been rematerialized
+  Function *F = M->getFunction("f.resume");
+  ASSERT_TRUE(F) << "could not find split function f.resume";
+
+  BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
+  ASSERT_TRUE(Resume1)
+      << "could not find expected BB resume1 in split function";
+
+  // With default materialization the intrinsic should not have been
+  // rematerialized
+  CallInst *CI = getCallByName(Resume1, "should.remat");
+  ASSERT_FALSE(CI);
+}
+
+TEST_F(ExtraRematTest, TestCoroRematWithCallback) {
+  ParseAssembly(Text);
+
+  ASSERT_TRUE(M);
+
+  CGSCCPassManager CGPM;
+  CGPM.addPass(
+      CoroSplitPass(std::function<bool(Instruction &)>(ExtraMaterializable)));
+  MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM)));
+  MPM.run(*M, MAM);
+
+  // Verify that extra rematerializable instruction has been rematerialized
+  Function *F = M->getFunction("f.resume");
+  ASSERT_TRUE(F) << "could not find split function f.resume";
+
+  BasicBlock *Resume1 = getBasicBlockByName(F, "resume1");
+  ASSERT_TRUE(Resume1)
+      << "could not find expected BB resume1 in split function";
+
+  // With callback the extra rematerialization of the function should have
+  // happened
+  CallInst *CI = getCallByName(Resume1, "should.remat");
+  ASSERT_TRUE(CI);
+}
+} // namespace