isl_basic_set_opt: avoid invalid access on error path
[platform/upstream/isl.git] / isl_flow.c
index 2059365..8dca5c4 100644 (file)
@@ -4,7 +4,7 @@
  * Copyright 2010      INRIA Saclay
  * Copyright 2012      Universiteit Leiden
  *
- * Use of this software is governed by the GNU LGPLv2.1 license
+ * Use of this software is governed by the MIT license
  *
  * Written by Sven Verdoolaege, Leiden Institute of Advanced Computer Science,
  * Universiteit Leiden, Niels Bohrweg 1, 2333 CA Leiden, The Netherlands
@@ -17,7 +17,7 @@
 #include <isl/set.h>
 #include <isl/map.h>
 #include <isl/flow.h>
-#include <isl_qsort.h>
+#include <isl_sort.h>
 
 enum isl_restriction_type {
        isl_restriction_type_empty,
@@ -33,9 +33,10 @@ struct isl_restriction {
        isl_set *sink;
 };
 
-/* Create a restriction that doesn't restrict anything.
+/* Create a restriction of the given type.
  */
-__isl_give isl_restriction *isl_restriction_none(__isl_keep isl_map *source_map)
+static __isl_give isl_restriction *isl_restriction_alloc(
+       __isl_take isl_map *source_map, enum isl_restriction_type type)
 {
        isl_ctx *ctx;
        isl_restriction *restr;
@@ -46,32 +47,30 @@ __isl_give isl_restriction *isl_restriction_none(__isl_keep isl_map *source_map)
        ctx = isl_map_get_ctx(source_map);
        restr = isl_calloc_type(ctx, struct isl_restriction);
        if (!restr)
-               return NULL;
+               goto error;
 
-       restr->type = isl_restriction_type_none;
+       restr->type = type;
 
+       isl_map_free(source_map);
        return restr;
+error:
+       isl_map_free(source_map);
+       return NULL;
+}
+
+/* Create a restriction that doesn't restrict anything.
+ */
+__isl_give isl_restriction *isl_restriction_none(__isl_take isl_map *source_map)
+{
+       return isl_restriction_alloc(source_map, isl_restriction_type_none);
 }
 
 /* Create a restriction that removes everything.
  */
 __isl_give isl_restriction *isl_restriction_empty(
-       __isl_keep isl_map *source_map)
+       __isl_take isl_map *source_map)
 {
-       isl_ctx *ctx;
-       isl_restriction *restr;
-
-       if (!source_map)
-               return NULL;
-
-       ctx = isl_map_get_ctx(source_map);
-       restr = isl_calloc_type(ctx, struct isl_restriction);
-       if (!restr)
-               return NULL;
-
-       restr->type = isl_restriction_type_empty;
-
-       return restr;
+       return isl_restriction_alloc(source_map, isl_restriction_type_empty);
 }
 
 /* Create a restriction on the input of the maximization problem
@@ -139,6 +138,11 @@ void *isl_restriction_free(__isl_take isl_restriction *restr)
        return NULL;
 }
 
+isl_ctx *isl_restriction_get_ctx(__isl_keep isl_restriction *restr)
+{
+       return restr ? isl_set_get_ctx(restr->source) : NULL;
+}
+
 /* A private structure to keep track of a mapping together with
  * a user-specified identifier and a boolean indicating whether
  * the map represents a must or may access/dependence.
@@ -223,17 +227,18 @@ error:
 
 /* Free the given isl_access_info structure.
  */
-void isl_access_info_free(__isl_take isl_access_info *acc)
+void *isl_access_info_free(__isl_take isl_access_info *acc)
 {
        int i;
 
        if (!acc)
-               return;
+               return NULL;
        isl_map_free(acc->domain_map);
        isl_map_free(acc->sink.map);
        for (i = 0; i < acc->n_must + acc->n_may; ++i)
                isl_map_free(acc->source[i].map);
        free(acc);
+       return NULL;
 }
 
 isl_ctx *isl_access_info_get_ctx(__isl_keep isl_access_info *acc)
@@ -264,7 +269,7 @@ __isl_give isl_access_info *isl_access_info_add_source(
        isl_ctx *ctx;
 
        if (!acc)
-               return NULL;
+               goto error;
        ctx = isl_map_get_ctx(acc->sink.map);
        isl_assert(ctx, acc->n_must + acc->n_may < acc->max_source, goto error);
        
@@ -335,8 +340,9 @@ static __isl_give isl_access_info *isl_access_info_sort_sources(
        if (acc->n_must <= 1)
                return acc;
 
-       isl_quicksort(acc->source, acc->n_must, sizeof(struct isl_labeled_map),
-               access_sort_cmp, acc);
+       if (isl_sort(acc->source, acc->n_must, sizeof(struct isl_labeled_map),
+                   access_sort_cmp, acc) < 0)
+               return isl_access_info_free(acc);
 
        return acc;
 }
@@ -1121,6 +1127,7 @@ static __isl_give struct isl_sched_info *sched_info_alloc(
        isl_space *dim;
        struct isl_sched_info *info;
        int i, n;
+       isl_int v;
 
        if (!map)
                return NULL;
@@ -1140,9 +1147,13 @@ static __isl_give struct isl_sched_info *sched_info_alloc(
        if (!info->is_cst || !info->cst)
                goto error;
 
-       for (i = 0; i < n; ++i)
+       isl_int_init(v);
+       for (i = 0; i < n; ++i) {
                info->is_cst[i] = isl_map_plain_is_fixed(map, isl_dim_in, i,
-                                                       &info->cst->el[i]);
+                                                        &v);
+               info->cst = isl_vec_set_element(info->cst, i, v);
+       }
+       isl_int_clear(v);
 
        return info;
 error:
@@ -1243,22 +1254,36 @@ static int before(void *first, void *second)
        struct isl_sched_info *info2 = second;
        int n1, n2;
        int i;
+       isl_int v1, v2;
 
-       n1 = info1->cst->size;
-       n2 = info2->cst->size;
+       n1 = isl_vec_size(info1->cst);
+       n2 = isl_vec_size(info2->cst);
 
        if (n2 < n1)
                n1 = n2;
 
+       isl_int_init(v1);
+       isl_int_init(v2);
        for (i = 0; i < n1; ++i) {
+               int r;
+
                if (!info1->is_cst[i])
                        continue;
                if (!info2->is_cst[i])
                        continue;
-               if (isl_int_eq(info1->cst->el[i], info2->cst->el[i]))
+               isl_vec_get_element(info1->cst, i, &v1);
+               isl_vec_get_element(info2->cst, i, &v2);
+               if (isl_int_eq(v1, v2))
                        continue;
-               return 2 * i + isl_int_lt(info1->cst->el[i], info2->cst->el[i]);
+
+               r = 2 * i + isl_int_lt(v1, v2);
+
+               isl_int_clear(v1);
+               isl_int_clear(v2);
+               return r;
        }
+       isl_int_clear(v1);
+       isl_int_clear(v2);
 
        return 2 * n1;
 }