add before_each_for/after_each_for callbacks
authorSven Verdoolaege <skimo@kotnet.org>
Tue, 20 Nov 2012 19:22:35 +0000 (20:22 +0100)
committerSven Verdoolaege <skimo@kotnet.org>
Wed, 21 Nov 2012 15:43:59 +0000 (16:43 +0100)
doc/user.pod
include/isl/ast_build.h
isl_ast_build.c
isl_ast_build_private.h
isl_ast_codegen.c
isl_test.c

index 78fa57e..e8969c6 100644 (file)
@@ -5975,9 +5975,34 @@ user defined node created using the following function.
                        __isl_take isl_ast_node *node,
                        __isl_keep isl_ast_build *build,
                        void *user), void *user);
+       __isl_give isl_ast_build *
+       isl_ast_build_set_before_each_for(
+               __isl_take isl_ast_build *build,
+               __isl_give isl_id *(*fn)(
+                       __isl_keep isl_ast_build *build,
+                       void *user), void *user);
+       __isl_give isl_ast_build *
+       isl_ast_build_set_after_each_for(
+               __isl_take isl_ast_build *build,
+               __isl_give isl_ast_node *(*fn)(
+                       __isl_take isl_ast_node *node,
+                       __isl_keep isl_ast_build *build,
+                       void *user), void *user);
 
 The callback set by C<isl_ast_build_set_at_each_domain> will
 be called for each domain AST node.
+The callbacks set by C<isl_ast_build_set_before_each_for>
+and C<isl_ast_build_set_after_each_for> will be called
+for each for AST node.  The first will be called in depth-first
+pre-order, while the second will be called in depth-first post-order.
+Since C<isl_ast_build_set_before_each_for> is called before the for
+node is actually constructed, it is only passed an C<isl_ast_build>.
+The returned C<isl_id> will be added as an annotation (using
+C<isl_ast_node_set_annotation>) to the constructed for node.
+In particular, if the user has also specified an C<after_each_for>
+callback, then the annotation can be retrieved from the node passed to
+that callback using C<isl_ast_node_get_annotation>.
+All callbacks should C<NULL> on failure.
 The given C<isl_ast_build> can be used to create new
 C<isl_ast_expr> objects using C<isl_ast_build_expr_from_pw_aff>
 or C<isl_ast_build_call_from_pw_multi_aff>.
