a5381dd1538df1db8f900b3b731f77c3ab59efa2
[platform/upstream/isl.git] / isl_dim.c
1 #include "isl_dim.h"
2 #include "isl_name.h"
3
4 struct isl_dim *isl_dim_alloc(struct isl_ctx *ctx,
5                         unsigned nparam, unsigned n_in, unsigned n_out)
6 {
7         struct isl_dim *dim;
8
9         dim = isl_alloc_type(ctx, struct isl_dim);
10         if (!dim)
11                 return NULL;
12
13         dim->ctx = ctx;
14         isl_ctx_ref(ctx);
15         dim->ref = 1;
16         dim->nparam = nparam;
17         dim->n_in = n_in;
18         dim->n_out = n_out;
19
20         dim->n_name = 0;
21         dim->names = NULL;
22
23         return dim;
24 }
25
26 struct isl_dim *isl_dim_set_alloc(struct isl_ctx *ctx,
27                         unsigned nparam, unsigned dim)
28 {
29         return isl_dim_alloc(ctx, nparam, 0, dim);
30 }
31
32 static unsigned global_pos(struct isl_dim *dim,
33                                  enum isl_dim_type type, unsigned pos)
34 {
35         struct isl_ctx *ctx = dim->ctx;
36
37         switch (type) {
38         case isl_dim_param:
39                 isl_assert(ctx, pos < dim->nparam, return isl_dim_total(dim));
40                 return pos;
41         case isl_dim_in:
42                 isl_assert(ctx, pos < dim->n_in, return isl_dim_total(dim));
43                 return pos + dim->nparam;
44         case isl_dim_out:
45                 isl_assert(ctx, pos < dim->n_out, return isl_dim_total(dim));
46                 return pos + dim->nparam + dim->n_in;
47         default:
48                 isl_assert(ctx, 0, goto error);
49         }
50         return isl_dim_total(dim);
51 }
52
53 static struct isl_dim *set_name(struct isl_dim *dim,
54                                  enum isl_dim_type type, unsigned pos,
55                                  struct isl_name *name)
56 {
57         struct isl_ctx *ctx = dim->ctx;
58         dim = isl_dim_cow(dim);
59
60         if (!dim)
61                 goto error;
62
63         pos = global_pos(dim, type, pos);
64         isl_assert(ctx, pos != isl_dim_total(dim), goto error);
65
66         if (pos >= dim->n_name) {
67                 if (!name)
68                         return dim;
69                 if (!dim->names) {
70                         dim->names = isl_calloc_array(dim->ctx,
71                                         struct isl_name *, isl_dim_total(dim));
72                         if (!dim->names)
73                                 goto error;
74                 } else {
75                         int i;
76                         dim->names = isl_realloc_array(dim->ctx, dim->names,
77                                         struct isl_name *, isl_dim_total(dim));
78                         if (!dim->names)
79                                 goto error;
80                         for (i = dim->n_name; i < isl_dim_total(dim); ++i)
81                                 dim->names[i] = NULL;
82                 }
83                 dim->n_name = isl_dim_total(dim);
84         }
85
86         dim->names[pos] = name;
87
88         return dim;
89 error:
90         isl_name_free(ctx, name);
91         isl_dim_free(dim);
92         return NULL;
93 }
94
95 static struct isl_name *get_name(struct isl_dim *dim,
96                                  enum isl_dim_type type, unsigned pos)
97 {
98         if (!dim)
99                 return NULL;
100
101         pos = global_pos(dim, type, pos);
102         if (pos == isl_dim_total(dim))
103                 return NULL;
104         if (pos >= dim->n_name)
105                 return NULL;
106         return dim->names[pos];
107 }
108
109 static unsigned n(struct isl_dim *dim, enum isl_dim_type type)
110 {
111         switch (type) {
112         case isl_dim_param:     return dim->nparam;
113         case isl_dim_in:        return dim->n_in;
114         case isl_dim_out:       return dim->n_out;
115         }
116 }
117
118 unsigned isl_dim_size(struct isl_dim *dim, enum isl_dim_type type)
119 {
120         return n(dim, type);
121 }
122
123 static struct isl_dim *copy_names(struct isl_dim *dst,
124         enum isl_dim_type dst_type, unsigned offset, struct isl_dim *src,
125         enum isl_dim_type src_type)
126 {
127         int i;
128         struct isl_name *name;
129
130         for (i = 0; i < n(src, src_type); ++i) {
131                 name = get_name(src, src_type, i);
132                 if (!name)
133                         continue;
134                 dst = set_name(dst, dst_type, offset + i,
135                                         isl_name_copy(dst->ctx, name));
136                 if (!dst)
137                         return NULL;
138         }
139         return dst;
140 }
141
142 struct isl_dim *isl_dim_dup(struct isl_dim *dim)
143 {
144         struct isl_dim *dup;
145         dup = isl_dim_alloc(dim->ctx, dim->nparam, dim->n_in, dim->n_out);
146         if (!dim->names)
147                 return dup;
148         dup = copy_names(dup, isl_dim_param, 0, dim, isl_dim_param);
149         dup = copy_names(dup, isl_dim_in, 0, dim, isl_dim_in);
150         dup = copy_names(dup, isl_dim_out, 0, dim, isl_dim_out);
151         return dup;
152 }
153
154 struct isl_dim *isl_dim_cow(struct isl_dim *dim)
155 {
156         if (!dim)
157                 return NULL;
158
159         if (dim->ref == 1)
160                 return dim;
161         dim->ref--;
162         return isl_dim_dup(dim);
163 }
164
165 struct isl_dim *isl_dim_copy(struct isl_dim *dim)
166 {
167         if (!dim)
168                 return NULL;
169
170         dim->ref++;
171         return dim;
172 }
173
174 void isl_dim_free(struct isl_dim *dim)
175 {
176         int i;
177
178         if (!dim)
179                 return;
180
181         if (--dim->ref > 0)
182                 return;
183
184         for (i = 0; i < dim->n_name; ++i)
185                 isl_name_free(dim->ctx, dim->names[i]);
186         free(dim->names);
187         isl_ctx_deref(dim->ctx);
188         
189         free(dim);
190 }
191
192 struct isl_dim *isl_dim_set_name(struct isl_dim *dim,
193                                  enum isl_dim_type type, unsigned pos,
194                                  const char *s)
195 {
196         struct isl_name *name;
197         if (!dim)
198                 return NULL;
199         name = isl_name_get(dim->ctx, s);
200         if (!name)
201                 goto error;
202         return set_name(dim, type, pos, name);
203 error:
204         isl_dim_free(dim);
205         return NULL;
206 }
207
208 const char *isl_dim_get_name(struct isl_dim *dim,
209                                  enum isl_dim_type type, unsigned pos)
210 {
211         struct isl_name *name = get_name(dim, type, pos);
212         return name ? name->name : NULL;
213 }
214
215 static int match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
216                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
217 {
218         int i;
219
220         if (n(dim1, dim1_type) != n(dim2, dim2_type))
221                 return 0;
222
223         if (!dim1->names && !dim2->names)
224                 return 1;
225
226         for (i = 0; i < n(dim1, dim1_type); ++i) {
227                 if (get_name(dim1, dim1_type, i) !=
228                     get_name(dim2, dim2_type, i))
229                         return 0;
230         }
231         return 1;
232 }
233
234 int isl_dim_match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
235                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
236 {
237         return match(dim1, dim1_type, dim2, dim2_type);
238 }
239
240 static void get_names(struct isl_dim *dim, enum isl_dim_type type,
241         unsigned first, unsigned n, struct isl_name **names)
242 {
243         int i;
244
245         for (i = 0; i < n ; ++i)
246                 names[i] = get_name(dim, type, first+i);
247 }
248
249 struct isl_dim *isl_dim_extend(struct isl_dim *dim,
250                         unsigned nparam, unsigned n_in, unsigned n_out)
251 {
252         struct isl_name **names = NULL;
253
254         if (!dim)
255                 return NULL;
256         if (dim->nparam == nparam && dim->n_in == n_in && dim->n_out == n_out)
257                 return dim;
258
259         isl_assert(dim->ctx, dim->nparam <= nparam, goto error);
260         isl_assert(dim->ctx, dim->n_in <= n_in, goto error);
261         isl_assert(dim->ctx, dim->n_out <= n_out, goto error);
262
263         dim = isl_dim_cow(dim);
264
265         if (dim->names) {
266                 names = isl_calloc_array(dim->ctx, struct isl_name *,
267                                          nparam + n_in + n_out);
268                 if (!names)
269                         goto error;
270                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
271                 get_names(dim, isl_dim_in, 0, dim->n_in, names + nparam);
272                 get_names(dim, isl_dim_out, 0, dim->n_out,
273                                 names + nparam + n_in);
274                 free(dim->names);
275                 dim->names = names;
276                 dim->n_name = nparam + n_in + n_out;
277         }
278         dim->nparam = nparam;
279         dim->n_in = n_in;
280         dim->n_out = n_out;
281
282         return dim;
283 error:
284         free(names);
285         isl_dim_free(dim);
286         return NULL;
287 }
288
289 struct isl_dim *isl_dim_add(struct isl_dim *dim, enum isl_dim_type type,
290         unsigned n)
291 {
292         switch (type) {
293         case isl_dim_param:
294                 return isl_dim_extend(dim,
295                                         dim->nparam + n, dim->n_in, dim->n_out);
296         case isl_dim_in:
297                 return isl_dim_extend(dim,
298                                         dim->nparam, dim->n_in + n, dim->n_out);
299         case isl_dim_out:
300                 return isl_dim_extend(dim,
301                                         dim->nparam, dim->n_in, dim->n_out + n);
302         }
303         return dim;
304 }
305
306 struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
307 {
308         struct isl_dim *dim;
309
310         if (!left || !right)
311                 goto error;
312
313         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
314                         goto error);
315         isl_assert(left->ctx, match(left, isl_dim_out, right, isl_dim_in),
316                         goto error);
317
318         dim = isl_dim_alloc(left->ctx, left->nparam, left->n_in, right->n_out);
319         if (!dim)
320                 goto error;
321
322         dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
323         dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
324         dim = copy_names(dim, isl_dim_out, 0, right, isl_dim_out);
325
326         isl_dim_free(left);
327         isl_dim_free(right);
328
329         return dim;
330 error:
331         isl_dim_free(left);
332         isl_dim_free(right);
333         return NULL;
334 }
335
336 struct isl_dim *isl_dim_product(struct isl_dim *left, struct isl_dim *right)
337 {
338         struct isl_dim *dim;
339
340         if (!left || !right)
341                 goto error;
342
343         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
344                         goto error);
345
346         dim = isl_dim_alloc(left->ctx, left->nparam,
347                         left->n_in + right->n_in, left->n_out + right->n_out);
348         if (!dim)
349                 goto error;
350
351         dim = copy_names(dim, isl_dim_param, 0, left, isl_dim_param);
352         dim = copy_names(dim, isl_dim_in, 0, left, isl_dim_in);
353         dim = copy_names(dim, isl_dim_in, left->n_in, right, isl_dim_in);
354         dim = copy_names(dim, isl_dim_out, 0, left, isl_dim_out);
355         dim = copy_names(dim, isl_dim_out, left->n_out, right, isl_dim_out);
356
357         isl_dim_free(left);
358         isl_dim_free(right);
359
360         return dim;
361 error:
362         isl_dim_free(left);
363         isl_dim_free(right);
364         return NULL;
365 }
366
367 struct isl_dim *isl_dim_map(struct isl_dim *dim)
368 {
369         struct isl_name **names = NULL;
370
371         if (!dim)
372                 return NULL;
373         isl_assert(dim->ctx, dim->n_in == 0, goto error);
374         if (dim->n_out == 0)
375                 return dim;
376         dim = isl_dim_cow(dim);
377         if (!dim)
378                 return NULL;
379         if (dim->names) {
380                 names = isl_calloc_array(dim->ctx, struct isl_name *,
381                                         dim->nparam + dim->n_out + dim->n_out);
382                 if (!names)
383                         goto error;
384                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
385                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->nparam);
386         }
387         dim->n_in = dim->n_out;
388         if (names) {
389                 copy_names(dim, isl_dim_out, 0, dim, isl_dim_in);
390                 free(dim->names);
391                 dim->names = names;
392                 dim->n_name = dim->nparam + dim->n_out + dim->n_out;
393         }
394         return dim;
395 error:
396         isl_dim_free(dim);
397         return NULL;
398 }
399
400 static struct isl_dim *set_names(struct isl_dim *dim, enum isl_dim_type type,
401         unsigned first, unsigned n, struct isl_name **names)
402 {
403         int i;
404
405         for (i = 0; i < n ; ++i)
406                 dim = set_name(dim, type, first+i, names[i]);
407
408         return dim;
409 }
410
411 struct isl_dim *isl_dim_reverse(struct isl_dim *dim)
412 {
413         unsigned t;
414         struct isl_name **names = NULL;
415
416         if (!dim)
417                 return NULL;
418         if (match(dim, isl_dim_in, dim, isl_dim_out))
419                 return dim;
420
421         dim = isl_dim_cow(dim);
422         if (!dim)
423                 return NULL;
424
425         if (dim->names) {
426                 names = isl_alloc_array(dim->ctx, struct isl_name *,
427                                         dim->n_in + dim->n_out);
428                 if (!names)
429                         goto error;
430                 get_names(dim, isl_dim_in, 0, dim->n_in, names);
431                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->n_in);
432         }
433
434         t = dim->n_in;
435         dim->n_in = dim->n_out;
436         dim->n_out = t;
437
438         if (dim->names) {
439                 dim = set_names(dim, isl_dim_out, 0, dim->n_out, names);
440                 dim = set_names(dim, isl_dim_in, 0, dim->n_in, names + dim->n_out);
441                 free(names);
442         }
443
444         return dim;
445 error:
446         free(names);
447         isl_dim_free(dim);
448         return NULL;
449 }
450
451 struct isl_dim *isl_dim_drop_inputs(struct isl_dim *dim,
452                 unsigned first, unsigned n)
453 {
454         int i;
455
456         if (!dim)
457                 return NULL;
458
459         if (n == 0)
460                 return dim;
461
462         isl_assert(dim->ctx, first + n <= dim->n_in, goto error);
463         dim = isl_dim_cow(dim);
464         if (!dim)
465                 goto error;
466         if (dim->names) {
467                 for (i = 0; i < n; ++i) {
468                         isl_name_free(dim->ctx,
469                                         get_name(dim, isl_dim_in, first+i));
470                 }
471                 for (i = first+n; i < dim->n_in; ++i)
472                         set_name(dim, isl_dim_in, i - n,
473                                 get_name(dim, isl_dim_in, i));
474                 get_names(dim, isl_dim_out, 0, dim->n_out,
475                                 dim->names + dim->nparam + dim->n_in - n);
476         }
477         dim->n_in -= n;
478         return dim;
479 error:
480         isl_dim_free(dim);
481         return NULL;
482 }
483
484 struct isl_dim *isl_dim_drop_outputs(struct isl_dim *dim,
485                 unsigned first, unsigned n)
486 {
487         int i;
488
489         if (!dim)
490                 return NULL;
491
492         if (n == 0)
493                 return dim;
494
495         isl_assert(dim->ctx, first + n <= dim->n_out, goto error);
496         dim = isl_dim_cow(dim);
497         if (!dim)
498                 goto error;
499         if (dim->names) {
500                 for (i = 0; i < n; ++i) {
501                         isl_name_free(dim->ctx,
502                                         get_name(dim, isl_dim_out, first+i));
503                 }
504                 for (i = first+n; i < dim->n_out; ++i)
505                         set_name(dim, isl_dim_out, i - n,
506                                 get_name(dim, isl_dim_out, i));
507         }
508         dim->n_out -= n;
509         return dim;
510 error:
511         isl_dim_free(dim);
512         return NULL;
513 }
514
515 struct isl_dim *isl_dim_domain(struct isl_dim *dim)
516 {
517         if (!dim)
518                 return NULL;
519         dim = isl_dim_drop_outputs(dim, 0, dim->n_out);
520         return isl_dim_reverse(dim);
521 }
522
523 struct isl_dim *isl_dim_underlying(struct isl_dim *dim, unsigned n_div)
524 {
525         int i;
526
527         if (!dim)
528                 return NULL;
529         if (n_div == 0 &&
530             dim->nparam == 0 && dim->n_in == 0 && dim->n_name == 0)
531                 return dim;
532         dim = isl_dim_cow(dim);
533         if (!dim)
534                 return NULL;
535         dim->n_out += dim->nparam + dim->n_in + n_div;
536         dim->nparam = 0;
537         dim->n_in = 0;
538
539         for (i = 0; i < dim->n_name; ++i)
540                 isl_name_free(dim->ctx, get_name(dim, isl_dim_out, i));
541         dim->n_name = 0;
542
543         return dim;
544 }
545
546 unsigned isl_dim_total(struct isl_dim *dim)
547 {
548         return dim->nparam + dim->n_in + dim->n_out;
549 }
550
551 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
552 {
553         return match(dim1, isl_dim_param, dim2, isl_dim_param) &&
554                match(dim1, isl_dim_in, dim2, isl_dim_in) &&
555                match(dim1, isl_dim_out, dim2, isl_dim_out);
556 }
557
558 int isl_dim_compatible(struct isl_dim *dim1, struct isl_dim *dim2)
559 {
560         return dim1->nparam == dim2->nparam &&
561                dim1->n_in + dim1->n_out == dim2->n_in + dim2->n_out;
562 }