[clang] Expose CoawaitExpr's operand in the AST
authorNathan Ridge <zeratul976@hotmail.com>
Mon, 4 Apr 2022 06:29:21 +0000 (02:29 -0400)
committerNathan Ridge <zeratul976@hotmail.com>
Tue, 17 May 2022 12:13:37 +0000 (08:13 -0400)
Previously the Expr returned by getOperand() was actually the
subexpression common to the "ready", "suspend", and "resume"
expressions, which often isn't just the operand but e.g.
await_transform() called on the operand.

It's important for the AST to expose the operand as written
in the source for traversals and tools like clangd to work
correctly.

Fixes https://github.com/clangd/clangd/issues/939

Differential Revision: https://reviews.llvm.org/D115187

clang-tools-extra/clangd/unittests/FindTargetTests.cpp
clang/include/clang/AST/ExprCXX.h
clang/include/clang/Sema/Sema.h
clang/lib/Sema/SemaChecking.cpp
clang/lib/Sema/SemaCoroutine.cpp
clang/lib/Sema/TreeTransform.h
clang/test/AST/coroutine-locals-cleanup-exp-namespace.cpp
clang/test/AST/coroutine-locals-cleanup.cpp
clang/test/AST/coroutine-source-location-crash-exp-namespace.cpp
clang/test/AST/coroutine-source-location-crash.cpp
clang/test/SemaCXX/co_await-ast.cpp [new file with mode: 0644]

index c21114f..f7d547b 100644 (file)
@@ -548,6 +548,50 @@ TEST_F(TargetDeclTest, Concept) {
                {"template <typename T, typename U> concept Fooable = true"});
 }
 
