Insert OMPLoopTransformationDirective between OMPLoopBasedDirective and the loop transformations OMPTileDirective and OMPUnrollDirective. This simplifies handling of loop transformations not requiring distinguishing between OMPTileDirective and OMPUnrollDirective anymore.
Reviewed By: ABataev
Differential Revision: https://reviews.llvm.org/D111119
/// Calls the specified callback function for all the loops in \p CurStmt,
/// from the outermost to the innermost.
- static bool doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops,
- unsigned NumLoops,
- llvm::function_ref<bool(unsigned, Stmt *)> Callback,
- llvm::function_ref<void(OMPLoopBasedDirective *)>
- OnTransformationCallback);
+ static bool
+ doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops,
+ unsigned NumLoops,
+ llvm::function_ref<bool(unsigned, Stmt *)> Callback,
+ llvm::function_ref<void(OMPLoopTransformationDirective *)>
+ OnTransformationCallback);
static bool
doForAllLoops(const Stmt *CurStmt, bool TryImperfectlyNestedLoops,
unsigned NumLoops,
llvm::function_ref<bool(unsigned, const Stmt *)> Callback,
- llvm::function_ref<void(const OMPLoopBasedDirective *)>
+ llvm::function_ref<void(const OMPLoopTransformationDirective *)>
OnTransformationCallback) {
auto &&NewCallback = [Callback](unsigned Cnt, Stmt *CurStmt) {
return Callback(Cnt, CurStmt);
};
auto &&NewTransformCb =
- [OnTransformationCallback](OMPLoopBasedDirective *A) {
+ [OnTransformationCallback](OMPLoopTransformationDirective *A) {
OnTransformationCallback(A);
};
return doForAllLoops(const_cast<Stmt *>(CurStmt), TryImperfectlyNestedLoops,
doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops,
unsigned NumLoops,
llvm::function_ref<bool(unsigned, Stmt *)> Callback) {
- auto &&TransformCb = [](OMPLoopBasedDirective *) {};
+ auto &&TransformCb = [](OMPLoopTransformationDirective *) {};
return doForAllLoops(CurStmt, TryImperfectlyNestedLoops, NumLoops, Callback,
TransformCb);
}
}
};
+/// The base class for all loop transformation directives.
+class OMPLoopTransformationDirective : public OMPLoopBasedDirective {
+ friend class ASTStmtReader;
+
+protected:
+ explicit OMPLoopTransformationDirective(StmtClass SC,
+ OpenMPDirectiveKind Kind,
+ SourceLocation StartLoc,
+ SourceLocation EndLoc,
+ unsigned NumAssociatedLoops)
+ : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {}
+
+public:
+ /// Return the number of associated (consumed) loops.
+ unsigned getNumAssociatedLoops() const { return getLoopsNumber(); }
+
+ /// Get the de-sugared statements after after the loop transformation.
+ ///
+ /// Might be nullptr if either the directive generates no loops and is handled
+ /// directly in CodeGen, or resolving a template-dependence context is
+ /// required.
+ Stmt *getTransformedStmt() const;
+
+ /// Return preinits statement.
+ Stmt *getPreInits() const;
+
+ static bool classof(const Stmt *T) {
+ return T->getStmtClass() == OMPTileDirectiveClass ||
+ T->getStmtClass() == OMPUnrollDirectiveClass;
+ }
+};
+
/// This is a common base class for loop directives ('omp simd', 'omp
/// for', 'omp for simd' etc.). It is responsible for the loop code generation.
///
};
/// This represents the '#pragma omp tile' loop transformation directive.
-class OMPTileDirective final : public OMPLoopBasedDirective {
+class OMPTileDirective final : public OMPLoopTransformationDirective {
friend class ASTStmtReader;
friend class OMPExecutableDirective;
explicit OMPTileDirective(SourceLocation StartLoc, SourceLocation EndLoc,
unsigned NumLoops)
- : OMPLoopBasedDirective(OMPTileDirectiveClass, llvm::omp::OMPD_tile,
- StartLoc, EndLoc, NumLoops) {}
+ : OMPLoopTransformationDirective(OMPTileDirectiveClass,
+ llvm::omp::OMPD_tile, StartLoc, EndLoc,
+ NumLoops) {}
void setPreInits(Stmt *PreInits) {
Data->getChildren()[PreInitsOffset] = PreInits;
static OMPTileDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses,
unsigned NumLoops);
- unsigned getNumAssociatedLoops() const { return getLoopsNumber(); }
-
/// Gets/sets the associated loops after tiling.
///
/// This is in de-sugared format stored as a CompoundStmt.
/// #pragma omp unroll
/// for (int i = 0; i < 64; ++i)
/// \endcode
-class OMPUnrollDirective final : public OMPLoopBasedDirective {
+class OMPUnrollDirective final : public OMPLoopTransformationDirective {
friend class ASTStmtReader;
friend class OMPExecutableDirective;
};
explicit OMPUnrollDirective(SourceLocation StartLoc, SourceLocation EndLoc)
- : OMPLoopBasedDirective(OMPUnrollDirectiveClass, llvm::omp::OMPD_unroll,
- StartLoc, EndLoc, 1) {}
+ : OMPLoopTransformationDirective(OMPUnrollDirectiveClass,
+ llvm::omp::OMPD_unroll, StartLoc, EndLoc,
+ 1) {}
/// Set the pre-init statements.
void setPreInits(Stmt *PreInits) {
def OMPLoopDirective : StmtNode<OMPLoopBasedDirective, 1>;
def OMPParallelDirective : StmtNode<OMPExecutableDirective>;
def OMPSimdDirective : StmtNode<OMPLoopDirective>;
-def OMPTileDirective : StmtNode<OMPLoopBasedDirective>;
-def OMPUnrollDirective : StmtNode<OMPLoopBasedDirective>;
+def OMPLoopTransformationDirective : StmtNode<OMPLoopBasedDirective, 1>;
+def OMPTileDirective : StmtNode<OMPLoopTransformationDirective>;
+def OMPUnrollDirective : StmtNode<OMPLoopTransformationDirective>;
def OMPForDirective : StmtNode<OMPLoopDirective>;
def OMPForSimdDirective : StmtNode<OMPLoopDirective>;
def OMPSectionsDirective : StmtNode<OMPExecutableDirective>;
bool OMPLoopBasedDirective::doForAllLoops(
Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops,
llvm::function_ref<bool(unsigned, Stmt *)> Callback,
- llvm::function_ref<void(OMPLoopBasedDirective *)>
+ llvm::function_ref<void(OMPLoopTransformationDirective *)>
OnTransformationCallback) {
CurStmt = CurStmt->IgnoreContainers();
for (unsigned Cnt = 0; Cnt < NumLoops; ++Cnt) {
while (true) {
- auto *OrigStmt = CurStmt;
- if (auto *Dir = dyn_cast<OMPTileDirective>(OrigStmt)) {
- OnTransformationCallback(Dir);
- CurStmt = Dir->getTransformedStmt();
- } else if (auto *Dir = dyn_cast<OMPUnrollDirective>(OrigStmt)) {
- OnTransformationCallback(Dir);
- CurStmt = Dir->getTransformedStmt();
- } else {
+ auto *Dir = dyn_cast<OMPLoopTransformationDirective>(CurStmt);
+ if (!Dir)
break;
- }
- if (!CurStmt) {
- // May happen if the loop transformation does not result in a generated
- // loop (such as full unrolling).
- CurStmt = OrigStmt;
+ OnTransformationCallback(Dir);
+
+ Stmt *TransformedStmt = Dir->getTransformedStmt();
+ if (!TransformedStmt) {
+ // May happen if the loop transformation does not result in a
+ // generated loop (such as full unrolling).
break;
}
+
+ CurStmt = TransformedStmt;
}
if (auto *CanonLoop = dyn_cast<OMPCanonicalLoop>(CurStmt))
CurStmt = CanonLoop->getLoopStmt();
return Dir;
}
+Stmt *OMPLoopTransformationDirective::getTransformedStmt() const {
+ switch (getStmtClass()) {
+#define STMT(CLASS, PARENT)
+#define ABSTRACT_STMT(CLASS)
+#define OMPLOOPTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \
+ case Stmt::CLASS##Class: \
+ return static_cast<const CLASS *>(this)->getTransformedStmt();
+#include "clang/AST/StmtNodes.inc"
+ default:
+ llvm_unreachable("Not a loop transformation");
+ }
+}
+
+Stmt *OMPLoopTransformationDirective::getPreInits() const {
+ switch (getStmtClass()) {
+#define STMT(CLASS, PARENT)
+#define ABSTRACT_STMT(CLASS)
+#define OMPLOOPTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \
+ case Stmt::CLASS##Class: \
+ return static_cast<const CLASS *>(this)->getPreInits();
+#include "clang/AST/StmtNodes.inc"
+ default:
+ llvm_unreachable("Not a loop transformation");
+ }
+}
+
OMPForDirective *OMPForDirective::CreateEmpty(const ASTContext &C,
unsigned NumClauses,
unsigned CollapsedNum,
VisitOMPLoopDirective(S);
}
-void StmtProfiler::VisitOMPTileDirective(const OMPTileDirective *S) {
+void StmtProfiler::VisitOMPLoopTransformationDirective(
+ const OMPLoopTransformationDirective *S) {
VisitOMPLoopBasedDirective(S);
}
+void StmtProfiler::VisitOMPTileDirective(const OMPTileDirective *S) {
+ VisitOMPLoopTransformationDirective(S);
+}
+
void StmtProfiler::VisitOMPUnrollDirective(const OMPUnrollDirective *S) {
- VisitOMPLoopBasedDirective(S);
+ VisitOMPLoopTransformationDirective(S);
}
void StmtProfiler::VisitOMPForDirective(const OMPForDirective *S) {
return;
}
if (SimplifiedS == NextLoop) {
- if (auto *Dir = dyn_cast<OMPTileDirective>(SimplifiedS))
- SimplifiedS = Dir->getTransformedStmt();
- if (auto *Dir = dyn_cast<OMPUnrollDirective>(SimplifiedS))
+ if (auto *Dir = dyn_cast<OMPLoopTransformationDirective>(SimplifiedS))
SimplifiedS = Dir->getTransformedStmt();
if (const auto *CanonLoop = dyn_cast<OMPCanonicalLoop>(SimplifiedS))
SimplifiedS = CanonLoop->getLoopStmt();
VisitSubCaptures(S);
}
- void VisitOMPTileDirective(OMPTileDirective *S) {
- // #pragma omp tile does not introduce data sharing.
- VisitStmt(S);
- }
-
- void VisitOMPUnrollDirective(OMPUnrollDirective *S) {
- // #pragma omp unroll does not introduce data sharing.
+ void VisitOMPLoopTransformationDirective(OMPLoopTransformationDirective *S) {
+ // Loop transformation directives do not introduce data sharing
VisitStmt(S);
}
}
return false;
},
- [&SemaRef, &Captures](OMPLoopBasedDirective *Transform) {
- Stmt *DependentPreInits;
- if (auto *Dir = dyn_cast<OMPTileDirective>(Transform)) {
- DependentPreInits = Dir->getPreInits();
- } else if (auto *Dir = dyn_cast<OMPUnrollDirective>(Transform)) {
- DependentPreInits = Dir->getPreInits();
- } else {
- llvm_unreachable("Unexpected loop transformation");
- }
+ [&SemaRef, &Captures](OMPLoopTransformationDirective *Transform) {
+ Stmt *DependentPreInits = Transform->getPreInits();
if (!DependentPreInits)
return;
for (Decl *C : cast<DeclStmt>(DependentPreInits)->getDeclGroup()) {
VisitOMPLoopDirective(D);
}
-void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) {
+void ASTStmtReader::VisitOMPLoopTransformationDirective(
+ OMPLoopTransformationDirective *D) {
VisitOMPLoopBasedDirective(D);
}
+void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) {
+ VisitOMPLoopTransformationDirective(D);
+}
+
void ASTStmtReader::VisitOMPUnrollDirective(OMPUnrollDirective *D) {
- VisitOMPLoopBasedDirective(D);
+ VisitOMPLoopTransformationDirective(D);
}
void ASTStmtReader::VisitOMPForDirective(OMPForDirective *D) {
Code = serialization::STMT_OMP_SIMD_DIRECTIVE;
}
-void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) {
+void ASTStmtWriter::VisitOMPLoopTransformationDirective(
+ OMPLoopTransformationDirective *D) {
VisitOMPLoopBasedDirective(D);
+}
+
+void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) {
+ VisitOMPLoopTransformationDirective(D);
Code = serialization::STMT_OMP_TILE_DIRECTIVE;
}
void ASTStmtWriter::VisitOMPUnrollDirective(OMPUnrollDirective *D) {
- VisitOMPLoopBasedDirective(D);
+ VisitOMPLoopTransformationDirective(D);
Code = serialization::STMT_OMP_UNROLL_DIRECTIVE;
}
void VisitOMPLoopDirective(const OMPLoopDirective *D);
void VisitOMPParallelDirective(const OMPParallelDirective *D);
void VisitOMPSimdDirective(const OMPSimdDirective *D);
+ void
+ VisitOMPLoopTransformationDirective(const OMPLoopTransformationDirective *D);
void VisitOMPTileDirective(const OMPTileDirective *D);
void VisitOMPUnrollDirective(const OMPUnrollDirective *D);
void VisitOMPForDirective(const OMPForDirective *D);
VisitOMPLoopDirective(D);
}
-void EnqueueVisitor::VisitOMPTileDirective(const OMPTileDirective *D) {
+void EnqueueVisitor::VisitOMPLoopTransformationDirective(
+ const OMPLoopTransformationDirective *D) {
VisitOMPLoopBasedDirective(D);
}
+void EnqueueVisitor::VisitOMPTileDirective(const OMPTileDirective *D) {
+ VisitOMPLoopTransformationDirective(D);
+}
+
void EnqueueVisitor::VisitOMPUnrollDirective(const OMPUnrollDirective *D) {
- VisitOMPLoopBasedDirective(D);
+ VisitOMPLoopTransformationDirective(D);
}
void EnqueueVisitor::VisitOMPForDirective(const OMPForDirective *D) {