isl_assert: validate all arguments and fix up fallout
[platform/upstream/isl.git] / isl_dim.c
1 #include "isl_dim.h"
2 #include "isl_name.h"
3
4 struct isl_dim *isl_dim_alloc(struct isl_ctx *ctx,
5                         unsigned nparam, unsigned n_in, unsigned n_out)
6 {
7         struct isl_dim *dim;
8
9         dim = isl_alloc_type(ctx, struct isl_dim);
10         if (!dim)
11                 return NULL;
12
13         dim->ctx = ctx;
14         isl_ctx_ref(ctx);
15         dim->ref = 1;
16         dim->nparam = nparam;
17         dim->n_in = n_in;
18         dim->n_out = n_out;
19
20         dim->n_name = 0;
21         dim->names = NULL;
22
23         return dim;
24 }
25
26 struct isl_dim *isl_dim_set_alloc(struct isl_ctx *ctx,
27                         unsigned nparam, unsigned dim)
28 {
29         return isl_dim_alloc(ctx, nparam, 0, dim);
30 }
31
32 static unsigned global_pos(struct isl_dim *dim,
33                                  enum isl_dim_type type, unsigned pos)
34 {
35         struct isl_ctx *ctx = dim->ctx;
36
37         switch (type) {
38         case isl_dim_param:
39                 isl_assert(ctx, pos < dim->nparam, return isl_dim_total(dim));
40                 return pos;
41         case isl_dim_in:
42                 isl_assert(ctx, pos < dim->n_in, return isl_dim_total(dim));
43                 return pos + dim->nparam;
44         case isl_dim_out:
45                 isl_assert(ctx, pos < dim->n_out, return isl_dim_total(dim));
46                 return pos + dim->nparam + dim->n_in;
47         default:
48                 isl_assert(ctx, 0, return isl_dim_total(dim));
49         }
50         return isl_dim_total(dim);
51 }
52
53 static struct isl_dim *set_name(struct isl_dim *dim,
54                                  enum isl_dim_type type, unsigned pos,
55                                  struct isl_name *name)
56 {
57         struct isl_ctx *ctx = dim->ctx;
58         dim = isl_dim_cow(dim);
59
60         if (!dim)
61                 goto error;
62
63         pos = global_pos(dim, type, pos);
64         isl_assert(ctx, pos != isl_dim_total(dim), goto error);
65
66         if (pos >= dim->n_name) {
67                 if (!name)
68                         return dim;
69                 if (!dim->names) {
70                         dim->names = isl_calloc_array(dim->ctx,
71                                         struct isl_name *, isl_dim_total(dim));
72                         if (!dim->names)
73                                 goto error;
74                 } else {
75                         int i;
76                         dim->names = isl_realloc_array(dim->ctx, dim->names,
77                                         struct isl_name *, isl_dim_total(dim));
78                         if (!dim->names)
79                                 goto error;
80                         for (i = dim->n_name; i < isl_dim_total(dim); ++i)
81                                 dim->names[i] = NULL;
82                 }
83                 dim->n_name = isl_dim_total(dim);
84         }
85
86         dim->names[pos] = name;
87
88         return dim;
89 error:
90         isl_name_free(ctx, name);
91         isl_dim_free(dim);
92         return NULL;
93 }
94
95 static struct isl_name *get_name(struct isl_dim *dim,
96                                  enum isl_dim_type type, unsigned pos)
97 {
98         if (!dim)
99                 return NULL;
100
101         pos = global_pos(dim, type, pos);
102         if (pos == isl_dim_total(dim))
103                 return NULL;
104         if (pos >= dim->n_name)
105                 return NULL;
106         return dim->names[pos];
107 }
108
109 static unsigned offset(struct isl_dim *dim, enum isl_dim_type type)
110 {
111         switch (type) {
112         case isl_dim_param:     return 0;
113         case isl_dim_in:        return dim->nparam;
114         case isl_dim_out:       return dim->nparam + dim->n_in;
115         }
116 }
117
118 static unsigned n(struct isl_dim *dim, enum isl_dim_type type)
119 {
120         switch (type) {
121         case isl_dim_param:     return dim->nparam;
122         case isl_dim_in:        return dim->n_in;
123         case isl_dim_out:       return dim->n_out;
124         }
125 }
126
127 unsigned isl_dim_size(struct isl_dim *dim, enum isl_dim_type type)
128 {
129         return n(dim, type);
130 }
131
132 static struct isl_dim *copy_names(struct isl_dim *dst,
133         enum isl_dim_type dst_type, unsigned offset, struct isl_dim *src,
134         enum isl_dim_type src_type)
135 {
136         int i;
137         struct isl_name *name;
138
139         for (i = 0; i < n(src, src_type); ++i) {
140                 name = get_name(src, src_type, i);
141                 if (!name)
142                         continue;
143                 dst = set_name(dst, dst_type, offset + i,
144                                         isl_name_copy(dst->ctx, name));
145                 if (!dst)
146                         return NULL;
147         }
148         return dst;
149 }
150
151 struct isl_dim *isl_dim_dup(struct isl_dim *dim)
152 {
153         struct isl_dim *dup;
154         dup = isl_dim_alloc(dim->ctx, dim->nparam, dim->n_in, dim->n_out);
155         if (!dim->names)
156                 return dup;
157         dup = copy_names(dup, isl_dim_param, 0, dim, isl_dim_param);
158         dup = copy_names(dup, isl_dim_in, 0, dim, isl_dim_in);
159         dup = copy_names(dup, isl_dim_out, 0, dim, isl_dim_out);
160         return dup;
161 }
162
163 struct isl_dim *isl_dim_cow(struct isl_dim *dim)
164 {
165         if (!dim)
166                 return NULL;
167
168         if (dim->ref == 1)
169                 return dim;
170         dim->ref--;
171         return isl_dim_dup(dim);
172 }
173
174 struct isl_dim *isl_dim_copy(struct isl_dim *dim)
175 {
176         if (!dim)
177                 return NULL;
178
179         dim->ref++;
180         return dim;
181 }
182
183 void isl_dim_free(struct isl_dim *dim)
184 {
185         int i;
186
187         if (!dim)
188                 return;
189
190         if (--dim->ref > 0)
191                 return;
192
193         for (i = 0; i < dim->n_name; ++i)
194                 isl_name_free(dim->ctx, dim->names[i]);
195         free(dim->names);
196         isl_ctx_deref(dim->ctx);
197         
198         free(dim);
199 }
200
201 struct isl_dim *isl_dim_set_name(struct isl_dim *dim,
202                                  enum isl_dim_type type, unsigned pos,
203                                  const char *s)
204 {
205         struct isl_name *name;
206         if (!dim)
207                 return NULL;
208         name = isl_name_get(dim->ctx, s);
209         if (!name)
210                 goto error;
211         return set_name(dim, type, pos, name);
212 error:
213         isl_dim_free(dim);
214         return NULL;
215 }
216
217 const char *isl_dim_get_name(struct isl_dim *dim,
218                                  enum isl_dim_type type, unsigned pos)
219 {
220         struct isl_name *name = get_name(dim, type, pos);
221         return name ? name->name : NULL;
222 }
223
224 static int match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
225                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
226 {
227         int i;
228
229         if (n(dim1, dim1_type) != n(dim2, dim2_type))
230                 return 0;
231
232         if (!dim1->names && !dim2->names)
233                 return 1;
234
235         for (i = 0; i < n(dim1, dim1_type); ++i) {
236                 if (get_name(dim1, dim1_type, i) !=
237                     get_name(dim2, dim2_type, i))
238                         return 0;
239         }
240         return 1;
241 }
242
243 int isl_dim_match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
244                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
245 {
246         return match(dim1, dim1_type, dim2, dim2_type);
247 }
248
249 static void get_names(struct isl_dim *dim, enum isl_dim_type type,
250         unsigned first, unsigned n, struct isl_name **names)
251 {
252         int i;
253
254         for (i = 0; i < n ; ++i)
255                 names[i] = get_name(dim, type, first+i);
256 }
257
258 struct isl_dim *isl_dim_extend(struct isl_dim *dim,
259                         unsigned nparam, unsigned n_in, unsigned n_out)
260 {
261         struct isl_name **names = NULL;
262
263         if (!dim)
264                 return NULL;
265         if (dim->nparam == nparam && dim->n_in == n_in && dim->n_out == n_out)
266                 return dim;
267
268         isl_assert(dim->ctx, dim->nparam <= nparam, goto error);
269         isl_assert(dim->ctx, dim->n_in <= n_in, goto error);
270         isl_assert(dim->ctx, dim->n_out <= n_out, goto error);
271
272         dim = isl_dim_cow(dim);
273
274         if (dim->names) {
275                 names = isl_calloc_array(dim->ctx, struct isl_name *,
276                                          nparam + n_in + n_out);
277                 if (!names)
278                         goto error;
279                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
280                 get_names(dim, isl_dim_in, 0, dim->n_in, names + nparam);
281                 get_names(dim, isl_dim_out, 0, dim->n_out,
282                                 names + nparam + n_in);
283                 free(dim->names);
284                 dim->names = names;
285                 dim->n_name = nparam + n_in + n_out;
286         }
287         dim->nparam = nparam;
288         dim->n_in = n_in;
289         dim->n_out = n_out;
290
291         return dim;
292 error:
293         free(names);
294         isl_dim_free(dim);
295         return NULL;
296 }
297
298 struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
299         unsigned n)
300 {
301         switch (type) {
302         case isl_dim_param:
303                 return isl_dim_extend(dim,
304                                         dim->nparam + n, dim->n_in, dim->n_out);
305         case isl_dim_in:
306                 return isl_dim_extend(dim,
307                                         dim->nparam, dim->n_in + n, dim->n_out);
308         case isl_dim_out:
309                 return isl_dim_extend(dim,
310                                         dim->nparam, dim->n_in, dim->n_out + n);
311         }
312         return dim;
313 }
314
315 struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
316 {
317         struct isl_dim *dim;
318
319         if (!left || !right)
320                 goto error;
321
322         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
323                         goto error);
324         isl_assert(left->ctx, match(left, isl_dim_out, right, isl_dim_in),
325                         goto error);
326
327         dim = isl_dim_alloc(left->ctx, left->nparam, left->n_in, right->n_out);
328         if (!dim)
329                 goto error;
330
331         dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
332         dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
333         dim = copy_names(dim, isl_dim_out, 0, right, isl_dim_out);
334
335         isl_dim_free(left);
336         isl_dim_free(right);
337
338         return dim;
339 error:
340         isl_dim_free(left);
341         isl_dim_free(right);
342         return NULL;
343 }
344
345 struct isl_dim *isl_dim_product(struct isl_dim *left, struct isl_dim *right)
346 {
347         struct isl_dim *dim;
348
349         if (!left || !right)
350                 goto error;
351
352         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
353                         goto error);
354
355         dim = isl_dim_alloc(left->ctx, left->nparam,
356                         left->n_in + right->n_in, left->n_out + right->n_out);
357         if (!dim)
358                 goto error;
359
360         dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
361         dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
362         dim = copy_names(dim, isl_dim_in, left->n_in, right, isl_dim_in);
363         dim = copy_names(dim, isl_dim_out, 0, left, isl_dim_out);
364         dim = copy_names(dim, isl_dim_out, left->n_out, right, isl_dim_out);
365
366         isl_dim_free(left);
367         isl_dim_free(right);
368
369         return dim;
370 error:
371         isl_dim_free(left);
372         isl_dim_free(right);
373         return NULL;
374 }
375
376 struct isl_dim *isl_dim_map(struct isl_dim *dim)
377 {
378         struct isl_name **names = NULL;
379
380         if (!dim)
381                 return NULL;
382         isl_assert(dim->ctx, dim->n_in == 0, goto error);
383         if (dim->n_out == 0)
384                 return dim;
385         dim = isl_dim_cow(dim);
386         if (!dim)
387                 return NULL;
388         if (dim->names) {
389                 names = isl_calloc_array(dim->ctx, struct isl_name *,
390                                         dim->nparam + dim->n_out + dim->n_out);
391                 if (!names)
392                         goto error;
393                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
394                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->nparam);
395         }
396         dim->n_in = dim->n_out;
397         if (names) {
398                 copy_names(dim, isl_dim_out, 0, dim, isl_dim_in);
399                 free(dim->names);
400                 dim->names = names;
401                 dim->n_name = dim->nparam + dim->n_out + dim->n_out;
402         }
403         return dim;
404 error:
405         isl_dim_free(dim);
406         return NULL;
407 }
408
409 static struct isl_dim *set_names(struct isl_dim *dim, enum isl_dim_type type,
410         unsigned first, unsigned n, struct isl_name **names)
411 {
412         int i;
413
414         for (i = 0; i < n ; ++i)
415                 dim = set_name(dim, type, first+i, names[i]);
416
417         return dim;
418 }
419
420 struct isl_dim *isl_dim_reverse(struct isl_dim *dim)
421 {
422         unsigned t;
423         struct isl_name **names = NULL;
424
425         if (!dim)
426                 return NULL;
427         if (match(dim, isl_dim_in, dim, isl_dim_out))
428                 return dim;
429
430         dim = isl_dim_cow(dim);
431         if (!dim)
432                 return NULL;
433
434         if (dim->names) {
435                 names = isl_alloc_array(dim->ctx, struct isl_name *,
436                                         dim->n_in + dim->n_out);
437                 if (!names)
438                         goto error;
439                 get_names(dim, isl_dim_in, 0, dim->n_in, names);
440                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->n_in);
441         }
442
443         t = dim->n_in;
444         dim->n_in = dim->n_out;
445         dim->n_out = t;
446
447         if (dim->names) {
448                 dim = set_names(dim, isl_dim_out, 0, dim->n_out, names);
449                 dim = set_names(dim, isl_dim_in, 0, dim->n_in, names + dim->n_out);
450                 free(names);
451         }
452
453         return dim;
454 error:
455         free(names);
456         isl_dim_free(dim);
457         return NULL;
458 }
459
460 struct isl_dim *isl_dim_drop(struct isl_dim *dim, enum isl_dim_type type,
461                 unsigned first, unsigned num)
462 {
463         int i;
464
465         if (!dim)
466                 return NULL;
467
468         if (n == 0)
469                 return dim;
470
471         isl_assert(dim->ctx, first + num <= n(dim, type), goto error);
472         dim = isl_dim_cow(dim);
473         if (!dim)
474                 goto error;
475         if (dim->names) {
476                 for (i = 0; i < num; ++i)
477                         isl_name_free(dim->ctx, get_name(dim, type, first+i));
478                 for (i = first+num; i < n(dim, type); ++i)
479                         set_name(dim, type, i - num, get_name(dim, type, i));
480                 switch (type) {
481                 case isl_dim_param:
482                         get_names(dim, isl_dim_in, 0, dim->n_in,
483                                 dim->names + offset(dim, isl_dim_in) - num);
484                 case isl_dim_in:
485                         get_names(dim, isl_dim_out, 0, dim->n_out,
486                                 dim->names + offset(dim, isl_dim_out) - num);
487                 case isl_dim_out:
488                         ;
489                 }
490         }
491         switch (type) {
492         case isl_dim_param:     dim->nparam -= num; break;
493         case isl_dim_in:        dim->n_in -= num; break;
494         case isl_dim_out:       dim->n_out -= num; break;
495         }
496         return dim;
497 error:
498         isl_dim_free(dim);
499         return NULL;
500 }
501
502 struct isl_dim *isl_dim_drop_inputs(struct isl_dim *dim,
503                 unsigned first, unsigned n)
504 {
505         return isl_dim_drop(dim, isl_dim_in, first, n);
506 }
507
508 struct isl_dim *isl_dim_drop_outputs(struct isl_dim *dim,
509                 unsigned first, unsigned n)
510 {
511         return isl_dim_drop(dim, isl_dim_out, first, n);
512 }
513
514 struct isl_dim *isl_dim_domain(struct isl_dim *dim)
515 {
516         if (!dim)
517                 return NULL;
518         dim = isl_dim_drop_outputs(dim, 0, dim->n_out);
519         return isl_dim_reverse(dim);
520 }
521
522 struct isl_dim *isl_dim_range(struct isl_dim *dim)
523 {
524         if (!dim)
525                 return NULL;
526         return isl_dim_drop_inputs(dim, 0, dim->n_in);
527 }
528
529 struct isl_dim *isl_dim_underlying(struct isl_dim *dim, unsigned n_div)
530 {
531         int i;
532
533         if (!dim)
534                 return NULL;
535         if (n_div == 0 &&
536             dim->nparam == 0 && dim->n_in == 0 && dim->n_name == 0)
537                 return dim;
538         dim = isl_dim_cow(dim);
539         if (!dim)
540                 return NULL;
541         dim->n_out += dim->nparam + dim->n_in + n_div;
542         dim->nparam = 0;
543         dim->n_in = 0;
544
545         for (i = 0; i < dim->n_name; ++i)
546                 isl_name_free(dim->ctx, get_name(dim, isl_dim_out, i));
547         dim->n_name = 0;
548
549         return dim;
550 }
551
552 unsigned isl_dim_total(struct isl_dim *dim)
553 {
554         return dim->nparam + dim->n_in + dim->n_out;
555 }
556
557 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
558 {
559         return match(dim1, isl_dim_param, dim2, isl_dim_param) &&
560                match(dim1, isl_dim_in, dim2, isl_dim_in) &&
561                match(dim1, isl_dim_out, dim2, isl_dim_out);
562 }
563
564 int isl_dim_compatible(struct isl_dim *dim1, struct isl_dim *dim2)
565 {
566         return dim1->nparam == dim2->nparam &&
567                dim1->n_in + dim1->n_out == dim2->n_in + dim2->n_out;
568 }