[clang] Fix consteval operators in template contexts
authorMariya Podchishchaeva <mariya.podchishchaeva@intel.com>
Wed, 7 Jun 2023 14:28:04 +0000 (10:28 -0400)
committerMariya Podchishchaeva <mariya.podchishchaeva@intel.com>
Thu, 8 Jun 2023 08:26:45 +0000 (04:26 -0400)
Clang used to reject consteval operators if they're used inside a
template due to TreeTransform putting two different `DeclRefExpr`
expressions for the same reference of the same operator declaration into
`ReferenceToConsteval` set.
It seems there was an attempt to not rebuild the whole operator that
never succeeded, so this patch just removes this attempt and
problemating referencing of a `DeclRefExpr` that always ended up
discarded.

Fixes https://github.com/llvm/llvm-project/issues/62886

Reviewed By: cor3ntin

Differential Revision: https://reviews.llvm.org/D151553

clang/docs/ReleaseNotes.rst
clang/lib/Sema/TreeTransform.h
clang/test/SemaCXX/consteval-operators.cpp [new file with mode: 0644]
clang/test/SemaCXX/overloaded-operator.cpp

index a30d9f8..b17e746 100644 (file)
@@ -485,6 +485,8 @@ Bug Fixes in This Version
 - Fix assertion and quality of diagnostic messages in a for loop
   containing multiple declarations and a range specifier
   (`#63010 <https://github.com/llvm/llvm-project/issues/63010>`_).
+- Fix rejects-valid when consteval operator appears inside of a template.
+  (`#62886 <https://github.com/llvm/llvm-project/issues/62886>`_).
 
 Bug Fixes to Compiler Builtins
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
index bb6c0aa..f0401f0 100644 (file)
@@ -3053,10 +3053,11 @@ public:
   /// argument-dependent lookup, etc. Subclasses may override this routine to
   /// provide different behavior.
   ExprResult RebuildCXXOperatorCallExpr(OverloadedOperatorKind Op,
-                                              SourceLocation OpLoc,
-                                              Expr *Callee,
-                                              Expr *First,
-                                              Expr *Second);
+                                        SourceLocation OpLoc,
+                                        SourceLocation CalleeLoc,
+                                        bool RequiresADL,
+                                        const UnresolvedSetImpl &Functions,
+                                        Expr *First, Expr *Second);
 
   /// Build a new C++ "named" cast expression, such as static_cast or
   /// reinterpret_cast.
@@ -11962,10 +11963,6 @@ TreeTransform<Derived>::TransformCXXOperatorCallExpr(CXXOperatorCallExpr *E) {
     llvm_unreachable("not an overloaded operator?");
   }
 
-  ExprResult Callee = getDerived().TransformExpr(E->getCallee());
-  if (Callee.isInvalid())
-    return ExprError();
-
   ExprResult First;
   if (E->getOperator() == OO_Amp)
     First = getDerived().TransformAddressOfOperand(E->getArg(0));
@@ -11982,23 +11979,39 @@ TreeTransform<Derived>::TransformCXXOperatorCallExpr(CXXOperatorCallExpr *E) {
       return ExprError();
   }
 
-  if (!getDerived().AlwaysRebuild() &&
-      Callee.get() == E->getCallee() &&
-      First.get() == E->getArg(0) &&
-      (E->getNumArgs() != 2 || Second.get() == E->getArg(1)))
-    return SemaRef.MaybeBindToTemporary(E);
-
   Sema::FPFeaturesStateRAII FPFeaturesState(getSema());
   FPOptionsOverride NewOverrides(E->getFPFeatures());
   getSema().CurFPFeatures =
       NewOverrides.applyOverrides(getSema().getLangOpts());
   getSema().FpPragmaStack.CurrentValue = NewOverrides;
 
