cleanup solv_pgpvrfy a bit
authorMichael Schroeder <mls@suse.de>
Fri, 26 Jul 2013 09:07:39 +0000 (11:07 +0200)
committerMichael Schroeder <mls@suse.de>
Fri, 26 Jul 2013 09:07:39 +0000 (11:07 +0200)
ext/solv_pgpvrfy.c

index 0da3131..3fe277e 100644 (file)
@@ -17,6 +17,7 @@
 typedef unsigned int mp_t;
 typedef unsigned long long mp2_t;
 #define MP_T_BYTES 4
+
 #define MP_T_BITS (MP_T_BYTES * 8)
 
 static inline void
@@ -31,6 +32,24 @@ mpnew(int len)
   return solv_calloc(len, MP_T_BYTES);
 }
 
+static inline void
+mpcpy(int len, mp_t *target, mp_t *source)
+{
+  memcpy(target, source, len * MP_T_BYTES);
+}
+
+#if 0
+static void mpdump(int l, mp_t *a, char *s)
+{
+  int i;
+  if (s)
+    fprintf(stderr, "%s", s);
+  for (i = l - 1; i >= 0; i--)
+    fprintf(stderr, "%0*x", MP_T_BYTES * 2, a[i]);
+  fprintf(stderr, "\n");
+}
+#endif
+
 /* target[len] = x, target = target % mod */
 static void
 mpdomod(int len, mp_t *target, mp2_t x, mp_t *mod)
@@ -102,7 +121,7 @@ mpmult_add_int(int len, mp_t *target, mp_t *src, mp2_t m, mp_t *mod)
   mpdomod(len, target, x, mod);
 }
 