+TEST_F(TargetDeclTest, Coroutine) {
+  Flags.push_back("-std=c++20");
+
+  Code = R"cpp(
+    namespace std::experimental {
+    template <typename, typename...> struct coroutine_traits;
+    template <typename> struct coroutine_handle {
+      template <typename U>
+      coroutine_handle(coroutine_handle<U>&&) noexcept;
+      static coroutine_handle from_address(void* __addr) noexcept;
+    };
+    } // namespace std::experimental
+
+    struct executor {};
+    struct awaitable {};
+    struct awaitable_frame {
+      awaitable get_return_object();
+      void return_void();
+      void unhandled_exception();
+      struct result_t {
+        ~result_t();
+        bool await_ready() const noexcept;
+        void await_suspend(std::experimental::coroutine_handle<void>) noexcept;
+        void await_resume() const noexcept;
+      };
+      result_t initial_suspend() noexcept;
+      result_t final_suspend() noexcept;
+      result_t await_transform(executor) noexcept;
+    };
+
+    namespace std::experimental {
+    template <>
+    struct coroutine_traits<awaitable> {
+      typedef awaitable_frame promise_type;
+    };
+    } // namespace std::experimental
+
+    awaitable foo() {
+      co_await [[executor]]();
+    }
+  )cpp";
+  EXPECT_DECLS("RecordTypeLoc", "struct executor");
+}
+
 TEST_F(TargetDeclTest, FunctionTemplate) {
   Code = R"cpp(
     // Implicit specialization.
index 3da9290..967a74d 100644 (file)
@@ -4698,18 +4698,19 @@ class CoroutineSuspendExpr : public Expr {
 
   SourceLocation KeywordLoc;
 
-  enum SubExpr { Common, Ready, Suspend, Resume, Count };
+  enum SubExpr { Operand, Common, Ready, Suspend, Resume, Count };
 
   Stmt *SubExprs[SubExpr::Count];
   OpaqueValueExpr *OpaqueValue = nullptr;
 
 public:
-  CoroutineSuspendExpr(StmtClass SC, SourceLocation KeywordLoc, Expr *Common,
-                       Expr *Ready, Expr *Suspend, Expr *Resume,
+  CoroutineSuspendExpr(StmtClass SC, SourceLocation KeywordLoc, Expr *Operand,
+                       Expr *Common, Expr *Ready, Expr *Suspend, Expr *Resume,
                        OpaqueValueExpr *OpaqueValue)
       : Expr(SC, Resume->getType(), Resume->getValueKind(),
              Resume->getObjectKind()),
         KeywordLoc(KeywordLoc), OpaqueValue(OpaqueValue) {
+    SubExprs[SubExpr::Operand] = Operand;
     SubExprs[SubExpr::Common] = Common;
     SubExprs[SubExpr::Ready] = Ready;
     SubExprs[SubExpr::Suspend] = Suspend;
@@ -4718,10 +4719,11 @@ public:
   }
 
   CoroutineSuspendExpr(StmtClass SC, SourceLocation KeywordLoc, QualType Ty,
-                       Expr *Common)
+                       Expr *Operand, Expr *Common)
       : Expr(SC, Ty, VK_PRValue, OK_Ordinary), KeywordLoc(KeywordLoc) {
     assert(Common->isTypeDependent() && Ty->isDependentType() &&
            "wrong constructor for non-dependent co_await/co_yield expression");
+    SubExprs[SubExpr::Operand] = Operand;
     SubExprs[SubExpr::Common] = Common;
     SubExprs[SubExpr::Ready] = nullptr;
     SubExprs[SubExpr::Suspend] = nullptr;
@@ -4730,14 +4732,13 @@ public:
   }
 
   CoroutineSuspendExpr(StmtClass SC, EmptyShell Empty) : Expr(SC, Empty) {
+    SubExprs[SubExpr::Operand] = nullptr;
     SubExprs[SubExpr::Common] = nullptr;
     SubExprs[SubExpr::Ready] = nullptr;
     SubExprs[SubExpr::Suspend] = nullptr;
     SubExprs[SubExpr::Resume] = nullptr;
   }
 
-  SourceLocation getKeywordLoc() const { return KeywordLoc; }
-
   Expr *getCommonExpr() const {
     return static_cast<Expr*>(SubExprs[SubExpr::Common]);
   }
@@ -4757,10 +4758,17 @@ public:
     return static_cast<Expr*>(SubExprs[SubExpr::Resume]);
   }
 
+  // The syntactic operand written in the code
+  Expr *getOperand() const {
+    return static_cast<Expr *>(SubExprs[SubExpr::Operand]);
+  }
+
+  SourceLocation getKeywordLoc() const { return KeywordLoc; }
+
   SourceLocation getBeginLoc() const LLVM_READONLY { return KeywordLoc; }
 
   SourceLocation getEndLoc() const LLVM_READONLY {
-    return getCommonExpr()->getEndLoc();
+    return getOperand()->getEndLoc();
   }
 
   child_range children() {
@@ -4782,28 +4790,24 @@ class CoawaitExpr : public CoroutineSuspendExpr {
   friend class ASTStmtReader;
 
 public:
-  CoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand, Expr *Ready,
-              Expr *Suspend, Expr *Resume, OpaqueValueExpr *OpaqueValue,
-              bool IsImplicit = false)
-      : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Operand, Ready,
-                             Suspend, Resume, OpaqueValue) {
+  CoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand, Expr *Common,
+              Expr *Ready, Expr *Suspend, Expr *Resume,
+              OpaqueValueExpr *OpaqueValue, bool IsImplicit = false)
+      : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Operand, Common,
+                             Ready, Suspend, Resume, OpaqueValue) {
     CoawaitBits.IsImplicit = IsImplicit;
   }
 
   CoawaitExpr(SourceLocation CoawaitLoc, QualType Ty, Expr *Operand,
-              bool IsImplicit = false)
-      : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Ty, Operand) {
+              Expr *Common, bool IsImplicit = false)
+      : CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Ty, Operand,
+                             Common) {
     CoawaitBits.IsImplicit = IsImplicit;
   }
 
   CoawaitExpr(EmptyShell Empty)
       : CoroutineSuspendExpr(CoawaitExprClass, Empty) {}
 
-  Expr *getOperand() const {
-    // FIXME: Dig out the actual operand or store it.
-    return getCommonExpr();
-  }
-
   bool isImplicit() const { return CoawaitBits.IsImplicit; }
   void setIsImplicit(bool value = true) { CoawaitBits.IsImplicit = value; }
 
@@ -4867,20 +4871,18 @@ class CoyieldExpr : public CoroutineSuspendExpr {
   friend class ASTStmtReader;
 
 public:
-  CoyieldExpr(SourceLocation CoyieldLoc, Expr *Operand, Expr *Ready,
-              Expr *Suspend, Expr *Resume, OpaqueValueExpr *OpaqueValue)
-      : CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Operand, Ready,
-                             Suspend, Resume, OpaqueValue) {}
-  CoyieldExpr(SourceLocation CoyieldLoc, QualType Ty, Expr *Operand)
-      : CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Ty, Operand) {}
+  CoyieldExpr(SourceLocation CoyieldLoc, Expr *Operand, Expr *Common,
+              Expr *Ready, Expr *Suspend, Expr *Resume,
+              OpaqueValueExpr *OpaqueValue)
+      : CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Operand, Common,
+                             Ready, Suspend, Resume, OpaqueValue) {}
+  CoyieldExpr(SourceLocation CoyieldLoc, QualType Ty, Expr *Operand,
+              Expr *Common)
+      : CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Ty, Operand,
+                             Common) {}
   CoyieldExpr(EmptyShell Empty)
       : CoroutineSuspendExpr(CoyieldExprClass, Empty) {}
 
