Imported Upstream version 0.6.24
[platform/upstream/libsolv.git] / ext / solv_pgpvrfy.c
1 /*
2  * Copyright (c) 2013, SUSE Inc.
3  *
4  * This program is licensed under the BSD license, read LICENSE.BSD
5  * for further information
6  */
7
8 /* simple and slow rsa/dsa verification code. */
9
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <string.h>
13
14 #include "util.h"
15 #include "solv_pgpvrfy.h"
16
17 typedef unsigned int mp_t;
18 typedef unsigned long long mp2_t;
19 #define MP_T_BYTES 4
20
21 #define MP_T_BITS (MP_T_BYTES * 8)
22
23 static inline void
24 mpzero(int len, mp_t *target)
25 {
26   memset(target, 0, MP_T_BYTES * len);
27 }
28
29 static inline mp_t *
30 mpnew(int len)
31 {
32   return solv_calloc(len, MP_T_BYTES);
33 }
34
35 static inline void
36 mpcpy(int len, mp_t *target, mp_t *source)
37 {
38   memcpy(target, source, len * MP_T_BYTES);
39 }
40
41 #if 0
42 static void mpdump(int l, mp_t *a, char *s)
43 {
44   int i;
45   if (s)
46     fprintf(stderr, "%s", s);
47   for (i = l - 1; i >= 0; i--)
48     fprintf(stderr, "%0*x", MP_T_BYTES * 2, a[i]);
49   fprintf(stderr, "\n");
50 }
51 #endif
52
53 /* target[len] = x, target = target % mod
54  * assumes that target < (mod << MP_T_BITS)! */
55 static void
56 mpdomod(int len, mp_t *target, mp2_t x, mp_t *mod)
57 {
58   int i, j;
59   for (i = len - 1; i >= 0; i--)
60     {
61       x = (x << MP_T_BITS) | target[i];
62       target[i] = 0;
63       if (mod[i])
64         break;
65     }
66   if (i < 0)
67     return;
68   while (x >= 2 * (mp2_t)mod[i])
69     {
70       /* reduce */
71       mp2_t z = x / ((mp2_t)mod[i] + 1);
72       mp2_t n = 0;
73       if ((z >> MP_T_BITS) != 0)
74         z = (mp2_t)1 << MP_T_BITS;      /* just in case... */
75       for (j = 0; j < i; j++)
76         {
77           mp_t n2;
78           n += mod[j] * z;
79           n2 = (mp_t)n;
80           n >>= MP_T_BITS;
81           if (n2 > target[j])
82             n++;
83           target[j] -= n2;
84         }
85       n += mod[j] * z;
86       x -= n;
87     }
88   target[i] = x;
89   if (x >= mod[i])
90     {
91       mp_t n;
92       if (x == mod[i])
93         {
94           for (j = i - 1; j >= 0; j--)
95             if (target[j] < mod[j])
96               return;
97             else if (target[j] > mod[j])
98               break;
99         }
100       /* target >= mod, subtract mod */
101       n = 0;
102       for (j = 0; j <= i; j++)
103         {
104           mp2_t n2 = mod[j] + n;
105           n = n2 > target[j] ? 1 : 0;
106           target[j] -= (mp_t)n2;
107         }
108     }
109 }
110
111 /* target += src * m */
112 static void
113 mpmult_add_int(int len, mp_t *target, mp_t *src, mp2_t m, mp_t *mod)
114 {
115   int i;
116   mp2_t x = 0;
117   for (i = 0; i < len; i++)
118     {
119       x += src[i] * m + target[i];
120       target[i] = x;
121       x >>= MP_T_BITS;
122     }
123   mpdomod(len, target, x, mod);
124 }
125
126 /* target = target << MP_T_BITS */
127 static void
128 mpshift(int len, mp_t *target, mp_t *mod)
129 {
130   mp_t x;
131   if (len <= 0)
132     return;
133   x = target[len - 1];
134   if (len > 1)
135     memmove(target + 1, target, (len - 1) * MP_T_BYTES);
136   target[0] = 0;
137   mpdomod(len, target, x, mod);
138 }
139
140 /* target += m1 * m2 */
141 static void
142 mpmult_add(int len, mp_t *target, mp_t *m1, int m2len, mp_t *m2, mp_t *tmp, mp_t *mod)
143 {
144   int i, j;
145   for (j = m2len - 1; j >= 0; j--)
146     if (m2[j])
147       break;
148   if (j < 0)
149     return;
150   mpcpy(len, tmp, m1);
151   for (i = 0; i < j; i++)
152     {
153       if (m2[i])
154         mpmult_add_int(len, target, tmp, m2[i], mod);
155       mpshift(len, tmp, mod);
156     }
157   if (m2[i])
158     mpmult_add_int(len, target, tmp, m2[i], mod);
159 }
160
161 /* target = target * m */
162 static void
163 mpmult_inplace(int len, mp_t *target, mp_t *m, mp_t *tmp1, mp_t *tmp2, mp_t *mod)
164 {
165   mpzero(len, tmp1);
166   mpmult_add(len, tmp1, target, len, m, tmp2, mod);
167   mpcpy(len, target, tmp1);
168 }
169
170 /* target = target ^ 16 * b ^ e */
171 static void
172 mppow_int(int len, mp_t *target, mp_t *t, mp_t *mod, int e)
173 {
174   mp_t *t2 = t + len * 16;
175   mpmult_inplace(len, target, target, t, t2, mod);
176   mpmult_inplace(len, target, target, t, t2, mod);
177   mpmult_inplace(len, target, target, t, t2, mod);
178   mpmult_inplace(len, target, target, t, t2, mod);
179   if (e)
180     mpmult_inplace(len, target, t + len * e, t, t2, mod);
181 }
182
183 /* target = b ^ e (b has to be < mod) */
184 static void
185 mppow(int len, mp_t *target, mp_t *b, int elen, mp_t *e, mp_t *mod)
186 {
187   int i, j;
188   mp_t *t;
189   mpzero(len, target);
190   target[0] = 1;
191   for (i = elen - 1; i >= 0; i--)
192     if (e[i])
193       break;
194   if (i < 0)
195     return;
196   t = mpnew(len * 17);
197   mpcpy(len, t + len, b);
198   for (j = 2; j < 16; j++)
199     mpmult_add(len, t + len * j, b, len, t + len * j - len, t + len * 16, mod);
200   for (; i >= 0; i--)
201     {
202 #if MP_T_BYTES == 4
203       mppow_int(len, target, t, mod, (e[i] >> 28) & 0x0f);
204       mppow_int(len, target, t, mod, (e[i] >> 24) & 0x0f);
205       mppow_int(len, target, t, mod, (e[i] >> 20) & 0x0f);
206       mppow_int(len, target, t, mod, (e[i] >> 16) & 0x0f);
207       mppow_int(len, target, t, mod, (e[i] >> 12) & 0x0f);
208       mppow_int(len, target, t, mod, (e[i] >>  8) & 0x0f);
209       mppow_int(len, target, t, mod, (e[i] >>  4) & 0x0f);
210       mppow_int(len, target, t, mod,  e[i]        & 0x0f);
211 #elif MP_T_BYTES == 1
212       mppow_int(len, target, t, mod, (e[i] >>  4) & 0x0f);
213       mppow_int(len, target, t, mod,  e[i]        & 0x0f);
214 #endif
215     }
216   free(t);
217 }
218
219 /* target = m1 * m2 (m1 has to be < mod) */
220 static void
221 mpmult(int len, mp_t *target, mp_t *m1, int m2len, mp_t *m2, mp_t *mod)
222 {
223   mp_t *tmp = mpnew(len);
224   mpzero(len, target);
225   mpmult_add(len, target, m1, m2len, m2, tmp, mod);
226   free(tmp);
227 }
228
229 static int
230 mpisless(int len, mp_t *a, mp_t *b)
231 {
232   int i;
233   for (i = len - 1; i >= 0; i--)
234     if (a[i] < b[i])
235       return 1;
236     else if (a[i] > b[i])
237       return 0;
238   return 0;
239 }
240
241 static int
242 mpiszero(int len, mp_t *a)
243 {
244   int i;
245   for (i = 0; i < len; i++)
246     if (a[i])
247       return 0;
248   return 1;
249 }
250
251 static void
252 mpdec(int len, mp_t *a)
253 {
254   int i;
255   for (i = 0; i < len; i++)
256     if (a[i]--)
257       return;
258     else
259       a[i] = -(mp_t)1;
260 }
261
262 static int
263 mpdsa(int pl, mp_t *p, int ql, mp_t *q, mp_t *g, mp_t *y, mp_t *r, mp_t *s, int hl, mp_t *h)
264 {
265   mp_t *w;
266   mp_t *tmp;
267   mp_t *u1, *u2;
268   mp_t *gu1, *yu2;
269 #if 0
270   mpdump(pl, p, "p = ");
271   mpdump(ql, q, "q = ");
272   mpdump(pl, g, "g = ");
273   mpdump(pl, y, "y = ");
274   mpdump(ql, r, "r = ");
275   mpdump(ql, s, "s = ");
276   mpdump(hl, h, "h = ");
277 #endif
278   if (pl < ql || !mpisless(pl, g, p) || !mpisless(pl, y, p))
279     return 0;                           /* hmm, bad pubkey? */
280   if (!mpisless(ql, r, q) || mpiszero(ql, r))
281     return 0;
282   if (!mpisless(ql, s, q) || mpiszero(ql, s))
283     return 0;
284   tmp = mpnew(pl);                      /* note pl */
285   mpcpy(ql, tmp, q);                    /* tmp = q */
286   mpdec(ql, tmp);                       /* tmp-- */
287   mpdec(ql, tmp);                       /* tmp-- */
288   w = mpnew(ql);
289   mppow(ql, w, s, ql, tmp, q);          /* w = s ^ tmp (s ^ -1) */
290   u1 = mpnew(pl);                       /* note pl */
291   /* order is important here: h can be >= q */
292   mpmult(ql, u1, w, hl, h, q);          /* u1 = w * h */
293   u2 = mpnew(ql);                       /* u2 = 0 */
294   mpmult(ql, u2, w, ql, r, q);          /* u2 = w * r */
295   free(w);
296   gu1 = mpnew(pl);
297   yu2 = mpnew(pl);
298   mppow(pl, gu1, g, ql, u1, p);         /* gu1 = g ^ u1 */
299   mppow(pl, yu2, y, ql, u2, p);         /* yu2 = y ^ u2 */
300   mpmult(pl, u1, gu1, pl, yu2, p);      /* u1 = gu1 * yu2 */
301   free(gu1);
302   free(yu2);
303   mpzero(ql, u2);
304   u2[0] = 1;                            /* u2 = 1 */
305   mpmult(ql, tmp, u2, pl, u1, q);       /* tmp = u2 * u1 */
306   free(u1);
307   free(u2);
308 #if 0
309   mpdump(ql, tmp, "res = ");
310 #endif
311   if (memcmp(tmp, r, ql * MP_T_BYTES) != 0)
312     {
313       free(tmp);
314       return 0;
315     }
316   free(tmp);
317   return 1;
318 }
319
320 static int
321 mprsa(int nl, mp_t *n, int el, mp_t *e, mp_t *m, mp_t *c)
322 {
323   mp_t *tmp;
324 #if 0
325   mpdump(nl, n, "n = ");
326   mpdump(el, e, "e = ");
327   mpdump(nl, m, "m = ");
328   mpdump(nl, c, "c = ");
329 #endif
330   if (!mpisless(nl, m, n))
331     return 0;
332   if (!mpisless(nl, c, n))
333     return 0;
334   tmp = mpnew(nl);
335   mppow(nl, tmp, m, el, e, n);          /* tmp = m ^ e */
336 #if 0
337   mpdump(nl, tmp, "res = ");
338 #endif
339   if (memcmp(tmp, c, nl * MP_T_BYTES) != 0)
340     {
341       free(tmp);
342       return 0;
343     }
344   free(tmp);
345   return 1;
346 }
347
348 /* create mp with size tbits from data with size dbits */
349 static mp_t *
350 mpbuild(const unsigned char *d, int dbits, int tbits, int *mplp)
351 {
352   int l = (tbits + MP_T_BITS - 1) / MP_T_BITS;
353   int dl, i;
354
355   mp_t *out = mpnew(l ? l : 1);
356   if (mplp)
357     *mplp = l;
358   dl = (dbits + 7) / 8;
359   d += dl;
360   if (dbits > tbits)
361     dl = (tbits + 7) / 8;
362   for (i = 0; dl > 0; dl--, i++)
363     {
364       int x = *--d;
365       out[i / MP_T_BYTES] |= x << (8 * (i % MP_T_BYTES));
366     }
367   return out;
368 }
369
370 static const unsigned char *
371 findmpi(const unsigned char **mpip, int *mpilp, int maxbits, int *outlen)
372 {
373   int mpil = *mpilp;
374   const unsigned char *mpi = *mpip;
375   int bits, l;
376
377   *outlen = 0;
378   if (mpil < 2)
379     return 0;
380   bits = mpi[0] << 8 | mpi[1];
381   l = 2 + (bits + 7) / 8;
382   if (bits > maxbits || mpil < l || (bits && !mpi[2]))
383     {
384       *mpilp = 0;
385       return 0;
386     }
387   *outlen = bits;
388   *mpilp = mpil - l;
389   *mpip = mpi + l;
390   return mpi + 2;
391 }
392
393 int
394 solv_pgpvrfy(const unsigned char *pub, int publ, const unsigned char *sig, int sigl)
395 {
396   int hashl;
397   unsigned char *oid = 0;
398   const unsigned char *mpi;
399   int mpil;
400   int res = 0;
401
402   if (!pub || !sig || publ < 1 || sigl < 2)
403     return 0;
404   if (pub[0] != sig[0])
405     return 0;           /* key algo mismatch */
406   switch(sig[1])
407     {
408     case 1:
409       hashl = 16;       /* MD5 */
410       oid = (unsigned char *)"\022\060\040\060\014\006\010\052\206\110\206\367\015\002\005\005\000\004\020";
411       break;
412     case 2:
413       hashl = 20;       /* SHA-1 */
414       oid = (unsigned char *)"\017\060\041\060\011\006\005\053\016\003\002\032\005\000\004\024";
415       break;
416     case 8:
417       hashl = 32;       /* SHA-256 */
418       oid = (unsigned char *)"\023\060\061\060\015\006\011\140\206\110\001\145\003\004\002\001\005\000\004\040";
419       break;
420     case 9:
421       hashl = 48;       /* SHA-384 */
422       oid = (unsigned char *)"\023\060\101\060\015\006\011\140\206\110\001\145\003\004\002\002\005\000\004\060";
423       break;
424     case 10:
425       hashl = 64;       /* SHA-512 */
426       oid = (unsigned char *)"\023\060\121\060\015\006\011\140\206\110\001\145\003\004\002\003\005\000\004\100";
427       break;
428     case 11:
429       hashl = 28;       /* SHA-224 */
430       oid = (unsigned char *)"\023\060\061\060\015\006\011\140\206\110\001\145\003\004\002\004\005\000\004\034";
431       break;
432     default:
433       return 0;         /* unsupported hash algo */
434     }
435   if (sigl < 2 + hashl)
436     return 0;
437   switch (pub[0])
438     {
439     case 1:             /* RSA */
440       {
441         const unsigned char *n, *e, *m;
442         unsigned char *c;
443         int nlen, elen, mlen, clen;
444         mp_t *nx, *ex, *mx, *cx;
445         int nxl, exl;
446
447         mpi = pub + 1;
448         mpil = publ - 1;
449         n = findmpi(&mpi, &mpil, 8192, &nlen);
450         e = findmpi(&mpi, &mpil, 1024, &elen);
451         mpi = sig + 2 + hashl;
452         mpil = sigl - (2 + hashl);
453         m = findmpi(&mpi, &mpil, nlen, &mlen);
454         if (!n || !e || !m || !nlen || !elen)
455           return 0;
456         /* build padding block */
457         clen = (nlen - 1) / 8;
458         if (hashl + *oid + 2 > clen)
459           return 0;
460         c = solv_malloc(clen);
461         memset(c, 0xff, clen);
462         c[0] = 1;
463         memcpy(c + clen - hashl, sig + 2, hashl);
464         memcpy(c + clen - hashl - *oid, oid + 1, *oid);
465         c[clen - hashl - *oid - 1] = 0;
466         clen = clen * 8 - 7;    /* always <= nlen */
467         nx = mpbuild(n, nlen, nlen, &nxl);
468         ex = mpbuild(e, elen, elen, &exl);
469         mx = mpbuild(m, mlen, nlen, 0);
470         cx = mpbuild(c, clen, nlen, 0);
471         free(c);
472         res = mprsa(nxl, nx, exl, ex, mx, cx);
473         free(nx);
474         free(ex);
475         free(mx);
476         free(cx);
477         break;
478       }
479     case 17:            /* DSA */
480       {
481         const unsigned char *p, *q, *g, *y, *r, *s;
482         int plen, qlen, glen, ylen, rlen, slen, hlen;
483         mp_t *px, *qx, *gx, *yx, *rx, *sx, *hx;
484         int pxl, qxl, hxl;
485
486         mpi = pub + 1;
487         mpil = publ - 1;
488         p = findmpi(&mpi, &mpil, 8192, &plen);
489         q = findmpi(&mpi, &mpil, 1024, &qlen);
490         g = findmpi(&mpi, &mpil, plen, &glen);
491         y = findmpi(&mpi, &mpil, plen, &ylen);
492         mpi = sig + 2 + hashl;
493         mpil = sigl - (2 + hashl);
494         r = findmpi(&mpi, &mpil, qlen, &rlen);
495         s = findmpi(&mpi, &mpil, qlen, &slen);
496         if (!p || !q || !g || !y || !r || !s || !plen || !qlen)
497           return 0;
498         hlen = (qlen + 7) & ~7;
499         if (hlen > hashl * 8)
500           return 0;
501         px = mpbuild(p, plen, plen, &pxl);
502         qx = mpbuild(q, qlen, qlen, &qxl);
503         gx = mpbuild(g, glen, plen, 0);
504         yx = mpbuild(y, ylen, plen, 0);
505         rx = mpbuild(r, rlen, qlen, 0);
506         sx = mpbuild(s, slen, qlen, 0);
507         hx = mpbuild(sig + 2, hlen, hlen, &hxl);
508         res = mpdsa(pxl, px, qxl, qx, gx, yx, rx, sx, hxl, hx);
509         free(px);
510         free(qx);
511         free(gx);
512         free(yx);
513         free(rx);
514         free(sx);
515         free(hx);
516         break;
517       }
518     default:
519       return 0;         /* unsupported pubkey algo */
520     }
521   return res;
522 }
523