isl_dim_dup: check input dim
[platform/upstream/isl.git] / isl_dim.c
1 /*
2  * Copyright 2008-2009 Katholieke Universiteit Leuven
3  *
4  * Use of this software is governed by the GNU LGPLv2.1 license
5  *
6  * Written by Sven Verdoolaege, K.U.Leuven, Departement
7  * Computerwetenschappen, Celestijnenlaan 200A, B-3001 Leuven, Belgium
8  */
9
10 #include <isl_dim.h>
11 #include "isl_name.h"
12
13 struct isl_dim *isl_dim_alloc(struct isl_ctx *ctx,
14                         unsigned nparam, unsigned n_in, unsigned n_out)
15 {
16         struct isl_dim *dim;
17
18         dim = isl_alloc_type(ctx, struct isl_dim);
19         if (!dim)
20                 return NULL;
21
22         dim->ctx = ctx;
23         isl_ctx_ref(ctx);
24         dim->ref = 1;
25         dim->nparam = nparam;
26         dim->n_in = n_in;
27         dim->n_out = n_out;
28
29         dim->n_name = 0;
30         dim->names = NULL;
31
32         return dim;
33 }
34
35 struct isl_dim *isl_dim_set_alloc(struct isl_ctx *ctx,
36                         unsigned nparam, unsigned dim)
37 {
38         return isl_dim_alloc(ctx, nparam, 0, dim);
39 }
40
41 static unsigned global_pos(struct isl_dim *dim,
42                                  enum isl_dim_type type, unsigned pos)
43 {
44         struct isl_ctx *ctx = dim->ctx;
45
46         switch (type) {
47         case isl_dim_param:
48                 isl_assert(ctx, pos < dim->nparam, return isl_dim_total(dim));
49                 return pos;
50         case isl_dim_in:
51                 isl_assert(ctx, pos < dim->n_in, return isl_dim_total(dim));
52                 return pos + dim->nparam;
53         case isl_dim_out:
54                 isl_assert(ctx, pos < dim->n_out, return isl_dim_total(dim));
55                 return pos + dim->nparam + dim->n_in;
56         default:
57                 isl_assert(ctx, 0, return isl_dim_total(dim));
58         }
59         return isl_dim_total(dim);
60 }
61
62 /* Extend length of names array to the total number of dimensions.
63  */
64 static __isl_give isl_dim *extend_names(__isl_take isl_dim *dim)
65 {
66         struct isl_name **names;
67         int i;
68
69         if (isl_dim_total(dim) <= dim->n_name)
70                 return dim;
71
72         if (!dim->names) {
73                 dim->names = isl_calloc_array(dim->ctx,
74                                 struct isl_name *, isl_dim_total(dim));
75                 if (!dim->names)
76                         goto error;
77         } else {
78                 names = isl_realloc_array(dim->ctx, dim->names,
79                                 struct isl_name *, isl_dim_total(dim));
80                 if (!names)
81                         goto error;
82                 dim->names = names;
83                 for (i = dim->n_name; i < isl_dim_total(dim); ++i)
84                         dim->names[i] = NULL;
85         }
86
87         dim->n_name = isl_dim_total(dim);
88
89         return dim;
90 error:
91         isl_dim_free(dim);
92         return NULL;
93 }
94
95 static struct isl_dim *set_name(struct isl_dim *dim,
96                                  enum isl_dim_type type, unsigned pos,
97                                  struct isl_name *name)
98 {
99         struct isl_ctx *ctx = dim->ctx;
100         dim = isl_dim_cow(dim);
101
102         if (!dim)
103                 goto error;
104
105         pos = global_pos(dim, type, pos);
106         isl_assert(ctx, pos != isl_dim_total(dim), goto error);
107
108         if (pos >= dim->n_name) {
109                 if (!name)
110                         return dim;
111                 dim = extend_names(dim);
112                 if (!dim)
113                         goto error;
114         }
115
116         dim->names[pos] = name;
117
118         return dim;
119 error:
120         isl_name_free(ctx, name);
121         isl_dim_free(dim);
122         return NULL;
123 }
124
125 static struct isl_name *get_name(struct isl_dim *dim,
126                                  enum isl_dim_type type, unsigned pos)
127 {
128         if (!dim)
129                 return NULL;
130
131         pos = global_pos(dim, type, pos);
132         if (pos == isl_dim_total(dim))
133                 return NULL;
134         if (pos >= dim->n_name)
135                 return NULL;
136         return dim->names[pos];
137 }
138
139 static unsigned offset(struct isl_dim *dim, enum isl_dim_type type)
140 {
141         switch (type) {
142         case isl_dim_param:     return 0;
143         case isl_dim_in:        return dim->nparam;
144         case isl_dim_out:       return dim->nparam + dim->n_in;
145         default:                return 0;
146         }
147 }
148
149 static unsigned n(struct isl_dim *dim, enum isl_dim_type type)
150 {
151         switch (type) {
152         case isl_dim_param:     return dim->nparam;
153         case isl_dim_in:        return dim->n_in;
154         case isl_dim_out:       return dim->n_out;
155         default:                return 0;
156         }
157 }
158
159 unsigned isl_dim_size(struct isl_dim *dim, enum isl_dim_type type)
160 {
161         if (!dim)
162                 return 0;
163         return n(dim, type);
164 }
165
166 unsigned isl_dim_offset(__isl_keep isl_dim *dim, enum isl_dim_type type)
167 {
168         if (!dim)
169                 return 0;
170         return offset(dim, type);
171 }
172
173 static struct isl_dim *copy_names(struct isl_dim *dst,
174         enum isl_dim_type dst_type, unsigned offset, struct isl_dim *src,
175         enum isl_dim_type src_type)
176 {
177         int i;
178         struct isl_name *name;
179
180         for (i = 0; i < n(src, src_type); ++i) {
181                 name = get_name(src, src_type, i);
182                 if (!name)
183                         continue;
184                 dst = set_name(dst, dst_type, offset + i,
185                                         isl_name_copy(dst->ctx, name));
186                 if (!dst)
187                         return NULL;
188         }
189         return dst;
190 }
191
192 struct isl_dim *isl_dim_dup(struct isl_dim *dim)
193 {
194         struct isl_dim *dup;
195         if (!dim)
196                 return NULL;
197         dup = isl_dim_alloc(dim->ctx, dim->nparam, dim->n_in, dim->n_out);
198         if (!dim->names)
199                 return dup;
200         dup = copy_names(dup, isl_dim_param, 0, dim, isl_dim_param);
201         dup = copy_names(dup, isl_dim_in, 0, dim, isl_dim_in);
202         dup = copy_names(dup, isl_dim_out, 0, dim, isl_dim_out);
203         return dup;
204 }
205
206 struct isl_dim *isl_dim_cow(struct isl_dim *dim)
207 {
208         if (!dim)
209                 return NULL;
210
211         if (dim->ref == 1)
212                 return dim;
213         dim->ref--;
214         return isl_dim_dup(dim);
215 }
216
217 struct isl_dim *isl_dim_copy(struct isl_dim *dim)
218 {
219         if (!dim)
220                 return NULL;
221
222         dim->ref++;
223         return dim;
224 }
225
226 void isl_dim_free(struct isl_dim *dim)
227 {
228         int i;
229
230         if (!dim)
231                 return;
232
233         if (--dim->ref > 0)
234                 return;
235
236         for (i = 0; i < dim->n_name; ++i)
237                 isl_name_free(dim->ctx, dim->names[i]);
238         free(dim->names);
239         isl_ctx_deref(dim->ctx);
240         
241         free(dim);
242 }
243
244 struct isl_dim *isl_dim_set_name(struct isl_dim *dim,
245                                  enum isl_dim_type type, unsigned pos,
246                                  const char *s)
247 {
248         struct isl_name *name;
249         if (!dim)
250                 return NULL;
251         name = isl_name_get(dim->ctx, s);
252         if (!name)
253                 goto error;
254         return set_name(dim, type, pos, name);
255 error:
256         isl_dim_free(dim);
257         return NULL;
258 }
259
260 const char *isl_dim_get_name(struct isl_dim *dim,
261                                  enum isl_dim_type type, unsigned pos)
262 {
263         struct isl_name *name = get_name(dim, type, pos);
264         return name ? name->name : NULL;
265 }
266
267 static int match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
268                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
269 {
270         int i;
271
272         if (n(dim1, dim1_type) != n(dim2, dim2_type))
273                 return 0;
274
275         if (!dim1->names && !dim2->names)
276                 return 1;
277
278         for (i = 0; i < n(dim1, dim1_type); ++i) {
279                 if (get_name(dim1, dim1_type, i) !=
280                     get_name(dim2, dim2_type, i))
281                         return 0;
282         }
283         return 1;
284 }
285
286 int isl_dim_match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
287                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
288 {
289         return match(dim1, dim1_type, dim2, dim2_type);
290 }
291
292 static void get_names(struct isl_dim *dim, enum isl_dim_type type,
293         unsigned first, unsigned n, struct isl_name **names)
294 {
295         int i;
296
297         for (i = 0; i < n ; ++i)
298                 names[i] = get_name(dim, type, first+i);
299 }
300
301 struct isl_dim *isl_dim_extend(struct isl_dim *dim,
302                         unsigned nparam, unsigned n_in, unsigned n_out)
303 {
304         struct isl_name **names = NULL;
305
306         if (!dim)
307                 return NULL;
308         if (dim->nparam == nparam && dim->n_in == n_in && dim->n_out == n_out)
309                 return dim;
310
311         isl_assert(dim->ctx, dim->nparam <= nparam, goto error);
312         isl_assert(dim->ctx, dim->n_in <= n_in, goto error);
313         isl_assert(dim->ctx, dim->n_out <= n_out, goto error);
314
315         dim = isl_dim_cow(dim);
316
317         if (dim->names) {
318                 names = isl_calloc_array(dim->ctx, struct isl_name *,
319                                          nparam + n_in + n_out);
320                 if (!names)
321                         goto error;
322                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
323                 get_names(dim, isl_dim_in, 0, dim->n_in, names + nparam);
324                 get_names(dim, isl_dim_out, 0, dim->n_out,
325                                 names + nparam + n_in);
326                 free(dim->names);
327                 dim->names = names;
328                 dim->n_name = nparam + n_in + n_out;
329         }
330         dim->nparam = nparam;
331         dim->n_in = n_in;
332         dim->n_out = n_out;
333
334         return dim;
335 error:
336         free(names);
337         isl_dim_free(dim);
338         return NULL;
339 }
340
341 struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
342         unsigned n)
343 {
344         switch (type) {
345         case isl_dim_param:
346                 return isl_dim_extend(dim,
347                                         dim->nparam + n, dim->n_in, dim->n_out);
348         case isl_dim_in:
349                 return isl_dim_extend(dim,
350                                         dim->nparam, dim->n_in + n, dim->n_out);
351         case isl_dim_out:
352                 return isl_dim_extend(dim,
353                                         dim->nparam, dim->n_in, dim->n_out + n);
354         }
355         return dim;
356 }
357
358 __isl_give isl_dim *isl_dim_insert(__isl_take isl_dim *dim,
359         enum isl_dim_type type, unsigned pos, unsigned n)
360 {
361         struct isl_name **names = NULL;
362
363         if (!dim)
364                 return NULL;
365         if (n == 0)
366                 return dim;
367
368         isl_assert(dim->ctx, pos <= isl_dim_size(dim, type), goto error);
369
370         dim = isl_dim_cow(dim);
371         if (!dim)
372                 return NULL;
373
374         if (dim->names) {
375                 enum isl_dim_type t;
376                 int off;
377                 int size[3];
378                 names = isl_calloc_array(dim->ctx, struct isl_name *,
379                                      dim->nparam + dim->n_in + dim->n_out + n);
380                 if (!names)
381                         goto error;
382                 off = 0;
383                 size[isl_dim_param] = dim->nparam;
384                 size[isl_dim_in] = dim->n_in;
385                 size[isl_dim_out] = dim->n_out;
386                 for (t = isl_dim_param; t <= isl_dim_out; ++t) {
387                         if (t != type) {
388                                 get_names(dim, t, 0, size[t], names + off);
389                                 off += size[t];
390                         } else {
391                                 get_names(dim, t, 0, pos, names + off);
392                                 off += pos + n;
393                                 get_names(dim, t, pos, size[t]-pos, names+off);
394                                 off += size[t] - pos;
395                         }
396                 }
397                 free(dim->names);
398                 dim->names = names;
399                 dim->n_name = dim->nparam + dim->n_in + dim->n_out + n;
400         }
401         switch (type) {
402         case isl_dim_param:     dim->nparam += n; break;
403         case isl_dim_in:        dim->n_in += n; break;
404         case isl_dim_out:       dim->n_out += n; break;
405         }
406
407         return dim;
408 error:
409         isl_dim_free(dim);
410         return NULL;
411 }
412
413 __isl_give isl_dim *isl_dim_move(__isl_take isl_dim *dim,
414         enum isl_dim_type dst_type, unsigned dst_pos,
415         enum isl_dim_type src_type, unsigned src_pos, unsigned n)
416 {
417         if (!dim)
418                 return NULL;
419         if (n == 0)
420                 return dim;
421
422         isl_assert(dim->ctx, src_pos + n <= isl_dim_size(dim, src_type),
423                 goto error);
424
425         if (dst_type == src_type && dst_pos == src_pos)
426                 return dim;
427
428         isl_assert(dim->ctx, dst_type != src_type, goto error);
429
430         dim = isl_dim_cow(dim);
431         if (!dim)
432                 return NULL;
433
434         if (dim->names) {
435                 struct isl_name **names;
436                 enum isl_dim_type t;
437                 int off;
438                 int size[3];
439                 names = isl_calloc_array(dim->ctx, struct isl_name *,
440                                          dim->nparam + dim->n_in + dim->n_out);
441                 if (!names)
442                         goto error;
443                 off = 0;
444                 size[isl_dim_param] = dim->nparam;
445                 size[isl_dim_in] = dim->n_in;
446                 size[isl_dim_out] = dim->n_out;
447                 for (t = isl_dim_param; t <= isl_dim_out; ++t) {
448                         if (t == dst_type) {
449                                 get_names(dim, t, 0, dst_pos, names + off);
450                                 off += dst_pos;
451                                 get_names(dim, src_type, src_pos, n, names+off);
452                                 off += n;
453                                 get_names(dim, t, dst_pos, size[t] - dst_pos,
454                                                 names + off);
455                                 off += size[t] - dst_pos;
456                         } else if (t == src_type) {
457                                 get_names(dim, t, 0, src_pos, names + off);
458                                 off += src_pos;
459                                 get_names(dim, t, src_pos + n,
460                                             size[t] - src_pos - n, names + off);
461                                 off += size[t] - src_pos - n;
462                         } else {
463                                 get_names(dim, t, 0, size[t], names + off);
464                                 off += size[t];
465                         }
466                 }
467                 free(dim->names);
468                 dim->names = names;
469                 dim->n_name = dim->nparam + dim->n_in + dim->n_out;
470         }
471
472         switch (dst_type) {
473         case isl_dim_param:     dim->nparam += n; break;
474         case isl_dim_in:        dim->n_in += n; break;
475         case isl_dim_out:       dim->n_out += n; break;
476         }
477
478         switch (src_type) {
479         case isl_dim_param:     dim->nparam -= n; break;
480         case isl_dim_in:        dim->n_in -= n; break;
481         case isl_dim_out:       dim->n_out -= n; break;
482         }
483
484         return dim;
485 error:
486         isl_dim_free(dim);
487         return NULL;
488 }
489
490 struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
491 {
492         struct isl_dim *dim;
493
494         if (!left || !right)
495                 goto error;
496
497         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
498                         goto error);
499         isl_assert(left->ctx, n(left, isl_dim_out) == n(right, isl_dim_in),
500                         goto error);
501
502         dim = isl_dim_alloc(left->ctx, left->nparam, left->n_in, right->n_out);
503         if (!dim)
504                 goto error;
505
506         dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
507         dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
508         dim = copy_names(dim, isl_dim_out, 0, right, isl_dim_out);
509
510         isl_dim_free(left);
511         isl_dim_free(right);
512
513         return dim;
514 error:
515         isl_dim_free(left);
516         isl_dim_free(right);
517         return NULL;
518 }
519
520 struct isl_dim *isl_dim_product(struct isl_dim *left, struct isl_dim *right)
521 {
522         struct isl_dim *dim;
523
524         if (!left || !right)
525                 goto error;
526
527         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
528                         goto error);
529
530         dim = isl_dim_alloc(left->ctx, left->nparam,
531                         left->n_in + right->n_in, left->n_out + right->n_out);
532         if (!dim)
533                 goto error;
534
535         dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
536         dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
537         dim = copy_names(dim, isl_dim_in, left->n_in, right, isl_dim_in);
538         dim = copy_names(dim, isl_dim_out, 0, left, isl_dim_out);
539         dim = copy_names(dim, isl_dim_out, left->n_out, right, isl_dim_out);
540
541         isl_dim_free(left);
542         isl_dim_free(right);
543
544         return dim;
545 error:
546         isl_dim_free(left);
547         isl_dim_free(right);
548         return NULL;
549 }
550
551 struct isl_dim *isl_dim_map(struct isl_dim *dim)
552 {
553         struct isl_name **names = NULL;
554
555         if (!dim)
556                 return NULL;
557         isl_assert(dim->ctx, dim->n_in == 0, goto error);
558         if (dim->n_out == 0)
559                 return dim;
560         dim = isl_dim_cow(dim);
561         if (!dim)
562                 return NULL;
563         if (dim->names) {
564                 names = isl_calloc_array(dim->ctx, struct isl_name *,
565                                         dim->nparam + dim->n_out + dim->n_out);
566                 if (!names)
567                         goto error;
568                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
569                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->nparam);
570         }
571         dim->n_in = dim->n_out;
572         if (names) {
573                 free(dim->names);
574                 dim->names = names;
575                 dim->n_name = dim->nparam + dim->n_out + dim->n_out;
576                 dim = copy_names(dim, isl_dim_out, 0, dim, isl_dim_in);
577         }
578         return dim;
579 error:
580         isl_dim_free(dim);
581         return NULL;
582 }
583
584 static struct isl_dim *set_names(struct isl_dim *dim, enum isl_dim_type type,
585         unsigned first, unsigned n, struct isl_name **names)
586 {
587         int i;
588
589         for (i = 0; i < n ; ++i)
590                 dim = set_name(dim, type, first+i, names[i]);
591
592         return dim;
593 }
594
595 struct isl_dim *isl_dim_reverse(struct isl_dim *dim)
596 {
597         unsigned t;
598         struct isl_name **names = NULL;
599
600         if (!dim)
601                 return NULL;
602         if (match(dim, isl_dim_in, dim, isl_dim_out))
603                 return dim;
604
605         dim = isl_dim_cow(dim);
606         if (!dim)
607                 return NULL;
608
609         if (dim->names) {
610                 names = isl_alloc_array(dim->ctx, struct isl_name *,
611                                         dim->n_in + dim->n_out);
612                 if (!names)
613                         goto error;
614                 get_names(dim, isl_dim_in, 0, dim->n_in, names);
615                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->n_in);
616         }
617
618         t = dim->n_in;
619         dim->n_in = dim->n_out;
620         dim->n_out = t;
621
622         if (dim->names) {
623                 dim = set_names(dim, isl_dim_out, 0, dim->n_out, names);
624                 dim = set_names(dim, isl_dim_in, 0, dim->n_in, names + dim->n_out);
625                 free(names);
626         }
627
628         return dim;
629 error:
630         free(names);
631         isl_dim_free(dim);
632         return NULL;
633 }
634
635 struct isl_dim *isl_dim_drop(struct isl_dim *dim, enum isl_dim_type type,
636                 unsigned first, unsigned num)
637 {
638         int i;
639
640         if (!dim)
641                 return NULL;
642
643         if (n == 0)
644                 return dim;
645
646         isl_assert(dim->ctx, first + num <= n(dim, type), goto error);
647         dim = isl_dim_cow(dim);
648         if (!dim)
649                 goto error;
650         if (dim->names) {
651                 dim = extend_names(dim);
652                 if (!dim)
653                         goto error;
654                 for (i = 0; i < num; ++i)
655                         isl_name_free(dim->ctx, get_name(dim, type, first+i));
656                 for (i = first+num; i < n(dim, type); ++i)
657                         set_name(dim, type, i - num, get_name(dim, type, i));
658                 switch (type) {
659                 case isl_dim_param:
660                         get_names(dim, isl_dim_in, 0, dim->n_in,
661                                 dim->names + offset(dim, isl_dim_in) - num);
662                 case isl_dim_in:
663                         get_names(dim, isl_dim_out, 0, dim->n_out,
664                                 dim->names + offset(dim, isl_dim_out) - num);
665                 case isl_dim_out:
666                         ;
667                 }
668                 dim->n_name -= num;
669         }
670         switch (type) {
671         case isl_dim_param:     dim->nparam -= num; break;
672         case isl_dim_in:        dim->n_in -= num; break;
673         case isl_dim_out:       dim->n_out -= num; break;
674         }
675         return dim;
676 error:
677         isl_dim_free(dim);
678         return NULL;
679 }
680
681 struct isl_dim *isl_dim_drop_inputs(struct isl_dim *dim,
682                 unsigned first, unsigned n)
683 {
684         return isl_dim_drop(dim, isl_dim_in, first, n);
685 }
686
687 struct isl_dim *isl_dim_drop_outputs(struct isl_dim *dim,
688                 unsigned first, unsigned n)
689 {
690         return isl_dim_drop(dim, isl_dim_out, first, n);
691 }
692
693 struct isl_dim *isl_dim_domain(struct isl_dim *dim)
694 {
695         if (!dim)
696                 return NULL;
697         dim = isl_dim_drop_outputs(dim, 0, dim->n_out);
698         return isl_dim_reverse(dim);
699 }
700
701 struct isl_dim *isl_dim_range(struct isl_dim *dim)
702 {
703         if (!dim)
704                 return NULL;
705         return isl_dim_drop_inputs(dim, 0, dim->n_in);
706 }
707
708 struct isl_dim *isl_dim_underlying(struct isl_dim *dim, unsigned n_div)
709 {
710         int i;
711
712         if (!dim)
713                 return NULL;
714         if (n_div == 0 &&
715             dim->nparam == 0 && dim->n_in == 0 && dim->n_name == 0)
716                 return dim;
717         dim = isl_dim_cow(dim);
718         if (!dim)
719                 return NULL;
720         dim->n_out += dim->nparam + dim->n_in + n_div;
721         dim->nparam = 0;
722         dim->n_in = 0;
723
724         for (i = 0; i < dim->n_name; ++i)
725                 isl_name_free(dim->ctx, get_name(dim, isl_dim_out, i));
726         dim->n_name = 0;
727
728         return dim;
729 }
730
731 unsigned isl_dim_total(struct isl_dim *dim)
732 {
733         return dim->nparam + dim->n_in + dim->n_out;
734 }
735
736 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
737 {
738         return match(dim1, isl_dim_param, dim2, isl_dim_param) &&
739                n(dim1, isl_dim_in) == n(dim2, isl_dim_in) &&
740                n(dim1, isl_dim_out) == n(dim2, isl_dim_out);
741 }
742
743 int isl_dim_compatible(struct isl_dim *dim1, struct isl_dim *dim2)
744 {
745         return dim1->nparam == dim2->nparam &&
746                dim1->n_in + dim1->n_out == dim2->n_in + dim2->n_out;
747 }