-/* target = target * 2^MP_T_BITS */
+/* target = target * 2 ^ MP_T_BITS */
 static void
 mpshift(int len, mp_t *target, mp_t *mod)
 {
@@ -118,50 +137,48 @@ mpshift(int len, mp_t *target, mp_t *mod)
 
 /* target += m1 * m2 */
 static void
-mpmult_add(int len, mp_t *target, mp_t *m1, int m2len, mp_t *m2, mp_t *mod)
+mpmult_add(int len, mp_t *target, mp_t *m1, int m2len, mp_t *m2, mp_t *tmp, mp_t *mod)
 {
   int i, j;
-  mp_t *t;
   for (j = m2len - 1; j >= 0; j--)
     if (m2[j])
       break;
   if (j < 0)
     return;
-  t = mpnew(len);
-  memcpy(t, m1, len * MP_T_BYTES);
+  mpcpy(len, tmp, m1);
   for (i = 0; i < j; i++)
     {
       if (m2[i])
-        mpmult_add_int(len, target, t, m2[i], mod);
-      mpshift(len, t, mod);
+        mpmult_add_int(len, target, tmp, m2[i], mod);
+      mpshift(len, tmp, mod);
     }
   if (m2[i])
-    mpmult_add_int(len, target, t, m2[i], mod);
-  free(t);
+    mpmult_add_int(len, target, tmp, m2[i], mod);
 }
 
 /* target = target * m */
 static void
-mpmult_inplace(int len, mp_t *target, mp_t *m, mp_t *tmp, mp_t *mod)
+mpmult_inplace(int len, mp_t *target, mp_t *m, mp_t *tmp1, mp_t *tmp2, mp_t *mod)
 {
-  mpzero(len, tmp);
-  mpmult_add(len, tmp, target, len, m, mod);
-  memcpy(target, tmp, len * MP_T_BYTES);
+  mpzero(len, tmp1);
+  mpmult_add(len, tmp1, target, len, m, tmp2, mod);
+  mpcpy(len, target, tmp1);
 }
 
-/* target = target ^ (16 + e) */
+/* target = target ^ 16 * b ^ e */
 static void
 mppow_int(int len, mp_t *target, mp_t *t, mp_t *mod, int e)
 {
-  mpmult_inplace(len, target, target, t, mod);
-  mpmult_inplace(len, target, target, t, mod);
-  mpmult_inplace(len, target, target, t, mod);
-  mpmult_inplace(len, target, target, t, mod);
+  mp_t *t2 = t + len * 16;
+  mpmult_inplace(len, target, target, t, t2, mod);
+  mpmult_inplace(len, target, target, t, t2, mod);
+  mpmult_inplace(len, target, target, t, t2, mod);
+  mpmult_inplace(len, target, target, t, t2, mod);
   if (e)
-    mpmult_inplace(len, target, t + len * e, t, mod);
+    mpmult_inplace(len, target, t + len * e, t, t2, mod);
 }
 
-/* target = b ^ e */
+/* target = b ^ e (b has to be < mod) */
 static void
 mppow(int len, mp_t *target, mp_t *b, int elen, mp_t *e, mp_t *mod)
 {
@@ -174,10 +191,10 @@ mppow(int len, mp_t *target, mp_t *b, int elen, mp_t *e, mp_t *mod)
       break;
   if (i < 0)
     return;
-  t = mpnew(len * 16);
-  memcpy(t + len, b, len * MP_T_BYTES);
+  t = mpnew(len * 17);
+  mpcpy(len, t + len, b);
   for (j = 2; j < 16; j++)
-    mpmult_add(len, t + len * j, b, len, t + len * j - len, mod);
+    mpmult_add(len, t + len * j, b, len, t + len * j - len, t + len * 16, mod);
   for (; i >= 0; i--)
     {
 #if MP_T_BYTES == 4
@@ -197,6 +214,16 @@ mppow(int len, mp_t *target, mp_t *b, int elen, mp_t *e, mp_t *mod)
   free(t);
 }
 
+/* target = m1 * m2 (m1 has to be < mod) */
+static void
+mpmult(int len, mp_t *target, mp_t *m1, int m2len, mp_t *m2, mp_t *mod)
+{
+  mp_t *tmp = mpnew(len);
+  mpzero(len, target);
+  mpmult_add(len, target, m1, m2len, m2, tmp, mod);
+  free(tmp);
+}
+
 static int
 mpisless(int len, mp_t *a, mp_t *b)
 {
@@ -230,18 +257,6 @@ mpdec(int len, mp_t *a)
       a[i] = -(mp_t)1;
 }
 
-#if 0
-static void mpdump(int l, mp_t *a, char *s)
-{
-  int i;
-  if (s)
-    fprintf(stderr, "%s", s);
-  for (i = l - 1; i >= 0; i--)
-    fprintf(stderr, "%08x", a[i]);
-  fprintf(stderr, "\n");
-}
-#endif
-
 static int
 mpdsa(int pl, mp_t *p, int ql, mp_t *q, mp_t *g, mp_t *y, mp_t *r, mp_t *s, int hl, mp_t *h)
 {
@@ -264,30 +279,28 @@ mpdsa(int pl, mp_t *p, int ql, mp_t *q, mp_t *g, mp_t *y, mp_t *r, mp_t *s, int
     return 0;
   if (!mpisless(ql, s, q) || mpiszero(ql, s))
     return 0;
-  tmp = mpnew(pl);                     /* note pl! */
-  memcpy(tmp, q, ql * MP_T_BYTES);     /* tmp = q */
+  tmp = mpnew(pl);                     /* note pl */
+  mpcpy(ql, tmp, q);                   /* tmp = q */
   mpdec(ql, tmp);                      /* tmp-- */
   mpdec(ql, tmp);                      /* tmp-- */
   w = mpnew(ql);
   mppow(ql, w, s, ql, tmp, q);         /* w = s ^ tmp (s ^ -1) */
-  u1 = mpnew(pl);                      /* u1 = 0 */
+  u1 = mpnew(pl);                      /* note pl */
   /* order is important here: h can be >= q */
-  mpmult_add(ql, u1, w, hl, h, q);     /* u1 += w * h */
-  u2 = mpnew(pl);                      /* u2 = 0 */
-  mpmult_add(ql, u2, w, ql, r, q);     /* u2 += w * r */
+  mpmult(ql, u1, w, hl, h, q);         /* u1 = w * h */
+  u2 = mpnew(ql);                      /* u2 = 0 */
+  mpmult(ql, u2, w, ql, r, q);         /* u2 = w * r */
   free(w);
   gu1 = mpnew(pl);
   yu2 = mpnew(pl);
-  mppow(pl, gu1, g, pl, u1, p);                /* gu1 = g ^ u1 */
-  mppow(pl, yu2, y, pl, u2, p);                /* yu2 = y ^ u2 */
-  mpzero(pl, u1);                      /* u1 = 0 */
-  mpmult_add(pl, u1, gu1, pl, yu2, p); /* u1 += gu1 * yu2 */
+  mppow(pl, gu1, g, ql, u1, p);                /* gu1 = g ^ u1 */
+  mppow(pl, yu2, y, ql, u2, p);                /* yu2 = y ^ u2 */
+  mpmult(pl, u1, gu1, pl, yu2, p);     /* u1 = gu1 * yu2 */
   free(gu1);
   free(yu2);
   mpzero(ql, u2);
   u2[0] = 1;                           /* u2 = 1 */
-  mpzero(ql, tmp);                     /* tmp = 0 */
-  mpmult_add(ql, tmp, u2, pl, u1, q);  /* tmp += u2 * u1 */
+  mpmult(ql, tmp, u2, pl, u1, q);      /* tmp = u2 * u1 */
   free(u1);
   free(u2);
 #if 0
@@ -357,7 +370,6 @@ findmpi(unsigned char **mpip, int *mpilp, int maxbits, int *outlen)
 {
   int mpil = *mpilp;
   unsigned char *mpi = *mpip;
-  unsigned char *out = 0;
   int bits, l;
 
   *outlen = 0;
@@ -366,15 +378,14 @@ findmpi(unsigned char **mpip, int *mpilp, int maxbits, int *outlen)
   bits = mpi[0] << 8 | mpi[1];
   l = 2 + (bits + 7) / 8;
   if (bits > maxbits || mpil < l)
-    *mpilp = 0;
-  else
     {
-      out = mpi + 2;
-      *outlen = bits;
-      *mpilp = mpil - l;
-      *mpip = mpi + l;
+      *mpilp = 0;
+      return 0;
     }
-  return out;
+  *outlen = bits;
+  *mpilp = mpil - l;
+  *mpip = mpi + l;
+  return mpi + 2;
 }
 
 int