298aa80cabeba30b9a39468321250434c1fedc52
[platform/upstream/isl.git] / isl_dim.c
1 #include "isl_dim.h"
2 #include "isl_name.h"
3
4 struct isl_dim *isl_dim_alloc(struct isl_ctx *ctx,
5                         unsigned nparam, unsigned n_in, unsigned n_out)
6 {
7         struct isl_dim *dim;
8
9         dim = isl_alloc_type(ctx, struct isl_dim);
10         if (!dim)
11                 return NULL;
12
13         dim->ctx = ctx;
14         isl_ctx_ref(ctx);
15         dim->ref = 1;
16         dim->nparam = nparam;
17         dim->n_in = n_in;
18         dim->n_out = n_out;
19
20         dim->n_name = 0;
21         dim->names = NULL;
22
23         return dim;
24 }
25
26 struct isl_dim *isl_dim_set_alloc(struct isl_ctx *ctx,
27                         unsigned nparam, unsigned dim)
28 {
29         return isl_dim_alloc(ctx, nparam, 0, dim);
30 }
31
32 static unsigned global_pos(struct isl_dim *dim,
33                                  enum isl_dim_type type, unsigned pos)
34 {
35         struct isl_ctx *ctx = dim->ctx;
36
37         switch (type) {
38         case isl_dim_param:
39                 isl_assert(ctx, pos < dim->nparam, return isl_dim_total(dim));
40                 return pos;
41         case isl_dim_in:
42                 isl_assert(ctx, pos < dim->n_in, return isl_dim_total(dim));
43                 return pos + dim->nparam;
44         case isl_dim_out:
45                 isl_assert(ctx, pos < dim->n_out, return isl_dim_total(dim));
46                 return pos + dim->nparam + dim->n_in;
47         default:
48                 isl_assert(ctx, 0, goto error);
49         }
50         return isl_dim_total(dim);
51 }
52
53 static struct isl_dim *set_name(struct isl_dim *dim,
54                                  enum isl_dim_type type, unsigned pos,
55                                  struct isl_name *name)
56 {
57         struct isl_ctx *ctx = dim->ctx;
58         dim = isl_dim_cow(dim);
59
60         if (!dim)
61                 goto error;
62
63         pos = global_pos(dim, type, pos);
64         isl_assert(ctx, pos != isl_dim_total(dim), goto error);
65
66         if (pos >= dim->n_name) {
67                 if (!name)
68                         return dim;
69                 if (!dim->names) {
70                         dim->names = isl_calloc_array(dim->ctx,
71                                         struct isl_name *, isl_dim_total(dim));
72                         if (!dim->names)
73                                 goto error;
74                 } else {
75                         int i;
76                         dim->names = isl_realloc_array(dim->ctx, dim->names,
77                                         struct isl_name *, isl_dim_total(dim));
78                         if (!dim->names)
79                                 goto error;
80                         for (i = dim->n_name; i < isl_dim_total(dim); ++i)
81                                 dim->names[i] = NULL;
82                 }
83                 dim->n_name = isl_dim_total(dim);
84         }
85
86         dim->names[pos] = name;
87
88         return dim;
89 error:
90         isl_name_free(ctx, name);
91         isl_dim_free(dim);
92         return NULL;
93 }
94
95 static struct isl_name *get_name(struct isl_dim *dim,
96                                  enum isl_dim_type type, unsigned pos)
97 {
98         if (!dim)
99                 return NULL;
100
101         pos = global_pos(dim, type, pos);
102         if (pos == isl_dim_total(dim))
103                 return NULL;
104         if (pos >= dim->n_name)
105                 return NULL;
106         return dim->names[pos];
107 }
108
109 static unsigned n(struct isl_dim *dim, enum isl_dim_type type)
110 {
111         switch (type) {
112         case isl_dim_param:     return dim->nparam;
113         case isl_dim_in:        return dim->n_in;
114         case isl_dim_out:       return dim->n_out;
115         }
116 }
117
118 unsigned isl_dim_size(struct isl_dim *dim, enum isl_dim_type type)
119 {
120         return n(dim, type);
121 }
122
123 static struct isl_dim *copy_names(struct isl_dim *dst,
124         enum isl_dim_type dst_type, struct isl_dim *src,
125         enum isl_dim_type src_type)
126 {
127         int i;
128         struct isl_name *name;
129
130         for (i = 0; i < n(dst, dst_type); ++i) {
131                 name = get_name(src, src_type, i);
132                 if (!name)
133                         continue;
134                 dst = set_name(dst, dst_type, i, isl_name_copy(dst->ctx, name));
135                 if (!dst)
136                         return NULL;
137         }
138         return dst;
139 }
140
141 struct isl_dim *isl_dim_dup(struct isl_dim *dim)
142 {
143         struct isl_dim *dup;
144         dup = isl_dim_alloc(dim->ctx, dim->nparam, dim->n_in, dim->n_out);
145         if (!dim->names)
146                 return dup;
147         dup = copy_names(dup, isl_dim_param, dim, isl_dim_param);
148         dup = copy_names(dup, isl_dim_in, dim, isl_dim_in);
149         dup = copy_names(dup, isl_dim_out, dim, isl_dim_out);
150         return dup;
151 }
152
153 struct isl_dim *isl_dim_cow(struct isl_dim *dim)
154 {
155         if (!dim)
156                 return NULL;
157
158         if (dim->ref == 1)
159                 return dim;
160         dim->ref--;
161         return isl_dim_dup(dim);
162 }
163
164 struct isl_dim *isl_dim_copy(struct isl_dim *dim)
165 {
166         if (!dim)
167                 return NULL;
168
169         dim->ref++;
170         return dim;
171 }
172
173 void isl_dim_free(struct isl_dim *dim)
174 {
175         int i;
176
177         if (!dim)
178                 return;
179
180         if (--dim->ref > 0)
181                 return;
182
183         for (i = 0; i < dim->n_name; ++i)
184                 isl_name_free(dim->ctx, dim->names[i]);
185         free(dim->names);
186         isl_ctx_deref(dim->ctx);
187         
188         free(dim);
189 }
190
191 struct isl_dim *isl_dim_set_name(struct isl_dim *dim,
192                                  enum isl_dim_type type, unsigned pos,
193                                  const char *s)
194 {
195         struct isl_name *name;
196         if (!dim)
197                 return NULL;
198         name = isl_name_get(dim->ctx, s);
199         if (!name)
200                 goto error;
201         return set_name(dim, type, pos, name);
202 error:
203         isl_dim_free(dim);
204         return NULL;
205 }
206
207 const char *isl_dim_get_name(struct isl_dim *dim,
208                                  enum isl_dim_type type, unsigned pos)
209 {
210         struct isl_name *name = get_name(dim, type, pos);
211         return name ? name->name : NULL;
212 }
213
214 static int match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
215                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
216 {
217         int i;
218
219         if (n(dim1, dim1_type) != n(dim2, dim2_type))
220                 return 0;
221
222         if (!dim1->names && !dim2->names)
223                 return 1;
224
225         for (i = 0; i < n(dim1, dim1_type); ++i) {
226                 if (get_name(dim1, dim1_type, i) !=
227                     get_name(dim2, dim2_type, i))
228                         return 0;
229         }
230         return 1;
231 }
232
233 int isl_dim_match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
234                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
235 {
236         return match(dim1, dim1_type, dim2, dim2_type);
237 }
238
239 static void get_names(struct isl_dim *dim, enum isl_dim_type type,
240         unsigned first, unsigned n, struct isl_name **names)
241 {
242         int i;
243
244         for (i = 0; i < n ; ++i)
245                 names[i] = get_name(dim, type, first+i);
246 }
247
248 struct isl_dim *isl_dim_extend(struct isl_dim *dim,
249                         unsigned nparam, unsigned n_in, unsigned n_out)
250 {
251         struct isl_name **names = NULL;
252
253         if (!dim)
254                 return NULL;
255         if (dim->nparam == nparam && dim->n_in == n_in && dim->n_out == n_out)
256                 return dim;
257
258         isl_assert(dim->ctx, dim->nparam <= nparam, goto error);
259         isl_assert(dim->ctx, dim->n_in <= n_in, goto error);
260         isl_assert(dim->ctx, dim->n_out <= n_out, goto error);
261
262         dim = isl_dim_cow(dim);
263
264         if (dim->names) {
265                 names = isl_calloc_array(dim->ctx, struct isl_name *,
266                                          nparam + n_in + n_out);
267                 if (!names)
268                         goto error;
269                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
270                 get_names(dim, isl_dim_in, 0, dim->n_in, names + nparam);
271                 get_names(dim, isl_dim_out, 0, dim->n_out,
272                                 names + nparam + n_in);
273                 free(dim->names);
274                 dim->names = names;
275                 dim->n_name = nparam + n_in + n_out;
276         }
277         dim->nparam = nparam;
278         dim->n_in = n_in;
279         dim->n_out = n_out;
280
281         return dim;
282 error:
283         free(names);
284         isl_dim_free(dim);
285         return NULL;
286 }
287
288 struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
289 {
290         struct isl_dim *dim;
291
292         if (!left || !right)
293                 goto error;
294
295         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
296                         goto error);
297         isl_assert(left->ctx, match(left, isl_dim_out, right, isl_dim_in),
298                         goto error);
299
300         dim = isl_dim_alloc(left->ctx, left->nparam, left->n_in, right->n_out);
301         if (!dim)
302                 goto error;
303
304         dim = copy_names(dim, isl_dim_param, left, isl_dim_param);
305         dim = copy_names(dim, isl_dim_in, left, isl_dim_in);
306         dim = copy_names(dim, isl_dim_out, right, isl_dim_out);
307
308         isl_dim_free(left);
309         isl_dim_free(right);
310
311         return dim;
312 error:
313         isl_dim_free(left);
314         isl_dim_free(right);
315         return NULL;
316 }
317
318 struct isl_dim *isl_dim_map(struct isl_dim *dim)
319 {
320         struct isl_name **names = NULL;
321
322         if (!dim)
323                 return NULL;
324         isl_assert(dim->ctx, dim->n_in == 0, goto error);
325         if (dim->n_out == 0)
326                 return dim;
327         dim = isl_dim_cow(dim);
328         if (!dim)
329                 return NULL;
330         if (dim->names) {
331                 names = isl_calloc_array(dim->ctx, struct isl_name *,
332                                         dim->nparam + dim->n_out + dim->n_out);
333                 if (!names)
334                         goto error;
335                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
336                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->nparam);
337         }
338         dim->n_in = dim->n_out;
339         if (names) {
340                 copy_names(dim, isl_dim_out, dim, isl_dim_in);
341                 free(dim->names);
342                 dim->names = names;
343                 dim->n_name = dim->nparam + dim->n_out + dim->n_out;
344         }
345         return dim;
346 error:
347         isl_dim_free(dim);
348         return NULL;
349 }
350
351 static struct isl_dim *set_names(struct isl_dim *dim, enum isl_dim_type type,
352         unsigned first, unsigned n, struct isl_name **names)
353 {
354         int i;
355
356         for (i = 0; i < n ; ++i)
357                 dim = set_name(dim, type, first+i, names[i]);
358
359         return dim;
360 }
361
362 struct isl_dim *isl_dim_reverse(struct isl_dim *dim)
363 {
364         unsigned t;
365         struct isl_name **names = NULL;
366
367         if (!dim)
368                 return NULL;
369         if (match(dim, isl_dim_in, dim, isl_dim_out))
370                 return dim;
371
372         dim = isl_dim_cow(dim);
373         if (!dim)
374                 return NULL;
375
376         if (dim->names) {
377                 names = isl_alloc_array(dim->ctx, struct isl_name *,
378                                         dim->n_in + dim->n_out);
379                 if (!names)
380                         goto error;
381                 get_names(dim, isl_dim_in, 0, dim->n_in, names);
382                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->n_in);
383         }
384
385         t = dim->n_in;
386         dim->n_in = dim->n_out;
387         dim->n_out = t;
388
389         if (dim->names) {
390                 dim = set_names(dim, isl_dim_out, 0, dim->n_out, names);
391                 dim = set_names(dim, isl_dim_in, 0, dim->n_in, names + dim->n_out);
392                 free(names);
393         }
394
395         return dim;
396 error:
397         free(names);
398         isl_dim_free(dim);
399         return NULL;
400 }
401
402 struct isl_dim *isl_dim_drop_inputs(struct isl_dim *dim,
403                 unsigned first, unsigned n)
404 {
405         int i;
406
407         if (!dim)
408                 return NULL;
409
410         if (n == 0)
411                 return dim;
412
413         isl_assert(dim->ctx, first + n <= dim->n_in, goto error);
414         dim = isl_dim_cow(dim);
415         if (!dim)
416                 goto error;
417         if (dim->names) {
418                 for (i = 0; i < n; ++i) {
419                         isl_name_free(dim->ctx,
420                                         get_name(dim, isl_dim_in, first+i));
421                 }
422                 for (i = first+n; i < dim->n_in; ++i)
423                         set_name(dim, isl_dim_in, i - n,
424                                 get_name(dim, isl_dim_in, i));
425                 get_names(dim, isl_dim_out, 0, dim->n_out,
426                                 dim->names + dim->nparam + dim->n_in - n);
427         }
428         dim->n_in -= n;
429         return dim;
430 error:
431         isl_dim_free(dim);
432         return NULL;
433 }
434
435 struct isl_dim *isl_dim_drop_outputs(struct isl_dim *dim,
436                 unsigned first, unsigned n)
437 {
438         int i;
439
440         if (!dim)
441                 return NULL;
442
443         if (n == 0)
444                 return dim;
445
446         isl_assert(dim->ctx, first + n <= dim->n_out, goto error);
447         dim = isl_dim_cow(dim);
448         if (!dim)
449                 goto error;
450         if (dim->names) {
451                 for (i = 0; i < n; ++i) {
452                         isl_name_free(dim->ctx,
453                                         get_name(dim, isl_dim_out, first+i));
454                 }
455                 for (i = first+n; i < dim->n_out; ++i)
456                         set_name(dim, isl_dim_out, i - n,
457                                 get_name(dim, isl_dim_out, i));
458         }
459         dim->n_out -= n;
460         return dim;
461 error:
462         isl_dim_free(dim);
463         return NULL;
464 }
465
466 struct isl_dim *isl_dim_domain(struct isl_dim *dim)
467 {
468         if (!dim)
469                 return NULL;
470         dim = isl_dim_drop_outputs(dim, 0, dim->n_out);
471         return isl_dim_reverse(dim);
472 }
473
474 struct isl_dim *isl_dim_underlying(struct isl_dim *dim, unsigned n_div)
475 {
476         int i;
477
478         if (!dim)
479                 return NULL;
480         if (n_div == 0 &&
481             dim->nparam == 0 && dim->n_in == 0 && dim->n_name == 0)
482                 return dim;
483         dim = isl_dim_cow(dim);
484         if (!dim)
485                 return NULL;
486         dim->n_out += dim->nparam + dim->n_in + n_div;
487         dim->nparam = 0;
488         dim->n_in = 0;
489
490         for (i = 0; i < dim->n_name; ++i)
491                 isl_name_free(dim->ctx, get_name(dim, isl_dim_out, i));
492         dim->n_name = 0;
493
494         return dim;
495 }
496
497 unsigned isl_dim_total(struct isl_dim *dim)
498 {
499         return dim->nparam + dim->n_in + dim->n_out;
500 }
501
502 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
503 {
504         return match(dim1, isl_dim_param, dim2, isl_dim_param) &&
505                match(dim1, isl_dim_in, dim2, isl_dim_in) &&
506                match(dim1, isl_dim_out, dim2, isl_dim_out);
507 }
508
509 int isl_dim_compatible(struct isl_dim *dim1, struct isl_dim *dim2)
510 {
511         return dim1->nparam == dim2->nparam &&
512                dim1->n_in + dim1->n_out == dim2->n_in + dim2->n_out;
513 }