[C++20] [Coroutines] Handle function-try-block in SemaCoroutine
authorChuanqi Xu <yedeng.yd@linux.alibaba.com>
Thu, 6 Apr 2023 07:10:24 +0000 (15:10 +0800)
committerChuanqi Xu <yedeng.yd@linux.alibaba.com>
Thu, 6 Apr 2023 07:11:34 +0000 (15:11 +0800)
In https://reviews.llvm.org/D146758, we handled the rare case that the
coroutine has a function-try-block. But it will be better to handle it
in the Sema part. This patch handles the preprocess.

clang/include/clang/AST/StmtCXX.h
clang/lib/CodeGen/CGCoroutine.cpp
clang/lib/Sema/SemaCoroutine.cpp
clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp

index 288ffc2..60fc3f3 100644 (file)
@@ -375,9 +375,10 @@ public:
   }
 
   /// Retrieve the body of the coroutine as written. This will be either
-  /// a CompoundStmt or a TryStmt.
-  Stmt *getBody() const {
-    return getStoredStmts()[SubStmt::Body];
+  /// a CompoundStmt. If the coroutine is in function-try-block, we will
+  /// wrap the CXXTryStmt into a CompoundStmt to keep consistency.
+  CompoundStmt *getBody() const {
+    return cast<CompoundStmt>(getStoredStmts()[SubStmt::Body]);
   }
 
   Stmt *getPromiseDeclStmt() const {
index 90ab82e..da3da5e 100644 (file)
@@ -593,18 +593,6 @@ static void emitBodyAndFallthrough(CodeGenFunction &CGF,
       CGF.EmitStmt(OnFallthrough);
 }
 
-static CompoundStmt *CoroutineStmtBuilder(ASTContext &Context,
-                                          const CoroutineBodyStmt &S) {
-  Stmt *Stmt = S.getBody();
-  if (CompoundStmt *Body = dyn_cast<CompoundStmt>(Stmt))
-    return Body;
-  // We are about to create a `CXXTryStmt` which requires a `CompoundStmt`.
-  // If the function body is not a `CompoundStmt` yet then we have to create
-  // a new one. This happens for cases like the "function-try-block" syntax.
-  return CompoundStmt::Create(Context, {Stmt}, FPOptionsOverride(),
-                              SourceLocation(), SourceLocation());
-}
-
 void CodeGenFunction::EmitCoroutineBody(const CoroutineBodyStmt &S) {
   auto *NullPtr = llvm::ConstantPointerNull::get(Builder.getInt8PtrTy());
   auto &TI = CGM.getContext().getTargetInfo();
@@ -733,8 +721,8 @@ void CodeGenFunction::EmitCoroutineBody(const CoroutineBodyStmt &S) {
       auto Loc = S.getBeginLoc();
       CXXCatchStmt Catch(Loc, /*exDecl=*/nullptr,
                          CurCoro.Data->ExceptionHandler);
-      CompoundStmt *Body = CoroutineStmtBuilder(getContext(), S);
-      auto *TryStmt = CXXTryStmt::Create(getContext(), Loc, Body, &Catch);
+      auto *TryStmt =
+          CXXTryStmt::Create(getContext(), Loc, S.getBody(), &Catch);
 
       EnterCXXTryStmt(*TryStmt);
       emitBodyAndFallthrough(*this, S, TryStmt->getTryBlock());
index e87f2a7..deb6733 100644 (file)
@@ -1137,6 +1137,18 @@ void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
   Body = CoroutineBodyStmt::Create(Context, Builder);
 }
 
+static CompoundStmt *buildCoroutineBody(Stmt *Body, ASTContext &Context) {
+  if (auto *CS = dyn_cast<CompoundStmt>(Body))
+    return CS;
+
+  // The body of the coroutine may be a try statement if it is in
+  // 'function-try-block' syntax. Here we wrap it into a compound
+  // statement for consistency.
+  assert(isa<CXXTryStmt>(Body) && "Unimaged coroutine body type");
+  return CompoundStmt::Create(Context, {Body}, FPOptionsOverride(),
+                              SourceLocation(), SourceLocation());
+}
+
 CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD,
                                            sema::FunctionScopeInfo &Fn,
                                            Stmt *Body)
@@ -1144,7 +1156,7 @@ CoroutineStmtBuilder::CoroutineStmtBuilder(Sema &S, FunctionDecl &FD,
       IsPromiseDependentType(
           !Fn.CoroutinePromise ||
           Fn.CoroutinePromise->getType()->isDependentType()) {
-  this->Body = Body;
+  this->Body = buildCoroutineBody(Body, S.getASTContext());
 
   for (auto KV : Fn.CoroutineParameterMoves)
     this->ParamMovesVector.push_back(KV.second);
index e6d2bb4..8995471 100644 (file)
@@ -717,8 +717,8 @@ void coro() try {
 )cpp";
   EXPECT_TRUE(matchesConditionally(
       CoroWithTryCatchDeclCode,
-      coroutineBodyStmt(hasBody(cxxTryStmt(has(compoundStmt(has(
-          declStmt(containsDeclaration(0, varDecl(hasName("thevar")))))))))),
+      coroutineBodyStmt(hasBody(compoundStmt(has(cxxTryStmt(has(compoundStmt(has(
+          declStmt(containsDeclaration(0, varDecl(hasName("thevar")))))))))))),
       true, {"-std=c++20", "-I/"}, M));
 }