[Refactor] IslAst and payload struct
authorJohannes Doerfert <jdoerfert@codeaurora.org>
Wed, 23 Jul 2014 20:17:28 +0000 (20:17 +0000)
committerJohannes Doerfert <jdoerfert@codeaurora.org>
Wed, 23 Jul 2014 20:17:28 +0000 (20:17 +0000)
  + Renamed context into build when it's the isl_ast_build
  + Use the IslAstInfo functions to extract the schedule of a node
  + Use the IslAstInfo functions to extract the build/context of a node
  + Move the payload struct into the IslAstInfo class
  + Use a constructor and destructor (also new and delete) to
    allocate/initialize the payload struct

llvm-svn: 213792

polly/include/polly/CodeGen/IslAst.h
polly/lib/CodeGen/IslAst.cpp
polly/lib/CodeGen/IslCodeGeneration.cpp

index ab10a54..341d56c 100644 (file)
@@ -40,20 +40,32 @@ namespace polly {
 class Scop;
 class IslAst;
 
-// Information about an ast node.
-struct IslAstUserPayload {
-  struct isl_ast_build *Context;
-  // The node is the outermost parallel loop.
-  int IsOutermostParallel;
+class IslAstInfo : public ScopPass {
+public:
+  /// @brief Payload information used to annoate an ast node.
+  struct IslAstUserPayload {
+    /// @brief Construct and initialize the payload.
+    IslAstUserPayload()
+        : IsInnermostParallel(false), IsOutermostParallel(false),
+          IsReductionParallel(false), Build(nullptr) {}
 
-  // The node is the innermost parallel loop.
-  int IsInnermostParallel;
+    /// @brief Cleanup all isl structs on destruction.
+    ~IslAstUserPayload();
 
-  // The node is only parallel because of reductions
-  bool IsReductionParallel;
-};
+    /// @brief Flag to mark innermost parallel loops.
+    bool IsInnermostParallel;
 
-class IslAstInfo : public ScopPass {
+    /// @brief Flag to mark outermost parallel loops.
+    bool IsOutermostParallel;
+
+    /// @brief Flag to mark reduction parallel loops.
+    bool IsReductionParallel;
+
+    /// @brief The build environment at the time this node was constructed.
+    isl_ast_build *Build;
+  };
+
+private:
   Scop *S;
   IslAst *Ast;
 
@@ -97,6 +109,9 @@ public:
   /// @brief Is this loop a reduction parallel loop?
   static bool isReductionParallel(__isl_keep isl_ast_node *Node);
 
+  /// @brief Get the nodes schedule or a nullptr if not available.
+  static __isl_give isl_union_map *getSchedule(__isl_keep isl_ast_node *Node);
+
   ///}
 
   virtual void getAnalysisUsage(AnalysisUsage &AU) const;
index ac41c91..3805c9a 100644 (file)
 #include "isl/map.h"
 #include "isl/aff.h"
 
+#define DEBUG_TYPE "polly-ast"
+
 using namespace llvm;
 using namespace polly;
 
-#define DEBUG_TYPE "polly-ast"
+using IslAstUserPayload = IslAstInfo::IslAstUserPayload;
 
 static cl::opt<bool> UseContext("polly-ast-use-context",
                                 cl::desc("Use context"), cl::Hidden,
@@ -69,10 +71,19 @@ private:
   isl_ast_node *Root;
   isl_ast_expr *RunCondition;
 
-  void buildRunCondition(__isl_keep isl_ast_build *Context);
+  void buildRunCondition(__isl_keep isl_ast_build *Build);
 };
 } // End namespace polly.
 
+/// @brief Free an IslAstUserPayload object pointed to by @p Ptr
+static void freeIslAstUserPayload(void *Ptr) {
+  delete ((IslAstInfo::IslAstUserPayload *)Ptr);
+}
+
+IslAstInfo::IslAstUserPayload::~IslAstUserPayload() {
+  isl_ast_build_free(Build);
+}
+
 // Temporary information used when building the ast.
 struct AstBuildUserInfo {
   // The dependence information.
@@ -115,32 +126,12 @@ printFor(__isl_take isl_printer *Printer,
   if (!Id)
     return isl_ast_node_for_print(Node, Printer, PrintOptions);
 
-  struct IslAstUserPayload *Info =
-      (struct IslAstUserPayload *)isl_id_get_user(Id);
+  IslAstUserPayload *Info = (IslAstUserPayload *)isl_id_get_user(Id);
   Printer = printParallelFor(Node, Printer, PrintOptions, Info);
   isl_id_free(Id);
   return Printer;
 }
 
-// Allocate an AstNodeInfo structure and initialize it with default values.
-static struct IslAstUserPayload *allocateIslAstUser() {
-  struct IslAstUserPayload *NodeInfo;
-  NodeInfo =
-      (struct IslAstUserPayload *)malloc(sizeof(struct IslAstUserPayload));
-  NodeInfo->Context = 0;
-  NodeInfo->IsOutermostParallel = 0;
-  NodeInfo->IsInnermostParallel = 0;
-  NodeInfo->IsReductionParallel = false;
-  return NodeInfo;
-}
-
-// Free the AstNodeInfo structure.
-static void freeIslAstUser(void *Ptr) {
-  struct IslAstUserPayload *UserStruct = (struct IslAstUserPayload *)Ptr;
-  isl_ast_build_free(UserStruct->Context);
-  free(UserStruct);
-}
-
 // Check if the current scheduling dimension is parallel.
 //
 // We check for parallelism by verifying that the loop does not carry any
@@ -221,8 +212,8 @@ static bool astScheduleDimIsParallel(__isl_keep isl_ast_build *Build,
 
 // Mark a for node openmp parallel, if it is the outermost parallel for node.
 static void markOpenmpParallel(__isl_keep isl_ast_build *Build,
-                               struct AstBuildUserInfo *BuildInfo,
-                               struct IslAstUserPayload *NodeInfo) {
+                               AstBuildUserInfo *BuildInfo,
+                               IslAstUserPayload *NodeInfo) {
   if (BuildInfo->InParallelFor)
     return;
 
@@ -242,10 +233,10 @@ static void markOpenmpParallel(__isl_keep isl_ast_build *Build,
 //
 static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build,
                                             void *User) {
-  struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *)User;
-  struct IslAstUserPayload *NodeInfo = allocateIslAstUser();
+  AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User;
+  IslAstUserPayload *NodeInfo = new IslAstUserPayload();
   isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", NodeInfo);
-  Id = isl_id_set_free_user(Id, freeIslAstUser);
+  Id = isl_id_set_free_user(Id, freeIslAstUserPayload);
 
   markOpenmpParallel(Build, BuildInfo, NodeInfo);
 
@@ -305,9 +296,8 @@ astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build,
   isl_id *Id = isl_ast_node_get_annotation(Node);
   if (!Id)
     return Node;
-  struct IslAstUserPayload *Info =
-      (struct IslAstUserPayload *)isl_id_get_user(Id);
-  struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *)User;
+  IslAstUserPayload *Info = (IslAstUserPayload *)isl_id_get_user(Id);
+  AstBuildUserInfo *BuildInfo = (AstBuildUserInfo *)User;
 
   if (Info) {
     if (Info->IsOutermostParallel)
@@ -316,8 +306,8 @@ astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build,
       if (astScheduleDimIsParallel(Build, BuildInfo->Deps,
                                    Info->IsReductionParallel))
         Info->IsInnermostParallel = 1;
-    if (!Info->Context)
-      Info->Context = isl_ast_build_copy(Build);
+    if (!Info->Build)
+      Info->Build = isl_ast_build_copy(Build);
   }
 
   isl_id_free(Id);
@@ -325,29 +315,29 @@ astBuildAfterFor(__isl_take isl_ast_node *Node, __isl_keep isl_ast_build *Build,
 }
 
 static __isl_give isl_ast_node *AtEachDomain(__isl_take isl_ast_node *Node,
-                                             __isl_keep isl_ast_build *Context,
+                                             __isl_keep isl_ast_build *Build,
                                              void *User) {
-  struct IslAstUserPayload *Info = nullptr;
+  IslAstUserPayload *Info = nullptr;
   isl_id *Id = isl_ast_node_get_annotation(Node);
 
   if (Id)
-    Info = (struct IslAstUserPayload *)isl_id_get_user(Id);
+    Info = (IslAstUserPayload *)isl_id_get_user(Id);
 
   if (!Info) {
     // Allocate annotations once: parallel for detection might have already
     // allocated the annotations for this node.
-    Info = allocateIslAstUser();
+    Info = new IslAstUserPayload();
     Id = isl_id_alloc(isl_ast_node_get_ctx(Node), nullptr, Info);
-    Id = isl_id_set_free_user(Id, &freeIslAstUser);
+    Id = isl_id_set_free_user(Id, freeIslAstUserPayload);
   }
 
-  if (!Info->Context)
-    Info->Context = isl_ast_build_copy(Context);
+  if (!Info->Build)
+    Info->Build = isl_ast_build_copy(Build);
 
   return isl_ast_node_set_annotation(Node, Id);
 }
 
-void IslAst::buildRunCondition(__isl_keep isl_ast_build *Context) {
+void IslAst::buildRunCondition(__isl_keep isl_ast_build *Build) {
   // The conditions that need to be checked at run-time for this scop are
   // available as an isl_set in the AssumedContext. We generate code for this
   // check as follows. First, we generate an isl_pw_aff that is 1, if a certain
@@ -373,21 +363,21 @@ void IslAst::buildRunCondition(__isl_keep isl_ast_build *Context) {
 
   isl_pw_aff *Cond = isl_pw_aff_union_max(PwOne, PwZero);
 
-  RunCondition = isl_ast_build_expr_from_pw_aff(Context, Cond);
+  RunCondition = isl_ast_build_expr_from_pw_aff(Build, Cond);
 }
 
 IslAst::IslAst(Scop *Scop, Dependences &D) : S(Scop) {
   isl_ctx *Ctx = S->getIslCtx();
   isl_options_set_ast_build_atomic_upper_bound(Ctx, true);
-  isl_ast_build *Context;
-  struct AstBuildUserInfo BuildInfo;
+  isl_ast_build *Build;
+  AstBuildUserInfo BuildInfo;
 
   if (UseContext)
-    Context = isl_ast_build_from_context(S->getContext());
+    Build = isl_ast_build_from_context(S->getContext());
   else
-    Context = isl_ast_build_from_context(isl_set_universe(S->getParamSpace()));
+    Build = isl_ast_build_from_context(isl_set_universe(S->getParamSpace()));
 
-  Context = isl_ast_build_set_at_each_domain(Context, AtEachDomain, nullptr);
+  Build = isl_ast_build_set_at_each_domain(Build, AtEachDomain, nullptr);
 
   isl_union_map *Schedule =
       isl_union_map_intersect_domain(S->getSchedule(), S->getDomains());
@@ -396,17 +386,17 @@ IslAst::IslAst(Scop *Scop, Dependences &D) : S(Scop) {
     BuildInfo.Deps = &D;
     BuildInfo.InParallelFor = 0;
 
-    Context = isl_ast_build_set_before_each_for(Context, &astBuildBeforeFor,
-                                                &BuildInfo);
-    Context = isl_ast_build_set_after_each_for(Context, &astBuildAfterFor,
-                                               &BuildInfo);
+    Build = isl_ast_build_set_before_each_for(Build, &astBuildBeforeFor,
+                                              &BuildInfo);
+    Build =
+        isl_ast_build_set_after_each_for(Build, &astBuildAfterFor, &BuildInfo);
   }
 
-  buildRunCondition(Context);
+  buildRunCondition(Build);
 
-  Root = isl_ast_build_ast_from_schedule(Context, Schedule);
+  Root = isl_ast_build_ast_from_schedule(Build, Schedule);
 
-  isl_ast_build_free(Context);
+  isl_ast_build_free(Build);
 }
 
 IslAst::~IslAst() {
@@ -476,6 +466,11 @@ bool IslAstInfo::isReductionParallel(__isl_keep isl_ast_node *Node) {
   return Payload && Payload->IsReductionParallel;
 }
 
+isl_union_map *IslAstInfo::getSchedule(__isl_keep isl_ast_node *Node) {
+  IslAstUserPayload *Payload = getNodePayload(Node);
+  return Payload ? isl_ast_build_get_schedule(Payload->Build) : nullptr;
+}
+
 void IslAstInfo::printScop(raw_ostream &OS) const {
   isl_ast_print_options *Options;
   isl_ast_node *RootNode = getAst();
index 4eb39cc..142324c 100644 (file)
@@ -778,20 +778,8 @@ IslNodeBuilder::getUpperBound(__isl_keep isl_ast_node *For,
 }
 
 unsigned IslNodeBuilder::getNumberOfIterations(__isl_keep isl_ast_node *For) {
-  isl_id *Annotation = isl_ast_node_get_annotation(For);
-  if (!Annotation)
-    return -1;
-
-  struct IslAstUserPayload *Info =
-      (struct IslAstUserPayload *)isl_id_get_user(Annotation);
-  if (!Info) {
-    isl_id_free(Annotation);
-    return -1;
-  }
-
-  isl_union_map *Schedule = isl_ast_build_get_schedule(Info->Context);
+  isl_union_map *Schedule = IslAstInfo::getSchedule(Build);
   isl_set *LoopDomain = isl_set_from_union_set(isl_union_map_range(Schedule));
-  isl_id_free(Annotation);
   int NumberOfIterations = polly::getNumberOfIterations(LoopDomain);
   if (NumberOfIterations == -1)
     return -1;
@@ -848,14 +836,7 @@ void IslNodeBuilder::createForVector(__isl_take isl_ast_node *For,
   for (int i = 1; i < VectorWidth; i++)
     IVS[i] = Builder.CreateAdd(IVS[i - 1], ValueInc, "p_vector_iv");
 
-  isl_id *Annotation = isl_ast_node_get_annotation(For);
-  assert(Annotation && "For statement is not annotated");
-
-  struct IslAstUserPayload *Info =
-      (struct IslAstUserPayload *)isl_id_get_user(Annotation);
-  assert(Info && "For statement annotation does not contain info");
-
-  isl_union_map *Schedule = isl_ast_build_get_schedule(Info->Context);
+  isl_union_map *Schedule = IslAstInfo::getSchedule(Build);
   assert(Schedule && "For statement annotation does not contain its schedule");
 
   IDToValue[IteratorID] = ValueLB;
@@ -883,7 +864,6 @@ void IslNodeBuilder::createForVector(__isl_take isl_ast_node *For,
 
   IDToValue.erase(IteratorID);
   isl_id_free(IteratorID);
-  isl_id_free(Annotation);
   isl_union_map_free(Schedule);
 
   isl_ast_node_free(For);