isl_basic_set_expand_divs: avoid invalid access on error
[platform/upstream/isl.git] / isl_ast_codegen.c
index c0afa0a..9d4d93b 100644 (file)
@@ -151,9 +151,11 @@ static __isl_give isl_ast_graft *at_each_domain(__isl_take isl_ast_graft *graft,
  * in generate_non_single_valued.
  * Note that the inverse schedule being single-valued may depend
  * on constraints that are only available in the original context
- * domain specified by the user.  If the bare inverse schedule
- * is not single-valued, we double-check after introducing the constraints
- * from data->build->domain.
+ * domain specified by the user.  We therefore first introduce
+ * the constraints from data->build->domain.
+ * On the other hand, we only perform the test after having taken the gist
+ * of the domain as the resulting map is the one from which the call
+ * expression is constructed.
  *
  * Otherwise, we generate a call expression for the single executed
  * domain element and put a guard around it based on the (simplified)
@@ -171,22 +173,19 @@ static int generate_domain(__isl_take isl_map *executed, void *user)
        isl_map *map;
        int sv;
 
-       sv = isl_map_is_single_valued(executed);
-       if (sv < 0)
-               goto error;
-       if (!sv) {
-               map = isl_map_copy(executed);
-               map = isl_map_intersect_domain(map,
+       executed = isl_map_intersect_domain(executed,
                                            isl_set_copy(data->build->domain));
-               sv = isl_map_is_single_valued(map);
-               isl_map_free(map);
-       }
-       if (!sv)
-               return generate_non_single_valued(executed, data);
 
        executed = isl_map_coalesce(executed);
        map = isl_map_copy(executed);
        map = isl_ast_build_compute_gist_map_domain(data->build, map);
+       sv = isl_map_is_single_valued(map);
+       if (sv < 0)
+               goto error;
+       if (!sv) {
+               isl_map_free(map);
+               return generate_non_single_valued(executed, data);
+       }
        guard = isl_map_domain(isl_map_copy(map));
        guard = isl_set_coalesce(guard);
        guard = isl_ast_build_compute_gist(data->build, guard);
@@ -201,6 +200,7 @@ static int generate_domain(__isl_take isl_map *executed, void *user)
 
        return 0;
 error:
+       isl_map_free(map);
        isl_map_free(executed);
        return -1;
 }
@@ -269,6 +269,38 @@ error:             data.list = NULL;
        return data.list;
 }
 
+/* Call the before_each_for callback, if requested by the user.
+ */
+static __isl_give isl_ast_node *before_each_for(__isl_take isl_ast_node *node,
+       __isl_keep isl_ast_build *build)
+{
+       isl_id *id;
+
+       if (!node || !build)
+               return isl_ast_node_free(node);
+       if (!build->before_each_for)
+               return node;
+       id = build->before_each_for(build, build->before_each_for_user);
+       node = isl_ast_node_set_annotation(node, id);
+       return node;
+}
+
+/* Call the after_each_for callback, if requested by the user.
+ */
+static __isl_give isl_ast_graft *after_each_for(__isl_keep isl_ast_graft *graft,
+       __isl_keep isl_ast_build *build)
+{
+       if (!graft || !build)
+               return isl_ast_graft_free(graft);
+       if (!build->after_each_for)
+               return graft;
+       graft->node = build->after_each_for(graft->node, build,
+                                               build->after_each_for_user);
+       if (!graft->node)
+               return isl_ast_graft_free(graft);
+       return graft;
+}
+
 /* Eliminate the schedule dimension "pos" from "executed" and return
  * the result.
  */
@@ -716,27 +748,121 @@ static __isl_give isl_ast_graft *set_enforced_from_list(
        return graft;
 }
 