index daf78dc..7294ed7 100644 (file)
@@ -64,8 +64,8 @@ __isl_give isl_ast_build *isl_ast_build_set_at_each_domain(
                __isl_keep isl_ast_build *build, void *user), void *user);
 __isl_give isl_ast_build *isl_ast_build_set_before_each_for(
        __isl_take isl_ast_build *build,
-       __isl_give isl_ast_node *(*fn)(__isl_take isl_ast_node *node,
-               __isl_keep isl_ast_build *build, void *user), void *user);
+       __isl_give isl_id *(*fn)(__isl_keep isl_ast_build *build,
+               void *user), void *user);
 __isl_give isl_ast_build *isl_ast_build_set_after_each_for(
        __isl_take isl_ast_build *build,
        __isl_give isl_ast_node *(*fn)(__isl_take isl_ast_node *node,
index 0673726..26aabab 100644 (file)
@@ -180,6 +180,10 @@ __isl_give isl_ast_build *isl_ast_build_dup(__isl_keep isl_ast_build *build)
        dup->options = isl_union_map_copy(build->options);
        dup->at_each_domain = build->at_each_domain;
        dup->at_each_domain_user = build->at_each_domain_user;
+       dup->before_each_for = build->before_each_for;
+       dup->before_each_for_user = build->before_each_for_user;
+       dup->after_each_for = build->after_each_for;
+       dup->after_each_for_user = build->after_each_for_user;
        dup->create_leaf = build->create_leaf;
        dup->create_leaf_user = build->create_leaf_user;
 
@@ -338,6 +342,42 @@ __isl_give isl_ast_build *isl_ast_build_set_at_each_domain(
        return build;
 }
 
+/* Set the "before_each_for" callback of "build" to "fn".
+ */
+__isl_give isl_ast_build *isl_ast_build_set_before_each_for(
+       __isl_take isl_ast_build *build,
+       __isl_give isl_id *(*fn)(__isl_keep isl_ast_build *build,
+               void *user), void *user)
+{
+       build = isl_ast_build_cow(build);
+
+       if (!build)
+               return NULL;
+
+       build->before_each_for = fn;
+       build->before_each_for_user = user;
+
+       return build;
+}
+
+/* Set the "after_each_for" callback of "build" to "fn".
+ */
+__isl_give isl_ast_build *isl_ast_build_set_after_each_for(
+       __isl_take isl_ast_build *build,
+       __isl_give isl_ast_node *(*fn)(__isl_take isl_ast_node *node,
+               __isl_keep isl_ast_build *build, void *user), void *user)
+{
+       build = isl_ast_build_cow(build);
+
+       if (!build)
+               return NULL;
+
+       build->after_each_for = fn;
+       build->after_each_for_user = user;
+
+       return build;
+}
+
 /* Set the "create_leaf" callback of "build" to "fn".
  */
 __isl_give isl_ast_build *isl_ast_build_set_create_leaf(
@@ -374,6 +414,10 @@ __isl_give isl_ast_build *isl_ast_build_clear_local_info(
 
        build->at_each_domain = NULL;
        build->at_each_domain_user = NULL;
+       build->before_each_for = NULL;
+       build->before_each_for_user = NULL;
+       build->after_each_for = NULL;
+       build->after_each_for_user = NULL;
        build->create_leaf = NULL;
        build->create_leaf_user = NULL;
 
index 0ee88c5..5881eff 100644 (file)
@@ -98,6 +98,12 @@ enum isl_ast_build_domain_type {
  * an element of the domain.  Each of these nodes is a user node
  * with as expression a call expression.
  *
+ * The "before_each_for" callback is called on each for node before
+ * its children have been created.
+ *
+ * The "after_each_for" callback is called on each for node after
+ * its children have been created.
+ *
  * "executed" contains the inverse schedule at this point
  * of the AST generation.
  * It is currently only used in isl_ast_build_get_schedule, which is
@@ -131,6 +137,14 @@ struct isl_ast_build {
                __isl_keep isl_ast_build *build, void *user);
        void *at_each_domain_user;
 
+       __isl_give isl_id *(*before_each_for)(
+               __isl_keep isl_ast_build *context, void *user);
+       void *before_each_for_user;
+       __isl_give isl_ast_node *(*after_each_for)(
+               __isl_take isl_ast_node *node,
+               __isl_keep isl_ast_build *context, void *user);
+       void *after_each_for_user;
+
        __isl_give isl_ast_node *(*create_leaf)(
                __isl_take isl_ast_build *build, void *user);
        void *create_leaf_user;
index 89bc512..e0e217d 100644 (file)
@@ -269,6 +269,38 @@ error:             data.list = NULL;
        return data.list;
 }
 
+/* Call the before_each_for callback, if requested by the user.
+ */
+static __isl_give isl_ast_node *before_each_for(__isl_take isl_ast_node *node,
+       __isl_keep isl_ast_build *build)
+{
+       isl_id *id;
+
+       if (!node || !build)
+               return isl_ast_node_free(node);
+       if (!build->before_each_for)
+               return node;
+       id = build->before_each_for(build, build->before_each_for_user);
+       node = isl_ast_node_set_annotation(node, id);
+       return node;
+}
+
+/* Call the after_each_for callback, if requested by the user.
+ */
+static __isl_give isl_ast_graft *after_each_for(__isl_keep isl_ast_graft *graft,
+       __isl_keep isl_ast_build *build)
+{
+       if (!graft || !build)
+               isl_ast_graft_free(graft);
+       if (!build->after_each_for)
+               return graft;
+       graft->node = build->after_each_for(graft->node, build,
+                                               build->after_each_for_user);
+       if (!graft->node)
+               return isl_ast_graft_free(graft);
+       return graft;
+}
+
 /* Eliminate the schedule dimension "pos" from "executed" and return
  * the result.
  */
@@ -1177,6 +1209,9 @@ static __isl_give isl_ast_node *create_for(__isl_keep isl_ast_build *build,
  * we performed separation with explicit bounds.
  * The very first step is then to copy these constraints to "bounds".
  *
+ * Since we may be calling before_each_for and after_each_for
+ * callbacks, we record the current inverse schedule in the build.
+ *
  * We consider three builds,
  * "build" is the one in which the current level is created,
  * "body_build" is the build in which the next level is created,
@@ -1230,6 +1265,7 @@ static __isl_give isl_ast_graft *create_node_scaled(
        domain = isl_set_detect_equalities(domain);
        hull = isl_set_unshifted_simple_hull(isl_set_copy(domain));
        bounds = isl_basic_set_intersect(bounds, hull);
+       build = isl_ast_build_set_executed(build, isl_union_map_copy(executed));
 
        depth = isl_ast_build_get_depth(build);
        sub_build = isl_ast_build_copy(build);
@@ -1247,6 +1283,8 @@ static __isl_give isl_ast_graft *create_node_scaled(
 
        body_build = isl_ast_build_copy(sub_build);
        body_build = isl_ast_build_increase_depth(body_build);
+       if (!eliminated)
+               node = before_each_for(node, body_build);
        children = generate_next_level(executed,
                                    isl_ast_build_copy(body_build));
 
@@ -1259,6 +1297,8 @@ static __isl_give isl_ast_graft *create_node_scaled(
                graft = refine_degenerate(graft, bounds, build, sub_build);
        else
                graft = refine_generic(graft, bounds, domain, build);
+       if (!eliminated)
+               graft = after_each_for(graft, body_build);
 
        isl_ast_build_free(body_build);
        isl_ast_build_free(sub_build);
index 09c88f9..b39c682 100644 (file)
@@ -3438,6 +3438,158 @@ static int test_ast(isl_ctx *ctx)
        return 0;
 }
 
+/* Internal data structure for before_for and after_for callbacks.
+ *
+ * depth is the current depth
+ * before is the number of times before_for has been called
+ * after is the number of times after_for has been called
+ */
+struct isl_test_codegen_data {
+       int depth;
+       int before;
+       int after;
+};
+
+/* This function is called before each for loop in the AST generated
+ * from test_ast_gen1.
+ *
+ * Increment the number of calls and the depth.
+ * Check that the space returned by isl_ast_build_get_schedule_space
+ * matches the target space of the schedule returned by
+ * isl_ast_build_get_schedule.
+ * Return an isl_id that is checked by the corresponding call
+ * to after_for.
+ */
+static __isl_give isl_id *before_for(__isl_keep isl_ast_build *build,
+       void *user)
+{
+       struct isl_test_codegen_data *data = user;
+       isl_ctx *ctx;
+       isl_space *space;
+       isl_union_map *schedule;
+       isl_union_set *uset;
+       isl_set *set;
+       int empty;
+       char name[] = "d0";
+
+       ctx = isl_ast_build_get_ctx(build);
+
+       if (data->before >= 3)
+               isl_die(ctx, isl_error_unknown,
+                       "unexpected number of for nodes", return NULL);
+       if (data->depth >= 2)
+               isl_die(ctx, isl_error_unknown,
+                       "unexpected depth", return NULL);
+
+       snprintf(name, sizeof(name), "d%d", data->depth);
+       data->before++;
+       data->depth++;
+
+       schedule = isl_ast_build_get_schedule(build);
+       uset = isl_union_map_range(schedule);
+       if (isl_union_set_n_set(uset) != 1) {
+               isl_union_set_free(uset);
+               isl_die(ctx, isl_error_unknown,
+                       "expecting single range space", return NULL);
+       }
+
+       space = isl_ast_build_get_schedule_space(build);
+       set = isl_union_set_extract_set(uset, space);
+       isl_union_set_free(uset);
+       empty = isl_set_is_empty(set);
+       isl_set_free(set);
+
+       if (empty < 0)
+               return NULL;
+       if (empty)
+               isl_die(ctx, isl_error_unknown,
+                       "spaces don't match", return NULL);
+
+       return isl_id_alloc(ctx, name, NULL);
+}
+
+/* This function is called after each for loop in the AST generated
+ * from test_ast_gen1.
+ *
+ * Increment the number of calls and decrement the depth.
+ * Check that the annotation attached to the node matches
+ * the isl_id returned by the corresponding call to before_for.
+ */
+static __isl_give isl_ast_node *after_for(__isl_take isl_ast_node *node,
+       __isl_keep isl_ast_build *build, void *user)
+{
+       struct isl_test_codegen_data *data = user;
+       isl_id *id;
+       const char *name;
+       int valid;
+
+       data->after++;
+       data->depth--;
+
+       if (data->after > data->before)
+               isl_die(isl_ast_node_get_ctx(node), isl_error_unknown,
+                       "mismatch in number of for nodes",
+                       return isl_ast_node_free(node));
+
+       id = isl_ast_node_get_annotation(node);
+       if (!id)
+               isl_die(isl_ast_node_get_ctx(node), isl_error_unknown,
+                       "missing annotation", return isl_ast_node_free(node));
+
+       name = isl_id_get_name(id);
+       valid = name && atoi(name + 1) == data->depth;
+       isl_id_free(id);
+
+       if (!valid)
+               isl_die(isl_ast_node_get_ctx(node), isl_error_unknown,
+                       "wrong annotation", return isl_ast_node_free(node));
+
+       return node;
+}
+
+/* Check that the before_each_for and after_each_for callbacks
+ * are called for each for loop in the generated code,
+ * that they are called in the right order and that the isl_id
+ * returned from the before_each_for callback is attached to
+ * the isl_ast_node passed to the corresponding after_each_for call.
+ */
+static int test_ast_gen1(isl_ctx *ctx)
+{
+       const char *str;
+       isl_set *set;
+       isl_union_map *schedule;
+       isl_ast_build *build;
+       isl_ast_node *tree;
+       struct isl_test_codegen_data data;
+
+       str = "[N] -> { : N >= 10 }";
+       set = isl_set_read_from_str(ctx, str);
+       str = "[N] -> { A[i,j] -> S[8,i,3,j] : 0 <= i,j <= N; "
+                   "B[i,j] -> S[8,j,9,i] : 0 <= i,j <= N }";
+       schedule = isl_union_map_read_from_str(ctx, str);
+
+       data.before = 0;
+       data.after = 0;
+       data.depth = 0;
+       build = isl_ast_build_from_context(set);
+       build = isl_ast_build_set_before_each_for(build,
+                       &before_for, &data);
+       build = isl_ast_build_set_after_each_for(build,
+                       &after_for, &data);
+       tree = isl_ast_build_ast_from_schedule(build, schedule);
+       isl_ast_build_free(build);
+       if (!tree)
+               return -1;
+
+       isl_ast_node_free(tree);
+
+       if (data.before != 3 || data.after != 3)
+               isl_die(ctx, isl_error_unknown,
+                       "unexpected number of for nodes", return -1);
+
+       return 0;
+}
+
 /* Check that the AST generator handles domains that are integrally disjoint
  * but not ratinoally disjoint.
  */
@@ -3576,6 +3728,8 @@ static int test_ast_gen4(isl_ctx *ctx)
 
 static int test_ast_gen(isl_ctx *ctx)
 {
+       if (test_ast_gen1(ctx) < 0)
+               return -1;
        if (test_ast_gen2(ctx) < 0)
                return -1;
        if (test_ast_gen3(ctx) < 0)