Tizen 2.1 base
[external/gmp.git] / demos / pexpr.c
1 /* Program for computing integer expressions using the GNU Multiple Precision
2    Arithmetic Library.
3
4 Copyright 1997, 1999, 2000, 2001, 2002, 2005 Free Software Foundation, Inc.
5
6 This program is free software; you can redistribute it and/or modify it under
7 the terms of the GNU General Public License as published by the Free Software
8 Foundation; either version 3 of the License, or (at your option) any later
9 version.
10
11 This program is distributed in the hope that it will be useful, but WITHOUT ANY
12 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
13 PARTICULAR PURPOSE.  See the GNU General Public License for more details.
14
15 You should have received a copy of the GNU General Public License along with
16 this program.  If not, see http://www.gnu.org/licenses/.  */
17
18
19 /* This expressions evaluator works by building an expression tree (using a
20    recursive descent parser) which is then evaluated.  The expression tree is
21    useful since we want to optimize certain expressions (like a^b % c).
22
23    Usage: pexpr [options] expr ...
24    (Assuming you called the executable `pexpr' of course.)
25
26    Command line options:
27
28    -b        print output in binary
29    -o        print output in octal
30    -d        print output in decimal (the default)
31    -x        print output in hexadecimal
32    -b<NUM>   print output in base NUM
33    -t        print timing information
34    -html     output html
35    -wml      output wml
36    -split    split long lines each 80th digit
37 */
38
39 /* Define LIMIT_RESOURCE_USAGE if you want to make sure the program doesn't
40    use up extensive resources (cpu, memory).  Useful for the GMP demo on the
41    GMP web site, since we cannot load the server too much.  */
42
43 #include "pexpr-config.h"
44
45 #include <string.h>
46 #include <stdio.h>
47 #include <stdlib.h>
48 #include <setjmp.h>
49 #include <signal.h>
50 #include <ctype.h>
51
52 #include <time.h>
53 #include <sys/types.h>
54 #include <sys/time.h>
55 #if HAVE_SYS_RESOURCE_H
56 #include <sys/resource.h>
57 #endif
58
59 #include "gmp.h"
60
61 /* SunOS 4 and HPUX 9 don't define a canonical SIGSTKSZ, use a default. */
62 #ifndef SIGSTKSZ
63 #define SIGSTKSZ  4096
64 #endif
65
66
67 #define TIME(t,func)                                                    \
68   do { int __t0, __tmp;                                                 \
69     __t0 = cputime ();                                                  \
70     {func;}                                                             \
71     __tmp = cputime () - __t0;                                          \
72     (t) = __tmp;                                                        \
73   } while (0)
74
75 /* GMP version 1.x compatibility.  */
76 #if ! (__GNU_MP_VERSION >= 2)
77 typedef MP_INT __mpz_struct;
78 typedef __mpz_struct mpz_t[1];
79 typedef __mpz_struct *mpz_ptr;
80 #define mpz_fdiv_q      mpz_div
81 #define mpz_fdiv_r      mpz_mod
82 #define mpz_tdiv_q_2exp mpz_div_2exp
83 #define mpz_sgn(Z) ((Z)->size < 0 ? -1 : (Z)->size > 0)
84 #endif
85
86 /* GMP version 2.0 compatibility.  */
87 #if ! (__GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1)
88 #define mpz_swap(a,b) \
89   do { __mpz_struct __t; __t = *a; *a = *b; *b = __t;} while (0)
90 #endif
91
92 jmp_buf errjmpbuf;
93
94 enum op_t {NOP, LIT, NEG, NOT, PLUS, MINUS, MULT, DIV, MOD, REM, INVMOD, POW,
95            AND, IOR, XOR, SLL, SRA, POPCNT, HAMDIST, GCD, LCM, SQRT, ROOT, FAC,
96            LOG, LOG2, FERMAT, MERSENNE, FIBONACCI, RANDOM, NEXTPRIME, BINOM,
97            TIMING};
98
99 /* Type for the expression tree.  */
100 struct expr
101 {
102   enum op_t op;
103   union
104   {
105     struct {struct expr *lhs, *rhs;} ops;
106     mpz_t val;
107   } operands;
108 };
109
110 typedef struct expr *expr_t;
111
112 void cleanup_and_exit __GMP_PROTO ((int));
113
114 char *skipspace __GMP_PROTO ((char *));
115 void makeexp __GMP_PROTO ((expr_t *, enum op_t, expr_t, expr_t));
116 void free_expr __GMP_PROTO ((expr_t));
117 char *expr __GMP_PROTO ((char *, expr_t *));
118 char *term __GMP_PROTO ((char *, expr_t *));
119 char *power __GMP_PROTO ((char *, expr_t *));
120 char *factor __GMP_PROTO ((char *, expr_t *));
121 int match __GMP_PROTO ((char *, char *));
122 int matchp __GMP_PROTO ((char *, char *));
123 int cputime __GMP_PROTO ((void));
124
125 void mpz_eval_expr __GMP_PROTO ((mpz_ptr, expr_t));
126 void mpz_eval_mod_expr __GMP_PROTO ((mpz_ptr, expr_t, mpz_ptr));
127
128 char *error;
129 int flag_print = 1;
130 int print_timing = 0;
131 int flag_html = 0;
132 int flag_wml = 0;
133 int flag_splitup_output = 0;
134 char *newline = "";
135 gmp_randstate_t rstate;
136
137
138
139 /* cputime() returns user CPU time measured in milliseconds.  */
140 #if ! HAVE_CPUTIME
141 #if HAVE_GETRUSAGE
142 int
143 cputime (void)
144 {
145   struct rusage rus;
146
147   getrusage (0, &rus);
148   return rus.ru_utime.tv_sec * 1000 + rus.ru_utime.tv_usec / 1000;
149 }
150 #else
151 #if HAVE_CLOCK
152 int
153 cputime (void)
154 {
155   if (CLOCKS_PER_SEC < 100000)
156     return clock () * 1000 / CLOCKS_PER_SEC;
157   return clock () / (CLOCKS_PER_SEC / 1000);
158 }
159 #else
160 int
161 cputime (void)
162 {
163   return 0;
164 }
165 #endif
166 #endif
167 #endif
168
169
170 int
171 stack_downwards_helper (char *xp)
172 {
173   char  y;
174   return &y < xp;
175 }
176 int
177 stack_downwards_p (void)
178 {
179   char  x;
180   return stack_downwards_helper (&x);
181 }
182
183
184 void
185 setup_error_handler (void)
186 {
187 #if HAVE_SIGACTION
188   struct sigaction act;
189   act.sa_handler = cleanup_and_exit;
190   sigemptyset (&(act.sa_mask));
191 #define SIGNAL(sig)  sigaction (sig, &act, NULL)
192 #else
193   struct { int sa_flags; } act;
194 #define SIGNAL(sig)  signal (sig, cleanup_and_exit)
195 #endif
196   act.sa_flags = 0;
197
198   /* Set up a stack for signal handling.  A typical cause of error is stack
199      overflow, and in such situation a signal can not be delivered on the
200      overflown stack.  */
201 #if HAVE_SIGALTSTACK
202   {
203     /* AIX uses stack_t, MacOS uses struct sigaltstack, various other
204        systems have both. */
205 #if HAVE_STACK_T
206     stack_t s;
207 #else
208     struct sigaltstack s;
209 #endif
210     s.ss_sp = malloc (SIGSTKSZ);
211     s.ss_size = SIGSTKSZ;
212     s.ss_flags = 0;
213     if (sigaltstack (&s, NULL) != 0)
214       perror("sigaltstack");
215     act.sa_flags = SA_ONSTACK;
216   }
217 #else
218 #if HAVE_SIGSTACK
219   {
220     struct sigstack s;
221     s.ss_sp = malloc (SIGSTKSZ);
222     if (stack_downwards_p ())
223       s.ss_sp += SIGSTKSZ;
224     s.ss_onstack = 0;
225     if (sigstack (&s, NULL) != 0)
226       perror("sigstack");
227     act.sa_flags = SA_ONSTACK;
228   }
229 #else
230 #endif
231 #endif
232
233 #ifdef LIMIT_RESOURCE_USAGE
234   {
235     struct rlimit limit;
236
237     limit.rlim_cur = limit.rlim_max = 0;
238     setrlimit (RLIMIT_CORE, &limit);
239
240     limit.rlim_cur = 3;
241     limit.rlim_max = 4;
242     setrlimit (RLIMIT_CPU, &limit);
243
244     limit.rlim_cur = limit.rlim_max = 16 * 1024 * 1024;
245     setrlimit (RLIMIT_DATA, &limit);
246
247     getrlimit (RLIMIT_STACK, &limit);
248     limit.rlim_cur = 4 * 1024 * 1024;
249     setrlimit (RLIMIT_STACK, &limit);
250
251     SIGNAL (SIGXCPU);
252   }
253 #endif /* LIMIT_RESOURCE_USAGE */
254
255   SIGNAL (SIGILL);
256   SIGNAL (SIGSEGV);
257 #ifdef SIGBUS /* not in mingw */
258   SIGNAL (SIGBUS);
259 #endif
260   SIGNAL (SIGFPE);
261   SIGNAL (SIGABRT);
262 }
263
264 int
265 main (int argc, char **argv)
266 {
267   struct expr *e;
268   int i;
269   mpz_t r;
270   int errcode = 0;
271   char *str;
272   int base = 10;
273
274   setup_error_handler ();
275
276   gmp_randinit (rstate, GMP_RAND_ALG_LC, 128);
277
278   {
279 #if HAVE_GETTIMEOFDAY
280     struct timeval tv;
281     gettimeofday (&tv, NULL);
282     gmp_randseed_ui (rstate, tv.tv_sec + tv.tv_usec);
283 #else
284     time_t t;
285     time (&t);
286     gmp_randseed_ui (rstate, t);
287 #endif
288   }
289
290   mpz_init (r);
291
292   while (argc > 1 && argv[1][0] == '-')
293     {
294       char *arg = argv[1];
295
296       if (arg[1] >= '0' && arg[1] <= '9')
297         break;
298
299       if (arg[1] == 't')
300         print_timing = 1;
301       else if (arg[1] == 'b' && arg[2] >= '0' && arg[2] <= '9')
302         {
303           base = atoi (arg + 2);
304           if (base < 2 || base > 62)
305             {
306               fprintf (stderr, "error: invalid output base\n");
307               exit (-1);
308             }
309         }
310       else if (arg[1] == 'b' && arg[2] == 0)
311         base = 2;
312       else if (arg[1] == 'x' && arg[2] == 0)
313         base = 16;
314       else if (arg[1] == 'X' && arg[2] == 0)
315         base = -16;
316       else if (arg[1] == 'o' && arg[2] == 0)
317         base = 8;
318       else if (arg[1] == 'd' && arg[2] == 0)
319         base = 10;
320       else if (arg[1] == 'v' && arg[2] == 0)
321         {
322           printf ("pexpr linked to gmp %s\n", __gmp_version);
323         }
324       else if (strcmp (arg, "-html") == 0)
325         {
326           flag_html = 1;
327           newline = "<br>";
328         }
329       else if (strcmp (arg, "-wml") == 0)
330         {
331           flag_wml = 1;
332           newline = "<br/>";
333         }
334       else if (strcmp (arg, "-split") == 0)
335         {
336           flag_splitup_output = 1;
337         }
338       else if (strcmp (arg, "-noprint") == 0)
339         {
340           flag_print = 0;
341         }
342       else
343         {
344           fprintf (stderr, "error: unknown option `%s'\n", arg);
345           exit (-1);
346         }
347       argv++;
348       argc--;
349     }
350
351   for (i = 1; i < argc; i++)
352     {
353       int s;
354       int jmpval;
355
356       /* Set up error handler for parsing expression.  */
357       jmpval = setjmp (errjmpbuf);
358       if (jmpval != 0)
359         {
360           fprintf (stderr, "error: %s%s\n", error, newline);
361           fprintf (stderr, "       %s%s\n", argv[i], newline);
362           if (! flag_html)
363             {
364               /* ??? Dunno how to align expression position with arrow in
365                  HTML ??? */
366               fprintf (stderr, "       ");
367               for (s = jmpval - (long) argv[i]; --s >= 0; )
368                 putc (' ', stderr);
369               fprintf (stderr, "^\n");
370             }
371
372           errcode |= 1;
373           continue;
374         }
375
376       str = expr (argv[i], &e);
377
378       if (str[0] != 0)
379         {
380           fprintf (stderr,
381                    "error: garbage where end of expression expected%s\n",
382                    newline);
383           fprintf (stderr, "       %s%s\n", argv[i], newline);
384           if (! flag_html)
385             {
386               /* ??? Dunno how to align expression position with arrow in
387                  HTML ??? */
388               fprintf (stderr, "        ");
389               for (s = str - argv[i]; --s; )
390                 putc (' ', stderr);
391               fprintf (stderr, "^\n");
392             }
393
394           errcode |= 1;
395           free_expr (e);
396           continue;
397         }
398
399       /* Set up error handler for evaluating expression.  */
400       if (setjmp (errjmpbuf))
401         {
402           fprintf (stderr, "error: %s%s\n", error, newline);
403           fprintf (stderr, "       %s%s\n", argv[i], newline);
404           if (! flag_html)
405             {
406               /* ??? Dunno how to align expression position with arrow in
407                  HTML ??? */
408               fprintf (stderr, "       ");
409               for (s = str - argv[i]; --s >= 0; )
410                 putc (' ', stderr);
411               fprintf (stderr, "^\n");
412             }
413
414           errcode |= 2;
415           continue;
416         }
417
418       if (print_timing)
419         {
420           int t;
421           TIME (t, mpz_eval_expr (r, e));
422           printf ("computation took %d ms%s\n", t, newline);
423         }
424       else
425         mpz_eval_expr (r, e);
426
427       if (flag_print)
428         {
429           size_t out_len;
430           char *tmp, *s;
431
432           out_len = mpz_sizeinbase (r, base >= 0 ? base : -base) + 2;
433 #ifdef LIMIT_RESOURCE_USAGE
434           if (out_len > 100000)
435             {
436               printf ("result is about %ld digits, not printing it%s\n",
437                       (long) out_len - 3, newline);
438               exit (-2);
439             }
440 #endif
441           tmp = malloc (out_len);
442
443           if (print_timing)
444             {
445               int t;
446               printf ("output conversion ");
447               TIME (t, mpz_get_str (tmp, base, r));
448               printf ("took %d ms%s\n", t, newline);
449             }
450           else
451             mpz_get_str (tmp, base, r);
452
453           out_len = strlen (tmp);
454           if (flag_splitup_output)
455             {
456               for (s = tmp; out_len > 80; s += 80)
457                 {
458                   fwrite (s, 1, 80, stdout);
459                   printf ("%s\n", newline);
460                   out_len -= 80;
461                 }
462
463               fwrite (s, 1, out_len, stdout);
464             }
465           else
466             {
467               fwrite (tmp, 1, out_len, stdout);
468             }
469
470           free (tmp);
471           printf ("%s\n", newline);
472         }
473       else
474         {
475           printf ("result is approximately %ld digits%s\n",
476                   (long) mpz_sizeinbase (r, base >= 0 ? base : -base),
477                   newline);
478         }
479
480       free_expr (e);
481     }
482
483   exit (errcode);
484 }
485
486 char *
487 expr (char *str, expr_t *e)
488 {
489   expr_t e2;
490
491   str = skipspace (str);
492   if (str[0] == '+')
493     {
494       str = term (str + 1, e);
495     }
496   else if (str[0] == '-')
497     {
498       str = term (str + 1, e);
499       makeexp (e, NEG, *e, NULL);
500     }
501   else if (str[0] == '~')
502     {
503       str = term (str + 1, e);
504       makeexp (e, NOT, *e, NULL);
505     }
506   else
507     {
508       str = term (str, e);
509     }
510
511   for (;;)
512     {
513       str = skipspace (str);
514       switch (str[0])
515         {
516         case 'p':
517           if (match ("plus", str))
518             {
519               str = term (str + 4, &e2);
520               makeexp (e, PLUS, *e, e2);
521             }
522           else
523             return str;
524           break;
525         case 'm':
526           if (match ("minus", str))
527             {
528               str = term (str + 5, &e2);
529               makeexp (e, MINUS, *e, e2);
530             }
531           else
532             return str;
533           break;
534         case '+':
535           str = term (str + 1, &e2);
536           makeexp (e, PLUS, *e, e2);
537           break;
538         case '-':
539           str = term (str + 1, &e2);
540           makeexp (e, MINUS, *e, e2);
541           break;
542         default:
543           return str;
544         }
545     }
546 }
547
548 char *
549 term (char *str, expr_t *e)
550 {
551   expr_t e2;
552
553   str = power (str, e);
554   for (;;)
555     {
556       str = skipspace (str);
557       switch (str[0])
558         {
559         case 'm':
560           if (match ("mul", str))
561             {
562               str = power (str + 3, &e2);
563               makeexp (e, MULT, *e, e2);
564               break;
565             }
566           if (match ("mod", str))
567             {
568               str = power (str + 3, &e2);
569               makeexp (e, MOD, *e, e2);
570               break;
571             }
572           return str;
573         case 'd':
574           if (match ("div", str))
575             {
576               str = power (str + 3, &e2);
577               makeexp (e, DIV, *e, e2);
578               break;
579             }
580           return str;
581         case 'r':
582           if (match ("rem", str))
583             {
584               str = power (str + 3, &e2);
585               makeexp (e, REM, *e, e2);
586               break;
587             }
588           return str;
589         case 'i':
590           if (match ("invmod", str))
591             {
592               str = power (str + 6, &e2);
593               makeexp (e, REM, *e, e2);
594               break;
595             }
596           return str;
597         case 't':
598           if (match ("times", str))
599             {
600               str = power (str + 5, &e2);
601               makeexp (e, MULT, *e, e2);
602               break;
603             }
604           if (match ("thru", str))
605             {
606               str = power (str + 4, &e2);
607               makeexp (e, DIV, *e, e2);
608               break;
609             }
610           if (match ("through", str))
611             {
612               str = power (str + 7, &e2);
613               makeexp (e, DIV, *e, e2);
614               break;
615             }
616           return str;
617         case '*':
618           str = power (str + 1, &e2);
619           makeexp (e, MULT, *e, e2);
620           break;
621         case '/':
622           str = power (str + 1, &e2);
623           makeexp (e, DIV, *e, e2);
624           break;
625         case '%':
626           str = power (str + 1, &e2);
627           makeexp (e, MOD, *e, e2);
628           break;
629         default:
630           return str;
631         }
632     }
633 }
634
635 char *
636 power (char *str, expr_t *e)
637 {
638   expr_t e2;
639
640   str = factor (str, e);
641   while (str[0] == '!')
642     {
643       str++;
644       makeexp (e, FAC, *e, NULL);
645     }
646   str = skipspace (str);
647   if (str[0] == '^')
648     {
649       str = power (str + 1, &e2);
650       makeexp (e, POW, *e, e2);
651     }
652   return str;
653 }
654
655 int
656 match (char *s, char *str)
657 {
658   char *ostr = str;
659   int i;
660
661   for (i = 0; s[i] != 0; i++)
662     {
663       if (str[i] != s[i])
664         return 0;
665     }
666   str = skipspace (str + i);
667   return str - ostr;
668 }
669
670 int
671 matchp (char *s, char *str)
672 {
673   char *ostr = str;
674   int i;
675
676   for (i = 0; s[i] != 0; i++)
677     {
678       if (str[i] != s[i])
679         return 0;
680     }
681   str = skipspace (str + i);
682   if (str[0] == '(')
683     return str - ostr + 1;
684   return 0;
685 }
686
687 struct functions
688 {
689   char *spelling;
690   enum op_t op;
691   int arity; /* 1 or 2 means real arity; 0 means arbitrary.  */
692 };
693
694 struct functions fns[] =
695 {
696   {"sqrt", SQRT, 1},
697 #if __GNU_MP_VERSION >= 2
698   {"root", ROOT, 2},
699   {"popc", POPCNT, 1},
700   {"hamdist", HAMDIST, 2},
701 #endif
702   {"gcd", GCD, 0},
703 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
704   {"lcm", LCM, 0},
705 #endif
706   {"and", AND, 0},
707   {"ior", IOR, 0},
708 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
709   {"xor", XOR, 0},
710 #endif
711   {"plus", PLUS, 0},
712   {"pow", POW, 2},
713   {"minus", MINUS, 2},
714   {"mul", MULT, 0},
715   {"div", DIV, 2},
716   {"mod", MOD, 2},
717   {"rem", REM, 2},
718 #if __GNU_MP_VERSION >= 2
719   {"invmod", INVMOD, 2},
720 #endif
721   {"log", LOG, 2},
722   {"log2", LOG2, 1},
723   {"F", FERMAT, 1},
724   {"M", MERSENNE, 1},
725   {"fib", FIBONACCI, 1},
726   {"Fib", FIBONACCI, 1},
727   {"random", RANDOM, 1},
728   {"nextprime", NEXTPRIME, 1},
729   {"binom", BINOM, 2},
730   {"binomial", BINOM, 2},
731   {"fac", FAC, 1},
732   {"fact", FAC, 1},
733   {"factorial", FAC, 1},
734   {"time", TIMING, 1},
735   {"", NOP, 0}
736 };
737
738 char *
739 factor (char *str, expr_t *e)
740 {
741   expr_t e1, e2;
742
743   str = skipspace (str);
744
745   if (isalpha (str[0]))
746     {
747       int i;
748       int cnt;
749
750       for (i = 0; fns[i].op != NOP; i++)
751         {
752           if (fns[i].arity == 1)
753             {
754               cnt = matchp (fns[i].spelling, str);
755               if (cnt != 0)
756                 {
757                   str = expr (str + cnt, &e1);
758                   str = skipspace (str);
759                   if (str[0] != ')')
760                     {
761                       error = "expected `)'";
762                       longjmp (errjmpbuf, (int) (long) str);
763                     }
764                   makeexp (e, fns[i].op, e1, NULL);
765                   return str + 1;
766                 }
767             }
768         }
769
770       for (i = 0; fns[i].op != NOP; i++)
771         {
772           if (fns[i].arity != 1)
773             {
774               cnt = matchp (fns[i].spelling, str);
775               if (cnt != 0)
776                 {
777                   str = expr (str + cnt, &e1);
778                   str = skipspace (str);
779
780                   if (str[0] != ',')
781                     {
782                       error = "expected `,' and another operand";
783                       longjmp (errjmpbuf, (int) (long) str);
784                     }
785
786                   str = skipspace (str + 1);
787                   str = expr (str, &e2);
788                   str = skipspace (str);
789
790                   if (fns[i].arity == 0)
791                     {
792                       while (str[0] == ',')
793                         {
794                           makeexp (&e1, fns[i].op, e1, e2);
795                           str = skipspace (str + 1);
796                           str = expr (str, &e2);
797                           str = skipspace (str);
798                         }
799                     }
800
801                   if (str[0] != ')')
802                     {
803                       error = "expected `)'";
804                       longjmp (errjmpbuf, (int) (long) str);
805                     }
806
807                   makeexp (e, fns[i].op, e1, e2);
808                   return str + 1;
809                 }
810             }
811         }
812     }
813
814   if (str[0] == '(')
815     {
816       str = expr (str + 1, e);
817       str = skipspace (str);
818       if (str[0] != ')')
819         {
820           error = "expected `)'";
821           longjmp (errjmpbuf, (int) (long) str);
822         }
823       str++;
824     }
825   else if (str[0] >= '0' && str[0] <= '9')
826     {
827       expr_t res;
828       char *s, *sc;
829
830       res = malloc (sizeof (struct expr));
831       res -> op = LIT;
832       mpz_init (res->operands.val);
833
834       s = str;
835       while (isalnum (str[0]))
836         str++;
837       sc = malloc (str - s + 1);
838       memcpy (sc, s, str - s);
839       sc[str - s] = 0;
840
841       mpz_set_str (res->operands.val, sc, 0);
842       *e = res;
843       free (sc);
844     }
845   else
846     {
847       error = "operand expected";
848       longjmp (errjmpbuf, (int) (long) str);
849     }
850   return str;
851 }
852
853 char *
854 skipspace (char *str)
855 {
856   while (str[0] == ' ')
857     str++;
858   return str;
859 }
860
861 /* Make a new expression with operation OP and right hand side
862    RHS and left hand side lhs.  Put the result in R.  */
863 void
864 makeexp (expr_t *r, enum op_t op, expr_t lhs, expr_t rhs)
865 {
866   expr_t res;
867   res = malloc (sizeof (struct expr));
868   res -> op = op;
869   res -> operands.ops.lhs = lhs;
870   res -> operands.ops.rhs = rhs;
871   *r = res;
872   return;
873 }
874
875 /* Free the memory used by expression E.  */
876 void
877 free_expr (expr_t e)
878 {
879   if (e->op != LIT)
880     {
881       free_expr (e->operands.ops.lhs);
882       if (e->operands.ops.rhs != NULL)
883         free_expr (e->operands.ops.rhs);
884     }
885   else
886     {
887       mpz_clear (e->operands.val);
888     }
889 }
890
891 /* Evaluate the expression E and put the result in R.  */
892 void
893 mpz_eval_expr (mpz_ptr r, expr_t e)
894 {
895   mpz_t lhs, rhs;
896
897   switch (e->op)
898     {
899     case LIT:
900       mpz_set (r, e->operands.val);
901       return;
902     case PLUS:
903       mpz_init (lhs); mpz_init (rhs);
904       mpz_eval_expr (lhs, e->operands.ops.lhs);
905       mpz_eval_expr (rhs, e->operands.ops.rhs);
906       mpz_add (r, lhs, rhs);
907       mpz_clear (lhs); mpz_clear (rhs);
908       return;
909     case MINUS:
910       mpz_init (lhs); mpz_init (rhs);
911       mpz_eval_expr (lhs, e->operands.ops.lhs);
912       mpz_eval_expr (rhs, e->operands.ops.rhs);
913       mpz_sub (r, lhs, rhs);
914       mpz_clear (lhs); mpz_clear (rhs);
915       return;
916     case MULT:
917       mpz_init (lhs); mpz_init (rhs);
918       mpz_eval_expr (lhs, e->operands.ops.lhs);
919       mpz_eval_expr (rhs, e->operands.ops.rhs);
920       mpz_mul (r, lhs, rhs);
921       mpz_clear (lhs); mpz_clear (rhs);
922       return;
923     case DIV:
924       mpz_init (lhs); mpz_init (rhs);
925       mpz_eval_expr (lhs, e->operands.ops.lhs);
926       mpz_eval_expr (rhs, e->operands.ops.rhs);
927       mpz_fdiv_q (r, lhs, rhs);
928       mpz_clear (lhs); mpz_clear (rhs);
929       return;
930     case MOD:
931       mpz_init (rhs);
932       mpz_eval_expr (rhs, e->operands.ops.rhs);
933       mpz_abs (rhs, rhs);
934       mpz_eval_mod_expr (r, e->operands.ops.lhs, rhs);
935       mpz_clear (rhs);
936       return;
937     case REM:
938       /* Check if lhs operand is POW expression and optimize for that case.  */
939       if (e->operands.ops.lhs->op == POW)
940         {
941           mpz_t powlhs, powrhs;
942           mpz_init (powlhs);
943           mpz_init (powrhs);
944           mpz_init (rhs);
945           mpz_eval_expr (powlhs, e->operands.ops.lhs->operands.ops.lhs);
946           mpz_eval_expr (powrhs, e->operands.ops.lhs->operands.ops.rhs);
947           mpz_eval_expr (rhs, e->operands.ops.rhs);
948           mpz_powm (r, powlhs, powrhs, rhs);
949           if (mpz_cmp_si (rhs, 0L) < 0)
950             mpz_neg (r, r);
951           mpz_clear (powlhs);
952           mpz_clear (powrhs);
953           mpz_clear (rhs);
954           return;
955         }
956
957       mpz_init (lhs); mpz_init (rhs);
958       mpz_eval_expr (lhs, e->operands.ops.lhs);
959       mpz_eval_expr (rhs, e->operands.ops.rhs);
960       mpz_fdiv_r (r, lhs, rhs);
961       mpz_clear (lhs); mpz_clear (rhs);
962       return;
963 #if __GNU_MP_VERSION >= 2
964     case INVMOD:
965       mpz_init (lhs); mpz_init (rhs);
966       mpz_eval_expr (lhs, e->operands.ops.lhs);
967       mpz_eval_expr (rhs, e->operands.ops.rhs);
968       mpz_invert (r, lhs, rhs);
969       mpz_clear (lhs); mpz_clear (rhs);
970       return;
971 #endif
972     case POW:
973       mpz_init (lhs); mpz_init (rhs);
974       mpz_eval_expr (lhs, e->operands.ops.lhs);
975       if (mpz_cmpabs_ui (lhs, 1) <= 0)
976         {
977           /* For 0^rhs and 1^rhs, we just need to verify that
978              rhs is well-defined.  For (-1)^rhs we need to
979              determine (rhs mod 2).  For simplicity, compute
980              (rhs mod 2) for all three cases.  */
981           expr_t two, et;
982           two = malloc (sizeof (struct expr));
983           two -> op = LIT;
984           mpz_init_set_ui (two->operands.val, 2L);
985           makeexp (&et, MOD, e->operands.ops.rhs, two);
986           e->operands.ops.rhs = et;
987         }
988
989       mpz_eval_expr (rhs, e->operands.ops.rhs);
990       if (mpz_cmp_si (rhs, 0L) == 0)
991         /* x^0 is 1 */
992         mpz_set_ui (r, 1L);
993       else if (mpz_cmp_si (lhs, 0L) == 0)
994         /* 0^y (where y != 0) is 0 */
995         mpz_set_ui (r, 0L);
996       else if (mpz_cmp_ui (lhs, 1L) == 0)
997         /* 1^y is 1 */
998         mpz_set_ui (r, 1L);
999       else if (mpz_cmp_si (lhs, -1L) == 0)
1000         /* (-1)^y just depends on whether y is even or odd */
1001         mpz_set_si (r, (mpz_get_ui (rhs) & 1) ? -1L : 1L);
1002       else if (mpz_cmp_si (rhs, 0L) < 0)
1003         /* x^(-n) is 0 */
1004         mpz_set_ui (r, 0L);
1005       else
1006         {
1007           unsigned long int cnt;
1008           unsigned long int y;
1009           /* error if exponent does not fit into an unsigned long int.  */
1010           if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1011             goto pow_err;
1012
1013           y = mpz_get_ui (rhs);
1014           /* x^y == (x/(2^c))^y * 2^(c*y) */
1015 #if __GNU_MP_VERSION >= 2
1016           cnt = mpz_scan1 (lhs, 0);
1017 #else
1018           cnt = 0;
1019 #endif
1020           if (cnt != 0)
1021             {
1022               if (y * cnt / cnt != y)
1023                 goto pow_err;
1024               mpz_tdiv_q_2exp (lhs, lhs, cnt);
1025               mpz_pow_ui (r, lhs, y);
1026               mpz_mul_2exp (r, r, y * cnt);
1027             }
1028           else
1029             mpz_pow_ui (r, lhs, y);
1030         }
1031       mpz_clear (lhs); mpz_clear (rhs);
1032       return;
1033     pow_err:
1034       error = "result of `pow' operator too large";
1035       mpz_clear (lhs); mpz_clear (rhs);
1036       longjmp (errjmpbuf, 1);
1037     case GCD:
1038       mpz_init (lhs); mpz_init (rhs);
1039       mpz_eval_expr (lhs, e->operands.ops.lhs);
1040       mpz_eval_expr (rhs, e->operands.ops.rhs);
1041       mpz_gcd (r, lhs, rhs);
1042       mpz_clear (lhs); mpz_clear (rhs);
1043       return;
1044 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1045     case LCM:
1046       mpz_init (lhs); mpz_init (rhs);
1047       mpz_eval_expr (lhs, e->operands.ops.lhs);
1048       mpz_eval_expr (rhs, e->operands.ops.rhs);
1049       mpz_lcm (r, lhs, rhs);
1050       mpz_clear (lhs); mpz_clear (rhs);
1051       return;
1052 #endif
1053     case AND:
1054       mpz_init (lhs); mpz_init (rhs);
1055       mpz_eval_expr (lhs, e->operands.ops.lhs);
1056       mpz_eval_expr (rhs, e->operands.ops.rhs);
1057       mpz_and (r, lhs, rhs);
1058       mpz_clear (lhs); mpz_clear (rhs);
1059       return;
1060     case IOR:
1061       mpz_init (lhs); mpz_init (rhs);
1062       mpz_eval_expr (lhs, e->operands.ops.lhs);
1063       mpz_eval_expr (rhs, e->operands.ops.rhs);
1064       mpz_ior (r, lhs, rhs);
1065       mpz_clear (lhs); mpz_clear (rhs);
1066       return;
1067 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1068     case XOR:
1069       mpz_init (lhs); mpz_init (rhs);
1070       mpz_eval_expr (lhs, e->operands.ops.lhs);
1071       mpz_eval_expr (rhs, e->operands.ops.rhs);
1072       mpz_xor (r, lhs, rhs);
1073       mpz_clear (lhs); mpz_clear (rhs);
1074       return;
1075 #endif
1076     case NEG:
1077       mpz_eval_expr (r, e->operands.ops.lhs);
1078       mpz_neg (r, r);
1079       return;
1080     case NOT:
1081       mpz_eval_expr (r, e->operands.ops.lhs);
1082       mpz_com (r, r);
1083       return;
1084     case SQRT:
1085       mpz_init (lhs);
1086       mpz_eval_expr (lhs, e->operands.ops.lhs);
1087       if (mpz_sgn (lhs) < 0)
1088         {
1089           error = "cannot take square root of negative numbers";
1090           mpz_clear (lhs);
1091           longjmp (errjmpbuf, 1);
1092         }
1093       mpz_sqrt (r, lhs);
1094       return;
1095 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1096     case ROOT:
1097       mpz_init (lhs); mpz_init (rhs);
1098       mpz_eval_expr (lhs, e->operands.ops.lhs);
1099       mpz_eval_expr (rhs, e->operands.ops.rhs);
1100       if (mpz_sgn (rhs) <= 0)
1101         {
1102           error = "cannot take non-positive root orders";
1103           mpz_clear (lhs); mpz_clear (rhs);
1104           longjmp (errjmpbuf, 1);
1105         }
1106       if (mpz_sgn (lhs) < 0 && (mpz_get_ui (rhs) & 1) == 0)
1107         {
1108           error = "cannot take even root orders of negative numbers";
1109           mpz_clear (lhs); mpz_clear (rhs);
1110           longjmp (errjmpbuf, 1);
1111         }
1112
1113       {
1114         unsigned long int nth = mpz_get_ui (rhs);
1115         if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1116           {
1117             /* If we are asked to take an awfully large root order, cheat and
1118                ask for the largest order we can pass to mpz_root.  This saves
1119                some error prone special cases.  */
1120             nth = ~(unsigned long int) 0;
1121           }
1122         mpz_root (r, lhs, nth);
1123       }
1124       mpz_clear (lhs); mpz_clear (rhs);
1125       return;
1126 #endif
1127     case FAC:
1128       mpz_eval_expr (r, e->operands.ops.lhs);
1129       if (mpz_size (r) > 1)
1130         {
1131           error = "result of `!' operator too large";
1132           longjmp (errjmpbuf, 1);
1133         }
1134       mpz_fac_ui (r, mpz_get_ui (r));
1135       return;
1136 #if __GNU_MP_VERSION >= 2
1137     case POPCNT:
1138       mpz_eval_expr (r, e->operands.ops.lhs);
1139       { long int cnt;
1140         cnt = mpz_popcount (r);
1141         mpz_set_si (r, cnt);
1142       }
1143       return;
1144     case HAMDIST:
1145       { long int cnt;
1146         mpz_init (lhs); mpz_init (rhs);
1147         mpz_eval_expr (lhs, e->operands.ops.lhs);
1148         mpz_eval_expr (rhs, e->operands.ops.rhs);
1149         cnt = mpz_hamdist (lhs, rhs);
1150         mpz_clear (lhs); mpz_clear (rhs);
1151         mpz_set_si (r, cnt);
1152       }
1153       return;
1154 #endif
1155     case LOG2:
1156       mpz_eval_expr (r, e->operands.ops.lhs);
1157       { unsigned long int cnt;
1158         if (mpz_sgn (r) <= 0)
1159           {
1160             error = "logarithm of non-positive number";
1161             longjmp (errjmpbuf, 1);
1162           }
1163         cnt = mpz_sizeinbase (r, 2);
1164         mpz_set_ui (r, cnt - 1);
1165       }
1166       return;
1167     case LOG:
1168       { unsigned long int cnt;
1169         mpz_init (lhs); mpz_init (rhs);
1170         mpz_eval_expr (lhs, e->operands.ops.lhs);
1171         mpz_eval_expr (rhs, e->operands.ops.rhs);
1172         if (mpz_sgn (lhs) <= 0)
1173           {
1174             error = "logarithm of non-positive number";
1175             mpz_clear (lhs); mpz_clear (rhs);
1176             longjmp (errjmpbuf, 1);
1177           }
1178         if (mpz_cmp_ui (rhs, 256) >= 0)
1179           {
1180             error = "logarithm base too large";
1181             mpz_clear (lhs); mpz_clear (rhs);
1182             longjmp (errjmpbuf, 1);
1183           }
1184         cnt = mpz_sizeinbase (lhs, mpz_get_ui (rhs));
1185         mpz_set_ui (r, cnt - 1);
1186         mpz_clear (lhs); mpz_clear (rhs);
1187       }
1188       return;
1189     case FERMAT:
1190       {
1191         unsigned long int t;
1192         mpz_init (lhs);
1193         mpz_eval_expr (lhs, e->operands.ops.lhs);
1194         t = (unsigned long int) 1 << mpz_get_ui (lhs);
1195         if (mpz_cmp_ui (lhs, ~(unsigned long int) 0) > 0 || t == 0)
1196           {
1197             error = "too large Mersenne number index";
1198             mpz_clear (lhs);
1199             longjmp (errjmpbuf, 1);
1200           }
1201         mpz_set_ui (r, 1);
1202         mpz_mul_2exp (r, r, t);
1203         mpz_add_ui (r, r, 1);
1204         mpz_clear (lhs);
1205       }
1206       return;
1207     case MERSENNE:
1208       mpz_init (lhs);
1209       mpz_eval_expr (lhs, e->operands.ops.lhs);
1210       if (mpz_cmp_ui (lhs, ~(unsigned long int) 0) > 0)
1211         {
1212           error = "too large Mersenne number index";
1213           mpz_clear (lhs);
1214           longjmp (errjmpbuf, 1);
1215         }
1216       mpz_set_ui (r, 1);
1217       mpz_mul_2exp (r, r, mpz_get_ui (lhs));
1218       mpz_sub_ui (r, r, 1);
1219       mpz_clear (lhs);
1220       return;
1221     case FIBONACCI:
1222       { mpz_t t;
1223         unsigned long int n, i;
1224         mpz_init (lhs);
1225         mpz_eval_expr (lhs, e->operands.ops.lhs);
1226         if (mpz_sgn (lhs) <= 0 || mpz_cmp_si (lhs, 1000000000) > 0)
1227           {
1228             error = "Fibonacci index out of range";
1229             mpz_clear (lhs);
1230             longjmp (errjmpbuf, 1);
1231           }
1232         n = mpz_get_ui (lhs);
1233         mpz_clear (lhs);
1234
1235 #if __GNU_MP_VERSION > 2 || __GNU_MP_VERSION_MINOR >= 1
1236         mpz_fib_ui (r, n);
1237 #else
1238         mpz_init_set_ui (t, 1);
1239         mpz_set_ui (r, 1);
1240
1241         if (n <= 2)
1242           mpz_set_ui (r, 1);
1243         else
1244           {
1245             for (i = 3; i <= n; i++)
1246               {
1247                 mpz_add (t, t, r);
1248                 mpz_swap (t, r);
1249               }
1250           }
1251         mpz_clear (t);
1252 #endif
1253       }
1254       return;
1255     case RANDOM:
1256       {
1257         unsigned long int n;
1258         mpz_init (lhs);
1259         mpz_eval_expr (lhs, e->operands.ops.lhs);
1260         if (mpz_sgn (lhs) <= 0 || mpz_cmp_si (lhs, 1000000000) > 0)
1261           {
1262             error = "random number size out of range";
1263             mpz_clear (lhs);
1264             longjmp (errjmpbuf, 1);
1265           }
1266         n = mpz_get_ui (lhs);
1267         mpz_clear (lhs);
1268         mpz_urandomb (r, rstate, n);
1269       }
1270       return;
1271     case NEXTPRIME:
1272       {
1273         mpz_eval_expr (r, e->operands.ops.lhs);
1274         mpz_nextprime (r, r);
1275       }
1276       return;
1277     case BINOM:
1278       mpz_init (lhs); mpz_init (rhs);
1279       mpz_eval_expr (lhs, e->operands.ops.lhs);
1280       mpz_eval_expr (rhs, e->operands.ops.rhs);
1281       {
1282         unsigned long int k;
1283         if (mpz_cmp_ui (rhs, ~(unsigned long int) 0) > 0)
1284           {
1285             error = "k too large in (n over k) expression";
1286             mpz_clear (lhs); mpz_clear (rhs);
1287             longjmp (errjmpbuf, 1);
1288           }
1289         k = mpz_get_ui (rhs);
1290         mpz_bin_ui (r, lhs, k);
1291       }
1292       mpz_clear (lhs); mpz_clear (rhs);
1293       return;
1294     case TIMING:
1295       {
1296         int t0;
1297         t0 = cputime ();
1298         mpz_eval_expr (r, e->operands.ops.lhs);
1299         printf ("time: %d\n", cputime () - t0);
1300       }
1301       return;
1302     default:
1303       abort ();
1304     }
1305 }
1306
1307 /* Evaluate the expression E modulo MOD and put the result in R.  */
1308 void
1309 mpz_eval_mod_expr (mpz_ptr r, expr_t e, mpz_ptr mod)
1310 {
1311   mpz_t lhs, rhs;
1312
1313   switch (e->op)
1314     {
1315       case POW:
1316         mpz_init (lhs); mpz_init (rhs);
1317         mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1318         mpz_eval_expr (rhs, e->operands.ops.rhs);
1319         mpz_powm (r, lhs, rhs, mod);
1320         mpz_clear (lhs); mpz_clear (rhs);
1321         return;
1322       case PLUS:
1323         mpz_init (lhs); mpz_init (rhs);
1324         mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1325         mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1326         mpz_add (r, lhs, rhs);
1327         if (mpz_cmp_si (r, 0L) < 0)
1328           mpz_add (r, r, mod);
1329         else if (mpz_cmp (r, mod) >= 0)
1330           mpz_sub (r, r, mod);
1331         mpz_clear (lhs); mpz_clear (rhs);
1332         return;
1333       case MINUS:
1334         mpz_init (lhs); mpz_init (rhs);
1335         mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1336         mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1337         mpz_sub (r, lhs, rhs);
1338         if (mpz_cmp_si (r, 0L) < 0)
1339           mpz_add (r, r, mod);
1340         else if (mpz_cmp (r, mod) >= 0)
1341           mpz_sub (r, r, mod);
1342         mpz_clear (lhs); mpz_clear (rhs);
1343         return;
1344       case MULT:
1345         mpz_init (lhs); mpz_init (rhs);
1346         mpz_eval_mod_expr (lhs, e->operands.ops.lhs, mod);
1347         mpz_eval_mod_expr (rhs, e->operands.ops.rhs, mod);
1348         mpz_mul (r, lhs, rhs);
1349         mpz_mod (r, r, mod);
1350         mpz_clear (lhs); mpz_clear (rhs);
1351         return;
1352       default:
1353         mpz_init (lhs);
1354         mpz_eval_expr (lhs, e);
1355         mpz_mod (r, lhs, mod);
1356         mpz_clear (lhs);
1357         return;
1358     }
1359 }
1360
1361 void
1362 cleanup_and_exit (int sig)
1363 {
1364   switch (sig) {
1365 #ifdef LIMIT_RESOURCE_USAGE
1366   case SIGXCPU:
1367     printf ("expression took too long to evaluate%s\n", newline);
1368     break;
1369 #endif
1370   case SIGFPE:
1371     printf ("divide by zero%s\n", newline);
1372     break;
1373   default:
1374     printf ("expression required too much memory to evaluate%s\n", newline);
1375     break;
1376   }
1377   exit (-2);
1378 }