61926ca609aba4c78595c84546f8b431209aaa33
[platform/upstream/isl.git] / isl_dim.c
1 #include "isl_dim.h"
2 #include "isl_name.h"
3
4 struct isl_dim *isl_dim_alloc(struct isl_ctx *ctx,
5                         unsigned nparam, unsigned n_in, unsigned n_out)
6 {
7         struct isl_dim *dim;
8
9         dim = isl_alloc_type(ctx, struct isl_dim);
10         if (!dim)
11                 return NULL;
12
13         dim->ctx = ctx;
14         isl_ctx_ref(ctx);
15         dim->ref = 1;
16         dim->nparam = nparam;
17         dim->n_in = n_in;
18         dim->n_out = n_out;
19
20         dim->n_name = 0;
21         dim->names = NULL;
22
23         return dim;
24 }
25
26 struct isl_dim *isl_dim_set_alloc(struct isl_ctx *ctx,
27                         unsigned nparam, unsigned dim)
28 {
29         return isl_dim_alloc(ctx, nparam, 0, dim);
30 }
31
32 static unsigned global_pos(struct isl_dim *dim,
33                                  enum isl_dim_type type, unsigned pos)
34 {
35         struct isl_ctx *ctx = dim->ctx;
36
37         switch (type) {
38         case isl_dim_param:
39                 isl_assert(ctx, pos < dim->nparam, return isl_dim_total(dim));
40                 return pos;
41         case isl_dim_in:
42                 isl_assert(ctx, pos < dim->n_in, return isl_dim_total(dim));
43                 return pos + dim->nparam;
44         case isl_dim_out:
45                 isl_assert(ctx, pos < dim->n_out, return isl_dim_total(dim));
46                 return pos + dim->nparam + dim->n_in;
47         default:
48                 isl_assert(ctx, 0, goto error);
49         }
50         return isl_dim_total(dim);
51 }
52
53 static struct isl_dim *set_name(struct isl_dim *dim,
54                                  enum isl_dim_type type, unsigned pos,
55                                  struct isl_name *name)
56 {
57         struct isl_ctx *ctx = dim->ctx;
58         dim = isl_dim_cow(dim);
59
60         if (!dim)
61                 goto error;
62
63         pos = global_pos(dim, type, pos);
64         isl_assert(ctx, pos != isl_dim_total(dim), goto error);
65
66         if (pos >= dim->n_name) {
67                 if (!name)
68                         return dim;
69                 if (!dim->names) {
70                         dim->names = isl_calloc_array(dim->ctx,
71                                         struct isl_name *, isl_dim_total(dim));
72                         if (!dim->names)
73                                 goto error;
74                 } else {
75                         int i;
76                         dim->names = isl_realloc_array(dim->ctx, dim->names,
77                                         struct isl_name *, isl_dim_total(dim));
78                         if (!dim->names)
79                                 goto error;
80                         for (i = dim->n_name; i < isl_dim_total(dim); ++i)
81                                 dim->names[i] = NULL;
82                 }
83         }
84         dim->n_name = isl_dim_total(dim);
85
86         dim->names[pos] = name;
87
88         return dim;
89 error:
90         isl_name_free(ctx, name);
91         isl_dim_free(dim);
92         return NULL;
93 }
94
95 static struct isl_name *get_name(struct isl_dim *dim,
96                                  enum isl_dim_type type, unsigned pos)
97 {
98         if (!dim)
99                 return NULL;
100
101         pos = global_pos(dim, type, pos);
102         if (pos == isl_dim_total(dim))
103                 return NULL;
104         if (pos >= dim->n_name)
105                 return NULL;
106         return dim->names[pos];
107 }
108
109 static unsigned n(struct isl_dim *dim, enum isl_dim_type type)
110 {
111         switch (type) {
112         case isl_dim_param:     return dim->nparam;
113         case isl_dim_in:        return dim->n_in;
114         case isl_dim_out:       return dim->n_out;
115         }
116 }
117
118 static struct isl_dim *copy_names(struct isl_dim *dst,
119         enum isl_dim_type dst_type, struct isl_dim *src,
120         enum isl_dim_type src_type)
121 {
122         int i;
123         struct isl_name *name;
124
125         for (i = 0; i < n(dst, dst_type); ++i) {
126                 name = get_name(src, src_type, i);
127                 if (!name)
128                         continue;
129                 dst = set_name(dst, dst_type, i, isl_name_copy(dst->ctx, name));
130                 if (!dst)
131                         return NULL;
132         }
133         return dst;
134 }
135
136 struct isl_dim *isl_dim_dup(struct isl_dim *dim)
137 {
138         struct isl_dim *dup;
139         dup = isl_dim_alloc(dim->ctx, dim->nparam, dim->n_in, dim->n_out);
140         if (!dim->names)
141                 return dup;
142         dup = copy_names(dup, isl_dim_param, dim, isl_dim_param);
143         dup = copy_names(dup, isl_dim_in, dim, isl_dim_in);
144         dup = copy_names(dup, isl_dim_out, dim, isl_dim_out);
145         return dup;
146 }
147
148 struct isl_dim *isl_dim_cow(struct isl_dim *dim)
149 {
150         if (!dim)
151                 return NULL;
152
153         if (dim->ref == 1)
154                 return dim;
155         dim->ref--;
156         return isl_dim_dup(dim);
157 }
158
159 struct isl_dim *isl_dim_copy(struct isl_dim *dim)
160 {
161         if (!dim)
162                 return NULL;
163
164         dim->ref++;
165         return dim;
166 }
167
168 void isl_dim_free(struct isl_dim *dim)
169 {
170         int i;
171
172         if (!dim)
173                 return;
174
175         if (--dim->ref > 0)
176                 return;
177
178         for (i = 0; i < dim->n_name; ++i)
179                 isl_name_free(dim->ctx, dim->names[i]);
180         free(dim->names);
181         isl_ctx_deref(dim->ctx);
182         
183         free(dim);
184 }
185
186 struct isl_dim *isl_dim_set_name(struct isl_dim *dim,
187                                  enum isl_dim_type type, unsigned pos,
188                                  const char *s)
189 {
190         struct isl_name *name;
191         if (!dim)
192                 return NULL;
193         name = isl_name_get(dim->ctx, s);
194         if (!name)
195                 goto error;
196         return set_name(dim, type, pos, name);
197 error:
198         isl_dim_free(dim);
199         return NULL;
200 }
201
202 const char *isl_dim_get_name(struct isl_dim *dim,
203                                  enum isl_dim_type type, unsigned pos)
204 {
205         struct isl_name *name = get_name(dim, type, pos);
206         return name ? name->name : NULL;
207 }
208
209 static int match(struct isl_dim *dim1, enum isl_dim_type dim1_type,
210                 struct isl_dim *dim2, enum isl_dim_type dim2_type)
211 {
212         int i;
213
214         if (n(dim1, dim1_type) != n(dim2, dim2_type))
215                 return 0;
216
217         if (!dim1->names && !dim2->names)
218                 return 1;
219
220         for (i = 0; i < n(dim1, dim1_type); ++i) {
221                 if (get_name(dim1, dim1_type, i) !=
222                     get_name(dim2, dim2_type, i))
223                         return 0;
224         }
225         return 1;
226 }
227
228 static void get_names(struct isl_dim *dim, enum isl_dim_type type,
229         unsigned first, unsigned n, struct isl_name **names)
230 {
231         int i;
232
233         for (i = 0; i < n ; ++i)
234                 names[i] = get_name(dim, type, first+i);
235 }
236
237 struct isl_dim *isl_dim_extend(struct isl_dim *dim,
238                         unsigned nparam, unsigned n_in, unsigned n_out)
239 {
240         struct isl_name **names = NULL;
241
242         if (!dim)
243                 return NULL;
244         if (dim->nparam == nparam && dim->n_in == n_in && dim->n_out == n_out)
245                 return dim;
246
247         isl_assert(dim->ctx, dim->nparam <= nparam, goto error);
248         isl_assert(dim->ctx, dim->n_in <= n_in, goto error);
249         isl_assert(dim->ctx, dim->n_out <= n_out, goto error);
250
251         dim = isl_dim_cow(dim);
252
253         if (dim->names) {
254                 names = isl_calloc_array(dim->ctx, struct isl_name *,
255                                          nparam + n_in + n_out);
256                 if (!names)
257                         goto error;
258                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
259                 get_names(dim, isl_dim_in, 0, dim->n_in, names + nparam);
260                 get_names(dim, isl_dim_out, 0, dim->n_out,
261                                 names + nparam + n_in);
262                 free(dim->names);
263                 dim->names = names;
264         }
265         dim->nparam = nparam;
266         dim->n_in = n_in;
267         dim->n_out = n_out;
268
269         return dim;
270 error:
271         free(names);
272         isl_dim_free(dim);
273         return NULL;
274 }
275
276 struct isl_dim *isl_dim_join(struct isl_dim *left, struct isl_dim *right)
277 {
278         struct isl_dim *dim;
279
280         if (!left || !right)
281                 goto error;
282
283         isl_assert(left->ctx, match(left, isl_dim_param, right, isl_dim_param),
284                         goto error);
285         isl_assert(left->ctx, match(left, isl_dim_out, right, isl_dim_in),
286                         goto error);
287
288         dim = isl_dim_alloc(left->ctx, left->nparam, left->n_in, right->n_out);
289         if (!dim)
290                 goto error;
291
292         dim = copy_names(dim, isl_dim_param, left, isl_dim_param);
293         dim = copy_names(dim, isl_dim_in, left, isl_dim_in);
294         dim = copy_names(dim, isl_dim_out, right, isl_dim_out);
295
296         isl_dim_free(left);
297         isl_dim_free(right);
298
299         return dim;
300 error:
301         isl_dim_free(left);
302         isl_dim_free(right);
303         return NULL;
304 }
305
306 struct isl_dim *isl_dim_map(struct isl_dim *dim)
307 {
308         struct isl_name **names = NULL;
309
310         if (!dim)
311                 return NULL;
312         isl_assert(dim->ctx, dim->n_in == 0, goto error);
313         if (dim->n_out == 0)
314                 return dim;
315         dim = isl_dim_cow(dim);
316         if (!dim)
317                 return NULL;
318         if (dim->names) {
319                 names = isl_calloc_array(dim->ctx, struct isl_name *,
320                                         dim->nparam + dim->n_out + dim->n_out);
321                 if (!names)
322                         goto error;
323                 get_names(dim, isl_dim_param, 0, dim->nparam, names);
324                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->nparam);
325         }
326         dim->n_in = dim->n_out;
327         if (names) {
328                 copy_names(dim, isl_dim_out, dim, isl_dim_in);
329                 free(dim->names);
330                 dim->names = names;
331         }
332         return dim;
333 error:
334         isl_dim_free(dim);
335         return NULL;
336 }
337
338 static struct isl_dim *set_names(struct isl_dim *dim, enum isl_dim_type type,
339         unsigned first, unsigned n, struct isl_name **names)
340 {
341         int i;
342
343         for (i = 0; i < n ; ++i)
344                 dim = set_name(dim, type, first+i, names[i]);
345
346         return dim;
347 }
348
349 struct isl_dim *isl_dim_reverse(struct isl_dim *dim)
350 {
351         unsigned t;
352         struct isl_name **names = NULL;
353
354         if (!dim)
355                 return NULL;
356         if (match(dim, isl_dim_in, dim, isl_dim_out))
357                 return dim;
358
359         dim = isl_dim_cow(dim);
360         if (!dim)
361                 return NULL;
362
363         if (dim->names) {
364                 names = isl_alloc_array(dim->ctx, struct isl_name *,
365                                         dim->n_in + dim->n_out);
366                 if (!names)
367                         goto error;
368                 get_names(dim, isl_dim_in, 0, dim->n_in, names);
369                 get_names(dim, isl_dim_out, 0, dim->n_out, names + dim->n_in);
370         }
371
372         t = dim->n_in;
373         dim->n_in = dim->n_out;
374         dim->n_out = t;
375
376         if (dim->names) {
377                 dim = set_names(dim, isl_dim_out, 0, dim->n_out, names);
378                 dim = set_names(dim, isl_dim_in, 0, dim->n_in, names + dim->n_out);
379                 free(names);
380         }
381
382         return dim;
383 error:
384         free(names);
385         isl_dim_free(dim);
386         return NULL;
387 }
388
389 struct isl_dim *isl_dim_drop_inputs(struct isl_dim *dim,
390                 unsigned first, unsigned n)
391 {
392         int i;
393
394         if (!dim)
395                 return NULL;
396
397         if (n == 0)
398                 return dim;
399
400         isl_assert(dim->ctx, first + n <= dim->n_in, goto error);
401         dim = isl_dim_cow(dim);
402         if (!dim)
403                 goto error;
404         if (dim->names) {
405                 for (i = 0; i < n; ++i) {
406                         isl_name_free(dim->ctx,
407                                         get_name(dim, isl_dim_in, first+i));
408                 }
409                 for (i = first+n; i < dim->n_in; ++i)
410                         set_name(dim, isl_dim_in, i - n,
411                                 get_name(dim, isl_dim_in, i));
412                 get_names(dim, isl_dim_out, 0, dim->n_out,
413                                 dim->names + dim->nparam + dim->n_in - n);
414         }
415         dim->n_in -= n;
416         return dim;
417 error:
418         isl_dim_free(dim);
419         return NULL;
420 }
421
422 struct isl_dim *isl_dim_drop_outputs(struct isl_dim *dim,
423                 unsigned first, unsigned n)
424 {
425         int i;
426
427         if (!dim)
428                 return NULL;
429
430         if (n == 0)
431                 return dim;
432
433         isl_assert(dim->ctx, first + n <= dim->n_out, goto error);
434         dim = isl_dim_cow(dim);
435         if (!dim)
436                 goto error;
437         if (dim->names) {
438                 for (i = 0; i < n; ++i) {
439                         isl_name_free(dim->ctx,
440                                         get_name(dim, isl_dim_out, first+i));
441                 }
442                 for (i = first+n; i < dim->n_out; ++i)
443                         set_name(dim, isl_dim_out, i - n,
444                                 get_name(dim, isl_dim_out, i));
445         }
446         dim->n_out -= n;
447         return dim;
448 error:
449         isl_dim_free(dim);
450         return NULL;
451 }
452
453 struct isl_dim *isl_dim_domain(struct isl_dim *dim)
454 {
455         if (!dim)
456                 return NULL;
457         dim = isl_dim_drop_outputs(dim, 0, dim->n_out);
458         return isl_dim_reverse(dim);
459 }
460
461 struct isl_dim *isl_dim_underlying(struct isl_dim *dim, unsigned n_div)
462 {
463         int i;
464
465         if (!dim)
466                 return NULL;
467         if (n_div == 0 &&
468             dim->nparam == 0 && dim->n_in == 0 && dim->n_name == 0)
469                 return dim;
470         dim = isl_dim_cow(dim);
471         if (!dim)
472                 return NULL;
473         dim->n_out += dim->nparam + dim->n_in + n_div;
474         dim->nparam = 0;
475         dim->n_in = 0;
476
477         for (i = 0; i < dim->n_name; ++i)
478                 isl_name_free(dim->ctx, get_name(dim, isl_dim_out, i));
479         dim->n_name = 0;
480
481         return dim;
482 }
483
484 unsigned isl_dim_total(struct isl_dim *dim)
485 {
486         return dim->nparam + dim->n_in + dim->n_out;
487 }
488
489 int isl_dim_equal(struct isl_dim *dim1, struct isl_dim *dim2)
490 {
491         return match(dim1, isl_dim_param, dim2, isl_dim_param) &&
492                match(dim1, isl_dim_in, dim2, isl_dim_in) &&
493                match(dim1, isl_dim_out, dim2, isl_dim_out);
494 }
495
496 int isl_dim_compatible(struct isl_dim *dim1, struct isl_dim *dim2)
497 {
498         return dim1->nparam == dim2->nparam &&
499                dim1->n_in + dim1->n_out == dim2->n_in + dim2->n_out;
500 }