isl_basic_set_opt: avoid invalid access on error path
[platform/upstream/isl.git] / isl_ast_codegen.c
index bec2939..a0717a9 100644 (file)
@@ -7,6 +7,7 @@
  * Ecole Normale Superieure, 45 rue d’Ulm, 75230 Paris, France
  */
 
+#include <limits.h>
 #include <isl/aff.h>
 #include <isl/set.h>
 #include <isl/ilp.h>
@@ -108,6 +109,7 @@ static int generate_non_single_valued(__isl_take isl_map *executed,
 
        identity = isl_set_identity(isl_map_range(isl_map_copy(executed)));
        executed = isl_map_domain_product(executed, identity);
+       build = isl_ast_build_set_single_valued(build, 1);
 
        list = generate_code(isl_union_map_from_map(executed), build, 1);
 
@@ -155,7 +157,16 @@ static __isl_give isl_ast_graft *at_each_domain(__isl_take isl_ast_graft *graft,
  * the constraints from data->build->domain.
  * On the other hand, we only perform the test after having taken the gist
  * of the domain as the resulting map is the one from which the call
- * expression is constructed.
+ * expression is constructed.  Using this map to construct the call
+ * expression usually yields simpler results.
+ * Because we perform the single-valuedness test on the gisted map,
+ * we may in rare cases fail to recognize that the inverse schedule
+ * is single-valued.  This becomes problematic if this happens
+ * from the recursive call through generate_non_single_valued
+ * as we would then end up in an infinite recursion.
+ * We therefore check if we are inside a call to generate_non_single_valued
+ * and revert to the ungisted map if the gisted map turns out not to be
+ * single-valued.
  *
  * Otherwise, we generate a call expression for the single executed
  * domain element and put a guard around it based on the (simplified)
@@ -184,7 +195,10 @@ static int generate_domain(__isl_take isl_map *executed, void *user)
                goto error;
        if (!sv) {
                isl_map_free(map);
-               return generate_non_single_valued(executed, data);
+               if (data->build->single_valued)
+                       map = isl_map_copy(executed);
+               else
+                       return generate_non_single_valued(executed, data);
        }
        guard = isl_map_domain(isl_map_copy(map));
        guard = isl_set_coalesce(guard);
@@ -1141,6 +1155,28 @@ static __isl_give isl_ast_graft *refine_generic_split(
        return graft;
 }
 
+/* Add the guard implied by the current stride constraint (if any),
+ * but not (necessarily) enforced by the generated AST to "graft".
+ */
+static __isl_give isl_ast_graft *add_stride_guard(
+       __isl_take isl_ast_graft *graft, __isl_keep isl_ast_build *build)
+{
+       int depth;
+       isl_set *dom;
+
+       depth = isl_ast_build_get_depth(build);
+       if (!isl_ast_build_has_stride(build, depth))
+               return graft;
+
+       dom = isl_ast_build_get_stride_constraint(build);
+       dom = isl_set_eliminate(dom, isl_dim_set, depth, 1);
+       dom = isl_ast_build_compute_gist(build, dom);
+
+       graft = isl_ast_graft_add_guard(graft, dom, build);
+
+       return graft;
+}
+
 /* Update "graft" based on "bounds" and "domain" for the generic,
  * non-degenerate, case.
  *
@@ -1167,6 +1203,7 @@ static __isl_give isl_ast_graft *refine_generic(
        list = isl_constraint_list_from_basic_set(bounds);
 
        graft = refine_generic_split(graft, list, domain, build);
+       graft = add_stride_guard(graft, build);
 
        isl_constraint_list_free(list);
        return graft;
@@ -1290,7 +1327,7 @@ static __isl_give isl_ast_graft *create_node_scaled(
        children = generate_next_level(executed,
                                    isl_ast_build_copy(body_build));
 
-       graft = isl_ast_graft_alloc_level(children, sub_build);
+       graft = isl_ast_graft_alloc_level(children, build, sub_build);
        if (!eliminated)
                graft = isl_ast_graft_insert_for(graft, node);
        if (eliminated)
@@ -1458,7 +1495,7 @@ static __isl_give isl_ast_graft *create_node(__isl_take isl_union_map *executed,
        if (isl_aff_get_denominator(offset, &data.d) < 0)
                executed = isl_union_map_free(executed);
 
-       if (isl_int_is_divisible_by(data.m, data.d))
+       if (executed && isl_int_is_divisible_by(data.m, data.d))
                isl_int_divexact(data.m, data.m, data.d);
        else
                isl_int_set_si(data.m, 1);
@@ -1737,6 +1774,8 @@ static __isl_give isl_ast_graft_list *generate_sorted_domains(
        data.depth = isl_ast_build_get_depth(build);
        data.piece = domain_list->p;
        g = isl_tarjan_graph_init(ctx, n, &domain_follows_at_depth, &data);
+       if (!g)
+               goto error;
 
        i = 0;
        while (list && n) {
@@ -2075,7 +2114,8 @@ static int update_unrolling_lower_bound(struct isl_find_unroll_data *data,
                return 0;
        }
 
-       if (!data->lower || isl_int_cmp_si(data->tmp, *data->n) < 0) {
+       if (isl_int_cmp_si(data->tmp, INT_MAX) <= 0 &&
+           (!data->lower || isl_int_cmp_si(data->tmp, *data->n) < 0)) {
                isl_aff_free(data->lower);
                data->lower = lower;
                *data->n = isl_int_get_si(data->tmp);
@@ -2151,22 +2191,17 @@ error:
        return isl_aff_free(data.lower);
 }
 
-/* Intersect "set" with the constraint
+/* Return the constraint
  *
  *     i_"depth" = aff + offset
  */
-static __isl_give isl_set *at_offset(__isl_take isl_set *set, int depth,
-       __isl_keep isl_aff *aff, int offset)
+static __isl_give isl_constraint *at_offset(int depth, __isl_keep isl_aff *aff,
+       int offset)
 {
-       isl_constraint *eq;
-
        aff = isl_aff_copy(aff);
        aff = isl_aff_add_coefficient_si(aff, isl_dim_in, depth, -1);
        aff = isl_aff_add_constant_si(aff, offset);
-       eq = isl_equality_from_aff(aff);
-       set = isl_set_add_constraint(set, eq);
-
-       return set;
+       return isl_equality_from_aff(aff);
 }
 
 /* Return a list of basic sets, one for each value of the current dimension
@@ -2192,7 +2227,10 @@ static __isl_give isl_set *at_offset(__isl_take isl_set *set, int depth,
  *
  * We compute the unshifted simple hull of each slice to ensure that
  * we have a single basic set per offset.  The slicing constraint
- * is preserved by taking the unshifted simple hull, so these basic sets
+ * may get simplified away before the unshifted simple hull is taken
+ * and may therefore in some rare cases disappear from the result.
+ * We therefore explicitly add the constraint back after computing
+ * the unshifted simple hull to ensure that the basic sets
  * remain disjoint.  The constraints that are dropped by taking the hull
  * will be taken into account at the next level, as in the case of the
  * atomic option.
@@ -2237,9 +2275,13 @@ static __isl_give isl_basic_set_list *do_unroll(__isl_take isl_set *domain,
        for (i = 0; list && i < n; ++i) {
                isl_set *set;
                isl_basic_set *bset;
+               isl_constraint *slice;
 
-               set = at_offset(isl_set_copy(domain), depth, lower, i);
+               slice = at_offset(depth, lower, i);
+               set = isl_set_copy(domain);
+               set = isl_set_add_constraint(set, isl_constraint_copy(slice));
                bset = isl_set_unshifted_simple_hull(set);
+               bset = isl_basic_set_add_constraint(bset, slice);
                bset = isl_basic_set_apply(bset, isl_basic_map_copy(bmap));
                list = isl_basic_set_list_add(list, bset);
        }
@@ -2269,6 +2311,9 @@ static __isl_give isl_basic_set_list *do_unroll(__isl_take isl_set *domain,
  * specialized to the current depth.
  * "done" contains the union of th separation domains that have already
  * been handled.
+ * "atomic" contains the domain that has effectively been made atomic.
+ * This domain may be larger than the intersection of option[atomic]
+ * and the schedule domain.
  */
 struct isl_codegen_domains {
        isl_basic_set_list *list;
@@ -2281,6 +2326,7 @@ struct isl_codegen_domains {
 
        isl_map *sep_class;
        isl_set *done;
+       isl_set *atomic;
 };
 
 /* Add domains to domains->list for each individual value of the current
@@ -2345,7 +2391,8 @@ static int compute_unroll_domains(struct isl_codegen_domains *domains,
 
 /* Construct a single basic set that includes the intersection of
  * the schedule domain, the atomic option domain and the class domain.
- * Add the resulting basic set to domains->list.
+ * Add the resulting basic set to domains->list and save a copy
+ * in domains->atomic for use in compute_partial_domains.
  *
  * We construct a single domain rather than trying to combine
  * the schedule domains of individual domains because we are working
@@ -2371,12 +2418,13 @@ static int compute_atomic_domain(struct isl_codegen_domains *domains,
        atomic_domain = isl_set_intersect(atomic_domain, isl_set_copy(domain));
        empty = isl_set_is_empty(atomic_domain);
        if (empty < 0 || empty) {
-               isl_set_free(atomic_domain);
+               domains->atomic = atomic_domain;
                return empty < 0 ? -1 : 0;
        }
 
        atomic_domain = isl_set_coalesce(atomic_domain);
        bset = isl_set_unshifted_simple_hull(atomic_domain);
+       domains->atomic = isl_set_from_basic_set(isl_basic_set_copy(bset));
        domains->list = isl_basic_set_list_add(domains->list, bset);
 
        return 0;
@@ -2449,6 +2497,11 @@ static int compute_separate_domain(struct isl_codegen_domains *domains,
  * the result with the current "class_domain" to ensure that the domains
  * are disjoint from those generated from other class domains.
  *
+ * The domain that has been made atomic may be larger than specified
+ * by the user since it needs to be representable as a single basic set.
+ * This possibly larger domain is stored in domains->atomic by
+ * compute_atomic_domain.
+ *
  * If anything is left after handling separate, unroll and atomic,
  * we split it up into basic sets and append the basic sets to domains->list.
  */
@@ -2482,9 +2535,8 @@ static int compute_partial_domains(struct isl_codegen_domains *domains,
        domain = isl_set_intersect(domain, isl_set_copy(class_domain));
 
        if (compute_atomic_domain(domains, domain) < 0)
-               goto error;
-       domain = isl_set_subtract(domain,
-                                   isl_set_copy(domains->option[atomic]));
+               domain = isl_set_free(domain);
+       domain = isl_set_subtract(domain, domains->atomic);
 
        domain = isl_set_coalesce(domain);
        domain = isl_set_make_disjoint(domain);
@@ -2518,6 +2570,7 @@ static int compute_class_domains(__isl_take isl_point *pnt, void *user)
        class_set = isl_set_from_point(pnt);
        domain = isl_map_domain(isl_map_intersect_range(
                                isl_map_copy(domains->sep_class), class_set));
+       domain = isl_ast_build_compute_gist(domains->build, domain);
        domain = isl_ast_build_eliminate(domains->build, domain);
 
        disjoint = isl_set_plain_is_disjoint(domain, domains->schedule_domain);
@@ -2587,6 +2640,9 @@ static __isl_give isl_basic_set_list *compute_domains(
        enum isl_ast_build_domain_type type;
        int empty;
 
+       if (!executed)
+               return NULL;
+
        ctx = isl_union_map_get_ctx(executed);
        domains.list = isl_basic_set_list_alloc(ctx, 0);
 
@@ -3623,9 +3679,12 @@ __isl_give isl_ast_node *isl_ast_build_ast_from_schedule(
        isl_ast_node *node;
        isl_union_map *executed;
 
+       build = isl_ast_build_copy(build);
+       build = isl_ast_build_set_single_valued(build, 0);
        executed = isl_union_map_reverse(schedule);
        list = generate_code(executed, isl_ast_build_copy(build), 0);
        node = isl_ast_node_from_graft_list(list, build);
+       isl_ast_build_free(build);
 
        return node;
 }