isl_map_subtract: only add divs when needed
authorSven Verdoolaege <skimo@kotnet.org>
Thu, 4 Mar 2010 09:08:47 +0000 (10:08 +0100)
committerSven Verdoolaege <skimo@kotnet.org>
Thu, 4 Mar 2010 14:33:12 +0000 (15:33 +0100)
Before, the divs from all basic maps were aligne with each
other, resulting in all divs being added to all basic maps,
even if the arguments of the subtract operation are disjoint.
Now, divs are only added at the point where they are needed
and removed again if it turns out that they are not needed
in the difference.

isl_map_subtract.c

index 650c196..5a383fd 100644 (file)
 #include "isl_map_private.h"
 #include "isl_tab.h"
 
+static void expand_constraint(isl_vec *v, unsigned dim,
+       isl_int *c, int *div_map, unsigned n_div)
+{
+       int i;
+
+       isl_seq_cpy(v->el, c, 1 + dim);
+       isl_seq_clr(v->el + 1 + dim, v->size - (1 + dim));
+
+       for (i = 0; i < n_div; ++i)
+               isl_int_set(v->el[1 + dim + div_map[i]], c[1 + dim + i]);
+}
+
 /* Add all constraints of bmap to tab.  The equalities of bmap
  * are added as a pair of inequalities.
  */
 static int tab_add_constraints(struct isl_tab *tab,
-       __isl_keep isl_basic_map *bmap)
+       __isl_keep isl_basic_map *bmap, int *div_map)
 {
        int i;
-       unsigned total;
+       unsigned dim;
+       unsigned tab_total;
+       unsigned bmap_total;
+       isl_vec *v;
 
        if (!tab || !bmap)
                return -1;
 
-       total = isl_basic_map_total_dim(bmap);
+       tab_total = isl_basic_map_total_dim(tab->bmap);
+       bmap_total = isl_basic_map_total_dim(bmap);
+       dim = isl_dim_total(tab->bmap->dim);
 
        if (isl_tab_extend_cons(tab, 2 * bmap->n_eq + bmap->n_ineq) < 0)
                return -1;
 
+       v = isl_vec_alloc(bmap->ctx, 1 + tab_total);
+
        for (i = 0; i < bmap->n_eq; ++i) {
-               if (isl_tab_add_ineq(tab, bmap->eq[i]) < 0)
-                       return -1;
-               isl_seq_neg(bmap->eq[i], bmap->eq[i], 1 + total);
-               if (isl_tab_add_ineq(tab, bmap->eq[i]) < 0)
-                       return -1;
-               isl_seq_neg(bmap->eq[i], bmap->eq[i], 1 + total);
+               expand_constraint(v, dim, bmap->eq[i], div_map, bmap->n_div);
+               if (isl_tab_add_ineq(tab, v->el) < 0)
+                       goto error;
+               isl_seq_neg(bmap->eq[i], bmap->eq[i], 1 + bmap_total);
+               expand_constraint(v, dim, bmap->eq[i], div_map, bmap->n_div);
+               if (isl_tab_add_ineq(tab, v->el) < 0)
+                       goto error;
+               isl_seq_neg(bmap->eq[i], bmap->eq[i], 1 + bmap_total);
                if (tab->empty)
-                       return 0;
+                       break;
        }
 
        for (i = 0; i < bmap->n_ineq; ++i) {
-               if (isl_tab_add_ineq(tab, bmap->ineq[i]) < 0)
-                       return -1;
+               expand_constraint(v, dim, bmap->ineq[i], div_map, bmap->n_div);
+               if (isl_tab_add_ineq(tab, v->el) < 0)
+                       goto error;
                if (tab->empty)
-                       return 0;
+                       break;
        }
 
+       isl_vec_free(v);
        return 0;
+error:
+       isl_vec_free(v);
+       return -1;
 }
 
 /* Add a specific constraint of bmap (or its opposite) to tab.
@@ -57,42 +83,109 @@ static int tab_add_constraints(struct isl_tab *tab,
  * that is equal to the equality, and once for its negation.
  */
 static int tab_add_constraint(struct isl_tab *tab,
-       __isl_keep isl_basic_map *bmap, int c, int oppose)
+       __isl_keep isl_basic_map *bmap, int *div_map, int c, int oppose)
 {
-       unsigned total;
+       unsigned dim;
+       unsigned tab_total;
+       unsigned bmap_total;
+       isl_vec *v;
        int r;
 
        if (!tab || !bmap)
                return -1;
 
-       total = isl_basic_map_total_dim(bmap);
+       tab_total = isl_basic_map_total_dim(tab->bmap);
+       bmap_total = isl_basic_map_total_dim(bmap);
+       dim = isl_dim_total(tab->bmap->dim);
+
+       v = isl_vec_alloc(bmap->ctx, 1 + tab_total);
+       if (!v)
+               return -1;
 
        if (c < 2 * bmap->n_eq) {
                if ((c % 2) != oppose)
-                       isl_seq_neg(bmap->eq[c/2], bmap->eq[c/2], 1 + total);
+                       isl_seq_neg(bmap->eq[c/2], bmap->eq[c/2],
+                                       1 + bmap_total);
                if (oppose)
                        isl_int_sub_ui(bmap->eq[c/2][0], bmap->eq[c/2][0], 1);
-               r = isl_tab_add_ineq(tab, bmap->eq[c/2]);
+               expand_constraint(v, dim, bmap->eq[c/2], div_map, bmap->n_div);
+               r = isl_tab_add_ineq(tab, v->el);
                if (oppose)
                        isl_int_add_ui(bmap->eq[c/2][0], bmap->eq[c/2][0], 1);
                if ((c % 2) != oppose)
-                       isl_seq_neg(bmap->eq[c/2], bmap->eq[c/2], 1 + total);
+                       isl_seq_neg(bmap->eq[c/2], bmap->eq[c/2],
+                                       1 + bmap_total);
        } else {
                c -= 2 * bmap->n_eq;
                if (oppose) {
-                       isl_seq_neg(bmap->ineq[c], bmap->ineq[c], 1 + total);
+                       isl_seq_neg(bmap->ineq[c], bmap->ineq[c],
+                                       1 + bmap_total);
                        isl_int_sub_ui(bmap->ineq[c][0], bmap->ineq[c][0], 1);
                }
-               r = isl_tab_add_ineq(tab, bmap->ineq[c]);
+               expand_constraint(v, dim, bmap->ineq[c], div_map, bmap->n_div);
+               r = isl_tab_add_ineq(tab, v->el);
                if (oppose) {
                        isl_int_add_ui(bmap->ineq[c][0], bmap->ineq[c][0], 1);
-                       isl_seq_neg(bmap->ineq[c], bmap->ineq[c], 1 + total);
+                       isl_seq_neg(bmap->ineq[c], bmap->ineq[c],
+                                       1 + bmap_total);
                }
        }
 
+       isl_vec_free(v);
        return r;
 }
 
