[Coroutines] Store the index for final suspend point in the exception path
authorChuanqi Xu <yedeng.yd@linux.alibaba.com>
Tue, 20 Jun 2023 10:35:58 +0000 (18:35 +0800)
committerChuanqi Xu <yedeng.yd@linux.alibaba.com>
Tue, 20 Jun 2023 10:38:05 +0000 (18:38 +0800)
Try to address part of
https://github.com/llvm/llvm-project/issues/61900.

It is not completely addressed since the original reproducer is not
fixed due to the final suspend point is optimized out in its special
case. But that is a relatively independent issue.

llvm/lib/Transforms/Coroutines/CoroSplit.cpp
llvm/test/Transforms/Coroutines/coro-split-final-suspend.ll

index e3c4e94..ca25fd0 100644 (file)
@@ -300,6 +300,26 @@ static void markCoroutineAsDone(IRBuilder<> &Builder, const coro::Shape &Shape,
   auto *NullPtr = ConstantPointerNull::get(cast<PointerType>(
       Shape.FrameTy->getTypeAtIndex(coro::Shape::SwitchFieldIndex::Resume)));
   Builder.CreateStore(NullPtr, GepIndex);
+
+  // If the coroutine don't have unwind coro end, we could omit the store to
+  // the final suspend point since we could infer the coroutine is suspended
+  // at the final suspend point by the nullness of ResumeFnAddr.
+  // However, we can't skip it if the coroutine have unwind coro end. Since
+  // the coroutine reaches unwind coro end is considered suspended at the
+  // final suspend point (the ResumeFnAddr is null) but in fact the coroutine
+  // didn't complete yet. We need the IndexVal for the final suspend point
+  // to make the states clear.
+  if (Shape.SwitchLowering.HasUnwindCoroEnd &&
+      Shape.SwitchLowering.HasFinalSuspend) {
+    assert(cast<CoroSuspendInst>(Shape.CoroSuspends.back())->isFinal() &&
+           "The final suspend should only live in the last position of "
+           "CoroSuspends.");
+    ConstantInt *IndexVal = Shape.getIndex(Shape.CoroSuspends.size() - 1);
+    auto *FinalIndex = Builder.CreateStructGEP(
+        Shape.FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
+
+    Builder.CreateStore(IndexVal, FinalIndex);
+  }
 }
 
 /// Replace an unwind call to llvm.coro.end.
@@ -397,17 +417,7 @@ static void createResumeEntryBlock(Function &F, coro::Shape &Shape) {
       // The coroutine should be marked done if it reaches the final suspend
       // point.
       markCoroutineAsDone(Builder, Shape, FramePtr);
-    }
-
-    // If the coroutine don't have unwind coro end, we could omit the store to
-    // the final suspend point since we could infer the coroutine is suspended
-    // at the final suspend point by the nullness of ResumeFnAddr.
-    // However, we can't skip it if the coroutine have unwind coro end. Since
-    // the coroutine reaches unwind coro end is considered suspended at the
-    // final suspend point (the ResumeFnAddr is null) but in fact the coroutine
-    // didn't complete yet. We need the IndexVal for the final suspend point
-    // to make the states clear.
-    if (!S->isFinal() || Shape.SwitchLowering.HasUnwindCoroEnd) {
+    } else {
       auto *GepIndex = Builder.CreateStructGEP(
           FrameTy, FramePtr, Shape.getSwitchIndexField(), "index.addr");
       Builder.CreateStore(IndexVal, GepIndex);
index 73bd3c5..df99e89 100644 (file)
@@ -17,9 +17,10 @@ init_suspend:
   ret ptr %hdl
 
 init_resume:
-  br label %susp
+  invoke void @print(i32 1)
+          to label %final_suspend unwind label %lpad
 
-susp:
+final_suspend:
   %0 = call i8 @llvm.coro.suspend(token none, i1 true)
   switch i8 %0, label %suspend [i8 0, label %resume
                                 i8 1, label %suspend]
@@ -49,6 +50,19 @@ eh.resume:
 
 ; Tests that we need to store the final index if we see unwind coro end.
 ; CHECK: define{{.*}}@unwind_coro_end.resume
+; CHECK: invoke{{.*}}print
+; CHECK-NEXT: to label %[[RESUME:.*]] unwind label %[[LPAD:.*]]
+
+; CHECK: [[RESUME]]:
+; CHECK-NOT: {{.*:}}
+; CHECK: store ptr null, ptr %hdl
+; CHECK-NOT: {{.*:}}
+; CHECK: store i1 true, ptr %index.addr
+
+; CHECK: [[LPAD]]:
+; CHECK-NOT: {{.*:}}
+; CHECK: store ptr null, ptr %hdl
+; CHECK-NOT: {{.*:}}
 ; CHECK: store i1 true, ptr %index.addr
 
 ; Tests the use of final index in the destroy function.