isl_tab_rollback: avoid double free on error path
[platform/upstream/isl.git] / isl_tab.c
index 1b9a2f0..b8a80be 100644 (file)
--- a/isl_tab.c
+++ b/isl_tab.c
@@ -64,6 +64,7 @@ struct isl_tab *isl_tab_alloc(struct isl_ctx *ctx,
        tab->n_div = 0;
        tab->n_dead = 0;
        tab->n_redundant = 0;
+       tab->strict_redundant = 0;
        tab->need_undo = 0;
        tab->rational = 0;
        tab->empty = 0;
@@ -86,11 +87,13 @@ error:
 
 int isl_tab_extend_cons(struct isl_tab *tab, unsigned n_new)
 {
-       unsigned off = 2 + tab->M;
+       unsigned off;
 
        if (!tab)
                return -1;
 
+       off = 2 + tab->M;
+
        if (tab->max_con < tab->n_con + n_new) {
                struct isl_tab_var *con;
 
@@ -208,34 +211,34 @@ struct isl_tab *isl_tab_dup(struct isl_tab *tab)
                return NULL;
 
        off = 2 + tab->M;
-       dup = isl_calloc_type(tab->ctx, struct isl_tab);
+       dup = isl_calloc_type(tab->mat->ctx, struct isl_tab);
        if (!dup)
                return NULL;
        dup->mat = isl_mat_dup(tab->mat);
        if (!dup->mat)
                goto error;
-       dup->var = isl_alloc_array(tab->ctx, struct isl_tab_var, tab->max_var);
+       dup->var = isl_alloc_array(tab->mat->ctx, struct isl_tab_var, tab->max_var);
        if (!dup->var)
                goto error;
        for (i = 0; i < tab->n_var; ++i)
                dup->var[i] = tab->var[i];
-       dup->con = isl_alloc_array(tab->ctx, struct isl_tab_var, tab->max_con);
+       dup->con = isl_alloc_array(tab->mat->ctx, struct isl_tab_var, tab->max_con);
        if (!dup->con)
                goto error;
        for (i = 0; i < tab->n_con; ++i)
                dup->con[i] = tab->con[i];
-       dup->col_var = isl_alloc_array(tab->ctx, int, tab->mat->n_col - off);
+       dup->col_var = isl_alloc_array(tab->mat->ctx, int, tab->mat->n_col - off);
        if (!dup->col_var)
                goto error;
        for (i = 0; i < tab->n_col; ++i)
                dup->col_var[i] = tab->col_var[i];
-       dup->row_var = isl_alloc_array(tab->ctx, int, tab->mat->n_row);
+       dup->row_var = isl_alloc_array(tab->mat->ctx, int, tab->mat->n_row);
        if (!dup->row_var)
                goto error;
        for (i = 0; i < tab->n_row; ++i)
                dup->row_var[i] = tab->row_var[i];
        if (tab->row_sign) {
-               dup->row_sign = isl_alloc_array(tab->ctx, enum isl_tab_row_sign,
+               dup->row_sign = isl_alloc_array(tab->mat->ctx, enum isl_tab_row_sign,
                                                tab->mat->n_row);
                if (!dup->row_sign)
                        goto error;
@@ -266,6 +269,7 @@ struct isl_tab *isl_tab_dup(struct isl_tab *tab)
        dup->n_redundant = tab->n_redundant;
        dup->rational = tab->rational;
        dup->empty = tab->empty;
+       dup->strict_redundant = 0;
        dup->need_undo = 0;
        dup->in_undo = 0;
        dup->M = tab->M;
@@ -308,6 +312,8 @@ static struct isl_mat *tab_mat_product(struct isl_mat *mat1,
 
        prod = isl_mat_alloc(mat1->ctx, mat1->n_row + mat2->n_row,
                                        off + col1 + col2);
+       if (!prod)
+               return NULL;
 
        n = 0;
        for (i = 0; i < r1; ++i) {
@@ -516,6 +522,7 @@ struct isl_tab *isl_tab_product(struct isl_tab *tab1, struct isl_tab *tab2)
        prod->n_redundant = tab1->n_redundant + tab2->n_redundant;
        prod->rational = tab1->rational;
        prod->empty = tab1->empty || tab2->empty;
+       prod->strict_redundant = tab1->strict_redundant || tab2->strict_redundant;
        prod->need_undo = 0;
        prod->in_undo = 0;
        prod->M = tab1->M;
@@ -723,6 +730,8 @@ int isl_tab_row_is_redundant(struct isl_tab *tab, int row)
 
        if (isl_int_is_neg(tab->mat->row[row][1]))
                return 0;
+       if (tab->strict_redundant && isl_int_is_zero(tab->mat->row[row][1]))
+               return 0;
        if (tab->M && isl_int_is_neg(tab->mat->row[row][2]))
                return 0;
 
@@ -742,6 +751,8 @@ int isl_tab_row_is_redundant(struct isl_tab *tab, int row)
 static void swap_rows(struct isl_tab *tab, int row1, int row2)
 {
        int t;
+       enum isl_tab_row_sign s;
+
        t = tab->row_var[row1];
        tab->row_var[row1] = tab->row_var[row2];
        tab->row_var[row2] = t;
@@ -751,9 +762,9 @@ static void swap_rows(struct isl_tab *tab, int row1, int row2)
 
        if (!tab->row_sign)
                return;
-       t = tab->row_sign[row1];
+       s = tab->row_sign[row1];
        tab->row_sign[row1] = tab->row_sign[row2];
-       tab->row_sign[row2] = t;
+       tab->row_sign[row2] = s;
 }
 
 static int push_union(struct isl_tab *tab,
@@ -1165,10 +1176,19 @@ static void check_table(struct isl_tab *tab)
 
        if (tab->empty)
                return;
-       for (i = 0; i < tab->n_row; ++i) {
-               if (!isl_tab_var_from_row(tab, i)->is_nonneg)
+       for (i = tab->n_redundant; i < tab->n_row; ++i) {
+               struct isl_tab_var *var;
+               var = isl_tab_var_from_row(tab, i);
+               if (!var->is_nonneg)
                        continue;
-               assert(!isl_int_is_neg(tab->mat->row[i][1]));
+               if (tab->M) {
+                       isl_assert(tab->mat->ctx,
+                               !isl_int_is_neg(tab->mat->row[i][2]), abort());
+                       if (isl_int_is_pos(tab->mat->row[i][2]))
+                               continue;
+               }
+               isl_assert(tab->mat->ctx, !isl_int_is_neg(tab->mat->row[i][1]),
+                               abort());
        }
 }
 
@@ -1203,6 +1223,20 @@ static int sign_of_max(struct isl_tab *tab, struct isl_tab_var *var)
        return 1;
 }
 
+int isl_tab_sign_of_max(struct isl_tab *tab, int con)
+{
+       struct isl_tab_var *var;
+
+       if (!tab)
+               return -2;
+
+       var = &tab->con[con];
+       isl_assert(tab->mat->ctx, !var->is_redundant, return -2);
+       isl_assert(tab->mat->ctx, !var->is_zero, return -2);
+
+       return sign_of_max(tab, var);
+}
+
 static int row_is_neg(struct isl_tab *tab, int row)
 {
        if (!tab->M)
@@ -1354,7 +1388,8 @@ static int row_at_most_neg_one(struct isl_tab *tab, int row)
  * Return 0 otherwise.
  *
  * The sample value of "var" is assumed to be non-negative when the
- * the function is called and will be made non-negative again before
+ * the function is called.  If 1 is returned then the constraint
+ * is not redundant and the sample value is made non-negative again before
  * the function returns.
  */
 int isl_tab_min_at_most_neg_one(struct isl_tab *tab, struct isl_tab_var *var)
@@ -1389,8 +1424,11 @@ int isl_tab_min_at_most_neg_one(struct isl_tab *tab, struct isl_tab_var *var)
                return 0;
        do {
                find_pivot(tab, var, var, -1, &row, &col);
-               if (row == var->index)
+               if (row == var->index) {
+                       if (restore_row(tab, var) < -1)
+                               return -1;
                        return 1;
+               }
                if (row == -1)
                        return 0;
                pivot_var = var_from_col(tab, col);
@@ -1500,11 +1538,15 @@ static int close_row(struct isl_tab *tab, struct isl_tab_var *var)
                if (isl_tab_push_var(tab, isl_tab_undo_zero, var) < 0)
                        return -1;
        for (j = tab->n_dead; j < tab->n_col; ++j) {
+               int recheck;
                if (isl_int_is_zero(mat->row[var->index][off + j]))
                        continue;
                isl_assert(tab->mat->ctx,
                    isl_int_is_neg(mat->row[var->index][off + j]), return -1);
-               if (isl_tab_kill_col(tab, j))
+               recheck = isl_tab_kill_col(tab, j);
+               if (recheck < 0)
+                       return -1;
+               if (recheck)
                        --j;
        }
        if (isl_tab_mark_redundant(tab, var->index) < 0)
@@ -1632,7 +1674,7 @@ int isl_tab_add_row(struct isl_tab *tab, isl_int *line)
        isl_int_clear(b);
 
        if (tab->row_sign)
-               tab->row_sign[tab->con[r].index] = 0;
+               tab->row_sign[tab->con[r].index] = isl_tab_row_unknown;
 
        return r;
 }
@@ -1791,24 +1833,24 @@ static int row_is_manifestly_zero(struct isl_tab *tab, int row)
 
 /* Add an equality that is known to be valid for the given tableau.
  */
-struct isl_tab *isl_tab_add_valid_eq(struct isl_tab *tab, isl_int *eq)
+int isl_tab_add_valid_eq(struct isl_tab *tab, isl_int *eq)
 {
        struct isl_tab_var *var;
        int r;
 
        if (!tab)
-               return NULL;
+               return -1;
        r = isl_tab_add_row(tab, eq);
        if (r < 0)
-               goto error;
+               return -1;
 
        var = &tab->con[r];
        r = var->index;
        if (row_is_manifestly_zero(tab, r)) {
                var->is_zero = 1;
                if (isl_tab_mark_redundant(tab, r) < 0)
-                       goto error;
-               return tab;
+                       return -1;
+               return 0;
        }
 
        if (isl_int_is_neg(tab->mat->row[r][1])) {
@@ -1818,15 +1860,12 @@ struct isl_tab *isl_tab_add_valid_eq(struct isl_tab *tab, isl_int *eq)
        }
        var->is_nonneg = 1;
        if (to_col(tab, var) < 0)
-               goto error;
+               return -1;
        var->is_nonneg = 0;
        if (isl_tab_kill_col(tab, var->index) < 0)
-               goto error;
+               return -1;
 
-       return tab;
-error:
-       isl_tab_free(tab);
-       return NULL;
+       return 0;
 }
 
 static int add_zero_row(struct isl_tab *tab)
@@ -1848,7 +1887,7 @@ static int add_zero_row(struct isl_tab *tab)
 /* Add equality "eq" and check if it conflicts with the
  * previously added constraints or if it is obviously redundant.
  */
-struct isl_tab *isl_tab_add_eq(struct isl_tab *tab, isl_int *eq)
+int isl_tab_add_eq(struct isl_tab *tab, isl_int *eq)
 {
        struct isl_tab_undo *snap = NULL;
        struct isl_tab_var *var;
@@ -1858,8 +1897,8 @@ struct isl_tab *isl_tab_add_eq(struct isl_tab *tab, isl_int *eq)
        isl_int cst;
 
        if (!tab)
-               return NULL;
-       isl_assert(tab->mat->ctx, !tab->M, goto error);
+               return -1;
+       isl_assert(tab->mat->ctx, !tab->M, return -1);
 
        if (tab->need_undo)
                snap = isl_tab_snap(tab);
@@ -1874,32 +1913,32 @@ struct isl_tab *isl_tab_add_eq(struct isl_tab *tab, isl_int *eq)
                isl_int_clear(cst);
        }
        if (r < 0)
-               goto error;
+               return -1;
 
        var = &tab->con[r];
        row = var->index;
        if (row_is_manifestly_zero(tab, row)) {
                if (snap) {
                        if (isl_tab_rollback(tab, snap) < 0)
-                               goto error;
+                               return -1;
                } else
                        drop_row(tab, row);
-               return tab;
+               return 0;
        }
 
        if (tab->bmap) {
                tab->bmap = isl_basic_map_add_ineq(tab->bmap, eq);
                if (isl_tab_push(tab, isl_tab_undo_bmap_ineq) < 0)
-                       goto error;
+                       return -1;
                isl_seq_neg(eq, eq, 1 + tab->n_var);
                tab->bmap = isl_basic_map_add_ineq(tab->bmap, eq);
                isl_seq_neg(eq, eq, 1 + tab->n_var);
                if (isl_tab_push(tab, isl_tab_undo_bmap_ineq) < 0)
-                       goto error;
+                       return -1;
                if (!tab->bmap)
-                       goto error;
+                       return -1;
                if (add_zero_row(tab) < 0)
-                       goto error;
+                       return -1;
        }
 
        sgn = isl_int_sgn(tab->mat->row[row][1]);
@@ -1914,25 +1953,22 @@ struct isl_tab *isl_tab_add_eq(struct isl_tab *tab, isl_int *eq)
        if (sgn < 0) {
                sgn = sign_of_max(tab, var);
                if (sgn < -1)
-                       goto error;
+                       return -1;
                if (sgn < 0) {
                        if (isl_tab_mark_empty(tab) < 0)
-                               goto error;
-                       return tab;
+                               return -1;
+                       return 0;
                }
        }
 
        var->is_nonneg = 1;
        if (to_col(tab, var) < 0)
-               goto error;
+               return -1;
        var->is_nonneg = 0;
        if (isl_tab_kill_col(tab, var->index) < 0)
-               goto error;
+               return -1;
 
-       return tab;
-error:
-       isl_tab_free(tab);
-       return NULL;
+       return 0;
 }
 
 /* Construct and return an inequality that expresses an upper bound
@@ -2121,16 +2157,20 @@ struct isl_tab *isl_tab_from_basic_set(struct isl_basic_set *bset)
 
 /* Construct a tableau corresponding to the recession cone of "bset".
  */
-struct isl_tab *isl_tab_from_recession_cone(struct isl_basic_set *bset)
+struct isl_tab *isl_tab_from_recession_cone(__isl_keep isl_basic_set *bset,
+       int parametric)
 {
        isl_int cst;
        int i;
        struct isl_tab *tab;
+       unsigned offset = 0;
 
        if (!bset)
                return NULL;
+       if (parametric)
+               offset = isl_basic_set_dim(bset, isl_dim_param);
        tab = isl_tab_alloc(bset->ctx, bset->n_eq + bset->n_ineq,
-                               isl_basic_set_total_dim(bset), 0);
+                               isl_basic_set_total_dim(bset) - offset, 0);
        if (!tab)
                return NULL;
        tab->rational = ISL_F_ISSET(bset, ISL_BASIC_SET_RATIONAL);
@@ -2138,17 +2178,21 @@ struct isl_tab *isl_tab_from_recession_cone(struct isl_basic_set *bset)
 
        isl_int_init(cst);
        for (i = 0; i < bset->n_eq; ++i) {
-               isl_int_swap(bset->eq[i][0], cst);
-               tab = add_eq(tab, bset->eq[i]);
-               isl_int_swap(bset->eq[i][0], cst);
+               isl_int_swap(bset->eq[i][offset], cst);
+               if (offset > 0) {
+                       if (isl_tab_add_eq(tab, bset->eq[i] + offset) < 0)
+                               goto error;
+               } else
+                       tab = add_eq(tab, bset->eq[i]);
+               isl_int_swap(bset->eq[i][offset], cst);
                if (!tab)
                        goto done;
        }
        for (i = 0; i < bset->n_ineq; ++i) {
                int r;
-               isl_int_swap(bset->ineq[i][0], cst);
-               r = isl_tab_add_row(tab, bset->ineq[i]);
-               isl_int_swap(bset->ineq[i][0], cst);
+               isl_int_swap(bset->ineq[i][offset], cst);
+               r = isl_tab_add_row(tab, bset->ineq[i] + offset);
+               isl_int_swap(bset->ineq[i][offset], cst);
                if (r < 0)
                        goto error;
                tab->con[r].is_nonneg = 1;
@@ -2331,8 +2375,7 @@ struct isl_basic_set *isl_basic_set_update_from_tab(struct isl_basic_set *bset,
  * the resulting tableau is empty.
  * Otherwise, we know the value will be zero and we close the row.
  */
-static struct isl_tab *cut_to_hyperplane(struct isl_tab *tab,
-       struct isl_tab_var *var)
+static int cut_to_hyperplane(struct isl_tab *tab, struct isl_tab_var *var)
 {
        unsigned r;
        isl_int *row;
@@ -2340,12 +2383,12 @@ static struct isl_tab *cut_to_hyperplane(struct isl_tab *tab,
        unsigned off = 2 + tab->M;
 
        if (var->is_zero)
-               return tab;
-       isl_assert(tab->mat->ctx, !var->is_redundant, goto error);
-       isl_assert(tab->mat->ctx, var->is_nonneg, goto error);
+               return 0;
+       isl_assert(tab->mat->ctx, !var->is_redundant, return -1);
+       isl_assert(tab->mat->ctx, var->is_nonneg, return -1);
 
        if (isl_tab_extend_cons(tab, 1) < 0)
-               goto error;
+               return -1;
 
        r = tab->n_con;
        tab->con[r].index = tab->n_row;
@@ -2371,27 +2414,24 @@ static struct isl_tab *cut_to_hyperplane(struct isl_tab *tab,
        tab->n_row++;
        tab->n_con++;
        if (isl_tab_push_var(tab, isl_tab_undo_allocate, &tab->con[r]) < 0)
-               goto error;
+               return -1;
 
        sgn = sign_of_max(tab, &tab->con[r]);
        if (sgn < -1)
-               goto error;
+               return -1;
        if (sgn < 0) {
                if (isl_tab_mark_empty(tab) < 0)
-                       goto error;
-               return tab;
+                       return -1;
+               return 0;
        }
        tab->con[r].is_nonneg = 1;
        if (isl_tab_push_var(tab, isl_tab_undo_nonneg, &tab->con[r]) < 0)
-               goto error;
+               return -1;
        /* sgn == 0 */
        if (close_row(tab, &tab->con[r]) < 0)
-               goto error;
+               return -1;
 
-       return tab;
-error:
-       isl_tab_free(tab);
-       return NULL;
+       return 0;
 }
 
 /* Given a tableau "tab" and an inequality constraint "con" of the tableau,
@@ -2454,10 +2494,10 @@ error:
        return NULL;
 }
 
-struct isl_tab *isl_tab_select_facet(struct isl_tab *tab, int con)
+int isl_tab_select_facet(struct isl_tab *tab, int con)
 {
        if (!tab)
-               return NULL;
+               return -1;
 
        return cut_to_hyperplane(tab, &tab->con[con]);
 }
@@ -2465,11 +2505,9 @@ struct isl_tab *isl_tab_select_facet(struct isl_tab *tab, int con)
 static int may_be_equality(struct isl_tab *tab, int row)
 {
        unsigned off = 2 + tab->M;
-       return (tab->rational ? isl_int_is_zero(tab->mat->row[row][1])
-                             : isl_int_lt(tab->mat->row[row][1],
-                                           tab->mat->row[row][0])) &&
-               isl_seq_first_non_zero(tab->mat->row[row] + off + tab->n_dead,
-                                       tab->n_col - tab->n_dead) != -1;
+       return tab->rational ? isl_int_is_zero(tab->mat->row[row][1])
+                            : isl_int_lt(tab->mat->row[row][1],
+                                           tab->mat->row[row][0]);
 }
 
 /* Check for (near) equalities among the constraints.
@@ -2488,17 +2526,17 @@ static int may_be_equality(struct isl_tab *tab, int row)
  * tableau is integer), then we restrict the value to being zero
  * by adding an opposite non-negative variable.
  */
-struct isl_tab *isl_tab_detect_implicit_equalities(struct isl_tab *tab)
+int isl_tab_detect_implicit_equalities(struct isl_tab *tab)
 {
        int i;
        unsigned n_marked;
 
        if (!tab)
-               return NULL;
+               return -1;
        if (tab->empty)
-               return tab;
+               return 0;
        if (tab->n_dead == tab->n_col)
-               return tab;
+               return 0;
 
        n_marked = 0;
        for (i = tab->n_redundant; i < tab->n_row; ++i) {
@@ -2535,12 +2573,13 @@ struct isl_tab *isl_tab_detect_implicit_equalities(struct isl_tab *tab)
                n_marked--;
                sgn = sign_of_max(tab, var);
                if (sgn < 0)
-                       goto error;
+                       return -1;
                if (sgn == 0) {
                        if (close_row(tab, var) < 0)
-                               goto error;
+                               return -1;
                } else if (!tab->rational && !at_least_one(tab, var)) {
-                       tab = cut_to_hyperplane(tab, var);
+                       if (cut_to_hyperplane(tab, var) < 0)
+                               return -1;
                        return isl_tab_detect_implicit_equalities(tab);
                }
                for (i = tab->n_redundant; i < tab->n_row; ++i) {
@@ -2554,10 +2593,7 @@ struct isl_tab *isl_tab_detect_implicit_equalities(struct isl_tab *tab)
                }
        }
 
-       return tab;
-error:
-       isl_tab_free(tab);
-       return NULL;
+       return 0;
 }
 
 static int con_is_redundant(struct isl_tab *tab, struct isl_tab_var *var)
@@ -2693,6 +2729,9 @@ enum isl_lp_result isl_tab_min(struct isl_tab *tab,
        struct isl_tab_var *var;
        struct isl_tab_undo *snap;
 
+       if (!tab)
+               return isl_lp_error;
+
        if (tab->empty)
                return isl_lp_empty;
 
@@ -2817,6 +2856,7 @@ static int perform_undo_var(struct isl_tab *tab, struct isl_tab_undo *undo)
        case isl_tab_undo_redundant:
                var->is_redundant = 0;
                tab->n_redundant--;
+               restore_row(tab, isl_tab_var_from_row(tab, tab->n_redundant));
                break;
        case isl_tab_undo_freeze:
                var->frozen = 0;
@@ -2990,6 +3030,7 @@ int isl_tab_rollback(struct isl_tab *tab, struct isl_tab_undo *snap)
                if (undo == snap)
                        break;
                if (perform_undo(tab, undo) < 0) {
+                       tab->top = undo;
                        free_undo(tab);
                        tab->in_undo = 0;
                        return -1;