isl_stream: accept "@" token
[platform/upstream/isl.git] / isl_dim.c
1 /*
2  * Copyright 2008-2009 Katholieke Universiteit Leuven
3  *
4  * Use of this software is governed by the GNU LGPLv2.1 license
5  *
6  * Written by Sven Verdoolaege, K.U.Leuven, Departement
7  * Computerwetenschappen, Celestijnenlaan 200A, B-3001 Leuven, Belgium
8  */
9
10 #include <isl_dim_private.h>
11 #include "isl_name.h"
12
13 struct isl_dim *isl_dim_alloc(struct isl_ctx *ctx,
14                         unsigned nparam, unsigned n_in, unsigned n_out)
15 {
16         struct isl_dim *dim;
17
18         dim = isl_alloc_type(ctx, struct isl_dim);
19         if (!dim)
20                 return NULL;
21
22         dim->ctx = ctx;
23         isl_ctx_ref(ctx);
24         dim->ref = 1;
25         dim->nparam = nparam;
26         dim->n_in = n_in;
27         dim->n_out = n_out;
28
29         dim->n_name = 0;
30         dim->names = NULL;
31
32         return dim;
33 }
34
35 struct isl_dim *isl_dim_set_alloc(struct isl_ctx *ctx,
36                         unsigned nparam, unsigned dim)
37 {
38         return isl_dim_alloc(ctx, nparam, 0, dim);
39 }
40
41 static unsigned global_pos(struct isl_dim *dim,
42                                  enum isl_dim_type type, unsigned pos)
43 {
44         struct isl_ctx *ctx = dim->ctx;
45
46         switch (type) {
47         case isl_dim_param:
48                 isl_assert(ctx, pos < dim->nparam, return isl_dim_total(dim));
49                 return pos;
50         case isl_dim_in:
51                 isl_assert(ctx, pos < dim->n_in, return isl_dim_total(dim));
52                 return pos + dim->nparam;
53         case isl_dim_out:
54                 isl_assert(ctx, pos < dim->n_out, return isl_dim_total(dim));
55                 return pos + dim->nparam + dim->n_in;
56         default:
57                 isl_assert(ctx, 0, return isl_dim_total(dim));
58         }
59         return isl_dim_total(dim);
60 }
61
62 /* Extend length of names array to the total number of dimensions.
63  */
64 static __isl_give isl_dim *extend_names(__isl_take isl_dim *dim)
65 {
66         struct isl_name **names;
67         int i;
68
69         if (isl_dim_total(dim) <= dim->n_name)
70                 return dim;
71
72         if (!dim->names) {
73                 dim->names = isl_calloc_array(dim->ctx,
74                                 struct isl_name *, isl_dim_total(dim));
75                 if (!dim->names)
76                         goto error;
77         } else {
78                 names = isl_realloc_array(dim->ctx, dim->names,
79                                 struct isl_name *, isl_dim_total(dim));
80                 if (!names)
81                         goto error;
82                 dim->names = names;
83                 for (i = dim->n_name; i < isl_dim_total(dim); ++i)
84                         dim->names[i] = NULL;
85         }
86
87         dim->n_name = isl_dim_total(dim);
88
89         return dim;
90 error:
91         isl_dim_free(dim);
92         return NULL;
93 }
94
95 static struct isl_dim *set_name(struct isl_dim *dim,
96                                  enum isl_dim_type type, unsigned pos,
97                                  struct isl_name *name)
98 {
99         struct isl_ctx *ctx = dim->ctx;
100         dim = isl_dim_cow(dim);
101
102         if (!dim)
103                 goto error;
104
105         pos = global_pos(dim, type, pos);
106         isl_assert(ctx, pos != isl_dim_total(dim), goto error);
107
108         if (pos >= dim->n_name) {
109                 if (!name)
110                         return dim;
111                 dim = extend_names(dim);
112                 if (!dim)
113                         goto error;
114         }
115
116         dim->names[pos] = name;
117
118         return dim;
119 error:
120         isl_name_free(ctx, name);
121         isl_dim_free(dim);
122         return NULL;
123 }
124
125 static struct isl_name *get_name(struct isl_dim *dim,
126                                  enum isl_dim_type type, unsigned pos)
127 {
128         if (!dim)
129                 return NULL;
130
131         pos = global_pos(dim, type, pos);
132         if (pos == isl_dim_total(dim))
133                 return NULL;
134         if (pos >= dim->n_name)
135                 return NULL;
136         return dim->names[pos];
137 }
138
139 static unsigned offset(struct isl_dim *dim, enum isl_dim_type type)
140 {
141         switch (type) {
142         case isl_dim_param:     return 0;
143         case isl_dim_in:        return dim->nparam;
144         case isl_dim_out:       return dim->nparam + dim->n_in;
145         }
146 }
147
148 static unsigned n(struct isl_dim *dim, enum isl_dim_type type)
149 {
150         switch (type) {
151         case isl_dim_param:     return dim->nparam;
152         case isl_dim_in:        return dim->n_in;
153         case isl_dim_out:       return dim->n_out;
154         }
155 }
156
157 unsigned isl_dim_size(struct isl_dim *dim, enum isl_dim_type type)
158 {
159         if (!dim)
160                 return 0;
161         return n(dim, type);
162 }
163
164 static struct isl_dim *copy_names(struct isl_dim *dst,
165         enum isl_dim_type dst_type, unsigned offset, struct isl_dim *src,
166         enum isl_dim_type src_type)
167 {
168         int i;
169         struct isl_name *name;
170
171         for (i = 0; i < n(src, src_type); ++i) {
172                 name = get_name(src, src_type, i);
173                 if (!name)
174                         continue;
175                 dst = set_name(dst, dst_type, offset + i,
176                                         isl_name_copy(dst->ctx, name));
177                 if (!dst)
178                         return NULL;
179         }
180         return dst;
181 }
182
183 struct isl_dim *isl_dim_dup(struct isl_dim *dim)
184 {
185         struct isl_dim *dup;
186         dup = isl_dim_alloc(dim->ctx, dim->nparam, dim->n_in, dim->n_out);
187         if (!dim->names)
188                 return dup;
189         dup = copy_names(dup, isl_dim_param, 0, dim, isl_dim_param);
190         dup = copy_names(dup, isl_dim_in, 0, dim, isl_dim_in);
191         dup = copy_names(dup, isl_dim_out, 0, dim, isl_dim_out);
192         return dup;
193 }
194
195 struct isl_dim *isl_dim_cow(struct isl_dim *dim)
196 {
197         if (!dim)
198                 return NULL;
199
200         if (dim->ref == 1)
201                 return dim;
202         dim->ref--;
203         return isl_dim_dup(dim);
204 }
205
206 struct isl_dim *isl_dim_copy(struct isl_dim *dim)
207 {
208         if (!dim)
209                 return NULL;
210
211         dim->ref++;
212         return dim;
213 }
214
215 void isl_dim_free(struct isl_dim *dim)
216 {
217         int i;
218
219         if (!dim)
220                 return;
221
222         if (--dim->ref > 0)
223                 return;
224
225         for (i = 0; i < dim->n_name; ++i)
226                 isl_name_free(dim->ctx, dim->names[i]);
227         free(dim->names);
228         isl_ctx_deref(dim->ctx);
229         
230         free(dim);
231 }
232
233 struct isl_dim *isl_dim_set_name(struct isl_dim *dim,
234                                  enum isl_dim_type type, unsigned pos,
235                                  const char *s)
236 {
237         struct isl_name *name;
238         if (!dim)
239                 return NULL;
240         name = isl_name_get(dim->ctx, s);
241         if (!name)
242                 goto error;
243         return set_name(dim, type, pos, name);
244 error:
245         isl_dim_free(dim);
246         return NULL;
247 }
248
249 const char *isl_dim_get_name(struct isl_dim *dim,
250                                  enum isl_dim_type type, unsigned pos)
251 {
252         struct isl_name *name = get_name(dim, type, pos);
253         return name ? name->name : NULL;
254 }
255
256 static int match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
257                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
258 {
259         int i;
260
261         if (n(dim1, dim1_type) != n(dim2, dim2_type))
262                 return 0;
263
264         if (!dim1->names && !dim2->names)
265                 return 1;
266
267         for (i = 0; i < n(dim1, dim1_type); ++i) {
268                 if (get_name(dim1, dim1_type, i) !=
269                     get_name(dim2, dim2_type, i))
270                         return 0;
271         }
272         return 1;
273 }
274
275 int isl_dim_match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
276                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
277 {
278         return match(dim1, dim1_type, dim2, dim2_type);
279 }
280
281 static void get_names(struct isl_dim *dim, enum isl_dim_type type,
282         unsigned first, unsigned n, struct isl_name **names)
283 {
284         int i;
285
286         for (i = 0; i < n ; ++i)
287                 names[i] = get_name(dim, type, first+i);
288 }
289
290 struct isl_dim *isl_dim_extend(struct isl_dim *dim,
291                         unsigned nparam, unsigned n_in, unsigned n_out)
292 {
293         struct isl_name **names = NULL;
294
295         if (!dim)
296                 return NULL;
297         if (dim->nparam == nparam && dim->n_in == n_in && dim->n_out == n_out)
298                 return dim;
299
300         isl_assert(dim->ctx, dim->nparam <= nparam, goto error);
301         isl_assert(dim->ctx, dim->n_in <= n_in, goto error);
302         isl_assert(dim->ctx, dim->n_out <= n_out, goto error);
303
304         dim = isl_dim_cow(dim);
305
306         if (dim->names) {
307                 names = isl_calloc_array(dim->ctx, struct isl_name *,
308                                          nparam + n_in + n_out);
309                 if (!names)
310                         goto error;
311                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
312                 get_names(dim, isl_dim_in, 0, dim->n_in, names + nparam);
313                 get_names(dim, isl_dim_out, 0, dim->n_out,
314                                 names + nparam + n_in);
315                 free(dim->names);
316                 dim->names = names;
317                 dim->n_name = nparam + n_in + n_out;
318         }
319         dim->nparam = nparam;
320         dim->n_in = n_in;
321         dim->n_out = n_out;
322
323         return dim;
324 error:
325         free(names);
326         isl_dim_free(dim);
327         return NULL;
328 }
329
330 struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
331         unsigned n)
332 {
333         switch (type) {
334         case isl_dim_param:
335                 return isl_dim_extend(dim,
336                                         dim->nparam + n, dim->n_in, dim->n_out);
337         case isl_dim_in:
338                 return isl_dim_extend(dim,
339                                         dim->nparam, dim->n_in + n, dim->n_out);
340         case isl_dim_out:
341                 return isl_dim_extend(dim,
342                                         dim->nparam, dim->n_in, dim->n_out + n);
343         }
344         return dim;
345 }
346
347 __isl_give isl_dim *isl_dim_move(__isl_take isl_dim *dim,
348         enum isl_dim_type dst_type, unsigned dst_pos,
349         enum isl_dim_type src_type, unsigned src_pos, unsigned n)
350 {
351         if (!dim)
352                 return NULL;
353         if (n == 0)
354                 return dim;
355
356         isl_assert(dim->ctx, src_pos + n <= isl_dim_size(dim, src_type),
357                 goto error);
358
359         /* just the simple case for now */
360         isl_assert(dim->ctx,
361                 offset(dim, dst_type) + dst_pos ==
362                 offset(dim, src_type) + src_pos + ((src_type < dst_type) ? n : 0),
363                 goto error);
364
365         if (dst_type == src_type)
366                 return dim;
367
368         dim = isl_dim_cow(dim);
369         if (!dim)
370                 return NULL;
371
372         switch (dst_type) {
373         case isl_dim_param:     dim->nparam += n; break;
374         case isl_dim_in:        dim->n_in += n; break;
375         case isl_dim_out:       dim->n_out += n; break;
376         }
377
378         switch (src_type) {
379         case isl_dim_param:     dim->nparam -= n; break;
380         case isl_dim_in:        dim->n_in -= n; break;
381         case isl_dim_out:       dim->n_out -= n; break;
382         }
383
384         return dim;
385 error:
386         isl_dim_free(dim);
387         return NULL;
388 }
389
390 struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
391 {
392         struct isl_dim *dim;
393
394         if (!left || !right)
395                 goto error;
396
397         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
398                         goto error);
399         isl_assert(left->ctx, n(left, isl_dim_out) == n(right, isl_dim_in),
400                         goto error);
401
402         dim = isl_dim_alloc(left->ctx, left->nparam, left->n_in, right->n_out);
403         if (!dim)
404                 goto error;
405
406         dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
407         dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
408         dim = copy_names(dim, isl_dim_out, 0, right, isl_dim_out);
409
410         isl_dim_free(left);
411         isl_dim_free(right);
412
413         return dim;
414 error:
415         isl_dim_free(left);
416         isl_dim_free(right);
417         return NULL;
418 }
419
420 struct isl_dim *isl_dim_product(struct isl_dim *left, struct isl_dim *right)
421 {
422         struct isl_dim *dim;
423
424         if (!left || !right)
425                 goto error;
426
427         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
428                         goto error);
429
430         dim = isl_dim_alloc(left->ctx, left->nparam,
431                         left->n_in + right->n_in, left->n_out + right->n_out);
432         if (!dim)
433                 goto error;
434
435         dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
436         dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
437         dim = copy_names(dim, isl_dim_in, left->n_in, right, isl_dim_in);
438         dim = copy_names(dim, isl_dim_out, 0, left, isl_dim_out);
439         dim = copy_names(dim, isl_dim_out, left->n_out, right, isl_dim_out);
440
441         isl_dim_free(left);
442         isl_dim_free(right);
443
444         return dim;
445 error:
446         isl_dim_free(left);
447         isl_dim_free(right);
448         return NULL;
449 }
450
451 struct isl_dim *isl_dim_map(struct isl_dim *dim)
452 {
453         struct isl_name **names = NULL;
454
455         if (!dim)
456                 return NULL;
457         isl_assert(dim->ctx, dim->n_in == 0, goto error);
458         if (dim->n_out == 0)
459                 return dim;
460         dim = isl_dim_cow(dim);
461         if (!dim)
462                 return NULL;
463         if (dim->names) {
464                 names = isl_calloc_array(dim->ctx, struct isl_name *,
465                                         dim->nparam + dim->n_out + dim->n_out);
466                 if (!names)
467                         goto error;
468                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
469                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->nparam);
470         }
471         dim->n_in = dim->n_out;
472         if (names) {
473                 free(dim->names);
474                 dim->names = names;
475                 dim->n_name = dim->nparam + dim->n_out + dim->n_out;
476                 dim = copy_names(dim, isl_dim_out, 0, dim, isl_dim_in);
477         }
478         return dim;
479 error:
480         isl_dim_free(dim);
481         return NULL;
482 }
483
484 static struct isl_dim *set_names(struct isl_dim *dim, enum isl_dim_type type,
485         unsigned first, unsigned n, struct isl_name **names)
486 {
487         int i;
488
489         for (i = 0; i < n ; ++i)
490                 dim = set_name(dim, type, first+i, names[i]);
491
492         return dim;
493 }
494
495 struct isl_dim *isl_dim_reverse(struct isl_dim *dim)
496 {
497         unsigned t;
498         struct isl_name **names = NULL;
499
500         if (!dim)
501                 return NULL;
502         if (match(dim, isl_dim_in, dim, isl_dim_out))
503                 return dim;
504
505         dim = isl_dim_cow(dim);
506         if (!dim)
507                 return NULL;
508
509         if (dim->names) {
510                 names = isl_alloc_array(dim->ctx, struct isl_name *,
511                                         dim->n_in + dim->n_out);
512                 if (!names)
513                         goto error;
514                 get_names(dim, isl_dim_in, 0, dim->n_in, names);
515                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->n_in);
516         }
517
518         t = dim->n_in;
519         dim->n_in = dim->n_out;
520         dim->n_out = t;
521
522         if (dim->names) {
523                 dim = set_names(dim, isl_dim_out, 0, dim->n_out, names);
524                 dim = set_names(dim, isl_dim_in, 0, dim->n_in, names + dim->n_out);
525                 free(names);
526         }
527
528         return dim;
529 error:
530         free(names);
531         isl_dim_free(dim);
532         return NULL;
533 }
534
535 struct isl_dim *isl_dim_drop(struct isl_dim *dim, enum isl_dim_type type,
536                 unsigned first, unsigned num)
537 {
538         int i;
539
540         if (!dim)
541                 return NULL;
542
543         if (n == 0)
544                 return dim;
545
546         isl_assert(dim->ctx, first + num <= n(dim, type), goto error);
547         dim = isl_dim_cow(dim);
548         if (!dim)
549                 goto error;
550         if (dim->names) {
551                 dim = extend_names(dim);
552                 if (!dim)
553                         goto error;
554                 for (i = 0; i < num; ++i)
555                         isl_name_free(dim->ctx, get_name(dim, type, first+i));
556                 for (i = first+num; i < n(dim, type); ++i)
557                         set_name(dim, type, i - num, get_name(dim, type, i));
558                 switch (type) {
559                 case isl_dim_param:
560                         get_names(dim, isl_dim_in, 0, dim->n_in,
561                                 dim->names + offset(dim, isl_dim_in) - num);
562                 case isl_dim_in:
563                         get_names(dim, isl_dim_out, 0, dim->n_out,
564                                 dim->names + offset(dim, isl_dim_out) - num);
565                 case isl_dim_out:
566                         ;
567                 }
568                 dim->n_name -= num;
569         }
570         switch (type) {
571         case isl_dim_param:     dim->nparam -= num; break;
572         case isl_dim_in:        dim->n_in -= num; break;
573         case isl_dim_out:       dim->n_out -= num; break;
574         }
575         return dim;
576 error:
577         isl_dim_free(dim);
578         return NULL;
579 }
580
581 struct isl_dim *isl_dim_drop_inputs(struct isl_dim *dim,
582                 unsigned first, unsigned n)
583 {
584         return isl_dim_drop(dim, isl_dim_in, first, n);
585 }
586
587 struct isl_dim *isl_dim_drop_outputs(struct isl_dim *dim,
588                 unsigned first, unsigned n)
589 {
590         return isl_dim_drop(dim, isl_dim_out, first, n);
591 }
592
593 struct isl_dim *isl_dim_domain(struct isl_dim *dim)
594 {
595         if (!dim)
596                 return NULL;
597         dim = isl_dim_drop_outputs(dim, 0, dim->n_out);
598         return isl_dim_reverse(dim);
599 }
600
601 struct isl_dim *isl_dim_range(struct isl_dim *dim)
602 {
603         if (!dim)
604                 return NULL;
605         return isl_dim_drop_inputs(dim, 0, dim->n_in);
606 }
607
608 struct isl_dim *isl_dim_underlying(struct isl_dim *dim, unsigned n_div)
609 {
610         int i;
611
612         if (!dim)
613                 return NULL;
614         if (n_div == 0 &&
615             dim->nparam == 0 && dim->n_in == 0 && dim->n_name == 0)
616                 return dim;
617         dim = isl_dim_cow(dim);
618         if (!dim)
619                 return NULL;
620         dim->n_out += dim->nparam + dim->n_in + n_div;
621         dim->nparam = 0;
622         dim->n_in = 0;
623
624         for (i = 0; i < dim->n_name; ++i)
625                 isl_name_free(dim->ctx, get_name(dim, isl_dim_out, i));
626         dim->n_name = 0;
627
628         return dim;
629 }
630
631 unsigned isl_dim_total(struct isl_dim *dim)
632 {
633         return dim->nparam + dim->n_in + dim->n_out;
634 }
635
636 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
637 {
638         return match(dim1, isl_dim_param, dim2, isl_dim_param) &&
639                n(dim1, isl_dim_in) == n(dim2, isl_dim_in) &&
640                n(dim1, isl_dim_out) == n(dim2, isl_dim_out);
641 }
642
643 int isl_dim_compatible(struct isl_dim *dim1, struct isl_dim *dim2)
644 {
645         return dim1->nparam == dim2->nparam &&
646                dim1->n_in + dim1->n_out == dim2->n_in + dim2->n_out;
647 }