+/* Does "aff" have a negative constant term?
+ */
+static int aff_constant_is_negative(__isl_take isl_set *set,
+       __isl_take isl_aff *aff, void *user)
+{
+       int *neg = user;
+       isl_int v;
+
+       isl_int_init(v);
+       isl_aff_get_constant(aff, &v);
+       *neg = isl_int_is_neg(v);
+       isl_int_clear(v);
+       isl_set_free(set);
+       isl_aff_free(aff);
+
+       return *neg ? 0 : -1;
+}
+
+/* Does "pa" have a negative constant term over its entire domain?
+ */
+static int pw_aff_constant_is_negative(__isl_take isl_pw_aff *pa, void *user)
+{
+       int r;
+       int *neg = user;
+
+       r = isl_pw_aff_foreach_piece(pa, &aff_constant_is_negative, user);
+       isl_pw_aff_free(pa);
+
+       return *neg ? 0 : -1;
+}
+
+/* Does each element in "list" have a negative constant term?
+ *
+ * The callback terminates the iteration as soon an element has been
+ * found that does not have a negative constant term.
+ */
+static int list_constant_is_negative(__isl_keep isl_pw_aff_list *list)
+{
+       int neg = 1;
+
+       if (isl_pw_aff_list_foreach(list,
+                               &pw_aff_constant_is_negative, &neg) < 0 && neg)
+               return -1;
+
+       return neg;
+}
+
+/* Add 1 to each of the elements in "list", where each of these elements
+ * is defined over the internal schedule space of "build".
+ */
+static __isl_give isl_pw_aff_list *list_add_one(
+       __isl_take isl_pw_aff_list *list, __isl_keep isl_ast_build *build)
+{
+       int i, n;
+       isl_space *space;
+       isl_aff *aff;
+       isl_pw_aff *one;
+
+       space = isl_ast_build_get_space(build, 1);
+       aff = isl_aff_zero_on_domain(isl_local_space_from_space(space));
+       aff = isl_aff_add_constant_si(aff, 1);
+       one = isl_pw_aff_from_aff(aff);
+
+       n = isl_pw_aff_list_n_pw_aff(list);
+       for (i = 0; i < n; ++i) {
+               isl_pw_aff *pa;
+               pa = isl_pw_aff_list_get_pw_aff(list, i);
+               pa = isl_pw_aff_add(pa, isl_pw_aff_copy(one));
+               list = isl_pw_aff_list_set_pw_aff(list, i, pa);
+       }
+
+       isl_pw_aff_free(one);
+
+       return list;
+}
+
 /* Set the condition part of the for node graft->node in case
  * the upper bound is represented as a list of piecewise affine expressions.
  *
  * In particular, set the condition to
  *
  *     iterator <= min(list of upper bounds)
+ *
+ * If each of the upper bounds has a negative constant term, then
+ * set the condition to
+ *
+ *     iterator < min(list of (upper bound + 1)s)
+ *
  */
 static __isl_give isl_ast_graft *set_for_cond_from_list(
        __isl_take isl_ast_graft *graft, __isl_keep isl_pw_aff_list *list,
        __isl_keep isl_ast_build *build)
 {
+       int neg;
        isl_ast_expr *bound, *iterator, *cond;
+       enum isl_ast_op_type type = isl_ast_op_le;
 
        if (!graft || !list)
                return isl_ast_graft_free(graft);
 
+       neg = list_constant_is_negative(list);
+       if (neg < 0)
+               return isl_ast_graft_free(graft);
+       list = isl_pw_aff_list_copy(list);
+       if (neg) {
+               list = list_add_one(list, build);
+               type = isl_ast_op_lt;
+       }
+
        bound = reduce_list(isl_ast_op_min, list, build);
        iterator = isl_ast_expr_copy(graft->node->u.f.iterator);
-       cond = isl_ast_expr_alloc_binary(isl_ast_op_le, iterator, bound);
+       cond = isl_ast_expr_alloc_binary(type, iterator, bound);
        graft->node->u.f.cond = cond;
 
+       isl_pw_aff_list_free(list);
        if (!graft->node->u.f.cond)
                return isl_ast_graft_free(graft);
        return graft;
@@ -771,6 +897,8 @@ static __isl_give isl_ast_expr *for_inc(__isl_keep isl_ast_build *build)
        isl_ctx *ctx;
        isl_ast_expr *inc;
 
+       if (!build)
+               return NULL;
        ctx = isl_ast_build_get_ctx(build);
        depth = isl_ast_build_get_depth(build);
 
@@ -1083,6 +1211,9 @@ static __isl_give isl_ast_node *create_for(__isl_keep isl_ast_build *build,
  * we performed separation with explicit bounds.
  * The very first step is then to copy these constraints to "bounds".
  *
+ * Since we may be calling before_each_for and after_each_for
+ * callbacks, we record the current inverse schedule in the build.
+ *
  * We consider three builds,
  * "build" is the one in which the current level is created,
  * "body_build" is the build in which the next level is created,
@@ -1136,6 +1267,7 @@ static __isl_give isl_ast_graft *create_node_scaled(
        domain = isl_set_detect_equalities(domain);
        hull = isl_set_unshifted_simple_hull(isl_set_copy(domain));
        bounds = isl_basic_set_intersect(bounds, hull);
+       build = isl_ast_build_set_executed(build, isl_union_map_copy(executed));
 
        depth = isl_ast_build_get_depth(build);
        sub_build = isl_ast_build_copy(build);
@@ -1153,6 +1285,8 @@ static __isl_give isl_ast_graft *create_node_scaled(
 
        body_build = isl_ast_build_copy(sub_build);
        body_build = isl_ast_build_increase_depth(body_build);
+       if (!eliminated)
+               node = before_each_for(node, body_build);
        children = generate_next_level(executed,
                                    isl_ast_build_copy(body_build));
 
@@ -1165,6 +1299,8 @@ static __isl_give isl_ast_graft *create_node_scaled(
                graft = refine_degenerate(graft, bounds, build, sub_build);
        else
                graft = refine_generic(graft, bounds, domain, build);
+       if (!eliminated)
+               graft = after_each_for(graft, body_build);
 
        isl_ast_build_free(body_build);
        isl_ast_build_free(sub_build);
@@ -1322,7 +1458,7 @@ static __isl_give isl_ast_graft *create_node(__isl_take isl_union_map *executed,
        if (isl_aff_get_denominator(offset, &data.d) < 0)
                executed = isl_union_map_free(executed);
 
-       if (isl_int_is_divisible_by(data.m, data.d))
+       if (executed && isl_int_is_divisible_by(data.m, data.d))
                isl_int_divexact(data.m, data.m, data.d);
        else
                isl_int_set_si(data.m, 1);
@@ -1601,6 +1737,8 @@ static __isl_give isl_ast_graft_list *generate_sorted_domains(
        data.depth = isl_ast_build_get_depth(build);
        data.piece = domain_list->p;
        g = isl_tarjan_graph_init(ctx, n, &domain_follows_at_depth, &data);
+       if (!g)
+               goto error;
 
        i = 0;
        while (list && n) {
@@ -2129,10 +2267,6 @@ static __isl_give isl_basic_set_list *do_unroll(__isl_take isl_set *domain,
  * the user, except that inner dimensions have been eliminated and
  * that they have been made pair-wise disjoint.
  *
- * "includes_schedule_domain" is set if the "class_domain" (not stored
- * in this structure, but passed to the various functions) has been
- * intersected with "schedule_domain".
- *
  * "sep_class" contains the user-specified split into separation classes
  * specialized to the current depth.
  * "done" contains the union of th separation domains that have already
@@ -2147,8 +2281,6 @@ struct isl_codegen_domains {
 
        isl_set *option[3];
 
-       int includes_schedule_domain;
-
        isl_map *sep_class;
        isl_set *done;
 };
@@ -2297,6 +2429,10 @@ static int compute_separate_domain(struct isl_codegen_domains *domains,
  * basic sets for which code should be generated separately
  * for the given separation class domain.
  *
+ * If any separation classes have been defined, then "class_domain"
+ * is the domain of the current class and does not refer to inner dimensions.
+ * Otherwise, "class_domain" is the universe domain.
+ *
  * We first make sure that the class domain is disjoint from
  * previously considered class domains.
  *
@@ -2311,6 +2447,9 @@ static int compute_separate_domain(struct isl_codegen_domains *domains,
  *
  * For atomic and remainder domains, inner dimensions and divs involving
  * the current dimensions should be eliminated.
+ * In case we are working within a separation class, we need to intersect
+ * the result with the current "class_domain" to ensure that the domains
+ * are disjoint from those generated from other class domains.
  *
  * If anything is left after handling separate, unroll and atomic,
  * we split it up into basic sets and append the basic sets to domains->list.
@@ -2319,42 +2458,47 @@ static int compute_partial_domains(struct isl_codegen_domains *domains,
        __isl_take isl_set *class_domain)
 {
        isl_basic_set_list *list;
+       isl_set *domain;
 
        class_domain = isl_set_subtract(class_domain,
                                        isl_set_copy(domains->done));
        domains->done = isl_set_union(domains->done,
                                        isl_set_copy(class_domain));
 
-       if (compute_separate_domain(domains, class_domain) < 0)
+       domain = isl_set_copy(class_domain);
+
+       if (compute_separate_domain(domains, domain) < 0)
                goto error;
-       class_domain = isl_set_subtract(class_domain,
+       domain = isl_set_subtract(domain,
                                    isl_set_copy(domains->option[separate]));
 
-       if (!domains->includes_schedule_domain)
-               class_domain = isl_set_intersect(class_domain,
-                                       isl_set_copy(domains->schedule_domain));
+       domain = isl_set_intersect(domain,
+                               isl_set_copy(domains->schedule_domain));
 
-       if (compute_unroll_domains(domains, class_domain) < 0)
+       if (compute_unroll_domains(domains, domain) < 0)
                goto error;
-       class_domain = isl_set_subtract(class_domain,
+       domain = isl_set_subtract(domain,
                                    isl_set_copy(domains->option[unroll]));
 
-       class_domain = isl_ast_build_eliminate(domains->build,
-                                       class_domain);
+       domain = isl_ast_build_eliminate(domains->build, domain);
+       domain = isl_set_intersect(domain, isl_set_copy(class_domain));
 
-       if (compute_atomic_domain(domains, class_domain) < 0)
+       if (compute_atomic_domain(domains, domain) < 0)
                goto error;
-       class_domain = isl_set_subtract(class_domain,
+       domain = isl_set_subtract(domain,
                                    isl_set_copy(domains->option[atomic]));
 
-       class_domain = isl_set_coalesce(class_domain);
-       class_domain = isl_set_make_disjoint(class_domain);
+       domain = isl_set_coalesce(domain);
+       domain = isl_set_make_disjoint(domain);
 
-       list = isl_basic_set_list_from_set(class_domain);
+       list = isl_basic_set_list_from_set(domain);
        domains->list = isl_basic_set_list_concat(domains->list, list);
 
+       isl_set_free(class_domain);
+
        return 0;
 error:
+       isl_set_free(domain);
        isl_set_free(class_domain);
        return -1;
 }
@@ -2386,7 +2530,6 @@ static int compute_class_domains(__isl_take isl_point *pnt, void *user)
                return 0;
        }
 
-       domains->includes_schedule_domain = 0;
        return compute_partial_domains(domains, domain);
 }
 
@@ -2428,9 +2571,8 @@ static void compute_domains_init_options(isl_set *option[3],
  * and split up the domain for each of them separately.
  * Finally, we consider the remainder.  If no separation classes were
  * specified, then we call compute_partial_domains with the universe
- * "class_domain".  Otherwise, we take the "schedule_domain" as "class_domain"
- * and set includes_schedule_domain to reflect that the schedule domain
- * has already been taken into account.  We do this because we want to
+ * "class_domain".  Otherwise, we take the "schedule_domain" as "class_domain",
+ * with inner dimensions removed.  We do this because we want to
  * avoid computing the complement of the class domains (i.e., the difference
  * between the universe and domains->done).
  */
@@ -2445,6 +2587,10 @@ static __isl_give isl_basic_set_list *compute_domains(
        isl_space *space;
        int n_param;
        enum isl_ast_build_domain_type type;
+       int empty;
+
+       if (!executed)
+               return NULL;
 
        ctx = isl_union_map_get_ctx(executed);
        domains.list = isl_basic_set_list_alloc(ctx, 0);
@@ -2469,12 +2615,15 @@ static __isl_give isl_basic_set_list *compute_domains(
                domains.list = isl_basic_set_list_free(domains.list);
        isl_set_free(classes);
 
-       if (!domains.done)
+       empty = isl_set_is_empty(domains.done);
+       if (empty < 0) {
                domains.list = isl_basic_set_list_free(domains.list);
-       domains.includes_schedule_domain = !isl_set_is_empty(domains.done);
-       if (!domains.includes_schedule_domain) {
+               domain = isl_set_free(domain);
+       } else if (empty) {
                isl_set_free(domain);
                domain = isl_set_universe(isl_set_get_space(domains.done));
+       } else {
+               domain = isl_ast_build_eliminate(build, domain);
        }
        if (compute_partial_domains(&domains, domain) < 0)
                domains.list = isl_basic_set_list_free(domains.list);