-  return getDerived().RebuildCXXOperatorCallExpr(E->getOperator(),
-                                                 E->getOperatorLoc(),
-                                                 Callee.get(),
-                                                 First.get(),
-                                                 Second.get());
+  Expr *Callee = E->getCallee();
+  if (UnresolvedLookupExpr *ULE = dyn_cast<UnresolvedLookupExpr>(Callee)) {
+    LookupResult R(SemaRef, ULE->getName(), ULE->getNameLoc(),
+                   Sema::LookupOrdinaryName);
+    if (getDerived().TransformOverloadExprDecls(ULE, ULE->requiresADL(), R))
+      return ExprError();
+
+    return getDerived().RebuildCXXOperatorCallExpr(
+        E->getOperator(), E->getOperatorLoc(), Callee->getBeginLoc(),
+        ULE->requiresADL(), R.asUnresolvedSet(), First.get(), Second.get());
+  }
+
+  UnresolvedSet<1> Functions;
+  if (ImplicitCastExpr *ICE = dyn_cast<ImplicitCastExpr>(Callee))
+    Callee = ICE->getSubExprAsWritten();
+  NamedDecl *DR = cast<DeclRefExpr>(Callee)->getDecl();
+  ValueDecl *VD = cast_or_null<ValueDecl>(
+      getDerived().TransformDecl(DR->getLocation(), DR));
+  if (!VD)
+    return ExprError();
+
+  if (!isa<CXXMethodDecl>(VD))
+    Functions.addDecl(VD);
+
+  return getDerived().RebuildCXXOperatorCallExpr(
+      E->getOperator(), E->getOperatorLoc(), Callee->getBeginLoc(),
+      /*RequiresADL=*/false, Functions, First.get(), Second.get());
 }
 
 template<typename Derived>
@@ -14108,13 +14121,17 @@ TreeTransform<Derived>::TransformCXXFoldExpr(CXXFoldExpr *E) {
       // We've got down to a single element; build a binary operator.
       Expr *LHS = LeftFold ? Result.get() : Out.get();
       Expr *RHS = LeftFold ? Out.get() : Result.get();
-      if (Callee)
+      if (Callee) {
+        UnresolvedSet<16> Functions;
+        Functions.append(Callee->decls_begin(), Callee->decls_end());
         Result = getDerived().RebuildCXXOperatorCallExpr(
             BinaryOperator::getOverloadedOperator(E->getOperator()),
-            E->getEllipsisLoc(), Callee, LHS, RHS);
-      else
+            E->getEllipsisLoc(), Callee->getBeginLoc(), Callee->requiresADL(),
+            Functions, LHS, RHS);
+      } else {
         Result = getDerived().RebuildBinaryOperator(E->getEllipsisLoc(),
                                                     E->getOperator(), LHS, RHS);
+      }
     } else
       Result = Out;
 
@@ -15118,14 +15135,11 @@ TreeTransform<Derived>::RebuildTemplateName(CXXScopeSpec &SS,
   return Template.get();
 }
 
