isl_basic_set_solve_ilp: fix handling of sets with equalities
[platform/upstream/isl.git] / isl_ilp.c
index 67e734f..15110c1 100644 (file)
--- a/isl_ilp.c
+++ b/isl_ilp.c
@@ -110,16 +110,10 @@ error:
        return NULL;
 }
 
-/* Find an integer point in "bset" that minimizes f (if any).
- * If sol_p is not NULL then the integer point is returned in *sol_p.
- * The optimal value of f is returned in *opt.
- *
- * The algorithm maintains a currently best solution and an interval [l, u]
- * of values of f for which integer solutions could potentially still be found.
- * The initial value of the best solution so far is any solution.
- * The initial value of l is minimal value of f over the rationals
- * (rounded up to the nearest integer).
- * The initial value of u is the value of f at the current solution minus 1.
+/* Find an integer point in "bset" that minimizes f (in any) such that
+ * the value of f lies inside the interval [l, u].
+ * Return this integer point if it can be found.
+ * Otherwise, return sol.
  *
  * We perform a number of steps until l > u.
  * In each step, we look for an integer point with value in either
@@ -132,14 +126,72 @@ error:
  * If no point can be found, we update l to the upper bound of the interval
  * we checked (u or l+floor(u-l-1/2)) plus 1.
  */
+static struct isl_vec *solve_ilp_search(struct isl_basic_set *bset,
+       isl_int *f, isl_int *opt, struct isl_vec *sol, isl_int l, isl_int u)
+{
+       isl_int tmp;
+       int divide = 1;
+
+       isl_int_init(tmp);
+
+       while (isl_int_le(l, u)) {
+               struct isl_basic_set *slice;
+               struct isl_vec *sample;
+
+               if (!divide)
+                       isl_int_set(tmp, u);
+               else {
+                       isl_int_sub(tmp, u, l);
+                       isl_int_fdiv_q_ui(tmp, tmp, 2);
+                       isl_int_add(tmp, tmp, l);
+               }
+               slice = add_bounds(isl_basic_set_copy(bset), f, l, tmp);
+               sample = isl_basic_set_sample_vec(slice);
+               if (!sample) {
+                       isl_vec_free(sol);
+                       sol = NULL;
+                       break;
+               }
+               if (sample->size > 0) {
+                       isl_vec_free(sol);
+                       sol = sample;
+                       isl_seq_inner_product(f, sol->el, sol->size, opt);
+                       isl_int_sub_ui(u, *opt, 1);
+                       divide = 1;
+               } else {
+                       isl_vec_free(sample);
+                       if (!divide)
+                               break;
+                       isl_int_add_ui(l, tmp, 1);
+                       divide = 0;
+               }
+       }
+
+       isl_int_clear(tmp);
+
+       return sol;
+}
+
+/* Find an integer point in "bset" that minimizes f (if any).
+ * If sol_p is not NULL then the integer point is returned in *sol_p.
+ * The optimal value of f is returned in *opt.
+ *
+ * The algorithm maintains a currently best solution and an interval [l, u]
+ * of values of f for which integer solutions could potentially still be found.
+ * The initial value of the best solution so far is any solution.
+ * The initial value of l is minimal value of f over the rationals
+ * (rounded up to the nearest integer).
+ * The initial value of u is the value of f at the initial solution minus 1.
+ *
+ * We then call solve_ilp_search to perform a binary search on the interval.
+ */
 static enum isl_lp_result solve_ilp(struct isl_basic_set *bset,
                                      isl_int *f, isl_int *opt,
                                      struct isl_vec **sol_p)
 {
        enum isl_lp_result res;
-       isl_int l, u, tmp;
+       isl_int l, u;
        struct isl_vec *sol;
-       int divide = 1;
 
        res = isl_basic_set_solve_lp(bset, 0, f, bset->ctx->one,
                                        opt, NULL, &sol);
@@ -168,50 +220,18 @@ static enum isl_lp_result solve_ilp(struct isl_basic_set *bset,
 
        isl_int_init(l);
        isl_int_init(u);
-       isl_int_init(tmp);
 
        isl_int_set(l, *opt);
 
        isl_seq_inner_product(f, sol->el, sol->size, opt);
        isl_int_sub_ui(u, *opt, 1);
 
-       while (isl_int_le(l, u)) {
-               struct isl_basic_set *slice;
-               struct isl_vec *sample;
-
-               if (!divide)
-                       isl_int_set(tmp, u);
-               else {
-                       isl_int_sub(tmp, u, l);
-                       isl_int_fdiv_q_ui(tmp, tmp, 2);
-                       isl_int_add(tmp, tmp, l);
-               }
-               slice = add_bounds(isl_basic_set_copy(bset), f, l, tmp);
-               sample = isl_basic_set_sample_vec(slice);
-               if (!sample) {
-                       isl_vec_free(sol);
-                       sol = NULL;
-                       res = isl_lp_error;
-                       break;
-               }
-               if (sample->size > 0) {
-                       isl_vec_free(sol);
-                       sol = sample;
-                       isl_seq_inner_product(f, sol->el, sol->size, opt);
-                       isl_int_sub_ui(u, *opt, 1);
-                       divide = 1;
-               } else {
-                       isl_vec_free(sample);
-                       if (!divide)
-                               break;
-                       isl_int_add_ui(l, tmp, 1);
-                       divide = 0;
-               }
-       }
+       sol = solve_ilp_search(bset, f, opt, sol, l, u);
+       if (!sol)
+               res = isl_lp_error;
 
        isl_int_clear(l);
        isl_int_clear(u);
-       isl_int_clear(tmp);
 
        if (sol_p)
                *sol_p = sol;
@@ -230,6 +250,7 @@ static enum isl_lp_result solve_ilp_with_eq(struct isl_basic_set *bset, int max,
        struct isl_mat *T = NULL;
        struct isl_vec *v;
 
+       bset = isl_basic_set_copy(bset);
        dim = isl_basic_set_total_dim(bset);
        v = isl_vec_alloc(bset->ctx, 1 + dim);
        if (!v)
@@ -241,12 +262,13 @@ static enum isl_lp_result solve_ilp_with_eq(struct isl_basic_set *bset, int max,
                goto error;
        res = isl_basic_set_solve_ilp(bset, max, v->el, opt, sol_p);
        isl_vec_free(v);
-       if (res == isl_lp_ok && *sol_p) {
+       if (res == isl_lp_ok && sol_p) {
                *sol_p = isl_mat_vec_product(T, *sol_p);
                if (!*sol_p)
                        res = isl_lp_error;
        } else
                isl_mat_free(T);
+       isl_basic_set_free(bset);
        return res;
 error:
        isl_mat_free(T);