05436e2cc2331a82dad9847db2b42e6b1840067f
[platform/upstream/isl.git] / isl_dim.c
1 /*
2  * Copyright 2008-2009 Katholieke Universiteit Leuven
3  *
4  * Use of this software is governed by the GNU LGPLv2.1 license
5  *
6  * Written by Sven Verdoolaege, K.U.Leuven, Departement
7  * Computerwetenschappen, Celestijnenlaan 200A, B-3001 Leuven, Belgium
8  */
9
10 #include <isl_dim.h>
11 #include "isl_name.h"
12
13 struct isl_dim *isl_dim_alloc(struct isl_ctx *ctx,
14                         unsigned nparam, unsigned n_in, unsigned n_out)
15 {
16         struct isl_dim *dim;
17
18         dim = isl_alloc_type(ctx, struct isl_dim);
19         if (!dim)
20                 return NULL;
21
22         dim->ctx = ctx;
23         isl_ctx_ref(ctx);
24         dim->ref = 1;
25         dim->nparam = nparam;
26         dim->n_in = n_in;
27         dim->n_out = n_out;
28
29         dim->n_name = 0;
30         dim->names = NULL;
31
32         return dim;
33 }
34
35 struct isl_dim *isl_dim_set_alloc(struct isl_ctx *ctx,
36                         unsigned nparam, unsigned dim)
37 {
38         return isl_dim_alloc(ctx, nparam, 0, dim);
39 }
40
41 static unsigned global_pos(struct isl_dim *dim,
42                                  enum isl_dim_type type, unsigned pos)
43 {
44         struct isl_ctx *ctx = dim->ctx;
45
46         switch (type) {
47         case isl_dim_param:
48                 isl_assert(ctx, pos < dim->nparam, return isl_dim_total(dim));
49                 return pos;
50         case isl_dim_in:
51                 isl_assert(ctx, pos < dim->n_in, return isl_dim_total(dim));
52                 return pos + dim->nparam;
53         case isl_dim_out:
54                 isl_assert(ctx, pos < dim->n_out, return isl_dim_total(dim));
55                 return pos + dim->nparam + dim->n_in;
56         default:
57                 isl_assert(ctx, 0, return isl_dim_total(dim));
58         }
59         return isl_dim_total(dim);
60 }
61
62 /* Extend length of names array to the total number of dimensions.
63  */
64 static __isl_give isl_dim *extend_names(__isl_take isl_dim *dim)
65 {
66         struct isl_name **names;
67         int i;
68
69         if (isl_dim_total(dim) <= dim->n_name)
70                 return dim;
71
72         if (!dim->names) {
73                 dim->names = isl_calloc_array(dim->ctx,
74                                 struct isl_name *, isl_dim_total(dim));
75                 if (!dim->names)
76                         goto error;
77         } else {
78                 names = isl_realloc_array(dim->ctx, dim->names,
79                                 struct isl_name *, isl_dim_total(dim));
80                 if (!names)
81                         goto error;
82                 dim->names = names;
83                 for (i = dim->n_name; i < isl_dim_total(dim); ++i)
84                         dim->names[i] = NULL;
85         }
86
87         dim->n_name = isl_dim_total(dim);
88
89         return dim;
90 error:
91         isl_dim_free(dim);
92         return NULL;
93 }
94
95 static struct isl_dim *set_name(struct isl_dim *dim,
96                                  enum isl_dim_type type, unsigned pos,
97                                  struct isl_name *name)
98 {
99         struct isl_ctx *ctx = dim->ctx;
100         dim = isl_dim_cow(dim);
101
102         if (!dim)
103                 goto error;
104
105         pos = global_pos(dim, type, pos);
106         isl_assert(ctx, pos != isl_dim_total(dim), goto error);
107
108         if (pos >= dim->n_name) {
109                 if (!name)
110                         return dim;
111                 dim = extend_names(dim);
112                 if (!dim)
113                         goto error;
114         }
115
116         dim->names[pos] = name;
117
118         return dim;
119 error:
120         isl_name_free(ctx, name);
121         isl_dim_free(dim);
122         return NULL;
123 }
124
125 static struct isl_name *get_name(struct isl_dim *dim,
126                                  enum isl_dim_type type, unsigned pos)
127 {
128         if (!dim)
129                 return NULL;
130
131         pos = global_pos(dim, type, pos);
132         if (pos == isl_dim_total(dim))
133                 return NULL;
134         if (pos >= dim->n_name)
135                 return NULL;
136         return dim->names[pos];
137 }
138
139 static unsigned offset(struct isl_dim *dim, enum isl_dim_type type)
140 {
141         switch (type) {
142         case isl_dim_param:     return 0;
143         case isl_dim_in:        return dim->nparam;
144         case isl_dim_out:       return dim->nparam + dim->n_in;
145         default:                return 0;
146         }
147 }
148
149 static unsigned n(struct isl_dim *dim, enum isl_dim_type type)
150 {
151         switch (type) {
152         case isl_dim_param:     return dim->nparam;
153         case isl_dim_in:        return dim->n_in;
154         case isl_dim_out:       return dim->n_out;
155         default:                return 0;
156         }
157 }
158
159 unsigned isl_dim_size(struct isl_dim *dim, enum isl_dim_type type)
160 {
161         if (!dim)
162                 return 0;
163         return n(dim, type);
164 }
165
166 unsigned isl_dim_offset(__isl_keep isl_dim *dim, enum isl_dim_type type)
167 {
168         if (!dim)
169                 return 0;
170         return offset(dim, type);
171 }
172
173 static struct isl_dim *copy_names(struct isl_dim *dst,
174         enum isl_dim_type dst_type, unsigned offset, struct isl_dim *src,
175         enum isl_dim_type src_type)
176 {
177         int i;
178         struct isl_name *name;
179
180         for (i = 0; i < n(src, src_type); ++i) {
181                 name = get_name(src, src_type, i);
182                 if (!name)
183                         continue;
184                 dst = set_name(dst, dst_type, offset + i,
185                                         isl_name_copy(dst->ctx, name));
186                 if (!dst)
187                         return NULL;
188         }
189         return dst;
190 }
191
192 struct isl_dim *isl_dim_dup(struct isl_dim *dim)
193 {
194         struct isl_dim *dup;
195         dup = isl_dim_alloc(dim->ctx, dim->nparam, dim->n_in, dim->n_out);
196         if (!dim->names)
197                 return dup;
198         dup = copy_names(dup, isl_dim_param, 0, dim, isl_dim_param);
199         dup = copy_names(dup, isl_dim_in, 0, dim, isl_dim_in);
200         dup = copy_names(dup, isl_dim_out, 0, dim, isl_dim_out);
201         return dup;
202 }
203
204 struct isl_dim *isl_dim_cow(struct isl_dim *dim)
205 {
206         if (!dim)
207                 return NULL;
208
209         if (dim->ref == 1)
210                 return dim;
211         dim->ref--;
212         return isl_dim_dup(dim);
213 }
214
215 struct isl_dim *isl_dim_copy(struct isl_dim *dim)
216 {
217         if (!dim)
218                 return NULL;
219
220         dim->ref++;
221         return dim;
222 }
223
224 void isl_dim_free(struct isl_dim *dim)
225 {
226         int i;
227
228         if (!dim)
229                 return;
230
231         if (--dim->ref > 0)
232                 return;
233
234         for (i = 0; i < dim->n_name; ++i)
235                 isl_name_free(dim->ctx, dim->names[i]);
236         free(dim->names);
237         isl_ctx_deref(dim->ctx);
238         
239         free(dim);
240 }
241
242 struct isl_dim *isl_dim_set_name(struct isl_dim *dim,
243                                  enum isl_dim_type type, unsigned pos,
244                                  const char *s)
245 {
246         struct isl_name *name;
247         if (!dim)
248                 return NULL;
249         name = isl_name_get(dim->ctx, s);
250         if (!name)
251                 goto error;
252         return set_name(dim, type, pos, name);
253 error:
254         isl_dim_free(dim);
255         return NULL;
256 }
257
258 const char *isl_dim_get_name(struct isl_dim *dim,
259                                  enum isl_dim_type type, unsigned pos)
260 {
261         struct isl_name *name = get_name(dim, type, pos);
262         return name ? name->name : NULL;
263 }
264
265 static int match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
266                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
267 {
268         int i;
269
270         if (n(dim1, dim1_type) != n(dim2, dim2_type))
271                 return 0;
272
273         if (!dim1->names && !dim2->names)
274                 return 1;
275
276         for (i = 0; i < n(dim1, dim1_type); ++i) {
277                 if (get_name(dim1, dim1_type, i) !=
278                     get_name(dim2, dim2_type, i))
279                         return 0;
280         }
281         return 1;
282 }
283
284 int isl_dim_match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
285                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
286 {
287         return match(dim1, dim1_type, dim2, dim2_type);
288 }
289
290 static void get_names(struct isl_dim *dim, enum isl_dim_type type,
291         unsigned first, unsigned n, struct isl_name **names)
292 {
293         int i;
294
295         for (i = 0; i < n ; ++i)
296                 names[i] = get_name(dim, type, first+i);
297 }
298
299 struct isl_dim *isl_dim_extend(struct isl_dim *dim,
300                         unsigned nparam, unsigned n_in, unsigned n_out)
301 {
302         struct isl_name **names = NULL;
303
304         if (!dim)
305                 return NULL;
306         if (dim->nparam == nparam && dim->n_in == n_in && dim->n_out == n_out)
307                 return dim;
308
309         isl_assert(dim->ctx, dim->nparam <= nparam, goto error);
310         isl_assert(dim->ctx, dim->n_in <= n_in, goto error);
311         isl_assert(dim->ctx, dim->n_out <= n_out, goto error);
312
313         dim = isl_dim_cow(dim);
314
315         if (dim->names) {
316                 names = isl_calloc_array(dim->ctx, struct isl_name *,
317                                          nparam + n_in + n_out);
318                 if (!names)
319                         goto error;
320                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
321                 get_names(dim, isl_dim_in, 0, dim->n_in, names + nparam);
322                 get_names(dim, isl_dim_out, 0, dim->n_out,
323                                 names + nparam + n_in);
324                 free(dim->names);
325                 dim->names = names;
326                 dim->n_name = nparam + n_in + n_out;
327         }
328         dim->nparam = nparam;
329         dim->n_in = n_in;
330         dim->n_out = n_out;
331
332         return dim;
333 error:
334         free(names);
335         isl_dim_free(dim);
336         return NULL;
337 }
338
339 struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
340         unsigned n)
341 {
342         switch (type) {
343         case isl_dim_param:
344                 return isl_dim_extend(dim,
345                                         dim->nparam + n, dim->n_in, dim->n_out);
346         case isl_dim_in:
347                 return isl_dim_extend(dim,
348                                         dim->nparam, dim->n_in + n, dim->n_out);
349         case isl_dim_out:
350                 return isl_dim_extend(dim,
351                                         dim->nparam, dim->n_in, dim->n_out + n);
352         }
353         return dim;
354 }
355
356 __isl_give isl_dim *isl_dim_insert(__isl_take isl_dim *dim,
357         enum isl_dim_type type, unsigned pos, unsigned n)
358 {
359         struct isl_name **names = NULL;
360
361         if (!dim)
362                 return NULL;
363         if (n == 0)
364                 return dim;
365
366         isl_assert(dim->ctx, pos <= isl_dim_size(dim, type), goto error);
367
368         dim = isl_dim_cow(dim);
369         if (!dim)
370                 return NULL;
371
372         if (dim->names) {
373                 enum isl_dim_type t;
374                 int off;
375                 int size[3];
376                 names = isl_calloc_array(dim->ctx, struct isl_name *,
377                                      dim->nparam + dim->n_in + dim->n_out + n);
378                 if (!names)
379                         goto error;
380                 off = 0;
381                 size[isl_dim_param] = dim->nparam;
382                 size[isl_dim_in] = dim->n_in;
383                 size[isl_dim_out] = dim->n_out;
384                 for (t = isl_dim_param; t <= isl_dim_out; ++t) {
385                         if (t != type) {
386                                 get_names(dim, t, 0, size[t], names + off);
387                                 off += size[t];
388                         } else {
389                                 get_names(dim, t, 0, pos, names + off);
390                                 off += pos + n;
391                                 get_names(dim, t, pos, size[t]-pos, names+off);
392                                 off += size[t] - pos;
393                         }
394                 }
395                 free(dim->names);
396                 dim->names = names;
397                 dim->n_name = dim->nparam + dim->n_in + dim->n_out + n;
398         }
399         switch (type) {
400         case isl_dim_param:     dim->nparam += n; break;
401         case isl_dim_in:        dim->n_in += n; break;
402         case isl_dim_out:       dim->n_out += n; break;
403         }
404
405         return dim;
406 error:
407         isl_dim_free(dim);
408         return NULL;
409 }
410
411 __isl_give isl_dim *isl_dim_move(__isl_take isl_dim *dim,
412         enum isl_dim_type dst_type, unsigned dst_pos,
413         enum isl_dim_type src_type, unsigned src_pos, unsigned n)
414 {
415         if (!dim)
416                 return NULL;
417         if (n == 0)
418                 return dim;
419
420         isl_assert(dim->ctx, src_pos + n <= isl_dim_size(dim, src_type),
421                 goto error);
422
423         if (dst_type == src_type && dst_pos == src_pos)
424                 return dim;
425
426         isl_assert(dim->ctx, dst_type != src_type, goto error);
427
428         dim = isl_dim_cow(dim);
429         if (!dim)
430                 return NULL;
431
432         if (dim->names) {
433                 struct isl_name **names;
434                 enum isl_dim_type t;
435                 int off;
436                 int size[3];
437                 names = isl_calloc_array(dim->ctx, struct isl_name *,
438                                          dim->nparam + dim->n_in + dim->n_out);
439                 if (!names)
440                         goto error;
441                 off = 0;
442                 size[isl_dim_param] = dim->nparam;
443                 size[isl_dim_in] = dim->n_in;
444                 size[isl_dim_out] = dim->n_out;
445                 for (t = isl_dim_param; t <= isl_dim_out; ++t) {
446                         if (t == dst_type) {
447                                 get_names(dim, t, 0, dst_pos, names + off);
448                                 off += dst_pos;
449                                 get_names(dim, src_type, src_pos, n, names+off);
450                                 off += n;
451                                 get_names(dim, t, dst_pos, size[t] - dst_pos,
452                                                 names + off);
453                                 off += size[t] - dst_pos;
454                         } else if (t == src_type) {
455                                 get_names(dim, t, 0, src_pos, names + off);
456                                 off += src_pos;
457                                 get_names(dim, t, src_pos + n,
458                                             size[t] - src_pos - n, names + off);
459                                 off += size[t] - src_pos - n;
460                         } else {
461                                 get_names(dim, t, 0, size[t], names + off);
462                                 off += size[t];
463                         }
464                 }
465                 free(dim->names);
466                 dim->names = names;
467                 dim->n_name = dim->nparam + dim->n_in + dim->n_out;
468         }
469
470         switch (dst_type) {
471         case isl_dim_param:     dim->nparam += n; break;
472         case isl_dim_in:        dim->n_in += n; break;
473         case isl_dim_out:       dim->n_out += n; break;
474         }
475
476         switch (src_type) {
477         case isl_dim_param:     dim->nparam -= n; break;
478         case isl_dim_in:        dim->n_in -= n; break;
479         case isl_dim_out:       dim->n_out -= n; break;
480         }
481
482         return dim;
483 error:
484         isl_dim_free(dim);
485         return NULL;
486 }
487
488 struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
489 {
490         struct isl_dim *dim;
491
492         if (!left || !right)
493                 goto error;
494
495         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
496                         goto error);
497         isl_assert(left->ctx, n(left, isl_dim_out) == n(right, isl_dim_in),
498                         goto error);
499
500         dim = isl_dim_alloc(left->ctx, left->nparam, left->n_in, right->n_out);
501         if (!dim)
502                 goto error;
503
504         dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
505         dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
506         dim = copy_names(dim, isl_dim_out, 0, right, isl_dim_out);
507
508         isl_dim_free(left);
509         isl_dim_free(right);
510
511         return dim;
512 error:
513         isl_dim_free(left);
514         isl_dim_free(right);
515         return NULL;
516 }
517
518 struct isl_dim *isl_dim_product(struct isl_dim *left, struct isl_dim *right)
519 {
520         struct isl_dim *dim;
521
522         if (!left || !right)
523                 goto error;
524
525         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
526                         goto error);
527
528         dim = isl_dim_alloc(left->ctx, left->nparam,
529                         left->n_in + right->n_in, left->n_out + right->n_out);
530         if (!dim)
531                 goto error;
532
533         dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
534         dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
535         dim = copy_names(dim, isl_dim_in, left->n_in, right, isl_dim_in);
536         dim = copy_names(dim, isl_dim_out, 0, left, isl_dim_out);
537         dim = copy_names(dim, isl_dim_out, left->n_out, right, isl_dim_out);
538
539         isl_dim_free(left);
540         isl_dim_free(right);
541
542         return dim;
543 error:
544         isl_dim_free(left);
545         isl_dim_free(right);
546         return NULL;
547 }
548
549 struct isl_dim *isl_dim_map(struct isl_dim *dim)
550 {
551         struct isl_name **names = NULL;
552
553         if (!dim)
554                 return NULL;
555         isl_assert(dim->ctx, dim->n_in == 0, goto error);
556         if (dim->n_out == 0)
557                 return dim;
558         dim = isl_dim_cow(dim);
559         if (!dim)
560                 return NULL;
561         if (dim->names) {
562                 names = isl_calloc_array(dim->ctx, struct isl_name *,
563                                         dim->nparam + dim->n_out + dim->n_out);
564                 if (!names)
565                         goto error;
566                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
567                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->nparam);
568         }
569         dim->n_in = dim->n_out;
570         if (names) {
571                 free(dim->names);
572                 dim->names = names;
573                 dim->n_name = dim->nparam + dim->n_out + dim->n_out;
574                 dim = copy_names(dim, isl_dim_out, 0, dim, isl_dim_in);
575         }
576         return dim;
577 error:
578         isl_dim_free(dim);
579         return NULL;
580 }
581
582 static struct isl_dim *set_names(struct isl_dim *dim, enum isl_dim_type type,
583         unsigned first, unsigned n, struct isl_name **names)
584 {
585         int i;
586
587         for (i = 0; i < n ; ++i)
588                 dim = set_name(dim, type, first+i, names[i]);
589
590         return dim;
591 }
592
593 struct isl_dim *isl_dim_reverse(struct isl_dim *dim)
594 {
595         unsigned t;
596         struct isl_name **names = NULL;
597
598         if (!dim)
599                 return NULL;
600         if (match(dim, isl_dim_in, dim, isl_dim_out))
601                 return dim;
602
603         dim = isl_dim_cow(dim);
604         if (!dim)
605                 return NULL;
606
607         if (dim->names) {
608                 names = isl_alloc_array(dim->ctx, struct isl_name *,
609                                         dim->n_in + dim->n_out);
610                 if (!names)
611                         goto error;
612                 get_names(dim, isl_dim_in, 0, dim->n_in, names);
613                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->n_in);
614         }
615
616         t = dim->n_in;
617         dim->n_in = dim->n_out;
618         dim->n_out = t;
619
620         if (dim->names) {
621                 dim = set_names(dim, isl_dim_out, 0, dim->n_out, names);
622                 dim = set_names(dim, isl_dim_in, 0, dim->n_in, names + dim->n_out);
623                 free(names);
624         }
625
626         return dim;
627 error:
628         free(names);
629         isl_dim_free(dim);
630         return NULL;
631 }
632
633 struct isl_dim *isl_dim_drop(struct isl_dim *dim, enum isl_dim_type type,
634                 unsigned first, unsigned num)
635 {
636         int i;
637
638         if (!dim)
639                 return NULL;
640
641         if (n == 0)
642                 return dim;
643
644         isl_assert(dim->ctx, first + num <= n(dim, type), goto error);
645         dim = isl_dim_cow(dim);
646         if (!dim)
647                 goto error;
648         if (dim->names) {
649                 dim = extend_names(dim);
650                 if (!dim)
651                         goto error;
652                 for (i = 0; i < num; ++i)
653                         isl_name_free(dim->ctx, get_name(dim, type, first+i));
654                 for (i = first+num; i < n(dim, type); ++i)
655                         set_name(dim, type, i - num, get_name(dim, type, i));
656                 switch (type) {
657                 case isl_dim_param:
658                         get_names(dim, isl_dim_in, 0, dim->n_in,
659                                 dim->names + offset(dim, isl_dim_in) - num);
660                 case isl_dim_in:
661                         get_names(dim, isl_dim_out, 0, dim->n_out,
662                                 dim->names + offset(dim, isl_dim_out) - num);
663                 case isl_dim_out:
664                         ;
665                 }
666                 dim->n_name -= num;
667         }
668         switch (type) {
669         case isl_dim_param:     dim->nparam -= num; break;
670         case isl_dim_in:        dim->n_in -= num; break;
671         case isl_dim_out:       dim->n_out -= num; break;
672         }
673         return dim;
674 error:
675         isl_dim_free(dim);
676         return NULL;
677 }
678
679 struct isl_dim *isl_dim_drop_inputs(struct isl_dim *dim,
680                 unsigned first, unsigned n)
681 {
682         return isl_dim_drop(dim, isl_dim_in, first, n);
683 }
684
685 struct isl_dim *isl_dim_drop_outputs(struct isl_dim *dim,
686                 unsigned first, unsigned n)
687 {
688         return isl_dim_drop(dim, isl_dim_out, first, n);
689 }
690
691 struct isl_dim *isl_dim_domain(struct isl_dim *dim)
692 {
693         if (!dim)
694                 return NULL;
695         dim = isl_dim_drop_outputs(dim, 0, dim->n_out);
696         return isl_dim_reverse(dim);
697 }
698
699 struct isl_dim *isl_dim_range(struct isl_dim *dim)
700 {
701         if (!dim)
702                 return NULL;
703         return isl_dim_drop_inputs(dim, 0, dim->n_in);
704 }
705
706 struct isl_dim *isl_dim_underlying(struct isl_dim *dim, unsigned n_div)
707 {
708         int i;
709
710         if (!dim)
711                 return NULL;
712         if (n_div == 0 &&
713             dim->nparam == 0 && dim->n_in == 0 && dim->n_name == 0)
714                 return dim;
715         dim = isl_dim_cow(dim);
716         if (!dim)
717                 return NULL;
718         dim->n_out += dim->nparam + dim->n_in + n_div;
719         dim->nparam = 0;
720         dim->n_in = 0;
721
722         for (i = 0; i < dim->n_name; ++i)
723                 isl_name_free(dim->ctx, get_name(dim, isl_dim_out, i));
724         dim->n_name = 0;
725
726         return dim;
727 }
728
729 unsigned isl_dim_total(struct isl_dim *dim)
730 {
731         return dim->nparam + dim->n_in + dim->n_out;
732 }
733
734 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
735 {
736         return match(dim1, isl_dim_param, dim2, isl_dim_param) &&
737                n(dim1, isl_dim_in) == n(dim2, isl_dim_in) &&
738                n(dim1, isl_dim_out) == n(dim2, isl_dim_out);
739 }
740
741 int isl_dim_compatible(struct isl_dim *dim1, struct isl_dim *dim2)
742 {
743         return dim1->nparam == dim2->nparam &&
744                dim1->n_in + dim1->n_out == dim2->n_in + dim2->n_out;
745 }