compute set difference using a backtracking algorithm
[platform/upstream/isl.git] / isl_map_subtract.c
1 #include "isl_seq.h"
2 #include "isl_set.h"
3 #include "isl_map.h"
4 #include "isl_map_private.h"
5 #include "isl_tab.h"
6
7 /* Add all constraints of bmap to tab.  The equalities of bmap
8  * are added as a pair of inequalities.
9  */
10 static int tab_add_constraints(struct isl_tab *tab,
11         __isl_keep isl_basic_map *bmap)
12 {
13         int i;
14         unsigned total;
15
16         if (!tab || !bmap)
17                 return -1;
18
19         total = isl_basic_map_total_dim(bmap);
20
21         if (isl_tab_extend_cons(tab, 2 * bmap->n_eq + bmap->n_ineq) < 0)
22                 return -1;
23
24         for (i = 0; i < bmap->n_eq; ++i) {
25                 if (isl_tab_add_ineq(tab, bmap->eq[i]) < 0)
26                         return -1;
27                 isl_seq_neg(bmap->eq[i], bmap->eq[i], 1 + total);
28                 if (isl_tab_add_ineq(tab, bmap->eq[i]) < 0)
29                         return -1;
30                 isl_seq_neg(bmap->eq[i], bmap->eq[i], 1 + total);
31                 if (tab->empty)
32                         return 0;
33         }
34
35         for (i = 0; i < bmap->n_ineq; ++i) {
36                 if (isl_tab_add_ineq(tab, bmap->ineq[i]) < 0)
37                         return -1;
38                 if (tab->empty)
39                         return 0;
40         }
41
42         return 0;
43 }
44
45 /* Add a specific constraint of bmap (or its opposite) to tab.
46  * The position of the constraint is specified by "c", where
47  * the equalities of bmap are counted twice, once for the inequality
48  * that is equal to the equality, and once for its negation.
49  */
50 static int tab_add_constraint(struct isl_tab *tab,
51         __isl_keep isl_basic_map *bmap, int c, int oppose)
52 {
53         unsigned total;
54         int r;
55
56         if (!tab || !bmap)
57                 return -1;
58
59         total = isl_basic_map_total_dim(bmap);
60
61         if (c < 2 * bmap->n_eq) {
62                 if ((c % 2) != oppose)
63                         isl_seq_neg(bmap->eq[c/2], bmap->eq[c/2], 1 + total);
64                 if (oppose)
65                         isl_int_sub_ui(bmap->eq[c/2][0], bmap->eq[c/2][0], 1);
66                 r = isl_tab_add_ineq(tab, bmap->eq[c/2]);
67                 if (oppose)
68                         isl_int_add_ui(bmap->eq[c/2][0], bmap->eq[c/2][0], 1);
69                 if ((c % 2) != oppose)
70                         isl_seq_neg(bmap->eq[c/2], bmap->eq[c/2], 1 + total);
71         } else {
72                 c -= 2 * bmap->n_eq;
73                 if (oppose) {
74                         isl_seq_neg(bmap->ineq[c], bmap->ineq[c], 1 + total);
75                         isl_int_sub_ui(bmap->ineq[c][0], bmap->ineq[c][0], 1);
76                 }
77                 r = isl_tab_add_ineq(tab, bmap->ineq[c]);
78                 if (oppose) {
79                         isl_int_add_ui(bmap->ineq[c][0], bmap->ineq[c][0], 1);
80                         isl_seq_neg(bmap->ineq[c], bmap->ineq[c], 1 + total);
81                 }
82         }
83
84         return r;
85 }
86
87 /* Freeze all constraints of tableau tab.
88  */
89 static int tab_freeze_constraints(struct isl_tab *tab)
90 {
91         int i;
92
93         for (i = 0; i < tab->n_con; ++i)
94                 if (isl_tab_freeze_constraint(tab, i) < 0)
95                         return -1;
96
97         return 0;
98 }
99
100 /* Check for redundant constraints starting at offset.
101  * Put the indices of the redundant constraints in index
102  * and return the number of redundant constraints.
103  */
104 static int n_non_redundant(struct isl_tab *tab, int offset, int **index)
105 {
106         int i, n;
107         int n_test = tab->n_con - offset;
108
109         if (isl_tab_detect_redundant(tab) < 0)
110                 return -1;
111
112         if (!*index)
113                 *index = isl_alloc_array(tab->mat->ctx, int, n_test);
114         if (!*index)
115                 return -1;
116
117         for (n = 0, i = 0; i < n_test; ++i) {
118                 int r;
119                 r = isl_tab_is_redundant(tab, offset + i);
120                 if (r < 0)
121                         return -1;
122                 if (r)
123                         continue;
124                 (*index)[n++] = i;
125         }
126
127         return n;
128 }
129
130 /* basic_map_collect_diff calls add on each of the pieces of
131  * the set difference between bmap and map until the add method
132  * return a negative value.
133  */
134 struct isl_diff_collector {
135         int (*add)(struct isl_diff_collector *dc,
136                     __isl_take isl_basic_map *bmap);
137 };
138
139 /* Compute the set difference between bmap and map and call
140  * dc->add on each of the piece until this function returns
141  * a negative value.
142  * Return 0 on success and -1 on error.  dc->add returning
143  * a negative value is treated as an error, but the calling
144  * function can interpret the results based on the state of dc.
145  *
146  * Assumes that both bmap and map have known divs.
147  *
148  * The difference is computed by a backtracking algorithm.
149  * Each level corresponds to a basic map in "map".
150  * When a node in entered for the first time, we check
151  * if the corresonding basic map intersect the current piece
152  * of "bmap".  If not, we move to the next level.
153  * Otherwise, we split the current piece into as many
154  * pieces as there are non-redundant constraints of the current
155  * basic map in the intersection.  Each of these pieces is
156  * handled by a child of the current node.
157  * In particular, if there are n non-redundant constraints,
158  * then for each 0 <= i < n, a piece is cut off by adding
159  * constraints 0 <= j < i and adding the opposite of constrain i.
160  * If there are no non-redundant constraints, meaning that the current
161  * piece is a subset of the current basic map, then we simply backtrack.
162  *
163  * In the leaves, we check if the remaining piece has any integer points
164  * and if so, pass it along to dc->add.  As a special case, if nothing
165  * has been removed when we end up in a leaf, we simply pass along
166  * the original basic map.
167  */
168 static int basic_map_collect_diff(__isl_take isl_basic_map *bmap,
169         __isl_take isl_map *map, struct isl_diff_collector *dc)
170 {
171         int i;
172         int modified;
173         int level;
174         int init;
175         int empty;
176         struct isl_tab *tab = NULL;
177         struct isl_tab_undo **snap = NULL;
178         int *k = NULL;
179         int *n = NULL;
180         int **index = NULL;
181
182         empty = isl_basic_map_is_empty(bmap);
183         if (empty) {
184                 isl_basic_map_free(bmap);
185                 isl_map_free(map);
186                 return empty < 0 ? -1 : 0;
187         }
188
189         bmap = isl_basic_map_cow(bmap);
190         map = isl_map_cow(map);
191
192         if (!bmap || !map)
193                 goto error;
194
195         snap = isl_alloc_array(map->ctx, struct isl_tab_undo *, map->n);
196         k = isl_alloc_array(map->ctx, int, map->n);
197         n = isl_alloc_array(map->ctx, int, map->n);
198         index = isl_calloc_array(map->ctx, int *, map->n);
199         if (!snap || !k || !n || !index)
200                 goto error;
201
202         for (i = 0; i < map->n; ++i) {
203                 bmap = isl_basic_map_align_divs(bmap, map->p[i]);
204                 if (!bmap)
205                         goto error;
206         }
207         for (i = 0; i < map->n; ++i) {
208                 map->p[i] = isl_basic_map_align_divs(map->p[i], bmap);
209                 if (!map->p[i])
210                         goto error;
211         }
212
213         tab = isl_tab_from_basic_map(bmap);
214         if (isl_tab_track_bmap(tab, isl_basic_map_copy(bmap)) < 0)
215                 goto error;
216
217         modified = 0;
218         level = 0;
219         init = 1;
220
221         while (level >= 0) {
222                 if (level >= map->n) {
223                         int empty;
224                         struct isl_basic_map *bm;
225                         if (!modified) {
226                                 if (dc->add(dc, isl_basic_map_copy(bmap)) < 0)
227                                         goto error;
228                                 break;
229                         }
230                         bm = isl_basic_map_copy(tab->bmap);
231                         bm = isl_basic_map_cow(bm);
232                         bm = isl_basic_map_update_from_tab(bm, tab);
233                         bm = isl_basic_map_simplify(bm);
234                         bm = isl_basic_map_finalize(bm);
235                         empty = isl_basic_map_is_empty(bm);
236                         if (empty)
237                                 isl_basic_map_free(bm);
238                         else if (dc->add(dc, bm) < 0)
239                                 goto error;
240                         if (empty < 0)
241                                 goto error;
242                         level--;
243                         init = 0;
244                         continue;
245                 }
246                 if (init) {
247                         int offset = tab->n_con;
248                         snap[level] = isl_tab_snap(tab);
249                         if (tab_freeze_constraints(tab) < 0)
250                                 goto error;
251                         if (tab_add_constraints(tab, map->p[level]) < 0)
252                                 goto error;
253                         k[level] = 0;
254                         n[level] = 0;
255                         if (tab->empty) {
256                                 if (isl_tab_rollback(tab, snap[level]) < 0)
257                                         goto error;
258                                 level++;
259                                 continue;
260                         }
261                         modified = 1;
262                         n[level] = n_non_redundant(tab, offset, &index[level]);
263                         if (n[level] < 0)
264                                 goto error;
265                         if (n[level] == 0) {
266                                 level--;
267                                 init = 0;
268                                 continue;
269                         }
270                         if (isl_tab_rollback(tab, snap[level]) < 0)
271                                 goto error;
272                         if (tab_add_constraint(tab, map->p[level],
273                                                 index[level][0], 1) < 0)
274                                 goto error;
275                         level++;
276                         continue;
277                 } else {
278                         if (k[level] + 1 >= n[level]) {
279                                 level--;
280                                 continue;
281                         }
282                         if (isl_tab_rollback(tab, snap[level]) < 0)
283                                 goto error;
284                         if (tab_add_constraint(tab, map->p[level],
285                                                 index[level][k[level]], 0) < 0)
286                                 goto error;
287                         snap[level] = isl_tab_snap(tab);
288                         k[level]++;
289                         if (tab_add_constraint(tab, map->p[level],
290                                                 index[level][k[level]], 1) < 0)
291                                 goto error;
292                         level++;
293                         init = 1;
294                         continue;
295                 }
296         }
297
298         isl_tab_free(tab);
299         free(snap);
300         free(n);
301         free(k);
302         for (i = 0; index && i < map->n; ++i)
303                 free(index[i]);
304         free(index);
305
306         isl_basic_map_free(bmap);
307         isl_map_free(map);
308
309         return 0;
310 error:
311         isl_tab_free(tab);
312         free(snap);
313         free(n);
314         free(k);
315         for (i = 0; index && i < map->n; ++i)
316                 free(index[i]);
317         free(index);
318         isl_basic_map_free(bmap);
319         isl_map_free(map);
320         return -1;
321 }
322
323 /* A diff collector that actually collects all parts of the
324  * set difference in the field diff.
325  */
326 struct isl_subtract_diff_collector {
327         struct isl_diff_collector dc;
328         struct isl_map *diff;
329 };
330
331 /* isl_subtract_diff_collector callback.
332  */
333 int basic_map_subtract_add(struct isl_diff_collector *dc,
334                             __isl_take isl_basic_map *bmap)
335 {
336         struct isl_subtract_diff_collector *sdc;
337         sdc = (struct isl_subtract_diff_collector *)dc;
338
339         sdc->diff = isl_map_union_disjoint(sdc->diff,
340                         isl_map_from_basic_map(bmap));
341
342         return sdc->diff ? 0 : -1;
343 }
344
345 /* Return the set difference between bmap and map.
346  */
347 static __isl_give isl_map *basic_map_subtract(__isl_take isl_basic_map *bmap,
348         __isl_take isl_map *map)
349 {
350         struct isl_subtract_diff_collector sdc;
351         sdc.dc.add = &basic_map_subtract_add;
352         sdc.diff = isl_map_empty_like_basic_map(bmap);
353         if (basic_map_collect_diff(bmap, map, &sdc.dc) < 0) {
354                 isl_map_free(sdc.diff);
355                 sdc.diff = NULL;
356         }
357         return sdc.diff;
358 }
359
360 /* Return the set difference between map1 and map2.
361  * (U_i A_i) \ (U_j B_j) is computed as U_i (A_i \ (U_j B_j))
362  */
363 struct isl_map *isl_map_subtract(struct isl_map *map1, struct isl_map *map2)
364 {
365         int i;
366         struct isl_map *diff;
367
368         if (!map1 || !map2)
369                 goto error;
370
371         isl_assert(map1->ctx, isl_dim_equal(map1->dim, map2->dim), goto error);
372
373         if (isl_map_is_empty(map2)) {
374                 isl_map_free(map2);
375                 return map1;
376         }
377
378         map1 = isl_map_compute_divs(map1);
379         map2 = isl_map_compute_divs(map2);
380         if (!map1 || !map2)
381                 goto error;
382
383         map1 = isl_map_remove_empty_parts(map1);
384         map2 = isl_map_remove_empty_parts(map2);
385
386         diff = isl_map_empty_like(map1);
387         for (i = 0; i < map1->n; ++i) {
388                 struct isl_map *d;
389                 d = basic_map_subtract(isl_basic_map_copy(map1->p[i]),
390                                        isl_map_copy(map2));
391                 if (ISL_F_ISSET(map1, ISL_MAP_DISJOINT))
392                         diff = isl_map_union_disjoint(diff, d);
393                 else
394                         diff = isl_map_union(diff, d);
395         }
396
397         isl_map_free(map1);
398         isl_map_free(map2);
399
400         return diff;
401 error:
402         isl_map_free(map1);
403         isl_map_free(map2);
404         return NULL;
405 }
406
407 struct isl_set *isl_set_subtract(struct isl_set *set1, struct isl_set *set2)
408 {
409         return (struct isl_set *)
410                 isl_map_subtract(
411                         (struct isl_map *)set1, (struct isl_map *)set2);
412 }
413
414 /* Return 1 if "bmap" contains a single element.
415  */
416 int isl_basic_map_is_singleton(__isl_keep isl_basic_map *bmap)
417 {
418         if (!bmap)
419                 return -1;
420         if (bmap->n_div)
421                 return 0;
422         if (bmap->n_ineq)
423                 return 0;
424         return bmap->n_eq == isl_basic_map_total_dim(bmap);
425 }
426
427 /* Return 1 if "map" contains a single element.
428  */
429 int isl_map_is_singleton(__isl_keep isl_map *map)
430 {
431         if (!map)
432                 return -1;
433         if (map->n != 1)
434                 return 0;
435
436         return isl_basic_map_is_singleton(map->p[0]);
437 }
438
439 /* Given a singleton basic map, extract the single element
440  * as an isl_vec.
441  */
442 static __isl_give isl_vec *singleton_extract_point(__isl_keep isl_basic_map *bmap)
443 {
444         int i, j;
445         unsigned dim;
446         struct isl_vec *point;
447         isl_int m;
448
449         if (!bmap)
450                 return NULL;
451
452         dim = isl_basic_map_total_dim(bmap);
453         isl_assert(bmap->ctx, bmap->n_eq == dim, return NULL);
454         point = isl_vec_alloc(bmap->ctx, 1 + dim);
455         if (!point)
456                 return NULL;
457
458         isl_int_init(m);
459
460         isl_int_set_si(point->el[0], 1);
461         for (j = 0; j < bmap->n_eq; ++j) {
462                 int s;
463                 int i = dim - 1 - j;
464                 isl_assert(bmap->ctx,
465                     isl_seq_first_non_zero(bmap->eq[j] + 1, i) == -1,
466                     goto error);
467                 isl_assert(bmap->ctx,
468                     isl_int_is_one(bmap->eq[j][1 + i]) ||
469                     isl_int_is_negone(bmap->eq[j][1 + i]),
470                     goto error);
471                 isl_assert(bmap->ctx,
472                     isl_seq_first_non_zero(bmap->eq[j]+1+i+1, dim-i-1) == -1,
473                     goto error);
474
475                 isl_int_gcd(m, point->el[0], bmap->eq[j][1 + i]);
476                 isl_int_divexact(m, bmap->eq[j][1 + i], m);
477                 isl_int_abs(m, m);
478                 isl_seq_scale(point->el, point->el, m, 1 + i);
479                 isl_int_divexact(m, point->el[0], bmap->eq[j][1 + i]);
480                 isl_int_neg(m, m);
481                 isl_int_mul(point->el[1 + i], m, bmap->eq[j][0]);
482         }
483
484         isl_int_clear(m);
485         return point;
486 error:
487         isl_int_clear(m);
488         isl_vec_free(point);
489         return NULL;
490 }
491
492 /* Return 1 if "bmap" contains the point "point".
493  * "bmap" is assumed to have known divs.
494  * The point is first extended with the divs and then passed
495  * to basic_map_contains.
496  */
497 static int basic_map_contains_point(__isl_keep isl_basic_map *bmap,
498         __isl_keep isl_vec *point)
499 {
500         int i;
501         struct isl_vec *vec;
502         unsigned dim;
503         int contains;
504
505         if (!bmap || !point)
506                 return -1;
507         if (bmap->n_div == 0)
508                 return isl_basic_map_contains(bmap, point);
509
510         dim = isl_basic_map_total_dim(bmap) - bmap->n_div;
511         vec = isl_vec_alloc(bmap->ctx, 1 + dim + bmap->n_div);
512         if (!vec)
513                 return -1;
514
515         isl_seq_cpy(vec->el, point->el, point->size);
516         for (i = 0; i < bmap->n_div; ++i) {
517                 isl_seq_inner_product(bmap->div[i] + 1, vec->el,
518                                         1 + dim + i, &vec->el[1+dim+i]);
519                 isl_int_fdiv_q(vec->el[1+dim+i], vec->el[1+dim+i],
520                                 bmap->div[i][0]);
521         }
522
523         contains = isl_basic_map_contains(bmap, vec);
524
525         isl_vec_free(vec);
526         return contains;
527 }
528
529 /* Return 1 is the singleton map "map1" is a subset of "map2",
530  * i.e., if the single element of "map1" is also an element of "map2".
531  */
532 static int map_is_singleton_subset(__isl_keep isl_map *map1,
533         __isl_keep isl_map *map2)
534 {
535         int i;
536         int is_subset = 0;
537         struct isl_vec *point;
538
539         if (!map1 || !map2)
540                 return -1;
541         if (map1->n != 1)
542                 return -1;
543
544         point = singleton_extract_point(map1->p[0]);
545         if (!point)
546                 return -1;
547
548         for (i = 0; i < map2->n; ++i) {
549                 is_subset = basic_map_contains_point(map2->p[i], point);
550                 if (is_subset)
551                         break;
552         }
553
554         isl_vec_free(point);
555         return is_subset;
556 }
557
558 int isl_map_is_subset(struct isl_map *map1, struct isl_map *map2)
559 {
560         int is_subset = 0;
561         struct isl_map *diff;
562
563         if (!map1 || !map2)
564                 return -1;
565
566         if (isl_map_is_empty(map1))
567                 return 1;
568
569         if (isl_map_is_empty(map2))
570                 return 0;
571
572         if (isl_map_fast_is_universe(map2))
573                 return 1;
574
575         map2 = isl_map_compute_divs(isl_map_copy(map2));
576         if (isl_map_is_singleton(map1)) {
577                 is_subset = map_is_singleton_subset(map1, map2);
578                 isl_map_free(map2);
579                 return is_subset;
580         }
581         diff = isl_map_subtract(isl_map_copy(map1), map2);
582         if (!diff)
583                 return -1;
584
585         is_subset = isl_map_is_empty(diff);
586         isl_map_free(diff);
587
588         return is_subset;
589 }
590
591 int isl_set_is_subset(struct isl_set *set1, struct isl_set *set2)
592 {
593         return isl_map_is_subset(
594                         (struct isl_map *)set1, (struct isl_map *)set2);
595 }