isl_pw_qpolynomial_fold_bound: avoid access to freed memory
[platform/upstream/isl.git] / isl_dim.c
index d198fa5..4b5d2c0 100644 (file)
--- a/isl_dim.c
+++ b/isl_dim.c
@@ -118,7 +118,8 @@ static struct isl_dim *set_name(struct isl_dim *dim,
                goto error;
 
        pos = global_pos(dim, type, pos);
-       isl_assert(ctx, pos != isl_dim_total(dim), goto error);
+       if (pos == isl_dim_total(dim))
+               goto error;
 
        if (pos >= dim->n_name) {
                if (!name)
@@ -489,8 +490,17 @@ struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
        dim = isl_dim_reset(dim, type);
        switch (type) {
        case isl_dim_param:
-               return isl_dim_extend(dim,
+               dim = isl_dim_extend(dim,
                                        dim->nparam + n, dim->n_in, dim->n_out);
+               if (dim && dim->nested[0] &&
+                   !(dim->nested[0] = isl_dim_add(dim->nested[0],
+                                                   isl_dim_param, n)))
+                       goto error;
+               if (dim && dim->nested[1] &&
+                   !(dim->nested[1] = isl_dim_add(dim->nested[1],
+                                                   isl_dim_param, n)))
+                       goto error;
+               return dim;
        case isl_dim_in:
                return isl_dim_extend(dim,
                                        dim->nparam, dim->n_in + n, dim->n_out);
@@ -499,6 +509,9 @@ struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
                                        dim->nparam, dim->n_in, dim->n_out + n);
        }
        return dim;
+error:
+       isl_dim_free(dim);
+       return NULL;
 }
 
 __isl_give isl_dim *isl_dim_insert(__isl_take isl_dim *dim,
@@ -847,6 +860,16 @@ struct isl_dim *isl_dim_drop(struct isl_dim *dim, enum isl_dim_type type,
        case isl_dim_out:       dim->n_out -= num; break;
        }
        dim = isl_dim_reset(dim, type);
+       if (type == isl_dim_param) {
+               if (dim && dim->nested[0] &&
+                   !(dim->nested[0] = isl_dim_drop(dim->nested[0],
+                                                   isl_dim_param, first, num)))
+                       goto error;
+               if (dim && dim->nested[1] &&
+                   !(dim->nested[1] = isl_dim_drop(dim->nested[1],
+                                                   isl_dim_param, first, num)))
+                       goto error;
+       }
        return dim;
 error:
        isl_dim_free(dim);
@@ -935,7 +958,7 @@ struct isl_dim *isl_dim_underlying(struct isl_dim *dim, unsigned n_div)
 
 unsigned isl_dim_total(struct isl_dim *dim)
 {
-       return dim->nparam + dim->n_in + dim->n_out;
+       return dim ? dim->nparam + dim->n_in + dim->n_out : 0;
 }
 
 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
@@ -1085,10 +1108,10 @@ __isl_give isl_dim *isl_dim_flatten(__isl_take isl_dim *dim)
        if (!dim->nested[0] && !dim->nested[1])
                return dim;
 
-       isl_dim_free(dim->nested[0]);
-       dim->nested[0] = NULL;
-       isl_dim_free(dim->nested[1]);
-       dim->nested[1] = NULL;
+       if (dim->nested[0])
+               dim = isl_dim_reset(dim, isl_dim_in);
+       if (dim && dim->nested[1])
+               dim = isl_dim_reset(dim, isl_dim_out);
 
        return dim;
 }
@@ -1103,7 +1126,11 @@ __isl_give isl_dim *isl_dim_replace(__isl_take isl_dim *dst,
        if (!dst || !src)
                goto error;
 
-       if (type == isl_dim_param) {
+       dst = isl_dim_drop(dst, type, 0, isl_dim_size(dst, type));
+       dst = isl_dim_add(dst, type, isl_dim_size(src, type));
+       dst = copy_names(dst, type, 0, src, type);
+
+       if (dst && type == isl_dim_param) {
                int i;
                for (i = 0; i <= 1; ++i) {
                        if (!dst->nested[i])
@@ -1115,10 +1142,6 @@ __isl_give isl_dim *isl_dim_replace(__isl_take isl_dim *dst,
                }
        }
 
-       dst = isl_dim_drop(dst, type, 0, isl_dim_size(dst, type));
-       dst = isl_dim_add(dst, type, isl_dim_size(src, type));
-       dst = copy_names(dst, type, 0, src, type);
-
        return dst;
 error:
        isl_dim_free(dst);