[OPENMP5.0]Introduce attribute for declare variant directive.
authorAlexey Bataev <a.bataev@hotmail.com>
Tue, 17 Sep 2019 17:36:49 +0000 (17:36 +0000)
committerAlexey Bataev <a.bataev@hotmail.com>
Tue, 17 Sep 2019 17:36:49 +0000 (17:36 +0000)
Added attribute for declare variant directive. It will allow to handle
declare variant directive at the codegen and will allow to add extra
checks.

llvm-svn: 372147

clang/include/clang/Basic/Attr.td
clang/include/clang/Basic/AttrDocs.td
clang/include/clang/Basic/DiagnosticSemaKinds.td
clang/include/clang/Sema/Sema.h
clang/lib/Sema/SemaExpr.cpp
clang/lib/Sema/SemaOpenMP.cpp
clang/lib/Sema/SemaTemplateInstantiateDecl.cpp
clang/test/OpenMP/declare_variant_ast_print.c [new file with mode: 0644]
clang/test/OpenMP/declare_variant_ast_print.cpp [new file with mode: 0644]
clang/test/OpenMP/declare_variant_messages.c

index 4cb1618..dfc8590 100644 (file)
@@ -3265,6 +3265,29 @@ def OMPAllocateDecl : InheritableAttr {
   let Documentation = [Undocumented];
 }
 
+def OMPDeclareVariant : Attr {
+  let Spellings = [Pragma<"omp", "declare variant">];
+  let Subjects = SubjectList<[Function]>;
+  let SemaHandler = 0;
+  let HasCustomParsing = 1;
+  let Documentation = [OMPDeclareVariantDocs];
+  let Args = [
+    ExprArgument<"VariantFuncRef">
+  ];
+  let AdditionalMembers = [{
+    void printPrettyPragma(raw_ostream & OS, const PrintingPolicy &Policy)
+        const {
+      if (const Expr *E = getVariantFuncRef()) {
+        OS << "(";
+        E->printPretty(OS, nullptr, Policy);
+        OS << ")";
+      }
+      // TODO: add printing of real context selectors.
+      OS << " match(unknown={})";
+    }
+  }];
+}
+
 def InternalLinkage : InheritableAttr {
   let Spellings = [Clang<"internal_linkage">];
   let Subjects = SubjectList<[Var, Function, CXXRecord]>;
index 1379da8..2e9ec57 100644 (file)
@@ -3208,6 +3208,34 @@ where clause is one of the following:
   }];
 }
 
+def OMPDeclareVariantDocs : Documentation {
+  let Category = DocCatFunction;
+  let Heading = "#pragma omp declare variant";
+  let Content = [{
+The `declare variant` directive declares a specialized variant of a base
+ function and specifies the context in which that specialized variant is used.
+ The declare variant directive is a declarative directive.
+The syntax of the `declare variant` construct is as follows:
+
+  .. code-block:: none
+
+    #pragma omp declare variant(variant-func-id) clause new-line
+    [#pragma omp declare variant(variant-func-id) clause new-line]
+    [...]
+    function definition or declaration
+
+where clause is one of the following:
+
+  .. code-block:: none
+
+    match(context-selector-specification)
+
+and where `variant-func-id` is the name of a function variant that is either a
+ base language identifier or, for C++, a template-id.
+
+  }];
+}
+
 def NoStackProtectorDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
index e550955..b80cda5 100644 (file)
@@ -9429,6 +9429,10 @@ def err_omp_declare_variant_diff : Error<
 def err_omp_declare_variant_incompat_types : Error<
   "variant in '#pragma omp declare variant' with type %0 is incompatible with type %1"
   >;
+def warn_omp_declare_variant_marked_as_declare_variant : Warning<
+  "variant function in '#pragma omp declare variant' is itself marked as '#pragma omp declare variant'"
+  >, InGroup<SourceUsesOpenMP>;
+def note_omp_marked_declare_variant_here : Note<"marked as 'declare variant' here">;
 } // end of OpenMP category
 
 let CategoryName = "Related Result Type Issue" in {
index 59fc120..c7f15f3 100644 (file)
@@ -9089,6 +9089,12 @@ private:
                                      MapT &Map, unsigned Selector = 0,
                                      SourceRange SrcRange = SourceRange());
 
