add slow but small dsa/rsa implementation
[platform/upstream/libsolv.git] / ext / solv_pgpvrfy.c
1 /*
2  * Copyright (c) 2013, SUSE Inc.
3  *
4  * This program is licensed under the BSD license, read LICENSE.BSD
5  * for further information
6  */
7
8 /* simple and slow rsa/dsa verification code. */
9
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <string.h>
13
14 #include "util.h"
15 #include "solv_pgpvrfy.h"
16
17 typedef unsigned int mp_t;
18 typedef unsigned long long mp2_t;
19 #define MP_T_BYTES 4
20 #define MP_T_BITS (MP_T_BYTES * 8)
21
22 static inline void
23 mpzero(int len, mp_t *target)
24 {
25   memset(target, 0, MP_T_BYTES * len);
26 }
27
28 static inline mp_t *
29 mpnew(int len)
30 {
31   return solv_calloc(len, MP_T_BYTES);
32 }
33
34 /* target[len] = x, target = target % mod */
35 static void
36 mpdomod(int len, mp_t *target, mp2_t x, mp_t *mod)
37 {
38   int i, j;
39   /* assumes that x does not overflow, i.e. target is not much bigger than mod! */
40   for (i = len - 1; i >= 0; i--)
41     {
42       x = (x << MP_T_BITS) | target[i];
43       target[i] = 0;
44       if (mod[i])
45         break;
46     }
47   if (i < 0)
48     return;
49   while (x >= 2 * (mp2_t)mod[i])
50     {
51       /* reduce */
52       mp2_t z = x / ((mp2_t)mod[i] + 1);
53       mp2_t n = 0;
54       for (j = 0; j < i; j++)
55         {
56           mp_t n2;
57           n += mod[j] * z;
58           n2 = (mp_t)n;
59           n >>= MP_T_BITS;
60           if (n2 > target[j])
61             n++;
62           target[j] -= n2;
63         }
64       n += mod[j] * z;
65       x -= n;
66     }
67   target[i] = x;
68   if (x >= mod[i])
69     {
70       mp_t n;
71       if (x == mod[i])
72         {
73           for (j = i - 1; j >= 0; j--)
74             if (target[j] < mod[j])
75               return;
76             else if (target[j] > mod[j])
77               break;
78         }
79       /* target >= mod, subtract mod */
80       n = 0;
81       for (j = 0; j <= i; j++)
82         {
83           mp2_t n2 = mod[j] + n;
84           n = n2 > target[j] ? 1 : 0;
85           target[j] -= (mp_t)n2;
86         }
87     }
88 }
89
90 /* target += src * m */
91 static void
92 mpmult_add_int(int len, mp_t *target, mp_t *src, mp2_t m, mp_t *mod)
93 {
94   int i;
95   mp2_t x = 0;
96   for (i = 0; i < len; i++)
97     {
98       x += src[i] * m + target[i];
99       target[i] = x;
100       x >>= MP_T_BITS;
101     }
102   mpdomod(len, target, x, mod);
103 }
104
105 /* target = target * 2^MP_T_BITS */
106 static void
107 mpshift(int len, mp_t *target, mp_t *mod)
108 {
109   mp_t x;
110   if (len <= 0)
111     return;
112   x = target[len - 1];
113   if (len > 1)
114     memmove(target + 1, target, (len - 1) * MP_T_BYTES);
115   target[0] = 0;
116   mpdomod(len, target, x, mod);
117 }
118
119 /* target += m1 * m2 */
120 static void
121 mpmult_add(int len, mp_t *target, mp_t *m1, int m2len, mp_t *m2, mp_t *mod)
122 {
123   int i, j;
124   mp_t *t;
125   for (j = m2len - 1; j >= 0; j--)
126     if (m2[j])
127       break;
128   if (j < 0)
129     return;
130   t = mpnew(len);
131   memcpy(t, m1, len * MP_T_BYTES);
132   for (i = 0; i < j; i++)
133     {
134       if (m2[i])
135         mpmult_add_int(len, target, t, m2[i], mod);
136       mpshift(len, t, mod);
137     }
138   if (m2[i])
139     mpmult_add_int(len, target, t, m2[i], mod);
140   free(t);
141 }
142
143 /* target = target * m */
144 static void
145 mpmult_inplace(int len, mp_t *target, mp_t *m, mp_t *tmp, mp_t *mod)
146 {
147   mpzero(len, tmp);
148   mpmult_add(len, tmp, target, len, m, mod);
149   memcpy(target, tmp, len * MP_T_BYTES);
150 }
151
152 /* target = target ^ (16 + e) */
153 static void
154 mppow_int(int len, mp_t *target, mp_t *t, mp_t *mod, int e)
155 {
156   mpmult_inplace(len, target, target, t, mod);
157   mpmult_inplace(len, target, target, t, mod);
158   mpmult_inplace(len, target, target, t, mod);
159   mpmult_inplace(len, target, target, t, mod);
160   if (e)
161     mpmult_inplace(len, target, t + len * e, t, mod);
162 }
163
164 /* target = b ^ e */
165 static void
166 mppow(int len, mp_t *target, mp_t *b, int elen, mp_t *e, mp_t *mod)
167 {
168   int i, j;
169   mp_t *t;
170   mpzero(len, target);
171   target[0] = 1;
172   for (i = elen - 1; i >= 0; i--)
173     if (e[i])
174       break;
175   if (i < 0)
176     return;
177   t = mpnew(len * 16);
178   memcpy(t + len, b, len * MP_T_BYTES);
179   for (j = 2; j < 16; j++)
180     mpmult_add(len, t + len * j, b, len, t + len * j - len, mod);
181   for (; i >= 0; i--)
182     {
183 #if MP_T_BYTES == 4
184       mppow_int(len, target, t, mod, (e[i] >> 28) & 0x0f);
185       mppow_int(len, target, t, mod, (e[i] >> 24) & 0x0f);
186       mppow_int(len, target, t, mod, (e[i] >> 20) & 0x0f);
187       mppow_int(len, target, t, mod, (e[i] >> 16) & 0x0f);
188       mppow_int(len, target, t, mod, (e[i] >> 12) & 0x0f);
189       mppow_int(len, target, t, mod, (e[i] >>  8) & 0x0f);
190       mppow_int(len, target, t, mod, (e[i] >>  4) & 0x0f);
191       mppow_int(len, target, t, mod,  e[i]        & 0x0f);
192 #elif MP_T_BYTES == 1
193       mppow_int(len, target, t, mod, (e[i] >>  4) & 0x0f);
194       mppow_int(len, target, t, mod,  e[i]        & 0x0f);
195 #endif
196     }
197   free(t);
198 }
199
200 static int
201 mpisless(int len, mp_t *a, mp_t *b)
202 {
203   int i;
204   for (i = len - 1; i >= 0; i--)
205     if (a[i] < b[i])
206       return 1;
207     else if (a[i] > b[i])
208       return 0;
209   return 0;
210 }
211
212 static int
213 mpiszero(int len, mp_t *a)
214 {
215   int i;
216   for (i = 0; i < len; i++)
217     if (a[i])
218       return 0;
219   return 1;
220 }
221
222 static void
223 mpdec(int len, mp_t *a)
224 {
225   int i;
226   for (i = 0; i < len; i++)
227     if (a[i]--)
228       return;
229     else
230       a[i] = -(mp_t)1;
231 }
232
233 #if 0
234 static void mpdump(int l, mp_t *a, char *s)
235 {
236   int i;
237   if (s)
238     fprintf(stderr, "%s", s);
239   for (i = l - 1; i >= 0; i--)
240     fprintf(stderr, "%08x", a[i]);
241   fprintf(stderr, "\n");
242 }
243 #endif
244
245 static int
246 mpdsa(int pl, mp_t *p, int ql, mp_t *q, mp_t *g, mp_t *y, mp_t *r, mp_t *s, int hl, mp_t *h)
247 {
248   mp_t *w;
249   mp_t *tmp;
250   mp_t *u1, *u2;
251   mp_t *gu1, *yu2;
252 #if 0
253   mpdump(pl, p, "p = ");
254   mpdump(ql, q, "q = ");
255   mpdump(pl, g, "g = ");
256   mpdump(pl, y, "y = ");
257   mpdump(ql, r, "r = ");
258   mpdump(ql, s, "s = ");
259   mpdump(hl, h, "h = ");
260 #endif
261   if (pl < ql || !mpisless(pl, g, p) || !mpisless(pl, y, p))
262     return 0;                           /* hmm, bad pubkey? */
263   if (!mpisless(ql, r, q) || mpiszero(ql, r))
264     return 0;
265   if (!mpisless(ql, s, q) || mpiszero(ql, s))
266     return 0;
267   tmp = mpnew(pl);                      /* note pl! */
268   memcpy(tmp, q, ql * MP_T_BYTES);      /* tmp = q */
269   mpdec(ql, tmp);                       /* tmp-- */
270   mpdec(ql, tmp);                       /* tmp-- */
271   w = mpnew(ql);
272   mppow(ql, w, s, ql, tmp, q);          /* w = s ^ tmp (s ^ -1) */
273   u1 = mpnew(pl);                       /* u1 = 0 */
274   /* order is important here: h can be >= q */
275   mpmult_add(ql, u1, w, hl, h, q);      /* u1 += w * h */
276   u2 = mpnew(pl);                       /* u2 = 0 */
277   mpmult_add(ql, u2, w, ql, r, q);      /* u2 += w * r */
278   free(w);
279   gu1 = mpnew(pl);
280   yu2 = mpnew(pl);
281   mppow(pl, gu1, g, pl, u1, p);         /* gu1 = g ^ u1 */
282   mppow(pl, yu2, y, pl, u2, p);         /* yu2 = y ^ u2 */
283   mpzero(pl, u1);                       /* u1 = 0 */
284   mpmult_add(pl, u1, gu1, pl, yu2, p);  /* u1 += gu1 * yu2 */
285   free(gu1);
286   free(yu2);
287   mpzero(ql, u2);
288   u2[0] = 1;                            /* u2 = 1 */
289   mpzero(ql, tmp);                      /* tmp = 0 */
290   mpmult_add(ql, tmp, u2, pl, u1, q);   /* tmp += u2 * u1 */
291   free(u1);
292   free(u2);
293 #if 0
294   mpdump(ql, tmp, "res = ");
295 #endif
296   if (memcmp(tmp, r, ql * MP_T_BYTES) != 0)
297     {
298       free(tmp);
299       return 0;
300     }
301   free(tmp);
302   return 1;
303 }
304
305 static int 
306 mprsa(int nl, mp_t *n, int el, mp_t *e, mp_t *m, mp_t *c)
307 {
308   mp_t *tmp;
309 #if 0
310   mpdump(nl, n, "n = ");
311   mpdump(el, e, "e = ");
312   mpdump(nl, m, "m = ");
313   mpdump(nl, c, "c = ");
314 #endif
315   if (!mpisless(nl, m, n))
316     return 0;
317   if (!mpisless(nl, c, n))
318     return 0;
319   tmp = mpnew(nl);
320   mppow(nl, tmp, m, el, e, n);          /* tmp = m ^ e */
321 #if 0
322   mpdump(nl, tmp, "res = ");
323 #endif
324   if (memcmp(tmp, c, nl * MP_T_BYTES) != 0)
325     {
326       free(tmp);
327       return 0;
328     }
329   free(tmp);
330   return 1;
331 }
332
333 /* create mp with size tbits from data with size dbits */
334 static mp_t *
335 mpbuild(unsigned char *d, int dbits, int tbits, int *mplp)
336 {
337   int l = (tbits + MP_T_BITS - 1) / MP_T_BITS;
338   int dl, i;
339
340   mp_t *out = mpnew(l ? l : 1);
341   if (mplp)
342     *mplp = l;
343   dl = (dbits + 7) / 8;
344   d += dl;
345   if (dbits > tbits)
346     dl = (tbits + 7) / 8;
347   for (i = 0; dl > 0; dl--, i++)
348     {
349       int x = *--d;
350       out[i / MP_T_BYTES] |= x << (8 * (i % MP_T_BYTES));
351     }
352   return out;
353 }
354
355 static unsigned char *
356 findmpi(unsigned char **mpip, int *mpilp, int maxbits, int *outlen)
357 {
358   int mpil = *mpilp;
359   unsigned char *mpi = *mpip;
360   unsigned char *out = 0;
361   int bits, l;
362
363   if (mpil < 2)
364     return 0;
365   bits = mpi[0] << 8 | mpi[1];
366   l = 2 + (bits + 7) / 8;
367   if (bits > maxbits || mpil < l)
368     *mpilp = 0;
369   else
370     {
371       out = mpi + 2;
372       *outlen = bits;
373       *mpilp = mpil - l;
374       *mpip = mpi + l;
375     }
376   return out;
377 }
378
379 int
380 solv_pgpvrfy(unsigned char *pub, int publ, unsigned char *sig, int sigl)
381 {
382   int hashl;
383   unsigned char *oid = 0;
384   unsigned char *mpi;
385   int mpil;
386   int res = 0;
387
388   if (!pub || !sig || publ < 1 || sigl < 2)
389     return 0;
390   if (pub[0] != sig[0])
391     return 0;           /* key algo mismatch */
392   switch(sig[1])
393     {
394     case 1:
395       hashl = 16;       /* MD5 */
396       oid = (unsigned char *)"\022\060\040\060\014\006\010\052\206\110\206\367\015\002\005\005\000\004\020";
397       break;
398     case 2:
399       hashl = 20;       /* SHA-1 */
400       oid = (unsigned char *)"\017\060\041\060\011\006\005\053\016\003\002\032\005\000\004\024";
401       break;
402     case 8:
403       hashl = 32;       /* SHA-256 */
404       oid = (unsigned char *)"\023\060\061\060\015\006\011\140\206\110\001\145\003\004\002\001\005\000\004\040";
405       break;
406     case 10:
407       hashl = 64;       /* SHA-512 */
408       oid = (unsigned char *)"\023\060\121\060\015\006\011\140\206\110\001\145\003\004\002\003\005\000\004\100";
409       break;
410     default:
411       return 0;         /* unsupported hash algo */
412     }
413   if (sigl < 2 + hashl)
414     return 0;
415   switch (pub[0])
416     {
417     case 1:             /* RSA */
418       {
419         unsigned char *n, *e, *m, *c;
420         int nlen, elen, mlen, clen;
421         mp_t *nx, *ex, *mx, *cx;
422         int nxl, exl;
423
424         mpi = pub + 1;
425         mpil = publ - 1;
426         n = findmpi(&mpi, &mpil, 8192, &nlen);
427         e = findmpi(&mpi, &mpil, 1024, &elen);
428         mpi = sig + 2 + hashl;
429         mpil = sigl - (2 + hashl);
430         m = findmpi(&mpi, &mpil, nlen, &mlen);
431         if (!n || !e || !m || !nlen || !elen)
432           return 0;
433         /* build padding block */
434         clen = (nlen - 1) / 8;
435         if (hashl + *oid + 2 > clen)
436           return 0;
437         c = solv_malloc(clen);
438         memset(c, 0xff, clen);
439         c[0] = 1;
440         memcpy(c + clen - hashl, sig + 2, hashl);
441         memcpy(c + clen - hashl - *oid, oid + 1, *oid);
442         c[clen - hashl - *oid - 1] = 0;
443         clen = clen * 8 - 7;    /* always <= nlen */
444         nx = mpbuild(n, nlen, nlen, &nxl);
445         ex = mpbuild(e, elen, elen, &exl);
446         mx = mpbuild(m, mlen, nlen, 0);
447         cx = mpbuild(c, clen, nlen, 0);
448         free(c);
449         res = mprsa(nxl, nx, exl, ex, mx, cx);
450         free(nx);
451         free(ex);
452         free(mx);
453         free(cx);
454         break;
455       }
456     case 17:            /* DSA */
457       {
458         unsigned char *p, *q, *g, *y, *r, *s;
459         int plen, qlen, glen, ylen, rlen, slen, hlen;
460         mp_t *px, *qx, *gx, *yx, *rx, *sx, *hx;
461         int pxl, qxl, hxl;
462
463         mpi = pub + 1;
464         mpil = publ - 1;
465         p = findmpi(&mpi, &mpil, 8192, &plen);
466         q = findmpi(&mpi, &mpil, 1024, &qlen);
467         g = findmpi(&mpi, &mpil, plen, &glen);
468         y = findmpi(&mpi, &mpil, plen, &ylen);
469         mpi = sig + 2 + hashl;
470         mpil = sigl - (2 + hashl);
471         r = findmpi(&mpi, &mpil, qlen, &rlen);
472         s = findmpi(&mpi, &mpil, qlen, &slen);
473         if (!p || !q || !g || !y || !r || !s || !plen || !qlen)
474           return 0;
475         hlen = (qlen + 7) & ~7;
476         if (hlen > hashl * 8)
477           return 0;
478         px = mpbuild(p, plen, plen, &pxl);
479         qx = mpbuild(q, qlen, qlen, &qxl);
480         gx = mpbuild(g, glen, plen, 0);
481         yx = mpbuild(y, ylen, plen, 0);
482         rx = mpbuild(r, rlen, qlen, 0);
483         sx = mpbuild(s, slen, qlen, 0);
484         hx = mpbuild(sig + 2, hlen, hlen, &hxl);
485         res = mpdsa(pxl, px, qxl, qx, gx, yx, rx, sx, hxl, hx);
486         free(px);
487         free(qx);
488         free(gx);
489         free(yx);
490         free(rx);
491         free(sx);
492         free(hx);
493         break;
494       }
495     default:
496       return 0;         /* unsupported pubkey algo */
497     }
498   return res;
499 }
500