isl_basic_map_eliminate_vars: avoid NULL pointer dereference
[platform/upstream/isl.git] / isl_dim.c
index a5381dd..a08f316 100644 (file)
--- a/isl_dim.c
+++ b/isl_dim.c
@@ -1,4 +1,13 @@
-#include "isl_dim.h"
+/*
+ * Copyright 2008-2009 Katholieke Universiteit Leuven
+ *
+ * Use of this software is governed by the GNU LGPLv2.1 license
+ *
+ * Written by Sven Verdoolaege, K.U.Leuven, Departement
+ * Computerwetenschappen, Celestijnenlaan 200A, B-3001 Leuven, Belgium
+ */
+
+#include <isl_dim.h>
 #include "isl_name.h"
 
 struct isl_dim *isl_dim_alloc(struct isl_ctx *ctx,
@@ -45,11 +54,44 @@ static unsigned global_pos(struct isl_dim *dim,
                isl_assert(ctx, pos < dim->n_out, return isl_dim_total(dim));
                return pos + dim->nparam + dim->n_in;
        default:
-               isl_assert(ctx, 0, goto error);
+               isl_assert(ctx, 0, return isl_dim_total(dim));
        }
        return isl_dim_total(dim);
 }
 
+/* Extend length of names array to the total number of dimensions.
+ */
+static __isl_give isl_dim *extend_names(__isl_take isl_dim *dim)
+{
+       struct isl_name **names;
+       int i;
+
+       if (isl_dim_total(dim) <= dim->n_name)
+               return dim;
+
+       if (!dim->names) {
+               dim->names = isl_calloc_array(dim->ctx,
+                               struct isl_name *, isl_dim_total(dim));
+               if (!dim->names)
+                       goto error;
+       } else {
+               names = isl_realloc_array(dim->ctx, dim->names,
+                               struct isl_name *, isl_dim_total(dim));
+               if (!names)
+                       goto error;
+               dim->names = names;
+               for (i = dim->n_name; i < isl_dim_total(dim); ++i)
+                       dim->names[i] = NULL;
+       }
+
+       dim->n_name = isl_dim_total(dim);
+
+       return dim;
+error:
+       isl_dim_free(dim);
+       return NULL;
+}
+
 static struct isl_dim *set_name(struct isl_dim *dim,
                                 enum isl_dim_type type, unsigned pos,
                                 struct isl_name *name)