-  Expr *getOperand() const {
-    // FIXME: Dig out the actual operand or store it.
-    return getCommonExpr();
-  }
-
   static bool classof(const Stmt *T) {
     return T->getStmtClass() == CoyieldExprClass;
   }
index e9bd756..976ee59 100644 (file)
@@ -10426,10 +10426,13 @@ public:
   ExprResult ActOnCoyieldExpr(Scope *S, SourceLocation KwLoc, Expr *E);
   StmtResult ActOnCoreturnStmt(Scope *S, SourceLocation KwLoc, Expr *E);
 
-  ExprResult BuildResolvedCoawaitExpr(SourceLocation KwLoc, Expr *E,
-                                      bool IsImplicit = false);
-  ExprResult BuildUnresolvedCoawaitExpr(SourceLocation KwLoc, Expr *E,
-                                        UnresolvedLookupExpr* Lookup);
+  ExprResult BuildOperatorCoawaitLookupExpr(Scope *S, SourceLocation Loc);
+  ExprResult BuildOperatorCoawaitCall(SourceLocation Loc, Expr *E,
+                                      UnresolvedLookupExpr *Lookup);
+  ExprResult BuildResolvedCoawaitExpr(SourceLocation KwLoc, Expr *Operand,
+                                      Expr *Awaiter, bool IsImplicit = false);
+  ExprResult BuildUnresolvedCoawaitExpr(SourceLocation KwLoc, Expr *Operand,
+                                        UnresolvedLookupExpr *Lookup);
   ExprResult BuildCoyieldExpr(SourceLocation KwLoc, Expr *E);
   StmtResult BuildCoreturnStmt(SourceLocation KwLoc, Expr *E,
                                bool IsImplicit = false);
index 4c4041a..453364c 100644 (file)
@@ -14095,6 +14095,13 @@ static void AnalyzeImplicitConversions(
     if (!ChildExpr)
       continue;
 
+    if (auto *CSE = dyn_cast<CoroutineSuspendExpr>(E))
+      if (ChildExpr == CSE->getOperand())
+        // Do not recurse over a CoroutineSuspendExpr's operand.
+        // The operand is also a subexpression of getCommonExpr(), and
+        // recursing into it directly would produce duplicate diagnostics.
+        continue;
+
     if (IsLogicalAndOperator &&
         isa<StringLiteral>(ChildExpr->IgnoreParenImpCasts()))
       // Ignore checking string literals that are in logical and operators.
index 8709e7b..b43b0b3 100644 (file)
@@ -246,44 +246,22 @@ static bool isValidCoroutineContext(Sema &S, SourceLocation Loc,
   return !Diagnosed;
 }
 
-static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S,
-                                                 SourceLocation Loc) {
-  DeclarationName OpName =
-      SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
-  LookupResult Operators(SemaRef, OpName, SourceLocation(),
-                         Sema::LookupOperatorName);
-  SemaRef.LookupName(Operators, S);
-
-  assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
-  const auto &Functions = Operators.asUnresolvedSet();
-  bool IsOverloaded =
-      Functions.size() > 1 ||
-      (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin()));
-  Expr *CoawaitOp = UnresolvedLookupExpr::Create(
-      SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
-      DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
-      Functions.begin(), Functions.end());
-  assert(CoawaitOp);
-  return CoawaitOp;
-}
-
 /// Build a call to 'operator co_await' if there is a suitable operator for
 /// the given expression.
-static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc,
-                                           Expr *E,
-                                           UnresolvedLookupExpr *Lookup) {
+ExprResult Sema::BuildOperatorCoawaitCall(SourceLocation Loc, Expr *E,
+                                          UnresolvedLookupExpr *Lookup) {
   UnresolvedSet<16> Functions;
   Functions.append(Lookup->decls_begin(), Lookup->decls_end());
-  return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
+  return CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
 }
 
 static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
                                            SourceLocation Loc, Expr *E) {
-  ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc);
+  ExprResult R = SemaRef.BuildOperatorCoawaitLookupExpr(S, Loc);
   if (R.isInvalid())
     return ExprError();