-template<typename Derived>
-ExprResult
-TreeTransform<Derived>::RebuildCXXOperatorCallExpr(OverloadedOperatorKind Op,
-                                                   SourceLocation OpLoc,
-                                                   Expr *OrigCallee,
-                                                   Expr *First,
-                                                   Expr *Second) {
-  Expr *Callee = OrigCallee->IgnoreParenCasts();
+template <typename Derived>
+ExprResult TreeTransform<Derived>::RebuildCXXOperatorCallExpr(
+    OverloadedOperatorKind Op, SourceLocation OpLoc, SourceLocation CalleeLoc,
+    bool RequiresADL, const UnresolvedSetImpl &Functions, Expr *First,
+    Expr *Second) {
   bool isPostIncDec = Second && (Op == OO_PlusPlus || Op == OO_MinusMinus);
 
   if (First->getObjectKind() == OK_ObjCProperty) {
@@ -15150,8 +15164,8 @@ TreeTransform<Derived>::RebuildCXXOperatorCallExpr(OverloadedOperatorKind Op,
   if (Op == OO_Subscript) {
     if (!First->getType()->isOverloadableType() &&
         !Second->getType()->isOverloadableType())
-      return getSema().CreateBuiltinArraySubscriptExpr(
-          First, Callee->getBeginLoc(), Second, OpLoc);
+      return getSema().CreateBuiltinArraySubscriptExpr(First, CalleeLoc, Second,
+                                                       OpLoc);
   } else if (Op == OO_Arrow) {
     // It is possible that the type refers to a RecoveryExpr created earlier
     // in the tree transformation.
@@ -15185,27 +15199,6 @@ TreeTransform<Derived>::RebuildCXXOperatorCallExpr(OverloadedOperatorKind Op,
     }
   }
 
-  // Compute the transformed set of functions (and function templates) to be
-  // used during overload resolution.
-  UnresolvedSet<16> Functions;
-  bool RequiresADL;
-
-  if (UnresolvedLookupExpr *ULE = dyn_cast<UnresolvedLookupExpr>(Callee)) {
-    Functions.append(ULE->decls_begin(), ULE->decls_end());
-    // If the overload could not be resolved in the template definition
-    // (because we had a dependent argument), ADL is performed as part of
-    // template instantiation.
-    RequiresADL = ULE->requiresADL();
-  } else {
-    // If we've resolved this to a particular non-member function, just call
-    // that function. If we resolved it to a member function,
-    // CreateOverloaded* will find that function for us.
-    NamedDecl *ND = cast<DeclRefExpr>(Callee)->getDecl();
-    if (!isa<CXXMethodDecl>(ND))
-      Functions.addDecl(ND);
-    RequiresADL = false;
-  }
-
   // Add any functions found via argument-dependent lookup.
   Expr *Args[2] = { First, Second };
   unsigned NumArgs = 1 + (Second != nullptr);
@@ -15218,23 +15211,6 @@ TreeTransform<Derived>::RebuildCXXOperatorCallExpr(OverloadedOperatorKind Op,
                                            RequiresADL);
   }
 
-  if (Op == OO_Subscript) {
-    SourceLocation LBrace;
-    SourceLocation RBrace;
-
-    if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Callee)) {
-      DeclarationNameLoc NameLoc = DRE->getNameInfo().getInfo();
-      LBrace = NameLoc.getCXXOperatorNameBeginLoc();
-      RBrace = NameLoc.getCXXOperatorNameEndLoc();
-    } else {
-      LBrace = Callee->getBeginLoc();
-      RBrace = OpLoc;
-    }
-
-    return SemaRef.CreateOverloadedArraySubscriptExpr(LBrace, RBrace,
-                                                      First, Second);
-  }
-
   // Create the overloaded operator invocation for binary operators.
   BinaryOperatorKind Opc = BinaryOperator::getOverloadedOpcode(Op);
   ExprResult Result = SemaRef.CreateOverloadedBinOp(
diff --git a/clang/test/SemaCXX/consteval-operators.cpp b/clang/test/SemaCXX/consteval-operators.cpp
new file mode 100644 (file)
index 0000000..addb4d6
--- /dev/null
@@ -0,0 +1,46 @@
+// RUN: %clang_cc1 -std=c++2a -emit-llvm-only -Wno-unused-value %s -verify
+
+// expected-no-diagnostics
+
+struct A {
+  consteval A operator+() { return {}; }
+};
+consteval A operator~(A) { return {}; }
+consteval A operator+(A, A) { return {}; }
+
+template <class> void f() {
+  A a;
+  A b = ~a;
+  A c = a + a;
+  A d = +a;
+}
+template void f<int>();
+
+template <class T> void foo() {
+  T a;
+  T b = ~a;
+  T c = a + a;
+  T d = +a;
+}
+
+template void foo<A>();
+
+template <typename DataT> struct B { DataT D; };
+
+template <typename DataT>
+consteval B<DataT> operator+(B<DataT> lhs, B<DataT> rhs) {
+  return B<DataT>{lhs.D + rhs.D};
+}
+
+template <class T> consteval T template_add(T a, T b) { return a + b; }
+
+consteval B<int> non_template_add(B<int> a, B<int> b) { return a + b; }
+
+void bar() {
+  constexpr B<int> a{};
+  constexpr B<int> b{};
+  auto constexpr c = a + b;
+}
+
+static_assert((template_add(B<int>{7}, B<int>{3})).D == 10);
+static_assert((non_template_add(B<int>{7}, B<int>{3})).D == 10);
index 3290656..83a7e65 100644 (file)
@@ -585,3 +585,16 @@ namespace LateADLInNonDependentExpressions {
   float &operator->*(B, B);
   template void f<int>();
 }
+
+namespace test {
+namespace A {
+template<typename T> T f(T t) {
+  T operator+(T, T);
+  return t + t;
+}
+}
+namespace B {
+  struct X {};
+}
+void g(B::X x) { A::f(x); }
+}