isl_hmap_map_basic_set_set: avoid memory leak on error path
[platform/upstream/isl.git] / isl_ast_build.c
index a1941f8..95b4692 100644 (file)
@@ -49,7 +49,7 @@ static __isl_give isl_ast_build *isl_ast_build_init_derived(
        isl_vec *strides;
 
        build = isl_ast_build_cow(build);
-       if (!build)
+       if (!build || !build->domain)
                goto error;
 
        ctx = isl_ast_build_get_ctx(build);
@@ -76,6 +76,24 @@ error:
        return isl_ast_build_free(build);
 }
 
+/* Return an isl_id called "c%d", with "%d" set to "i".
+ * If an isl_id with such a name already appears among the parameters
+ * in build->domain, then adjust the name to "c%d_%d".
+ */
+static __isl_give isl_id *generate_name(isl_ctx *ctx, int i,
+       __isl_keep isl_ast_build *build)
+{
+       int j;
+       char name[16];
+       isl_set *dom = build->domain;
+
+       snprintf(name, sizeof(name), "c%d", i);
+       j = 0;
+       while (isl_set_find_dim_by_name(dom, isl_dim_param, name) >= 0)
+               snprintf(name, sizeof(name), "c%d_%d", i, j++);
+       return isl_id_alloc(ctx, name, NULL);
+}
+
 /* Create an isl_ast_build with "set" as domain.
  *
  * The input set is usually a parameter domain, but we currently allow it to
@@ -108,7 +126,11 @@ __isl_give isl_ast_build *isl_ast_build_from_context(__isl_take isl_set *set)
        build->depth = n;
        build->iterators = isl_id_list_alloc(ctx, n);
        for (i = 0; i < n; ++i) {
-               isl_id *id = isl_set_get_dim_id(set, isl_dim_set, i);
+               isl_id *id;
+               if (isl_set_has_dim_id(set, isl_dim_set, i))
+                       id = isl_set_get_dim_id(set, isl_dim_set, i);
+               else
+                       id = generate_name(ctx, i, build);
                build->iterators = isl_id_list_add(build->iterators, id);
        }
        space = isl_set_get_space(set);
@@ -155,9 +177,14 @@ __isl_give isl_ast_build *isl_ast_build_dup(__isl_keep isl_ast_build *build)
        dup->strides = isl_vec_copy(build->strides);
        dup->offsets = isl_multi_aff_copy(build->offsets);
        dup->executed = isl_union_map_copy(build->executed);
+       dup->single_valued = build->single_valued;
        dup->options = isl_union_map_copy(build->options);
        dup->at_each_domain = build->at_each_domain;
        dup->at_each_domain_user = build->at_each_domain_user;
+       dup->before_each_for = build->before_each_for;
+       dup->before_each_for_user = build->before_each_for_user;
+       dup->after_each_for = build->after_each_for;
+       dup->after_each_for_user = build->after_each_for_user;
        dup->create_leaf = build->create_leaf;
        dup->create_leaf_user = build->create_leaf_user;
 
@@ -316,6 +343,42 @@ __isl_give isl_ast_build *isl_ast_build_set_at_each_domain(
        return build;
 }
 
+/* Set the "before_each_for" callback of "build" to "fn".
+ */
+__isl_give isl_ast_build *isl_ast_build_set_before_each_for(
+       __isl_take isl_ast_build *build,
+       __isl_give isl_id *(*fn)(__isl_keep isl_ast_build *build,
+               void *user), void *user)
+{
+       build = isl_ast_build_cow(build);
+
+       if (!build)
+               return NULL;
+
+       build->before_each_for = fn;
+       build->before_each_for_user = user;
+
+       return build;
+}
+
+/* Set the "after_each_for" callback of "build" to "fn".
+ */
+__isl_give isl_ast_build *isl_ast_build_set_after_each_for(
+       __isl_take isl_ast_build *build,
+       __isl_give isl_ast_node *(*fn)(__isl_take isl_ast_node *node,
+               __isl_keep isl_ast_build *build, void *user), void *user)
+{
+       build = isl_ast_build_cow(build);
+
+       if (!build)
+               return NULL;
+
+       build->after_each_for = fn;
+       build->after_each_for_user = user;
+
+       return build;
+}
+
 /* Set the "create_leaf" callback of "build" to "fn".
  */
 __isl_give isl_ast_build *isl_ast_build_set_create_leaf(
@@ -352,6 +415,10 @@ __isl_give isl_ast_build *isl_ast_build_clear_local_info(
 
        build->at_each_domain = NULL;
        build->at_each_domain_user = NULL;
+       build->before_each_for = NULL;
+       build->before_each_for_user = NULL;
+       build->after_each_for = NULL;
+       build->after_each_for_user = NULL;
        build->create_leaf = NULL;
        build->create_leaf_user = NULL;
 
@@ -666,9 +733,11 @@ __isl_give isl_ast_build *isl_ast_build_set_loop_bounds(
                set = isl_set_compute_divs(set);
                build->pending = isl_set_intersect(build->pending,
                                                        isl_set_copy(set));
-               if (isl_ast_build_has_stride(build, build->depth))
+               if (isl_ast_build_has_stride(build, build->depth)) {
                        build->domain = isl_set_eliminate(build->domain,
                                                isl_dim_set, build->depth, 1);
+                       build->domain = isl_set_compute_divs(build->domain);
+               }
        } else {
                isl_basic_set *generated, *pending;
 
@@ -1502,20 +1571,14 @@ error:
 static __isl_give isl_id_list *generate_names(isl_ctx *ctx, int n, int first,
        __isl_keep isl_ast_build *build)
 {
-       int i, j;
-       char name[16];
+       int i;
        isl_id_list *names;
-       isl_set *dom = build->domain;
 
        names = isl_id_list_alloc(ctx, n);
        for (i = 0; i < n; ++i) {
                isl_id *id;
 
-               snprintf(name, sizeof(name), "c%d", first + i);
-               j = 0;
-               while (isl_set_find_dim_by_name(dom, isl_dim_param, name) >= 0)
-                       snprintf(name, sizeof(name), "c%d_%d", first + i, j++);
-               id = isl_id_alloc(ctx, name, NULL);
+               id = generate_name(ctx, first + i, build);
                names = isl_id_list_add(names, id);
        }
 
@@ -2016,3 +2079,20 @@ __isl_give isl_set *isl_ast_build_eliminate(
        domain = isl_ast_build_eliminate_divs(build, domain);
        return domain;
 }
+
+/* Replace build->single_valued by "sv".
+ */
+__isl_give isl_ast_build *isl_ast_build_set_single_valued(
+       __isl_take isl_ast_build *build, int sv)
+{
+       if (!build)
+               return build;
+       if (build->single_valued == sv)
+               return build;
+       build = isl_ast_build_cow(build);
+       if (!build)
+               return build;
+       build->single_valued = sv;
+
+       return build;
+}