@@ -66,21 +108,9 @@ static struct isl_dim *set_name(struct isl_dim *dim,
        if (pos >= dim->n_name) {
                if (!name)
                        return dim;
-               if (!dim->names) {
-                       dim->names = isl_calloc_array(dim->ctx,
-                                       struct isl_name *, isl_dim_total(dim));
-                       if (!dim->names)
-                               goto error;
-               } else {
-                       int i;
-                       dim->names = isl_realloc_array(dim->ctx, dim->names,
-                                       struct isl_name *, isl_dim_total(dim));
-                       if (!dim->names)
-                               goto error;
-                       for (i = dim->n_name; i < isl_dim_total(dim); ++i)
-                               dim->names[i] = NULL;
-               }
-               dim->n_name = isl_dim_total(dim);
+               dim = extend_names(dim);
+               if (!dim)
+                       goto error;
        }
 
        dim->names[pos] = name;
@@ -106,20 +136,40 @@ static struct isl_name *get_name(struct isl_dim *dim,
        return dim->names[pos];
 }
 
+static unsigned offset(struct isl_dim *dim, enum isl_dim_type type)
+{
+       switch (type) {
+       case isl_dim_param:     return 0;
+       case isl_dim_in:        return dim->nparam;
+       case isl_dim_out:       return dim->nparam + dim->n_in;
+       default:                return 0;
+       }
+}
+
 static unsigned n(struct isl_dim *dim, enum isl_dim_type type)
 {
        switch (type) {
        case isl_dim_param:     return dim->nparam;
        case isl_dim_in:        return dim->n_in;
        case isl_dim_out:       return dim->n_out;
+       default:                return 0;
        }
 }
 
 unsigned isl_dim_size(struct isl_dim *dim, enum isl_dim_type type)
 {
+       if (!dim)
+               return 0;
        return n(dim, type);
 }
 
+unsigned isl_dim_offset(__isl_keep isl_dim *dim, enum isl_dim_type type)
+{
+       if (!dim)
+               return 0;
+       return offset(dim, type);
+}
+
 static struct isl_dim *copy_names(struct isl_dim *dst,
        enum isl_dim_type dst_type, unsigned offset, struct isl_dim *src,
        enum isl_dim_type src_type)
@@ -127,6 +177,9 @@ static struct isl_dim *copy_names(struct isl_dim *dst,
        int i;
        struct isl_name *name;
 
+       if (!dst)
+               return NULL;
+
        for (i = 0; i < n(src, src_type); ++i) {
                name = get_name(src, src_type, i);
                if (!name)
@@ -142,6 +195,8 @@ static struct isl_dim *copy_names(struct isl_dim *dst,
 struct isl_dim *isl_dim_dup(struct isl_dim *dim)
 {
        struct isl_dim *dup;
+       if (!dim)
+               return NULL;
        dup = isl_dim_alloc(dim->ctx, dim->nparam, dim->n_in, dim->n_out);
        if (!dim->names)
                return dup;
@@ -303,6 +358,138 @@ struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
        return dim;
 }
 
+__isl_give isl_dim *isl_dim_insert(__isl_take isl_dim *dim,
+       enum isl_dim_type type, unsigned pos, unsigned n)
+{
+       struct isl_name **names = NULL;
+
+       if (!dim)
+               return NULL;
+       if (n == 0)
+               return dim;
+
+       isl_assert(dim->ctx, pos <= isl_dim_size(dim, type), goto error);
+
+       dim = isl_dim_cow(dim);
+       if (!dim)
+               return NULL;
+
+       if (dim->names) {
+               enum isl_dim_type t;
+               int off;
+               int size[3];
+               names = isl_calloc_array(dim->ctx, struct isl_name *,
+                                    dim->nparam + dim->n_in + dim->n_out + n);
+               if (!names)
+                       goto error;
+               off = 0;
+               size[isl_dim_param] = dim->nparam;
+               size[isl_dim_in] = dim->n_in;
+               size[isl_dim_out] = dim->n_out;
+               for (t = isl_dim_param; t <= isl_dim_out; ++t) {
+                       if (t != type) {
+                               get_names(dim, t, 0, size[t], names + off);
+                               off += size[t];
+                       } else {
+                               get_names(dim, t, 0, pos, names + off);
+                               off += pos + n;
+                               get_names(dim, t, pos, size[t]-pos, names+off);
+                               off += size[t] - pos;
+                       }
+               }
+               free(dim->names);
+               dim->names = names;
+               dim->n_name = dim->nparam + dim->n_in + dim->n_out + n;
+       }
+       switch (type) {
+       case isl_dim_param:     dim->nparam += n; break;
+       case isl_dim_in:        dim->n_in += n; break;
+       case isl_dim_out:       dim->n_out += n; break;
+       }
+
+       return dim;
+error:
+       isl_dim_free(dim);
+       return NULL;
+}
+
+__isl_give isl_dim *isl_dim_move(__isl_take isl_dim *dim,
+       enum isl_dim_type dst_type, unsigned dst_pos,
+       enum isl_dim_type src_type, unsigned src_pos, unsigned n)
+{
+       if (!dim)
+               return NULL;
+       if (n == 0)
+               return dim;
+
+       isl_assert(dim->ctx, src_pos + n <= isl_dim_size(dim, src_type),
+               goto error);
+
+       if (dst_type == src_type && dst_pos == src_pos)
+               return dim;
+
+       isl_assert(dim->ctx, dst_type != src_type, goto error);
+
+       dim = isl_dim_cow(dim);
+       if (!dim)
+               return NULL;
+
+       if (dim->names) {
+               struct isl_name **names;
+               enum isl_dim_type t;
+               int off;
+               int size[3];
+               names = isl_calloc_array(dim->ctx, struct isl_name *,
+                                        dim->nparam + dim->n_in + dim->n_out);
+               if (!names)
+                       goto error;
+               off = 0;
+               size[isl_dim_param] = dim->nparam;
+               size[isl_dim_in] = dim->n_in;
+               size[isl_dim_out] = dim->n_out;
+               for (t = isl_dim_param; t <= isl_dim_out; ++t) {
+                       if (t == dst_type) {
+                               get_names(dim, t, 0, dst_pos, names + off);
+                               off += dst_pos;
+                               get_names(dim, src_type, src_pos, n, names+off);
+                               off += n;
+                               get_names(dim, t, dst_pos, size[t] - dst_pos,
+                                               names + off);
+                               off += size[t] - dst_pos;
+                       } else if (t == src_type) {
+                               get_names(dim, t, 0, src_pos, names + off);
+                               off += src_pos;
+                               get_names(dim, t, src_pos + n,
+                                           size[t] - src_pos - n, names + off);
+                               off += size[t] - src_pos - n;
+                       } else {
+                               get_names(dim, t, 0, size[t], names + off);
+                               off += size[t];
+                       }
+               }
+               free(dim->names);
+               dim->names = names;
+               dim->n_name = dim->nparam + dim->n_in + dim->n_out;
+       }
+
+       switch (dst_type) {
+       case isl_dim_param:     dim->nparam += n; break;
+       case isl_dim_in:        dim->n_in += n; break;
+       case isl_dim_out:       dim->n_out += n; break;
+       }
+
+       switch (src_type) {
+       case isl_dim_param:     dim->nparam -= n; break;
+       case isl_dim_in:        dim->n_in -= n; break;
+       case isl_dim_out:       dim->n_out -= n; break;
+       }
+
+       return dim;
+error:
+       isl_dim_free(dim);
+       return NULL;
+}
+
 struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
 {
        struct isl_dim *dim;
@@ -312,7 +499,7 @@ struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
 
        isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
                        goto error);
-       isl_assert(left->ctx, match(left, isl_dim_out, right, isl_dim_in),
+       isl_assert(left->ctx, n(left, isl_dim_out) == n(right, isl_dim_in),
                        goto error);
 
        dim = isl_dim_alloc(left->ctx, left->nparam, left->n_in, right->n_out);
@@ -386,10 +573,10 @@ struct isl_dim *isl_dim_map(struct isl_dim *dim)
        }
        dim->n_in = dim->n_out;
        if (names) {
-               copy_names(dim, isl_dim_out, 0, dim, isl_dim_in);
                free(dim->names);
                dim->names = names;
                dim->n_name = dim->nparam + dim->n_out + dim->n_out;
+               dim = copy_names(dim, isl_dim_out, 0, dim, isl_dim_in);
        }
        return dim;
 error:
@@ -448,8 +635,8 @@ error:
        return NULL;
 }
 
-struct isl_dim *isl_dim_drop_inputs(struct isl_dim *dim,
-               unsigned first, unsigned n)
+struct isl_dim *isl_dim_drop(struct isl_dim *dim, enum isl_dim_type type,
+               unsigned first, unsigned num)
 {
        int i;
 
@@ -459,57 +646,51 @@ struct isl_dim *isl_dim_drop_inputs(struct isl_dim *dim,
        if (n == 0)
                return dim;
 
-       isl_assert(dim->ctx, first + n <= dim->n_in, goto error);
+       isl_assert(dim->ctx, first + num <= n(dim, type), goto error);
        dim = isl_dim_cow(dim);
        if (!dim)
                goto error;
        if (dim->names) {
-               for (i = 0; i < n; ++i) {
-                       isl_name_free(dim->ctx,
-                                       get_name(dim, isl_dim_in, first+i));
+               dim = extend_names(dim);
+               if (!dim)
+                       goto error;
+               for (i = 0; i < num; ++i)
+                       isl_name_free(dim->ctx, get_name(dim, type, first+i));
+               for (i = first+num; i < n(dim, type); ++i)
+                       set_name(dim, type, i - num, get_name(dim, type, i));
+               switch (type) {
+               case isl_dim_param:
+                       get_names(dim, isl_dim_in, 0, dim->n_in,
+                               dim->names + offset(dim, isl_dim_in) - num);
+               case isl_dim_in:
+                       get_names(dim, isl_dim_out, 0, dim->n_out,
+                               dim->names + offset(dim, isl_dim_out) - num);
+               case isl_dim_out:
+                       ;
                }
-               for (i = first+n; i < dim->n_in; ++i)
-                       set_name(dim, isl_dim_in, i - n,
-                               get_name(dim, isl_dim_in, i));
-               get_names(dim, isl_dim_out, 0, dim->n_out,
-                               dim->names + dim->nparam + dim->n_in - n);
+               dim->n_name -= num;
+       }
+       switch (type) {
+       case isl_dim_param:     dim->nparam -= num; break;
+       case isl_dim_in:        dim->n_in -= num; break;
+       case isl_dim_out:       dim->n_out -= num; break;
        }
-       dim->n_in -= n;
        return dim;
 error:
        isl_dim_free(dim);
        return NULL;
 }
 
-struct isl_dim *isl_dim_drop_outputs(struct isl_dim *dim,
+struct isl_dim *isl_dim_drop_inputs(struct isl_dim *dim,
                unsigned first, unsigned n)
 {
-       int i;
-
-       if (!dim)
-               return NULL;
-
-       if (n == 0)
-               return dim;
+       return isl_dim_drop(dim, isl_dim_in, first, n);
+}
 
-       isl_assert(dim->ctx, first + n <= dim->n_out, goto error);
-       dim = isl_dim_cow(dim);
-       if (!dim)
-               goto error;
-       if (dim->names) {
-               for (i = 0; i < n; ++i) {
-                       isl_name_free(dim->ctx,
-                                       get_name(dim, isl_dim_out, first+i));
-               }
-               for (i = first+n; i < dim->n_out; ++i)
-                       set_name(dim, isl_dim_out, i - n,
-                               get_name(dim, isl_dim_out, i));
-       }
-       dim->n_out -= n;
-       return dim;
-error:
-       isl_dim_free(dim);
-       return NULL;
+struct isl_dim *isl_dim_drop_outputs(struct isl_dim *dim,
+               unsigned first, unsigned n)
+{
+       return isl_dim_drop(dim, isl_dim_out, first, n);
 }
 
 struct isl_dim *isl_dim_domain(struct isl_dim *dim)
@@ -520,6 +701,13 @@ struct isl_dim *isl_dim_domain(struct isl_dim *dim)
        return isl_dim_reverse(dim);
 }
 
+struct isl_dim *isl_dim_range(struct isl_dim *dim)
+{
+       if (!dim)
+               return NULL;
+       return isl_dim_drop_inputs(dim, 0, dim->n_in);
+}
+
 struct isl_dim *isl_dim_underlying(struct isl_dim *dim, unsigned n_div)
 {
        int i;
@@ -551,8 +739,8 @@ unsigned isl_dim_total(struct isl_dim *dim)
 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
 {
        return match(dim1, isl_dim_param, dim2, isl_dim_param) &&
-              match(dim1, isl_dim_in, dim2, isl_dim_in) &&
-              match(dim1, isl_dim_out, dim2, isl_dim_out);
+              n(dim1, isl_dim_in) == n(dim2, isl_dim_in) &&
+              n(dim1, isl_dim_out) == n(dim2, isl_dim_out);
 }
 
 int isl_dim_compatible(struct isl_dim *dim1, struct isl_dim *dim2)