isl_tab: don't create new undo records during rollback
[platform/upstream/isl.git] / isl_tab.c
index b3e5618..4f2e783 100644 (file)
--- a/isl_tab.c
+++ b/isl_tab.c
@@ -42,6 +42,7 @@ struct isl_tab *isl_tab_alloc(struct isl_ctx *ctx,
        }
        tab->n_row = 0;
        tab->n_con = 0;
+       tab->n_eq = 0;
        tab->max_con = n_row;
        tab->n_col = n_var;
        tab->n_var = n_var;
@@ -50,6 +51,7 @@ struct isl_tab *isl_tab_alloc(struct isl_ctx *ctx,
        tab->need_undo = 0;
        tab->rational = 0;
        tab->empty = 0;
+       tab->in_undo = 0;
        tab->bottom.type = isl_tab_undo_bottom;
        tab->bottom.next = NULL;
        tab->top = &tab->bottom;
@@ -114,6 +116,7 @@ void isl_tab_free(struct isl_ctx *ctx, struct isl_tab *tab)
                return;
        free_undo(ctx, tab);
        isl_mat_free(ctx, tab->mat);
+       isl_vec_free(tab->dual);
        free(tab->var);
        free(tab->con);
        free(tab->row_var);
@@ -234,6 +237,9 @@ static int pivot_row(struct isl_ctx *ctx, struct isl_tab *tab,
 
 /* Find a pivot (row and col) that will increase (sgn > 0) or decrease
  * (sgn < 0) the value of row variable var.
+ * If not NULL, then skip_var is a row variable that should be ignored
+ * while looking for a pivot row.  It is usually equal to var.
+ *
  * As the given row in the tableau is of the form
  *
  *     x_r = a_r0 + \sum_i a_ri x_i
@@ -246,7 +252,8 @@ static int pivot_row(struct isl_ctx *ctx, struct isl_tab *tab,
  * opposite direction.
  */
 static void find_pivot(struct isl_ctx *ctx, struct isl_tab *tab,
-       struct isl_tab_var *var, int sgn, int *row, int *col)
+       struct isl_tab_var *var, struct isl_tab_var *skip_var,
+       int sgn, int *row, int *col)
 {
        int j, r, c;
        isl_int *tr;
@@ -270,7 +277,7 @@ static void find_pivot(struct isl_ctx *ctx, struct isl_tab *tab,
                return;
 
        sgn *= isl_int_sgn(tr[2 + c]);
-       r = pivot_row(ctx, tab, var, sgn, c);
+       r = pivot_row(ctx, tab, skip_var, sgn, c);
        *row = r < 0 ? var->index : r;
        *col = c;
 }
@@ -478,6 +485,8 @@ static void pivot(struct isl_ctx *ctx,
        var = var_from_col(ctx, tab, col);
        var->is_row = 0;
        var->index = col;
+       if (tab->in_undo)
+               return;
        for (i = tab->n_redundant; i < tab->n_row; ++i) {
                if (isl_int_is_zero(mat->row[i][2 + col]))
                        continue;
@@ -538,7 +547,7 @@ static int sign_of_max(struct isl_ctx *ctx,
                return 1;
        to_row(ctx, tab, var, 1);
        while (!isl_int_is_pos(tab->mat->row[var->index][1])) {
-               find_pivot(ctx, tab, var, 1, &row, &col);
+               find_pivot(ctx, tab, var, var, 1, &row, &col);
                if (row == -1)
                        return isl_int_sgn(tab->mat->row[var->index][1]);
                pivot(ctx, tab, row, col);
@@ -559,7 +568,7 @@ static int restore_row(struct isl_ctx *ctx,
        int row, col;
 
        while (isl_int_is_neg(tab->mat->row[var->index][1])) {
-               find_pivot(ctx, tab, var, 1, &row, &col);
+               find_pivot(ctx, tab, var, var, 1, &row, &col);
                if (row == -1)
                        break;
                pivot(ctx, tab, row, col);
@@ -580,7 +589,7 @@ static int at_least_zero(struct isl_ctx *ctx,
        int row, col;
 
        while (isl_int_is_neg(tab->mat->row[var->index][1])) {
-               find_pivot(ctx, tab, var, 1, &row, &col);
+               find_pivot(ctx, tab, var, var, 1, &row, &col);
                if (row == -1)
                        break;
                if (row == var->index) /* manifestly unbounded */
@@ -636,7 +645,7 @@ static int sign_of_min(struct isl_ctx *ctx,
        if (var->is_redundant)
                return 0;
        while (!isl_int_is_neg(tab->mat->row[var->index][1])) {
-               find_pivot(ctx, tab, var, -1, &row, &col);
+               find_pivot(ctx, tab, var, var, -1, &row, &col);
                if (row == var->index)
                        return -1;
                if (row == -1)
@@ -694,7 +703,7 @@ static int min_at_most_neg_one(struct isl_ctx *ctx,
        if (var->is_redundant)
                return 0;
        do {
-               find_pivot(ctx, tab, var, -1, &row, &col);
+               find_pivot(ctx, tab, var, var, -1, &row, &col);
                if (row == var->index)
                        return 1;
                if (row == -1)
@@ -729,7 +738,7 @@ static int at_least_one(struct isl_ctx *ctx,
        to_row(ctx, tab, var, 1);
        r = tab->mat->row[var->index];
        while (isl_int_lt(r[1], r[0])) {
-               find_pivot(ctx, tab, var, 1, &row, &col);
+               find_pivot(ctx, tab, var, var, 1, &row, &col);
                if (row == -1)
                        return isl_int_ge(r[1], r[0]);
                if (row == var->index) /* manifestly unbounded */
@@ -919,6 +928,36 @@ error:
        return NULL;
 }
 
+/* Pivot a non-negative variable down until it reaches the value zero
+ * and then pivot the variable into a column position.
+ */
+static int to_col(struct isl_ctx *ctx,
+       struct isl_tab *tab, struct isl_tab_var *var)
+{
+       int i;
+       int row, col;
+
+       if (!var->is_row)
+               return;
+
+       while (isl_int_is_pos(tab->mat->row[var->index][1])) {
+               find_pivot(ctx, tab, var, NULL, -1, &row, &col);
+               isl_assert(ctx, row != -1, return -1);
+               pivot(ctx, tab, row, col);
+               if (!var->is_row)
+                       return;
+       }
+
+       for (i = tab->n_dead; i < tab->n_col; ++i)
+               if (!isl_int_is_zero(tab->mat->row[var->index][2 + i]))
+                       break;
+
+       isl_assert(ctx, i < tab->n_col, return -1);
+       pivot(ctx, tab, var->index, i);
+
+       return 0;
+}
+
 /* We assume Gaussian elimination has been performed on the equalities.
  * The equalities can therefore never conflict.
  * Adding the equalities is currently only really useful for a later call
@@ -944,6 +983,39 @@ static struct isl_tab *add_eq(struct isl_ctx *ctx,
                kill_col(ctx, tab, i);
                break;
        }
+       tab->n_eq++;
+
+       return tab;
+error:
+       isl_tab_free(ctx, tab);
+       return NULL;
+}
+
+/* Add an equality that is known to be valid for the given tableau.
+ */
+struct isl_tab *isl_tab_add_valid_eq(struct isl_ctx *ctx,
+       struct isl_tab *tab, isl_int *eq)
+{
+       struct isl_tab_var *var;
+       int i;
+       int r;
+
+       if (!tab)
+               return NULL;
+       r = add_row(ctx, tab, eq);
+       if (r < 0)
+               goto error;
+
+       var = &tab->con[r];
+       r = var->index;
+       if (isl_int_is_neg(tab->mat->row[r][1]))
+               isl_seq_neg(tab->mat->row[r] + 1, tab->mat->row[r] + 1,
+                           1 + tab->n_col);
+       var->is_nonneg = 1;
+       if (to_col(ctx, tab, var) < 0)
+               goto error;
+       var->is_nonneg = 0;
+       kill_col(ctx, tab, var->index);
 
        return tab;
 error:
@@ -1098,6 +1170,42 @@ static struct isl_vec *extract_integer_sample(struct isl_ctx *ctx,
        return vec;
 }
 
+struct isl_vec *isl_tab_get_sample_value(struct isl_ctx *ctx,
+                                               struct isl_tab *tab)
+{
+       int i;
+       struct isl_vec *vec;
+       isl_int m;
+
+       if (!tab)
+               return NULL;
+
+       vec = isl_vec_alloc(ctx, 1 + tab->n_var);
+       if (!vec)
+               return NULL;
+
+       isl_int_init(m);
+
+       isl_int_set_si(vec->block.data[0], 1);
+       for (i = 0; i < tab->n_var; ++i) {
+               int row;
+               if (!tab->var[i].is_row) {
+                       isl_int_set_si(vec->block.data[1 + i], 0);
+                       continue;
+               }
+               row = tab->var[i].index;
+               isl_int_gcd(m, vec->block.data[0], tab->mat->row[row][0]);
+               isl_int_divexact(m, tab->mat->row[row][0], m);
+               isl_seq_scale(vec->block.data, vec->block.data, m, 1 + i);
+               isl_int_divexact(m, vec->block.data[0], tab->mat->row[row][0]);
+               isl_int_mul(vec->block.data[1 + i], m, tab->mat->row[row][1]);
+       }
+       isl_seq_normalize(vec->block.data, vec->size);
+
+       isl_int_clear(m);
+       return vec;
+}
+
 /* Update "bmap" based on the results of the tableau "tab".
  * In particular, implicit equalities are made explicit, redundant constraints
  * are removed and if the sample value happens to be integer, it is stored
@@ -1117,7 +1225,7 @@ struct isl_basic_map *isl_basic_map_update_from_tab(struct isl_basic_map *bmap,
        if (!tab)
                return bmap;
 
-       n_eq = bmap->n_eq;
+       n_eq = tab->n_eq;
        if (tab->empty)
                bmap = isl_basic_map_set_to_empty(bmap);
        else
@@ -1443,15 +1551,18 @@ int isl_tab_is_equality(struct isl_ctx *ctx, struct isl_tab *tab, int con)
  * minmimal value returned in *opt).
  */
 enum isl_lp_result isl_tab_min(struct isl_ctx *ctx, struct isl_tab *tab,
-       isl_int *f, isl_int denom, isl_int *opt, isl_int *opt_denom)
+       isl_int *f, isl_int denom, isl_int *opt, isl_int *opt_denom,
+       unsigned flags)
 {
        int r;
        enum isl_lp_result res = isl_lp_ok;
        struct isl_tab_var *var;
+       struct isl_tab_undo *snap;
 
        if (tab->empty)
                return isl_lp_empty;
 
+       snap = isl_tab_snap(ctx, tab);
        r = add_row(ctx, tab, f);
        if (r < 0)
                return isl_lp_error;
@@ -1460,7 +1571,7 @@ enum isl_lp_result isl_tab_min(struct isl_ctx *ctx, struct isl_tab *tab,
                    tab->mat->row[var->index][0], denom);
        for (;;) {
                int row, col;
-               find_pivot(ctx, tab, var, -1, &row, &col);
+               find_pivot(ctx, tab, var, var, -1, &row, &col);
                if (row == var->index) {
                        res = isl_lp_unbounded;
                        break;
@@ -1469,8 +1580,26 @@ enum isl_lp_result isl_tab_min(struct isl_ctx *ctx, struct isl_tab *tab,
                        break;
                pivot(ctx, tab, row, col);
        }
-       if (drop_row(ctx, tab, var->index) < 0)
+       if (isl_tab_rollback(ctx, tab, snap) < 0)
                return isl_lp_error;
+       if (ISL_FL_ISSET(flags, ISL_TAB_SAVE_DUAL)) {
+               int i;
+
+               isl_vec_free(tab->dual);
+               tab->dual = isl_vec_alloc(ctx, 1 + tab->n_con);
+               if (!tab->dual)
+                       return isl_lp_error;
+               isl_int_set(tab->dual->el[0], tab->mat->row[var->index][0]);
+               for (i = 0; i < tab->n_con; ++i) {
+                       if (tab->con[i].is_row)
+                               isl_int_set_si(tab->dual->el[1 + i], 0);
+                       else {
+                               int pos = 2 + tab->con[i].index;
+                               isl_int_set(tab->dual->el[1 + i],
+                                           tab->mat->row[var->index][pos]);
+                       }
+               }
+       }
        if (res == isl_lp_ok) {
                if (opt_denom) {
                        isl_int_set(*opt, tab->mat->row[var->index][1]);
@@ -1575,6 +1704,7 @@ int isl_tab_rollback(struct isl_ctx *ctx, struct isl_tab *tab,
        if (!tab)
                return -1;
 
+       tab->in_undo = 1;
        for (undo = tab->top; undo && undo != &tab->bottom; undo = next) {
                next = undo->next;
                if (undo == snap)
@@ -1582,6 +1712,7 @@ int isl_tab_rollback(struct isl_ctx *ctx, struct isl_tab *tab,
                perform_undo(ctx, tab, undo);
                free(undo);
        }
+       tab->in_undo = 0;
        tab->top = undo;
        if (!undo)
                return -1;
@@ -1629,7 +1760,7 @@ static enum isl_ineq_type separation_type(struct isl_ctx *ctx,
 /* Check the effect of inequality "ineq" on the tableau "tab".
  * The result may be
  *     isl_ineq_redundant:     satisfied by all points in the tableau
- *     isl_ineq_separate:      satisfied by no point in tha tableau
+ *     isl_ineq_separate:      satisfied by no point in the tableau
  *     isl_ineq_cut:           satisfied by some by not all points
  *     isl_ineq_adj_eq:        adjacent to an equality
  *     isl_ineq_adj_ineq:      adjacent to an inequality.