isl_tab_rollback: avoid double free on error path
[platform/upstream/isl.git] / isl_dim.c
1 /*
2  * Copyright 2008-2009 Katholieke Universiteit Leuven
3  *
4  * Use of this software is governed by the GNU LGPLv2.1 license
5  *
6  * Written by Sven Verdoolaege, K.U.Leuven, Departement
7  * Computerwetenschappen, Celestijnenlaan 200A, B-3001 Leuven, Belgium
8  */
9
10 #include <isl_dim.h>
11 #include "isl_name.h"
12
13 struct isl_dim *isl_dim_alloc(struct isl_ctx *ctx,
14                         unsigned nparam, unsigned n_in, unsigned n_out)
15 {
16         struct isl_dim *dim;
17
18         dim = isl_alloc_type(ctx, struct isl_dim);
19         if (!dim)
20                 return NULL;
21
22         dim->ctx = ctx;
23         isl_ctx_ref(ctx);
24         dim->ref = 1;
25         dim->nparam = nparam;
26         dim->n_in = n_in;
27         dim->n_out = n_out;
28
29         dim->n_name = 0;
30         dim->names = NULL;
31
32         return dim;
33 }
34
35 struct isl_dim *isl_dim_set_alloc(struct isl_ctx *ctx,
36                         unsigned nparam, unsigned dim)
37 {
38         return isl_dim_alloc(ctx, nparam, 0, dim);
39 }
40
41 static unsigned global_pos(struct isl_dim *dim,
42                                  enum isl_dim_type type, unsigned pos)
43 {
44         struct isl_ctx *ctx = dim->ctx;
45
46         switch (type) {
47         case isl_dim_param:
48                 isl_assert(ctx, pos < dim->nparam, return isl_dim_total(dim));
49                 return pos;
50         case isl_dim_in:
51                 isl_assert(ctx, pos < dim->n_in, return isl_dim_total(dim));
52                 return pos + dim->nparam;
53         case isl_dim_out:
54                 isl_assert(ctx, pos < dim->n_out, return isl_dim_total(dim));
55                 return pos + dim->nparam + dim->n_in;
56         default:
57                 isl_assert(ctx, 0, return isl_dim_total(dim));
58         }
59         return isl_dim_total(dim);
60 }
61
62 /* Extend length of names array to the total number of dimensions.
63  */
64 static __isl_give isl_dim *extend_names(__isl_take isl_dim *dim)
65 {
66         struct isl_name **names;
67         int i;
68
69         if (isl_dim_total(dim) <= dim->n_name)
70                 return dim;
71
72         if (!dim->names) {
73                 dim->names = isl_calloc_array(dim->ctx,
74                                 struct isl_name *, isl_dim_total(dim));
75                 if (!dim->names)
76                         goto error;
77         } else {
78                 names = isl_realloc_array(dim->ctx, dim->names,
79                                 struct isl_name *, isl_dim_total(dim));
80                 if (!names)
81                         goto error;
82                 dim->names = names;
83                 for (i = dim->n_name; i < isl_dim_total(dim); ++i)
84                         dim->names[i] = NULL;
85         }
86
87         dim->n_name = isl_dim_total(dim);
88
89         return dim;
90 error:
91         isl_dim_free(dim);
92         return NULL;
93 }
94
95 static struct isl_dim *set_name(struct isl_dim *dim,
96                                  enum isl_dim_type type, unsigned pos,
97                                  struct isl_name *name)
98 {
99         struct isl_ctx *ctx = dim->ctx;
100         dim = isl_dim_cow(dim);
101
102         if (!dim)
103                 goto error;
104
105         pos = global_pos(dim, type, pos);
106         isl_assert(ctx, pos != isl_dim_total(dim), goto error);
107
108         if (pos >= dim->n_name) {
109                 if (!name)
110                         return dim;
111                 dim = extend_names(dim);
112                 if (!dim)
113                         goto error;
114         }
115
116         dim->names[pos] = name;
117
118         return dim;
119 error:
120         isl_name_free(ctx, name);
121         isl_dim_free(dim);
122         return NULL;
123 }
124
125 static struct isl_name *get_name(struct isl_dim *dim,
126                                  enum isl_dim_type type, unsigned pos)
127 {
128         if (!dim)
129                 return NULL;
130
131         pos = global_pos(dim, type, pos);
132         if (pos == isl_dim_total(dim))
133                 return NULL;
134         if (pos >= dim->n_name)
135                 return NULL;
136         return dim->names[pos];
137 }
138
139 static unsigned offset(struct isl_dim *dim, enum isl_dim_type type)
140 {
141         switch (type) {
142         case isl_dim_param:     return 0;
143         case isl_dim_in:        return dim->nparam;
144         case isl_dim_out:       return dim->nparam + dim->n_in;
145         default:                return 0;
146         }
147 }
148
149 static unsigned n(struct isl_dim *dim, enum isl_dim_type type)
150 {
151         switch (type) {
152         case isl_dim_param:     return dim->nparam;
153         case isl_dim_in:        return dim->n_in;
154         case isl_dim_out:       return dim->n_out;
155         default:                return 0;
156         }
157 }
158
159 unsigned isl_dim_size(struct isl_dim *dim, enum isl_dim_type type)
160 {
161         if (!dim)
162                 return 0;
163         return n(dim, type);
164 }
165
166 unsigned isl_dim_offset(__isl_keep isl_dim *dim, enum isl_dim_type type)
167 {
168         if (!dim)
169                 return 0;
170         return offset(dim, type);
171 }
172
173 static struct isl_dim *copy_names(struct isl_dim *dst,
174         enum isl_dim_type dst_type, unsigned offset, struct isl_dim *src,
175         enum isl_dim_type src_type)
176 {
177         int i;
178         struct isl_name *name;
179
180         if (!dst)
181                 return NULL;
182
183         for (i = 0; i < n(src, src_type); ++i) {
184                 name = get_name(src, src_type, i);
185                 if (!name)
186                         continue;
187                 dst = set_name(dst, dst_type, offset + i,
188                                         isl_name_copy(dst->ctx, name));
189                 if (!dst)
190                         return NULL;
191         }
192         return dst;
193 }
194
195 struct isl_dim *isl_dim_dup(struct isl_dim *dim)
196 {
197         struct isl_dim *dup;
198         if (!dim)
199                 return NULL;
200         dup = isl_dim_alloc(dim->ctx, dim->nparam, dim->n_in, dim->n_out);
201         if (!dim->names)
202                 return dup;
203         dup = copy_names(dup, isl_dim_param, 0, dim, isl_dim_param);
204         dup = copy_names(dup, isl_dim_in, 0, dim, isl_dim_in);
205         dup = copy_names(dup, isl_dim_out, 0, dim, isl_dim_out);
206         return dup;
207 }
208
209 struct isl_dim *isl_dim_cow(struct isl_dim *dim)
210 {
211         if (!dim)
212                 return NULL;
213
214         if (dim->ref == 1)
215                 return dim;
216         dim->ref--;
217         return isl_dim_dup(dim);
218 }
219
220 struct isl_dim *isl_dim_copy(struct isl_dim *dim)
221 {
222         if (!dim)
223                 return NULL;
224
225         dim->ref++;
226         return dim;
227 }
228
229 void isl_dim_free(struct isl_dim *dim)
230 {
231         int i;
232
233         if (!dim)
234                 return;
235
236         if (--dim->ref > 0)
237                 return;
238
239         for (i = 0; i < dim->n_name; ++i)
240                 isl_name_free(dim->ctx, dim->names[i]);
241         free(dim->names);
242         isl_ctx_deref(dim->ctx);
243         
244         free(dim);
245 }
246
247 struct isl_dim *isl_dim_set_name(struct isl_dim *dim,
248                                  enum isl_dim_type type, unsigned pos,
249                                  const char *s)
250 {
251         struct isl_name *name;
252         if (!dim)
253                 return NULL;
254         name = isl_name_get(dim->ctx, s);
255         if (!name)
256                 goto error;
257         return set_name(dim, type, pos, name);
258 error:
259         isl_dim_free(dim);
260         return NULL;
261 }
262
263 const char *isl_dim_get_name(struct isl_dim *dim,
264                                  enum isl_dim_type type, unsigned pos)
265 {
266         struct isl_name *name = get_name(dim, type, pos);
267         return name ? name->name : NULL;
268 }
269
270 static int match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
271                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
272 {
273         int i;
274
275         if (n(dim1, dim1_type) != n(dim2, dim2_type))
276                 return 0;
277
278         if (!dim1->names && !dim2->names)
279                 return 1;
280
281         for (i = 0; i < n(dim1, dim1_type); ++i) {
282                 if (get_name(dim1, dim1_type, i) !=
283                     get_name(dim2, dim2_type, i))
284                         return 0;
285         }
286         return 1;
287 }
288
289 int isl_dim_match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
290                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
291 {
292         return match(dim1, dim1_type, dim2, dim2_type);
293 }
294
295 static void get_names(struct isl_dim *dim, enum isl_dim_type type,
296         unsigned first, unsigned n, struct isl_name **names)
297 {
298         int i;
299
300         for (i = 0; i < n ; ++i)
301                 names[i] = get_name(dim, type, first+i);
302 }
303
304 struct isl_dim *isl_dim_extend(struct isl_dim *dim,
305                         unsigned nparam, unsigned n_in, unsigned n_out)
306 {
307         struct isl_name **names = NULL;
308
309         if (!dim)
310                 return NULL;
311         if (dim->nparam == nparam && dim->n_in == n_in && dim->n_out == n_out)
312                 return dim;
313
314         isl_assert(dim->ctx, dim->nparam <= nparam, goto error);
315         isl_assert(dim->ctx, dim->n_in <= n_in, goto error);
316         isl_assert(dim->ctx, dim->n_out <= n_out, goto error);
317
318         dim = isl_dim_cow(dim);
319
320         if (dim->names) {
321                 names = isl_calloc_array(dim->ctx, struct isl_name *,
322                                          nparam + n_in + n_out);
323                 if (!names)
324                         goto error;
325                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
326                 get_names(dim, isl_dim_in, 0, dim->n_in, names + nparam);
327                 get_names(dim, isl_dim_out, 0, dim->n_out,
328                                 names + nparam + n_in);
329                 free(dim->names);
330                 dim->names = names;
331                 dim->n_name = nparam + n_in + n_out;
332         }
333         dim->nparam = nparam;
334         dim->n_in = n_in;
335         dim->n_out = n_out;
336
337         return dim;
338 error:
339         free(names);
340         isl_dim_free(dim);
341         return NULL;
342 }
343
344 struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
345         unsigned n)
346 {
347         switch (type) {
348         case isl_dim_param:
349                 return isl_dim_extend(dim,
350                                         dim->nparam + n, dim->n_in, dim->n_out);
351         case isl_dim_in:
352                 return isl_dim_extend(dim,
353                                         dim->nparam, dim->n_in + n, dim->n_out);
354         case isl_dim_out:
355                 return isl_dim_extend(dim,
356                                         dim->nparam, dim->n_in, dim->n_out + n);
357         }
358         return dim;
359 }
360
361 __isl_give isl_dim *isl_dim_insert(__isl_take isl_dim *dim,
362         enum isl_dim_type type, unsigned pos, unsigned n)
363 {
364         struct isl_name **names = NULL;
365
366         if (!dim)
367                 return NULL;
368         if (n == 0)
369                 return dim;
370
371         isl_assert(dim->ctx, pos <= isl_dim_size(dim, type), goto error);
372
373         dim = isl_dim_cow(dim);
374         if (!dim)
375                 return NULL;
376
377         if (dim->names) {
378                 enum isl_dim_type t;
379                 int off;
380                 int size[3];
381                 names = isl_calloc_array(dim->ctx, struct isl_name *,
382                                      dim->nparam + dim->n_in + dim->n_out + n);
383                 if (!names)
384                         goto error;
385                 off = 0;
386                 size[isl_dim_param] = dim->nparam;
387                 size[isl_dim_in] = dim->n_in;
388                 size[isl_dim_out] = dim->n_out;
389                 for (t = isl_dim_param; t <= isl_dim_out; ++t) {
390                         if (t != type) {
391                                 get_names(dim, t, 0, size[t], names + off);
392                                 off += size[t];
393                         } else {
394                                 get_names(dim, t, 0, pos, names + off);
395                                 off += pos + n;
396                                 get_names(dim, t, pos, size[t]-pos, names+off);
397                                 off += size[t] - pos;
398                         }
399                 }
400                 free(dim->names);
401                 dim->names = names;
402                 dim->n_name = dim->nparam + dim->n_in + dim->n_out + n;
403         }
404         switch (type) {
405         case isl_dim_param:     dim->nparam += n; break;
406         case isl_dim_in:        dim->n_in += n; break;
407         case isl_dim_out:       dim->n_out += n; break;
408         }
409
410         return dim;
411 error:
412         isl_dim_free(dim);
413         return NULL;
414 }
415
416 __isl_give isl_dim *isl_dim_move(__isl_take isl_dim *dim,
417         enum isl_dim_type dst_type, unsigned dst_pos,
418         enum isl_dim_type src_type, unsigned src_pos, unsigned n)
419 {
420         if (!dim)
421                 return NULL;
422         if (n == 0)
423                 return dim;
424
425         isl_assert(dim->ctx, src_pos + n <= isl_dim_size(dim, src_type),
426                 goto error);
427
428         if (dst_type == src_type && dst_pos == src_pos)
429                 return dim;
430
431         isl_assert(dim->ctx, dst_type != src_type, goto error);
432
433         dim = isl_dim_cow(dim);
434         if (!dim)
435                 return NULL;
436
437         if (dim->names) {
438                 struct isl_name **names;
439                 enum isl_dim_type t;
440                 int off;
441                 int size[3];
442                 names = isl_calloc_array(dim->ctx, struct isl_name *,
443                                          dim->nparam + dim->n_in + dim->n_out);
444                 if (!names)
445                         goto error;
446                 off = 0;
447                 size[isl_dim_param] = dim->nparam;
448                 size[isl_dim_in] = dim->n_in;
449                 size[isl_dim_out] = dim->n_out;
450                 for (t = isl_dim_param; t <= isl_dim_out; ++t) {
451                         if (t == dst_type) {
452                                 get_names(dim, t, 0, dst_pos, names + off);
453                                 off += dst_pos;
454                                 get_names(dim, src_type, src_pos, n, names+off);
455                                 off += n;
456                                 get_names(dim, t, dst_pos, size[t] - dst_pos,
457                                                 names + off);
458                                 off += size[t] - dst_pos;
459                         } else if (t == src_type) {
460                                 get_names(dim, t, 0, src_pos, names + off);
461                                 off += src_pos;
462                                 get_names(dim, t, src_pos + n,
463                                             size[t] - src_pos - n, names + off);
464                                 off += size[t] - src_pos - n;
465                         } else {
466                                 get_names(dim, t, 0, size[t], names + off);
467                                 off += size[t];
468                         }
469                 }
470                 free(dim->names);
471                 dim->names = names;
472                 dim->n_name = dim->nparam + dim->n_in + dim->n_out;
473         }
474
475         switch (dst_type) {
476         case isl_dim_param:     dim->nparam += n; break;
477         case isl_dim_in:        dim->n_in += n; break;
478         case isl_dim_out:       dim->n_out += n; break;
479         }
480
481         switch (src_type) {
482         case isl_dim_param:     dim->nparam -= n; break;
483         case isl_dim_in:        dim->n_in -= n; break;
484         case isl_dim_out:       dim->n_out -= n; break;
485         }
486
487         return dim;
488 error:
489         isl_dim_free(dim);
490         return NULL;
491 }
492
493 struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
494 {
495         struct isl_dim *dim;
496
497         if (!left || !right)
498                 goto error;
499
500         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
501                         goto error);
502         isl_assert(left->ctx, n(left, isl_dim_out) == n(right, isl_dim_in),
503                         goto error);
504
505         dim = isl_dim_alloc(left->ctx, left->nparam, left->n_in, right->n_out);
506         if (!dim)
507                 goto error;
508
509         dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
510         dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
511         dim = copy_names(dim, isl_dim_out, 0, right, isl_dim_out);
512
513         isl_dim_free(left);
514         isl_dim_free(right);
515
516         return dim;
517 error:
518         isl_dim_free(left);
519         isl_dim_free(right);
520         return NULL;
521 }
522
523 struct isl_dim *isl_dim_product(struct isl_dim *left, struct isl_dim *right)
524 {
525         struct isl_dim *dim;
526
527         if (!left || !right)
528                 goto error;
529
530         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
531                         goto error);
532
533         dim = isl_dim_alloc(left->ctx, left->nparam,
534                         left->n_in + right->n_in, left->n_out + right->n_out);
535         if (!dim)
536                 goto error;
537
538         dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
539         dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
540         dim = copy_names(dim, isl_dim_in, left->n_in, right, isl_dim_in);
541         dim = copy_names(dim, isl_dim_out, 0, left, isl_dim_out);
542         dim = copy_names(dim, isl_dim_out, left->n_out, right, isl_dim_out);
543
544         isl_dim_free(left);
545         isl_dim_free(right);
546
547         return dim;
548 error:
549         isl_dim_free(left);
550         isl_dim_free(right);
551         return NULL;
552 }
553
554 struct isl_dim *isl_dim_map(struct isl_dim *dim)
555 {
556         struct isl_name **names = NULL;
557
558         if (!dim)
559                 return NULL;
560         isl_assert(dim->ctx, dim->n_in == 0, goto error);
561         if (dim->n_out == 0)
562                 return dim;
563         dim = isl_dim_cow(dim);
564         if (!dim)
565                 return NULL;
566         if (dim->names) {
567                 names = isl_calloc_array(dim->ctx, struct isl_name *,
568                                         dim->nparam + dim->n_out + dim->n_out);
569                 if (!names)
570                         goto error;
571                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
572                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->nparam);
573         }
574         dim->n_in = dim->n_out;
575         if (names) {
576                 free(dim->names);
577                 dim->names = names;
578                 dim->n_name = dim->nparam + dim->n_out + dim->n_out;
579                 dim = copy_names(dim, isl_dim_out, 0, dim, isl_dim_in);
580         }
581         return dim;
582 error:
583         isl_dim_free(dim);
584         return NULL;
585 }
586
587 static struct isl_dim *set_names(struct isl_dim *dim, enum isl_dim_type type,
588         unsigned first, unsigned n, struct isl_name **names)
589 {
590         int i;
591
592         for (i = 0; i < n ; ++i)
593                 dim = set_name(dim, type, first+i, names[i]);
594
595         return dim;
596 }
597
598 struct isl_dim *isl_dim_reverse(struct isl_dim *dim)
599 {
600         unsigned t;
601         struct isl_name **names = NULL;
602
603         if (!dim)
604                 return NULL;
605         if (match(dim, isl_dim_in, dim, isl_dim_out))
606                 return dim;
607
608         dim = isl_dim_cow(dim);
609         if (!dim)
610                 return NULL;
611
612         if (dim->names) {
613                 names = isl_alloc_array(dim->ctx, struct isl_name *,
614                                         dim->n_in + dim->n_out);
615                 if (!names)
616                         goto error;
617                 get_names(dim, isl_dim_in, 0, dim->n_in, names);
618                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->n_in);
619         }
620
621         t = dim->n_in;
622         dim->n_in = dim->n_out;
623         dim->n_out = t;
624
625         if (dim->names) {
626                 dim = set_names(dim, isl_dim_out, 0, dim->n_out, names);
627                 dim = set_names(dim, isl_dim_in, 0, dim->n_in, names + dim->n_out);
628                 free(names);
629         }
630
631         return dim;
632 error:
633         free(names);
634         isl_dim_free(dim);
635         return NULL;
636 }
637
638 struct isl_dim *isl_dim_drop(struct isl_dim *dim, enum isl_dim_type type,
639                 unsigned first, unsigned num)
640 {
641         int i;
642
643         if (!dim)
644                 return NULL;
645
646         if (n == 0)
647                 return dim;
648
649         isl_assert(dim->ctx, first + num <= n(dim, type), goto error);
650         dim = isl_dim_cow(dim);
651         if (!dim)
652                 goto error;
653         if (dim->names) {
654                 dim = extend_names(dim);
655                 if (!dim)
656                         goto error;
657                 for (i = 0; i < num; ++i)
658                         isl_name_free(dim->ctx, get_name(dim, type, first+i));
659                 for (i = first+num; i < n(dim, type); ++i)
660                         set_name(dim, type, i - num, get_name(dim, type, i));
661                 switch (type) {
662                 case isl_dim_param:
663                         get_names(dim, isl_dim_in, 0, dim->n_in,
664                                 dim->names + offset(dim, isl_dim_in) - num);
665                 case isl_dim_in:
666                         get_names(dim, isl_dim_out, 0, dim->n_out,
667                                 dim->names + offset(dim, isl_dim_out) - num);
668                 case isl_dim_out:
669                         ;
670                 }
671                 dim->n_name -= num;
672         }
673         switch (type) {
674         case isl_dim_param:     dim->nparam -= num; break;
675         case isl_dim_in:        dim->n_in -= num; break;
676         case isl_dim_out:       dim->n_out -= num; break;
677         }
678         return dim;
679 error:
680         isl_dim_free(dim);
681         return NULL;
682 }
683
684 struct isl_dim *isl_dim_drop_inputs(struct isl_dim *dim,
685                 unsigned first, unsigned n)
686 {
687         return isl_dim_drop(dim, isl_dim_in, first, n);
688 }
689
690 struct isl_dim *isl_dim_drop_outputs(struct isl_dim *dim,
691                 unsigned first, unsigned n)
692 {
693         return isl_dim_drop(dim, isl_dim_out, first, n);
694 }
695
696 struct isl_dim *isl_dim_domain(struct isl_dim *dim)
697 {
698         if (!dim)
699                 return NULL;
700         dim = isl_dim_drop_outputs(dim, 0, dim->n_out);
701         return isl_dim_reverse(dim);
702 }
703
704 struct isl_dim *isl_dim_range(struct isl_dim *dim)
705 {
706         if (!dim)
707                 return NULL;
708         return isl_dim_drop_inputs(dim, 0, dim->n_in);
709 }
710
711 struct isl_dim *isl_dim_underlying(struct isl_dim *dim, unsigned n_div)
712 {
713         int i;
714
715         if (!dim)
716                 return NULL;
717         if (n_div == 0 &&
718             dim->nparam == 0 && dim->n_in == 0 && dim->n_name == 0)
719                 return dim;
720         dim = isl_dim_cow(dim);
721         if (!dim)
722                 return NULL;
723         dim->n_out += dim->nparam + dim->n_in + n_div;
724         dim->nparam = 0;
725         dim->n_in = 0;
726
727         for (i = 0; i < dim->n_name; ++i)
728                 isl_name_free(dim->ctx, get_name(dim, isl_dim_out, i));
729         dim->n_name = 0;
730
731         return dim;
732 }
733
734 unsigned isl_dim_total(struct isl_dim *dim)
735 {
736         return dim->nparam + dim->n_in + dim->n_out;
737 }
738
739 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
740 {
741         return match(dim1, isl_dim_param, dim2, isl_dim_param) &&
742                n(dim1, isl_dim_in) == n(dim2, isl_dim_in) &&
743                n(dim1, isl_dim_out) == n(dim2, isl_dim_out);
744 }
745
746 int isl_dim_compatible(struct isl_dim *dim1, struct isl_dim *dim2)
747 {
748         return dim1->nparam == dim2->nparam &&
749                dim1->n_in + dim1->n_out == dim2->n_in + dim2->n_out;
750 }