isl_dim: allow specification of tuple names
[platform/upstream/isl.git] / isl_map.c
index b9f8b01..43ccf43 100644 (file)
--- a/isl_map.c
+++ b/isl_map.c
@@ -229,22 +229,37 @@ unsigned isl_map_n_param(const struct isl_map *map)
 
 int isl_map_compatible_domain(struct isl_map *map, struct isl_set *set)
 {
-       return map->dim->n_in == set->dim->n_out &&
-              map->dim->nparam == set->dim->nparam;
+       int m;
+       if (!map || !set)
+               return -1;
+       m = isl_dim_match(map->dim, isl_dim_param, set->dim, isl_dim_param);
+       if (m < 0 || !m)
+               return m;
+       return isl_dim_tuple_match(map->dim, isl_dim_in, set->dim, isl_dim_set);
 }
 
 int isl_basic_map_compatible_domain(struct isl_basic_map *bmap,
                struct isl_basic_set *bset)
 {
-       return bmap->dim->n_in == bset->dim->n_out &&
-              bmap->dim->nparam == bset->dim->nparam;
+       int m;
+       if (!bmap || !bset)
+               return -1;
+       m = isl_dim_match(bmap->dim, isl_dim_param, bset->dim, isl_dim_param);
+       if (m < 0 || !m)
+               return m;
+       return isl_dim_tuple_match(bmap->dim, isl_dim_in, bset->dim, isl_dim_set);
 }
 
 int isl_basic_map_compatible_range(struct isl_basic_map *bmap,
                struct isl_basic_set *bset)
 {
-       return bmap->dim->n_out == bset->dim->n_out &&
-              bmap->dim->nparam == bset->dim->nparam;
+       int m;
+       if (!bmap || !bset)
+               return -1;
+       m = isl_dim_match(bmap->dim, isl_dim_param, bset->dim, isl_dim_param);
+       if (m < 0 || !m)
+               return m;
+       return isl_dim_tuple_match(bmap->dim, isl_dim_out, bset->dim, isl_dim_set);
 }
 
 isl_ctx *isl_map_get_ctx(__isl_keep isl_map *map)
@@ -280,6 +295,46 @@ struct isl_dim *isl_set_get_dim(struct isl_set *set)
        return isl_dim_copy(set->dim);
 }
 
+__isl_give isl_basic_map *isl_basic_map_set_tuple_name(
+       __isl_take isl_basic_map *bmap, enum isl_dim_type type, const char *s)
+{
+       bmap = isl_basic_map_cow(bmap);
+       if (!bmap)
+               return NULL;
+       bmap->dim = isl_dim_set_tuple_name(bmap->dim, type, s);
+       if (!bmap->dim)
+               goto error;
+       return bmap;
+error:
+       isl_basic_map_free(bmap);
+       return NULL;
+}
+
+__isl_give isl_map *isl_map_set_tuple_name(__isl_take isl_map *map,
+       enum isl_dim_type type, const char *s)
+{
+       int i;
+
+       map = isl_map_cow(map);
+       if (!map)
+               return NULL;
+
+       map->dim = isl_dim_set_tuple_name(map->dim, type, s);
+       if (!map->dim)
+               goto error;
+
+       for (i = 0; i < map->n; ++i) {
+               map->p[i] = isl_basic_map_set_tuple_name(map->p[i], type, s);
+               if (!map->p[i])
+                       goto error;
+       }
+
+       return map;
+error:
+       isl_map_free(map);
+       return NULL;
+}
+
 __isl_give isl_basic_map *isl_basic_map_set_dim_name(
        __isl_take isl_basic_map *bmap,
        enum isl_dim_type type, unsigned pos, const char *s)
@@ -3152,11 +3207,9 @@ struct isl_basic_set *isl_basic_set_from_basic_map(struct isl_basic_map *bmap)
        bmap = isl_basic_map_cow(bmap);
        if (!bmap)
                goto error;
