improve isl_pw_qpolynomial_move
[platform/upstream/isl.git] / isl_dim.c
index 13a68b8..04c0a66 100644 (file)
--- a/isl_dim.c
+++ b/isl_dim.c
@@ -357,6 +357,8 @@ __isl_give isl_dim *isl_dim_insert(__isl_take isl_dim *dim,
        isl_assert(dim->ctx, pos <= isl_dim_size(dim, type), goto error);
 
        dim = isl_dim_cow(dim);
+       if (!dim)
+               return NULL;
 
        if (dim->names) {
                enum isl_dim_type t;
@@ -409,19 +411,53 @@ __isl_give isl_dim *isl_dim_move(__isl_take isl_dim *dim,
        isl_assert(dim->ctx, src_pos + n <= isl_dim_size(dim, src_type),
                goto error);
 
-       /* just the simple case for now */
-       isl_assert(dim->ctx,
-               offset(dim, dst_type) + dst_pos ==
-               offset(dim, src_type) + src_pos + ((src_type < dst_type) ? n : 0),
-               goto error);
-
-       if (dst_type == src_type)
+       if (dst_type == src_type && dst_pos == src_pos)
                return dim;
 
+       isl_assert(dim->ctx, dst_type != src_type, goto error);
+
        dim = isl_dim_cow(dim);
        if (!dim)
                return NULL;
 
+       if (dim->names) {
+               struct isl_name **names;
+               enum isl_dim_type t;
+               int off;
+               int size[3];
+               names = isl_calloc_array(dim->ctx, struct isl_name *,
+                                        dim->nparam + dim->n_in + dim->n_out);
+               if (!names)
+                       goto error;
+               off = 0;
+               size[isl_dim_param] = dim->nparam;
+               size[isl_dim_in] = dim->n_in;
+               size[isl_dim_out] = dim->n_out;
+               for (t = isl_dim_param; t <= isl_dim_out; ++t) {
+                       if (t == dst_type) {
+                               get_names(dim, t, 0, dst_pos, names + off);
+                               off += dst_pos;
+                               get_names(dim, src_type, src_pos, n, names+off);
+                               off += n;
+                               get_names(dim, t, dst_pos, size[t] - dst_pos,
+                                               names + off);
+                               off += size[t] - dst_pos;
+                       } else if (t == src_type) {
+                               get_names(dim, t, 0, src_pos, names + off);
+                               off += src_pos;
+                               get_names(dim, t, src_pos + n,
+                                           size[t] - src_pos - n, names + off);
+                               off += size[t] - src_pos - n;
+                       } else {
+                               get_names(dim, t, 0, size[t], names + off);
+                               off += size[t];
+                       }
+               }
+               free(dim->names);
+               dim->names = names;
+               dim->n_name = dim->nparam + dim->n_in + dim->n_out;
+       }
+
        switch (dst_type) {
        case isl_dim_param:     dim->nparam += n; break;
        case isl_dim_in:        dim->n_in += n; break;