[Refactor] Use nicer print callback function in IslAst
authorJohannes Doerfert <jdoerfert@codeaurora.org>
Thu, 31 Jul 2014 21:33:49 +0000 (21:33 +0000)
committerJohannes Doerfert <jdoerfert@codeaurora.org>
Thu, 31 Jul 2014 21:33:49 +0000 (21:33 +0000)
llvm-svn: 214447

polly/include/polly/CodeGen/IslAst.h
polly/lib/CodeGen/IslAst.cpp
polly/test/Isl/CodeGen/reduction_simple_binary.ll

index 341d56c..9e3dd76 100644 (file)
@@ -46,12 +46,16 @@ public:
   struct IslAstUserPayload {
     /// @brief Construct and initialize the payload.
     IslAstUserPayload()
-        : IsInnermostParallel(false), IsOutermostParallel(false),
-          IsReductionParallel(false), Build(nullptr) {}
+        : IsInnermost(false), IsInnermostParallel(false),
+          IsOutermostParallel(false), IsReductionParallel(false),
+          Build(nullptr) {}
 
     /// @brief Cleanup all isl structs on destruction.
     ~IslAstUserPayload();
 
+    /// @brief Flag to mark innermost loops.
+    bool IsInnermost;
+
     /// @brief Flag to mark innermost parallel loops.
     bool IsInnermostParallel;
 
@@ -97,6 +101,9 @@ public:
   /// @brief Get the complete payload attached to @p Node.
   static IslAstUserPayload *getNodePayload(__isl_keep isl_ast_node *Node);
 
+  /// @brief Is this loop an innermost loop?
+  static bool isInnermost(__isl_keep isl_ast_node *Node);
+
   /// @brief Is this loop a parallel loop?
   static bool isParallel(__isl_keep isl_ast_node *Node);
 
index 0910835..c36dea9 100644 (file)
@@ -100,43 +100,31 @@ struct AstBuildUserInfo {
   isl_id *LastForNodeId;
 };
 
-// Print a loop annotated with OpenMP or vector pragmas.
-static __isl_give isl_printer *
-printParallelFor(__isl_keep isl_ast_node *Node, __isl_take isl_printer *Printer,
-                 __isl_take isl_ast_print_options *PrintOptions,
-                 IslAstUserPayload *Info) {
-  if (Info) {
-    if (Info->IsInnermostParallel) {
-      Printer = isl_printer_start_line(Printer);
-      Printer = isl_printer_print_str(Printer, "#pragma simd");
-      if (Info->IsReductionParallel)
-        Printer = isl_printer_print_str(Printer, " reduction");
-      Printer = isl_printer_end_line(Printer);
-    }
-    if (Info->IsOutermostParallel) {
-      Printer = isl_printer_start_line(Printer);
-      Printer = isl_printer_print_str(Printer, "#pragma omp parallel for");
-      if (Info->IsReductionParallel)
-        Printer = isl_printer_print_str(Printer, " reduction");
-      Printer = isl_printer_end_line(Printer);
-    }
-  }
-  return isl_ast_node_for_print(Node, Printer, PrintOptions);
+/// @brief Print a string @p str in a single line using @p Printer.
+static isl_printer *printLine(__isl_take isl_printer *Printer,
+                              const std::string &str) {
+  Printer = isl_printer_start_line(Printer);
+  Printer = isl_printer_print_str(Printer, str.c_str());
+  return isl_printer_end_line(Printer);
 }
 
-// Print an isl_ast_for.
-static __isl_give isl_printer *
-printFor(__isl_take isl_printer *Printer,
-         __isl_take isl_ast_print_options *PrintOptions,
-         __isl_keep isl_ast_node *Node, void *User) {
-  isl_id *Id = isl_ast_node_get_annotation(Node);
-  if (!Id)
-    return isl_ast_node_for_print(Node, Printer, PrintOptions);
+/// @brief Callback executed for each for node in the ast in order to print it.
+static isl_printer *cbPrintFor(__isl_take isl_printer *Printer,
+                               __isl_take isl_ast_print_options *Options,
+                               __isl_keep isl_ast_node *Node, void *) {
+  if (IslAstInfo::isInnermostParallel(Node))
+    Printer = printLine(Printer, "#pragma simd");
 
-  IslAstUserPayload *Info = (IslAstUserPayload *)isl_id_get_user(Id);
-  Printer = printParallelFor(Node, Printer, PrintOptions, Info);
-  isl_id_free(Id);
-  return Printer;
+  if (IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node))
+    Printer = printLine(Printer, "#pragma simd reduction");
+
+  if (IslAstInfo::isOuterParallel(Node))
+    Printer = printLine(Printer, "#pragma omp parallel for");
+
+  if (!IslAstInfo::isInnermost(Node) && IslAstInfo::isReductionParallel(Node))
+    Printer = printLine(Printer, "#pragma omp parallel for reduction");
+
+  return isl_ast_node_for_print(Node, Printer, Options);
 }
 
 /// @brief Check if the current scheduling dimension is parallel
@@ -219,18 +207,16 @@ astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build,
   IslAstUserPayload *Info = (IslAstUserPayload *)isl_id_get_user(Id);
   AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User;
 
-  bool IsInnermost = (Id == BuildInfo->LastForNodeId);
-
-  if (Info) {
-    if (Info->IsOutermostParallel)
-      BuildInfo->InParallelFor = 0;
-    if (IsInnermost)
-      if (astScheduleDimIsParallel(Build, BuildInfo->Deps,
-                                   Info->IsReductionParallel))
-        Info->IsInnermostParallel = 1;
-    if (!Info->Build)
-      Info->Build = isl_ast_build_copy(Build);
-  }
+  Info->IsInnermost = (Id == BuildInfo->LastForNodeId);
+
+  if (Info->IsOutermostParallel)
+    BuildInfo->InParallelFor = 0;
+  if (Info->IsInnermost)
+    if (astScheduleDimIsParallel(Build, BuildInfo->Deps,
+                                  Info->IsReductionParallel))
+      Info->IsInnermostParallel = 1;
+  if (!Info->Build)
+    Info->Build = isl_ast_build_copy(Build);
 
   isl_id_free(Id);
   return Node;
@@ -356,6 +342,11 @@ IslAstUserPayload *IslAstInfo::getNodePayload(__isl_keep isl_ast_node *Node) {
   return Payload;
 }
 
+bool IslAstInfo::isInnermost(__isl_keep isl_ast_node *Node) {
+  IslAstUserPayload *Payload = getNodePayload(Node);
+  return Payload && Payload->IsInnermost;
+}
+
 bool IslAstInfo::isParallel(__isl_keep isl_ast_node *Node) {
   return (isInnermostParallel(Node) || isOuterParallel(Node)) &&
          !isReductionParallel(Node);
@@ -391,7 +382,7 @@ void IslAstInfo::printScop(raw_ostream &OS) const {
 
   Scop &S = getCurScop();
   Options = isl_ast_print_options_alloc(S.getIslCtx());
-  Options = isl_ast_print_options_set_print_for(Options, printFor, nullptr);
+  Options = isl_ast_print_options_set_print_for(Options, cbPrintFor, nullptr);
 
   isl_printer *P = isl_printer_to_str(S.getIslCtx());
   P = isl_printer_print_ast_expr(P, RunCondition);
index e9bc1f1..2d2812b 100644 (file)
@@ -1,7 +1,6 @@
 ; RUN: opt %loadPolly -polly-ast -polly-ast-detect-parallel -analyze < %s | FileCheck %s
 ;
 ; CHECK: pragma simd reduction
-; CHECK: pragma omp parallel for reduction
 ;
 ; int prod;
 ; void f() {