-  return buildOperatorCoawaitCall(SemaRef, Loc, E,
-                                  cast<UnresolvedLookupExpr>(R.get()));
+  return SemaRef.BuildOperatorCoawaitCall(Loc, E,
+                                          cast<UnresolvedLookupExpr>(R.get()));
 }
 
 static ExprResult buildCoroutineHandle(Sema &S, QualType PromiseType,
@@ -727,14 +705,15 @@ bool Sema::ActOnCoroutineBodyStart(Scope *SC, SourceLocation KWLoc,
   SourceLocation Loc = Fn->getLocation();
   // Build the initial suspend point
   auto buildSuspends = [&](StringRef Name) mutable -> StmtResult {
-    ExprResult Suspend =
+    ExprResult Operand =
         buildPromiseCall(*this, ScopeInfo->CoroutinePromise, Loc, Name, None);
-    if (Suspend.isInvalid())
+    if (Operand.isInvalid())
       return StmtError();
-    Suspend = buildOperatorCoawaitCall(*this, SC, Loc, Suspend.get());
+    ExprResult Suspend =
+        buildOperatorCoawaitCall(*this, SC, Loc, Operand.get());
     if (Suspend.isInvalid())
       return StmtError();
-    Suspend = BuildResolvedCoawaitExpr(Loc, Suspend.get(),
+    Suspend = BuildResolvedCoawaitExpr(Loc, Operand.get(), Suspend.get(),
                                        /*IsImplicit*/ true);
     Suspend = ActOnFinishFullExpr(Suspend.get(), /*DiscardedValue*/ false);
     if (Suspend.isInvalid()) {
@@ -815,88 +794,112 @@ ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
     if (R.isInvalid()) return ExprError();
     E = R.get();
   }
-  ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc);
+  ExprResult Lookup = BuildOperatorCoawaitLookupExpr(S, Loc);
   if (Lookup.isInvalid())
     return ExprError();
   return BuildUnresolvedCoawaitExpr(Loc, E,
                                    cast<UnresolvedLookupExpr>(Lookup.get()));
 }
 
-ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E,
+ExprResult Sema::BuildOperatorCoawaitLookupExpr(Scope *S, SourceLocation Loc) {
+  DeclarationName OpName =
+      Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
+  LookupResult Operators(*this, OpName, SourceLocation(),
+                         Sema::LookupOperatorName);
+  LookupName(Operators, S);
+
+  assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
+  const auto &Functions = Operators.asUnresolvedSet();
+  bool IsOverloaded =
+      Functions.size() > 1 ||
+      (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin()));
+  Expr *CoawaitOp = UnresolvedLookupExpr::Create(
+      Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
+      DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
+      Functions.begin(), Functions.end());
+  assert(CoawaitOp);
+  return CoawaitOp;
+}
+
+// Attempts to resolve and build a CoawaitExpr from "raw" inputs, bailing out to
+// DependentCoawaitExpr if needed.
+ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
                                             UnresolvedLookupExpr *Lookup) {
   auto *FSI = checkCoroutineContext(*this, Loc, "co_await");
   if (!FSI)
     return ExprError();
 
-  if (E->hasPlaceholderType()) {
-    ExprResult R = CheckPlaceholderExpr(E);
+  if (Operand->hasPlaceholderType()) {
+    ExprResult R = CheckPlaceholderExpr(Operand);
     if (R.isInvalid())
       return ExprError();
-    E = R.get();
+    Operand = R.get();
   }
 
   auto *Promise = FSI->CoroutinePromise;
   if (Promise->getType()->isDependentType()) {
-    Expr *Res =
-        new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup);
+    Expr *Res = new (Context)
+        DependentCoawaitExpr(Loc, Context.DependentTy, Operand, Lookup);
     return Res;
   }
 
   auto *RD = Promise->getType()->getAsCXXRecordDecl();
+  auto *Transformed = Operand;
   if (lookupMember(*this, "await_transform", RD, Loc)) {
-    ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E);
+    ExprResult R =
+        buildPromiseCall(*this, Promise, Loc, "await_transform", Operand);
     if (R.isInvalid()) {
       Diag(Loc,
            diag::note_coroutine_promise_implicit_await_transform_required_here)
-          << E->getSourceRange();
+          << Operand->getSourceRange();
       return ExprError();
     }
-    E = R.get();
+    Transformed = R.get();
   }
-  ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup);
-  if (Awaitable.isInvalid())
+  ExprResult Awaiter = BuildOperatorCoawaitCall(Loc, Transformed, Lookup);
+  if (Awaiter.isInvalid())
     return ExprError();
 
-  return BuildResolvedCoawaitExpr(Loc, Awaitable.get());
+  return BuildResolvedCoawaitExpr(Loc, Operand, Awaiter.get());
 }
 
-ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E,
-                                  bool IsImplicit) {
+ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *Operand,
+                                          Expr *Awaiter, bool IsImplicit) {
   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await", IsImplicit);
   if (!Coroutine)
     return ExprError();
 
-  if (E->hasPlaceholderType()) {
-    ExprResult R = CheckPlaceholderExpr(E);
+  if (Awaiter->hasPlaceholderType()) {
+    ExprResult R = CheckPlaceholderExpr(Awaiter);
     if (R.isInvalid()) return ExprError();
-    E = R.get();
+    Awaiter = R.get();
   }
 
-  if (E->getType()->isDependentType()) {
+  if (Awaiter->getType()->isDependentType()) {
     Expr *Res = new (Context)
-        CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit);
+        CoawaitExpr(Loc, Context.DependentTy, Operand, Awaiter, IsImplicit);
     return Res;
   }
 
   // If the expression is a temporary, materialize it as an lvalue so that we
   // can use it multiple times.
-  if (E->isPRValue())
-    E = CreateMaterializeTemporaryExpr(E->getType(), E, true);
+  if (Awaiter->isPRValue())
+    Awaiter = CreateMaterializeTemporaryExpr(Awaiter->getType(), Awaiter, true);
 
   // The location of the `co_await` token cannot be used when constructing
   // the member call expressions since it's before the location of `Expr`, which
   // is used as the start of the member call expression.
-  SourceLocation CallLoc = E->getExprLoc();
+  SourceLocation CallLoc = Awaiter->getExprLoc();
 
   // Build the await_ready, await_suspend, await_resume calls.
-  ReadySuspendResumeResult RSS = buildCoawaitCalls(
-      *this, Coroutine->CoroutinePromise, CallLoc, E);
+  ReadySuspendResumeResult RSS =
+      buildCoawaitCalls(*this, Coroutine->CoroutinePromise, CallLoc, Awaiter);
   if (RSS.IsInvalid)
     return ExprError();
 
-  Expr *Res =
-      new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
-                                RSS.Results[2], RSS.OpaqueValue, IsImplicit);
+  Expr *Res = new (Context)
+      CoawaitExpr(Loc, Operand, Awaiter, RSS.Results[0], RSS.Results[1],
+                  RSS.Results[2], RSS.OpaqueValue, IsImplicit);
 
   return Res;
 }
@@ -933,8 +936,10 @@ ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
     E = R.get();
   }
 
+  Expr *Operand = E;
+
   if (E->getType()->isDependentType()) {
-    Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, E);
+    Expr *Res = new (Context) CoyieldExpr(Loc, Context.DependentTy, Operand, E);
     return Res;
   }
 
@@ -950,7 +955,7 @@ ExprResult Sema::BuildCoyieldExpr(SourceLocation Loc, Expr *E) {
     return ExprError();
 
   Expr *Res =
-      new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
+      new (Context) CoyieldExpr(Loc, Operand, E, RSS.Results[0], RSS.Results[1],
                                 RSS.Results[2], RSS.OpaqueValue);
 
   return Res;
index aef757e..3d7279c 100644 (file)
@@ -1470,9 +1470,28 @@ public:
   ///
   /// By default, performs semantic analysis to build the new expression.
   /// Subclasses may override this routine to provide different behavior.
-  ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Result,
+  ExprResult RebuildCoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand,
+                                UnresolvedLookupExpr *OpCoawaitLookup,
                                 bool IsImplicit) {
-    return getSema().BuildResolvedCoawaitExpr(CoawaitLoc, Result, IsImplicit);
+    // This function rebuilds a coawait-expr given its operator.
+    // For an explicit coawait-expr, the rebuild involves the full set
+    // of transformations performed by BuildUnresolvedCoawaitExpr(),
+    // including calling await_transform().
+    // For an implicit coawait-expr, we need to rebuild the "operator
+    // coawait" but not await_transform(), so use BuildResolvedCoawaitExpr().
+    // This mirrors how the implicit CoawaitExpr is originally created
+    // in Sema::ActOnCoroutineBodyStart().
+    if (IsImplicit) {
+      ExprResult Suspend = getSema().BuildOperatorCoawaitCall(
+          CoawaitLoc, Operand, OpCoawaitLookup);
+      if (Suspend.isInvalid())
+        return ExprError();
+      return getSema().BuildResolvedCoawaitExpr(CoawaitLoc, Operand,
+                                                Suspend.get(), true);
+    }
+
+    return getSema().BuildUnresolvedCoawaitExpr(CoawaitLoc, Operand,
+                                                OpCoawaitLookup);
   }
 
   /// Build a new co_await expression.
@@ -7945,18 +7964,27 @@ TreeTransform<Derived>::TransformCoreturnStmt(CoreturnStmt *S) {
                                           S->isImplicit());
 }
 
