30354b2c56cebe688c9e80fb056979268dee3c92
[platform/framework/web/crosswalk.git] / src / third_party / tlslite / tlslite / utils / cryptomath.py
1 # Authors: 
2 #   Trevor Perrin
3 #   Martin von Loewis - python 3 port
4 #
5 # See the LICENSE file for legal information regarding use of this file.
6
7 """cryptomath module
8
9 This module has basic math/crypto code."""
10 from __future__ import print_function
11 import os
12 import math
13 import base64
14 import binascii
15
16 from .compat import *
17
18
19 # **************************************************************************
20 # Load Optional Modules
21 # **************************************************************************
22
23 # Try to load M2Crypto/OpenSSL
24 try:
25     from M2Crypto import m2
26     m2cryptoLoaded = True
27
28 except ImportError:
29     m2cryptoLoaded = False
30
31 #Try to load GMPY
32 try:
33     import gmpy
34     gmpyLoaded = True
35 except ImportError:
36     gmpyLoaded = False
37
38 #Try to load pycrypto
39 try:
40     import Crypto.Cipher.AES
41     pycryptoLoaded = True
42 except ImportError:
43     pycryptoLoaded = False
44
45
46 # **************************************************************************
47 # PRNG Functions
48 # **************************************************************************
49
50 # Check that os.urandom works
51 import zlib
52 length = len(zlib.compress(os.urandom(1000)))
53 assert(length > 900)
54
55 def getRandomBytes(howMany):
56     b = bytearray(os.urandom(howMany))
57     assert(len(b) == howMany)
58     return b
59
60 prngName = "os.urandom"
61
62 # **************************************************************************
63 # Simple hash functions
64 # **************************************************************************
65
66 import hmac
67 import hashlib
68
69 def MD5(b):
70     return bytearray(hashlib.md5(compat26Str(b)).digest())
71
72 def SHA1(b):
73     return bytearray(hashlib.sha1(compat26Str(b)).digest())
74
75 def HMAC_MD5(k, b):
76     k = compatHMAC(k)
77     b = compatHMAC(b)
78     return bytearray(hmac.new(k, b, hashlib.md5).digest())
79
80 def HMAC_SHA1(k, b):
81     k = compatHMAC(k)
82     b = compatHMAC(b)
83     return bytearray(hmac.new(k, b, hashlib.sha1).digest())
84
85
86 # **************************************************************************
87 # Converter Functions
88 # **************************************************************************
89
90 def bytesToNumber(b):
91     total = 0
92     multiplier = 1
93     for count in range(len(b)-1, -1, -1):
94         byte = b[count]
95         total += multiplier * byte
96         multiplier *= 256
97     # Force-cast to long to appease PyCrypto.
98     # https://github.com/trevp/tlslite/issues/15
99     return long(total)
100
101 def numberToByteArray(n, howManyBytes=None):
102     """Convert an integer into a bytearray, zero-pad to howManyBytes.
103
104     The returned bytearray may be smaller than howManyBytes, but will
105     not be larger.  The returned bytearray will contain a big-endian
106     encoding of the input integer (n).
107     """    
108     if howManyBytes == None:
109         howManyBytes = numBytes(n)
110     b = bytearray(howManyBytes)
111     for count in range(howManyBytes-1, -1, -1):
112         b[count] = int(n % 256)
113         n >>= 8
114     return b
115
116 def mpiToNumber(mpi): #mpi is an openssl-format bignum string
117     if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number
118         raise AssertionError()
119     b = bytearray(mpi[4:])
120     return bytesToNumber(b)
121
122 def numberToMPI(n):
123     b = numberToByteArray(n)
124     ext = 0
125     #If the high-order bit is going to be set,
126     #add an extra byte of zeros
127     if (numBits(n) & 0x7)==0:
128         ext = 1
129     length = numBytes(n) + ext
130     b = bytearray(4+ext) + b
131     b[0] = (length >> 24) & 0xFF
132     b[1] = (length >> 16) & 0xFF
133     b[2] = (length >> 8) & 0xFF
134     b[3] = length & 0xFF
135     return bytes(b)
136
137
138 # **************************************************************************
139 # Misc. Utility Functions
140 # **************************************************************************
141
142 def numBits(n):
143     if n==0:
144         return 0
145     s = "%x" % n
146     return ((len(s)-1)*4) + \
147     {'0':0, '1':1, '2':2, '3':2,
148      '4':3, '5':3, '6':3, '7':3,
149      '8':4, '9':4, 'a':4, 'b':4,
150      'c':4, 'd':4, 'e':4, 'f':4,
151      }[s[0]]
152     return int(math.floor(math.log(n, 2))+1)
153
154 def numBytes(n):
155     if n==0:
156         return 0
157     bits = numBits(n)
158     return int(math.ceil(bits / 8.0))
159
160 # **************************************************************************
161 # Big Number Math
162 # **************************************************************************
163
164 def getRandomNumber(low, high):
165     if low >= high:
166         raise AssertionError()
167     howManyBits = numBits(high)
168     howManyBytes = numBytes(high)
169     lastBits = howManyBits % 8
170     while 1:
171         bytes = getRandomBytes(howManyBytes)
172         if lastBits:
173             bytes[0] = bytes[0] % (1 << lastBits)
174         n = bytesToNumber(bytes)
175         if n >= low and n < high:
176             return n
177
178 def gcd(a,b):
179     a, b = max(a,b), min(a,b)
180     while b:
181         a, b = b, a % b
182     return a
183
184 def lcm(a, b):
185     return (a * b) // gcd(a, b)
186
187 #Returns inverse of a mod b, zero if none
188 #Uses Extended Euclidean Algorithm
189 def invMod(a, b):
190     c, d = a, b
191     uc, ud = 1, 0
192     while c != 0:
193         q = d // c
194         c, d = d-(q*c), c
195         uc, ud = ud - (q * uc), uc
196     if d == 1:
197         return ud % b
198     return 0
199
200
201 if gmpyLoaded:
202     def powMod(base, power, modulus):
203         base = gmpy.mpz(base)
204         power = gmpy.mpz(power)
205         modulus = gmpy.mpz(modulus)
206         result = pow(base, power, modulus)
207         return long(result)
208
209 else:
210     def powMod(base, power, modulus):
211         if power < 0:
212             result = pow(base, power*-1, modulus)
213             result = invMod(result, modulus)
214             return result
215         else:
216             return pow(base, power, modulus)
217
218 #Pre-calculate a sieve of the ~100 primes < 1000:
219 def makeSieve(n):
220     sieve = list(range(n))
221     for count in range(2, int(math.sqrt(n))):
222         if sieve[count] == 0:
223             continue
224         x = sieve[count] * 2
225         while x < len(sieve):
226             sieve[x] = 0
227             x += sieve[count]
228     sieve = [x for x in sieve[2:] if x]
229     return sieve
230
231 sieve = makeSieve(1000)
232
233 def isPrime(n, iterations=5, display=False):
234     #Trial division with sieve
235     for x in sieve:
236         if x >= n: return True
237         if n % x == 0: return False
238     #Passed trial division, proceed to Rabin-Miller
239     #Rabin-Miller implemented per Ferguson & Schneier
240     #Compute s, t for Rabin-Miller
241     if display: print("*", end=' ')
242     s, t = n-1, 0
243     while s % 2 == 0:
244         s, t = s//2, t+1
245     #Repeat Rabin-Miller x times
246     a = 2 #Use 2 as a base for first iteration speedup, per HAC
247     for count in range(iterations):
248         v = powMod(a, s, n)
249         if v==1:
250             continue
251         i = 0
252         while v != n-1:
253             if i == t-1:
254                 return False
255             else:
256                 v, i = powMod(v, 2, n), i+1
257         a = getRandomNumber(2, n)
258     return True
259
260 def getRandomPrime(bits, display=False):
261     if bits < 10:
262         raise AssertionError()
263     #The 1.5 ensures the 2 MSBs are set
264     #Thus, when used for p,q in RSA, n will have its MSB set
265     #
266     #Since 30 is lcm(2,3,5), we'll set our test numbers to
267     #29 % 30 and keep them there
268     low = ((2 ** (bits-1)) * 3) // 2
269     high = 2 ** bits - 30
270     p = getRandomNumber(low, high)
271     p += 29 - (p % 30)
272     while 1:
273         if display: print(".", end=' ')
274         p += 30
275         if p >= high:
276             p = getRandomNumber(low, high)
277             p += 29 - (p % 30)
278         if isPrime(p, display=display):
279             return p
280
281 #Unused at the moment...
282 def getRandomSafePrime(bits, display=False):
283     if bits < 10:
284         raise AssertionError()
285     #The 1.5 ensures the 2 MSBs are set
286     #Thus, when used for p,q in RSA, n will have its MSB set
287     #
288     #Since 30 is lcm(2,3,5), we'll set our test numbers to
289     #29 % 30 and keep them there
290     low = (2 ** (bits-2)) * 3//2
291     high = (2 ** (bits-1)) - 30
292     q = getRandomNumber(low, high)
293     q += 29 - (q % 30)
294     while 1:
295         if display: print(".", end=' ')
296         q += 30
297         if (q >= high):
298             q = getRandomNumber(low, high)
299             q += 29 - (q % 30)
300         #Ideas from Tom Wu's SRP code
301         #Do trial division on p and q before Rabin-Miller
302         if isPrime(q, 0, display=display):
303             p = (2 * q) + 1
304             if isPrime(p, display=display):
305                 if isPrime(q, display=display):
306                     return p