CoroFrame: Put escaped variables with multiple lifetimes on coroutine frame
authorMatthias Braun <matze@braunis.de>
Fri, 16 Dec 2022 15:21:41 +0000 (07:21 -0800)
committerMatthias Braun <matze@braunis.de>
Wed, 4 Jan 2023 15:30:08 +0000 (07:30 -0800)
The llvm.lifetime.start intrinsic guarantees that the address for a
given alloca is always the same. So variables with escaped addresses
reaching reaching a lifetime start/end block before and after a suspend
must be placed onto the coroutine frame even if the variable itself
is not alive across the suspend point.

This computes a new `LoopKill` flag in the suspend crossing data flow
anaysis to catch the case where a lifetime marker can reach itself
via suspend-crossing path.

This fixes https://llvm.org/PR52501

Differential Revision: https://reviews.llvm.org/D140231

llvm/lib/Transforms/Coroutines/CoroFrame.cpp
llvm/test/Transforms/Coroutines/coro-alloca-loop-carried-address.ll [new file with mode: 0644]

index aa32f51..40bf160 100644 (file)
@@ -77,11 +77,14 @@ public:
 //
 // For every basic block 'i' it maintains a BlockData that consists of:
 //   Consumes:  a bit vector which contains a set of indices of blocks that can
-//              reach block 'i'
+//              reach block 'i'. A block can trivially reach itself.
 //   Kills: a bit vector which contains a set of indices of blocks that can
-//          reach block 'i', but one of the path will cross a suspend point
+//          reach block 'i' but there is a path crossing a suspend point
+//          not repeating 'i' (path to 'i' without cycles containing 'i').
 //   Suspend: a boolean indicating whether block 'i' contains a suspend point.
 //   End: a boolean indicating whether block 'i' contains a coro.end intrinsic.
