From e00a8d081d789cac606cf0749c332c4632132820 Mon Sep 17 00:00:00 2001 From: Matthias Braun Date: Thu, 23 Mar 2023 13:38:16 -0700 Subject: [PATCH] Fix codegen for coroutine with function-try-block This fixes an assertion error when writing a coroutine with a function-try-block. In this case the function body is not a `CompoundStmt` so the code constructing an artificial CXXTryStmt must also construct a `CompoundStmt` for it. While on it adjust the `CXXStmt::Create` function to only accept `CompoundStmt*`. Differential Revision: https://reviews.llvm.org/D146758 --- clang/include/clang/AST/StmtCXX.h | 5 +++-- clang/lib/AST/ASTImporter.cpp | 4 ++-- clang/lib/AST/StmtCXX.cpp | 5 +++-- clang/lib/CodeGen/CGCoroutine.cpp | 16 +++++++++++++-- clang/lib/Sema/SemaStmt.cpp | 3 ++- .../CodeGenCoroutines/coro-function-try-block.cpp | 23 ++++++++++++++++++++++ 6 files changed, 47 insertions(+), 9 deletions(-) create mode 100644 clang/test/CodeGenCoroutines/coro-function-try-block.cpp diff --git a/clang/include/clang/AST/StmtCXX.h b/clang/include/clang/AST/StmtCXX.h index 8ba667c..288ffc2 100644 --- a/clang/include/clang/AST/StmtCXX.h +++ b/clang/include/clang/AST/StmtCXX.h @@ -75,7 +75,8 @@ class CXXTryStmt final : public Stmt, unsigned NumHandlers; size_t numTrailingObjects(OverloadToken) const { return NumHandlers; } - CXXTryStmt(SourceLocation tryLoc, Stmt *tryBlock, ArrayRef handlers); + CXXTryStmt(SourceLocation tryLoc, CompoundStmt *tryBlock, + ArrayRef handlers); CXXTryStmt(EmptyShell Empty, unsigned numHandlers) : Stmt(CXXTryStmtClass), NumHandlers(numHandlers) { } @@ -84,7 +85,7 @@ class CXXTryStmt final : public Stmt, public: static CXXTryStmt *Create(const ASTContext &C, SourceLocation tryLoc, - Stmt *tryBlock, ArrayRef handlers); + CompoundStmt *tryBlock, ArrayRef handlers); static CXXTryStmt *Create(const ASTContext &C, EmptyShell Empty, unsigned numHandlers); diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp index db6ebbb..7120802 100644 --- a/clang/lib/AST/ASTImporter.cpp +++ b/clang/lib/AST/ASTImporter.cpp @@ -6793,8 +6793,8 @@ ExpectedStmt ASTNodeImporter::VisitCXXTryStmt(CXXTryStmt *S) { return ToHandlerOrErr.takeError(); } - return CXXTryStmt::Create( - Importer.getToContext(), *ToTryLocOrErr,*ToTryBlockOrErr, ToHandlers); + return CXXTryStmt::Create(Importer.getToContext(), *ToTryLocOrErr, + cast(*ToTryBlockOrErr), ToHandlers); } ExpectedStmt ASTNodeImporter::VisitCXXForRangeStmt(CXXForRangeStmt *S) { diff --git a/clang/lib/AST/StmtCXX.cpp b/clang/lib/AST/StmtCXX.cpp index a3ae539..0d6fc84 100644 --- a/clang/lib/AST/StmtCXX.cpp +++ b/clang/lib/AST/StmtCXX.cpp @@ -23,7 +23,8 @@ QualType CXXCatchStmt::getCaughtType() const { } CXXTryStmt *CXXTryStmt::Create(const ASTContext &C, SourceLocation tryLoc, - Stmt *tryBlock, ArrayRef handlers) { + CompoundStmt *tryBlock, + ArrayRef handlers) { const size_t Size = totalSizeToAlloc(handlers.size() + 1); void *Mem = C.Allocate(Size, alignof(CXXTryStmt)); return new (Mem) CXXTryStmt(tryLoc, tryBlock, handlers); @@ -36,7 +37,7 @@ CXXTryStmt *CXXTryStmt::Create(const ASTContext &C, EmptyShell Empty, return new (Mem) CXXTryStmt(Empty, numHandlers); } -CXXTryStmt::CXXTryStmt(SourceLocation tryLoc, Stmt *tryBlock, +CXXTryStmt::CXXTryStmt(SourceLocation tryLoc, CompoundStmt *tryBlock, ArrayRef handlers) : Stmt(CXXTryStmtClass), TryLoc(tryLoc), NumHandlers(handlers.size()) { Stmt **Stmts = getStmts(); diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp index da3da5e..90ab82e 100644 --- a/clang/lib/CodeGen/CGCoroutine.cpp +++ b/clang/lib/CodeGen/CGCoroutine.cpp @@ -593,6 +593,18 @@ static void emitBodyAndFallthrough(CodeGenFunction &CGF, CGF.EmitStmt(OnFallthrough); } +static CompoundStmt *CoroutineStmtBuilder(ASTContext &Context, + const CoroutineBodyStmt &S) { + Stmt *Stmt = S.getBody(); + if (CompoundStmt *Body = dyn_cast(Stmt)) + return Body; + // We are about to create a `CXXTryStmt` which requires a `CompoundStmt`. + // If the function body is not a `CompoundStmt` yet then we have to create + // a new one. This happens for cases like the "function-try-block" syntax. + return CompoundStmt::Create(Context, {Stmt}, FPOptionsOverride(), + SourceLocation(), SourceLocation()); +} + void CodeGenFunction::EmitCoroutineBody(const CoroutineBodyStmt &S) { auto *NullPtr = llvm::ConstantPointerNull::get(Builder.getInt8PtrTy()); auto &TI = CGM.getContext().getTargetInfo(); @@ -721,8 +733,8 @@ void CodeGenFunction::EmitCoroutineBody(const CoroutineBodyStmt &S) { auto Loc = S.getBeginLoc(); CXXCatchStmt Catch(Loc, /*exDecl=*/nullptr, CurCoro.Data->ExceptionHandler); - auto *TryStmt = - CXXTryStmt::Create(getContext(), Loc, S.getBody(), &Catch); + CompoundStmt *Body = CoroutineStmtBuilder(getContext(), S); + auto *TryStmt = CXXTryStmt::Create(getContext(), Loc, Body, &Catch); EnterCXXTryStmt(*TryStmt); emitBodyAndFallthrough(*this, S, TryStmt->getTryBlock()); diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp index 3e0b3c2..1af5f73 100644 --- a/clang/lib/Sema/SemaStmt.cpp +++ b/clang/lib/Sema/SemaStmt.cpp @@ -4546,7 +4546,8 @@ StmtResult Sema::ActOnCXXTryBlock(SourceLocation TryLoc, Stmt *TryBlock, FSI->setHasCXXTry(TryLoc); - return CXXTryStmt::Create(Context, TryLoc, TryBlock, Handlers); + return CXXTryStmt::Create(Context, TryLoc, cast(TryBlock), + Handlers); } StmtResult Sema::ActOnSEHTryBlock(bool IsCXXTry, SourceLocation TryLoc, diff --git a/clang/test/CodeGenCoroutines/coro-function-try-block.cpp b/clang/test/CodeGenCoroutines/coro-function-try-block.cpp new file mode 100644 index 0000000..f609eb5 --- /dev/null +++ b/clang/test/CodeGenCoroutines/coro-function-try-block.cpp @@ -0,0 +1,23 @@ +// RUN: %clang_cc1 -std=c++20 -triple=x86_64-- -emit-llvm -fcxx-exceptions \ +// RUN: -disable-llvm-passes %s -o - | FileCheck %s + +#include "Inputs/coroutine.h" + +struct task { + struct promise_type { + task get_return_object(); + std::suspend_never initial_suspend(); + std::suspend_never final_suspend() noexcept; + void return_void(); + void unhandled_exception() noexcept; + }; +}; + +task f() try { + co_return; +} catch(...) { +} + +// CHECK-LABEL: define{{.*}} void @_Z1fv( +// CHECK: call void @_ZNSt13suspend_never13await_suspendESt16coroutine_handleIvE( +// CHECK: call void @_ZN4task12promise_type11return_voidEv( -- 2.7.4