+static int tab_add_divs(struct isl_tab *tab, __isl_keep isl_basic_map *bmap,
+       int **div_map)
+{
+       int i, j;
+       struct isl_vec *vec;
+       unsigned total;
+       unsigned dim;
+
+       if (!bmap)
+               return -1;
+       if (!bmap->n_div)
+               return 0;
+
+       if (!*div_map)
+               *div_map = isl_alloc_array(bmap->ctx, int, bmap->n_div);
+       if (!*div_map)
+               return -1;
+
+       total = isl_basic_map_total_dim(tab->bmap);
+       dim = total - tab->bmap->n_div;
+       vec = isl_vec_alloc(bmap->ctx, 2 + total + bmap->n_div);
+       if (!vec)
+               return -1;
+
+       for (i = 0; i < bmap->n_div; ++i) {
+               isl_seq_cpy(vec->el, bmap->div[i], 2 + dim);
+               isl_seq_clr(vec->el + 2 + dim, tab->bmap->n_div);
+               for (j = 0; j < i; ++j)
+                       isl_int_set(vec->el[2 + dim + (*div_map)[j]],
+                                       bmap->div[i][2 + dim + j]);
+               for (j = 0; j < tab->bmap->n_div; ++j)
+                       if (isl_seq_eq(tab->bmap->div[j],
+                                       vec->el, 2 + dim + tab->bmap->n_div))
+                               break;
+               (*div_map)[i] = j;
+               if (j == tab->bmap->n_div) {
+                       vec->size = 2 + dim + tab->bmap->n_div;
+                       if (isl_tab_add_div(tab, vec, NULL, NULL) < 0)
+                               goto error;
+               }
+       }
+
+       isl_vec_free(vec);
+
+       return 0;
+error:
+       isl_vec_free(vec);
+
+       return -1;
+}
+
 /* Freeze all constraints of tableau tab.
  */
 static int tab_freeze_constraints(struct isl_tab *tab)
