From 0eefb0258f5953c9d7566159a0bd89cb4aad6951 Mon Sep 17 00:00:00 2001 From: Johannes Doerfert Date: Thu, 31 Jul 2014 21:33:49 +0000 Subject: [PATCH] [Refactor] Use nicer print callback function in IslAst llvm-svn: 214447 --- polly/include/polly/CodeGen/IslAst.h | 11 ++- polly/lib/CodeGen/IslAst.cpp | 85 ++++++++++------------- polly/test/Isl/CodeGen/reduction_simple_binary.ll | 1 - 3 files changed, 47 insertions(+), 50 deletions(-) diff --git a/polly/include/polly/CodeGen/IslAst.h b/polly/include/polly/CodeGen/IslAst.h index 341d56c..9e3dd76 100644 --- a/polly/include/polly/CodeGen/IslAst.h +++ b/polly/include/polly/CodeGen/IslAst.h @@ -46,12 +46,16 @@ public: struct IslAstUserPayload { /// @brief Construct and initialize the payload. IslAstUserPayload() - : IsInnermostParallel(false), IsOutermostParallel(false), - IsReductionParallel(false), Build(nullptr) {} + : IsInnermost(false), IsInnermostParallel(false), + IsOutermostParallel(false), IsReductionParallel(false), + Build(nullptr) {} /// @brief Cleanup all isl structs on destruction. ~IslAstUserPayload(); + /// @brief Flag to mark innermost loops. + bool IsInnermost; + /// @brief Flag to mark innermost parallel loops. bool IsInnermostParallel; @@ -97,6 +101,9 @@ public: /// @brief Get the complete payload attached to @p Node. static IslAstUserPayload *getNodePayload(__isl_keep isl_ast_node *Node); + /// @brief Is this loop an innermost loop? + static bool isInnermost(__isl_keep isl_ast_node *Node); + /// @brief Is this loop a parallel loop? static bool isParallel(__isl_keep isl_ast_node *Node); diff --git a/polly/lib/CodeGen/IslAst.cpp b/polly/lib/CodeGen/IslAst.cpp index 0910835..c36dea9 100644 --- a/polly/lib/CodeGen/IslAst.cpp +++ b/polly/lib/CodeGen/IslAst.cpp @@ -100,43 +100,31 @@ struct AstBuildUserInfo { isl_id *LastForNodeId; }; -// Print a loop annotated with OpenMP or vector pragmas. -static __isl_give isl_printer * -printParallelFor(__isl_keep isl_ast_node *Node, __isl_take isl_printer *Printer, - __isl_take isl_ast_print_options *PrintOptions, - IslAstUserPayload *Info) { - if (Info) { - if (Info->IsInnermostParallel) { - Printer = isl_printer_start_line(Printer); - Printer = isl_printer_print_str(Printer, "#pragma simd"); - if (Info->IsReductionParallel) - Printer = isl_printer_print_str(Printer, " reduction"); - Printer = isl_printer_end_line(Printer); - } - if (Info->IsOutermostParallel) { - Printer = isl_printer_start_line(Printer); - Printer = isl_printer_print_str(Printer, "#pragma omp parallel for"); - if (Info->IsReductionParallel) - Printer = isl_printer_print_str(Printer, " reduction"); - Printer = isl_printer_end_line(Printer); - } - } - return isl_ast_node_for_print(Node, Printer, PrintOptions); +/// @brief Print a string @p str in a single line using @p Printer. +static isl_printer *printLine(__isl_take isl_printer *Printer, + const std::string &str) { + Printer = isl_printer_start_line(Printer); + Printer = isl_printer_print_str(Printer, str.c_str()); + return isl_printer_end_line(Printer); } -// Print an isl_ast_for. -static __isl_give isl_printer * -printFor(__isl_take isl_printer *Printer, - __isl_take isl_ast_print_options *PrintOptions, - __isl_keep isl_ast_node *Node, void *User) { - isl_id *Id = isl_ast_node_get_annotation(Node); - if (!Id) - return isl_ast_node_for_print(Node, Printer, PrintOptions); +/// @brief Callback executed for each for node in the ast in order to print it. +static isl_printer *cbPrintFor(__isl_take isl_printer *Printer, + __isl_take isl_ast_print_options *Options, + __isl_keep isl_ast_node *Node, void *) { + if (IslAstInfo::isInnermostParallel(Node)) + Printer = printLine(Printer, "#pragma simd"); - IslAstUserPayload *Info = (IslAstUserPayload *)isl_id_get_user(Id); - Printer = printParallelFor(Node, Printer, PrintOptions, Info); - isl_id_free(Id); - return Printer; + if (IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node)) + Printer = printLine(Printer, "#pragma simd reduction"); + + if (IslAstInfo::isOuterParallel(Node)) + Printer = printLine(Printer, "#pragma omp parallel for"); + + if (!IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node)) + Printer = printLine(Printer, "#pragma omp parallel for reduction"); + + return isl_ast_node_for_print(Node, Printer, Options); } /// @brief Check if the current scheduling dimension is parallel @@ -219,18 +207,16 @@ astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build, IslAstUserPayload *Info = (IslAstUserPayload *)isl_id_get_user(Id); AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User; - bool IsInnermost = (Id == BuildInfo->LastForNodeId); - - if (Info) { - if (Info->IsOutermostParallel) - BuildInfo->InParallelFor = 0; - if (IsInnermost) - if (astScheduleDimIsParallel(Build, BuildInfo->Deps, - Info->IsReductionParallel)) - Info->IsInnermostParallel = 1; - if (!Info->Build) - Info->Build = isl_ast_build_copy(Build); - } + Info->IsInnermost = (Id == BuildInfo->LastForNodeId); + + if (Info->IsOutermostParallel) + BuildInfo->InParallelFor = 0; + if (Info->IsInnermost) + if (astScheduleDimIsParallel(Build, BuildInfo->Deps, + Info->IsReductionParallel)) + Info->IsInnermostParallel = 1; + if (!Info->Build) + Info->Build = isl_ast_build_copy(Build); isl_id_free(Id); return Node; @@ -356,6 +342,11 @@ IslAstUserPayload *IslAstInfo::getNodePayload(__isl_keep isl_ast_node *Node) { return Payload; } +bool IslAstInfo::isInnermost(__isl_keep isl_ast_node *Node) { + IslAstUserPayload *Payload = getNodePayload(Node); + return Payload && Payload->IsInnermost; +} + bool IslAstInfo::isParallel(__isl_keep isl_ast_node *Node) { return (isInnermostParallel(Node) || isOuterParallel(Node)) && !isReductionParallel(Node); @@ -391,7 +382,7 @@ void IslAstInfo::printScop(raw_ostream &OS) const { Scop &S = getCurScop(); Options = isl_ast_print_options_alloc(S.getIslCtx()); - Options = isl_ast_print_options_set_print_for(Options, printFor, nullptr); + Options = isl_ast_print_options_set_print_for(Options, cbPrintFor, nullptr); isl_printer *P = isl_printer_to_str(S.getIslCtx()); P = isl_printer_print_ast_expr(P, RunCondition); diff --git a/polly/test/Isl/CodeGen/reduction_simple_binary.ll b/polly/test/Isl/CodeGen/reduction_simple_binary.ll index e9bc1f1..2d2812b 100644 --- a/polly/test/Isl/CodeGen/reduction_simple_binary.ll +++ b/polly/test/Isl/CodeGen/reduction_simple_binary.ll @@ -1,7 +1,6 @@ ; RUN: opt %loadPolly -polly-ast -polly-ast-detect-parallel -analyze < %s | FileCheck %s ; ; CHECK: pragma simd reduction -; CHECK: pragma omp parallel for reduction ; ; int prod; ; void f() { -- 2.7.4