export isl_set_split_dims
[platform/upstream/isl.git] / isl_dim.c
index 1aa66c4..ea77233 100644 (file)
--- a/isl_dim.c
+++ b/isl_dim.c
@@ -7,7 +7,7 @@
  * Computerwetenschappen, Celestijnenlaan 200A, B-3001 Leuven, Belgium
  */
 
-#include "isl_dim.h"
+#include <isl_dim.h>
 #include "isl_name.h"
 
 struct isl_dim *isl_dim_alloc(struct isl_ctx *ctx,
@@ -344,6 +344,138 @@ struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
        return dim;
 }
 
+__isl_give isl_dim *isl_dim_insert(__isl_take isl_dim *dim,
+       enum isl_dim_type type, unsigned pos, unsigned n)
+{
+       struct isl_name **names = NULL;
+
+       if (!dim)
+               return NULL;
+       if (n == 0)
+               return dim;
+
+       isl_assert(dim->ctx, pos <= isl_dim_size(dim, type), goto error);
+
+       dim = isl_dim_cow(dim);
+       if (!dim)
+               return NULL;
+
+       if (dim->names) {
+               enum isl_dim_type t;
+               int off;
+               int size[3];
+               names = isl_calloc_array(dim->ctx, struct isl_name *,
+                                    dim->nparam + dim->n_in + dim->n_out + n);
+               if (!names)
+                       goto error;
+               off = 0;
+               size[isl_dim_param] = dim->nparam;
+               size[isl_dim_in] = dim->n_in;
+               size[isl_dim_out] = dim->n_out;
+               for (t = isl_dim_param; t <= isl_dim_out; ++t) {
+                       if (t != type) {
+                               get_names(dim, t, 0, size[t], names + off);
+                               off += size[t];
+                       } else {
+                               get_names(dim, t, 0, pos, names + off);
+                               off += pos + n;
+                               get_names(dim, t, pos, size[t]-pos, names+off);
+                               off += size[t] - pos;
+                       }
+               }
+               free(dim->names);
+               dim->names = names;
+               dim->n_name = dim->nparam + dim->n_in + dim->n_out + n;
+       }
+       switch (type) {
+       case isl_dim_param:     dim->nparam += n; break;
+       case isl_dim_in:        dim->n_in += n; break;
+       case isl_dim_out:       dim->n_out += n; break;
+       }
+
+       return dim;
+error:
+       isl_dim_free(dim);
+       return NULL;
+}
+
+__isl_give isl_dim *isl_dim_move(__isl_take isl_dim *dim,
+       enum isl_dim_type dst_type, unsigned dst_pos,
+       enum isl_dim_type src_type, unsigned src_pos, unsigned n)
+{
+       if (!dim)
+               return NULL;
+       if (n == 0)
+               return dim;
+
+       isl_assert(dim->ctx, src_pos + n <= isl_dim_size(dim, src_type),
+               goto error);
+
+       if (dst_type == src_type && dst_pos == src_pos)
+               return dim;
+
+       isl_assert(dim->ctx, dst_type != src_type, goto error);
+
+       dim = isl_dim_cow(dim);
+       if (!dim)
+               return NULL;
+
+       if (dim->names) {
+               struct isl_name **names;
+               enum isl_dim_type t;
+               int off;
+               int size[3];
+               names = isl_calloc_array(dim->ctx, struct isl_name *,
+                                        dim->nparam + dim->n_in + dim->n_out);
+               if (!names)
+                       goto error;
+               off = 0;
+               size[isl_dim_param] = dim->nparam;
+               size[isl_dim_in] = dim->n_in;
+               size[isl_dim_out] = dim->n_out;
+               for (t = isl_dim_param; t <= isl_dim_out; ++t) {
+                       if (t == dst_type) {
+                               get_names(dim, t, 0, dst_pos, names + off);
+                               off += dst_pos;
+                               get_names(dim, src_type, src_pos, n, names+off);
+                               off += n;
+                               get_names(dim, t, dst_pos, size[t] - dst_pos,
+                                               names + off);
+                               off += size[t] - dst_pos;
+                       } else if (t == src_type) {
+                               get_names(dim, t, 0, src_pos, names + off);
+                               off += src_pos;
+                               get_names(dim, t, src_pos + n,
+                                           size[t] - src_pos - n, names + off);
+                               off += size[t] - src_pos - n;
+                       } else {
+                               get_names(dim, t, 0, size[t], names + off);
+                               off += size[t];
+                       }
+               }
+               free(dim->names);
+               dim->names = names;
+               dim->n_name = dim->nparam + dim->n_in + dim->n_out;
+       }
+
+       switch (dst_type) {
+       case isl_dim_param:     dim->nparam += n; break;
+       case isl_dim_in:        dim->n_in += n; break;
+       case isl_dim_out:       dim->n_out += n; break;
+       }
+
+       switch (src_type) {
+       case isl_dim_param:     dim->nparam -= n; break;
+       case isl_dim_in:        dim->n_in -= n; break;
+       case isl_dim_out:       dim->n_out -= n; break;
+       }
+
+       return dim;
+error:
+       isl_dim_free(dim);
+       return NULL;
+}
+
 struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
 {
        struct isl_dim *dim;
@@ -353,7 +485,7 @@ struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
 
        isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
                        goto error);
-       isl_assert(left->ctx, match(left, isl_dim_out, right, isl_dim_in),
+       isl_assert(left->ctx, n(left, isl_dim_out) == n(right, isl_dim_in),
                        goto error);
 
        dim = isl_dim_alloc(left->ctx, left->nparam, left->n_in, right->n_out);
@@ -427,10 +559,10 @@ struct isl_dim *isl_dim_map(struct isl_dim *dim)
        }
        dim->n_in = dim->n_out;
        if (names) {
-               copy_names(dim, isl_dim_out, 0, dim, isl_dim_in);
                free(dim->names);
                dim->names = names;
                dim->n_name = dim->nparam + dim->n_out + dim->n_out;
+               dim = copy_names(dim, isl_dim_out, 0, dim, isl_dim_in);
        }
        return dim;
 error:
@@ -593,8 +725,8 @@ unsigned isl_dim_total(struct isl_dim *dim)
 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
 {
        return match(dim1, isl_dim_param, dim2, isl_dim_param) &&
-              match(dim1, isl_dim_in, dim2, isl_dim_in) &&
-              match(dim1, isl_dim_out, dim2, isl_dim_out);
+              n(dim1, isl_dim_in) == n(dim2, isl_dim_in) &&
+              n(dim1, isl_dim_out) == n(dim2, isl_dim_out);
 }
 
 int isl_dim_compatible(struct isl_dim *dim1, struct isl_dim *dim2)