add isl_map_insert
[platform/upstream/isl.git] / isl_dim.c
1 /*
2  * Copyright 2008-2009 Katholieke Universiteit Leuven
3  *
4  * Use of this software is governed by the GNU LGPLv2.1 license
5  *
6  * Written by Sven Verdoolaege, K.U.Leuven, Departement
7  * Computerwetenschappen, Celestijnenlaan 200A, B-3001 Leuven, Belgium
8  */
9
10 #include <isl_dim_private.h>
11 #include "isl_name.h"
12
13 struct isl_dim *isl_dim_alloc(struct isl_ctx *ctx,
14                         unsigned nparam, unsigned n_in, unsigned n_out)
15 {
16         struct isl_dim *dim;
17
18         dim = isl_alloc_type(ctx, struct isl_dim);
19         if (!dim)
20                 return NULL;
21
22         dim->ctx = ctx;
23         isl_ctx_ref(ctx);
24         dim->ref = 1;
25         dim->nparam = nparam;
26         dim->n_in = n_in;
27         dim->n_out = n_out;
28
29         dim->n_name = 0;
30         dim->names = NULL;
31
32         return dim;
33 }
34
35 struct isl_dim *isl_dim_set_alloc(struct isl_ctx *ctx,
36                         unsigned nparam, unsigned dim)
37 {
38         return isl_dim_alloc(ctx, nparam, 0, dim);
39 }
40
41 static unsigned global_pos(struct isl_dim *dim,
42                                  enum isl_dim_type type, unsigned pos)
43 {
44         struct isl_ctx *ctx = dim->ctx;
45
46         switch (type) {
47         case isl_dim_param:
48                 isl_assert(ctx, pos < dim->nparam, return isl_dim_total(dim));
49                 return pos;
50         case isl_dim_in:
51                 isl_assert(ctx, pos < dim->n_in, return isl_dim_total(dim));
52                 return pos + dim->nparam;
53         case isl_dim_out:
54                 isl_assert(ctx, pos < dim->n_out, return isl_dim_total(dim));
55                 return pos + dim->nparam + dim->n_in;
56         default:
57                 isl_assert(ctx, 0, return isl_dim_total(dim));
58         }
59         return isl_dim_total(dim);
60 }
61
62 /* Extend length of names array to the total number of dimensions.
63  */
64 static __isl_give isl_dim *extend_names(__isl_take isl_dim *dim)
65 {
66         struct isl_name **names;
67         int i;
68
69         if (isl_dim_total(dim) <= dim->n_name)
70                 return dim;
71
72         if (!dim->names) {
73                 dim->names = isl_calloc_array(dim->ctx,
74                                 struct isl_name *, isl_dim_total(dim));
75                 if (!dim->names)
76                         goto error;
77         } else {
78                 names = isl_realloc_array(dim->ctx, dim->names,
79                                 struct isl_name *, isl_dim_total(dim));
80                 if (!names)
81                         goto error;
82                 dim->names = names;
83                 for (i = dim->n_name; i < isl_dim_total(dim); ++i)
84                         dim->names[i] = NULL;
85         }
86
87         dim->n_name = isl_dim_total(dim);
88
89         return dim;
90 error:
91         isl_dim_free(dim);
92         return NULL;
93 }
94
95 static struct isl_dim *set_name(struct isl_dim *dim,
96                                  enum isl_dim_type type, unsigned pos,
97                                  struct isl_name *name)
98 {
99         struct isl_ctx *ctx = dim->ctx;
100         dim = isl_dim_cow(dim);
101
102         if (!dim)
103                 goto error;
104
105         pos = global_pos(dim, type, pos);
106         isl_assert(ctx, pos != isl_dim_total(dim), goto error);
107
108         if (pos >= dim->n_name) {
109                 if (!name)
110                         return dim;
111                 dim = extend_names(dim);
112                 if (!dim)
113                         goto error;
114         }
115
116         dim->names[pos] = name;
117
118         return dim;
119 error:
120         isl_name_free(ctx, name);
121         isl_dim_free(dim);
122         return NULL;
123 }
124
125 static struct isl_name *get_name(struct isl_dim *dim,
126                                  enum isl_dim_type type, unsigned pos)
127 {
128         if (!dim)
129                 return NULL;
130
131         pos = global_pos(dim, type, pos);
132         if (pos == isl_dim_total(dim))
133                 return NULL;
134         if (pos >= dim->n_name)
135                 return NULL;
136         return dim->names[pos];
137 }
138
139 static unsigned offset(struct isl_dim *dim, enum isl_dim_type type)
140 {
141         switch (type) {
142         case isl_dim_param:     return 0;
143         case isl_dim_in:        return dim->nparam;
144         case isl_dim_out:       return dim->nparam + dim->n_in;
145         }
146 }
147
148 static unsigned n(struct isl_dim *dim, enum isl_dim_type type)
149 {
150         switch (type) {
151         case isl_dim_param:     return dim->nparam;
152         case isl_dim_in:        return dim->n_in;
153         case isl_dim_out:       return dim->n_out;
154         }
155 }
156
157 unsigned isl_dim_size(struct isl_dim *dim, enum isl_dim_type type)
158 {
159         if (!dim)
160                 return 0;
161         return n(dim, type);
162 }
163
164 static struct isl_dim *copy_names(struct isl_dim *dst,
165         enum isl_dim_type dst_type, unsigned offset, struct isl_dim *src,
166         enum isl_dim_type src_type)
167 {
168         int i;
169         struct isl_name *name;
170
171         for (i = 0; i < n(src, src_type); ++i) {
172                 name = get_name(src, src_type, i);
173                 if (!name)
174                         continue;
175                 dst = set_name(dst, dst_type, offset + i,
176                                         isl_name_copy(dst->ctx, name));
177                 if (!dst)
178                         return NULL;
179         }
180         return dst;
181 }
182
183 struct isl_dim *isl_dim_dup(struct isl_dim *dim)
184 {
185         struct isl_dim *dup;
186         dup = isl_dim_alloc(dim->ctx, dim->nparam, dim->n_in, dim->n_out);
187         if (!dim->names)
188                 return dup;
189         dup = copy_names(dup, isl_dim_param, 0, dim, isl_dim_param);
190         dup = copy_names(dup, isl_dim_in, 0, dim, isl_dim_in);
191         dup = copy_names(dup, isl_dim_out, 0, dim, isl_dim_out);
192         return dup;
193 }
194
195 struct isl_dim *isl_dim_cow(struct isl_dim *dim)
196 {
197         if (!dim)
198                 return NULL;
199
200         if (dim->ref == 1)
201                 return dim;
202         dim->ref--;
203         return isl_dim_dup(dim);
204 }
205
206 struct isl_dim *isl_dim_copy(struct isl_dim *dim)
207 {
208         if (!dim)
209                 return NULL;
210
211         dim->ref++;
212         return dim;
213 }
214
215 void isl_dim_free(struct isl_dim *dim)
216 {
217         int i;
218
219         if (!dim)
220                 return;
221
222         if (--dim->ref > 0)
223                 return;
224
225         for (i = 0; i < dim->n_name; ++i)
226                 isl_name_free(dim->ctx, dim->names[i]);
227         free(dim->names);
228         isl_ctx_deref(dim->ctx);
229         
230         free(dim);
231 }
232
233 struct isl_dim *isl_dim_set_name(struct isl_dim *dim,
234                                  enum isl_dim_type type, unsigned pos,
235                                  const char *s)
236 {
237         struct isl_name *name;
238         if (!dim)
239                 return NULL;
240         name = isl_name_get(dim->ctx, s);
241         if (!name)
242                 goto error;
243         return set_name(dim, type, pos, name);
244 error:
245         isl_dim_free(dim);
246         return NULL;
247 }
248
249 const char *isl_dim_get_name(struct isl_dim *dim,
250                                  enum isl_dim_type type, unsigned pos)
251 {
252         struct isl_name *name = get_name(dim, type, pos);
253         return name ? name->name : NULL;
254 }
255
256 static int match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
257                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
258 {
259         int i;
260
261         if (n(dim1, dim1_type) != n(dim2, dim2_type))
262                 return 0;
263
264         if (!dim1->names && !dim2->names)
265                 return 1;
266
267         for (i = 0; i < n(dim1, dim1_type); ++i) {
268                 if (get_name(dim1, dim1_type, i) !=
269                     get_name(dim2, dim2_type, i))
270                         return 0;
271         }
272         return 1;
273 }
274
275 int isl_dim_match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
276                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
277 {
278         return match(dim1, dim1_type, dim2, dim2_type);
279 }
280
281 static void get_names(struct isl_dim *dim, enum isl_dim_type type,
282         unsigned first, unsigned n, struct isl_name **names)
283 {
284         int i;
285
286         for (i = 0; i < n ; ++i)
287                 names[i] = get_name(dim, type, first+i);
288 }
289
290 struct isl_dim *isl_dim_extend(struct isl_dim *dim,
291                         unsigned nparam, unsigned n_in, unsigned n_out)
292 {
293         struct isl_name **names = NULL;
294
295         if (!dim)
296                 return NULL;
297         if (dim->nparam == nparam && dim->n_in == n_in && dim->n_out == n_out)
298                 return dim;
299
300         isl_assert(dim->ctx, dim->nparam <= nparam, goto error);
301         isl_assert(dim->ctx, dim->n_in <= n_in, goto error);
302         isl_assert(dim->ctx, dim->n_out <= n_out, goto error);
303
304         dim = isl_dim_cow(dim);
305
306         if (dim->names) {
307                 names = isl_calloc_array(dim->ctx, struct isl_name *,
308                                          nparam + n_in + n_out);
309                 if (!names)
310                         goto error;
311                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
312                 get_names(dim, isl_dim_in, 0, dim->n_in, names + nparam);
313                 get_names(dim, isl_dim_out, 0, dim->n_out,
314                                 names + nparam + n_in);
315                 free(dim->names);
316                 dim->names = names;
317                 dim->n_name = nparam + n_in + n_out;
318         }
319         dim->nparam = nparam;
320         dim->n_in = n_in;
321         dim->n_out = n_out;
322
323         return dim;
324 error:
325         free(names);
326         isl_dim_free(dim);
327         return NULL;
328 }
329
330 struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
331         unsigned n)
332 {
333         switch (type) {
334         case isl_dim_param:
335                 return isl_dim_extend(dim,
336                                         dim->nparam + n, dim->n_in, dim->n_out);
337         case isl_dim_in:
338                 return isl_dim_extend(dim,
339                                         dim->nparam, dim->n_in + n, dim->n_out);
340         case isl_dim_out:
341                 return isl_dim_extend(dim,
342                                         dim->nparam, dim->n_in, dim->n_out + n);
343         }
344         return dim;
345 }
346
347 __isl_give isl_dim *isl_dim_insert(__isl_take isl_dim *dim,
348         enum isl_dim_type type, unsigned pos, unsigned n)
349 {
350         struct isl_name **names = NULL;
351
352         if (!dim)
353                 return NULL;
354         if (n == 0)
355                 return dim;
356
357         isl_assert(dim->ctx, pos <= isl_dim_size(dim, type), goto error);
358
359         dim = isl_dim_cow(dim);
360
361         if (dim->names) {
362                 enum isl_dim_type t;
363                 int off;
364                 int size[3];
365                 names = isl_calloc_array(dim->ctx, struct isl_name *,
366                                      dim->nparam + dim->n_in + dim->n_out + n);
367                 if (!names)
368                         goto error;
369                 off = 0;
370                 size[isl_dim_param] = dim->nparam;
371                 size[isl_dim_in] = dim->n_in;
372                 size[isl_dim_out] = dim->n_out;
373                 for (t = isl_dim_param; t <= isl_dim_out; ++t) {
374                         if (t != type) {
375                                 get_names(dim, t, 0, size[t], names + off);
376                                 off += size[t];
377                         } else {
378                                 get_names(dim, t, 0, pos, names + off);
379                                 off += pos + n;
380                                 get_names(dim, t, pos, size[t]-pos, names+off);
381                                 off += size[t] - pos;
382                         }
383                 }
384                 free(dim->names);
385                 dim->names = names;
386                 dim->n_name = dim->nparam + dim->n_in + dim->n_out + n;
387         }
388         switch (type) {
389         case isl_dim_param:     dim->nparam += n; break;
390         case isl_dim_in:        dim->n_in += n; break;
391         case isl_dim_out:       dim->n_out += n; break;
392         }
393
394         return dim;
395 error:
396         isl_dim_free(dim);
397         return NULL;
398 }
399
400 __isl_give isl_dim *isl_dim_move(__isl_take isl_dim *dim,
401         enum isl_dim_type dst_type, unsigned dst_pos,
402         enum isl_dim_type src_type, unsigned src_pos, unsigned n)
403 {
404         if (!dim)
405                 return NULL;
406         if (n == 0)
407                 return dim;
408
409         isl_assert(dim->ctx, src_pos + n <= isl_dim_size(dim, src_type),
410                 goto error);
411
412         /* just the simple case for now */
413         isl_assert(dim->ctx,
414                 offset(dim, dst_type) + dst_pos ==
415                 offset(dim, src_type) + src_pos + ((src_type < dst_type) ? n : 0),
416                 goto error);
417
418         if (dst_type == src_type)
419                 return dim;
420
421         dim = isl_dim_cow(dim);
422         if (!dim)
423                 return NULL;
424
425         switch (dst_type) {
426         case isl_dim_param:     dim->nparam += n; break;
427         case isl_dim_in:        dim->n_in += n; break;
428         case isl_dim_out:       dim->n_out += n; break;
429         }
430
431         switch (src_type) {
432         case isl_dim_param:     dim->nparam -= n; break;
433         case isl_dim_in:        dim->n_in -= n; break;
434         case isl_dim_out:       dim->n_out -= n; break;
435         }
436
437         return dim;
438 error:
439         isl_dim_free(dim);
440         return NULL;
441 }
442
443 struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
444 {
445         struct isl_dim *dim;
446
447         if (!left || !right)
448                 goto error;
449
450         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
451                         goto error);
452         isl_assert(left->ctx, n(left, isl_dim_out) == n(right, isl_dim_in),
453                         goto error);
454
455         dim = isl_dim_alloc(left->ctx, left->nparam, left->n_in, right->n_out);
456         if (!dim)
457                 goto error;
458
459         dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
460         dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
461         dim = copy_names(dim, isl_dim_out, 0, right, isl_dim_out);
462
463         isl_dim_free(left);
464         isl_dim_free(right);
465
466         return dim;
467 error:
468         isl_dim_free(left);
469         isl_dim_free(right);
470         return NULL;
471 }
472
473 struct isl_dim *isl_dim_product(struct isl_dim *left, struct isl_dim *right)
474 {
475         struct isl_dim *dim;
476
477         if (!left || !right)
478                 goto error;
479
480         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
481                         goto error);
482
483         dim = isl_dim_alloc(left->ctx, left->nparam,
484                         left->n_in + right->n_in, left->n_out + right->n_out);
485         if (!dim)
486                 goto error;
487
488         dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
489         dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
490         dim = copy_names(dim, isl_dim_in, left->n_in, right, isl_dim_in);
491         dim = copy_names(dim, isl_dim_out, 0, left, isl_dim_out);
492         dim = copy_names(dim, isl_dim_out, left->n_out, right, isl_dim_out);
493
494         isl_dim_free(left);
495         isl_dim_free(right);
496
497         return dim;
498 error:
499         isl_dim_free(left);
500         isl_dim_free(right);
501         return NULL;
502 }
503
504 struct isl_dim *isl_dim_map(struct isl_dim *dim)
505 {
506         struct isl_name **names = NULL;
507
508         if (!dim)
509                 return NULL;
510         isl_assert(dim->ctx, dim->n_in == 0, goto error);
511         if (dim->n_out == 0)
512                 return dim;
513         dim = isl_dim_cow(dim);
514         if (!dim)
515                 return NULL;
516         if (dim->names) {
517                 names = isl_calloc_array(dim->ctx, struct isl_name *,
518                                         dim->nparam + dim->n_out + dim->n_out);
519                 if (!names)
520                         goto error;
521                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
522                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->nparam);
523         }
524         dim->n_in = dim->n_out;
525         if (names) {
526                 free(dim->names);
527                 dim->names = names;
528                 dim->n_name = dim->nparam + dim->n_out + dim->n_out;
529                 dim = copy_names(dim, isl_dim_out, 0, dim, isl_dim_in);
530         }
531         return dim;
532 error:
533         isl_dim_free(dim);
534         return NULL;
535 }
536
537 static struct isl_dim *set_names(struct isl_dim *dim, enum isl_dim_type type,
538         unsigned first, unsigned n, struct isl_name **names)
539 {
540         int i;
541
542         for (i = 0; i < n ; ++i)
543                 dim = set_name(dim, type, first+i, names[i]);
544
545         return dim;
546 }
547
548 struct isl_dim *isl_dim_reverse(struct isl_dim *dim)
549 {
550         unsigned t;
551         struct isl_name **names = NULL;
552
553         if (!dim)
554                 return NULL;
555         if (match(dim, isl_dim_in, dim, isl_dim_out))
556                 return dim;
557
558         dim = isl_dim_cow(dim);
559         if (!dim)
560                 return NULL;
561
562         if (dim->names) {
563                 names = isl_alloc_array(dim->ctx, struct isl_name *,
564                                         dim->n_in + dim->n_out);
565                 if (!names)
566                         goto error;
567                 get_names(dim, isl_dim_in, 0, dim->n_in, names);
568                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->n_in);
569         }
570
571         t = dim->n_in;
572         dim->n_in = dim->n_out;
573         dim->n_out = t;
574
575         if (dim->names) {
576                 dim = set_names(dim, isl_dim_out, 0, dim->n_out, names);
577                 dim = set_names(dim, isl_dim_in, 0, dim->n_in, names + dim->n_out);
578                 free(names);
579         }
580
581         return dim;
582 error:
583         free(names);
584         isl_dim_free(dim);
585         return NULL;
586 }
587
588 struct isl_dim *isl_dim_drop(struct isl_dim *dim, enum isl_dim_type type,
589                 unsigned first, unsigned num)
590 {
591         int i;
592
593         if (!dim)
594                 return NULL;
595
596         if (n == 0)
597                 return dim;
598
599         isl_assert(dim->ctx, first + num <= n(dim, type), goto error);
600         dim = isl_dim_cow(dim);
601         if (!dim)
602                 goto error;
603         if (dim->names) {
604                 dim = extend_names(dim);
605                 if (!dim)
606                         goto error;
607                 for (i = 0; i < num; ++i)
608                         isl_name_free(dim->ctx, get_name(dim, type, first+i));
609                 for (i = first+num; i < n(dim, type); ++i)
610                         set_name(dim, type, i - num, get_name(dim, type, i));
611                 switch (type) {
612                 case isl_dim_param:
613                         get_names(dim, isl_dim_in, 0, dim->n_in,
614                                 dim->names + offset(dim, isl_dim_in) - num);
615                 case isl_dim_in:
616                         get_names(dim, isl_dim_out, 0, dim->n_out,
617                                 dim->names + offset(dim, isl_dim_out) - num);
618                 case isl_dim_out:
619                         ;
620                 }
621                 dim->n_name -= num;
622         }
623         switch (type) {
624         case isl_dim_param:     dim->nparam -= num; break;
625         case isl_dim_in:        dim->n_in -= num; break;
626         case isl_dim_out:       dim->n_out -= num; break;
627         }
628         return dim;
629 error:
630         isl_dim_free(dim);
631         return NULL;
632 }
633
634 struct isl_dim *isl_dim_drop_inputs(struct isl_dim *dim,
635                 unsigned first, unsigned n)
636 {
637         return isl_dim_drop(dim, isl_dim_in, first, n);
638 }
639
640 struct isl_dim *isl_dim_drop_outputs(struct isl_dim *dim,
641                 unsigned first, unsigned n)
642 {
643         return isl_dim_drop(dim, isl_dim_out, first, n);
644 }
645
646 struct isl_dim *isl_dim_domain(struct isl_dim *dim)
647 {
648         if (!dim)
649                 return NULL;
650         dim = isl_dim_drop_outputs(dim, 0, dim->n_out);
651         return isl_dim_reverse(dim);
652 }
653
654 struct isl_dim *isl_dim_range(struct isl_dim *dim)
655 {
656         if (!dim)
657                 return NULL;
658         return isl_dim_drop_inputs(dim, 0, dim->n_in);
659 }
660
661 struct isl_dim *isl_dim_underlying(struct isl_dim *dim, unsigned n_div)
662 {
663         int i;
664
665         if (!dim)
666                 return NULL;
667         if (n_div == 0 &&
668             dim->nparam == 0 && dim->n_in == 0 && dim->n_name == 0)
669                 return dim;
670         dim = isl_dim_cow(dim);
671         if (!dim)
672                 return NULL;
673         dim->n_out += dim->nparam + dim->n_in + n_div;
674         dim->nparam = 0;
675         dim->n_in = 0;
676
677         for (i = 0; i < dim->n_name; ++i)
678                 isl_name_free(dim->ctx, get_name(dim, isl_dim_out, i));
679         dim->n_name = 0;
680
681         return dim;
682 }
683
684 unsigned isl_dim_total(struct isl_dim *dim)
685 {
686         return dim->nparam + dim->n_in + dim->n_out;
687 }
688
689 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
690 {
691         return match(dim1, isl_dim_param, dim2, isl_dim_param) &&
692                n(dim1, isl_dim_in) == n(dim2, isl_dim_in) &&
693                n(dim1, isl_dim_out) == n(dim2, isl_dim_out);
694 }
695
696 int isl_dim_compatible(struct isl_dim *dim1, struct isl_dim *dim2)
697 {
698         return dim1->nparam == dim2->nparam &&
699                dim1->n_in + dim1->n_out == dim2->n_in + dim2->n_out;
700 }