isl_ast_expr: add isl_ast_op_lt and isl_ast_op_gt operations
[platform/upstream/isl.git] / isl_ast.c
1 #include <isl_ast_private.h>
2 #include <isl_list_private.h>
3
4 #undef BASE
5 #define BASE ast_expr
6
7 #include <isl_list_templ.c>
8
9 #undef BASE
10 #define BASE ast_node
11
12 #include <isl_list_templ.c>
13
14 __isl_give isl_ast_print_options *isl_ast_print_options_alloc(isl_ctx *ctx)
15 {
16         return isl_calloc_type(ctx, isl_ast_print_options);
17 }
18
19 void *isl_ast_print_options_free(__isl_take isl_ast_print_options *options)
20 {
21         free(options);
22         return NULL;
23 }
24
25 /* Set the print_user callback of "options" to "print_user".
26  *
27  * If this callback is set, then it used to print user nodes in the AST.
28  * Otherwise, the expression associated to the user node is printed.
29  */
30 __isl_give isl_ast_print_options *isl_ast_print_options_set_print_user(
31         __isl_take isl_ast_print_options *options,
32         __isl_give isl_printer *(*print_user)(__isl_take isl_printer *p,
33                 __isl_keep isl_ast_node *node, void *user),
34         void *user)
35 {
36         if (!options)
37                 return NULL;
38
39         options->print_user = print_user;
40         options->print_user_user = user;
41
42         return options;
43 }
44
45 /* Set the print_for callback of "options" to "print_for".
46  *
47  * If this callback is set, then it used to print for nodes in the AST.
48  */
49 __isl_give isl_ast_print_options *isl_ast_print_options_set_print_for(
50         __isl_take isl_ast_print_options *options,
51         __isl_give isl_printer *(*print_for)(__isl_take isl_printer *p,
52                 __isl_keep isl_ast_node *node, void *user),
53         void *user)
54 {
55         if (!options)
56                 return NULL;
57
58         options->print_for = print_for;
59         options->print_for_user = user;
60
61         return options;
62 }
63
64 __isl_give isl_ast_expr *isl_ast_expr_copy(__isl_keep isl_ast_expr *expr)
65 {
66         if (!expr)
67                 return NULL;
68
69         expr->ref++;
70         return expr;
71 }
72
73 __isl_give isl_ast_expr *isl_ast_expr_dup(__isl_keep isl_ast_expr *expr)
74 {
75         int i;
76         isl_ctx *ctx;
77         isl_ast_expr *dup;
78
79         if (!expr)
80                 return NULL;
81
82         ctx = isl_ast_expr_get_ctx(expr);
83         switch (expr->type) {
84         case isl_ast_expr_int:
85                 dup = isl_ast_expr_alloc_int(ctx, expr->u.i);
86                 break;
87         case isl_ast_expr_id:
88                 dup = isl_ast_expr_from_id(isl_id_copy(expr->u.id));
89                 break;
90         case isl_ast_expr_op:
91                 dup = isl_ast_expr_alloc_op(ctx,
92                                             expr->u.op.op, expr->u.op.n_arg);
93                 if (!dup)
94                         return NULL;
95                 for (i = 0; i < expr->u.op.n_arg; ++i)
96                         dup->u.op.args[i] =
97                                 isl_ast_expr_copy(expr->u.op.args[i]);
98                 break;
99         case isl_ast_expr_error:
100                 dup = NULL;
101         }
102
103         if (!dup)
104                 return NULL;
105
106         return dup;
107 }
108
109 __isl_give isl_ast_expr *isl_ast_expr_cow(__isl_take isl_ast_expr *expr)
110 {
111         if (!expr)
112                 return NULL;
113
114         if (expr->ref == 1)
115                 return expr;
116         expr->ref--;
117         return isl_ast_expr_dup(expr);
118 }
119
120 void *isl_ast_expr_free(__isl_take isl_ast_expr *expr)
121 {
122         int i;
123
124         if (!expr)
125                 return NULL;
126
127         if (--expr->ref > 0)
128                 return NULL;
129
130         isl_ctx_deref(expr->ctx);
131
132         switch (expr->type) {
133         case isl_ast_expr_int:
134                 isl_int_clear(expr->u.i);
135                 break;
136         case isl_ast_expr_id:
137                 isl_id_free(expr->u.id);
138                 break;
139         case isl_ast_expr_op:
140                 for (i = 0; i < expr->u.op.n_arg; ++i)
141                         isl_ast_expr_free(expr->u.op.args[i]);
142                 free(expr->u.op.args);
143                 break;
144         case isl_ast_expr_error:
145                 break;
146         }
147
148         free(expr);
149         return NULL;
150 }
151
152 isl_ctx *isl_ast_expr_get_ctx(__isl_keep isl_ast_expr *expr)
153 {
154         return expr ? expr->ctx : NULL;
155 }
156
157 enum isl_ast_expr_type isl_ast_expr_get_type(__isl_keep isl_ast_expr *expr)
158 {
159         return expr ? expr->type : isl_ast_expr_error;
160 }
161
162 int isl_ast_expr_get_int(__isl_keep isl_ast_expr *expr, isl_int *v)
163 {
164         if (!expr)
165                 return -1;
166         if (expr->type != isl_ast_expr_int)
167                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
168                         "expression not an int", return -1);
169         isl_int_set(*v, expr->u.i);
170         return 0;
171 }
172
173 __isl_give isl_id *isl_ast_expr_get_id(__isl_keep isl_ast_expr *expr)
174 {
175         if (!expr)
176                 return NULL;
177         if (expr->type != isl_ast_expr_id)
178                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
179                         "expression not an identifier", return NULL);
180
181         return isl_id_copy(expr->u.id);
182 }
183
184 enum isl_ast_op_type isl_ast_expr_get_op_type(__isl_keep isl_ast_expr *expr)
185 {
186         if (!expr)
187                 return isl_ast_op_error;
188         if (expr->type != isl_ast_expr_op)
189                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
190                         "expression not an operation", return isl_ast_op_error);
191         return expr->u.op.op;
192 }
193
194 int isl_ast_expr_get_op_n_arg(__isl_keep isl_ast_expr *expr)
195 {
196         if (!expr)
197                 return -1;
198         if (expr->type != isl_ast_expr_op)
199                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
200                         "expression not an operation", return -1);
201         return expr->u.op.n_arg;
202 }
203
204 __isl_give isl_ast_expr *isl_ast_expr_get_op_arg(__isl_keep isl_ast_expr *expr,
205         int pos)
206 {
207         if (!expr)
208                 return NULL;
209         if (expr->type != isl_ast_expr_op)
210                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
211                         "expression not an operation", return NULL);
212         if (pos < 0 || pos >= expr->u.op.n_arg)
213                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
214                         "index out of bounds", return NULL);
215
216         return isl_ast_expr_copy(expr->u.op.args[pos]);
217 }
218
219 /* Replace the argument at position "pos" of "expr" by "arg".
220  */
221 __isl_give isl_ast_expr *isl_ast_expr_set_op_arg(__isl_take isl_ast_expr *expr,
222         int pos, __isl_take isl_ast_expr *arg)
223 {
224         expr = isl_ast_expr_cow(expr);
225         if (!expr || !arg)
226                 goto error;
227         if (expr->type != isl_ast_expr_op)
228                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
229                         "expression not an operation", goto error);
230         if (pos < 0 || pos >= expr->u.op.n_arg)
231                 isl_die(isl_ast_expr_get_ctx(expr), isl_error_invalid,
232                         "index out of bounds", goto error);
233
234         isl_ast_expr_free(expr->u.op.args[pos]);
235         expr->u.op.args[pos] = arg;
236
237         return expr;
238 error:
239         isl_ast_expr_free(arg);
240         return isl_ast_expr_free(expr);
241 }
242
243 /* Create a new operation expression of operation type "op",
244  * with "n_arg" as yet unspecified arguments.
245  */
246 __isl_give isl_ast_expr *isl_ast_expr_alloc_op(isl_ctx *ctx,
247         enum isl_ast_op_type op, int n_arg)
248 {
249         isl_ast_expr *expr;
250
251         expr = isl_calloc_type(ctx, isl_ast_expr);
252         if (!expr)
253                 return NULL;
254
255         expr->ctx = ctx;
256         isl_ctx_ref(ctx);
257         expr->ref = 1;
258         expr->type = isl_ast_expr_op;
259         expr->u.op.op = op;
260         expr->u.op.n_arg = n_arg;
261         expr->u.op.args = isl_calloc_array(ctx, isl_ast_expr *, n_arg);
262
263         if (!expr->u.op.args)
264                 return isl_ast_expr_free(expr);
265
266         return expr;
267 }
268
269 /* Create a new id expression representing "id".
270  */
271 __isl_give isl_ast_expr *isl_ast_expr_from_id(__isl_take isl_id *id)
272 {
273         isl_ctx *ctx;
274         isl_ast_expr *expr;
275
276         if (!id)
277                 return NULL;
278
279         ctx = isl_id_get_ctx(id);
280         expr = isl_calloc_type(ctx, isl_ast_expr);
281         if (!expr)
282                 return isl_id_free(id);
283
284         expr->ctx = ctx;
285         isl_ctx_ref(ctx);
286         expr->ref = 1;
287         expr->type = isl_ast_expr_id;
288         expr->u.id = id;
289
290         return expr;
291 }
292
293 /* Create a new integer expression representing "i".
294  */
295 __isl_give isl_ast_expr *isl_ast_expr_alloc_int_si(isl_ctx *ctx, int i)
296 {
297         isl_ast_expr *expr;
298
299         expr = isl_calloc_type(ctx, isl_ast_expr);
300         if (!expr)
301                 return NULL;
302
303         expr->ctx = ctx;
304         isl_ctx_ref(ctx);
305         expr->ref = 1;
306         expr->type = isl_ast_expr_int;
307
308         isl_int_init(expr->u.i);
309         isl_int_set_si(expr->u.i, i);
310
311         return expr;
312 }
313
314 /* Create a new integer expression representing "i".
315  */
316 __isl_give isl_ast_expr *isl_ast_expr_alloc_int(isl_ctx *ctx, isl_int i)
317 {
318         isl_ast_expr *expr;
319
320         expr = isl_calloc_type(ctx, isl_ast_expr);
321         if (!expr)
322                 return NULL;
323
324         expr->ctx = ctx;
325         isl_ctx_ref(ctx);
326         expr->ref = 1;
327         expr->type = isl_ast_expr_int;
328
329         isl_int_init(expr->u.i);
330         isl_int_set(expr->u.i, i);
331
332         return expr;
333 }
334
335 /* Create an expression representing the negation of "arg".
336  */
337 __isl_give isl_ast_expr *isl_ast_expr_neg(__isl_take isl_ast_expr *arg)
338 {
339         isl_ctx *ctx;
340         isl_ast_expr *expr = NULL;
341
342         if (!arg)
343                 return NULL;
344
345         ctx = isl_ast_expr_get_ctx(arg);
346         expr = isl_ast_expr_alloc_op(ctx, isl_ast_op_minus, 1);
347         if (!expr)
348                 goto error;
349
350         expr->u.op.args[0] = arg;
351
352         return expr;
353 error:
354         isl_ast_expr_free(arg);
355         return NULL;
356 }
357
358 /* Create an expression representing the binary operation "type"
359  * applied to "expr1" and "expr2".
360  */
361 __isl_give isl_ast_expr *isl_ast_expr_alloc_binary(enum isl_ast_op_type type,
362         __isl_take isl_ast_expr *expr1, __isl_take isl_ast_expr *expr2)
363 {
364         isl_ctx *ctx;
365         isl_ast_expr *expr = NULL;
366
367         if (!expr1 || !expr2)
368                 goto error;
369
370         ctx = isl_ast_expr_get_ctx(expr1);
371         expr = isl_ast_expr_alloc_op(ctx, type, 2);
372         if (!expr)
373                 goto error;
374
375         expr->u.op.args[0] = expr1;
376         expr->u.op.args[1] = expr2;
377
378         return expr;
379 error:
380         isl_ast_expr_free(expr1);
381         isl_ast_expr_free(expr2);
382         return NULL;
383 }
384
385 /* Create an expression representing the sum of "expr1" and "expr2".
386  */
387 __isl_give isl_ast_expr *isl_ast_expr_add(__isl_take isl_ast_expr *expr1,
388         __isl_take isl_ast_expr *expr2)
389 {
390         return isl_ast_expr_alloc_binary(isl_ast_op_add, expr1, expr2);
391 }
392
393 /* Create an expression representing the difference of "expr1" and "expr2".
394  */
395 __isl_give isl_ast_expr *isl_ast_expr_sub(__isl_take isl_ast_expr *expr1,
396         __isl_take isl_ast_expr *expr2)
397 {
398         return isl_ast_expr_alloc_binary(isl_ast_op_sub, expr1, expr2);
399 }
400
401 /* Create an expression representing the product of "expr1" and "expr2".
402  */
403 __isl_give isl_ast_expr *isl_ast_expr_mul(__isl_take isl_ast_expr *expr1,
404         __isl_take isl_ast_expr *expr2)
405 {
406         return isl_ast_expr_alloc_binary(isl_ast_op_mul, expr1, expr2);
407 }
408
409 /* Create an expression representing the quotient of "expr1" and "expr2".
410  */
411 __isl_give isl_ast_expr *isl_ast_expr_div(__isl_take isl_ast_expr *expr1,
412         __isl_take isl_ast_expr *expr2)
413 {
414         return isl_ast_expr_alloc_binary(isl_ast_op_div, expr1, expr2);
415 }
416
417 /* Create an expression representing the conjunction of "expr1" and "expr2".
418  */
419 __isl_give isl_ast_expr *isl_ast_expr_and(__isl_take isl_ast_expr *expr1,
420         __isl_take isl_ast_expr *expr2)
421 {
422         return isl_ast_expr_alloc_binary(isl_ast_op_and, expr1, expr2);
423 }
424
425 /* Create an expression representing the disjunction of "expr1" and "expr2".
426  */
427 __isl_give isl_ast_expr *isl_ast_expr_or(__isl_take isl_ast_expr *expr1,
428         __isl_take isl_ast_expr *expr2)
429 {
430         return isl_ast_expr_alloc_binary(isl_ast_op_or, expr1, expr2);
431 }
432
433 isl_ctx *isl_ast_node_get_ctx(__isl_keep isl_ast_node *node)
434 {
435         return node ? node->ctx : NULL;
436 }
437
438 enum isl_ast_node_type isl_ast_node_get_type(__isl_keep isl_ast_node *node)
439 {
440         return node ? node->type : isl_ast_node_error;
441 }
442
443 __isl_give isl_ast_node *isl_ast_node_alloc(isl_ctx *ctx,
444         enum isl_ast_node_type type)
445 {
446         isl_ast_node *node;
447
448         node = isl_calloc_type(ctx, isl_ast_node);
449         if (!node)
450                 return NULL;
451
452         node->ctx = ctx;
453         isl_ctx_ref(ctx);
454         node->ref = 1;
455         node->type = type;
456
457         return node;
458 }
459
460 /* Create an if node with the given guard.
461  *
462  * The then body needs to be filled in later.
463  */
464 __isl_give isl_ast_node *isl_ast_node_alloc_if(__isl_take isl_ast_expr *guard)
465 {
466         isl_ast_node *node;
467
468         if (!guard)
469                 return NULL;
470
471         node = isl_ast_node_alloc(isl_ast_expr_get_ctx(guard), isl_ast_node_if);
472         if (!node)
473                 goto error;
474         node->u.i.guard = guard;
475
476         return node;
477 error:
478         isl_ast_expr_free(guard);
479         return NULL;
480 }
481
482 /* Create a for node with the given iterator.
483  *
484  * The remaining fields need to be filled in later.
485  */
486 __isl_give isl_ast_node *isl_ast_node_alloc_for(__isl_take isl_id *id)
487 {
488         isl_ast_node *node;
489         isl_ctx *ctx;
490
491         if (!id)
492                 return NULL;
493
494         ctx = isl_id_get_ctx(id);
495         node = isl_ast_node_alloc(ctx, isl_ast_node_for);
496         if (!node)
497                 return NULL;
498
499         node->u.f.iterator = isl_ast_expr_from_id(id);
500         if (!node->u.f.iterator)
501                 return isl_ast_node_free(node);
502
503         return node;
504 }
505
506 /* Create a user node evaluating "expr".
507  */
508 __isl_give isl_ast_node *isl_ast_node_alloc_user(__isl_take isl_ast_expr *expr)
509 {
510         isl_ctx *ctx;
511         isl_ast_node *node;
512
513         if (!expr)
514                 return NULL;
515
516         ctx = isl_ast_expr_get_ctx(expr);
517         node = isl_ast_node_alloc(ctx, isl_ast_node_user);
518         if (!node)
519                 goto error;
520
521         node->u.e.expr = expr;
522
523         return node;
524 error:
525         isl_ast_expr_free(expr);
526         return NULL;
527 }
528
529 /* Create a block node with the given children.
530  */
531 __isl_give isl_ast_node *isl_ast_node_alloc_block(
532         __isl_take isl_ast_node_list *list)
533 {
534         isl_ast_node *node;
535         isl_ctx *ctx;
536
537         if (!list)
538                 return NULL;
539
540         ctx = isl_ast_node_list_get_ctx(list);
541         node = isl_ast_node_alloc(ctx, isl_ast_node_block);
542         if (!node)
543                 goto error;
544
545         node->u.b.children = list;
546
547         return node;
548 error:
549         isl_ast_node_list_free(list);
550         return NULL;
551 }
552
553 /* Represent the given list of nodes as a single node, either by
554  * extract the node from a single element list or by creating
555  * a block node with the list of nodes as children.
556  */
557 __isl_give isl_ast_node *isl_ast_node_from_ast_node_list(
558         __isl_take isl_ast_node_list *list)
559 {
560         isl_ast_node *node;
561
562         if (isl_ast_node_list_n_ast_node(list) != 1)
563                 return isl_ast_node_alloc_block(list);
564
565         node = isl_ast_node_list_get_ast_node(list, 0);
566         isl_ast_node_list_free(list);
567
568         return node;
569 }
570
571 __isl_give isl_ast_node *isl_ast_node_copy(__isl_keep isl_ast_node *node)
572 {
573         if (!node)
574                 return NULL;
575
576         node->ref++;
577         return node;
578 }
579
580 __isl_give isl_ast_node *isl_ast_node_dup(__isl_keep isl_ast_node *node)
581 {
582         isl_ast_node *dup;
583
584         if (!node)
585                 return NULL;
586
587         dup = isl_ast_node_alloc(isl_ast_node_get_ctx(node), node->type);
588         if (!dup)
589                 return NULL;
590
591         switch (node->type) {
592         case isl_ast_node_if:
593                 dup->u.i.guard = isl_ast_expr_copy(node->u.i.guard);
594                 dup->u.i.then = isl_ast_node_copy(node->u.i.then);
595                 dup->u.i.else_node = isl_ast_node_copy(node->u.i.else_node);
596                 if (!dup->u.i.guard  || !dup->u.i.then ||
597                     (node->u.i.else_node && !dup->u.i.else_node))
598                         return isl_ast_node_free(dup);
599                 break;
600         case isl_ast_node_for:
601                 dup->u.f.iterator = isl_ast_expr_copy(node->u.f.iterator);
602                 dup->u.f.init = isl_ast_expr_copy(node->u.f.init);
603                 dup->u.f.cond = isl_ast_expr_copy(node->u.f.cond);
604                 dup->u.f.inc = isl_ast_expr_copy(node->u.f.inc);
605                 dup->u.f.body = isl_ast_node_copy(node->u.f.body);
606                 if (!dup->u.f.iterator || !dup->u.f.init || !dup->u.f.cond ||
607                     !dup->u.f.inc || !dup->u.f.body)
608                         return isl_ast_node_free(dup);
609                 break;
610         case isl_ast_node_block:
611                 dup->u.b.children = isl_ast_node_list_copy(node->u.b.children);
612                 if (!dup->u.b.children)
613                         return isl_ast_node_free(dup);
614                 break;
615         case isl_ast_node_user:
616                 dup->u.e.expr = isl_ast_expr_copy(node->u.e.expr);
617                 if (!dup->u.e.expr)
618                         return isl_ast_node_free(dup);
619                 break;
620         case isl_ast_node_error:
621                 break;
622         }
623
624         return dup;
625 }
626
627 __isl_give isl_ast_node *isl_ast_node_cow(__isl_take isl_ast_node *node)
628 {
629         if (!node)
630                 return NULL;
631
632         if (node->ref == 1)
633                 return node;
634         node->ref--;
635         return isl_ast_node_dup(node);
636 }
637
638 void *isl_ast_node_free(__isl_take isl_ast_node *node)
639 {
640         if (!node)
641                 return NULL;
642
643         if (--node->ref > 0)
644                 return NULL;
645
646         switch (node->type) {
647         case isl_ast_node_if:
648                 isl_ast_expr_free(node->u.i.guard);
649                 isl_ast_node_free(node->u.i.then);
650                 isl_ast_node_free(node->u.i.else_node);
651                 break;
652         case isl_ast_node_for:
653                 isl_ast_expr_free(node->u.f.iterator);
654                 isl_ast_expr_free(node->u.f.init);
655                 isl_ast_expr_free(node->u.f.cond);
656                 isl_ast_expr_free(node->u.f.inc);
657                 isl_ast_node_free(node->u.f.body);
658                 break;
659         case isl_ast_node_block:
660                 isl_ast_node_list_free(node->u.b.children);
661                 break;
662         case isl_ast_node_user:
663                 isl_ast_expr_free(node->u.e.expr);
664                 break;
665         case isl_ast_node_error:
666                 break;
667         }
668
669         isl_id_free(node->annotation);
670         isl_ctx_deref(node->ctx);
671         free(node);
672
673         return NULL;
674 }
675
676 /* Replace the body of the for node "node" by "body".
677  */
678 __isl_give isl_ast_node *isl_ast_node_for_set_body(
679         __isl_take isl_ast_node *node, __isl_take isl_ast_node *body)
680 {
681         node = isl_ast_node_cow(node);
682         if (!node || !body)
683                 goto error;
684         if (node->type != isl_ast_node_for)
685                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
686                         "not a for node", goto error);
687
688         isl_ast_node_free(node->u.f.body);
689         node->u.f.body = body;
690
691         return node;
692 error:
693         isl_ast_node_free(node);
694         isl_ast_node_free(body);
695         return NULL;
696 }
697
698 __isl_give isl_ast_node *isl_ast_node_for_get_body(
699         __isl_keep isl_ast_node *node)
700 {
701         if (!node)
702                 return NULL;
703         if (node->type != isl_ast_node_for)
704                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
705                         "not a for node", return NULL);
706         return isl_ast_node_copy(node->u.f.body);
707 }
708
709 /* Mark the given for node as being degenerate.
710  */
711 __isl_give isl_ast_node *isl_ast_node_for_mark_degenerate(
712         __isl_take isl_ast_node *node)
713 {
714         node = isl_ast_node_cow(node);
715         if (!node)
716                 return NULL;
717         node->u.f.degenerate = 1;
718         return node;
719 }
720
721 int isl_ast_node_for_is_degenerate(__isl_keep isl_ast_node *node)
722 {
723         if (!node)
724                 return -1;
725         if (node->type != isl_ast_node_for)
726                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
727                         "not a for node", return -1);
728         return node->u.f.degenerate;
729 }
730
731 __isl_give isl_ast_expr *isl_ast_node_for_get_iterator(
732         __isl_keep isl_ast_node *node)
733 {
734         if (!node)
735                 return NULL;
736         if (node->type != isl_ast_node_for)
737                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
738                         "not a for node", return NULL);
739         return isl_ast_expr_copy(node->u.f.iterator);
740 }
741
742 __isl_give isl_ast_expr *isl_ast_node_for_get_init(
743         __isl_keep isl_ast_node *node)
744 {
745         if (!node)
746                 return NULL;
747         if (node->type != isl_ast_node_for)
748                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
749                         "not a for node", return NULL);
750         return isl_ast_expr_copy(node->u.f.init);
751 }
752
753 /* Return the condition expression of the given for node.
754  *
755  * If the for node is degenerate, then the condition is not explicitly
756  * stored in the node.  Instead, it is constructed as
757  *
758  *      iterator <= init
759  */
760 __isl_give isl_ast_expr *isl_ast_node_for_get_cond(
761         __isl_keep isl_ast_node *node)
762 {
763         if (!node)
764                 return NULL;
765         if (node->type != isl_ast_node_for)
766                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
767                         "not a for node", return NULL);
768         if (!node->u.f.degenerate)
769                 return isl_ast_expr_copy(node->u.f.cond);
770
771         return isl_ast_expr_alloc_binary(isl_ast_op_le,
772                                 isl_ast_expr_copy(node->u.f.iterator),
773                                 isl_ast_expr_copy(node->u.f.init));
774 }
775
776 /* Return the increment of the given for node.
777  *
778  * If the for node is degenerate, then the increment is not explicitly
779  * stored in the node.  We simply return "1".
780  */
781 __isl_give isl_ast_expr *isl_ast_node_for_get_inc(
782         __isl_keep isl_ast_node *node)
783 {
784         if (!node)
785                 return NULL;
786         if (node->type != isl_ast_node_for)
787                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
788                         "not a for node", return NULL);
789         if (!node->u.f.degenerate)
790                 return isl_ast_expr_copy(node->u.f.inc);
791         return isl_ast_expr_alloc_int_si(isl_ast_node_get_ctx(node), 1);
792 }
793
794 /* Replace the then branch of the if node "node" by "child".
795  */
796 __isl_give isl_ast_node *isl_ast_node_if_set_then(
797         __isl_take isl_ast_node *node, __isl_take isl_ast_node *child)
798 {
799         node = isl_ast_node_cow(node);
800         if (!node || !child)
801                 goto error;
802         if (node->type != isl_ast_node_if)
803                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
804                         "not an if node", goto error);
805
806         isl_ast_node_free(node->u.i.then);
807         node->u.i.then = child;
808
809         return node;
810 error:
811         isl_ast_node_free(node);
812         isl_ast_node_free(child);
813         return NULL;
814 }
815
816 __isl_give isl_ast_node *isl_ast_node_if_get_then(
817         __isl_keep isl_ast_node *node)
818 {
819         if (!node)
820                 return NULL;
821         if (node->type != isl_ast_node_if)
822                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
823                         "not an if node", return NULL);
824         return isl_ast_node_copy(node->u.i.then);
825 }
826
827 int isl_ast_node_if_has_else(
828         __isl_keep isl_ast_node *node)
829 {
830         if (!node)
831                 return -1;
832         if (node->type != isl_ast_node_if)
833                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
834                         "not an if node", return -1);
835         return node->u.i.else_node != NULL;
836 }
837
838 __isl_give isl_ast_node *isl_ast_node_if_get_else(
839         __isl_keep isl_ast_node *node)
840 {
841         if (!node)
842                 return NULL;
843         if (node->type != isl_ast_node_if)
844                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
845                         "not an if node", return NULL);
846         return isl_ast_node_copy(node->u.i.else_node);
847 }
848
849 __isl_give isl_ast_expr *isl_ast_node_if_get_cond(
850         __isl_keep isl_ast_node *node)
851 {
852         if (!node)
853                 return NULL;
854         if (node->type != isl_ast_node_if)
855                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
856                         "not a guard node", return NULL);
857         return isl_ast_expr_copy(node->u.i.guard);
858 }
859
860 __isl_give isl_ast_node_list *isl_ast_node_block_get_children(
861         __isl_keep isl_ast_node *node)
862 {
863         if (!node)
864                 return NULL;
865         if (node->type != isl_ast_node_block)
866                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
867                         "not a block node", return NULL);
868         return isl_ast_node_list_copy(node->u.b.children);
869 }
870
871 __isl_give isl_ast_expr *isl_ast_node_user_get_expr(
872         __isl_keep isl_ast_node *node)
873 {
874         if (!node)
875                 return NULL;
876
877         return isl_ast_expr_copy(node->u.e.expr);
878 }
879
880 __isl_give isl_id *isl_ast_node_get_annotation(__isl_keep isl_ast_node *node)
881 {
882         return node ? isl_id_copy(node->annotation) : NULL;
883 }
884
885 /* Replace node->annotation by "annotation".
886  */
887 __isl_give isl_ast_node *isl_ast_node_set_annotation(
888         __isl_take isl_ast_node *node, __isl_take isl_id *annotation)
889 {
890         node = isl_ast_node_cow(node);
891         if (!node || !annotation)
892                 goto error;
893
894         isl_id_free(node->annotation);
895         node->annotation = annotation;
896
897         return node;
898 error:
899         isl_id_free(annotation);
900         return isl_ast_node_free(node);
901 }
902
903 /* Textual C representation of the various operators.
904  */
905 static char *op_str[] = {
906         [isl_ast_op_and] = "&&",
907         [isl_ast_op_and_then] = "&&",
908         [isl_ast_op_or] = "||",
909         [isl_ast_op_or_else] = "||",
910         [isl_ast_op_max] = "max",
911         [isl_ast_op_min] = "min",
912         [isl_ast_op_minus] = "-",
913         [isl_ast_op_add] = "+",
914         [isl_ast_op_sub] = "-",
915         [isl_ast_op_mul] = "*",
916         [isl_ast_op_pdiv_q] = "/",
917         [isl_ast_op_pdiv_r] = "%",
918         [isl_ast_op_div] = "/",
919         [isl_ast_op_eq] = "==",
920         [isl_ast_op_le] = "<=",
921         [isl_ast_op_ge] = ">=",
922         [isl_ast_op_lt] = "<",
923         [isl_ast_op_gt] = ">"
924 };
925
926 /* Precedence in C of the various operators.
927  * Based on http://en.wikipedia.org/wiki/Operators_in_C_and_C++
928  * Lowest value means highest precedence.
929  */
930 static int op_prec[] = {
931         [isl_ast_op_and] = 13,
932         [isl_ast_op_and_then] = 13,
933         [isl_ast_op_or] = 14,
934         [isl_ast_op_or_else] = 14,
935         [isl_ast_op_max] = 2,
936         [isl_ast_op_min] = 2,
937         [isl_ast_op_minus] = 3,
938         [isl_ast_op_add] = 6,
939         [isl_ast_op_sub] = 6,
940         [isl_ast_op_mul] = 5,
941         [isl_ast_op_div] = 5,
942         [isl_ast_op_fdiv_q] = 2,
943         [isl_ast_op_pdiv_q] = 5,
944         [isl_ast_op_pdiv_r] = 5,
945         [isl_ast_op_cond] = 15,
946         [isl_ast_op_select] = 15,
947         [isl_ast_op_eq] = 9,
948         [isl_ast_op_le] = 8,
949         [isl_ast_op_ge] = 8,
950         [isl_ast_op_lt] = 8,
951         [isl_ast_op_gt] = 8,
952         [isl_ast_op_call] = 2
953 };
954
955 /* Is the operator left-to-right associative?
956  */
957 static int op_left[] = {
958         [isl_ast_op_and] = 1,
959         [isl_ast_op_and_then] = 1,
960         [isl_ast_op_or] = 1,
961         [isl_ast_op_or_else] = 1,
962         [isl_ast_op_max] = 1,
963         [isl_ast_op_min] = 1,
964         [isl_ast_op_minus] = 0,
965         [isl_ast_op_add] = 1,
966         [isl_ast_op_sub] = 1,
967         [isl_ast_op_mul] = 1,
968         [isl_ast_op_div] = 1,
969         [isl_ast_op_fdiv_q] = 1,
970         [isl_ast_op_pdiv_q] = 1,
971         [isl_ast_op_pdiv_r] = 1,
972         [isl_ast_op_cond] = 0,
973         [isl_ast_op_select] = 0,
974         [isl_ast_op_eq] = 1,
975         [isl_ast_op_le] = 1,
976         [isl_ast_op_ge] = 1,
977         [isl_ast_op_lt] = 1,
978         [isl_ast_op_gt] = 1,
979         [isl_ast_op_call] = 1
980 };
981
982 static int is_and(enum isl_ast_op_type op)
983 {
984         return op == isl_ast_op_and || op == isl_ast_op_and_then;
985 }
986
987 static int is_or(enum isl_ast_op_type op)
988 {
989         return op == isl_ast_op_or || op == isl_ast_op_or_else;
990 }
991
992 static int is_add_sub(enum isl_ast_op_type op)
993 {
994         return op == isl_ast_op_add || op == isl_ast_op_sub;
995 }
996
997 static int is_div_mod(enum isl_ast_op_type op)
998 {
999         return op == isl_ast_op_div || op == isl_ast_op_pdiv_r;
1000 }
1001
1002 /* Do we need/want parentheses around "expr" as a subexpression of
1003  * an "op" operation?  If "left" is set, then "expr" is the left-most
1004  * operand.
1005  *
1006  * We only need parentheses if "expr" represents an operation.
1007  *
1008  * If op has a higher precedence than expr->u.op.op, then we need
1009  * parentheses.
1010  * If op and expr->u.op.op have the same precedence, but the operations
1011  * are performed in an order that is different from the associativity,
1012  * then we need parentheses.
1013  *
1014  * An and inside an or technically does not require parentheses,
1015  * but some compilers complain about that, so we add them anyway.
1016  *
1017  * Computations such as "a / b * c" and "a % b + c" can be somewhat
1018  * difficult to read, so we add parentheses for those as well.
1019  */
1020 static int sub_expr_need_parens(enum isl_ast_op_type op,
1021         __isl_keep isl_ast_expr *expr, int left)
1022 {
1023         if (expr->type != isl_ast_expr_op)
1024                 return 0;
1025
1026         if (op_prec[expr->u.op.op] > op_prec[op])
1027                 return 1;
1028         if (op_prec[expr->u.op.op] == op_prec[op] && left != op_left[op])
1029                 return 1;
1030
1031         if (is_or(op) && is_and(expr->u.op.op))
1032                 return 1;
1033         if (op == isl_ast_op_mul && expr->u.op.op != isl_ast_op_mul &&
1034             op_prec[expr->u.op.op] == op_prec[op])
1035                 return 1;
1036         if (is_add_sub(op) && is_div_mod(expr->u.op.op))
1037                 return 1;
1038
1039         return 0;
1040 }
1041
1042 /* Print "expr" as a subexpression of an "op" operation.
1043  * If "left" is set, then "expr" is the left-most operand.
1044  */
1045 static __isl_give isl_printer *print_sub_expr(__isl_take isl_printer *p,
1046         enum isl_ast_op_type op, __isl_keep isl_ast_expr *expr, int left)
1047 {
1048         int need_parens;
1049
1050         need_parens = sub_expr_need_parens(op, expr, left);
1051
1052         if (need_parens)
1053                 p = isl_printer_print_str(p, "(");
1054         p = isl_printer_print_ast_expr(p, expr);
1055         if (need_parens)
1056                 p = isl_printer_print_str(p, ")");
1057         return p;
1058 }
1059
1060 /* Print a min or max reduction "expr".
1061  */
1062 static __isl_give isl_printer *print_min_max(__isl_take isl_printer *p,
1063         __isl_keep isl_ast_expr *expr)
1064 {
1065         int i = 0;
1066
1067         for (i = 1; i < expr->u.op.n_arg; ++i) {
1068                 p = isl_printer_print_str(p, op_str[expr->u.op.op]);
1069                 p = isl_printer_print_str(p, "(");
1070         }
1071         p = isl_printer_print_ast_expr(p, expr->u.op.args[0]);
1072         for (i = 1; i < expr->u.op.n_arg; ++i) {
1073                 p = isl_printer_print_str(p, ", ");
1074                 p = isl_printer_print_ast_expr(p, expr->u.op.args[i]);
1075                 p = isl_printer_print_str(p, ")");
1076         }
1077
1078         return p;
1079 }
1080
1081 /* Print a function call "expr".
1082  *
1083  * The first argument represents the function to be called.
1084  */
1085 static __isl_give isl_printer *print_call(__isl_take isl_printer *p,
1086         __isl_keep isl_ast_expr *expr)
1087 {
1088         int i = 0;
1089
1090         p = isl_printer_print_ast_expr(p, expr->u.op.args[0]);
1091         p = isl_printer_print_str(p, "(");
1092         for (i = 1; i < expr->u.op.n_arg; ++i) {
1093                 if (i != 1)
1094                         p = isl_printer_print_str(p, ", ");
1095                 p = isl_printer_print_ast_expr(p, expr->u.op.args[i]);
1096         }
1097         p = isl_printer_print_str(p, ")");
1098
1099         return p;
1100 }
1101
1102 /* Print "expr" to "p".
1103  *
1104  * If we are printing in isl format, then we also print an indication
1105  * of the size of the expression (if it was computed).
1106  */
1107 __isl_give isl_printer *isl_printer_print_ast_expr(__isl_take isl_printer *p,
1108         __isl_keep isl_ast_expr *expr)
1109 {
1110         if (!p)
1111                 return NULL;
1112         if (!expr)
1113                 return isl_printer_free(p);
1114
1115         switch (expr->type) {
1116         case isl_ast_expr_op:
1117                 if (expr->u.op.op == isl_ast_op_call) {
1118                         p = print_call(p, expr);
1119                         break;
1120                 }
1121                 if (expr->u.op.n_arg == 1) {
1122                         p = isl_printer_print_str(p, op_str[expr->u.op.op]);
1123                         p = print_sub_expr(p, expr->u.op.op,
1124                                                 expr->u.op.args[0], 0);
1125                         break;
1126                 }
1127                 if (expr->u.op.op == isl_ast_op_fdiv_q) {
1128                         p = isl_printer_print_str(p, "floord(");
1129                         p = isl_printer_print_ast_expr(p, expr->u.op.args[0]);
1130                         p = isl_printer_print_str(p, ", ");
1131                         p = isl_printer_print_ast_expr(p, expr->u.op.args[1]);
1132                         p = isl_printer_print_str(p, ")");
1133                         break;
1134                 }
1135                 if (expr->u.op.op == isl_ast_op_max ||
1136                     expr->u.op.op == isl_ast_op_min) {
1137                         p = print_min_max(p, expr);
1138                         break;
1139                 }
1140                 if (expr->u.op.op == isl_ast_op_cond ||
1141                     expr->u.op.op == isl_ast_op_select) {
1142                         p = isl_printer_print_ast_expr(p, expr->u.op.args[0]);
1143                         p = isl_printer_print_str(p, " ? ");
1144                         p = isl_printer_print_ast_expr(p, expr->u.op.args[1]);
1145                         p = isl_printer_print_str(p, " : ");
1146                         p = isl_printer_print_ast_expr(p, expr->u.op.args[2]);
1147                         break;
1148                 }
1149                 if (expr->u.op.n_arg != 2)
1150                         isl_die(isl_printer_get_ctx(p), isl_error_internal,
1151                                 "operation should have two arguments",
1152                                 goto error);
1153                 p = print_sub_expr(p, expr->u.op.op, expr->u.op.args[0], 1);
1154                 p = isl_printer_print_str(p, " ");
1155                 p = isl_printer_print_str(p, op_str[expr->u.op.op]);
1156                 p = isl_printer_print_str(p, " ");
1157                 p = print_sub_expr(p, expr->u.op.op, expr->u.op.args[1], 0);
1158                 break;
1159         case isl_ast_expr_id:
1160                 p = isl_printer_print_str(p, isl_id_get_name(expr->u.id));
1161                 break;
1162         case isl_ast_expr_int:
1163                 p = isl_printer_print_isl_int(p, expr->u.i);
1164                 break;
1165         case isl_ast_expr_error:
1166                 break;
1167         }
1168
1169         return p;
1170 error:
1171         isl_printer_free(p);
1172         return NULL;
1173 }
1174
1175 /* Print "node" to "p" in "isl format".
1176  */
1177 static __isl_give isl_printer *print_ast_node_isl(__isl_take isl_printer *p,
1178         __isl_keep isl_ast_node *node)
1179 {
1180         p = isl_printer_print_str(p, "(");
1181         switch (node->type) {
1182         case isl_ast_node_for:
1183                 if (node->u.f.degenerate) {
1184                         p = isl_printer_print_ast_expr(p, node->u.f.init);
1185                 } else {
1186                         p = isl_printer_print_str(p, "init: ");
1187                         p = isl_printer_print_ast_expr(p, node->u.f.init);
1188                         p = isl_printer_print_str(p, ", ");
1189                         p = isl_printer_print_str(p, "cond: ");
1190                         p = isl_printer_print_ast_expr(p, node->u.f.cond);
1191                         p = isl_printer_print_str(p, ", ");
1192                         p = isl_printer_print_str(p, "inc: ");
1193                         p = isl_printer_print_ast_expr(p, node->u.f.inc);
1194                 }
1195                 if (node->u.f.body) {
1196                         p = isl_printer_print_str(p, ", ");
1197                         p = isl_printer_print_str(p, "body: ");
1198                         p = isl_printer_print_ast_node(p, node->u.f.body);
1199                 }
1200                 break;
1201         case isl_ast_node_user:
1202                 p = isl_printer_print_ast_expr(p, node->u.e.expr);
1203                 break;
1204         case isl_ast_node_if:
1205                 p = isl_printer_print_str(p, "guard: ");
1206                 p = isl_printer_print_ast_expr(p, node->u.i.guard);
1207                 if (node->u.i.then) {
1208                         p = isl_printer_print_str(p, ", ");
1209                         p = isl_printer_print_str(p, "then: ");
1210                         p = isl_printer_print_ast_node(p, node->u.i.then);
1211                 }
1212                 if (node->u.i.else_node) {
1213                         p = isl_printer_print_str(p, ", ");
1214                         p = isl_printer_print_str(p, "else: ");
1215                         p = isl_printer_print_ast_node(p, node->u.i.else_node);
1216                 }
1217                 break;
1218         case isl_ast_node_block:
1219                 p = isl_printer_print_ast_node_list(p, node->u.b.children);
1220                 break;
1221         default:
1222                 break;
1223         }
1224         p = isl_printer_print_str(p, ")");
1225         return p;
1226 }
1227
1228 /* Do we need to print a block around the body "node" of a for or if node?
1229  *
1230  * If the node is a block, then we need to print a block.
1231  * Also if the node is a degenerate for then we will print it as
1232  * an assignment followed by the body of the for loop, so we need a block
1233  * as well.
1234  */
1235 static int need_block(__isl_keep isl_ast_node *node)
1236 {
1237         if (node->type == isl_ast_node_block)
1238                 return 1;
1239         if (node->type == isl_ast_node_for && node->u.f.degenerate)
1240                 return 1;
1241         return 0;
1242 }
1243
1244 static __isl_give isl_printer *print_ast_node_c(__isl_take isl_printer *p,
1245         __isl_keep isl_ast_node *node,
1246         __isl_keep isl_ast_print_options *options, int in_block);
1247 static __isl_give isl_printer *print_if_c(__isl_take isl_printer *p,
1248         __isl_keep isl_ast_node *node,
1249         __isl_keep isl_ast_print_options *options, int new_line);
1250
1251 /* Print the body "node" of a for or if node.
1252  * If "else_node" is set, then it is printed as well.
1253  *
1254  * We first check if we need to print out a block.
1255  * We always print out a block if there is an else node to make
1256  * sure that the else node is matched to the correct if node.
1257  *
1258  * If the else node is itself an if, then we print it as
1259  *
1260  *      } else if (..)
1261  *
1262  * Otherwise the else node is printed as
1263  *
1264  *      } else
1265  *        node
1266  */
1267 static __isl_give isl_printer *print_body_c(__isl_take isl_printer *p,
1268         __isl_keep isl_ast_node *node, __isl_keep isl_ast_node *else_node,
1269         __isl_keep isl_ast_print_options *options)
1270 {
1271         if (!node)
1272                 return isl_printer_free(p);
1273
1274         if (!else_node && !need_block(node)) {
1275                 p = isl_printer_end_line(p);
1276                 p = isl_printer_indent(p, 2);
1277                 p = isl_ast_node_print(node, p, options);
1278                 p = isl_printer_indent(p, -2);
1279                 return p;
1280         }
1281
1282         p = isl_printer_print_str(p, " {");
1283         p = isl_printer_end_line(p);
1284         p = isl_printer_indent(p, 2);
1285         p = print_ast_node_c(p, node, options, 1);
1286         p = isl_printer_indent(p, -2);
1287         p = isl_printer_start_line(p);
1288         p = isl_printer_print_str(p, "}");
1289         if (else_node) {
1290                 if (else_node->type == isl_ast_node_if) {
1291                         p = isl_printer_print_str(p, " else ");
1292                         p = print_if_c(p, else_node, options, 0);
1293                 } else {
1294                         p = isl_printer_print_str(p, " else");
1295                         p = print_body_c(p, else_node, NULL, options);
1296                 }
1297         } else
1298                 p = isl_printer_end_line(p);
1299
1300         return p;
1301 }
1302
1303 /* Print the for node "node".
1304  *
1305  * If the for node is degenerate, it is printed as
1306  *
1307  *      type iterator = init;
1308  *      body
1309  *
1310  * Otherwise, it is printed as
1311  *
1312  *      for (type iterator = init; cond; iterator += inc)
1313  *              body
1314  *
1315  * "in_block" is set if we are currently inside a block.
1316  * We simply pass it along to print_ast_node_c in case of a degenerate
1317  * for loop.
1318  */
1319 static __isl_give isl_printer *print_for_c(__isl_take isl_printer *p,
1320         __isl_keep isl_ast_node *node,
1321         __isl_keep isl_ast_print_options *options, int in_block)
1322 {
1323         isl_id *id;
1324         const char *name;
1325         const char *type;
1326
1327         type = isl_options_get_ast_iterator_type(isl_printer_get_ctx(p));
1328         if (!node->u.f.degenerate) {
1329                 id = isl_ast_expr_get_id(node->u.f.iterator);
1330                 name = isl_id_get_name(id);
1331                 isl_id_free(id);
1332                 p = isl_printer_start_line(p);
1333                 p = isl_printer_print_str(p, "for (");
1334                 p = isl_printer_print_str(p, type);
1335                 p = isl_printer_print_str(p, " ");
1336                 p = isl_printer_print_str(p, name);
1337                 p = isl_printer_print_str(p, " = ");
1338                 p = isl_printer_print_ast_expr(p, node->u.f.init);
1339                 p = isl_printer_print_str(p, "; ");
1340                 p = isl_printer_print_ast_expr(p, node->u.f.cond);
1341                 p = isl_printer_print_str(p, "; ");
1342                 p = isl_printer_print_str(p, name);
1343                 p = isl_printer_print_str(p, " += ");
1344                 p = isl_printer_print_ast_expr(p, node->u.f.inc);
1345                 p = isl_printer_print_str(p, ")");
1346                 p = print_body_c(p, node->u.f.body, NULL, options);
1347         } else {
1348                 id = isl_ast_expr_get_id(node->u.f.iterator);
1349                 name = isl_id_get_name(id);
1350                 isl_id_free(id);
1351                 p = isl_printer_start_line(p);
1352                 p = isl_printer_print_str(p, type);
1353                 p = isl_printer_print_str(p, " ");
1354                 p = isl_printer_print_str(p, name);
1355                 p = isl_printer_print_str(p, " = ");
1356                 p = isl_printer_print_ast_expr(p, node->u.f.init);
1357                 p = isl_printer_print_str(p, ";");
1358                 p = isl_printer_end_line(p);
1359                 p = print_ast_node_c(p, node->u.f.body, options, in_block);
1360         }
1361
1362         return p;
1363 }
1364
1365 /* Print the if node "node".
1366  * If "new_line" is set then the if node should be printed on a new line.
1367  */
1368 static __isl_give isl_printer *print_if_c(__isl_take isl_printer *p,
1369         __isl_keep isl_ast_node *node,
1370         __isl_keep isl_ast_print_options *options, int new_line)
1371 {
1372         if (new_line)
1373                 p = isl_printer_start_line(p);
1374         p = isl_printer_print_str(p, "if (");
1375         p = isl_printer_print_ast_expr(p, node->u.i.guard);
1376         p = isl_printer_print_str(p, ")");
1377         p = print_body_c(p, node->u.i.then, node->u.i.else_node, options);
1378
1379         return p;
1380 }
1381
1382 /* Print the "node" to "p".
1383  *
1384  * "in_block" is set if we are currently inside a block.
1385  * If so, we do not print a block around the children of a block node.
1386  * We do this to avoid an extra block around the body of a degenerate
1387  * for node.
1388  */
1389 static __isl_give isl_printer *print_ast_node_c(__isl_take isl_printer *p,
1390         __isl_keep isl_ast_node *node,
1391         __isl_keep isl_ast_print_options *options, int in_block)
1392 {
1393         switch (node->type) {
1394         case isl_ast_node_for:
1395                 if (options->print_for)
1396                         return options->print_for(p, node,
1397                                                     options->print_for_user);
1398                 p = print_for_c(p, node, options, in_block);
1399                 break;
1400         case isl_ast_node_if:
1401                 p = print_if_c(p, node, options, 1);
1402                 break;
1403         case isl_ast_node_block:
1404                 if (!in_block) {
1405                         p = isl_printer_start_line(p);
1406                         p = isl_printer_print_str(p, "{");
1407                         p = isl_printer_end_line(p);
1408                         p = isl_printer_indent(p, 2);
1409                 }
1410                 p = isl_ast_node_list_print(node->u.b.children, p, options);
1411                 if (!in_block) {
1412                         p = isl_printer_indent(p, -2);
1413                         p = isl_printer_start_line(p);
1414                         p = isl_printer_print_str(p, "}");
1415                         p = isl_printer_end_line(p);
1416                 }
1417                 break;
1418         case isl_ast_node_user:
1419                 if (options->print_user)
1420                         return options->print_user(p, node,
1421                                                     options->print_user_user);
1422                 p = isl_printer_start_line(p);
1423                 p = isl_printer_print_ast_expr(p, node->u.e.expr);
1424                 p = isl_printer_print_str(p, ";");
1425                 p = isl_printer_end_line(p);
1426                 break;
1427         case isl_ast_node_error:
1428                 break;
1429         }
1430         return p;
1431 }
1432
1433 /* Print the for node "node" to "p".
1434  */
1435 __isl_give isl_printer *isl_ast_node_for_print(__isl_keep isl_ast_node *node,
1436         __isl_take isl_printer *p, __isl_keep isl_ast_print_options *options)
1437 {
1438         if (!node || !options)
1439                 goto error;
1440         if (node->type != isl_ast_node_for)
1441                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
1442                         "not a for node", goto error);
1443         return print_for_c(p, node, options, 0);
1444 error:
1445         isl_printer_free(p);
1446         return NULL;
1447 }
1448
1449 /* Print the if node "node" to "p".
1450  */
1451 __isl_give isl_printer *isl_ast_node_if_print(__isl_keep isl_ast_node *node,
1452         __isl_take isl_printer *p, __isl_keep isl_ast_print_options *options)
1453 {
1454         if (!node || !options)
1455                 goto error;
1456         if (node->type != isl_ast_node_if)
1457                 isl_die(isl_ast_node_get_ctx(node), isl_error_invalid,
1458                         "not an if node", goto error);
1459         return print_if_c(p, node, options, 1);
1460 error:
1461         isl_printer_free(p);
1462         return NULL;
1463 }
1464
1465 /* Print "node" to "p".
1466  */
1467 __isl_give isl_printer *isl_ast_node_print(__isl_keep isl_ast_node *node,
1468         __isl_take isl_printer *p, __isl_keep isl_ast_print_options *options)
1469 {
1470         if (!options || !node)
1471                 goto error;
1472         return print_ast_node_c(p, node, options, 0);
1473 error:
1474         isl_printer_free(p);
1475         return NULL;
1476 }
1477
1478 /* Print "node" to "p".
1479  */
1480 __isl_give isl_printer *isl_printer_print_ast_node(__isl_take isl_printer *p,
1481         __isl_keep isl_ast_node *node)
1482 {
1483         int format;
1484         isl_ast_print_options *options;
1485
1486         if (!p)
1487                 return NULL;
1488
1489         format = isl_printer_get_output_format(p);
1490         switch (format) {
1491         case ISL_FORMAT_ISL:
1492                 p = print_ast_node_isl(p, node);
1493                 break;
1494         case ISL_FORMAT_C:
1495                 options = isl_ast_print_options_alloc(isl_printer_get_ctx(p));
1496                 p = isl_ast_node_print(node, p, options);
1497                 isl_ast_print_options_free(options);
1498                 break;
1499         default:
1500                 isl_die(isl_printer_get_ctx(p), isl_error_unsupported,
1501                         "output format not supported for ast_node",
1502                         return isl_printer_free(p));
1503         }
1504
1505         return p;
1506 }
1507
1508 /* Print the list of nodes "list" to "p".
1509  */
1510 __isl_give isl_printer *isl_ast_node_list_print(
1511         __isl_keep isl_ast_node_list *list, __isl_take isl_printer *p,
1512         __isl_keep isl_ast_print_options *options)
1513 {
1514         int i;
1515
1516         if (!p || !list || !options)
1517                 return isl_printer_free(p);
1518
1519         for (i = 0; i < list->n; ++i)
1520                 p = print_ast_node_c(p, list->p[i], options, 1);
1521
1522         return p;
1523 }
1524
1525 #define ISL_AST_MACRO_FLOORD    (1 << 0)
1526 #define ISL_AST_MACRO_MIN       (1 << 1)
1527 #define ISL_AST_MACRO_MAX       (1 << 2)
1528 #define ISL_AST_MACRO_ALL       (ISL_AST_MACRO_FLOORD | \
1529                                  ISL_AST_MACRO_MIN | \
1530                                  ISL_AST_MACRO_MAX)
1531
1532 /* If "expr" contains an isl_ast_op_min, isl_ast_op_max or isl_ast_op_fdiv_q
1533  * then set the corresponding bit in "macros".
1534  */
1535 static int ast_expr_required_macros(__isl_keep isl_ast_expr *expr, int macros)
1536 {
1537         int i;
1538
1539         if (macros == ISL_AST_MACRO_ALL)
1540                 return macros;
1541
1542         if (expr->type != isl_ast_expr_op)
1543                 return macros;
1544
1545         if (expr->u.op.op == isl_ast_op_min)
1546                 macros |= ISL_AST_MACRO_MIN;
1547         if (expr->u.op.op == isl_ast_op_max)
1548                 macros |= ISL_AST_MACRO_MAX;
1549         if (expr->u.op.op == isl_ast_op_fdiv_q)
1550                 macros |= ISL_AST_MACRO_FLOORD;
1551
1552         for (i = 0; i < expr->u.op.n_arg; ++i)
1553                 macros = ast_expr_required_macros(expr->u.op.args[i], macros);
1554
1555         return macros;
1556 }
1557
1558 static int ast_node_list_required_macros(__isl_keep isl_ast_node_list *list,
1559         int macros);
1560
1561 /* If "node" contains an isl_ast_op_min, isl_ast_op_max or isl_ast_op_fdiv_q
1562  * then set the corresponding bit in "macros".
1563  */
1564 static int ast_node_required_macros(__isl_keep isl_ast_node *node, int macros)
1565 {
1566         if (macros == ISL_AST_MACRO_ALL)
1567                 return macros;
1568
1569         switch (node->type) {
1570         case isl_ast_node_for:
1571                 macros = ast_expr_required_macros(node->u.f.init, macros);
1572                 if (!node->u.f.degenerate) {
1573                         macros = ast_expr_required_macros(node->u.f.cond,
1574                                                                 macros);
1575                         macros = ast_expr_required_macros(node->u.f.inc,
1576                                                                 macros);
1577                 }
1578                 macros = ast_node_required_macros(node->u.f.body, macros);
1579                 break;
1580         case isl_ast_node_if:
1581                 macros = ast_expr_required_macros(node->u.i.guard, macros);
1582                 macros = ast_node_required_macros(node->u.i.then, macros);
1583                 if (node->u.i.else_node)
1584                         macros = ast_node_required_macros(node->u.i.else_node,
1585                                                                 macros);
1586                 break;
1587         case isl_ast_node_block:
1588                 macros = ast_node_list_required_macros(node->u.b.children,
1589                                                         macros);
1590                 break;
1591         case isl_ast_node_user:
1592                 macros = ast_expr_required_macros(node->u.e.expr, macros);
1593                 break;
1594         case isl_ast_node_error:
1595                 break;
1596         }
1597
1598         return macros;
1599 }
1600
1601 /* If "list" contains an isl_ast_op_min, isl_ast_op_max or isl_ast_op_fdiv_q
1602  * then set the corresponding bit in "macros".
1603  */
1604 static int ast_node_list_required_macros(__isl_keep isl_ast_node_list *list,
1605         int macros)
1606 {
1607         int i;
1608
1609         for (i = 0; i < list->n; ++i)
1610                 macros = ast_node_required_macros(list->p[i], macros);
1611
1612         return macros;
1613 }
1614
1615 /* Print a macro definition for the operator "type".
1616  */
1617 __isl_give isl_printer *isl_ast_op_type_print_macro(
1618         enum isl_ast_op_type type, __isl_take isl_printer *p)
1619 {
1620         switch (type) {
1621         case isl_ast_op_min:
1622                 p = isl_printer_start_line(p);
1623                 p = isl_printer_print_str(p,
1624                         "#define min(x,y)    ((x) < (y) ? (x) : (y))");
1625                 p = isl_printer_end_line(p);
1626                 break;
1627         case isl_ast_op_max:
1628                 p = isl_printer_start_line(p);
1629                 p = isl_printer_print_str(p,
1630                         "#define max(x,y)    ((x) > (y) ? (x) : (y))");
1631                 p = isl_printer_end_line(p);
1632                 break;
1633         case isl_ast_op_fdiv_q:
1634                 p = isl_printer_start_line(p);
1635                 p = isl_printer_print_str(p,
1636                         "#define floord(n,d) "
1637                         "(((n)<0) ? -((-(n)+(d)-1)/(d)) : (n)/(d))");
1638                 p = isl_printer_end_line(p);
1639                 break;
1640         default:
1641                 break;
1642         }
1643
1644         return p;
1645 }
1646
1647 /* Call "fn" for each type of operation that appears in "node"
1648  * and that requires a macro definition.
1649  */
1650 int isl_ast_node_foreach_ast_op_type(__isl_keep isl_ast_node *node,
1651         int (*fn)(enum isl_ast_op_type type, void *user), void *user)
1652 {
1653         int macros;
1654
1655         if (!node)
1656                 return -1;
1657
1658         macros = ast_node_required_macros(node, 0);
1659
1660         if (macros & ISL_AST_MACRO_MIN && fn(isl_ast_op_min, user) < 0)
1661                 return -1;
1662         if (macros & ISL_AST_MACRO_MAX && fn(isl_ast_op_max, user) < 0)
1663                 return -1;
1664         if (macros & ISL_AST_MACRO_FLOORD && fn(isl_ast_op_fdiv_q, user) < 0)
1665                 return -1;
1666
1667         return 0;
1668 }
1669
1670 static int ast_op_type_print_macro(enum isl_ast_op_type type, void *user)
1671 {
1672         isl_printer **p = user;
1673
1674         *p = isl_ast_op_type_print_macro(type, *p);
1675
1676         return 0;
1677 }
1678
1679 /* Print macro definitions for all the macros used in the result
1680  * of printing "node.
1681  */
1682 __isl_give isl_printer *isl_ast_node_print_macros(
1683         __isl_keep isl_ast_node *node, __isl_take isl_printer *p)
1684 {
1685         if (isl_ast_node_foreach_ast_op_type(node,
1686                                             &ast_op_type_print_macro, &p) < 0)
1687                 return isl_printer_free(p);
1688         return p;
1689 }