+//   KillLoop: There is a path from 'i' to 'i' not otherwise repeating 'i' that
+//             crosses a suspend point.
 //
 namespace {
 struct SuspendCrossingInfo {
@@ -92,6 +95,7 @@ struct SuspendCrossingInfo {
     BitVector Kills;
     bool Suspend = false;
     bool End = false;
+    bool KillLoop = false;
   };
   SmallVector<BlockData, SmallVectorThreshold> Block;
 
@@ -109,16 +113,31 @@ struct SuspendCrossingInfo {
 
   SuspendCrossingInfo(Function &F, coro::Shape &Shape);
 
-  bool hasPathCrossingSuspendPoint(BasicBlock *DefBB, BasicBlock *UseBB) const {
-    size_t const DefIndex = Mapping.blockToIndex(DefBB);
-    size_t const UseIndex = Mapping.blockToIndex(UseBB);
-
-    bool const Result = Block[UseIndex].Kills[DefIndex];
-    LLVM_DEBUG(dbgs() << UseBB->getName() << " => " << DefBB->getName()
+  /// Returns true if there is a path from \p From to \p To crossing a suspend
+  /// point without crossing \p From a 2nd time.
+  bool hasPathCrossingSuspendPoint(BasicBlock *From, BasicBlock *To) const {
+    size_t const FromIndex = Mapping.blockToIndex(From);
+    size_t const ToIndex = Mapping.blockToIndex(To);
+    bool const Result = Block[ToIndex].Kills[FromIndex];
+    LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName()
                       << " answer is " << Result << "\n");
     return Result;
   }
 
+  /// Returns true if there is a path from \p From to \p To crossing a suspend
+  /// point without crossing \p From a 2nd time. If \p From is the same as \p To
+  /// this will also check if there is a looping path crossing a suspend point.
+  bool hasPathOrLoopCrossingSuspendPoint(BasicBlock *From,
+                                         BasicBlock *To) const {
+    size_t const FromIndex = Mapping.blockToIndex(From);
+    size_t const ToIndex = Mapping.blockToIndex(To);
+    bool Result = Block[ToIndex].Kills[FromIndex] ||
+                  (From == To && Block[ToIndex].KillLoop);
+    LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName()
+                      << " answer is " << Result << " (path or loop)\n");
+    return Result;
+  }
+
   bool isDefinitionAcrossSuspend(BasicBlock *DefBB, User *U) const {
     auto *I = cast<Instruction>(U);
 
@@ -271,6 +290,7 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape)
         } else {
           // This is reached when S block it not Suspend nor coro.end and it
           // need to make sure that it is not in the kill set.
+          S.KillLoop |= S.Kills[SuccNo];
           S.Kills.reset(SuccNo);
         }
 
@@ -1440,6 +1460,19 @@ private:
         for (auto *S : LifetimeStarts)
           if (Checker.isDefinitionAcrossSuspend(*S, I))
             return true;
+      // Addresses are guaranteed to be identical after every lifetime.start so
+      // we cannot use the local stack if the address escaped and there is a
+      // suspend point between lifetime markers. This should also cover the
+      // case of a single lifetime.start intrinsic in a loop with suspend point.
+      if (PI.isEscaped()) {
+        for (auto *A : LifetimeStarts) {
+          for (auto *B : LifetimeStarts) {
+            if (Checker.hasPathOrLoopCrossingSuspendPoint(A->getParent(),
+                                                          B->getParent()))
+              return true;
+          }
+        }
+      }
       return false;
     }
     // FIXME: Ideally the isEscaped check should come at the beginning.
diff --git a/llvm/test/Transforms/Coroutines/coro-alloca-loop-carried-address.ll b/llvm/test/Transforms/Coroutines/coro-alloca-loop-carried-address.ll
new file mode 100644 (file)
index 0000000..2d52120
--- /dev/null
@@ -0,0 +1,86 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes='cgscc(coro-split),simplifycfg,early-cse' -S | FileCheck %s
+
+@escape_hatch0 = external global i64
+@escape_hatch1 = external global i64
+
+define void @foo() presplitcoroutine {
+; CHECK-LABEL: @foo(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[STACKVAR0:%.*]] = alloca i64, align 8
+; CHECK-NEXT:    [[ID:%.*]] = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr @foo.resumers)
+; CHECK-NEXT:    [[ALLOC:%.*]] = call ptr @malloc(i64 40)
+; CHECK-NEXT:    [[VFRAME:%.*]] = call noalias nonnull ptr @llvm.coro.begin(token [[ID]], ptr [[ALLOC]])
+; CHECK-NEXT:    store ptr @foo.resume, ptr [[VFRAME]], align 8
+; CHECK-NEXT:    [[DESTROY_ADDR:%.*]] = getelementptr inbounds [[FOO_FRAME:%.*]], ptr [[VFRAME]], i32 0, i32 1
+; CHECK-NEXT:    store ptr @foo.destroy, ptr [[DESTROY_ADDR]], align 8
+; CHECK-NEXT:    [[STACKVAR0_RELOAD_ADDR:%.*]] = getelementptr inbounds [[FOO_FRAME]], ptr [[VFRAME]], i32 0, i32 2
+; CHECK-NEXT:    [[STACKVAR1_RELOAD_ADDR:%.*]] = getelementptr inbounds [[FOO_FRAME]], ptr [[VFRAME]], i32 0, i32 3
+; CHECK-NEXT:    [[STACKVAR0_INT:%.*]] = ptrtoint ptr [[STACKVAR0_RELOAD_ADDR]] to i64
+; CHECK-NEXT:    store i64 [[STACKVAR0_INT]], ptr @escape_hatch0, align 4
+; CHECK-NEXT:    [[STACKVAR1_INT:%.*]] = ptrtoint ptr [[STACKVAR1_RELOAD_ADDR]] to i64
+; CHECK-NEXT:    store i64 [[STACKVAR1_INT]], ptr @escape_hatch1, align 4
+; CHECK-NEXT:    br label [[LOOP:%.*]]
+; CHECK:       loop:
+; CHECK-NEXT:    store i64 1234, ptr [[STACKVAR0_RELOAD_ADDR]], align 4
+; CHECK-NEXT:    call void @bar()
+; CHECK-NEXT:    [[INDEX_ADDR1:%.*]] = getelementptr inbounds [[FOO_FRAME]], ptr [[VFRAME]], i32 0, i32 4
+; CHECK-NEXT:    store i1 false, ptr [[INDEX_ADDR1]], align 1
+; CHECK-NEXT:    br i1 false, label [[LOOP]], label [[AFTERCOROEND:%.*]]
+; CHECK:       AfterCoroEnd:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %stackvar0 = alloca i64
+  %stackvar1 = alloca i64
+
+  ; address of %stackvar escapes and may be relied upon even after
+  ; suspending/resuming the coroutine regardless of the lifetime markers.
+  %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
+  %size = call i64 @llvm.coro.size.i64()
+  %alloc = call ptr @malloc(i64 %size)
+  %vFrame = call noalias nonnull ptr @llvm.coro.begin(token %id, ptr %alloc)
+
+  ; %stackvar0 must be rewritten to reference the coroutine Frame!
+  %stackvar0_int = ptrtoint ptr %stackvar0 to i64
+  store i64 %stackvar0_int, ptr @escape_hatch0
+  ; %stackvar1 must be rewritten to reference the coroutine Frame!
+  %stackvar1_int = ptrtoint ptr %stackvar1 to i64
+  store i64 %stackvar1_int, ptr @escape_hatch1
+
+  br label %loop
+
+loop:
+  call void @llvm.lifetime.start(i64 8, ptr %stackvar0)
+
+  store i64 1234, ptr %stackvar0
+
+  ; Call could potentially change value in memory referenced by %stackvar0 /
+  ; %stackvar1 and rely on it staying the same across suspension.
+  call void @bar()
+
+  call void @llvm.lifetime.end(i64 8, ptr %stackvar0)
+
+  %save = call token @llvm.coro.save(ptr null)
+  %suspend = call i8 @llvm.coro.suspend(token %save, i1 false)
+  switch i8 %suspend, label %exit [
+  i8 0, label %loop
+  i8 1, label %exit
+  ]
+
+exit:
+  call i1 @llvm.coro.end(ptr null, i1 false)
+  ret void
+}
+
+declare void @bar()
+declare ptr @malloc(i64)
+
+declare token @llvm.coro.id(i32, ptr readnone, ptr nocapture readonly, ptr)
+declare i64 @llvm.coro.size.i64()
+declare ptr @llvm.coro.begin(token, ptr writeonly)
+declare token @llvm.coro.save(ptr)
+declare i8 @llvm.coro.suspend(token, i1)
+declare i1 @llvm.coro.end(ptr, i1)
+declare void @llvm.lifetime.start(i64, ptr nocapture)
+declare void @llvm.lifetime.end(i64, ptr nocapture)