add isl_pw_aff_union_opt
[platform/upstream/isl.git] / isl_union_map.c
index 7da5141..ec53f93 100644 (file)
@@ -154,6 +154,11 @@ __isl_give isl_union_map *isl_union_map_align_params(
                return umap;
        }
 
+       model = isl_dim_drop(model, isl_dim_in,
+                               0, isl_dim_size(model, isl_dim_in));
+       model = isl_dim_drop(model, isl_dim_out,
+                               0, isl_dim_size(model, isl_dim_out));
+
        data.exp = isl_parameter_alignment_reordering(umap->dim, model);
        if (!data.exp)
                goto error;
@@ -225,24 +230,25 @@ __isl_give isl_union_set *isl_union_set_copy(__isl_keep isl_union_set *uset)
        return isl_union_map_copy(uset);
 }
 
-void isl_union_map_free(__isl_take isl_union_map *umap)
+void *isl_union_map_free(__isl_take isl_union_map *umap)
 {
        if (!umap)
-               return;
+               return NULL;
 
        if (--umap->ref > 0)
-               return;
+               return NULL;
 
        isl_hash_table_foreach(umap->dim->ctx, &umap->table,
                               &free_umap_entry, NULL);
        isl_hash_table_clear(&umap->table);
        isl_dim_free(umap->dim);
        free(umap);
+       return NULL;
 }
 
-void isl_union_set_free(__isl_take isl_union_set *uset)
+void *isl_union_set_free(__isl_take isl_union_set *uset)
 {
-       isl_union_map_free(uset);
+       return isl_union_map_free(uset);
 }
 
 static int has_dim(const void *entry, const void *val)
@@ -259,19 +265,24 @@ __isl_give isl_union_map *isl_union_map_add_map(__isl_take isl_union_map *umap,
        uint32_t hash;
        struct isl_hash_table_entry *entry;
 
+       if (!map || !umap)
+               goto error;
+
        if (isl_map_plain_is_empty(map)) {
                isl_map_free(map);
                return umap;
        }
 
+       if (!isl_dim_match(map->dim, isl_dim_param, umap->dim, isl_dim_param)) {
+               umap = isl_union_map_align_params(umap, isl_map_get_dim(map));
+               map = isl_map_align_params(map, isl_union_map_get_dim(umap));
+       }
+
        umap = isl_union_map_cow(umap);
 
        if (!map || !umap)
                goto error;
 
-       isl_assert(map->ctx, isl_dim_match(map->dim, isl_dim_param, umap->dim,
-                                          isl_dim_param), goto error);
-
        hash = isl_dim_get_hash(map->dim);
        entry = isl_hash_table_find(umap->dim->ctx, &umap->table, hash,
                                    &has_dim, map->dim, 1);
@@ -369,21 +380,29 @@ static int copy_map(void **entry, void *user)
        return -1;
 }
 
-__isl_give isl_map *isl_union_map_copy_map(__isl_keep isl_union_map *umap)
+__isl_give isl_map *isl_map_from_union_map(__isl_take isl_union_map *umap)
 {
+       isl_ctx *ctx;
        isl_map *map = NULL;
 
-       if (!umap || umap->table.n == 0)
+       if (!umap)
                return NULL;
+       ctx = isl_union_map_get_ctx(umap);
+       if (umap->table.n != 1)
+               isl_die(ctx, isl_error_invalid,
+                       "union map needs to contain elements in exactly "
+                       "one space", return isl_union_map_free(umap));
 
-       isl_hash_table_foreach(umap->dim->ctx, &umap->table, &copy_map, &map);
+       isl_hash_table_foreach(ctx, &umap->table, &copy_map, &map);
+
+       isl_union_map_free(umap);
 
        return map;
 }
 
-__isl_give isl_set *isl_union_set_copy_set(__isl_keep isl_union_set *uset)
+__isl_give isl_set *isl_set_from_union_set(__isl_take isl_union_set *uset)
 {
-       return isl_union_map_copy_map(uset);
+       return isl_map_from_union_map(uset);
 }
 
 __isl_give isl_map *isl_union_map_extract_map(__isl_keep isl_union_map *umap,
@@ -932,6 +951,10 @@ static int range_product_entry(void **entry, void *user)
        struct isl_union_map_bin_data *data = user;
        isl_map *map2 = *entry;
 
+       if (!isl_dim_tuple_match(data->map->dim, isl_dim_in,
+                                map2->dim, isl_dim_in))
+               return 0;
+
        map2 = isl_map_range_product(isl_map_copy(data->map),
                                     isl_map_copy(map2));
 
@@ -946,6 +969,29 @@ __isl_give isl_union_map *isl_union_map_range_product(
        return bin_op(umap1, umap2, &range_product_entry);
 }
 
+static int flat_range_product_entry(void **entry, void *user)
+{
+       struct isl_union_map_bin_data *data = user;
+       isl_map *map2 = *entry;
+
+       if (!isl_dim_tuple_match(data->map->dim, isl_dim_in,
+                                map2->dim, isl_dim_in))
+               return 0;
+
+       map2 = isl_map_flat_range_product(isl_map_copy(data->map),
+                                         isl_map_copy(map2));
+
+       data->res = isl_union_map_add_map(data->res, map2);
+
+       return 0;
+}
+
+__isl_give isl_union_map *isl_union_map_flat_range_product(
+       __isl_take isl_union_map *umap1, __isl_take isl_union_map *umap2)
+{
+       return bin_op(umap1, umap2, &flat_range_product_entry);
+}
+
 __isl_give isl_union_map *isl_union_map_from_range(
        __isl_take isl_union_set *uset)
 {
@@ -1624,7 +1670,9 @@ int isl_union_map_is_single_valued(__isl_keep isl_union_map *umap)
        int sv;
 
        if (isl_union_map_n_map(umap) == 1) {
-               isl_map *map = isl_union_map_copy_map(umap);
+               isl_map *map;
+               umap = isl_union_map_copy(umap);
+               map = isl_map_from_union_map(umap);
                sv = isl_map_is_single_valued(map);
                isl_map_free(map);
                return sv;
@@ -1980,8 +2028,6 @@ static int solutions_entry(void **entry, void *user)
 __isl_give isl_union_set *isl_union_set_solutions(
        __isl_take isl_union_set *uset)
 {
-       isl_ctx *ctx;
-       isl_dim *dim;
        isl_union_set *res = NULL;
 
        if (!uset)