add isl_set_complement
[platform/upstream/isl.git] / isl_map_subtract.c
1 /*
2  * Copyright 2008-2009 Katholieke Universiteit Leuven
3  *
4  * Use of this software is governed by the GNU LGPLv2.1 license
5  *
6  * Written by Sven Verdoolaege, K.U.Leuven, Departement
7  * Computerwetenschappen, Celestijnenlaan 200A, B-3001 Leuven, Belgium
8  */
9
10 #include "isl_seq.h"
11 #include "isl_set.h"
12 #include "isl_map.h"
13 #include "isl_map_private.h"
14 #include "isl_tab.h"
15
16 /* Add all constraints of bmap to tab.  The equalities of bmap
17  * are added as a pair of inequalities.
18  */
19 static int tab_add_constraints(struct isl_tab *tab,
20         __isl_keep isl_basic_map *bmap)
21 {
22         int i;
23         unsigned total;
24
25         if (!tab || !bmap)
26                 return -1;
27
28         total = isl_basic_map_total_dim(bmap);
29
30         if (isl_tab_extend_cons(tab, 2 * bmap->n_eq + bmap->n_ineq) < 0)
31                 return -1;
32
33         for (i = 0; i < bmap->n_eq; ++i) {
34                 if (isl_tab_add_ineq(tab, bmap->eq[i]) < 0)
35                         return -1;
36                 isl_seq_neg(bmap->eq[i], bmap->eq[i], 1 + total);
37                 if (isl_tab_add_ineq(tab, bmap->eq[i]) < 0)
38                         return -1;
39                 isl_seq_neg(bmap->eq[i], bmap->eq[i], 1 + total);
40                 if (tab->empty)
41                         return 0;
42         }
43
44         for (i = 0; i < bmap->n_ineq; ++i) {
45                 if (isl_tab_add_ineq(tab, bmap->ineq[i]) < 0)
46                         return -1;
47                 if (tab->empty)
48                         return 0;
49         }
50
51         return 0;
52 }
53
54 /* Add a specific constraint of bmap (or its opposite) to tab.
55  * The position of the constraint is specified by "c", where
56  * the equalities of bmap are counted twice, once for the inequality
57  * that is equal to the equality, and once for its negation.
58  */
59 static int tab_add_constraint(struct isl_tab *tab,
60         __isl_keep isl_basic_map *bmap, int c, int oppose)
61 {
62         unsigned total;
63         int r;
64
65         if (!tab || !bmap)
66                 return -1;
67
68         total = isl_basic_map_total_dim(bmap);
69
70         if (c < 2 * bmap->n_eq) {
71                 if ((c % 2) != oppose)
72                         isl_seq_neg(bmap->eq[c/2], bmap->eq[c/2], 1 + total);
73                 if (oppose)
74                         isl_int_sub_ui(bmap->eq[c/2][0], bmap->eq[c/2][0], 1);
75                 r = isl_tab_add_ineq(tab, bmap->eq[c/2]);
76                 if (oppose)
77                         isl_int_add_ui(bmap->eq[c/2][0], bmap->eq[c/2][0], 1);
78                 if ((c % 2) != oppose)
79                         isl_seq_neg(bmap->eq[c/2], bmap->eq[c/2], 1 + total);
80         } else {
81                 c -= 2 * bmap->n_eq;
82                 if (oppose) {
83                         isl_seq_neg(bmap->ineq[c], bmap->ineq[c], 1 + total);
84                         isl_int_sub_ui(bmap->ineq[c][0], bmap->ineq[c][0], 1);
85                 }
86                 r = isl_tab_add_ineq(tab, bmap->ineq[c]);
87                 if (oppose) {
88                         isl_int_add_ui(bmap->ineq[c][0], bmap->ineq[c][0], 1);
89                         isl_seq_neg(bmap->ineq[c], bmap->ineq[c], 1 + total);
90                 }
91         }
92
93         return r;
94 }
95
96 /* Freeze all constraints of tableau tab.
97  */
98 static int tab_freeze_constraints(struct isl_tab *tab)
99 {
100         int i;
101
102         for (i = 0; i < tab->n_con; ++i)
103                 if (isl_tab_freeze_constraint(tab, i) < 0)
104                         return -1;
105
106         return 0;
107 }
108
109 /* Check for redundant constraints starting at offset.
110  * Put the indices of the redundant constraints in index
111  * and return the number of redundant constraints.
112  */
113 static int n_non_redundant(struct isl_tab *tab, int offset, int **index)
114 {
115         int i, n;
116         int n_test = tab->n_con - offset;
117
118         if (isl_tab_detect_redundant(tab) < 0)
119                 return -1;
120
121         if (!*index)
122                 *index = isl_alloc_array(tab->mat->ctx, int, n_test);
123         if (!*index)
124                 return -1;
125
126         for (n = 0, i = 0; i < n_test; ++i) {
127                 int r;
128                 r = isl_tab_is_redundant(tab, offset + i);
129                 if (r < 0)
130                         return -1;
131                 if (r)
132                         continue;
133                 (*index)[n++] = i;
134         }
135
136         return n;
137 }
138
139 /* basic_map_collect_diff calls add on each of the pieces of
140  * the set difference between bmap and map until the add method
141  * return a negative value.
142  */
143 struct isl_diff_collector {
144         int (*add)(struct isl_diff_collector *dc,
145                     __isl_take isl_basic_map *bmap);
146 };
147
148 /* Compute the set difference between bmap and map and call
149  * dc->add on each of the piece until this function returns
150  * a negative value.
151  * Return 0 on success and -1 on error.  dc->add returning
152  * a negative value is treated as an error, but the calling
153  * function can interpret the results based on the state of dc.
154  *
155  * Assumes that both bmap and map have known divs.
156  *
157  * The difference is computed by a backtracking algorithm.
158  * Each level corresponds to a basic map in "map".
159  * When a node in entered for the first time, we check
160  * if the corresonding basic map intersect the current piece
161  * of "bmap".  If not, we move to the next level.
162  * Otherwise, we split the current piece into as many
163  * pieces as there are non-redundant constraints of the current
164  * basic map in the intersection.  Each of these pieces is
165  * handled by a child of the current node.
166  * In particular, if there are n non-redundant constraints,
167  * then for each 0 <= i < n, a piece is cut off by adding
168  * constraints 0 <= j < i and adding the opposite of constrain i.
169  * If there are no non-redundant constraints, meaning that the current
170  * piece is a subset of the current basic map, then we simply backtrack.
171  *
172  * In the leaves, we check if the remaining piece has any integer points
173  * and if so, pass it along to dc->add.  As a special case, if nothing
174  * has been removed when we end up in a leaf, we simply pass along
175  * the original basic map.
176  */
177 static int basic_map_collect_diff(__isl_take isl_basic_map *bmap,
178         __isl_take isl_map *map, struct isl_diff_collector *dc)
179 {
180         int i;
181         int modified;
182         int level;
183         int init;
184         int empty;
185         struct isl_tab *tab = NULL;
186         struct isl_tab_undo **snap = NULL;
187         int *k = NULL;
188         int *n = NULL;
189         int **index = NULL;
190
191         empty = isl_basic_map_is_empty(bmap);
192         if (empty) {
193                 isl_basic_map_free(bmap);
194                 isl_map_free(map);
195                 return empty < 0 ? -1 : 0;
196         }
197
198         bmap = isl_basic_map_cow(bmap);
199         map = isl_map_cow(map);
200
201         if (!bmap || !map)
202                 goto error;
203
204         snap = isl_alloc_array(map->ctx, struct isl_tab_undo *, map->n);
205         k = isl_alloc_array(map->ctx, int, map->n);
206         n = isl_alloc_array(map->ctx, int, map->n);
207         index = isl_calloc_array(map->ctx, int *, map->n);
208         if (!snap || !k || !n || !index)
209                 goto error;
210
211         for (i = 0; i < map->n; ++i) {
212                 bmap = isl_basic_map_align_divs(bmap, map->p[i]);
213                 if (!bmap)
214                         goto error;
215         }
216         for (i = 0; i < map->n; ++i) {
217                 map->p[i] = isl_basic_map_align_divs(map->p[i], bmap);
218                 if (!map->p[i])
219                         goto error;
220         }
221
222         tab = isl_tab_from_basic_map(bmap);
223         if (isl_tab_track_bmap(tab, isl_basic_map_copy(bmap)) < 0)
224                 goto error;
225
226         modified = 0;
227         level = 0;
228         init = 1;
229
230         while (level >= 0) {
231                 if (level >= map->n) {
232                         int empty;
233                         struct isl_basic_map *bm;
234                         if (!modified) {
235                                 if (dc->add(dc, isl_basic_map_copy(bmap)) < 0)
236                                         goto error;
237                                 break;
238                         }
239                         bm = isl_basic_map_copy(tab->bmap);
240                         bm = isl_basic_map_cow(bm);
241                         bm = isl_basic_map_update_from_tab(bm, tab);
242                         bm = isl_basic_map_simplify(bm);
243                         bm = isl_basic_map_finalize(bm);
244                         empty = isl_basic_map_is_empty(bm);
245                         if (empty)
246                                 isl_basic_map_free(bm);
247                         else if (dc->add(dc, bm) < 0)
248                                 goto error;
249                         if (empty < 0)
250                                 goto error;
251                         level--;
252                         init = 0;
253                         continue;
254                 }
255                 if (init) {
256                         int offset = tab->n_con;
257                         snap[level] = isl_tab_snap(tab);
258                         if (tab_freeze_constraints(tab) < 0)
259                                 goto error;
260                         if (tab_add_constraints(tab, map->p[level]) < 0)
261                                 goto error;
262                         k[level] = 0;
263                         n[level] = 0;
264                         if (tab->empty) {
265                                 if (isl_tab_rollback(tab, snap[level]) < 0)
266                                         goto error;
267                                 level++;
268                                 continue;
269                         }
270                         modified = 1;
271                         n[level] = n_non_redundant(tab, offset, &index[level]);
272                         if (n[level] < 0)
273                                 goto error;
274                         if (n[level] == 0) {
275                                 level--;
276                                 init = 0;
277                                 continue;
278                         }
279                         if (isl_tab_rollback(tab, snap[level]) < 0)
280                                 goto error;
281                         if (tab_add_constraint(tab, map->p[level],
282                                                 index[level][0], 1) < 0)
283                                 goto error;
284                         level++;
285                         continue;
286                 } else {
287                         if (k[level] + 1 >= n[level]) {
288                                 level--;
289                                 continue;
290                         }
291                         if (isl_tab_rollback(tab, snap[level]) < 0)
292                                 goto error;
293                         if (tab_add_constraint(tab, map->p[level],
294                                                 index[level][k[level]], 0) < 0)
295                                 goto error;
296                         snap[level] = isl_tab_snap(tab);
297                         k[level]++;
298                         if (tab_add_constraint(tab, map->p[level],
299                                                 index[level][k[level]], 1) < 0)
300                                 goto error;
301                         level++;
302                         init = 1;
303                         continue;
304                 }
305         }
306
307         isl_tab_free(tab);
308         free(snap);
309         free(n);
310         free(k);
311         for (i = 0; index && i < map->n; ++i)
312                 free(index[i]);
313         free(index);
314
315         isl_basic_map_free(bmap);
316         isl_map_free(map);
317
318         return 0;
319 error:
320         isl_tab_free(tab);
321         free(snap);
322         free(n);
323         free(k);
324         for (i = 0; index && i < map->n; ++i)
325                 free(index[i]);
326         free(index);
327         isl_basic_map_free(bmap);
328         isl_map_free(map);
329         return -1;
330 }
331
332 /* A diff collector that actually collects all parts of the
333  * set difference in the field diff.
334  */
335 struct isl_subtract_diff_collector {
336         struct isl_diff_collector dc;
337         struct isl_map *diff;
338 };
339
340 /* isl_subtract_diff_collector callback.
341  */
342 static int basic_map_subtract_add(struct isl_diff_collector *dc,
343                             __isl_take isl_basic_map *bmap)
344 {
345         struct isl_subtract_diff_collector *sdc;
346         sdc = (struct isl_subtract_diff_collector *)dc;
347
348         sdc->diff = isl_map_union_disjoint(sdc->diff,
349                         isl_map_from_basic_map(bmap));
350
351         return sdc->diff ? 0 : -1;
352 }
353
354 /* Return the set difference between bmap and map.
355  */
356 static __isl_give isl_map *basic_map_subtract(__isl_take isl_basic_map *bmap,
357         __isl_take isl_map *map)
358 {
359         struct isl_subtract_diff_collector sdc;
360         sdc.dc.add = &basic_map_subtract_add;
361         sdc.diff = isl_map_empty_like_basic_map(bmap);
362         if (basic_map_collect_diff(bmap, map, &sdc.dc) < 0) {
363                 isl_map_free(sdc.diff);
364                 sdc.diff = NULL;
365         }
366         return sdc.diff;
367 }
368
369 /* Return the set difference between map1 and map2.
370  * (U_i A_i) \ (U_j B_j) is computed as U_i (A_i \ (U_j B_j))
371  */
372 struct isl_map *isl_map_subtract(struct isl_map *map1, struct isl_map *map2)
373 {
374         int i;
375         struct isl_map *diff;
376
377         if (!map1 || !map2)
378                 goto error;
379
380         isl_assert(map1->ctx, isl_dim_equal(map1->dim, map2->dim), goto error);
381
382         if (isl_map_is_empty(map2)) {
383                 isl_map_free(map2);
384                 return map1;
385         }
386
387         map1 = isl_map_compute_divs(map1);
388         map2 = isl_map_compute_divs(map2);
389         if (!map1 || !map2)
390                 goto error;
391
392         map1 = isl_map_remove_empty_parts(map1);
393         map2 = isl_map_remove_empty_parts(map2);
394
395         diff = isl_map_empty_like(map1);
396         for (i = 0; i < map1->n; ++i) {
397                 struct isl_map *d;
398                 d = basic_map_subtract(isl_basic_map_copy(map1->p[i]),
399                                        isl_map_copy(map2));
400                 if (ISL_F_ISSET(map1, ISL_MAP_DISJOINT))
401                         diff = isl_map_union_disjoint(diff, d);
402                 else
403                         diff = isl_map_union(diff, d);
404         }
405
406         isl_map_free(map1);
407         isl_map_free(map2);
408
409         return diff;
410 error:
411         isl_map_free(map1);
412         isl_map_free(map2);
413         return NULL;
414 }
415
416 struct isl_set *isl_set_subtract(struct isl_set *set1, struct isl_set *set2)
417 {
418         return (struct isl_set *)
419                 isl_map_subtract(
420                         (struct isl_map *)set1, (struct isl_map *)set2);
421 }
422
423 /* A diff collector that aborts as soon as its add function is called,
424  * setting empty to 0.
425  */
426 struct isl_is_empty_diff_collector {
427         struct isl_diff_collector dc;
428         int empty;
429 };
430
431 /* isl_is_empty_diff_collector callback.
432  */
433 static int basic_map_is_empty_add(struct isl_diff_collector *dc,
434                             __isl_take isl_basic_map *bmap)
435 {
436         struct isl_is_empty_diff_collector *edc;
437         edc = (struct isl_is_empty_diff_collector *)dc;
438
439         edc->empty = 0;
440
441         isl_basic_map_free(bmap);
442         return -1;
443 }
444
445 /* Check if bmap \ map is empty by computing this set difference
446  * and breaking off as soon as the difference is known to be non-empty.
447  */
448 static int basic_map_diff_is_empty(__isl_keep isl_basic_map *bmap,
449         __isl_keep isl_map *map)
450 {
451         int r;
452         struct isl_is_empty_diff_collector edc;
453
454         r = isl_basic_map_fast_is_empty(bmap);
455         if (r)
456                 return r;
457
458         edc.dc.add = &basic_map_is_empty_add;
459         edc.empty = 1;
460         r = basic_map_collect_diff(isl_basic_map_copy(bmap),
461                                    isl_map_copy(map), &edc.dc);
462         if (!edc.empty)
463                 return 0;
464
465         return r < 0 ? -1 : 1;
466 }
467
468 /* Check if map1 \ map2 is empty by checking if the set difference is empty
469  * for each of the basic maps in map1.
470  */
471 static int map_diff_is_empty(__isl_keep isl_map *map1, __isl_keep isl_map *map2)
472 {
473         int i;
474         int is_empty = 1;
475
476         if (!map1 || !map2)
477                 return -1;
478         
479         for (i = 0; i < map1->n; ++i) {
480                 is_empty = basic_map_diff_is_empty(map1->p[i], map2);
481                 if (is_empty < 0 || !is_empty)
482                          break;
483         }
484
485         return is_empty;
486 }
487
488 /* Return 1 if "bmap" contains a single element.
489  */
490 int isl_basic_map_is_singleton(__isl_keep isl_basic_map *bmap)
491 {
492         if (!bmap)
493                 return -1;
494         if (bmap->n_div)
495                 return 0;
496         if (bmap->n_ineq)
497                 return 0;
498         return bmap->n_eq == isl_basic_map_total_dim(bmap);
499 }
500
501 /* Return 1 if "map" contains a single element.
502  */
503 int isl_map_is_singleton(__isl_keep isl_map *map)
504 {
505         if (!map)
506                 return -1;
507         if (map->n != 1)
508                 return 0;
509
510         return isl_basic_map_is_singleton(map->p[0]);
511 }
512
513 /* Given a singleton basic map, extract the single element
514  * as an isl_vec.
515  */
516 static __isl_give isl_vec *singleton_extract_point(__isl_keep isl_basic_map *bmap)
517 {
518         int i, j;
519         unsigned dim;
520         struct isl_vec *point;
521         isl_int m;
522
523         if (!bmap)
524                 return NULL;
525
526         dim = isl_basic_map_total_dim(bmap);
527         isl_assert(bmap->ctx, bmap->n_eq == dim, return NULL);
528         point = isl_vec_alloc(bmap->ctx, 1 + dim);
529         if (!point)
530                 return NULL;
531
532         isl_int_init(m);
533
534         isl_int_set_si(point->el[0], 1);
535         for (j = 0; j < bmap->n_eq; ++j) {
536                 int s;
537                 int i = dim - 1 - j;
538                 isl_assert(bmap->ctx,
539                     isl_seq_first_non_zero(bmap->eq[j] + 1, i) == -1,
540                     goto error);
541                 isl_assert(bmap->ctx,
542                     isl_int_is_one(bmap->eq[j][1 + i]) ||
543                     isl_int_is_negone(bmap->eq[j][1 + i]),
544                     goto error);
545                 isl_assert(bmap->ctx,
546                     isl_seq_first_non_zero(bmap->eq[j]+1+i+1, dim-i-1) == -1,
547                     goto error);
548
549                 isl_int_gcd(m, point->el[0], bmap->eq[j][1 + i]);
550                 isl_int_divexact(m, bmap->eq[j][1 + i], m);
551                 isl_int_abs(m, m);
552                 isl_seq_scale(point->el, point->el, m, 1 + i);
553                 isl_int_divexact(m, point->el[0], bmap->eq[j][1 + i]);
554                 isl_int_neg(m, m);
555                 isl_int_mul(point->el[1 + i], m, bmap->eq[j][0]);
556         }
557
558         isl_int_clear(m);
559         return point;
560 error:
561         isl_int_clear(m);
562         isl_vec_free(point);
563         return NULL;
564 }
565
566 /* Return 1 if "bmap" contains the point "point".
567  * "bmap" is assumed to have known divs.
568  * The point is first extended with the divs and then passed
569  * to basic_map_contains.
570  */
571 static int basic_map_contains_point(__isl_keep isl_basic_map *bmap,
572         __isl_keep isl_vec *point)
573 {
574         int i;
575         struct isl_vec *vec;
576         unsigned dim;
577         int contains;
578
579         if (!bmap || !point)
580                 return -1;
581         if (bmap->n_div == 0)
582                 return isl_basic_map_contains(bmap, point);
583
584         dim = isl_basic_map_total_dim(bmap) - bmap->n_div;
585         vec = isl_vec_alloc(bmap->ctx, 1 + dim + bmap->n_div);
586         if (!vec)
587                 return -1;
588
589         isl_seq_cpy(vec->el, point->el, point->size);
590         for (i = 0; i < bmap->n_div; ++i) {
591                 isl_seq_inner_product(bmap->div[i] + 1, vec->el,
592                                         1 + dim + i, &vec->el[1+dim+i]);
593                 isl_int_fdiv_q(vec->el[1+dim+i], vec->el[1+dim+i],
594                                 bmap->div[i][0]);
595         }
596
597         contains = isl_basic_map_contains(bmap, vec);
598
599         isl_vec_free(vec);
600         return contains;
601 }
602
603 /* Return 1 is the singleton map "map1" is a subset of "map2",
604  * i.e., if the single element of "map1" is also an element of "map2".
605  */
606 static int map_is_singleton_subset(__isl_keep isl_map *map1,
607         __isl_keep isl_map *map2)
608 {
609         int i;
610         int is_subset = 0;
611         struct isl_vec *point;
612
613         if (!map1 || !map2)
614                 return -1;
615         if (map1->n != 1)
616                 return -1;
617
618         point = singleton_extract_point(map1->p[0]);
619         if (!point)
620                 return -1;
621
622         for (i = 0; i < map2->n; ++i) {
623                 is_subset = basic_map_contains_point(map2->p[i], point);
624                 if (is_subset)
625                         break;
626         }
627
628         isl_vec_free(point);
629         return is_subset;
630 }
631
632 int isl_map_is_subset(struct isl_map *map1, struct isl_map *map2)
633 {
634         int is_subset = 0;
635         struct isl_map *diff;
636
637         if (!map1 || !map2)
638                 return -1;
639
640         if (isl_map_is_empty(map1))
641                 return 1;
642
643         if (isl_map_is_empty(map2))
644                 return 0;
645
646         if (isl_map_fast_is_universe(map2))
647                 return 1;
648
649         map1 = isl_map_compute_divs(isl_map_copy(map1));
650         map2 = isl_map_compute_divs(isl_map_copy(map2));
651         if (isl_map_is_singleton(map1)) {
652                 is_subset = map_is_singleton_subset(map1, map2);
653                 isl_map_free(map1);
654                 isl_map_free(map2);
655                 return is_subset;
656         }
657         is_subset = map_diff_is_empty(map1, map2);
658         isl_map_free(map1);
659         isl_map_free(map2);
660
661         return is_subset;
662 }
663
664 int isl_set_is_subset(struct isl_set *set1, struct isl_set *set2)
665 {
666         return isl_map_is_subset(
667                         (struct isl_map *)set1, (struct isl_map *)set2);
668 }
669
670 __isl_give isl_map *isl_map_make_disjoint(__isl_take isl_map *map)
671 {
672         int i;
673         struct isl_subtract_diff_collector sdc;
674         sdc.dc.add = &basic_map_subtract_add;
675
676         if (!map)
677                 return NULL;
678         if (ISL_F_ISSET(map, ISL_MAP_DISJOINT))
679                 return map;
680         if (map->n <= 1)
681                 return map;
682
683         map = isl_map_compute_divs(map);
684         map = isl_map_remove_empty_parts(map);
685
686         if (!map || map->n <= 1)
687                 return map;
688
689         sdc.diff = isl_map_from_basic_map(isl_basic_map_copy(map->p[0]));
690
691         for (i = 1; i < map->n; ++i) {
692                 struct isl_basic_map *bmap = isl_basic_map_copy(map->p[i]);
693                 struct isl_map *copy = isl_map_copy(sdc.diff);
694                 if (basic_map_collect_diff(bmap, copy, &sdc.dc) < 0) {
695                         isl_map_free(sdc.diff);
696                         sdc.diff = NULL;
697                         break;
698                 }
699         }
700
701         isl_map_free(map);
702
703         return sdc.diff;
704 }
705
706 __isl_give isl_set *isl_set_make_disjoint(__isl_take isl_set *set)
707 {
708         return (struct isl_set *)isl_map_make_disjoint((struct isl_map *)set);
709 }
710
711 __isl_give isl_set *isl_set_complement(__isl_take isl_set *set)
712 {
713         isl_set *universe;
714
715         if (!set)
716                 return NULL;
717
718         universe = isl_set_universe(isl_set_get_dim(set));
719
720         return isl_set_subtract(universe, set);
721 }