From: Michael Kruse Date: Wed, 6 Oct 2021 15:02:27 +0000 (-0500) Subject: [Clang][OpenMP] Infix OMPLoopTransformationDirective abstract class. NFC. X-Git-Tag: upstream/15.0.7~29462 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f37e8b0b831e61d3b6033829fff05d6d193ab735;p=platform%2Fupstream%2Fllvm.git [Clang][OpenMP] Infix OMPLoopTransformationDirective abstract class. NFC. Insert OMPLoopTransformationDirective between OMPLoopBasedDirective and the loop transformations OMPTileDirective and OMPUnrollDirective. This simplifies handling of loop transformations not requiring distinguishing between OMPTileDirective and OMPUnrollDirective anymore. Reviewed By: ABataev Differential Revision: https://reviews.llvm.org/D111119 --- diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h index f028c3b..285426d 100644 --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -889,22 +889,23 @@ public: /// Calls the specified callback function for all the loops in \p CurStmt, /// from the outermost to the innermost. - static bool doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops, - unsigned NumLoops, - llvm::function_ref Callback, - llvm::function_ref - OnTransformationCallback); + static bool + doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops, + unsigned NumLoops, + llvm::function_ref Callback, + llvm::function_ref + OnTransformationCallback); static bool doForAllLoops(const Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops, llvm::function_ref Callback, - llvm::function_ref + llvm::function_ref OnTransformationCallback) { auto &&NewCallback = [Callback](unsigned Cnt, Stmt *CurStmt) { return Callback(Cnt, CurStmt); }; auto &&NewTransformCb = - [OnTransformationCallback](OMPLoopBasedDirective *A) { + [OnTransformationCallback](OMPLoopTransformationDirective *A) { OnTransformationCallback(A); }; return doForAllLoops(const_cast(CurStmt), TryImperfectlyNestedLoops, @@ -917,7 +918,7 @@ public: doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops, llvm::function_ref Callback) { - auto &&TransformCb = [](OMPLoopBasedDirective *) {}; + auto &&TransformCb = [](OMPLoopTransformationDirective *) {}; return doForAllLoops(CurStmt, TryImperfectlyNestedLoops, NumLoops, Callback, TransformCb); } @@ -954,6 +955,38 @@ public: } }; +/// The base class for all loop transformation directives. +class OMPLoopTransformationDirective : public OMPLoopBasedDirective { + friend class ASTStmtReader; + +protected: + explicit OMPLoopTransformationDirective(StmtClass SC, + OpenMPDirectiveKind Kind, + SourceLocation StartLoc, + SourceLocation EndLoc, + unsigned NumAssociatedLoops) + : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {} + +public: + /// Return the number of associated (consumed) loops. + unsigned getNumAssociatedLoops() const { return getLoopsNumber(); } + + /// Get the de-sugared statements after after the loop transformation. + /// + /// Might be nullptr if either the directive generates no loops and is handled + /// directly in CodeGen, or resolving a template-dependence context is + /// required. + Stmt *getTransformedStmt() const; + + /// Return preinits statement. + Stmt *getPreInits() const; + + static bool classof(const Stmt *T) { + return T->getStmtClass() == OMPTileDirectiveClass || + T->getStmtClass() == OMPUnrollDirectiveClass; + } +}; + /// This is a common base class for loop directives ('omp simd', 'omp /// for', 'omp for simd' etc.). It is responsible for the loop code generation. /// @@ -5011,7 +5044,7 @@ public: }; /// This represents the '#pragma omp tile' loop transformation directive. -class OMPTileDirective final : public OMPLoopBasedDirective { +class OMPTileDirective final : public OMPLoopTransformationDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; @@ -5023,8 +5056,9 @@ class OMPTileDirective final : public OMPLoopBasedDirective { explicit OMPTileDirective(SourceLocation StartLoc, SourceLocation EndLoc, unsigned NumLoops) - : OMPLoopBasedDirective(OMPTileDirectiveClass, llvm::omp::OMPD_tile, - StartLoc, EndLoc, NumLoops) {} + : OMPLoopTransformationDirective(OMPTileDirectiveClass, + llvm::omp::OMPD_tile, StartLoc, EndLoc, + NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5061,8 +5095,6 @@ public: static OMPTileDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses, unsigned NumLoops); - unsigned getNumAssociatedLoops() const { return getLoopsNumber(); } - /// Gets/sets the associated loops after tiling. /// /// This is in de-sugared format stored as a CompoundStmt. @@ -5092,7 +5124,7 @@ public: /// #pragma omp unroll /// for (int i = 0; i < 64; ++i) /// \endcode -class OMPUnrollDirective final : public OMPLoopBasedDirective { +class OMPUnrollDirective final : public OMPLoopTransformationDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; @@ -5103,8 +5135,9 @@ class OMPUnrollDirective final : public OMPLoopBasedDirective { }; explicit OMPUnrollDirective(SourceLocation StartLoc, SourceLocation EndLoc) - : OMPLoopBasedDirective(OMPUnrollDirectiveClass, llvm::omp::OMPD_unroll, - StartLoc, EndLoc, 1) {} + : OMPLoopTransformationDirective(OMPUnrollDirectiveClass, + llvm::omp::OMPD_unroll, StartLoc, EndLoc, + 1) {} /// Set the pre-init statements. void setPreInits(Stmt *PreInits) { diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td index c5540a1..f002c36 100644 --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -224,8 +224,9 @@ def OMPLoopBasedDirective : StmtNode; def OMPLoopDirective : StmtNode; def OMPParallelDirective : StmtNode; def OMPSimdDirective : StmtNode; -def OMPTileDirective : StmtNode; -def OMPUnrollDirective : StmtNode; +def OMPLoopTransformationDirective : StmtNode; +def OMPTileDirective : StmtNode; +def OMPUnrollDirective : StmtNode; def OMPForDirective : StmtNode; def OMPForSimdDirective : StmtNode; def OMPSectionsDirective : StmtNode; diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp index f461d9b..c615463 100644 --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -125,28 +125,25 @@ OMPLoopBasedDirective::tryToFindNextInnerLoop(Stmt *CurStmt, bool OMPLoopBasedDirective::doForAllLoops( Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops, llvm::function_ref Callback, - llvm::function_ref + llvm::function_ref OnTransformationCallback) { CurStmt = CurStmt->IgnoreContainers(); for (unsigned Cnt = 0; Cnt < NumLoops; ++Cnt) { while (true) { - auto *OrigStmt = CurStmt; - if (auto *Dir = dyn_cast(OrigStmt)) { - OnTransformationCallback(Dir); - CurStmt = Dir->getTransformedStmt(); - } else if (auto *Dir = dyn_cast(OrigStmt)) { - OnTransformationCallback(Dir); - CurStmt = Dir->getTransformedStmt(); - } else { + auto *Dir = dyn_cast(CurStmt); + if (!Dir) break; - } - if (!CurStmt) { - // May happen if the loop transformation does not result in a generated - // loop (such as full unrolling). - CurStmt = OrigStmt; + OnTransformationCallback(Dir); + + Stmt *TransformedStmt = Dir->getTransformedStmt(); + if (!TransformedStmt) { + // May happen if the loop transformation does not result in a + // generated loop (such as full unrolling). break; } + + CurStmt = TransformedStmt; } if (auto *CanonLoop = dyn_cast(CurStmt)) CurStmt = CanonLoop->getLoopStmt(); @@ -363,6 +360,32 @@ OMPForDirective *OMPForDirective::Create( return Dir; } +Stmt *OMPLoopTransformationDirective::getTransformedStmt() const { + switch (getStmtClass()) { +#define STMT(CLASS, PARENT) +#define ABSTRACT_STMT(CLASS) +#define OMPLOOPTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \ + case Stmt::CLASS##Class: \ + return static_cast(this)->getTransformedStmt(); +#include "clang/AST/StmtNodes.inc" + default: + llvm_unreachable("Not a loop transformation"); + } +} + +Stmt *OMPLoopTransformationDirective::getPreInits() const { + switch (getStmtClass()) { +#define STMT(CLASS, PARENT) +#define ABSTRACT_STMT(CLASS) +#define OMPLOOPTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \ + case Stmt::CLASS##Class: \ + return static_cast(this)->getPreInits(); +#include "clang/AST/StmtNodes.inc" + default: + llvm_unreachable("Not a loop transformation"); + } +} + OMPForDirective *OMPForDirective::CreateEmpty(const ASTContext &C, unsigned NumClauses, unsigned CollapsedNum, diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index 1afa377..9c48491 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -915,12 +915,17 @@ void StmtProfiler::VisitOMPSimdDirective(const OMPSimdDirective *S) { VisitOMPLoopDirective(S); } -void StmtProfiler::VisitOMPTileDirective(const OMPTileDirective *S) { +void StmtProfiler::VisitOMPLoopTransformationDirective( + const OMPLoopTransformationDirective *S) { VisitOMPLoopBasedDirective(S); } +void StmtProfiler::VisitOMPTileDirective(const OMPTileDirective *S) { + VisitOMPLoopTransformationDirective(S); +} + void StmtProfiler::VisitOMPUnrollDirective(const OMPUnrollDirective *S) { - VisitOMPLoopBasedDirective(S); + VisitOMPLoopTransformationDirective(S); } void StmtProfiler::VisitOMPForDirective(const OMPForDirective *S) { diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp index 68ab136..431cc62 100644 --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -1829,9 +1829,7 @@ static void emitBody(CodeGenFunction &CGF, const Stmt *S, const Stmt *NextLoop, return; } if (SimplifiedS == NextLoop) { - if (auto *Dir = dyn_cast(SimplifiedS)) - SimplifiedS = Dir->getTransformedStmt(); - if (auto *Dir = dyn_cast(SimplifiedS)) + if (auto *Dir = dyn_cast(SimplifiedS)) SimplifiedS = Dir->getTransformedStmt(); if (const auto *CanonLoop = dyn_cast(SimplifiedS)) SimplifiedS = CanonLoop->getLoopStmt(); diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index 5e6fa2c..af70a18 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -3823,13 +3823,8 @@ public: VisitSubCaptures(S); } - void VisitOMPTileDirective(OMPTileDirective *S) { - // #pragma omp tile does not introduce data sharing. - VisitStmt(S); - } - - void VisitOMPUnrollDirective(OMPUnrollDirective *S) { - // #pragma omp unroll does not introduce data sharing. + void VisitOMPLoopTransformationDirective(OMPLoopTransformationDirective *S) { + // Loop transformation directives do not introduce data sharing VisitStmt(S); } @@ -9050,15 +9045,8 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr, } return false; }, - [&SemaRef, &Captures](OMPLoopBasedDirective *Transform) { - Stmt *DependentPreInits; - if (auto *Dir = dyn_cast(Transform)) { - DependentPreInits = Dir->getPreInits(); - } else if (auto *Dir = dyn_cast(Transform)) { - DependentPreInits = Dir->getPreInits(); - } else { - llvm_unreachable("Unexpected loop transformation"); - } + [&SemaRef, &Captures](OMPLoopTransformationDirective *Transform) { + Stmt *DependentPreInits = Transform->getPreInits(); if (!DependentPreInits) return; for (Decl *C : cast(DependentPreInits)->getDeclGroup()) { diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp index 6dcb5d5..34a5883 100644 --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2324,12 +2324,17 @@ void ASTStmtReader::VisitOMPSimdDirective(OMPSimdDirective *D) { VisitOMPLoopDirective(D); } -void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) { +void ASTStmtReader::VisitOMPLoopTransformationDirective( + OMPLoopTransformationDirective *D) { VisitOMPLoopBasedDirective(D); } +void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) { + VisitOMPLoopTransformationDirective(D); +} + void ASTStmtReader::VisitOMPUnrollDirective(OMPUnrollDirective *D) { - VisitOMPLoopBasedDirective(D); + VisitOMPLoopTransformationDirective(D); } void ASTStmtReader::VisitOMPForDirective(OMPForDirective *D) { diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp index a12c726..bf32294 100644 --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -2223,13 +2223,18 @@ void ASTStmtWriter::VisitOMPSimdDirective(OMPSimdDirective *D) { Code = serialization::STMT_OMP_SIMD_DIRECTIVE; } -void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) { +void ASTStmtWriter::VisitOMPLoopTransformationDirective( + OMPLoopTransformationDirective *D) { VisitOMPLoopBasedDirective(D); +} + +void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) { + VisitOMPLoopTransformationDirective(D); Code = serialization::STMT_OMP_TILE_DIRECTIVE; } void ASTStmtWriter::VisitOMPUnrollDirective(OMPUnrollDirective *D) { - VisitOMPLoopBasedDirective(D); + VisitOMPLoopTransformationDirective(D); Code = serialization::STMT_OMP_UNROLL_DIRECTIVE; } diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp index 96a7cb4..f76f041 100644 --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2046,6 +2046,8 @@ public: void VisitOMPLoopDirective(const OMPLoopDirective *D); void VisitOMPParallelDirective(const OMPParallelDirective *D); void VisitOMPSimdDirective(const OMPSimdDirective *D); + void + VisitOMPLoopTransformationDirective(const OMPLoopTransformationDirective *D); void VisitOMPTileDirective(const OMPTileDirective *D); void VisitOMPUnrollDirective(const OMPUnrollDirective *D); void VisitOMPForDirective(const OMPForDirective *D); @@ -2901,12 +2903,17 @@ void EnqueueVisitor::VisitOMPSimdDirective(const OMPSimdDirective *D) { VisitOMPLoopDirective(D); } -void EnqueueVisitor::VisitOMPTileDirective(const OMPTileDirective *D) { +void EnqueueVisitor::VisitOMPLoopTransformationDirective( + const OMPLoopTransformationDirective *D) { VisitOMPLoopBasedDirective(D); } +void EnqueueVisitor::VisitOMPTileDirective(const OMPTileDirective *D) { + VisitOMPLoopTransformationDirective(D); +} + void EnqueueVisitor::VisitOMPUnrollDirective(const OMPUnrollDirective *D) { - VisitOMPLoopBasedDirective(D); + VisitOMPLoopTransformationDirective(D); } void EnqueueVisitor::VisitOMPForDirective(const OMPForDirective *D) {