isl_dim_move: update parameters of nested spaces
[platform/upstream/isl.git] / isl_dim.c
index 2f06824..2d19663 100644 (file)
--- a/isl_dim.c
+++ b/isl_dim.c
@@ -118,7 +118,8 @@ static struct isl_dim *set_name(struct isl_dim *dim,
                goto error;
 
        pos = global_pos(dim, type, pos);
-       isl_assert(ctx, pos != isl_dim_total(dim), goto error);
+       if (pos == isl_dim_total(dim))
+               goto error;
 
        if (pos >= dim->n_name) {
                if (!name)
@@ -344,6 +345,7 @@ struct isl_dim *isl_dim_set_name(struct isl_dim *dim,
                return NULL;
        if (!name_ok(dim->ctx, s))
                goto error;
+       isl_name_free(dim->ctx, get_name(dim, type, pos));
        name = isl_name_get(dim->ctx, s);
        if (!name)
                goto error;
@@ -489,8 +491,17 @@ struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
        dim = isl_dim_reset(dim, type);
        switch (type) {
        case isl_dim_param:
-               return isl_dim_extend(dim,
+               dim = isl_dim_extend(dim,
                                        dim->nparam + n, dim->n_in, dim->n_out);
+               if (dim && dim->nested[0] &&
+                   !(dim->nested[0] = isl_dim_add(dim->nested[0],
+                                                   isl_dim_param, n)))
+                       goto error;
+               if (dim && dim->nested[1] &&
+                   !(dim->nested[1] = isl_dim_add(dim->nested[1],
+                                                   isl_dim_param, n)))
+                       goto error;
+               return dim;
        case isl_dim_in:
                return isl_dim_extend(dim,
                                        dim->nparam, dim->n_in + n, dim->n_out);
@@ -499,6 +510,9 @@ struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
                                        dim->nparam, dim->n_in, dim->n_out + n);
        }
        return dim;
+error:
+       isl_dim_free(dim);
+       return NULL;
 }
 
 __isl_give isl_dim *isl_dim_insert(__isl_take isl_dim *dim,
@@ -520,7 +534,8 @@ __isl_give isl_dim *isl_dim_insert(__isl_take isl_dim *dim,
        if (dim->names) {
                enum isl_dim_type t;
                int off;
-               int size[3];
+               int s[3];
+               int *size = s - isl_dim_param;
                names = isl_calloc_array(dim->ctx, struct isl_name *,
                                     dim->nparam + dim->n_in + dim->n_out + n);
                if (!names)
@@ -561,6 +576,8 @@ __isl_give isl_dim *isl_dim_move(__isl_take isl_dim *dim,
        enum isl_dim_type dst_type, unsigned dst_pos,
        enum isl_dim_type src_type, unsigned src_pos, unsigned n)
 {
+       int i;
+
        if (!dim)
                return NULL;
        if (n == 0)
@@ -585,7 +602,8 @@ __isl_give isl_dim *isl_dim_move(__isl_take isl_dim *dim,
                struct isl_name **names;
                enum isl_dim_type t;
                int off;
-               int size[3];
+               int s[3];
+               int *size = s - isl_dim_param;
                names = isl_calloc_array(dim->ctx, struct isl_name *,
                                         dim->nparam + dim->n_in + dim->n_out);
                if (!names)
@@ -631,6 +649,18 @@ __isl_give isl_dim *isl_dim_move(__isl_take isl_dim *dim,
        case isl_dim_out:       dim->n_out -= n; break;
        }
 
+       if (dst_type != isl_dim_param && src_type != isl_dim_param)
+               return dim;
+
+       for (i = 0; i < 2; ++i) {
+               if (!dim->nested[i])
+                       continue;
+               dim->nested[i] = isl_dim_replace(dim->nested[i],
+                                                isl_dim_param, dim);
+               if (!dim->nested[i])
+                       goto error;
+       }
+
        return dim;
 error:
        isl_dim_free(dim);
@@ -683,7 +713,7 @@ error:
 
 struct isl_dim *isl_dim_product(struct isl_dim *left, struct isl_dim *right)
 {
-       struct isl_dim *dim;
+       isl_dim *dom1, *dom2, *nest1, *nest2;
 
        if (!left || !right)
                goto error;
@@ -691,21 +721,15 @@ struct isl_dim *isl_dim_product(struct isl_dim *left, struct isl_dim *right)
        isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
                        goto error);
 
-       dim = isl_dim_alloc(left->ctx, left->nparam,
-                       left->n_in + right->n_in, left->n_out + right->n_out);
-       if (!dim)
-               goto error;
+       dom1 = isl_dim_domain(isl_dim_copy(left));
+       dom2 = isl_dim_domain(isl_dim_copy(right));
+       nest1 = isl_dim_wrap(isl_dim_join(isl_dim_reverse(dom1), dom2));
 
-       dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
-       dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
-       dim = copy_names(dim, isl_dim_in, left->n_in, right, isl_dim_in);
-       dim = copy_names(dim, isl_dim_out, 0, left, isl_dim_out);
-       dim = copy_names(dim, isl_dim_out, left->n_out, right, isl_dim_out);
+       dom1 = isl_dim_range(left);
+       dom2 = isl_dim_range(right);
+       nest2 = isl_dim_wrap(isl_dim_join(isl_dim_reverse(dom1), dom2));
 
-       isl_dim_free(left);
-       isl_dim_free(right);
-
-       return dim;
+       return isl_dim_join(isl_dim_reverse(nest1), nest2);
 error:
        isl_dim_free(left);
        isl_dim_free(right);
@@ -851,6 +875,16 @@ struct isl_dim *isl_dim_drop(struct isl_dim *dim, enum isl_dim_type type,
        case isl_dim_out:       dim->n_out -= num; break;
        }
        dim = isl_dim_reset(dim, type);
+       if (type == isl_dim_param) {
+               if (dim && dim->nested[0] &&
+                   !(dim->nested[0] = isl_dim_drop(dim->nested[0],
+                                                   isl_dim_param, first, num)))
+                       goto error;
+               if (dim && dim->nested[1] &&
+                   !(dim->nested[1] = isl_dim_drop(dim->nested[1],
+                                                   isl_dim_param, first, num)))
+                       goto error;
+       }
        return dim;
 error:
        isl_dim_free(dim);
@@ -881,6 +915,11 @@ struct isl_dim *isl_dim_domain(struct isl_dim *dim)
        return isl_dim_reverse(dim);
 }
 
+__isl_give isl_dim *isl_dim_from_domain(__isl_take isl_dim *dim)
+{
+       return isl_dim_reverse(dim);
+}
+
 struct isl_dim *isl_dim_range(struct isl_dim *dim)
 {
        if (!dim)
@@ -888,6 +927,11 @@ struct isl_dim *isl_dim_range(struct isl_dim *dim)
        return isl_dim_drop_inputs(dim, 0, dim->n_in);
 }
 
+__isl_give isl_dim *isl_dim_from_range(__isl_take isl_dim *dim)
+{
+       return dim;
+}
+
 __isl_give isl_dim *isl_dim_as_set_dim(__isl_take isl_dim *dim)
 {
        dim = isl_dim_cow(dim);
@@ -929,7 +973,7 @@ struct isl_dim *isl_dim_underlying(struct isl_dim *dim, unsigned n_div)
 
 unsigned isl_dim_total(struct isl_dim *dim)
 {
-       return dim->nparam + dim->n_in + dim->n_out;
+       return dim ? dim->nparam + dim->n_in + dim->n_out : 0;
 }
 
 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
@@ -959,13 +1003,13 @@ static uint32_t isl_hash_dim(uint32_t hash, __isl_keep isl_dim *dim)
 
        for (i = 0; i < dim->nparam; ++i) {
                name = get_name(dim, isl_dim_param, i);
-               hash = isl_hash_builtin(hash, name);
+               hash = isl_hash_name(hash, name);
        }
 
        name = tuple_name(dim, isl_dim_in);
-       hash = isl_hash_builtin(hash, name);
+       hash = isl_hash_name(hash, name);
        name = tuple_name(dim, isl_dim_out);
-       hash = isl_hash_builtin(hash, name);
+       hash = isl_hash_name(hash, name);
 
        hash = isl_hash_dim(hash, dim->nested[0]);
        hash = isl_hash_dim(hash, dim->nested[1]);
@@ -1071,3 +1115,50 @@ __isl_give isl_dim *isl_dim_reset(__isl_take isl_dim *dim,
 
        return dim;
 }
+
+__isl_give isl_dim *isl_dim_flatten(__isl_take isl_dim *dim)
+{
+       if (!dim)
+               return NULL;
+       if (!dim->nested[0] && !dim->nested[1])
+               return dim;
+
+       if (dim->nested[0])
+               dim = isl_dim_reset(dim, isl_dim_in);
+       if (dim && dim->nested[1])
+               dim = isl_dim_reset(dim, isl_dim_out);
+
+       return dim;
+}
+
+/* Replace the dimensions of the given type of dst by those of src.
+ */
+__isl_give isl_dim *isl_dim_replace(__isl_take isl_dim *dst,
+       enum isl_dim_type type, __isl_keep isl_dim *src)
+{
+       dst = isl_dim_cow(dst);
+
+       if (!dst || !src)
+               goto error;
+
+       dst = isl_dim_drop(dst, type, 0, isl_dim_size(dst, type));
+       dst = isl_dim_add(dst, type, isl_dim_size(src, type));
+       dst = copy_names(dst, type, 0, src, type);
+
+       if (dst && type == isl_dim_param) {
+               int i;
+               for (i = 0; i <= 1; ++i) {
+                       if (!dst->nested[i])
+                               continue;
+                       dst->nested[i] = isl_dim_replace(dst->nested[i],
+                                                        type, src);
+                       if (!dst->nested[i])
+                               goto error;
+               }
+       }
+
+       return dst;
+error:
+       isl_dim_free(dst);
+       return NULL;
+}