From b9c55e2760b1be461f05a10ca92696b3381036d2 Mon Sep 17 00:00:00 2001 From: Alexey Bataev Date: Mon, 14 Oct 2019 19:29:52 +0000 Subject: [PATCH] [OPNEMP]Allow grainsize clause in combined task-based directives. The expression of the grainsize clause must be captured in the combined task-based directives, like 'parallel master taskloop' directive. llvm-svn: 374810 --- clang/include/clang/AST/OpenMPClause.h | 24 +++-- clang/include/clang/AST/RecursiveASTVisitor.h | 1 + clang/lib/AST/OpenMPClause.cpp | 9 +- clang/lib/AST/StmtProfile.cpp | 1 + clang/lib/Sema/SemaOpenMP.cpp | 105 +++++++++++++++++++-- clang/lib/Serialization/ASTReader.cpp | 1 + clang/lib/Serialization/ASTWriter.cpp | 1 + .../OpenMP/parallel_master_taskloop_codegen.cpp | 9 +- 8 files changed, 129 insertions(+), 22 deletions(-) diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h index 6c504c7..db780f7 100644 --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -5268,7 +5268,7 @@ public: /// \endcode /// In this example directive '#pragma omp taskloop' has clause 'grainsize' /// with single expression '4'. -class OMPGrainsizeClause : public OMPClause { +class OMPGrainsizeClause : public OMPClause, public OMPClauseWithPreInit { friend class OMPClauseReader; /// Location of '('. @@ -5284,16 +5284,23 @@ public: /// Build 'grainsize' clause. /// /// \param Size Expression associated with this clause. + /// \param HelperSize Helper grainsize for the construct. + /// \param CaptureRegion Innermost OpenMP region where expressions in this + /// clause must be captured. /// \param StartLoc Starting location of the clause. /// \param EndLoc Ending location of the clause. - OMPGrainsizeClause(Expr *Size, SourceLocation StartLoc, + OMPGrainsizeClause(Expr *Size, Stmt *HelperSize, + OpenMPDirectiveKind CaptureRegion, SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc) - : OMPClause(OMPC_grainsize, StartLoc, EndLoc), LParenLoc(LParenLoc), - Grainsize(Size) {} + : OMPClause(OMPC_grainsize, StartLoc, EndLoc), OMPClauseWithPreInit(this), + LParenLoc(LParenLoc), Grainsize(Size) { + setPreInitStmt(HelperSize, CaptureRegion); + } /// Build an empty clause. explicit OMPGrainsizeClause() - : OMPClause(OMPC_grainsize, SourceLocation(), SourceLocation()) {} + : OMPClause(OMPC_grainsize, SourceLocation(), SourceLocation()), + OMPClauseWithPreInit(this) {} /// Sets the location of '('. void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } @@ -5310,11 +5317,10 @@ public: return const_child_range(&Grainsize, &Grainsize + 1); } - child_range used_children() { - return child_range(child_iterator(), child_iterator()); - } + child_range used_children(); const_child_range used_children() const { - return const_child_range(const_child_iterator(), const_child_iterator()); + auto Children = const_cast(this)->used_children(); + return const_child_range(Children.begin(), Children.end()); } static bool classof(const OMPClause *T) { diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index cfeaec4..ba5232a 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3275,6 +3275,7 @@ bool RecursiveASTVisitor::VisitOMPPriorityClause( template bool RecursiveASTVisitor::VisitOMPGrainsizeClause( OMPGrainsizeClause *C) { + TRY_TO(VisitOMPClauseWithPreInit(C)); TRY_TO(TraverseStmt(C->getGrainsize())); return true; } diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp index bfe272b..b97607f 100644 --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -84,6 +84,8 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const OMPClause *C) { return static_cast(C); case OMPC_device: return static_cast(C); + case OMPC_grainsize: + return static_cast(C); case OMPC_default: case OMPC_proc_bind: case OMPC_final: @@ -113,7 +115,6 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const OMPClause *C) { case OMPC_simd: case OMPC_map: case OMPC_priority: - case OMPC_grainsize: case OMPC_nogroup: case OMPC_num_tasks: case OMPC_hint: @@ -234,6 +235,12 @@ OMPClause::child_range OMPIfClause::used_children() { return child_range(&Condition, &Condition + 1); } +OMPClause::child_range OMPGrainsizeClause::used_children() { + if (Stmt **C = getAddrOfExprAsWritten(getPreInitStmt())) + return child_range(C, C + 1); + return child_range(&Grainsize, &Grainsize + 1); +} + OMPOrderedClause *OMPOrderedClause::Create(const ASTContext &C, Expr *Num, unsigned NumLoops, SourceLocation StartLoc, diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index efc64af..6c65f8a 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -740,6 +740,7 @@ void OMPClauseProfiler::VisitOMPPriorityClause(const OMPPriorityClause *C) { Profiler->VisitStmt(C->getPriority()); } void OMPClauseProfiler::VisitOMPGrainsizeClause(const OMPGrainsizeClause *C) { + VistOMPClauseWithPreInit(C); if (C->getGrainsize()) Profiler->VisitStmt(C->getGrainsize()); } diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index f717dc7..94cd80a 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -4590,12 +4590,16 @@ StmtResult Sema::ActOnOpenMPExecutableDirective( continue; case OMPC_schedule: break; + case OMPC_grainsize: + // Do not analyze if no parent parallel directive. + if (isOpenMPParallelDirective(DSAStack->getCurrentDirective())) + break; + continue; case OMPC_ordered: case OMPC_device: case OMPC_num_teams: case OMPC_thread_limit: case OMPC_priority: - case OMPC_grainsize: case OMPC_num_tasks: case OMPC_hint: case OMPC_collapse: @@ -10773,6 +10777,74 @@ static OpenMPDirectiveKind getOpenMPCaptureRegionForClause( llvm_unreachable("Unknown OpenMP directive"); } break; + case OMPC_grainsize: + switch (DKind) { + case OMPD_task: + case OMPD_taskloop: + case OMPD_taskloop_simd: + case OMPD_master_taskloop: + break; + case OMPD_parallel_master_taskloop: + CaptureRegion = OMPD_parallel; + break; + case OMPD_target_update: + case OMPD_target_enter_data: + case OMPD_target_exit_data: + case OMPD_target: + case OMPD_target_simd: + case OMPD_target_teams: + case OMPD_target_parallel: + case OMPD_target_teams_distribute: + case OMPD_target_teams_distribute_simd: + case OMPD_target_parallel_for: + case OMPD_target_parallel_for_simd: + case OMPD_target_teams_distribute_parallel_for: + case OMPD_target_teams_distribute_parallel_for_simd: + case OMPD_target_data: + case OMPD_teams_distribute_parallel_for: + case OMPD_teams_distribute_parallel_for_simd: + case OMPD_teams: + case OMPD_teams_distribute: + case OMPD_teams_distribute_simd: + case OMPD_distribute_parallel_for: + case OMPD_distribute_parallel_for_simd: + case OMPD_cancel: + case OMPD_parallel: + case OMPD_parallel_sections: + case OMPD_parallel_for: + case OMPD_parallel_for_simd: + case OMPD_threadprivate: + case OMPD_allocate: + case OMPD_taskyield: + case OMPD_barrier: + case OMPD_taskwait: + case OMPD_cancellation_point: + case OMPD_flush: + case OMPD_declare_reduction: + case OMPD_declare_mapper: + case OMPD_declare_simd: + case OMPD_declare_variant: + case OMPD_declare_target: + case OMPD_end_declare_target: + case OMPD_simd: + case OMPD_for: + case OMPD_for_simd: + case OMPD_sections: + case OMPD_section: + case OMPD_single: + case OMPD_master: + case OMPD_critical: + case OMPD_taskgroup: + case OMPD_distribute: + case OMPD_ordered: + case OMPD_atomic: + case OMPD_distribute_simd: + case OMPD_requires: + llvm_unreachable("Unexpected OpenMP directive with grainsize-clause"); + case OMPD_unknown: + llvm_unreachable("Unknown OpenMP directive"); + } + break; case OMPC_firstprivate: case OMPC_lastprivate: case OMPC_reduction: @@ -10808,7 +10880,6 @@ static OpenMPDirectiveKind getOpenMPCaptureRegionForClause( case OMPC_simd: case OMPC_map: case OMPC_priority: - case OMPC_grainsize: case OMPC_nogroup: case OMPC_num_tasks: case OMPC_hint: @@ -10926,9 +10997,12 @@ ExprResult Sema::PerformOpenMPImplicitIntegerConversion(SourceLocation Loc, return PerformContextualImplicitConversion(Loc, Op, ConvertDiagnoser); } -static bool isNonNegativeIntegerValue(Expr *&ValExpr, Sema &SemaRef, - OpenMPClauseKind CKind, - bool StrictlyPositive) { +static bool +isNonNegativeIntegerValue(Expr *&ValExpr, Sema &SemaRef, OpenMPClauseKind CKind, + bool StrictlyPositive, bool BuildCapture = false, + OpenMPDirectiveKind DKind = OMPD_unknown, + OpenMPDirectiveKind *CaptureRegion = nullptr, + Stmt **HelperValStmt = nullptr) { if (!ValExpr->isTypeDependent() && !ValExpr->isValueDependent() && !ValExpr->isInstantiationDependent()) { SourceLocation Loc = ValExpr->getExprLoc(); @@ -10949,6 +11023,16 @@ static bool isNonNegativeIntegerValue(Expr *&ValExpr, Sema &SemaRef, << ValExpr->getSourceRange(); return false; } + if (!BuildCapture) + return true; + *CaptureRegion = getOpenMPCaptureRegionForClause(DKind, CKind); + if (*CaptureRegion != OMPD_unknown && + !SemaRef.CurContext->isDependentContext()) { + ValExpr = SemaRef.MakeFullExpr(ValExpr).get(); + llvm::MapVector Captures; + ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get(); + *HelperValStmt = buildPreInits(SemaRef.Context, Captures); + } } return true; } @@ -15847,15 +15931,20 @@ OMPClause *Sema::ActOnOpenMPGrainsizeClause(Expr *Grainsize, SourceLocation LParenLoc, SourceLocation EndLoc) { Expr *ValExpr = Grainsize; + Stmt *HelperValStmt = nullptr; + OpenMPDirectiveKind CaptureRegion = OMPD_unknown; // OpenMP [2.9.2, taskloop Constrcut] // The parameter of the grainsize clause must be a positive integer // expression. - if (!isNonNegativeIntegerValue(ValExpr, *this, OMPC_grainsize, - /*StrictlyPositive=*/true)) + if (!isNonNegativeIntegerValue( + ValExpr, *this, OMPC_grainsize, + /*StrictlyPositive=*/true, /*BuildCapture=*/true, + DSAStack->getCurrentDirective(), &CaptureRegion, &HelperValStmt)) return nullptr; - return new (Context) OMPGrainsizeClause(ValExpr, StartLoc, LParenLoc, EndLoc); + return new (Context) OMPGrainsizeClause(ValExpr, HelperValStmt, CaptureRegion, + StartLoc, LParenLoc, EndLoc); } OMPClause *Sema::ActOnOpenMPNumTasksClause(Expr *NumTasks, diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp index 55f2be3..0a7958f 100644 --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -12934,6 +12934,7 @@ void OMPClauseReader::VisitOMPPriorityClause(OMPPriorityClause *C) { } void OMPClauseReader::VisitOMPGrainsizeClause(OMPGrainsizeClause *C) { + VisitOMPClauseWithPreInit(C); C->setGrainsize(Record.readSubExpr()); C->setLParenLoc(Record.readSourceLocation()); } diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp index df89e44..57c9242 100644 --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -6938,6 +6938,7 @@ void OMPClauseWriter::VisitOMPPriorityClause(OMPPriorityClause *C) { } void OMPClauseWriter::VisitOMPGrainsizeClause(OMPGrainsizeClause *C) { + VisitOMPClauseWithPreInit(C); Record.AddStmt(C->getGrainsize()); Record.AddSourceLocation(C->getLParenLoc()); } diff --git a/clang/test/OpenMP/parallel_master_taskloop_codegen.cpp b/clang/test/OpenMP/parallel_master_taskloop_codegen.cpp index ab15c48..2a2f4eb 100644 --- a/clang/test/OpenMP/parallel_master_taskloop_codegen.cpp +++ b/clang/test/OpenMP/parallel_master_taskloop_codegen.cpp @@ -14,7 +14,7 @@ int main(int argc, char **argv) { // CHECK: [[GTID:%.+]] = call i32 @__kmpc_global_thread_num(%struct.ident_t* [[DEFLOC:@.+]]) // CHECK: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEFLOC]], i32 0, void (i32*, i32*, ...)* bitcast (void (i32*, i32*)* [[OMP_OUTLINED1:@.+]] to void (i32*, i32*, ...)*)) -// CHECK: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEFLOC]], i32 0, void (i32*, i32*, ...)* bitcast (void (i32*, i32*)* [[OMP_OUTLINED2:@.+]] to void (i32*, i32*, ...)*)) +// CHECK: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEFLOC]], i32 1, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i64)* [[OMP_OUTLINED2:@.+]] to void (i32*, i32*, ...)*), i64 [[GRAINSIZE:%.+]]) // CHECK: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEFLOC]], i32 3, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, i32*, i8***, i64)* [[OMP_OUTLINED3:@.+]] to void (i32*, i32*, ...)*), i32* [[ARGC:%.+]], i8*** [[ARGV:%.+]], i64 [[COND:%.+]]) // CHECK: call void @__kmpc_serialized_parallel(%struct.ident_t* [[DEFLOC]], i32 [[GTID]]) // CHECK: call void [[OMP_OUTLINED3]](i32* %{{.+}}, i32* %{{.+}}, i32* [[ARGC]], i8*** [[ARGV]], i64 [[COND]]) @@ -77,7 +77,7 @@ int main(int argc, char **argv) { #pragma omp parallel master taskloop priority(4) for (int i = 0; i < 10; ++i) ; -// CHECK: define internal void [[OMP_OUTLINED2]](i32* noalias %{{.+}}, i32* noalias %{{.+}}) +// CHECK: define internal void [[OMP_OUTLINED2]](i32* noalias %{{.+}}, i32* noalias %{{.+}}, i64 %{{.+}}) // CHECK: [[RES:%.+]] = call {{.*}}i32 @__kmpc_master(%struct.ident_t* [[DEFLOC]], i32 [[GTID:%.+]]) // CHECK-NEXT: [[IS_MASTER:%.+]] = icmp ne i32 [[RES]], 0 // CHECK-NEXT: br i1 [[IS_MASTER]], label {{%?}}[[THEN:.+]], label {{%?}}[[EXIT:.+]] @@ -92,7 +92,8 @@ int main(int argc, char **argv) { // CHECK: [[ST:%.+]] = getelementptr inbounds [[TD_TY]], [[TD_TY]]* [[TASK_DATA]], i32 0, i32 7 // CHECK: store i64 1, i64* [[ST]], // CHECK: [[ST_VAL:%.+]] = load i64, i64* [[ST]], -// CHECK: call void @__kmpc_taskloop(%struct.ident_t* [[DEFLOC]], i32 [[GTID]], i8* [[TASKV]], i32 1, i64* [[DOWN]], i64* [[UP]], i64 [[ST_VAL]], i32 1, i32 1, i64 4, i8* null) +// CHECK: [[GRAINSIZE:%.+]] = zext i32 %{{.+}} to i64 +// CHECK: call void @__kmpc_taskloop(%struct.ident_t* [[DEFLOC]], i32 [[GTID]], i8* [[TASKV]], i32 1, i64* [[DOWN]], i64* [[UP]], i64 [[ST_VAL]], i32 1, i32 1, i64 [[GRAINSIZE]], i8* null) // CHECK-NEXT: call {{.*}}void @__kmpc_end_master(%struct.ident_t* [[DEFLOC]], i32 [[GTID]]) // CHECK-NEXT: br label {{%?}}[[EXIT]] // CHECK: [[EXIT]] @@ -128,7 +129,7 @@ int main(int argc, char **argv) { // CHECK: br label % // CHECK: ret i32 0 -#pragma omp parallel master taskloop nogroup grainsize(4) +#pragma omp parallel master taskloop nogroup grainsize(argc) for (int i = 0; i < 10; ++i) ; // CHECK: define internal void [[OMP_OUTLINED3]](i32* noalias %{{.+}}, i32* noalias %{{.+}}, i32* dereferenceable(4) %{{.+}}, i8*** dereferenceable(8) %{{.+}}, i64 %{{.+}}) -- 2.7.4