@@ -157,7 +250,7 @@ struct isl_diff_collector {
  * The difference is computed by a backtracking algorithm.
  * Each level corresponds to a basic map in "map".
  * When a node in entered for the first time, we check
- * if the corresonding basic map intersect the current piece
+ * if the corresonding basic map intersects the current piece
  * of "bmap".  If not, we move to the next level.
  * Otherwise, we split the current piece into as many
  * pieces as there are non-redundant constraints of the current
@@ -165,7 +258,7 @@ struct isl_diff_collector {
  * handled by a child of the current node.
  * In particular, if there are n non-redundant constraints,
  * then for each 0 <= i < n, a piece is cut off by adding
- * constraints 0 <= j < i and adding the opposite of constrain i.
+ * constraints 0 <= j < i and adding the opposite of constraint i.
  * If there are no non-redundant constraints, meaning that the current
  * piece is a subset of the current basic map, then we simply backtrack.
  *
@@ -187,6 +280,7 @@ static int basic_map_collect_diff(__isl_take isl_basic_map *bmap,
        int *k = NULL;
        int *n = NULL;
        int **index = NULL;
+       int **div_map = NULL;
 
        empty = isl_basic_map_is_empty(bmap);
        if (empty) {
@@ -205,19 +299,12 @@ static int basic_map_collect_diff(__isl_take isl_basic_map *bmap,
        k = isl_alloc_array(map->ctx, int, map->n);
        n = isl_alloc_array(map->ctx, int, map->n);
        index = isl_calloc_array(map->ctx, int *, map->n);
-       if (!snap || !k || !n || !index)
+       div_map = isl_calloc_array(map->ctx, int *, map->n);
+       if (!snap || !k || !n || !index || !div_map)
                goto error;
 
-       for (i = 0; i < map->n; ++i) {
-               bmap = isl_basic_map_align_divs(bmap, map->p[i]);
-               if (!bmap)
-                       goto error;
-       }
-       for (i = 0; i < map->n; ++i) {
-               map->p[i] = isl_basic_map_align_divs(map->p[i], bmap);
-               if (!map->p[i])
-                       goto error;
-       }
+       bmap = isl_basic_map_order_divs(bmap);
+       map = isl_map_order_divs(map);
 
        tab = isl_tab_from_basic_map(bmap);
        if (isl_tab_track_bmap(tab, isl_basic_map_copy(bmap)) < 0)
@@ -253,16 +340,23 @@ static int basic_map_collect_diff(__isl_take isl_basic_map *bmap,
                        continue;
                }
                if (init) {
-                       int offset = tab->n_con;
+                       int offset;
+                       struct isl_tab_undo *snap2;
+                       snap2 = isl_tab_snap(tab);
+                       if (tab_add_divs(tab, map->p[level],
+                                        &div_map[level]) < 0)
+                               goto error;
+                       offset = tab->n_con;
                        snap[level] = isl_tab_snap(tab);
                        if (tab_freeze_constraints(tab) < 0)
                                goto error;
-                       if (tab_add_constraints(tab, map->p[level]) < 0)
+                       if (tab_add_constraints(tab, map->p[level],
+                                               div_map[level]) < 0)
                                goto error;
                        k[level] = 0;
                        n[level] = 0;
                        if (tab->empty) {
-                               if (isl_tab_rollback(tab, snap[level]) < 0)
+                               if (isl_tab_rollback(tab, snap2) < 0)
                                        goto error;
                                level++;
                                continue;
@@ -279,7 +373,7 @@ static int basic_map_collect_diff(__isl_take isl_basic_map *bmap,
                        if (isl_tab_rollback(tab, snap[level]) < 0)
                                goto error;
                        if (tab_add_constraint(tab, map->p[level],
-                                               index[level][0], 1) < 0)
+                                       div_map[level], index[level][0], 1) < 0)
                                goto error;
                        level++;
                        continue;
@@ -291,11 +385,13 @@ static int basic_map_collect_diff(__isl_take isl_basic_map *bmap,
                        if (isl_tab_rollback(tab, snap[level]) < 0)
                                goto error;
                        if (tab_add_constraint(tab, map->p[level],
+                                               div_map[level],
                                                index[level][k[level]], 0) < 0)
                                goto error;
                        snap[level] = isl_tab_snap(tab);
                        k[level]++;
                        if (tab_add_constraint(tab, map->p[level],
+                                               div_map[level],
                                                index[level][k[level]], 1) < 0)
                                goto error;
                        level++;
@@ -311,6 +407,9 @@ static int basic_map_collect_diff(__isl_take isl_basic_map *bmap,
        for (i = 0; index && i < map->n; ++i)
                free(index[i]);
        free(index);
+       for (i = 0; div_map && i < map->n; ++i)
+               free(div_map[i]);
+       free(div_map);
 
        isl_basic_map_free(bmap);
        isl_map_free(map);
@@ -324,6 +423,9 @@ error:
        for (i = 0; index && i < map->n; ++i)
                free(index[i]);
        free(index);
+       for (i = 0; div_map && i < map->n; ++i)
+               free(div_map[i]);
+       free(div_map);
        isl_basic_map_free(bmap);
        isl_map_free(map);
        return -1;