-template<typename Derived>
-ExprResult
-TreeTransform<Derived>::TransformCoawaitExpr(CoawaitExpr *E) {
-  ExprResult Result = getDerived().TransformInitializer(E->getOperand(),
-                                                        /*NotCopyInit*/false);
-  if (Result.isInvalid())
+template <typename Derived>
+ExprResult TreeTransform<Derived>::TransformCoawaitExpr(CoawaitExpr *E) {
+  ExprResult Operand = getDerived().TransformInitializer(E->getOperand(),
+                                                         /*NotCopyInit*/ false);
+  if (Operand.isInvalid())
     return ExprError();
 
+  // Rebuild the common-expr from the operand rather than transforming it
+  // separately.
+
+  // FIXME: getCurScope() should not be used during template instantiation.
+  // We should pick up the set of unqualified lookup results for operator
+  // co_await during the initial parse.
+  ExprResult Lookup = getSema().BuildOperatorCoawaitLookupExpr(
+      getSema().getCurScope(), E->getKeywordLoc());
+
   // Always rebuild; we don't know if this needs to be injected into a new
   // context or if the promise type has changed.
-  return getDerived().RebuildCoawaitExpr(E->getKeywordLoc(), Result.get(),
-                                         E->isImplicit());
+  return getDerived().RebuildCoawaitExpr(
+      E->getKeywordLoc(), Operand.get(),
+      cast<UnresolvedLookupExpr>(Lookup.get()), E->isImplicit());
 }
 
 template <typename Derived>
index 048c677..3122df9 100644 (file)
@@ -85,7 +85,8 @@ Task bar() {
 // CHECK:           CaseStmt
 // CHECK:             ExprWithCleanups {{.*}} 'void'
 // CHECK-NEXT:          CoawaitExpr
-// CHECK-NEXT:            MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
+// CHECK-NEXT:            CXXBindTemporaryExpr {{.*}} 'Task' (CXXTemporary {{.*}})
+// CHECK:                 MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
 // CHECK:                 ExprWithCleanups {{.*}} 'bool'
 // CHECK-NEXT:              CXXMemberCallExpr {{.*}} 'bool'
 // CHECK-NEXT:                MemberExpr {{.*}} .await_ready
@@ -97,7 +98,8 @@ Task bar() {
 // CHECK:           CaseStmt
 // CHECK:             ExprWithCleanups {{.*}} 'void'
 // CHECK-NEXT:          CoawaitExpr
-// CHECK-NEXT:            MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
+// CHECK-NEXT:            CXXBindTemporaryExpr {{.*}} 'Task' (CXXTemporary {{.*}})
+// CHECK:                 MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
 // CHECK:                 ExprWithCleanups {{.*}} 'bool'
 // CHECK-NEXT:              CXXMemberCallExpr {{.*}} 'bool'
 // CHECK-NEXT:                MemberExpr {{.*}} .await_ready
index 4e2fe62..aa04a35 100644 (file)
@@ -85,7 +85,8 @@ Task bar() {
 // CHECK:           CaseStmt
 // CHECK:             ExprWithCleanups {{.*}} 'void'
 // CHECK-NEXT:          CoawaitExpr
-// CHECK-NEXT:            MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
+// CHECK-NEXT:            CXXBindTemporaryExpr {{.*}} 'Task' (CXXTemporary {{.*}})
+// CHECK:                 MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
 // CHECK:                 ExprWithCleanups {{.*}} 'bool'
 // CHECK-NEXT:              CXXMemberCallExpr {{.*}} 'bool'
 // CHECK-NEXT:                MemberExpr {{.*}} .await_ready
@@ -97,7 +98,8 @@ Task bar() {
 // CHECK:           CaseStmt
 // CHECK:             ExprWithCleanups {{.*}} 'void'
 // CHECK-NEXT:          CoawaitExpr
-// CHECK-NEXT:            MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
+// CHECK-NEXT:            CXXBindTemporaryExpr {{.*}} 'Task' (CXXTemporary {{.*}})
+// CHECK:                 MaterializeTemporaryExpr {{.*}} 'Task::Awaiter':'Task::Awaiter'
 // CHECK:                 ExprWithCleanups {{.*}} 'bool'
 // CHECK-NEXT:              CXXMemberCallExpr {{.*}} 'bool'
 // CHECK-NEXT:                MemberExpr {{.*}} .await_ready
index 9995dee..fb9aaa5 100644 (file)
@@ -36,6 +36,7 @@ coro_t f(int n) {
   A a{};
   // CHECK: CoawaitExpr {{0x[0-9a-fA-F]+}} <col:3, col:12>
   // CHECK-NEXT: DeclRefExpr {{0x[0-9a-fA-F]+}} <col:12>
+  // CHECK-NEXT: DeclRefExpr {{0x[0-9a-fA-F]+}} <col:12>
   // CHECK-NEXT: CXXMemberCallExpr {{0x[0-9a-fA-F]+}} <col:12>
   // CHECK-NEXT: MemberExpr {{0x[0-9a-fA-F]+}} <col:12>
   co_await a;
index 9b18dc8..fcf23d2 100644 (file)
@@ -36,6 +36,7 @@ coro_t f(int n) {
   A a{};
   // CHECK: CoawaitExpr {{0x[0-9a-fA-F]+}} <col:3, col:12>
   // CHECK-NEXT: DeclRefExpr {{0x[0-9a-fA-F]+}} <col:12>
+  // CHECK-NEXT: DeclRefExpr {{0x[0-9a-fA-F]+}} <col:12>
   // CHECK-NEXT: CXXMemberCallExpr {{0x[0-9a-fA-F]+}} <col:12>
   // CHECK-NEXT: MemberExpr {{0x[0-9a-fA-F]+}} <col:12>
   co_await a;
diff --git a/clang/test/SemaCXX/co_await-ast.cpp b/clang/test/SemaCXX/co_await-ast.cpp
new file mode 100644 (file)
index 0000000..1221e86
--- /dev/null
@@ -0,0 +1,97 @@
+// RUN: %clang_cc1 -std=c++20 -fsyntax-only -ast-dump -ast-dump-filter=foo %s | FileCheck %s --strict-whitespace
+
+namespace std {
+template <typename, typename...> struct coroutine_traits;
+template <typename> struct coroutine_handle {
+  template <typename U>
+  coroutine_handle(coroutine_handle<U> &&) noexcept;
+  static coroutine_handle from_address(void *__addr) noexcept;
+};
+} // namespace std
+
+struct executor {};
+struct awaitable {};
+struct awaitable_frame {
+  awaitable get_return_object();
+  void return_void();
+  void unhandled_exception();
+  struct result_t {
+    ~result_t();
+    bool await_ready() const noexcept;
+    void await_suspend(std::coroutine_handle<void>) noexcept;
+    void await_resume() const noexcept;
+  };
+  result_t initial_suspend() noexcept;
+  result_t final_suspend() noexcept;
+  result_t await_transform(executor) noexcept;
+};
+
+namespace std {
+template <>
+struct coroutine_traits<awaitable> {
+  typedef awaitable_frame promise_type;
+};
+} // namespace std
+
+awaitable foo() {
+  co_await executor();
+}
+
+// Check that CoawaitExpr contains the correct subexpressions, including
+// the operand expression as written in the source.
+
+// CHECK-LABEL: Dumping foo:
+// CHECK: FunctionDecl {{.*}} foo 'awaitable ()'
+// CHECK: `-CoroutineBodyStmt {{.*}}
+// CHECK:   |-CompoundStmt {{.*}}
+// CHECK:   | `-ExprWithCleanups {{.*}} 'void'
+// CHECK:   |   `-CoawaitExpr {{.*}} 'void'
+// CHECK:   |     |-CXXTemporaryObjectExpr {{.*}} 'executor' 'void () noexcept' zeroing
+// CHECK:   |     |-MaterializeTemporaryExpr {{.*}} 'awaitable_frame::result_t' lvalue
+// CHECK:   |     | `-CXXBindTemporaryExpr {{.*}} 'awaitable_frame::result_t' (CXXTemporary {{.*}})
+// CHECK:   |     |   `-CXXMemberCallExpr {{.*}} 'awaitable_frame::result_t'
+// CHECK:   |     |     |-MemberExpr {{.*}} '<bound member function type>' .await_transform {{.*}}
+// CHECK:   |     |     | `-DeclRefExpr {{.*}} 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame' lvalue Var {{.*}} '__promise' 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame'
+// CHECK:   |     |     `-CXXTemporaryObjectExpr {{.*}} 'executor' 'void () noexcept' zeroing
+// CHECK:   |     |-ExprWithCleanups {{.*}} 'bool'
+// CHECK:   |     | `-CXXMemberCallExpr {{.*}} 'bool'
+// CHECK:   |     |   `-MemberExpr {{.*}} '<bound member function type>' .await_ready {{.*}}
+// CHECK:   |     |     `-ImplicitCastExpr {{.*}} 'const awaitable_frame::result_t' lvalue <NoOp>
+// CHECK:   |     |       `-OpaqueValueExpr {{.*}} 'awaitable_frame::result_t' lvalue
+// CHECK:   |     |         `-MaterializeTemporaryExpr {{.*}} 'awaitable_frame::result_t' lvalue
+// CHECK:   |     |           `-CXXBindTemporaryExpr {{.*}} 'awaitable_frame::result_t' (CXXTemporary {{.*}})
+// CHECK:   |     |             `-CXXMemberCallExpr {{.*}} 'awaitable_frame::result_t'
+// CHECK:   |     |               |-MemberExpr {{.*}} '<bound member function type>' .await_transform {{.*}}
+// CHECK:   |     |               | `-DeclRefExpr {{.*}} 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame' lvalue Var {{.*}} '__promise' 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame'
+// CHECK:   |     |               `-CXXTemporaryObjectExpr {{.*}} 'executor' 'void () noexcept' zeroing
+// CHECK:   |     |-ExprWithCleanups {{.*}} 'void'
+// CHECK:   |     | `-CXXMemberCallExpr {{.*}} 'void'
+// CHECK:   |     |   |-MemberExpr {{.*}} '<bound member function type>' .await_suspend {{.*}}
+// CHECK:   |     |   | `-OpaqueValueExpr {{.*}} 'awaitable_frame::result_t' lvalue
+// CHECK:   |     |   |   `-MaterializeTemporaryExpr {{.*}} 'awaitable_frame::result_t' lvalue
+// CHECK:   |     |   |     `-CXXBindTemporaryExpr {{.*}} 'awaitable_frame::result_t' (CXXTemporary {{.*}})
+// CHECK:   |     |   |       `-CXXMemberCallExpr {{.*}} 'awaitable_frame::result_t'
+// CHECK:   |     |   |         |-MemberExpr {{.*}} '<bound member function type>' .await_transform {{.*}}
+// CHECK:   |     |   |         | `-DeclRefExpr {{.*}} 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame' lvalue Var {{.*}} '__promise' 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame'
+// CHECK:   |     |   |         `-CXXTemporaryObjectExpr {{.*}} 'executor' 'void () noexcept' zeroing
+// CHECK:   |     |   `-ImplicitCastExpr {{.*}} 'std::coroutine_handle<void>':'std::coroutine_handle<void>' <ConstructorConversion>
+// CHECK:   |     |     `-CXXConstructExpr {{.*}} 'std::coroutine_handle<void>':'std::coroutine_handle<void>' 'void (coroutine_handle<awaitable_frame> &&) noexcept'
+// CHECK:   |     |       `-MaterializeTemporaryExpr {{.*}} 'std::coroutine_handle<awaitable_frame>' xvalue
+// CHECK:   |     |         `-CallExpr {{.*}} 'std::coroutine_handle<awaitable_frame>'
+// CHECK:   |     |           |-ImplicitCastExpr {{.*}} 'std::coroutine_handle<awaitable_frame> (*)(void *) noexcept' <FunctionToPointerDecay>
+// CHECK:   |     |           | `-DeclRefExpr {{.*}} 'std::coroutine_handle<awaitable_frame> (void *) noexcept' lvalue CXXMethod {{.*}} 'from_address' 'std::coroutine_handle<awaitable_frame> (void *) noexcept'
+// CHECK:   |     |           `-CallExpr {{.*}} 'void *'
+// CHECK:   |     |             `-ImplicitCastExpr {{.*}} 'void *(*)() noexcept' <FunctionToPointerDecay>
+// CHECK:   |     |               `-DeclRefExpr {{.*}} 'void *() noexcept' lvalue Function {{.*}} '__builtin_coro_frame' 'void *() noexcept'
+// CHECK:   |     `-CXXMemberCallExpr {{.*}} 'void'
+// CHECK:   |       `-MemberExpr {{.*}} '<bound member function type>' .await_resume {{.*}}
+// CHECK:   |         `-ImplicitCastExpr {{.*}} 'const awaitable_frame::result_t' lvalue <NoOp>
+// CHECK:   |           `-OpaqueValueExpr {{.*}} 'awaitable_frame::result_t' lvalue
+// CHECK:   |             `-MaterializeTemporaryExpr {{.*}} 'awaitable_frame::result_t' lvalue
+// CHECK:   |               `-CXXBindTemporaryExpr {{.*}} 'awaitable_frame::result_t' (CXXTemporary {{.*}})
+// CHECK:   |                 `-CXXMemberCallExpr {{.*}} 'awaitable_frame::result_t'
+// CHECK:   |                   |-MemberExpr {{.*}} '<bound member function type>' .await_transform {{.*}}
+// CHECK:   |                   | `-DeclRefExpr {{.*}} 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame' lvalue Var {{.*}} '__promise' 'std::coroutine_traits<awaitable>::promise_type':'awaitable_frame'
+// CHECK:   |                   `-CXXTemporaryObjectExpr {{.*}} <col:12, col:21> 'executor' 'void () noexcept' zeroing
+
+// Rest of the generated coroutine statements omitted.