-       bmap->dim = isl_dim_cow(bmap->dim);
+       bmap->dim = isl_dim_as_set_dim(bmap->dim);
        if (!bmap->dim)
                goto error;
-       bmap->dim->n_out += bmap->dim->n_in;
-       bmap->dim->n_in = 0;
        bmap = isl_basic_map_finalize(bmap);
        return (struct isl_basic_set *)bmap;
 error:
@@ -3215,7 +3268,8 @@ struct isl_basic_set *isl_basic_map_underlying_set(
 {
        if (!bmap)
                goto error;
-       if (bmap->dim->nparam == 0 && bmap->dim->n_in == 0 && bmap->n_div == 0)
+       if (bmap->dim->nparam == 0 && bmap->dim->n_in == 0 &&
+           bmap->n_div == 0 && !isl_dim_get_tuple_name(bmap->dim, isl_dim_out))
                return (struct isl_basic_set *)bmap;
        bmap = isl_basic_map_cow(bmap);
        if (!bmap)
@@ -3549,11 +3603,9 @@ struct isl_set *isl_set_from_map(struct isl_map *map)
        map = isl_map_cow(map);
        if (!map)
                return NULL;
-       map->dim = isl_dim_cow(map->dim);
+       map->dim = isl_dim_as_set_dim(map->dim);
        if (!map->dim)
                goto error;
-       map->dim->n_out += map->dim->n_in;
-       map->dim->n_in = 0;
        set = (struct isl_set *)map;
        for (i = 0; i < map->n; ++i) {
                set->p[i] = isl_basic_set_from_basic_map(map->p[i]);
@@ -4962,7 +5014,7 @@ error:
  */
 struct isl_basic_set *isl_basic_map_deltas(struct isl_basic_map *bmap)
 {
-       isl_dim *dims;
+       isl_dim *dims, *target_dim;
        struct isl_basic_set *bset;
        unsigned dim;
        unsigned nparam;
@@ -4970,9 +5022,12 @@ struct isl_basic_set *isl_basic_map_deltas(struct isl_basic_map *bmap)
 
        if (!bmap)
                goto error;
+       isl_assert(bmap->ctx, isl_dim_tuple_match(bmap->dim, isl_dim_in,
+                                                 bmap->dim, isl_dim_out),
+                  goto error);
+       target_dim = isl_dim_domain(isl_basic_map_get_dim(bmap));
        dim = isl_basic_map_n_in(bmap);
        nparam = isl_basic_map_n_param(bmap);
-       isl_assert(bmap->ctx, dim == isl_basic_map_n_out(bmap), goto error);
        bset = isl_basic_set_from_basic_map(bmap);
        bset = isl_basic_set_cow(bset);
        dims = isl_basic_set_get_dim(bset);
@@ -4989,7 +5044,9 @@ struct isl_basic_set *isl_basic_map_deltas(struct isl_basic_map *bmap)
                isl_int_set_si(bset->eq[j][1+nparam+dim+i], 1);
                isl_int_set_si(bset->eq[j][1+nparam+2*dim+i], -1);
        }
-       return isl_basic_set_project_out(bset, isl_dim_set, dim, 2*dim);
+       bset = isl_basic_set_project_out(bset, isl_dim_set, dim, 2*dim);
+       bset = isl_basic_set_reset_dim(bset, target_dim);
+       return bset;
 error:
        isl_basic_map_free(bmap);
        return NULL;
@@ -5007,7 +5064,9 @@ struct isl_set *isl_map_deltas(struct isl_map *map)
        if (!map)
                return NULL;
 
-       isl_assert(map->ctx, isl_map_n_in(map) == isl_map_n_out(map), goto error);
+       isl_assert(map->ctx, isl_dim_tuple_match(map->dim, isl_dim_in,
+                                                map->dim, isl_dim_out),
+                  goto error);
        dim = isl_map_get_dim(map);
        dim = isl_dim_domain(dim);
        result = isl_set_alloc_dim(dim, map->n, map->flags);