+  /// Marks all the functions that might be required for the currently active
+  /// OpenMP context.
+  void markOpenMPDeclareVariantFuncsReferenced(SourceLocation Loc,
+                                               FunctionDecl *Func,
+                                               bool MightBeOdrUse);
+
 public:
   /// Checks if the variant/multiversion functions are compatible.
   bool areMultiversionVariantFunctionsCompatible(
index 99e4389..808c0e4 100644 (file)
@@ -15497,6 +15497,7 @@ void Sema::MarkFunctionReferenced(SourceLocation Loc, FunctionDecl *Func,
   }
 
   if (LangOpts.OpenMP) {
+    markOpenMPDeclareVariantFuncsReferenced(Loc, Func, MightBeOdrUse);
     if (LangOpts.OpenMPIsDevice)
       checkOpenMPDeviceFunction(Loc, Func);
     else
index ef2ef8b..cd343bd 100644 (file)
@@ -4945,8 +4945,12 @@ Sema::ActOnOpenMPDeclareVariantDirective(Sema::DeclGroupPtrTy DG,
   // Do not check templates, wait until instantiation.
   if (VariantRef->isTypeDependent() || VariantRef->isValueDependent() ||
       VariantRef->containsUnexpandedParameterPack() ||
-      VariantRef->isInstantiationDependent() || FD->isDependentContext())
+      VariantRef->isInstantiationDependent() || FD->isDependentContext()) {
+    auto *NewAttr =
+        OMPDeclareVariantAttr::CreateImplicit(Context, VariantRef, SR);
+    FD->addAttr(NewAttr);
     return DG;
+  }
 
   // Convert VariantRef expression to the type of the original function to
   // resolve possible conflicts.
@@ -5025,6 +5029,17 @@ Sema::ActOnOpenMPDeclareVariantDirective(Sema::DeclGroupPtrTy DG,
     return DG;
   }
 
+  // Check if variant function is not marked with declare variant directive.
+  if (NewFD->hasAttrs() && NewFD->hasAttr<OMPDeclareVariantAttr>()) {
+    Diag(VariantRef->getExprLoc(),
+         diag::warn_omp_declare_variant_marked_as_declare_variant)
+        << VariantRef->getSourceRange();
+    SourceRange SR =
+        NewFD->specific_attr_begin<OMPDeclareVariantAttr>()->getRange();
+    Diag(SR.getBegin(), diag::note_omp_marked_declare_variant_here) << SR;
+    return DG;
+  }
+
   enum DoesntSupport {
     VirtFuncs = 1,
     Constructors = 3,
@@ -5087,9 +5102,30 @@ Sema::ActOnOpenMPDeclareVariantDirective(Sema::DeclGroupPtrTy DG,
           /*TemplatesSupported=*/true, /*ConstexprSupported=*/false))
     return DG;
 
+  auto *NewAttr = OMPDeclareVariantAttr::CreateImplicit(Context, DRE, SR);
+  FD->addAttr(NewAttr);
   return DG;
 }
 
+void Sema::markOpenMPDeclareVariantFuncsReferenced(SourceLocation Loc,
+                                                   FunctionDecl *Func,
+                                                   bool MightBeOdrUse) {
+  assert(LangOpts.OpenMP && "Expected OpenMP mode.");
+
+  if (!Func->isDependentContext() && Func->hasAttrs()) {
+    for (OMPDeclareVariantAttr *A :
+         Func->specific_attrs<OMPDeclareVariantAttr>()) {
+      // TODO: add checks for active OpenMP context where possible.
+      Expr *VariantRef = A->getVariantFuncRef();
+      auto *DRE = dyn_cast<DeclRefExpr>(VariantRef->IgnoreParenImpCasts());
+      auto *F = cast<FunctionDecl>(DRE->getDecl());
+      if (!F->isDefined() && F->isTemplateInstantiation())
+        InstantiateFunctionDefinition(Loc, F->getFirstDecl());
+      MarkFunctionReferenced(Loc, F, MightBeOdrUse);
+    }
+  }
+}
+
 StmtResult Sema::ActOnOpenMPParallelDirective(ArrayRef<OMPClause *> Clauses,
                                               Stmt *AStmt,
                                               SourceLocation StartLoc,
index 7521cad..6585917 100644 (file)
@@ -348,6 +348,50 @@ static void instantiateOMPDeclareSimdDeclAttr(
       Attr.getRange());
 }
 
+/// Instantiation of 'declare variant' attribute and its arguments.
+static void instantiateOMPDeclareVariantAttr(
+    Sema &S, const MultiLevelTemplateArgumentList &TemplateArgs,
+    const OMPDeclareVariantAttr &Attr, Decl *New) {
+  // Allow 'this' in clauses with varlists.
+  if (auto *FTD = dyn_cast<FunctionTemplateDecl>(New))
+    New = FTD->getTemplatedDecl();
+  auto *FD = cast<FunctionDecl>(New);
+  auto *ThisContext = dyn_cast_or_null<CXXRecordDecl>(FD->getDeclContext());
+
+  auto &&SubstExpr = [FD, ThisContext, &S, &TemplateArgs](Expr *E) {
+    if (auto *DRE = dyn_cast<DeclRefExpr>(E->IgnoreParenImpCasts()))
+      if (auto *PVD = dyn_cast<ParmVarDecl>(DRE->getDecl())) {
+        Sema::ContextRAII SavedContext(S, FD);
+        LocalInstantiationScope Local(S);
+        if (FD->getNumParams() > PVD->getFunctionScopeIndex())
+          Local.InstantiatedLocal(
+              PVD, FD->getParamDecl(PVD->getFunctionScopeIndex()));
+        return S.SubstExpr(E, TemplateArgs);
+      }
+    Sema::CXXThisScopeRAII ThisScope(S, ThisContext, Qualifiers(),
+                                     FD->isCXXInstanceMember());
+    return S.SubstExpr(E, TemplateArgs);
+  };
+
+  // Substitute a single OpenMP clause, which is a potentially-evaluated
+  // full-expression.
+  auto &&Subst = [&SubstExpr, &S](Expr *E) {
+    EnterExpressionEvaluationContext Evaluated(
+        S, Sema::ExpressionEvaluationContext::PotentiallyEvaluated);
+    ExprResult Res = SubstExpr(E);
+    if (Res.isInvalid())
+      return Res;
+    return S.ActOnFinishFullExpr(Res.get(), false);
+  };
+
+  ExprResult VariantFuncRef;
+  if (Expr *E = Attr.getVariantFuncRef())
+    VariantFuncRef = Subst(E);
+
+  (void)S.ActOnOpenMPDeclareVariantDirective(
+      S.ConvertDeclToDeclGroup(New), VariantFuncRef.get(), Attr.getRange());
+}
+
 static void instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
     Sema &S, const MultiLevelTemplateArgumentList &TemplateArgs,
     const AMDGPUFlatWorkGroupSizeAttr &Attr, Decl *New) {
@@ -505,6 +549,11 @@ void Sema::InstantiateAttrs(const MultiLevelTemplateArgumentList &TemplateArgs,
       continue;
     }
 
+    if (const auto *OMPAttr = dyn_cast<OMPDeclareVariantAttr>(TmplAttr)) {
+      instantiateOMPDeclareVariantAttr(*this, TemplateArgs, *OMPAttr, New);
+      continue;
+    }
+
     if (const auto *AMDGPUFlatWorkGroupSize =
             dyn_cast<AMDGPUFlatWorkGroupSizeAttr>(TmplAttr)) {
       instantiateDependentAMDGPUFlatWorkGroupSizeAttr(
diff --git a/clang/test/OpenMP/declare_variant_ast_print.c b/clang/test/OpenMP/declare_variant_ast_print.c
new file mode 100644 (file)
index 0000000..e1632bc
--- /dev/null
@@ -0,0 +1,16 @@
+// RUN: %clang_cc1 -verify -fopenmp -x c -std=c99 -ast-print %s -o - | FileCheck %s
+
+// RUN: %clang_cc1 -verify -fopenmp-simd -x c -std=c99 -ast-print %s -o - | FileCheck %s
+
+// expected-no-diagnostics
+
+int foo(void);
+
+#pragma omp declare variant(foo) match(xxx={})
+#pragma omp declare variant(foo) match(xxx={vvv})
+int bar(void);
+
+// CHECK:      int foo();
+// CHECK-NEXT: #pragma omp declare variant(foo) match(unknown={})
+// CHECK-NEXT: #pragma omp declare variant(foo) match(unknown={})
+// CHECK-NEXT: int bar();
diff --git a/clang/test/OpenMP/declare_variant_ast_print.cpp b/clang/test/OpenMP/declare_variant_ast_print.cpp
new file mode 100644 (file)
index 0000000..b4b515d
--- /dev/null
@@ -0,0 +1,161 @@
+// RUN: %clang_cc1 -verify -fopenmp -x c++ -std=c++14 -fexceptions -fcxx-exceptions %s -ast-print -o - | FileCheck %s
+
+// RUN: %clang_cc1 -verify -fopenmp-simd -x c++ -std=c++14 -fexceptions -fcxx-exceptions %s -ast-print -o - | FileCheck %s
+
+// expected-no-diagnostics
+
+// CHECK: int foo();
+int foo();
+
+// CHECK:      template <typename T> T foofoo() {
+// CHECK-NEXT: return T();
+// CHECK-NEXT: }
+template <typename T>
+T foofoo() { return T(); }
+
+// CHECK:      template<> int foofoo<int>() {
+// CHECK-NEXT: return int();
+// CHECK-NEXT: }
+
+// CHECK:      #pragma omp declare variant(foofoo<int>) match(unknown={})
+// CHECK-NEXT: #pragma omp declare variant(foofoo<int>) match(unknown={})
+// CHECK-NEXT: int bar();
+#pragma omp declare variant(foofoo <int>) match(xxx = {})
+#pragma omp declare variant(foofoo <int>) match(xxx = {vvv})
+int bar();
+
+// CHECK:      #pragma omp declare variant(foofoo<T>) match(unknown={})
+// CHECK-NEXT: #pragma omp declare variant(foofoo<T>) match(unknown={})
+// CHECK-NEXT: #pragma omp declare variant(foofoo<T>) match(unknown={})
+// CHECK-NEXT: #pragma omp declare variant(foofoo<T>) match(unknown={})
+// CHECK-NEXT: #pragma omp declare variant(foofoo<T>) match(unknown={})
+// CHECK-NEXT: #pragma omp declare variant(foofoo<T>) match(unknown={})
+// CHECK-NEXT: template <typename T> T barbar();
+#pragma omp declare variant(foofoo <T>) match(xxx = {})
+#pragma omp declare variant(foofoo <T>) match(xxx = {vvv})
+#pragma omp declare variant(foofoo <T>) match(user = {score(<expr>) : condition(<expr>)})
+#pragma omp declare variant(foofoo <T>) match(user = {score(<expr>) : condition(<expr>)})
+#pragma omp declare variant(foofoo <T>) match(user = {condition(<expr>)})
+#pragma omp declare variant(foofoo <T>) match(user = {condition(<expr>)})
+template <typename T>
+T barbar();
+
+// CHECK:      #pragma omp declare variant(foofoo<int>) match(unknown={})
+// CHECK-NEXT: #pragma omp declare variant(foofoo<int>) match(unknown={})
+// CHECK-NEXT: #pragma omp declare variant(foofoo<int>) match(unknown={})
+// CHECK-NEXT: #pragma omp declare variant(foofoo<int>) match(unknown={})
+// CHECK-NEXT: #pragma omp declare variant(foofoo<int>) match(unknown={})
+// CHECK-NEXT: #pragma omp declare variant(foofoo<int>) match(unknown={})
+// CHECK-NEXT: template<> int barbar<int>();
+
+// CHECK-NEXT: int baz() {
+// CHECK-NEXT: return barbar<int>();
+// CHECK-NEXT: }
+int baz() {
+  return barbar<int>();
+}
+
+// CHECK:      template <class C> void h_ref(C *hp, C *hp2, C *hq, C *lin) {
+// CHECK-NEXT: }
+// CHECK-NEXT: template<> void h_ref<double>(double *hp, double *hp2, double *hq, double *lin) {
+// CHECK-NEXT: }
+// CHECK-NEXT: template<> void h_ref<float>(float *hp, float *hp2, float *hq, float *lin) {
+// CHECK-NEXT: }
+template <class C>
+void h_ref(C *hp, C *hp2, C *hq, C *lin) {
+}
+
+// CHECK:      #pragma omp declare variant(h_ref<C>) match(unknown={})
+// CHECK-NEXT: template <class C> void h(C *hp, C *hp2, C *hq, C *lin) {
+// CHECK-NEXT: }
+#pragma omp declare variant(h_ref <C>) match(xxx = {})
+template <class C>
+void h(C *hp, C *hp2, C *hq, C *lin) {
+}
+
+// CHECK:      #pragma omp declare variant(h_ref<float>) match(unknown={})
+// CHECK-NEXT: template<> void h<float>(float *hp, float *hp2, float *hq, float *lin) {
+// CHECK-NEXT: }
+
+// CHECK-NEXT: template<> void h<double>(double *hp, double *hp2, double *hq, double *lin) {
+// CHECK-NEXT:   h((float *)hp, (float *)hp2, (float *)hq, (float *)lin);
+// CHECK-NEXT: }
+#pragma omp declare variant(h_ref <double>) match(xxx = {})
+template <>
+void h(double *hp, double *hp2, double *hq, double *lin) {
+  h((float *)hp, (float *)hp2, (float *)hq, (float *)lin);
+}
+
+// CHECK: int fn();
+int fn();
+// CHECK: int fn(int);
+int fn(int);
+// CHECK:      #pragma omp declare variant(fn) match(unknown={})
+// CHECK-NEXT: int overload();
+#pragma omp declare variant(fn) match(xxx = {})
+int overload(void);
+
+// CHECK:      int fn_deduced_variant() {
+// CHECK-NEXT: return 0;
+// CHECK-NEXT: }
+auto fn_deduced_variant() { return 0; }
+// CHECK:      #pragma omp declare variant(fn_deduced_variant) match(unknown={})
+// CHECK-NEXT: int fn_deduced();
+#pragma omp declare variant(fn_deduced_variant) match(xxx = {})
+int fn_deduced();
+
+// CHECK: int fn_deduced_variant1();
+int fn_deduced_variant1();
+// CHECK:      #pragma omp declare variant(fn_deduced_variant1) match(unknown={})
+// CHECK-NEXT: int fn_deduced1() {
+// CHECK-NEXT: return 0;
+// CHECK-NEXT: }
+#pragma omp declare variant(fn_deduced_variant1) match(xxx = {})
+auto fn_deduced1() { return 0; }
+
+// CHECK:      struct SpecialFuncs {
+// CHECK-NEXT: void vd() {
+// CHECK-NEXT: }
+// CHECK-NEXT: SpecialFuncs();
+// CHECK-NEXT: ~SpecialFuncs() noexcept;
+// CHECK-NEXT: void baz() {
+// CHECK-NEXT: }
+// CHECK-NEXT: void bar() {
+// CHECK-NEXT: }
+// CHECK-NEXT: void bar(int) {
+// CHECK-NEXT: }
+// CHECK-NEXT: #pragma omp declare variant(SpecialFuncs::bar) match(unknown={})
+// CHECK-NEXT: #pragma omp declare variant(SpecialFuncs::baz) match(unknown={})
+// CHECK-NEXT: void foo1() {
+// CHECK-NEXT: }
+// CHECK-NEXT: } s;
+struct SpecialFuncs {
+  void vd() {}
+  SpecialFuncs();
+  ~SpecialFuncs();
+
+  void baz() {}
+  void bar() {}
+  void bar(int) {}
+#pragma omp declare variant(SpecialFuncs::baz) match(xxx = {})
+#pragma omp declare variant(SpecialFuncs::bar) match(xxx = {})
+  void foo1() {}
+} s;
+
+// CHECK:      static void static_f_variant() {
+// CHECK-NEXT: }
+static void static_f_variant() {}
+// CHECK:      #pragma omp declare variant(static_f_variant) match(unknown={})
+// CHECK-NEXT: static void static_f() {
+// CHECK-NEXT: }
+#pragma omp declare variant(static_f_variant) match(xxx = {})
+static void static_f() {}
+
+// CHECK: void bazzzz() {
+// CHECK-NEXT: s.foo1();
+// CHECK-NEXT: static_f();
+// CHECK-NEXT: }
+void bazzzz() {
+  s.foo1();
+  static_f();
+}
index 3d8ae66..93023fd 100644 (file)
@@ -86,6 +86,15 @@ int diff_ret_variant(void);
 #pragma omp declare variant(diff_ret_variant) match(xxx={})
 void diff_ret(void);
 
+void marked(void);
+void not_marked(void);
+// expected-note@+1 {{marked as 'declare variant' here}}
+#pragma omp declare variant(not_marked) match(xxx={})
+void marked_variant(void);
+// expected-warning@+1 {{variant function in '#pragma omp declare variant' is itself marked as '#pragma omp declare variant'}}
+#pragma omp declare variant(marked_variant) match(xxx={})
+void marked(void);
+
 // expected-error@+1 {{function declaration is expected after 'declare variant' directive}}
 #pragma omp declare variant
 // expected-error@+1 {{function declaration is expected after 'declare variant' directive}}