[Coroutines] Also check lifetime intrinsic for local variable when build
authorJun Ma <JunMa@linux.alibaba.com>
Tue, 10 Mar 2020 10:32:55 +0000 (18:32 +0800)
committerJun Ma <JunMa@linux.alibaba.com>
Tue, 24 Mar 2020 05:41:55 +0000 (13:41 +0800)
coroutine frame

Currently we move all allocas into the frame when build coroutine frame in
CoroSplit pass. However, this can be relaxed.

Since CoroSplit pass run after Inline pass, we can use lifetime intrinsic to
do such analysis: If the scope of lifetime intrinsic is not across any suspend
point, rather than move the allocas to frame, we can just move them to entry bb
of corresponding function. This reduce the frame size.

More importantly, this also avoid data race in multithread environment.
Consider one inline function by coroutine: it starts a thread which access
local variables, while after inline the movement of allocs to frame also access
them. cause data race.

Differential Revision: https://reviews.llvm.org/D75664

llvm/lib/Transforms/Coroutines/CoroFrame.cpp
llvm/lib/Transforms/Coroutines/CoroSplit.cpp
llvm/test/Transforms/Coroutines/coro-split-02.ll

index 0016597..c85f480 100644 (file)
@@ -108,7 +108,6 @@ struct SuspendCrossingInfo {
     size_t const DefIndex = Mapping.blockToIndex(DefBB);
     size_t const UseIndex = Mapping.blockToIndex(UseBB);
 
-    assert(Block[UseIndex].Consumes[DefIndex] && "use must consume def");
     bool const Result = Block[UseIndex].Kills[DefIndex];
     LLVM_DEBUG(dbgs() << UseBB->getName() << " => " << DefBB->getName()
                       << " answer is " << Result << "\n");
@@ -1396,6 +1395,24 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
     Spills.clear();
   }
 
+  // Collect lifetime.start info for each alloca.
+  using LifetimeStart = SmallPtrSet<Instruction *, 2>;
+  llvm::DenseMap<Instruction *, std::unique_ptr<LifetimeStart>> LifetimeMap;
+  for (Instruction &I : instructions(F)) {
+    auto *II = dyn_cast<IntrinsicInst>(&I);
+    if (!II || II->getIntrinsicID() != Intrinsic::lifetime_start)
+      continue;
+
+    if (auto *OpInst = dyn_cast<BitCastInst>(I.getOperand(1)))
+      if (auto *AI = dyn_cast<AllocaInst>(OpInst->getOperand(0))) {
+
+        if (LifetimeMap.find(AI) == LifetimeMap.end())
+          LifetimeMap[AI] = std::make_unique<LifetimeStart>();
+
+        LifetimeMap[AI]->insert(OpInst);
+      }
+  }
+
   // Collect the spills for arguments and other not-materializable values.
   for (Argument &A : F.args())
     for (User *U : A.users())
@@ -1441,14 +1458,27 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) {
       continue;
     }
 
-    for (User *U : I.users())
-      if (Checker.isDefinitionAcrossSuspend(I, U)) {
+    auto Iter = LifetimeMap.find(&I);
+    for (User *U : I.users()) {
+      bool NeedSpill = false;
+
+      // Check against lifetime.start if the instruction has the info.
+      if (Iter != LifetimeMap.end())
+        for (auto *S : *Iter->second) {
+          if ((NeedSpill = Checker.isDefinitionAcrossSuspend(*S, U)))
+            break;
+        }
+      else
+        NeedSpill = Checker.isDefinitionAcrossSuspend(I, U);
+
+      if (NeedSpill) {
         // We cannot spill a token.
         if (I.getType()->isTokenTy())
           report_fatal_error(
               "token definition is separated from the use by a suspend point");
         Spills.emplace_back(&I, U);
       }
+    }
   }
   LLVM_DEBUG(dump("Spills", Spills));
   Shape.FrameTy = buildFrameType(F, Shape, Spills);
index 8aca27b..465b659 100644 (file)
@@ -567,8 +567,9 @@ void CoroCloner::replaceEntryBlock() {
   // branching to the original beginning of the coroutine.  Make this 
   // the entry block of the cloned function.
   auto *Entry = cast<BasicBlock>(VMap[Shape.AllocaSpillBlock]);
+  auto *OldEntry = &NewF->getEntryBlock();
   Entry->setName("entry" + Suffix);
-  Entry->moveBefore(&NewF->getEntryBlock());
+  Entry->moveBefore(OldEntry);
   Entry->getTerminator()->eraseFromParent();
 
   // Clear all predecessors of the new entry block.  There should be
@@ -581,8 +582,14 @@ void CoroCloner::replaceEntryBlock() {
   Builder.CreateUnreachable();
   BranchToEntry->eraseFromParent();
 
-  // TODO: move any allocas into Entry that weren't moved into the frame.
-  // (Currently we move all allocas into the frame.)
+  // Move any allocas into Entry that weren't moved into the frame.
+  for (auto IT = OldEntry->begin(), End = OldEntry->end(); IT != End;) {
+    Instruction &I = *IT++;
+    if (!isa<AllocaInst>(&I) || I.getNumUses() == 0)
+      continue;
+
+    I.moveBefore(*Entry, Entry->getFirstInsertionPt());
+  }
 
   // Branch from the entry to the appropriate place.
   Builder.SetInsertPoint(Entry);
index 0309c0d..9933742 100644 (file)
@@ -14,6 +14,7 @@ declare void @print(i32)
 define void @a() "coroutine.presplit"="1" {
 entry:
   %ref.tmp7 = alloca %"struct.lean_future<int>::Awaiter", align 8
+  %testval = alloca i32
   %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null)
   %alloc = call i8* @malloc(i64 16) #3
   %vFrame = call noalias nonnull i8* @llvm.coro.begin(token %id, i8* %alloc)
@@ -28,6 +29,9 @@ entry:
 await.ready:
   %StrayCoroSave = call token @llvm.coro.save(i8* null)
   %val = load i32, i32* %Result.i19
+  %cast = bitcast i32* %testval to i8*
+  call void @llvm.lifetime.start.p0i8(i64 4, i8* %cast)
+  call void @llvm.lifetime.end.p0i8(i64 4, i8*  %cast)
   call void @print(i32 %val)
   br label %exit
 exit:
@@ -36,10 +40,14 @@ exit:
 }
 
 ; CHECK-LABEL: @a.resume(
+; CHECK:         %testval = alloca i32
 ; CHECK:         getelementptr inbounds %a.Frame
 ; CHECK-NEXT:    getelementptr inbounds %"struct.lean_future<int>::Awaiter"
 ; CHECK-NOT:     call token @llvm.coro.save(i8* null)
 ; CHECK-NEXT:    %val = load i32, i32* %Result
+; CHECK-NEXT:    %cast = bitcast i32* %testval to i8*
+; CHECK-NEXT:    call void @llvm.lifetime.start.p0i8(i64 4, i8* %cast)
+; CHECK-NEXT:    call void @llvm.lifetime.end.p0i8(i64 4, i8* %cast)
 ; CHECK-NEXT:    call void @print(i32 %val)
 ; CHECK-NEXT:    ret void
 
@@ -55,4 +63,6 @@ declare i8 @llvm.coro.suspend(token, i1) #3
 declare void @"\01??3@YAXPEAX@Z"(i8*) local_unnamed_addr #10
 declare i8* @llvm.coro.free(token, i8* nocapture readonly) #2
 declare i1 @llvm.coro.end(i8*, i1) #3
+declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #4
+declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #4