isl_basic_set_expand_divs: avoid invalid access on error
[platform/upstream/isl.git] / isl_ast_codegen.c
index 89bc512..9d4d93b 100644 (file)
@@ -269,6 +269,38 @@ error:             data.list = NULL;
        return data.list;
 }
 
+/* Call the before_each_for callback, if requested by the user.
+ */
+static __isl_give isl_ast_node *before_each_for(__isl_take isl_ast_node *node,
+       __isl_keep isl_ast_build *build)
+{
+       isl_id *id;
+
+       if (!node || !build)
+               return isl_ast_node_free(node);
+       if (!build->before_each_for)
+               return node;
+       id = build->before_each_for(build, build->before_each_for_user);
+       node = isl_ast_node_set_annotation(node, id);
+       return node;
+}
+
+/* Call the after_each_for callback, if requested by the user.
+ */
+static __isl_give isl_ast_graft *after_each_for(__isl_keep isl_ast_graft *graft,
+       __isl_keep isl_ast_build *build)
+{
+       if (!graft || !build)
+               return isl_ast_graft_free(graft);
+       if (!build->after_each_for)
+               return graft;
+       graft->node = build->after_each_for(graft->node, build,
+                                               build->after_each_for_user);
+       if (!graft->node)
+               return isl_ast_graft_free(graft);
+       return graft;
+}
+
 /* Eliminate the schedule dimension "pos" from "executed" and return
  * the result.
  */
@@ -865,6 +897,8 @@ static __isl_give isl_ast_expr *for_inc(__isl_keep isl_ast_build *build)
        isl_ctx *ctx;
        isl_ast_expr *inc;
 
+       if (!build)
+               return NULL;
        ctx = isl_ast_build_get_ctx(build);
        depth = isl_ast_build_get_depth(build);
 
@@ -1177,6 +1211,9 @@ static __isl_give isl_ast_node *create_for(__isl_keep isl_ast_build *build,
  * we performed separation with explicit bounds.
  * The very first step is then to copy these constraints to "bounds".
  *
+ * Since we may be calling before_each_for and after_each_for
+ * callbacks, we record the current inverse schedule in the build.
+ *
  * We consider three builds,
  * "build" is the one in which the current level is created,
  * "body_build" is the build in which the next level is created,
@@ -1230,6 +1267,7 @@ static __isl_give isl_ast_graft *create_node_scaled(
        domain = isl_set_detect_equalities(domain);
        hull = isl_set_unshifted_simple_hull(isl_set_copy(domain));
        bounds = isl_basic_set_intersect(bounds, hull);
+       build = isl_ast_build_set_executed(build, isl_union_map_copy(executed));
 
        depth = isl_ast_build_get_depth(build);
        sub_build = isl_ast_build_copy(build);
@@ -1247,6 +1285,8 @@ static __isl_give isl_ast_graft *create_node_scaled(
 
        body_build = isl_ast_build_copy(sub_build);
        body_build = isl_ast_build_increase_depth(body_build);
+       if (!eliminated)
+               node = before_each_for(node, body_build);
        children = generate_next_level(executed,
                                    isl_ast_build_copy(body_build));
 
@@ -1259,6 +1299,8 @@ static __isl_give isl_ast_graft *create_node_scaled(
                graft = refine_degenerate(graft, bounds, build, sub_build);
        else
                graft = refine_generic(graft, bounds, domain, build);
+       if (!eliminated)
+               graft = after_each_for(graft, body_build);
 
        isl_ast_build_free(body_build);
        isl_ast_build_free(sub_build);
@@ -1416,7 +1458,7 @@ static __isl_give isl_ast_graft *create_node(__isl_take isl_union_map *executed,
        if (isl_aff_get_denominator(offset, &data.d) < 0)
                executed = isl_union_map_free(executed);
 
-       if (isl_int_is_divisible_by(data.m, data.d))
+       if (executed && isl_int_is_divisible_by(data.m, data.d))
                isl_int_divexact(data.m, data.m, data.d);
        else
                isl_int_set_si(data.m, 1);
@@ -1695,6 +1737,8 @@ static __isl_give isl_ast_graft_list *generate_sorted_domains(
        data.depth = isl_ast_build_get_depth(build);
        data.piece = domain_list->p;
        g = isl_tarjan_graph_init(ctx, n, &domain_follows_at_depth, &data);
+       if (!g)
+               goto error;
 
        i = 0;
        while (list && n) {
@@ -2545,6 +2589,9 @@ static __isl_give isl_basic_set_list *compute_domains(
        enum isl_ast_build_domain_type type;
        int empty;
 
+       if (!executed)
+               return NULL;
+
        ctx = isl_union_map_get_ctx(executed);
        domains.list = isl_basic_set_list_alloc(ctx, 0);