add isl_aff_mod_val
[platform/upstream/isl.git] / isl_ast.c
1 #include <isl_ast_private.h>
2
3 #undef BASE
4 #define BASE ast_expr
5
6 #include <isl_list_templ.c>
7
8 #undef BASE
9 #define BASE ast_node
10
11 #include <isl_list_templ.c>
12
13 isl_ctx *isl_ast_print_options_get_ctx(
14         __isl_keep isl_ast_print_options *options)
15 {
16         return options ? options->ctx : NULL;
17 }
18
19 __isl_give isl_ast_print_options *isl_ast_print_options_alloc(isl_ctx *ctx)
20 {
21         isl_ast_print_options *options;
22
23         options = isl_calloc_type(ctx, isl_ast_print_options);
24         if (!options)
25                 return NULL;
26
27         options->ctx = ctx;
28         isl_ctx_ref(ctx);
29         options->ref = 1;
30
31         return options;
32 }
33
34 __isl_give isl_ast_print_options *isl_ast_print_options_dup(
35         __isl_keep isl_ast_print_options *options)
36 {
37         isl_ctx *ctx;
38         isl_ast_print_options *dup;
39
40         if (!options)
41                 return NULL;
42
43         ctx = isl_ast_print_options_get_ctx(options);
44         dup = isl_ast_print_options_alloc(ctx);
45         if (!dup)
46                 return NULL;
47
48         dup->print_for = options->print_for;
49         dup->print_for_user = options->print_for_user;
50         dup->print_user = options->print_user;
51         dup->print_user_user = options->print_user_user;
52
53         return dup;
54 }
55
56 __isl_give isl_ast_print_options *isl_ast_print_options_cow(
57         __isl_take isl_ast_print_options *options)
58 {
59         if (!options)
60                 return NULL;
61
62         if (options->ref == 1)
63                 return options;
64         options->ref--;
65         return isl_ast_print_options_dup(options);
66 }
67
68 __isl_give isl_ast_print_options *isl_ast_print_options_copy(
69         __isl_keep isl_ast_print_options *options)
70 {
71         if (!options)
72                 return NULL;
73
74         options->ref++;
75         return options;
76 }
77
78 void *isl_ast_print_options_free(__isl_take isl_ast_print_options *options)
79 {
80         if (!options)
81                 return NULL;
82
83         if (--options->ref > 0)
84                 return NULL;
85
86         isl_ctx_deref(options->ctx);
87
88         free(options);
89         return NULL;
90 }
91
92 /* Set the print_user callback of "options" to "print_user".
93  *
94  * If this callback is set, then it used to print user nodes in the AST.
95  * Otherwise, the expression associated to the user node is printed.
96  */
97 __isl_give isl_ast_print_options *isl_ast_print_options_set_print_user(
98         __isl_take isl_ast_print_options *options,
99         __isl_give isl_printer *(*print_user)(__isl_take isl_printer *p,
100                 __isl_take isl_ast_print_options *options,
101                 __isl_keep isl_ast_node *node, void *user),
102         void *user)
103 {
104         options = isl_ast_print_options_cow(options);
105         if (!options)
106                 return NULL;
107
108         options->print_user = print_user;
109         options->print_user_user = user;
110
111         return options;
112 }
113
114 /* Set the print_for callback of "options" to "print_for".
115  *
116  * If this callback is set, then it used to print for nodes in the AST.
117  */
118 __isl_give isl_ast_print_options *isl_ast_print_options_set_print_for(
119         __isl_take isl_ast_print_options *options,
120         __isl_give isl_printer *(*print_for)(__isl_take isl_printer *p,
121                 __isl_take isl_ast_print_options *options,
122                 __isl_keep isl_ast_node *node, void *user),
123         void *user)
124 {
125         options = isl_ast_print_options_cow(options);
126         if (!options)
127                 return NULL;
128
129         options->print_for = print_for;
130         options->print_for_user = user;
131
132         return options;
133 }
134
135 __isl_give isl_ast_expr *isl_ast_expr_copy(__isl_keep isl_ast_expr *expr)
136 {
137         if (!expr)
138                 return NULL;
139
140         expr->ref++;
141         return expr;
142 }
143
144 __isl_give isl_ast_expr *isl_ast_expr_dup(__isl_keep isl_ast_expr *expr)
145 {
146         int i;
147         isl_ctx *ctx;
148         isl_ast_expr *dup;
149
150         if (!expr)
151                 return NULL;
152
153         ctx = isl_ast_expr_get_ctx(expr);
154         switch (expr->type) {
155         case isl_ast_expr_int:
156                 dup = isl_ast_expr_alloc_int(ctx, expr->u.i);
157                 break;
158         case isl_ast_expr_id:
159                 dup = isl_ast_expr_from_id(isl_id_copy(expr->u.id));
160                 break;
161         case isl_ast_expr_op:
162                 dup = isl_ast_expr_alloc_op(ctx,
163                                             expr->u.op.op, expr->u.op.n_arg);
164                 if (!dup)
165                         return NULL;
166                 for (i = 0; i < expr->u.op.n_arg; ++i)
167                         dup->u.op.args[i] =
168                                 isl_ast_expr_copy(expr->u.op.args[i]);
169                 break;
170         case isl_ast_expr_error:
171                 dup = NULL;
172         }
173
174         if (!dup)
175                 return NULL;
176
177         return dup;
178 }
179
180 __isl_give isl_ast_expr *isl_ast_expr_cow(__isl_take isl_ast_expr *expr)
181 {
182         if (!expr)
183                 return NULL;
184
185         if (expr->ref == 1)
186                 return expr;
187         expr->ref--;
188         return isl_ast_expr_dup(expr);
189 }
190
191 void *isl_ast_expr_free(__isl_take isl_ast_expr *expr)
192 {
193         int i;
194
195         if (!expr)
196                 return NULL;
197
198         if (--expr->ref > 0)
199                 return NULL;
200
201         isl_ctx_deref(expr->ctx);
202
203         switch (expr->type) {
204         case isl_ast_expr_int:
205                 isl_int_clear(expr->u.i);
206                 break;
207         case isl_ast_expr_id:
208                 isl_id_free(expr->u.id);
209                 break;
210         case isl_ast_expr_op:
211                 for (i = 0; i < expr->u.op.n_arg; ++i)
212                         isl_ast_expr_free(expr->u.op.args[i]);
213                 free(expr->u.op.args);
214                 break;
215         case isl_ast_expr_error:
216                 break;
217         }
218
219         free(expr);
220         return NULL;
221 }
222
223 isl_ctx *isl_ast_expr_get_ctx(__isl_keep isl_ast_expr *expr)
224 {
225         return expr ? expr->ctx : NULL;
226 }
227
228 enum isl_ast_expr_type isl_ast_expr_get_type(__isl_keep isl_ast_expr *expr)
229 {
230         return expr ? expr->type : isl_ast_expr_error;
231 }
232
233 int isl_ast_expr_get_int(__isl_keep isl_ast_expr *expr, isl_int *v)
234 {
235         if (!expr)
236                 return -1;
237         if (expr->type != isl_ast_expr_int)
238                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
239                         "expression not an int", return -1);
240         isl_int_set(*v, expr->u.i);
241         return 0;
242 }
243
244 __isl_give isl_id *isl_ast_expr_get_id(__isl_keep isl_ast_expr *expr)
245 {
246         if (!expr)
247                 return NULL;
248         if (expr->type != isl_ast_expr_id)
249                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
250                         "expression not an identifier", return NULL);
251
252         return isl_id_copy(expr->u.id);
253 }
254
255 enum isl_ast_op_type isl_ast_expr_get_op_type(__isl_keep isl_ast_expr *expr)
256 {
257         if (!expr)
258                 return isl_ast_op_error;
259         if (expr->type != isl_ast_expr_op)
260                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
261                         "expression not an operation", return isl_ast_op_error);
262         return expr->u.op.op;
263 }
264
265 int isl_ast_expr_get_op_n_arg(__isl_keep isl_ast_expr *expr)
266 {
267         if (!expr)
268                 return -1;
269         if (expr->type != isl_ast_expr_op)
270                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
271                         "expression not an operation", return -1);
272         return expr->u.op.n_arg;
273 }
274
275 __isl_give isl_ast_expr *isl_ast_expr_get_op_arg(__isl_keep isl_ast_expr *expr,
276         int pos)
277 {
278         if (!expr)
279                 return NULL;
280         if (expr->type != isl_ast_expr_op)
281                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
282                         "expression not an operation", return NULL);
283         if (pos < 0 || pos >= expr->u.op.n_arg)
284                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
285                         "index out of bounds", return NULL);
286
287         return isl_ast_expr_copy(expr->u.op.args[pos]);
288 }
289
290 /* Replace the argument at position "pos" of "expr" by "arg".
291  */
292 __isl_give isl_ast_expr *isl_ast_expr_set_op_arg(__isl_take isl_ast_expr *expr,
293         int pos, __isl_take isl_ast_expr *arg)
294 {
295         expr = isl_ast_expr_cow(expr);
296         if (!expr || !arg)
297                 goto error;
298         if (expr->type != isl_ast_expr_op)
299                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
300                         "expression not an operation", goto error);
301         if (pos < 0 || pos >= expr->u.op.n_arg)
302                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
303                         "index out of bounds", goto error);
304
305         isl_ast_expr_free(expr->u.op.args[pos]);
306         expr->u.op.args[pos] = arg;
307
308         return expr;
309 error:
310         isl_ast_expr_free(arg);
311         return isl_ast_expr_free(expr);
312 }
313
314 /* Create a new operation expression of operation type "op",
315  * with "n_arg" as yet unspecified arguments.
316  */
317 __isl_give isl_ast_expr *isl_ast_expr_alloc_op(isl_ctx *ctx,
318         enum isl_ast_op_type op, int n_arg)
319 {
320         isl_ast_expr *expr;
321
322         expr = isl_calloc_type(ctx, isl_ast_expr);
323         if (!expr)
324                 return NULL;
325
326         expr->ctx = ctx;
327         isl_ctx_ref(ctx);
328         expr->ref = 1;
329         expr->type = isl_ast_expr_op;
330         expr->u.op.op = op;
331         expr->u.op.n_arg = n_arg;
332         expr->u.op.args = isl_calloc_array(ctx, isl_ast_expr *, n_arg);
333
334         if (!expr->u.op.args)
335                 return isl_ast_expr_free(expr);
336
337         return expr;
338 }
339
340 /* Create a new id expression representing "id".
341  */
342 __isl_give isl_ast_expr *isl_ast_expr_from_id(__isl_take isl_id *id)
343 {
344         isl_ctx *ctx;
345         isl_ast_expr *expr;
346
347         if (!id)
348                 return NULL;
349
350         ctx = isl_id_get_ctx(id);
351         expr = isl_calloc_type(ctx, isl_ast_expr);
352         if (!expr)
353                 return isl_id_free(id);
354
355         expr->ctx = ctx;
356         isl_ctx_ref(ctx);
357         expr->ref = 1;
358         expr->type = isl_ast_expr_id;
359         expr->u.id = id;
360
361         return expr;
362 }
363
364 /* Create a new integer expression representing "i".
365  */
366 __isl_give isl_ast_expr *isl_ast_expr_alloc_int_si(isl_ctx *ctx, int i)
367 {
368         isl_ast_expr *expr;
369
370         expr = isl_calloc_type(ctx, isl_ast_expr);
371         if (!expr)
372                 return NULL;
373
374         expr->ctx = ctx;
375         isl_ctx_ref(ctx);
376         expr->ref = 1;
377         expr->type = isl_ast_expr_int;
378
379         isl_int_init(expr->u.i);
380         isl_int_set_si(expr->u.i, i);
381
382         return expr;
383 }
384
385 /* Create a new integer expression representing "i".
386  */
387 __isl_give isl_ast_expr *isl_ast_expr_alloc_int(isl_ctx *ctx, isl_int i)
388 {
389         isl_ast_expr *expr;
390
391         expr = isl_calloc_type(ctx, isl_ast_expr);
392         if (!expr)
393                 return NULL;
394
395         expr->ctx = ctx;
396         isl_ctx_ref(ctx);
397         expr->ref = 1;
398         expr->type = isl_ast_expr_int;
399
400         isl_int_init(expr->u.i);
401         isl_int_set(expr->u.i, i);
402
403         return expr;
404 }
405
406 /* Create an expression representing the negation of "arg".
407  */
408 __isl_give isl_ast_expr *isl_ast_expr_neg(__isl_take isl_ast_expr *arg)
409 {
410         isl_ctx *ctx;
411         isl_ast_expr *expr = NULL;
412
413         if (!arg)
414                 return NULL;
415
416         ctx = isl_ast_expr_get_ctx(arg);
417         expr = isl_ast_expr_alloc_op(ctx, isl_ast_op_minus, 1);
418         if (!expr)
419                 goto error;
420
421         expr->u.op.args[0] = arg;
422
423         return expr;
424 error:
425         isl_ast_expr_free(arg);
426         return NULL;
427 }
428
429 /* Create an expression representing the binary operation "type"
430  * applied to "expr1" and "expr2".
431  */
432 __isl_give isl_ast_expr *isl_ast_expr_alloc_binary(enum isl_ast_op_type type,
433         __isl_take isl_ast_expr *expr1, __isl_take isl_ast_expr *expr2)
434 {
435         isl_ctx *ctx;
436         isl_ast_expr *expr = NULL;
437
438         if (!expr1 || !expr2)
439                 goto error;
440
441         ctx = isl_ast_expr_get_ctx(expr1);
442         expr = isl_ast_expr_alloc_op(ctx, type, 2);
443         if (!expr)
444                 goto error;
445
446         expr->u.op.args[0] = expr1;
447         expr->u.op.args[1] = expr2;
448
449         return expr;
450 error:
451         isl_ast_expr_free(expr1);
452         isl_ast_expr_free(expr2);
453         return NULL;
454 }
455
456 /* Create an expression representing the sum of "expr1" and "expr2".
457  */
458 __isl_give isl_ast_expr *isl_ast_expr_add(__isl_take isl_ast_expr *expr1,
459         __isl_take isl_ast_expr *expr2)
460 {
461         return isl_ast_expr_alloc_binary(isl_ast_op_add, expr1, expr2);
462 }
463
464 /* Create an expression representing the difference of "expr1" and "expr2".
465  */
466 __isl_give isl_ast_expr *isl_ast_expr_sub(__isl_take isl_ast_expr *expr1,
467         __isl_take isl_ast_expr *expr2)
468 {
469         return isl_ast_expr_alloc_binary(isl_ast_op_sub, expr1, expr2);
470 }
471
472 /* Create an expression representing the product of "expr1" and "expr2".
473  */
474 __isl_give isl_ast_expr *isl_ast_expr_mul(__isl_take isl_ast_expr *expr1,
475         __isl_take isl_ast_expr *expr2)
476 {
477         return isl_ast_expr_alloc_binary(isl_ast_op_mul, expr1, expr2);
478 }
479
480 /* Create an expression representing the quotient of "expr1" and "expr2".
481  */
482 __isl_give isl_ast_expr *isl_ast_expr_div(__isl_take isl_ast_expr *expr1,
483         __isl_take isl_ast_expr *expr2)
484 {
485         return isl_ast_expr_alloc_binary(isl_ast_op_div, expr1, expr2);
486 }
487
488 /* Create an expression representing the conjunction of "expr1" and "expr2".
489  */
490 __isl_give isl_ast_expr *isl_ast_expr_and(__isl_take isl_ast_expr *expr1,
491         __isl_take isl_ast_expr *expr2)
492 {
493         return isl_ast_expr_alloc_binary(isl_ast_op_and, expr1, expr2);
494 }
495
496 /* Create an expression representing the disjunction of "expr1" and "expr2".
497  */
498 __isl_give isl_ast_expr *isl_ast_expr_or(__isl_take isl_ast_expr *expr1,
499         __isl_take isl_ast_expr *expr2)
500 {
501         return isl_ast_expr_alloc_binary(isl_ast_op_or, expr1, expr2);
502 }
503
504 isl_ctx *isl_ast_node_get_ctx(__isl_keep isl_ast_node *node)
505 {
506         return node ? node->ctx : NULL;
507 }
508
509 enum isl_ast_node_type isl_ast_node_get_type(__isl_keep isl_ast_node *node)
510 {
511         return node ? node->type : isl_ast_node_error;
512 }
513
514 __isl_give isl_ast_node *isl_ast_node_alloc(isl_ctx *ctx,
515         enum isl_ast_node_type type)
516 {
517         isl_ast_node *node;
518
519         node = isl_calloc_type(ctx, isl_ast_node);
520         if (!node)
521                 return NULL;
522
523         node->ctx = ctx;
524         isl_ctx_ref(ctx);
525         node->ref = 1;
526         node->type = type;
527
528         return node;
529 }
530
531 /* Create an if node with the given guard.
532  *
533  * The then body needs to be filled in later.
534  */
535 __isl_give isl_ast_node *isl_ast_node_alloc_if(__isl_take isl_ast_expr *guard)
536 {
537         isl_ast_node *node;
538
539         if (!guard)
540                 return NULL;
541
542         node = isl_ast_node_alloc(isl_ast_expr_get_ctx(guard), isl_ast_node_if);
543         if (!node)
544                 goto error;
545         node->u.i.guard = guard;
546
547         return node;
548 error:
549         isl_ast_expr_free(guard);
550         return NULL;
551 }
552
553 /* Create a for node with the given iterator.
554  *
555  * The remaining fields need to be filled in later.
556  */
557 __isl_give isl_ast_node *isl_ast_node_alloc_for(__isl_take isl_id *id)
558 {
559         isl_ast_node *node;
560         isl_ctx *ctx;
561
562         if (!id)
563                 return NULL;
564
565         ctx = isl_id_get_ctx(id);
566         node = isl_ast_node_alloc(ctx, isl_ast_node_for);
567         if (!node)
568                 return NULL;
569
570         node->u.f.iterator = isl_ast_expr_from_id(id);
571         if (!node->u.f.iterator)
572                 return isl_ast_node_free(node);
573
574         return node;
575 }
576
577 /* Create a user node evaluating "expr".
578  */
579 __isl_give isl_ast_node *isl_ast_node_alloc_user(__isl_take isl_ast_expr *expr)
580 {
581         isl_ctx *ctx;
582         isl_ast_node *node;
583
584         if (!expr)
585                 return NULL;
586
587         ctx = isl_ast_expr_get_ctx(expr);
588         node = isl_ast_node_alloc(ctx, isl_ast_node_user);
589         if (!node)
590                 goto error;
591
592         node->u.e.expr = expr;
593
594         return node;
595 error:
596         isl_ast_expr_free(expr);
597         return NULL;
598 }
599
600 /* Create a block node with the given children.
601  */
602 __isl_give isl_ast_node *isl_ast_node_alloc_block(
603         __isl_take isl_ast_node_list *list)
604 {
605         isl_ast_node *node;
606         isl_ctx *ctx;
607
608         if (!list)
609                 return NULL;
610
611         ctx = isl_ast_node_list_get_ctx(list);
612         node = isl_ast_node_alloc(ctx, isl_ast_node_block);
613         if (!node)
614                 goto error;
615
616         node->u.b.children = list;
617
618         return node;
619 error:
620         isl_ast_node_list_free(list);
621         return NULL;
622 }
623
624 /* Represent the given list of nodes as a single node, either by
625  * extract the node from a single element list or by creating
626  * a block node with the list of nodes as children.
627  */
628 __isl_give isl_ast_node *isl_ast_node_from_ast_node_list(
629         __isl_take isl_ast_node_list *list)
630 {
631         isl_ast_node *node;
632
633         if (isl_ast_node_list_n_ast_node(list) != 1)
634                 return isl_ast_node_alloc_block(list);
635
636         node = isl_ast_node_list_get_ast_node(list, 0);
637         isl_ast_node_list_free(list);
638
639         return node;
640 }
641
642 __isl_give isl_ast_node *isl_ast_node_copy(__isl_keep isl_ast_node *node)
643 {
644         if (!node)
645                 return NULL;
646
647         node->ref++;
648         return node;
649 }
650
651 __isl_give isl_ast_node *isl_ast_node_dup(__isl_keep isl_ast_node *node)
652 {
653         isl_ast_node *dup;
654
655         if (!node)
656                 return NULL;
657
658         dup = isl_ast_node_alloc(isl_ast_node_get_ctx(node), node->type);
659         if (!dup)
660                 return NULL;
661
662         switch (node->type) {
663         case isl_ast_node_if:
664                 dup->u.i.guard = isl_ast_expr_copy(node->u.i.guard);
665                 dup->u.i.then = isl_ast_node_copy(node->u.i.then);
666                 dup->u.i.else_node = isl_ast_node_copy(node->u.i.else_node);
667                 if (!dup->u.i.guard  || !dup->u.i.then ||
668                     (node->u.i.else_node && !dup->u.i.else_node))
669                         return isl_ast_node_free(dup);
670                 break;
671         case isl_ast_node_for:
672                 dup->u.f.iterator = isl_ast_expr_copy(node->u.f.iterator);
673                 dup->u.f.init = isl_ast_expr_copy(node->u.f.init);
674                 dup->u.f.cond = isl_ast_expr_copy(node->u.f.cond);
675                 dup->u.f.inc = isl_ast_expr_copy(node->u.f.inc);
676                 dup->u.f.body = isl_ast_node_copy(node->u.f.body);
677                 if (!dup->u.f.iterator || !dup->u.f.init || !dup->u.f.cond ||
678                     !dup->u.f.inc || !dup->u.f.body)
679                         return isl_ast_node_free(dup);
680                 break;
681         case isl_ast_node_block:
682                 dup->u.b.children = isl_ast_node_list_copy(node->u.b.children);
683                 if (!dup->u.b.children)
684                         return isl_ast_node_free(dup);
685                 break;
686         case isl_ast_node_user:
687                 dup->u.e.expr = isl_ast_expr_copy(node->u.e.expr);
688                 if (!dup->u.e.expr)
689                         return isl_ast_node_free(dup);
690                 break;
691         case isl_ast_node_error:
692                 break;
693         }
694
695         return dup;
696 }
697
698 __isl_give isl_ast_node *isl_ast_node_cow(__isl_take isl_ast_node *node)
699 {
700         if (!node)
701                 return NULL;
702
703         if (node->ref == 1)
704                 return node;
705         node->ref--;
706         return isl_ast_node_dup(node);
707 }
708
709 void *isl_ast_node_free(__isl_take isl_ast_node *node)
710 {
711         if (!node)
712                 return NULL;
713
714         if (--node->ref > 0)
715                 return NULL;
716
717         switch (node->type) {
718         case isl_ast_node_if:
719                 isl_ast_expr_free(node->u.i.guard);
720                 isl_ast_node_free(node->u.i.then);
721                 isl_ast_node_free(node->u.i.else_node);
722                 break;
723         case isl_ast_node_for:
724                 isl_ast_expr_free(node->u.f.iterator);
725                 isl_ast_expr_free(node->u.f.init);
726                 isl_ast_expr_free(node->u.f.cond);
727                 isl_ast_expr_free(node->u.f.inc);
728                 isl_ast_node_free(node->u.f.body);
729                 break;
730         case isl_ast_node_block:
731                 isl_ast_node_list_free(node->u.b.children);
732                 break;
733         case isl_ast_node_user:
734                 isl_ast_expr_free(node->u.e.expr);
735                 break;
736         case isl_ast_node_error:
737                 break;
738         }
739
740         isl_id_free(node->annotation);
741         isl_ctx_deref(node->ctx);
742         free(node);
743
744         return NULL;
745 }
746
747 /* Replace the body of the for node "node" by "body".
748  */
749 __isl_give isl_ast_node *isl_ast_node_for_set_body(
750         __isl_take isl_ast_node *node, __isl_take isl_ast_node *body)
751 {
752         node = isl_ast_node_cow(node);
753         if (!node || !body)
754                 goto error;
755         if (node->type != isl_ast_node_for)
756                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
757                         "not a for node", goto error);
758
759         isl_ast_node_free(node->u.f.body);
760         node->u.f.body = body;
761
762         return node;
763 error:
764         isl_ast_node_free(node);
765         isl_ast_node_free(body);
766         return NULL;
767 }
768
769 __isl_give isl_ast_node *isl_ast_node_for_get_body(
770         __isl_keep isl_ast_node *node)
771 {
772         if (!node)
773                 return NULL;
774         if (node->type != isl_ast_node_for)
775                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
776                         "not a for node", return NULL);
777         return isl_ast_node_copy(node->u.f.body);
778 }
779
780 /* Mark the given for node as being degenerate.
781  */
782 __isl_give isl_ast_node *isl_ast_node_for_mark_degenerate(
783         __isl_take isl_ast_node *node)
784 {
785         node = isl_ast_node_cow(node);
786         if (!node)
787                 return NULL;
788         node->u.f.degenerate = 1;
789         return node;
790 }
791
792 int isl_ast_node_for_is_degenerate(__isl_keep isl_ast_node *node)
793 {
794         if (!node)
795                 return -1;
796         if (node->type != isl_ast_node_for)
797                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
798                         "not a for node", return -1);
799         return node->u.f.degenerate;
800 }
801
802 __isl_give isl_ast_expr *isl_ast_node_for_get_iterator(
803         __isl_keep isl_ast_node *node)
804 {
805         if (!node)
806                 return NULL;
807         if (node->type != isl_ast_node_for)
808                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
809                         "not a for node", return NULL);
810         return isl_ast_expr_copy(node->u.f.iterator);
811 }
812
813 __isl_give isl_ast_expr *isl_ast_node_for_get_init(
814         __isl_keep isl_ast_node *node)
815 {
816         if (!node)
817                 return NULL;
818         if (node->type != isl_ast_node_for)
819                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
820                         "not a for node", return NULL);
821         return isl_ast_expr_copy(node->u.f.init);
822 }
823
824 /* Return the condition expression of the given for node.
825  *
826  * If the for node is degenerate, then the condition is not explicitly
827  * stored in the node.  Instead, it is constructed as
828  *
829  *      iterator <= init
830  */
831 __isl_give isl_ast_expr *isl_ast_node_for_get_cond(
832         __isl_keep isl_ast_node *node)
833 {
834         if (!node)
835                 return NULL;
836         if (node->type != isl_ast_node_for)
837                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
838                         "not a for node", return NULL);
839         if (!node->u.f.degenerate)
840                 return isl_ast_expr_copy(node->u.f.cond);
841
842         return isl_ast_expr_alloc_binary(isl_ast_op_le,
843                                 isl_ast_expr_copy(node->u.f.iterator),
844                                 isl_ast_expr_copy(node->u.f.init));
845 }
846
847 /* Return the increment of the given for node.
848  *
849  * If the for node is degenerate, then the increment is not explicitly
850  * stored in the node.  We simply return "1".
851  */
852 __isl_give isl_ast_expr *isl_ast_node_for_get_inc(
853         __isl_keep isl_ast_node *node)
854 {
855         if (!node)
856                 return NULL;
857         if (node->type != isl_ast_node_for)
858                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
859                         "not a for node", return NULL);
860         if (!node->u.f.degenerate)
861                 return isl_ast_expr_copy(node->u.f.inc);
862         return isl_ast_expr_alloc_int_si(isl_ast_node_get_ctx(node), 1);
863 }
864
865 /* Replace the then branch of the if node "node" by "child".
866  */
867 __isl_give isl_ast_node *isl_ast_node_if_set_then(
868         __isl_take isl_ast_node *node, __isl_take isl_ast_node *child)
869 {
870         node = isl_ast_node_cow(node);
871         if (!node || !child)
872                 goto error;
873         if (node->type != isl_ast_node_if)
874                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
875                         "not an if node", goto error);
876
877         isl_ast_node_free(node->u.i.then);
878         node->u.i.then = child;
879
880         return node;
881 error:
882         isl_ast_node_free(node);
883         isl_ast_node_free(child);
884         return NULL;
885 }
886
887 __isl_give isl_ast_node *isl_ast_node_if_get_then(
888         __isl_keep isl_ast_node *node)
889 {
890         if (!node)
891                 return NULL;
892         if (node->type != isl_ast_node_if)
893                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
894                         "not an if node", return NULL);
895         return isl_ast_node_copy(node->u.i.then);
896 }
897
898 int isl_ast_node_if_has_else(
899         __isl_keep isl_ast_node *node)
900 {
901         if (!node)
902                 return -1;
903         if (node->type != isl_ast_node_if)
904                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
905                         "not an if node", return -1);
906         return node->u.i.else_node != NULL;
907 }
908
909 __isl_give isl_ast_node *isl_ast_node_if_get_else(
910         __isl_keep isl_ast_node *node)
911 {
912         if (!node)
913                 return NULL;
914         if (node->type != isl_ast_node_if)
915                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
916                         "not an if node", return NULL);
917         return isl_ast_node_copy(node->u.i.else_node);
918 }
919
920 __isl_give isl_ast_expr *isl_ast_node_if_get_cond(
921         __isl_keep isl_ast_node *node)
922 {
923         if (!node)
924                 return NULL;
925         if (node->type != isl_ast_node_if)
926                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
927                         "not a guard node", return NULL);
928         return isl_ast_expr_copy(node->u.i.guard);
929 }
930
931 __isl_give isl_ast_node_list *isl_ast_node_block_get_children(
932         __isl_keep isl_ast_node *node)
933 {
934         if (!node)
935                 return NULL;
936         if (node->type != isl_ast_node_block)
937                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
938                         "not a block node", return NULL);
939         return isl_ast_node_list_copy(node->u.b.children);
940 }
941
942 __isl_give isl_ast_expr *isl_ast_node_user_get_expr(
943         __isl_keep isl_ast_node *node)
944 {
945         if (!node)
946                 return NULL;
947
948         return isl_ast_expr_copy(node->u.e.expr);
949 }
950
951 __isl_give isl_id *isl_ast_node_get_annotation(__isl_keep isl_ast_node *node)
952 {
953         return node ? isl_id_copy(node->annotation) : NULL;
954 }
955
956 /* Replace node->annotation by "annotation".
957  */
958 __isl_give isl_ast_node *isl_ast_node_set_annotation(
959         __isl_take isl_ast_node *node, __isl_take isl_id *annotation)
960 {
961         node = isl_ast_node_cow(node);
962         if (!node || !annotation)
963                 goto error;
964
965         isl_id_free(node->annotation);
966         node->annotation = annotation;
967
968         return node;
969 error:
970         isl_id_free(annotation);
971         return isl_ast_node_free(node);
972 }
973
974 /* Textual C representation of the various operators.
975  */
976 static char *op_str[] = {
977         [isl_ast_op_and] = "&&",
978         [isl_ast_op_and_then] = "&&",
979         [isl_ast_op_or] = "||",
980         [isl_ast_op_or_else] = "||",
981         [isl_ast_op_max] = "max",
982         [isl_ast_op_min] = "min",
983         [isl_ast_op_minus] = "-",
984         [isl_ast_op_add] = "+",
985         [isl_ast_op_sub] = "-",
986         [isl_ast_op_mul] = "*",
987         [isl_ast_op_pdiv_q] = "/",
988         [isl_ast_op_pdiv_r] = "%",
989         [isl_ast_op_div] = "/",
990         [isl_ast_op_eq] = "==",
991         [isl_ast_op_le] = "<=",
992         [isl_ast_op_ge] = ">=",
993         [isl_ast_op_lt] = "<",
994         [isl_ast_op_gt] = ">"
995 };
996
997 /* Precedence in C of the various operators.
998  * Based on http://en.wikipedia.org/wiki/Operators_in_C_and_C++
999  * Lowest value means highest precedence.
1000  */
1001 static int op_prec[] = {
1002         [isl_ast_op_and] = 13,
1003         [isl_ast_op_and_then] = 13,
1004         [isl_ast_op_or] = 14,
1005         [isl_ast_op_or_else] = 14,
1006         [isl_ast_op_max] = 2,
1007         [isl_ast_op_min] = 2,
1008         [isl_ast_op_minus] = 3,
1009         [isl_ast_op_add] = 6,
1010         [isl_ast_op_sub] = 6,
1011         [isl_ast_op_mul] = 5,
1012         [isl_ast_op_div] = 5,
1013         [isl_ast_op_fdiv_q] = 2,
1014         [isl_ast_op_pdiv_q] = 5,
1015         [isl_ast_op_pdiv_r] = 5,
1016         [isl_ast_op_cond] = 15,
1017         [isl_ast_op_select] = 15,
1018         [isl_ast_op_eq] = 9,
1019         [isl_ast_op_le] = 8,
1020         [isl_ast_op_ge] = 8,
1021         [isl_ast_op_lt] = 8,
1022         [isl_ast_op_gt] = 8,
1023         [isl_ast_op_call] = 2
1024 };
1025
1026 /* Is the operator left-to-right associative?
1027  */
1028 static int op_left[] = {
1029         [isl_ast_op_and] = 1,
1030         [isl_ast_op_and_then] = 1,
1031         [isl_ast_op_or] = 1,
1032         [isl_ast_op_or_else] = 1,
1033         [isl_ast_op_max] = 1,
1034         [isl_ast_op_min] = 1,
1035         [isl_ast_op_minus] = 0,
1036         [isl_ast_op_add] = 1,
1037         [isl_ast_op_sub] = 1,
1038         [isl_ast_op_mul] = 1,
1039         [isl_ast_op_div] = 1,
1040         [isl_ast_op_fdiv_q] = 1,
1041         [isl_ast_op_pdiv_q] = 1,
1042         [isl_ast_op_pdiv_r] = 1,
1043         [isl_ast_op_cond] = 0,
1044         [isl_ast_op_select] = 0,
1045         [isl_ast_op_eq] = 1,
1046         [isl_ast_op_le] = 1,
1047         [isl_ast_op_ge] = 1,
1048         [isl_ast_op_lt] = 1,
1049         [isl_ast_op_gt] = 1,
1050         [isl_ast_op_call] = 1
1051 };
1052
1053 static int is_and(enum isl_ast_op_type op)
1054 {
1055         return op == isl_ast_op_and || op == isl_ast_op_and_then;
1056 }
1057
1058 static int is_or(enum isl_ast_op_type op)
1059 {
1060         return op == isl_ast_op_or || op == isl_ast_op_or_else;
1061 }
1062
1063 static int is_add_sub(enum isl_ast_op_type op)
1064 {
1065         return op == isl_ast_op_add || op == isl_ast_op_sub;
1066 }
1067
1068 static int is_div_mod(enum isl_ast_op_type op)
1069 {
1070         return op == isl_ast_op_div || op == isl_ast_op_pdiv_r;
1071 }
1072
1073 /* Do we need/want parentheses around "expr" as a subexpression of
1074  * an "op" operation?  If "left" is set, then "expr" is the left-most
1075  * operand.
1076  *
1077  * We only need parentheses if "expr" represents an operation.
1078  *
1079  * If op has a higher precedence than expr->u.op.op, then we need
1080  * parentheses.
1081  * If op and expr->u.op.op have the same precedence, but the operations
1082  * are performed in an order that is different from the associativity,
1083  * then we need parentheses.
1084  *
1085  * An and inside an or technically does not require parentheses,
1086  * but some compilers complain about that, so we add them anyway.
1087  *
1088  * Computations such as "a / b * c" and "a % b + c" can be somewhat
1089  * difficult to read, so we add parentheses for those as well.
1090  */
1091 static int sub_expr_need_parens(enum isl_ast_op_type op,
1092         __isl_keep isl_ast_expr *expr, int left)
1093 {
1094         if (expr->type != isl_ast_expr_op)
1095                 return 0;
1096
1097         if (op_prec[expr->u.op.op] > op_prec[op])
1098                 return 1;
1099         if (op_prec[expr->u.op.op] == op_prec[op] && left != op_left[op])
1100                 return 1;
1101
1102         if (is_or(op) && is_and(expr->u.op.op))
1103                 return 1;
1104         if (op == isl_ast_op_mul && expr->u.op.op != isl_ast_op_mul &&
1105             op_prec[expr->u.op.op] == op_prec[op])
1106                 return 1;
1107         if (is_add_sub(op) && is_div_mod(expr->u.op.op))
1108                 return 1;
1109
1110         return 0;
1111 }
1112
1113 /* Print "expr" as a subexpression of an "op" operation.
1114  * If "left" is set, then "expr" is the left-most operand.
1115  */
1116 static __isl_give isl_printer *print_sub_expr(__isl_take isl_printer *p,
1117         enum isl_ast_op_type op, __isl_keep isl_ast_expr *expr, int left)
1118 {
1119         int need_parens;
1120
1121         need_parens = sub_expr_need_parens(op, expr, left);
1122
1123         if (need_parens)
1124                 p = isl_printer_print_str(p, "(");
1125         p = isl_printer_print_ast_expr(p, expr);
1126         if (need_parens)
1127                 p = isl_printer_print_str(p, ")");
1128         return p;
1129 }
1130
1131 /* Print a min or max reduction "expr".
1132  */
1133 static __isl_give isl_printer *print_min_max(__isl_take isl_printer *p,
1134         __isl_keep isl_ast_expr *expr)
1135 {
1136         int i = 0;
1137
1138         for (i = 1; i < expr->u.op.n_arg; ++i) {
1139                 p = isl_printer_print_str(p, op_str[expr->u.op.op]);
1140                 p = isl_printer_print_str(p, "(");
1141         }
1142         p = isl_printer_print_ast_expr(p, expr->u.op.args[0]);
1143         for (i = 1; i < expr->u.op.n_arg; ++i) {
1144                 p = isl_printer_print_str(p, ", ");
1145                 p = isl_printer_print_ast_expr(p, expr->u.op.args[i]);
1146                 p = isl_printer_print_str(p, ")");
1147         }
1148
1149         return p;
1150 }
1151
1152 /* Print a function call "expr".
1153  *
1154  * The first argument represents the function to be called.
1155  */
1156 static __isl_give isl_printer *print_call(__isl_take isl_printer *p,
1157         __isl_keep isl_ast_expr *expr)
1158 {
1159         int i = 0;
1160
1161         p = isl_printer_print_ast_expr(p, expr->u.op.args[0]);
1162         p = isl_printer_print_str(p, "(");
1163         for (i = 1; i < expr->u.op.n_arg; ++i) {
1164                 if (i != 1)
1165                         p = isl_printer_print_str(p, ", ");
1166                 p = isl_printer_print_ast_expr(p, expr->u.op.args[i]);
1167         }
1168         p = isl_printer_print_str(p, ")");
1169
1170         return p;
1171 }
1172
1173 /* Print "expr" to "p".
1174  *
1175  * If we are printing in isl format, then we also print an indication
1176  * of the size of the expression (if it was computed).
1177  */
1178 __isl_give isl_printer *isl_printer_print_ast_expr(__isl_take isl_printer *p,
1179         __isl_keep isl_ast_expr *expr)
1180 {
1181         if (!p)
1182                 return NULL;
1183         if (!expr)
1184                 return isl_printer_free(p);
1185
1186         switch (expr->type) {
1187         case isl_ast_expr_op:
1188                 if (expr->u.op.op == isl_ast_op_call) {
1189                         p = print_call(p, expr);
1190                         break;
1191                 }
1192                 if (expr->u.op.n_arg == 1) {
1193                         p = isl_printer_print_str(p, op_str[expr->u.op.op]);
1194                         p = print_sub_expr(p, expr->u.op.op,
1195                                                 expr->u.op.args[0], 0);
1196                         break;
1197                 }
1198                 if (expr->u.op.op == isl_ast_op_fdiv_q) {
1199                         p = isl_printer_print_str(p, "floord(");
1200                         p = isl_printer_print_ast_expr(p, expr->u.op.args[0]);
1201                         p = isl_printer_print_str(p, ", ");
1202                         p = isl_printer_print_ast_expr(p, expr->u.op.args[1]);
1203                         p = isl_printer_print_str(p, ")");
1204                         break;
1205                 }
1206                 if (expr->u.op.op == isl_ast_op_max ||
1207                     expr->u.op.op == isl_ast_op_min) {
1208                         p = print_min_max(p, expr);
1209                         break;
1210                 }
1211                 if (expr->u.op.op == isl_ast_op_cond ||
1212                     expr->u.op.op == isl_ast_op_select) {
1213                         p = isl_printer_print_ast_expr(p, expr->u.op.args[0]);
1214                         p = isl_printer_print_str(p, " ? ");
1215                         p = isl_printer_print_ast_expr(p, expr->u.op.args[1]);
1216                         p = isl_printer_print_str(p, " : ");
1217                         p = isl_printer_print_ast_expr(p, expr->u.op.args[2]);
1218                         break;
1219                 }
1220                 if (expr->u.op.n_arg != 2)
1221                         isl_die(isl_printer_get_ctx(p), isl_error_internal,
1222                                 "operation should have two arguments",
1223                                 goto error);
1224                 p = print_sub_expr(p, expr->u.op.op, expr->u.op.args[0], 1);
1225                 p = isl_printer_print_str(p, " ");
1226                 p = isl_printer_print_str(p, op_str[expr->u.op.op]);
1227                 p = isl_printer_print_str(p, " ");
1228                 p = print_sub_expr(p, expr->u.op.op, expr->u.op.args[1], 0);
1229                 break;
1230         case isl_ast_expr_id:
1231                 p = isl_printer_print_str(p, isl_id_get_name(expr->u.id));
1232                 break;
1233         case isl_ast_expr_int:
1234                 p = isl_printer_print_isl_int(p, expr->u.i);
1235                 break;
1236         case isl_ast_expr_error:
1237                 break;
1238         }
1239
1240         return p;
1241 error:
1242         isl_printer_free(p);
1243         return NULL;
1244 }
1245
1246 /* Print "node" to "p" in "isl format".
1247  */
1248 static __isl_give isl_printer *print_ast_node_isl(__isl_take isl_printer *p,
1249         __isl_keep isl_ast_node *node)
1250 {
1251         p = isl_printer_print_str(p, "(");
1252         switch (node->type) {
1253         case isl_ast_node_for:
1254                 if (node->u.f.degenerate) {
1255                         p = isl_printer_print_ast_expr(p, node->u.f.init);
1256                 } else {
1257                         p = isl_printer_print_str(p, "init: ");
1258                         p = isl_printer_print_ast_expr(p, node->u.f.init);
1259                         p = isl_printer_print_str(p, ", ");
1260                         p = isl_printer_print_str(p, "cond: ");
1261                         p = isl_printer_print_ast_expr(p, node->u.f.cond);
1262                         p = isl_printer_print_str(p, ", ");
1263                         p = isl_printer_print_str(p, "inc: ");
1264                         p = isl_printer_print_ast_expr(p, node->u.f.inc);
1265                 }
1266                 if (node->u.f.body) {
1267                         p = isl_printer_print_str(p, ", ");
1268                         p = isl_printer_print_str(p, "body: ");
1269                         p = isl_printer_print_ast_node(p, node->u.f.body);
1270                 }
1271                 break;
1272         case isl_ast_node_user:
1273                 p = isl_printer_print_ast_expr(p, node->u.e.expr);
1274                 break;
1275         case isl_ast_node_if:
1276                 p = isl_printer_print_str(p, "guard: ");
1277                 p = isl_printer_print_ast_expr(p, node->u.i.guard);
1278                 if (node->u.i.then) {
1279                         p = isl_printer_print_str(p, ", ");
1280                         p = isl_printer_print_str(p, "then: ");
1281                         p = isl_printer_print_ast_node(p, node->u.i.then);
1282                 }
1283                 if (node->u.i.else_node) {
1284                         p = isl_printer_print_str(p, ", ");
1285                         p = isl_printer_print_str(p, "else: ");
1286                         p = isl_printer_print_ast_node(p, node->u.i.else_node);
1287                 }
1288                 break;
1289         case isl_ast_node_block:
1290                 p = isl_printer_print_ast_node_list(p, node->u.b.children);
1291                 break;
1292         default:
1293                 break;
1294         }
1295         p = isl_printer_print_str(p, ")");
1296         return p;
1297 }
1298
1299 /* Do we need to print a block around the body "node" of a for or if node?
1300  *
1301  * If the node is a block, then we need to print a block.
1302  * Also if the node is a degenerate for then we will print it as
1303  * an assignment followed by the body of the for loop, so we need a block
1304  * as well.
1305  */
1306 static int need_block(__isl_keep isl_ast_node *node)
1307 {
1308         if (node->type == isl_ast_node_block)
1309                 return 1;
1310         if (node->type == isl_ast_node_for && node->u.f.degenerate)
1311                 return 1;
1312         return 0;
1313 }
1314
1315 static __isl_give isl_printer *print_ast_node_c(__isl_take isl_printer *p,
1316         __isl_keep isl_ast_node *node,
1317         __isl_keep isl_ast_print_options *options, int in_block, int in_list);
1318 static __isl_give isl_printer *print_if_c(__isl_take isl_printer *p,
1319         __isl_keep isl_ast_node *node,
1320         __isl_keep isl_ast_print_options *options, int new_line);
1321
1322 /* Print the body "node" of a for or if node.
1323  * If "else_node" is set, then it is printed as well.
1324  *
1325  * We first check if we need to print out a block.
1326  * We always print out a block if there is an else node to make
1327  * sure that the else node is matched to the correct if node.
1328  *
1329  * If the else node is itself an if, then we print it as
1330  *
1331  *      } else if (..)
1332  *
1333  * Otherwise the else node is printed as
1334  *
1335  *      } else
1336  *        node
1337  */
1338 static __isl_give isl_printer *print_body_c(__isl_take isl_printer *p,
1339         __isl_keep isl_ast_node *node, __isl_keep isl_ast_node *else_node,
1340         __isl_keep isl_ast_print_options *options)
1341 {
1342         if (!node)
1343                 return isl_printer_free(p);
1344
1345         if (!else_node && !need_block(node)) {
1346                 p = isl_printer_end_line(p);
1347                 p = isl_printer_indent(p, 2);
1348                 p = isl_ast_node_print(node, p,
1349                                         isl_ast_print_options_copy(options));
1350                 p = isl_printer_indent(p, -2);
1351                 return p;
1352         }
1353
1354         p = isl_printer_print_str(p, " {");
1355         p = isl_printer_end_line(p);
1356         p = isl_printer_indent(p, 2);
1357         p = print_ast_node_c(p, node, options, 1, 0);
1358         p = isl_printer_indent(p, -2);
1359         p = isl_printer_start_line(p);
1360         p = isl_printer_print_str(p, "}");
1361         if (else_node) {
1362                 if (else_node->type == isl_ast_node_if) {
1363                         p = isl_printer_print_str(p, " else ");
1364                         p = print_if_c(p, else_node, options, 0);
1365                 } else {
1366                         p = isl_printer_print_str(p, " else");
1367                         p = print_body_c(p, else_node, NULL, options);
1368                 }
1369         } else
1370                 p = isl_printer_end_line(p);
1371
1372         return p;
1373 }
1374
1375 /* Print the start of a compound statement.
1376  */
1377 static __isl_give isl_printer *start_block(__isl_take isl_printer *p)
1378 {
1379         p = isl_printer_start_line(p);
1380         p = isl_printer_print_str(p, "{");
1381         p = isl_printer_end_line(p);
1382         p = isl_printer_indent(p, 2);
1383
1384         return p;
1385 }
1386
1387 /* Print the end of a compound statement.
1388  */
1389 static __isl_give isl_printer *end_block(__isl_take isl_printer *p)
1390 {
1391         p = isl_printer_indent(p, -2);
1392         p = isl_printer_start_line(p);
1393         p = isl_printer_print_str(p, "}");
1394         p = isl_printer_end_line(p);
1395
1396         return p;
1397 }
1398
1399 /* Print the for node "node".
1400  *
1401  * If the for node is degenerate, it is printed as
1402  *
1403  *      type iterator = init;
1404  *      body
1405  *
1406  * Otherwise, it is printed as
1407  *
1408  *      for (type iterator = init; cond; iterator += inc)
1409  *              body
1410  *
1411  * "in_block" is set if we are currently inside a block.
1412  * "in_list" is set if the current node is not alone in the block.
1413  * If we are not in a block or if the current not is not alone in the block
1414  * then we print a block around a degenerate for loop such that the variable
1415  * declaration will not conflict with any potential other declaration
1416  * of the same variable.
1417  */
1418 static __isl_give isl_printer *print_for_c(__isl_take isl_printer *p,
1419         __isl_keep isl_ast_node *node,
1420         __isl_keep isl_ast_print_options *options, int in_block, int in_list)
1421 {
1422         isl_id *id;
1423         const char *name;
1424         const char *type;
1425
1426         type = isl_options_get_ast_iterator_type(isl_printer_get_ctx(p));
1427         if (!node->u.f.degenerate) {
1428                 id = isl_ast_expr_get_id(node->u.f.iterator);
1429                 name = isl_id_get_name(id);
1430                 isl_id_free(id);
1431                 p = isl_printer_start_line(p);
1432                 p = isl_printer_print_str(p, "for (");
1433                 p = isl_printer_print_str(p, type);
1434                 p = isl_printer_print_str(p, " ");
1435                 p = isl_printer_print_str(p, name);
1436                 p = isl_printer_print_str(p, " = ");
1437                 p = isl_printer_print_ast_expr(p, node->u.f.init);
1438                 p = isl_printer_print_str(p, "; ");
1439                 p = isl_printer_print_ast_expr(p, node->u.f.cond);
1440                 p = isl_printer_print_str(p, "; ");
1441                 p = isl_printer_print_str(p, name);
1442                 p = isl_printer_print_str(p, " += ");
1443                 p = isl_printer_print_ast_expr(p, node->u.f.inc);
1444                 p = isl_printer_print_str(p, ")");
1445                 p = print_body_c(p, node->u.f.body, NULL, options);
1446         } else {
1447                 id = isl_ast_expr_get_id(node->u.f.iterator);
1448                 name = isl_id_get_name(id);
1449                 isl_id_free(id);
1450                 if (!in_block || in_list)
1451                         p = start_block(p);
1452                 p = isl_printer_start_line(p);
1453                 p = isl_printer_print_str(p, type);
1454                 p = isl_printer_print_str(p, " ");
1455                 p = isl_printer_print_str(p, name);
1456                 p = isl_printer_print_str(p, " = ");
1457                 p = isl_printer_print_ast_expr(p, node->u.f.init);
1458                 p = isl_printer_print_str(p, ";");
1459                 p = isl_printer_end_line(p);
1460                 p = print_ast_node_c(p, node->u.f.body, options, 1, 0);
1461                 if (!in_block || in_list)
1462                         p = end_block(p);
1463         }
1464
1465         return p;
1466 }
1467
1468 /* Print the if node "node".
1469  * If "new_line" is set then the if node should be printed on a new line.
1470  */
1471 static __isl_give isl_printer *print_if_c(__isl_take isl_printer *p,
1472         __isl_keep isl_ast_node *node,
1473         __isl_keep isl_ast_print_options *options, int new_line)
1474 {
1475         if (new_line)
1476                 p = isl_printer_start_line(p);
1477         p = isl_printer_print_str(p, "if (");
1478         p = isl_printer_print_ast_expr(p, node->u.i.guard);
1479         p = isl_printer_print_str(p, ")");
1480         p = print_body_c(p, node->u.i.then, node->u.i.else_node, options);
1481
1482         return p;
1483 }
1484
1485 /* Print the "node" to "p".
1486  *
1487  * "in_block" is set if we are currently inside a block.
1488  * If so, we do not print a block around the children of a block node.
1489  * We do this to avoid an extra block around the body of a degenerate
1490  * for node.
1491  *
1492  * "in_list" is set if the current node is not alone in the block.
1493  */
1494 static __isl_give isl_printer *print_ast_node_c(__isl_take isl_printer *p,
1495         __isl_keep isl_ast_node *node,
1496         __isl_keep isl_ast_print_options *options, int in_block, int in_list)
1497 {
1498         switch (node->type) {
1499         case isl_ast_node_for:
1500                 if (options->print_for)
1501                         return options->print_for(p,
1502                                         isl_ast_print_options_copy(options),
1503                                         node, options->print_for_user);
1504                 p = print_for_c(p, node, options, in_block, in_list);
1505                 break;
1506         case isl_ast_node_if:
1507                 p = print_if_c(p, node, options, 1);
1508                 break;
1509         case isl_ast_node_block:
1510                 if (!in_block)
1511                         p = start_block(p);
1512                 p = isl_ast_node_list_print(node->u.b.children, p, options);
1513                 if (!in_block)
1514                         p = end_block(p);
1515                 break;
1516         case isl_ast_node_user:
1517                 if (options->print_user)
1518                         return options->print_user(p,
1519                                         isl_ast_print_options_copy(options),
1520                                         node, options->print_user_user);
1521                 p = isl_printer_start_line(p);
1522                 p = isl_printer_print_ast_expr(p, node->u.e.expr);
1523                 p = isl_printer_print_str(p, ";");
1524                 p = isl_printer_end_line(p);
1525                 break;
1526         case isl_ast_node_error:
1527                 break;
1528         }
1529         return p;
1530 }
1531
1532 /* Print the for node "node" to "p".
1533  */
1534 __isl_give isl_printer *isl_ast_node_for_print(__isl_keep isl_ast_node *node,
1535         __isl_take isl_printer *p, __isl_take isl_ast_print_options *options)
1536 {
1537         if (!node || !options)
1538                 goto error;
1539         if (node->type != isl_ast_node_for)
1540                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
1541                         "not a for node", goto error);
1542         p = print_for_c(p, node, options, 0, 0);
1543         isl_ast_print_options_free(options);
1544         return p;
1545 error:
1546         isl_ast_print_options_free(options);
1547         isl_printer_free(p);
1548         return NULL;
1549 }
1550
1551 /* Print the if node "node" to "p".
1552  */
1553 __isl_give isl_printer *isl_ast_node_if_print(__isl_keep isl_ast_node *node,
1554         __isl_take isl_printer *p, __isl_take isl_ast_print_options *options)
1555 {
1556         if (!node || !options)
1557                 goto error;
1558         if (node->type != isl_ast_node_if)
1559                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
1560                         "not an if node", goto error);
1561         p = print_if_c(p, node, options, 1);
1562         isl_ast_print_options_free(options);
1563         return p;
1564 error:
1565         isl_ast_print_options_free(options);
1566         isl_printer_free(p);
1567         return NULL;
1568 }
1569
1570 /* Print "node" to "p".
1571  */
1572 __isl_give isl_printer *isl_ast_node_print(__isl_keep isl_ast_node *node,
1573         __isl_take isl_printer *p, __isl_take isl_ast_print_options *options)
1574 {
1575         if (!options || !node)
1576                 goto error;
1577         p = print_ast_node_c(p, node, options, 0, 0);
1578         isl_ast_print_options_free(options);
1579         return p;
1580 error:
1581         isl_ast_print_options_free(options);
1582         isl_printer_free(p);
1583         return NULL;
1584 }
1585
1586 /* Print "node" to "p".
1587  */
1588 __isl_give isl_printer *isl_printer_print_ast_node(__isl_take isl_printer *p,
1589         __isl_keep isl_ast_node *node)
1590 {
1591         int format;
1592         isl_ast_print_options *options;
1593
1594         if (!p)
1595                 return NULL;
1596
1597         format = isl_printer_get_output_format(p);
1598         switch (format) {
1599         case ISL_FORMAT_ISL:
1600                 p = print_ast_node_isl(p, node);
1601                 break;
1602         case ISL_FORMAT_C:
1603                 options = isl_ast_print_options_alloc(isl_printer_get_ctx(p));
1604                 p = isl_ast_node_print(node, p, options);
1605                 break;
1606         default:
1607                 isl_die(isl_printer_get_ctx(p), isl_error_unsupported,
1608                         "output format not supported for ast_node",
1609                         return isl_printer_free(p));
1610         }
1611
1612         return p;
1613 }
1614
1615 /* Print the list of nodes "list" to "p".
1616  */
1617 __isl_give isl_printer *isl_ast_node_list_print(
1618         __isl_keep isl_ast_node_list *list, __isl_take isl_printer *p,
1619         __isl_keep isl_ast_print_options *options)
1620 {
1621         int i;
1622
1623         if (!p || !list || !options)
1624                 return isl_printer_free(p);
1625
1626         for (i = 0; i < list->n; ++i)
1627                 p = print_ast_node_c(p, list->p[i], options, 1, 1);
1628
1629         return p;
1630 }
1631
1632 #define ISL_AST_MACRO_FLOORD    (1 << 0)
1633 #define ISL_AST_MACRO_MIN       (1 << 1)
1634 #define ISL_AST_MACRO_MAX       (1 << 2)
1635 #define ISL_AST_MACRO_ALL       (ISL_AST_MACRO_FLOORD | \
1636                                  ISL_AST_MACRO_MIN | \
1637                                  ISL_AST_MACRO_MAX)
1638
1639 /* If "expr" contains an isl_ast_op_min, isl_ast_op_max or isl_ast_op_fdiv_q
1640  * then set the corresponding bit in "macros".
1641  */
1642 static int ast_expr_required_macros(__isl_keep isl_ast_expr *expr, int macros)
1643 {
1644         int i;
1645
1646         if (macros == ISL_AST_MACRO_ALL)
1647                 return macros;
1648
1649         if (expr->type != isl_ast_expr_op)
1650                 return macros;
1651
1652         if (expr->u.op.op == isl_ast_op_min)
1653                 macros |= ISL_AST_MACRO_MIN;
1654         if (expr->u.op.op == isl_ast_op_max)
1655                 macros |= ISL_AST_MACRO_MAX;
1656         if (expr->u.op.op == isl_ast_op_fdiv_q)
1657                 macros |= ISL_AST_MACRO_FLOORD;
1658
1659         for (i = 0; i < expr->u.op.n_arg; ++i)
1660                 macros = ast_expr_required_macros(expr->u.op.args[i], macros);
1661
1662         return macros;
1663 }
1664
1665 static int ast_node_list_required_macros(__isl_keep isl_ast_node_list *list,
1666         int macros);
1667
1668 /* If "node" contains an isl_ast_op_min, isl_ast_op_max or isl_ast_op_fdiv_q
1669  * then set the corresponding bit in "macros".
1670  */
1671 static int ast_node_required_macros(__isl_keep isl_ast_node *node, int macros)
1672 {
1673         if (macros == ISL_AST_MACRO_ALL)
1674                 return macros;
1675
1676         switch (node->type) {
1677         case isl_ast_node_for:
1678                 macros = ast_expr_required_macros(node->u.f.init, macros);
1679                 if (!node->u.f.degenerate) {
1680                         macros = ast_expr_required_macros(node->u.f.cond,
1681                                                                 macros);
1682                         macros = ast_expr_required_macros(node->u.f.inc,
1683                                                                 macros);
1684                 }
1685                 macros = ast_node_required_macros(node->u.f.body, macros);
1686                 break;
1687         case isl_ast_node_if:
1688                 macros = ast_expr_required_macros(node->u.i.guard, macros);
1689                 macros = ast_node_required_macros(node->u.i.then, macros);
1690                 if (node->u.i.else_node)
1691                         macros = ast_node_required_macros(node->u.i.else_node,
1692                                                                 macros);
1693                 break;
1694         case isl_ast_node_block:
1695                 macros = ast_node_list_required_macros(node->u.b.children,
1696                                                         macros);
1697                 break;
1698         case isl_ast_node_user:
1699                 macros = ast_expr_required_macros(node->u.e.expr, macros);
1700                 break;
1701         case isl_ast_node_error:
1702                 break;
1703         }
1704
1705         return macros;
1706 }
1707
1708 /* If "list" contains an isl_ast_op_min, isl_ast_op_max or isl_ast_op_fdiv_q
1709  * then set the corresponding bit in "macros".
1710  */
1711 static int ast_node_list_required_macros(__isl_keep isl_ast_node_list *list,
1712         int macros)
1713 {
1714         int i;
1715
1716         for (i = 0; i < list->n; ++i)
1717                 macros = ast_node_required_macros(list->p[i], macros);
1718
1719         return macros;
1720 }
1721
1722 /* Print a macro definition for the operator "type".
1723  */
1724 __isl_give isl_printer *isl_ast_op_type_print_macro(
1725         enum isl_ast_op_type type, __isl_take isl_printer *p)
1726 {
1727         switch (type) {
1728         case isl_ast_op_min:
1729                 p = isl_printer_start_line(p);
1730                 p = isl_printer_print_str(p,
1731                         "#define min(x,y)    ((x) < (y) ? (x) : (y))");
1732                 p = isl_printer_end_line(p);
1733                 break;
1734         case isl_ast_op_max:
1735                 p = isl_printer_start_line(p);
1736                 p = isl_printer_print_str(p,
1737                         "#define max(x,y)    ((x) > (y) ? (x) : (y))");
1738                 p = isl_printer_end_line(p);
1739                 break;
1740         case isl_ast_op_fdiv_q:
1741                 p = isl_printer_start_line(p);
1742                 p = isl_printer_print_str(p,
1743                         "#define floord(n,d) "
1744                         "(((n)<0) ? -((-(n)+(d)-1)/(d)) : (n)/(d))");
1745                 p = isl_printer_end_line(p);
1746                 break;
1747         default:
1748                 break;
1749         }
1750
1751         return p;
1752 }
1753
1754 /* Call "fn" for each type of operation that appears in "node"
1755  * and that requires a macro definition.
1756  */
1757 int isl_ast_node_foreach_ast_op_type(__isl_keep isl_ast_node *node,
1758         int (*fn)(enum isl_ast_op_type type, void *user), void *user)
1759 {
1760         int macros;
1761
1762         if (!node)
1763                 return -1;
1764
1765         macros = ast_node_required_macros(node, 0);
1766
1767         if (macros & ISL_AST_MACRO_MIN && fn(isl_ast_op_min, user) < 0)
1768                 return -1;
1769         if (macros & ISL_AST_MACRO_MAX && fn(isl_ast_op_max, user) < 0)
1770                 return -1;
1771         if (macros & ISL_AST_MACRO_FLOORD && fn(isl_ast_op_fdiv_q, user) < 0)
1772                 return -1;
1773
1774         return 0;
1775 }
1776
1777 static int ast_op_type_print_macro(enum isl_ast_op_type type, void *user)
1778 {
1779         isl_printer **p = user;
1780
1781         *p = isl_ast_op_type_print_macro(type, *p);
1782
1783         return 0;
1784 }
1785
1786 /* Print macro definitions for all the macros used in the result
1787  * of printing "node.
1788  */
1789 __isl_give isl_printer *isl_ast_node_print_macros(
1790         __isl_keep isl_ast_node *node, __isl_take isl_printer *p)
1791 {
1792         if (isl_ast_node_foreach_ast_op_type(node,
1793                                             &ast_op_type_print_macro, &p) < 0)
1794                 return isl_printer_free(p);
1795         return p;
1796 }